mirror of https://github.com/dbcli/pgcli
Merge pull request #83 from qwesda/master
support for table/column names which should be escaped
This commit is contained in:
commit
1464a3396f
|
@ -172,7 +172,7 @@ def cli(database, user, host, port, prompt_passwd, never_prompt):
|
|||
def need_completion_refresh(sql):
|
||||
try:
|
||||
first_token = sql.split()[0]
|
||||
return first_token in ('alter', 'create', 'use', '\c', 'drop')
|
||||
return first_token.lower() in ('alter', 'create', 'use', '\c', 'drop')
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
@ -219,6 +219,7 @@ def refresh_completions(pgexecute, completer):
|
|||
tables = pgexecute.tables()
|
||||
completer.extend_table_names(tables)
|
||||
for table in tables:
|
||||
table = table[1:-1] if table[0] == '"' and table[-1] == '"' else table
|
||||
completer.extend_column_names(table, pgexecute.columns(table))
|
||||
completer.extend_database_names(pgexecute.databases())
|
||||
|
||||
|
|
|
@ -114,8 +114,12 @@ def extract_table_identifiers(token_stream):
|
|||
yield (real_name, identifier.get_alias() or real_name)
|
||||
elif isinstance(item, Identifier):
|
||||
real_name = item.get_real_name()
|
||||
|
||||
if real_name:
|
||||
yield (real_name, item.get_alias() or real_name)
|
||||
else:
|
||||
name = item.get_name()
|
||||
yield (name, item.get_alias() or name)
|
||||
elif isinstance(item, Function):
|
||||
yield (item.get_name(), item.get_name())
|
||||
|
||||
|
|
|
@ -72,7 +72,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text):
|
|||
return 'columns', extract_tables(full_text)
|
||||
elif token_v.lower() in ('select', 'where', 'having'):
|
||||
return 'columns-and-functions', extract_tables(full_text)
|
||||
elif token_v.lower() in ('from', 'update', 'into', 'describe', 'join'):
|
||||
elif token_v.lower() in ('from', 'update', 'into', 'describe', 'join', 'table'):
|
||||
return 'tables', []
|
||||
elif token_v in ('d',): # \d
|
||||
return 'tables', []
|
||||
|
|
|
@ -4,6 +4,7 @@ from collections import defaultdict
|
|||
from prompt_toolkit.completion import Completer, Completion
|
||||
from .packages.sqlcompletion import suggest_type
|
||||
from .packages.parseutils import last_word
|
||||
from re import compile
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -45,12 +46,31 @@ class PGCompleter(Completer):
|
|||
super(self.__class__, self).__init__()
|
||||
self.smart_completion = smart_completion
|
||||
|
||||
self.name_pattern = compile("^[_a-z][_a-z0-9\$]*$")
|
||||
|
||||
def extend_escape_name(self, name):
|
||||
if not self.name_pattern.match(name) or name in self.keywords or name in self.functions:
|
||||
name = '"%s"' % name
|
||||
|
||||
return name
|
||||
|
||||
def extend_unescape_name(self, name):
|
||||
if name[0] == '"' and name[-1] == '"':
|
||||
name = name[1:-1]
|
||||
|
||||
return name
|
||||
|
||||
def extend_escaped_names(self, names):
|
||||
return [self.extend_escape_name(name) for name in names]
|
||||
|
||||
def extend_special_commands(self, special_commands):
|
||||
# Special commands are not part of all_completions since they can only
|
||||
# be at the beginning of a line.
|
||||
self.special_commands.extend(special_commands)
|
||||
|
||||
def extend_database_names(self, databases):
|
||||
databases = self.extend_escaped_names(databases)
|
||||
|
||||
self.databases.extend(databases)
|
||||
|
||||
def extend_keywords(self, additional_keywords):
|
||||
|
@ -58,11 +78,17 @@ class PGCompleter(Completer):
|
|||
self.all_completions.update(additional_keywords)
|
||||
|
||||
def extend_table_names(self, tables):
|
||||
tables = self.extend_escaped_names(tables)
|
||||
|
||||
self.tables.extend(tables)
|
||||
self.all_completions.update(tables)
|
||||
|
||||
def extend_column_names(self, table, columns):
|
||||
self.columns[table].extend(columns)
|
||||
columns = self.extend_escaped_names(columns)
|
||||
|
||||
unescaped_table_name = self.extend_unescape_name(table)
|
||||
|
||||
self.columns[unescaped_table_name].extend(columns)
|
||||
self.all_completions.update(columns)
|
||||
|
||||
def reset_completions(self):
|
||||
|
@ -75,7 +101,9 @@ class PGCompleter(Completer):
|
|||
def find_matches(text, collection):
|
||||
text = last_word(text, include='most_punctuations')
|
||||
for item in collection:
|
||||
if item.startswith(text) or item.startswith(text.upper()):
|
||||
item_unescaped = item[1:] if item[0] == '"' else item
|
||||
|
||||
if item_unescaped.startswith(text) or item_unescaped.startswith(text.upper()):
|
||||
yield Completion(item, -len(text))
|
||||
|
||||
def get_completions(self, document, complete_event, smart_completion=None):
|
||||
|
@ -115,5 +143,6 @@ class PGCompleter(Completer):
|
|||
def populate_scoped_cols(self, tables):
|
||||
scoped_cols = []
|
||||
for table in tables:
|
||||
scoped_cols.extend(self.columns[table])
|
||||
unescaped_table_name = self.extend_unescape_name(table)
|
||||
scoped_cols.extend(self.columns[unescaped_table_name])
|
||||
return scoped_cols
|
||||
|
|
Loading…
Reference in New Issue