1
0
Fork 0

Merge pull request #83 from qwesda/master

support for table/column names which should be escaped
This commit is contained in:
Amjith Ramanujam 2015-01-08 22:50:32 -08:00
commit 1464a3396f
4 changed files with 39 additions and 5 deletions

View File

@ -172,7 +172,7 @@ def cli(database, user, host, port, prompt_passwd, never_prompt):
def need_completion_refresh(sql):
try:
first_token = sql.split()[0]
return first_token in ('alter', 'create', 'use', '\c', 'drop')
return first_token.lower() in ('alter', 'create', 'use', '\c', 'drop')
except Exception:
return False
@ -219,6 +219,7 @@ def refresh_completions(pgexecute, completer):
tables = pgexecute.tables()
completer.extend_table_names(tables)
for table in tables:
table = table[1:-1] if table[0] == '"' and table[-1] == '"' else table
completer.extend_column_names(table, pgexecute.columns(table))
completer.extend_database_names(pgexecute.databases())

View File

@ -114,8 +114,12 @@ def extract_table_identifiers(token_stream):
yield (real_name, identifier.get_alias() or real_name)
elif isinstance(item, Identifier):
real_name = item.get_real_name()
if real_name:
yield (real_name, item.get_alias() or real_name)
else:
name = item.get_name()
yield (name, item.get_alias() or name)
elif isinstance(item, Function):
yield (item.get_name(), item.get_name())

View File

@ -72,7 +72,7 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text):
return 'columns', extract_tables(full_text)
elif token_v.lower() in ('select', 'where', 'having'):
return 'columns-and-functions', extract_tables(full_text)
elif token_v.lower() in ('from', 'update', 'into', 'describe', 'join'):
elif token_v.lower() in ('from', 'update', 'into', 'describe', 'join', 'table'):
return 'tables', []
elif token_v in ('d',): # \d
return 'tables', []

View File

@ -4,6 +4,7 @@ from collections import defaultdict
from prompt_toolkit.completion import Completer, Completion
from .packages.sqlcompletion import suggest_type
from .packages.parseutils import last_word
from re import compile
_logger = logging.getLogger(__name__)
@ -45,12 +46,31 @@ class PGCompleter(Completer):
super(self.__class__, self).__init__()
self.smart_completion = smart_completion
self.name_pattern = compile("^[_a-z][_a-z0-9\$]*$")
def extend_escape_name(self, name):
if not self.name_pattern.match(name) or name in self.keywords or name in self.functions:
name = '"%s"' % name
return name
def extend_unescape_name(self, name):
if name[0] == '"' and name[-1] == '"':
name = name[1:-1]
return name
def extend_escaped_names(self, names):
return [self.extend_escape_name(name) for name in names]
def extend_special_commands(self, special_commands):
# Special commands are not part of all_completions since they can only
# be at the beginning of a line.
self.special_commands.extend(special_commands)
def extend_database_names(self, databases):
databases = self.extend_escaped_names(databases)
self.databases.extend(databases)
def extend_keywords(self, additional_keywords):
@ -58,11 +78,17 @@ class PGCompleter(Completer):
self.all_completions.update(additional_keywords)
def extend_table_names(self, tables):
tables = self.extend_escaped_names(tables)
self.tables.extend(tables)
self.all_completions.update(tables)
def extend_column_names(self, table, columns):
self.columns[table].extend(columns)
columns = self.extend_escaped_names(columns)
unescaped_table_name = self.extend_unescape_name(table)
self.columns[unescaped_table_name].extend(columns)
self.all_completions.update(columns)
def reset_completions(self):
@ -75,7 +101,9 @@ class PGCompleter(Completer):
def find_matches(text, collection):
text = last_word(text, include='most_punctuations')
for item in collection:
if item.startswith(text) or item.startswith(text.upper()):
item_unescaped = item[1:] if item[0] == '"' else item
if item_unescaped.startswith(text) or item_unescaped.startswith(text.upper()):
yield Completion(item, -len(text))
def get_completions(self, document, complete_event, smart_completion=None):
@ -115,5 +143,6 @@ class PGCompleter(Completer):
def populate_scoped_cols(self, tables):
scoped_cols = []
for table in tables:
scoped_cols.extend(self.columns[table])
unescaped_table_name = self.extend_unescape_name(table)
scoped_cols.extend(self.columns[unescaped_table_name])
return scoped_cols