1
0
Fork 0

Merge pull request #127 from darikg/schema_autocomplete

Make autocomplete schema-aware
This commit is contained in:
Amjith Ramanujam 2015-01-26 20:26:52 -08:00
commit 6944ef60f8
13 changed files with 737 additions and 209 deletions

View File

@ -214,6 +214,12 @@ class PGCli(object):
end = time()
total += end - start
mutating = mutating or is_mutating(status)
if need_search_path_refresh(document.text, status):
logger.debug('Refreshing search path')
completer.set_search_path(pgexecute.search_path())
logger.debug('Search path: %r', completer.search_path)
except KeyboardInterrupt:
# Restart connection to the database
pgexecute.connect()
@ -262,13 +268,16 @@ class PGCli(object):
return less_opts
def refresh_completions(self):
self.completer.reset_completions()
tables, columns = self.pgexecute.tables()
self.completer.extend_table_names(tables)
for table in tables:
table = table[1:-1] if table[0] == '"' and table[-1] == '"' else table
self.completer.extend_column_names(table, columns[table])
self.completer.extend_database_names(self.pgexecute.databases())
completer = self.completer
completer.reset_completions()
pgexecute = self.pgexecute
completer.set_search_path(pgexecute.search_path())
completer.extend_schemata(pgexecute.schemata())
completer.extend_tables(pgexecute.tables())
completer.extend_columns(pgexecute.columns())
completer.extend_database_names(pgexecute.databases())
def get_completions(self, text, cursor_positition):
return self.completer.get_completions(
@ -329,6 +338,22 @@ def need_completion_refresh(sql):
except Exception:
return False
def need_search_path_refresh(sql, status):
# note that sql may be a multi-command query, but status belongs to an
# individual query, since pgexecute handles splitting up multi-commands
try:
status = status.split()[0]
if status.lower() == 'set':
# Since sql could be a multi-line query, it's hard to robustly
# pick out the variable name that's been set. Err on the side of
# false positives here, since the worst case is we refresh the
# search path when it's not necessary
return 'search_path' in sql.lower()
else:
return False
except Exception:
return False
def is_mutating(status):
"""Determines if the statement is mutating based on the status."""
if not status:
@ -349,6 +374,5 @@ def quit_command(sql):
or sql.strip() == '\q'
or sql.strip() == ':q')
if __name__ == "__main__":
cli()

View File

@ -101,50 +101,50 @@ def extract_from_part(parsed, stop_at_punctuation=True):
break
def extract_table_identifiers(token_stream):
"""yields tuples of (schema_name, table_name, table_alias)"""
for item in token_stream:
if isinstance(item, IdentifierList):
for identifier in item.get_identifiers():
# Sometimes Keywords (such as FROM ) are classified as
# identifiers which don't have the get_real_name() method.
try:
schema_name = identifier.get_parent_name()
real_name = identifier.get_real_name()
except AttributeError:
continue
if real_name:
yield (real_name, identifier.get_alias() or real_name)
yield (schema_name, real_name, identifier.get_alias())
elif isinstance(item, Identifier):
real_name = item.get_real_name()
schema_name = item.get_parent_name()
if real_name:
yield (real_name, item.get_alias() or real_name)
yield (schema_name, real_name, item.get_alias())
else:
name = item.get_name()
yield (name, item.get_alias() or name)
yield (None, name, item.get_alias() or name)
elif isinstance(item, Function):
yield (item.get_name(), item.get_name())
yield (None, item.get_name(), item.get_name())
# extract_tables is inspired from examples in the sqlparse lib.
def extract_tables(sql, include_alias=False):
def extract_tables(sql):
"""Extract the table names from an SQL statment.
Returns a list of table names if include_alias=False (default).
If include_alias=True, then a dictionary is returned where the keys are
aliases and values are real table names.
Returns a list of (schema, table, alias) tuples
"""
parsed = sqlparse.parse(sql)
if not parsed:
return []
# INSERT statements must stop looking for tables at the sign of first
# Punctuation. eg: INSERT INTO abc (col1, col2) VALUES (1, 2)
# abc is the table name, but if we don't stop at the first lparen, then
# we'll identify abc, col1 and col2 as table names.
insert_stmt = parsed[0].token_first().value.lower() == 'insert'
stream = extract_from_part(parsed[0], stop_at_punctuation=insert_stmt)
if include_alias:
return dict((alias, t) for t, alias in extract_table_identifiers(stream))
else:
return [x[0] for x in extract_table_identifiers(stream)]
return list(extract_table_identifiers(stream))
def find_prev_keyword(sql):
if not sql.strip():
@ -156,4 +156,4 @@ def find_prev_keyword(sql):
if __name__ == '__main__':
sql = 'select * from (select t. from tabl t'
print (extract_tables(sql, True))
print (extract_tables(sql))

View File

@ -66,28 +66,65 @@ def suggest_based_on_last_token(token, text_before_cursor, full_text):
# If the lparen is preceeded by a space chances are we're about to
# do a sub-select.
if last_word(text_before_cursor, 'all_punctuations').startswith('('):
return 'keywords', []
return 'columns', extract_tables(full_text)
return [{'type': 'keyword'}]
return [{'type': 'column', 'tables': extract_tables(full_text)}]
if token_v.lower() in ('set', 'by', 'distinct'):
return 'columns', extract_tables(full_text)
return [{'type': 'column', 'tables': extract_tables(full_text)}]
elif token_v.lower() in ('select', 'where', 'having'):
return 'columns-and-functions', extract_tables(full_text)
return [{'type': 'column', 'tables': extract_tables(full_text)},
{'type': 'function'}]
elif token_v.lower() in ('from', 'update', 'into', 'describe', 'join', 'table'):
return 'tables', []
return [{'type': 'schema'}, {'type': 'table', 'schema': []}]
elif token_v.lower() == 'on':
tables = extract_tables(full_text, include_alias=True)
return 'tables-or-aliases', tables.keys()
tables = extract_tables(full_text) # [(schema, table, alias), ...]
# Use table alias if there is one, otherwise the table name
alias = [t[2] or t[1] for t in tables]
return [{'type': 'alias', 'aliases': alias}]
elif token_v in ('d',): # \d
return 'tables', []
# Apparently "\d <other>" is parsed by sqlparse as
# Identifer('d', Whitespace, '<other>')
if len(token.tokens) > 2:
other = token.tokens[-1].value
identifiers = other.split('.')
if len(identifiers) == 1:
# "\d table" or "\d schema"
return [{'type': 'schema'}, {'type': 'table', 'schema': []}]
elif len(identifiers) == 2:
# \d schema.table
return [{'type': 'table', 'schema': identifiers[0]}]
else:
return [{'type': 'schema'}, {'type': 'table', 'schema': []}]
elif token_v.lower() in ('c', 'use'): # \c
return 'databases', []
return [{'type': 'database'}]
elif token_v.endswith(',') or token_v == '=':
prev_keyword = find_prev_keyword(text_before_cursor)
if prev_keyword:
return suggest_based_on_last_token(prev_keyword, text_before_cursor, full_text)
return suggest_based_on_last_token(
prev_keyword, text_before_cursor, full_text)
elif token_v.endswith('.'):
current_alias = last_word(token_v[:-1])
tables = extract_tables(full_text, include_alias=True)
return 'columns', [tables.get(current_alias) or current_alias]
return 'keywords', []
suggestions = []
identifier = last_word(token_v[:-1], 'all_punctuations')
# TABLE.<suggestion> or SCHEMA.TABLE.<suggestion>
tables = extract_tables(full_text)
tables = [t for t in tables if identifies(identifier, *t)]
suggestions.append({'type': 'column', 'tables': tables})
# SCHEMA.<suggestion>
suggestions.append({'type': 'table', 'schema': identifier})
return suggestions
return [{'type': 'keyword'}]
def identifies(id, schema, table, alias):
return id == alias or id == table or (
schema and (id == schema + '.' + table))

View File

@ -1,11 +1,11 @@
from __future__ import print_function
import logging
from collections import defaultdict
from prompt_toolkit.completion import Completer, Completion
from .packages.sqlcompletion import suggest_type
from .packages.parseutils import last_word
from re import compile
_logger = logging.getLogger(__name__)
class PGCompleter(Completer):
@ -21,7 +21,7 @@ class PGCompleter(Completer):
'MAXEXTENTS', 'MINUS', 'MLSLABEL', 'MODE', 'MODIFY', 'NOAUDIT',
'NOCOMPRESS', 'NOT', 'NOWAIT', 'NULL', 'NUMBER', 'OF', 'OFFLINE',
'ON', 'ONLINE', 'OPTION', 'OR', 'ORDER BY', 'OUTER', 'PCTFREE',
'PRIMARY', 'PRIOR', 'PRIVILEGES', 'PUBLIC', 'RAW', 'RENAME',
'PRIMARY', 'PRIOR', 'PRIVILEGES', 'RAW', 'RENAME',
'RESOURCE', 'REVOKE', 'RIGHT', 'ROW', 'ROWID', 'ROWNUM', 'ROWS',
'SELECT', 'SESSION', 'SET', 'SHARE', 'SIZE', 'SMALLINT', 'START',
'SUCCESSFUL', 'SYNONYM', 'SYSDATE', 'TABLE', 'THEN', 'TO',
@ -33,15 +33,6 @@ class PGCompleter(Completer):
'LCASE', 'LEN', 'MAX', 'MIN', 'MID', 'NOW', 'ROUND', 'SUM', 'TOP',
'UCASE']
special_commands = []
databases = []
tables = []
# This will create a defaultdict which is initialized with a list that has
# a '*' by default.
columns = defaultdict(lambda: ['*'])
all_completions = set(keywords + functions)
def __init__(self, smart_completion=True):
super(self.__class__, self).__init__()
self.smart_completion = smart_completion
@ -50,8 +41,15 @@ class PGCompleter(Completer):
self.reserved_words.update(x.split())
self.name_pattern = compile("^[_a-z][_a-z0-9\$]*$")
self.special_commands = []
self.databases = []
self.dbmetadata = {}
self.search_path = []
self.all_completions = set(self.keywords + self.functions)
def escape_name(self, name):
if ((not self.name_pattern.match(name))
if name and ((not self.name_pattern.match(name))
or (name.upper() in self.reserved_words)
or (name.upper() in self.functions)):
name = '"%s"' % name
@ -60,7 +58,7 @@ class PGCompleter(Completer):
def unescape_name(self, name):
""" Unquote a string."""
if name[0] == '"' and name[-1] == '"':
if name and name[0] == '"' and name[-1] == '"':
name = name[1:-1]
return name
@ -75,31 +73,50 @@ class PGCompleter(Completer):
def extend_database_names(self, databases):
databases = self.escaped_names(databases)
self.databases.extend(databases)
def extend_keywords(self, additional_keywords):
self.keywords.extend(additional_keywords)
self.all_completions.update(additional_keywords)
def extend_table_names(self, tables):
tables = self.escaped_names(tables)
def extend_schemata(self, schemata):
self.tables.extend(tables)
self.all_completions.update(tables)
# data is a DataFrame with columns [schema]
schemata = self.escaped_names(schemata)
for schema in schemata:
self.dbmetadata[schema] = {}
def extend_column_names(self, table, columns):
columns = self.escaped_names(columns)
self.all_completions.update(schemata)
unescaped_table_name = self.unescape_name(table)
def extend_tables(self, table_data):
self.columns[unescaped_table_name].extend(columns)
self.all_completions.update(columns)
# table_data is a list of (schema_name, table_name) tuples
table_data = [self.escaped_names(d) for d in table_data]
# dbmetadata['schema_name']['table_name'] should be a list of column
# names. Default to an asterisk
for schema, table in table_data:
self.dbmetadata[schema][table] = ['*']
self.all_completions.update(t[1] for t in table_data)
def extend_columns(self, column_data):
# column_data is a list of (schema_name, table_name, column_name) tuples
column_data = [self.escaped_names(d) for d in column_data]
for schema, table, column in column_data:
self.dbmetadata[schema][table].append(column)
self.all_completions.update(t[2] for t in column_data)
def set_search_path(self, search_path):
self.search_path = self.escaped_names(search_path)
def reset_completions(self):
self.databases = []
self.tables = []
self.columns = defaultdict(lambda: ['*'])
self.search_path = []
self.dbmetadata = {}
self.all_completions = set(self.keywords)
@staticmethod
@ -119,36 +136,90 @@ class PGCompleter(Completer):
if not smart_completion:
return self.find_matches(word_before_cursor, self.all_completions)
category, scope = suggest_type(document.text,
document.text_before_cursor)
completions = []
suggestions = suggest_type(document.text, document.text_before_cursor)
for suggestion in suggestions:
_logger.debug('Suggestion type: %r', suggestion['type'])
if suggestion['type'] == 'column':
tables = suggestion['tables']
_logger.debug("Completion column scope: %r", tables)
scoped_cols = self.populate_scoped_cols(tables)
cols = self.find_matches(word_before_cursor, scoped_cols)
completions.extend(cols)
elif suggestion['type'] == 'function':
funcs = self.find_matches(word_before_cursor, self.functions)
completions.extend(funcs)
elif suggestion['type'] == 'schema':
schema_names = self.dbmetadata.keys()
schema_names = self.find_matches(word_before_cursor, schema_names)
completions.extend(schema_names)
elif suggestion['type'] == 'table':
if suggestion['schema']:
try:
tables = self.dbmetadata[suggestion['schema']].keys()
except KeyError:
#schema doesn't exist
tables = []
else:
schemas = self.search_path
meta = self.dbmetadata
tables = [tbl for schema in schemas
for tbl in meta[schema].keys()]
tables = self.find_matches(word_before_cursor, tables)
completions.extend(tables)
elif suggestion['type'] == 'alias':
aliases = suggestion['aliases']
aliases = self.find_matches(word_before_cursor, aliases)
completions.extend(aliases)
elif suggestion['type'] == 'database':
dbs = self.find_matches(word_before_cursor, self.databases)
completions.extend(dbs)
elif suggestion['type'] == 'keyword':
keywords = self.keywords + self.special_commands
keywords = self.find_matches(word_before_cursor, keywords)
completions.extend(keywords)
return completions
def populate_scoped_cols(self, scoped_tbls):
""" Find all columns in a set of scoped_tables
:param scoped_tbls: list of (schema, table, alias) tuples
:return: list of column names
"""
columns = []
meta = self.dbmetadata
for tbl in scoped_tbls:
if tbl[0]:
# A fully qualified schema.table reference
schema = self.escape_name(tbl[0])
table = self.escape_name(tbl[1])
try:
# Get columns from the corresponding schema.table
columns.extend(meta[schema][table])
except KeyError:
# Either the schema or table doesn't exist
pass
else:
for schema in self.search_path:
table = self.escape_name(tbl[1])
try:
columns.extend(meta[schema][table])
break
except KeyError:
pass
return columns
if category == 'columns':
_logger.debug("Completion: 'columns' Scope: %r", scope)
scoped_cols = self.populate_scoped_cols(scope)
return self.find_matches(word_before_cursor, scoped_cols)
elif category == 'columns-and-functions':
_logger.debug("Completion: 'columns-and-functions' Scope: %r",
scope)
scoped_cols = self.populate_scoped_cols(scope)
return self.find_matches(word_before_cursor, scoped_cols +
self.functions)
elif category == 'tables':
_logger.debug("Completion: 'tables' Scope: %r", scope)
return self.find_matches(word_before_cursor, self.tables)
elif category == 'tables-or-aliases':
_logger.debug("Completion: 'tables-or-aliases' Scope: %r", scope)
return self.find_matches(word_before_cursor, scope)
elif category == 'databases':
_logger.debug("Completion: 'databases' Scope: %r", scope)
return self.find_matches(word_before_cursor, self.databases)
elif category == 'keywords':
_logger.debug("Completion: 'keywords' Scope: %r", scope)
return self.find_matches(word_before_cursor, self.keywords +
self.special_commands)
def populate_scoped_cols(self, tables):
scoped_cols = []
for table in tables:
unescaped_table_name = self.unescape_name(table)
scoped_cols.extend(self.columns[unescaped_table_name])
return scoped_cols

View File

@ -4,7 +4,6 @@ import psycopg2
import psycopg2.extras
import psycopg2.extensions as ext
import sqlparse
from collections import defaultdict
from .packages import pgspecial
PY2 = sys.version_info[0] == 2
@ -30,13 +29,43 @@ psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select)
class PGExecute(object):
tables_query = '''SELECT c.relname as "Name" FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE
c.relkind IN ('r','') AND n.nspname <> 'pg_catalog' AND n.nspname <>
'information_schema' AND n.nspname !~ '^pg_toast' AND
pg_catalog.pg_table_is_visible(c.oid) ORDER BY 1;'''
search_path_query = '''
SELECT * FROM unnest(current_schemas(false))'''
schemata_query = '''
SELECT nspname
FROM pg_catalog.pg_namespace
WHERE nspname !~ '^pg_'
AND nspname <> 'information_schema'
ORDER BY 1 '''
tables_query = '''
SELECT n.nspname schema_name,
c.relname table_name
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n
ON n.oid = c.relnamespace
WHERE c.relkind IN ('r','v', 'm') -- table, view, materialized view
AND n.nspname !~ '^pg_toast'
AND n.nspname NOT IN ('information_schema', 'pg_catalog')
ORDER BY 1,2;'''
columns_query = '''
SELECT nsp.nspname schema_name,
cls.relname table_name,
att.attname column_name
FROM pg_catalog.pg_attribute att
INNER JOIN pg_catalog.pg_class cls
ON att.attrelid = cls.oid
INNER JOIN pg_catalog.pg_namespace nsp
ON cls.relnamespace = nsp.oid
WHERE cls.relkind IN ('r', 'v', 'm')
AND nsp.nspname !~ '^pg_'
AND nsp.nspname <> 'information_schema'
AND NOT att.attisdropped
AND att.attnum > 0
ORDER BY 1, 2, 3'''
columns_query = '''SELECT table_name, column_name FROM information_schema.columns'''
databases_query = """SELECT d.datname as "Name",
pg_catalog.pg_get_userbyid(d.datdba) as "Owner",
@ -133,22 +162,37 @@ class PGExecute(object):
_logger.debug('No rows in result.')
return (None, None, cur.statusmessage)
def search_path(self):
"""Returns the current search path as a list of schema names"""
with self.conn.cursor() as cur:
_logger.debug('Search path query. sql: %r', self.search_path_query)
cur.execute(self.search_path_query)
return [x[0] for x in cur.fetchall()]
def schemata(self):
"""Returns a list of schema names in the database"""
with self.conn.cursor() as cur:
_logger.debug('Schemata Query. sql: %r', self.schemata_query)
cur.execute(self.schemata_query)
return [x[0] for x in cur.fetchall()]
def tables(self):
""" Returns tuple (sorted_tables, columns). Columns is a dictionary of
table name -> list of columns """
columns = defaultdict(list)
"""Returns a list of (schema_name, table_name) tuples """
with self.conn.cursor() as cur:
_logger.debug('Tables Query. sql: %r', self.tables_query)
cur.execute(self.tables_query)
tables = [x[0] for x in cur.fetchall()]
return cur.fetchall()
table_set = set(tables)
def columns(self):
"""Returns a list of (schema_name, table_name, column_name) tuples"""
with self.conn.cursor() as cur:
_logger.debug('Columns Query. sql: %r', self.columns_query)
cur.execute(self.columns_query)
for table, column in cur.fetchall():
if table in table_set:
columns[table].append(column)
return tables, columns
return cur.fetchall()
def databases(self):
with self.conn.cursor() as cur:

View File

@ -28,7 +28,7 @@ setup(
'jedi == 0.8.1', # Temporary fix for installation woes.
'prompt_toolkit==0.26',
'psycopg2 >= 2.5.4',
'sqlparse >= 0.1.14',
'sqlparse >= 0.1.14'
],
entry_points='''
[console_scripts]

View File

@ -1,3 +1,4 @@
import pytest
from pgcli.packages.parseutils import extract_tables
@ -7,48 +8,77 @@ def test_empty_string():
def test_simple_select_single_table():
tables = extract_tables('select * from abc')
assert tables == ['abc']
assert tables == [(None, 'abc', None)]
def test_simple_select_single_table_schema_qualified():
tables = extract_tables('select * from abc.def')
assert tables == [('abc', 'def', None)]
def test_simple_select_multiple_tables():
tables = extract_tables('select * from abc, def')
assert tables == ['abc', 'def']
assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
def test_simple_select_multiple_tables_schema_qualified():
tables = extract_tables('select * from abc.def, ghi.jkl')
assert sorted(tables) == [('abc', 'def', None), ('ghi', 'jkl', None)]
def test_simple_select_with_cols_single_table():
tables = extract_tables('select a,b from abc')
assert tables == ['abc']
assert tables == [(None, 'abc', None)]
def test_simple_select_with_cols_single_table_schema_qualified():
tables = extract_tables('select a,b from abc.def')
assert tables == [('abc', 'def', None)]
def test_simple_select_with_cols_multiple_tables():
tables = extract_tables('select a,b from abc, def')
assert tables == ['abc', 'def']
assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
def test_simple_select_with_cols_multiple_tables():
tables = extract_tables('select a,b from abc.def, def.ghi')
assert sorted(tables) == [('abc', 'def', None), ('def', 'ghi', None)]
def test_select_with_hanging_comma_single_table():
tables = extract_tables('select a, from abc')
assert tables == ['abc']
assert tables == [(None, 'abc', None)]
def test_select_with_hanging_comma_multiple_tables():
tables = extract_tables('select a, from abc, def')
assert tables == ['abc', 'def']
assert sorted(tables) == [(None, 'abc', None), (None, 'def', None)]
def test_select_with_hanging_period_multiple_tables():
tables = extract_tables('SELECT t1. FROM tabl1 t1, tabl2 t2')
assert sorted(tables) == [(None, 'tabl1', 't1'), (None, 'tabl2', 't2')]
def test_simple_insert_single_table():
tables = extract_tables('insert into abc (id, name) values (1, "def")')
assert tables == ['abc']
# sqlparse mistakenly assigns an alias to the table
# assert tables == [(None, 'abc', None)]
assert tables == [(None, 'abc', 'abc')]
@pytest.mark.xfail
def test_simple_insert_single_table_schema_qualified():
tables = extract_tables('insert into abc.def (id, name) values (1, "def")')
assert tables == [('abc', 'def', None)]
def test_simple_update_table():
tables = extract_tables('update abc set id = 1')
assert tables == ['abc']
assert tables == [(None, 'abc', None)]
def test_simple_update_table():
tables = extract_tables('update abc.def set id = 1')
assert tables == [('abc', 'def', None)]
def test_join_table():
expected = {'a': 'abc', 'd': 'def'}
tables = extract_tables('SELECT * FROM abc a JOIN def d ON a.id = d.num')
tables_aliases = extract_tables(
'SELECT * FROM abc a JOIN def d ON a.id = d.num', True)
assert tables == sorted(expected.values())
assert tables_aliases == expected
assert sorted(tables) == [(None, 'abc', 'a'), (None, 'def', 'd')]
def test_join_table_schema_qualified():
tables = extract_tables('SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num')
assert tables == [('abc', 'def', 'x'), ('ghi', 'jkl', 'y')]
def test_join_as_table():
expected = {'m': 'my_table'}
assert extract_tables(
'SELECT * FROM my_table AS m WHERE m.a > 5') == \
sorted(expected.values())
assert extract_tables(
'SELECT * FROM my_table AS m WHERE m.a > 5', True) == expected
tables = extract_tables('SELECT * FROM my_table AS m WHERE m.a > 5')
assert tables == [(None, 'my_table', 'm')]

View File

@ -16,14 +16,22 @@ def test_conn(executor):
SELECT 1""")
@dbtest
def test_table_and_columns_query(executor):
def test_schemata_table_and_columns_query(executor):
run(executor, "create table a(x text, y text)")
run(executor, "create table b(z text)")
run(executor, "create schema schema1")
run(executor, "create table schema1.c (w text)")
run(executor, "create schema schema2")
tables, columns = executor.tables()
assert tables == ['a', 'b']
assert columns['a'] == ['x', 'y']
assert columns['b'] == ['z']
assert executor.schemata() == ['public', 'schema1', 'schema2']
assert executor.tables() == [
('public', 'a'), ('public', 'b'), ('schema1', 'c')]
assert executor.columns() == [
('public', 'a', 'x'), ('public', 'a', 'y'),
('public', 'b', 'z'), ('schema1', 'c', 'w')]
assert executor.search_path() == ['public']
@dbtest
def test_database_list(executor):

View File

@ -0,0 +1,19 @@
from pgcli.packages.sqlcompletion import suggest_type
from test_sqlcompletion import sorted_dicts
def test_d_suggests_tables_and_schemas():
suggestions = suggest_type('\d ', '\d ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'schema'}, {'type': 'table', 'schema': []}])
suggestions = suggest_type('\d xxx', '\d xxx')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'schema'}, {'type': 'table', 'schema': []}])
def test_d_dot_suggests_schema_qualified_tables():
suggestions = suggest_type('\d myschema.', '\d myschema.')
assert suggestions == [{'type': 'table', 'schema': 'myschema'}]
suggestions = suggest_type('\d myschema.xxx', '\d myschema.xxx')
assert suggestions == [{'type': 'table', 'schema': 'myschema'}]

View File

@ -0,0 +1,227 @@
import pytest
from prompt_toolkit.completion import Completion
from prompt_toolkit.document import Document
metadata = {
'public': {
'users': ['id', 'email', 'first_name', 'last_name'],
'orders': ['id', 'ordered_date', 'status'],
'select': ['id', 'insert', 'ABC']
},
'custom': {
'users': ['id', 'phone_number'],
'products': ['id', 'product_name', 'price'],
'shipments': ['id', 'address', 'user_id']
}
}
@pytest.fixture
def completer():
import pgcli.pgcompleter as pgcompleter
comp = pgcompleter.PGCompleter(smart_completion=True)
schemata, tables, columns = [], [], []
for schema, tbls in metadata.items():
schemata.append(schema)
for table, cols in tbls.items():
tables.append((schema, table))
columns.extend([(schema, table, col) for col in cols])
comp.extend_schemata(schemata)
comp.extend_tables(tables)
comp.extend_columns(columns)
comp.set_search_path(['public'])
return comp
@pytest.fixture
def complete_event():
from mock import Mock
return Mock()
def test_schema_or_visible_table_completion(completer, complete_event):
text = 'SELECT * FROM '
position = len(text)
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event)
assert set(result) == set([Completion(text='public', start_position=0),
Completion(text='custom', start_position=0),
Completion(text='users', start_position=0),
Completion(text='"select"', start_position=0),
Completion(text='orders', start_position=0)])
def test_suggested_column_names_from_shadowed_visible_table(completer, complete_event):
"""
Suggest column and function names when selecting from table
:param completer:
:param complete_event:
:return:
"""
text = 'SELECT from users'
position = len('SELECT ')
result = set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert set(result) == set([
Completion(text='*', start_position=0),
Completion(text='id', start_position=0),
Completion(text='email', start_position=0),
Completion(text='first_name', start_position=0),
Completion(text='last_name', start_position=0)] +
list(map(Completion, completer.functions)))
def test_suggested_column_names_from_qualified_shadowed_table(completer, complete_event):
text = 'SELECT from custom.users'
position = len('SELECT ')
result = set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert set(result) == set([
Completion(text='*', start_position=0),
Completion(text='id', start_position=0),
Completion(text='phone_number', start_position=0)] +
list(map(Completion, completer.functions)))
def test_suggested_column_names_from_schema_qualifed_table(completer, complete_event):
"""
Suggest column and function names when selecting from a qualified-table
:param completer:
:param complete_event:
:return:
"""
text = 'SELECT from custom.products'
position = len('SELECT ')
result = set(completer.get_completions(
Document(text=text, cursor_position=position), complete_event))
assert set(result) == set([
Completion(text='*', start_position=0),
Completion(text='id', start_position=0),
Completion(text='product_name', start_position=0),
Completion(text='price', start_position=0)] +
list(map(Completion, completer.functions)))
def test_suggested_column_names_in_function(completer, complete_event):
"""
Suggest column and function names when selecting multiple
columns from table
:param completer:
:param complete_event:
:return:
"""
text = 'SELECT MAX( from custom.products'
position = len('SELECT MAX(')
result = completer.get_completions(
Document(text=text, cursor_position=position),
complete_event)
assert set(result) == set([
Completion(text='*', start_position=0),
Completion(text='id', start_position=0),
Completion(text='product_name', start_position=0),
Completion(text='price', start_position=0)])
def test_suggested_table_names_with_schema_dot(completer, complete_event):
text = 'SELECT * FROM custom.'
position = len(text)
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event)
assert set(result) == set([
Completion(text='users', start_position=0),
Completion(text='products', start_position=0),
Completion(text='shipments', start_position=0)])
def test_suggested_column_names_with_qualified_alias(completer, complete_event):
"""
Suggest column names on table alias and dot
:param completer:
:param complete_event:
:return:
"""
text = 'SELECT p. from custom.products p'
position = len('SELECT p.')
result = set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert set(result) == set([
Completion(text='*', start_position=0),
Completion(text='id', start_position=0),
Completion(text='product_name', start_position=0),
Completion(text='price', start_position=0)])
def test_suggested_multiple_column_names(completer, complete_event):
"""
Suggest column and function names when selecting multiple
columns from table
:param completer:
:param complete_event:
:return:
"""
text = 'SELECT id, from custom.products'
position = len('SELECT id, ')
result = set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert set(result) == set([
Completion(text='*', start_position=0),
Completion(text='id', start_position=0),
Completion(text='product_name', start_position=0),
Completion(text='price', start_position=0)] +
list(map(Completion, completer.functions)))
def test_suggested_multiple_column_names_with_alias(completer, complete_event):
"""
Suggest column names on table alias and dot
when selecting multiple columns from table
:param completer:
:param complete_event:
:return:
"""
text = 'SELECT p.id, p. from custom.products p'
position = len('SELECT u.id, u.')
result = set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert set(result) == set([
Completion(text='*', start_position=0),
Completion(text='id', start_position=0),
Completion(text='product_name', start_position=0),
Completion(text='price', start_position=0)])
def test_suggested_aliases_after_on(completer, complete_event):
text = 'SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON '
position = len(text)
result = set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert set(result) == set([
Completion(text='x', start_position=0),
Completion(text='y', start_position=0)])
def test_suggested_aliases_after_on_right_side(completer, complete_event):
text = 'SELECT x.id, y.product_name FROM custom.products x JOIN custom.products y ON x.id = '
position = len(text)
result = set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert set(result) == set([
Completion(text='x', start_position=0),
Completion(text='y', start_position=0)])
def test_table_names_after_from(completer, complete_event):
text = 'SELECT * FROM '
position = len('SELECT * FROM ')
result = set(completer.get_completions(
Document(text=text, cursor_position=position),
complete_event))
assert set(result) == set([
Completion(text='public', start_position=0),
Completion(text='custom', start_position=0),
Completion(text='users', start_position=0),
Completion(text='orders', start_position=0),
Completion(text='"select"', start_position=0),
])

View File

@ -2,20 +2,30 @@ import pytest
from prompt_toolkit.completion import Completion
from prompt_toolkit.document import Document
tables = {
'users': ['id', 'email', 'first_name', 'last_name'],
'orders': ['id', 'user_id', 'ordered_date', 'status'],
'select': ['id', 'insert', 'ABC'],
}
metadata = {
'users': ['id', 'email', 'first_name', 'last_name'],
'orders': ['id', 'ordered_date', 'status'],
'select': ['id', 'insert', 'ABC']
}
@pytest.fixture
def completer():
import pgcli.pgcompleter as pgcompleter
comp = pgcompleter.PGCompleter(smart_completion=True)
comp.extend_table_names(tables.keys())
for t in tables:
comp.extend_column_names(t, tables[t])
schemata = ['public']
tables, columns = [], []
for table, cols in metadata.items():
tables.append(('public', table))
columns.extend([('public', table, col) for col in cols])
comp.extend_schemata(schemata)
comp.extend_tables(tables)
comp.extend_columns(columns)
comp.set_search_path(['public'])
return comp
@pytest.fixture
@ -40,15 +50,26 @@ def test_select_keyword_completion(completer, complete_event):
complete_event)
assert set(result) == set([Completion(text='SELECT', start_position=-3)])
def test_schema_or_visible_table_completion(completer, complete_event):
text = 'SELECT * FROM '
position = len(text)
result = completer.get_completions(
Document(text=text, cursor_position=position), complete_event)
assert set(result) == set([Completion(text='public', start_position=0),
Completion(text='users', start_position=0),
Completion(text='"select"', start_position=0),
Completion(text='orders', start_position=0)])
def test_function_name_completion(completer, complete_event):
text = 'SELECT MA'
position = len('SELECT MA')
result = completer.get_completions(
Document(text=text, cursor_position=position),
complete_event)
Document(text=text, cursor_position=position), complete_event)
assert set(result) == set([Completion(text='MAX', start_position=-2)])
def test_suggested_column_names(completer, complete_event):
def test_suggested_column_names_from_visible_table(completer, complete_event):
"""
Suggest column and function names when selecting from table
:param completer:
@ -88,7 +109,7 @@ def test_suggested_column_names_in_function(completer, complete_event):
Completion(text='first_name', start_position=0),
Completion(text='last_name', start_position=0)])
def test_suggested_column_names_with_dot(completer, complete_event):
def test_suggested_column_names_with_table_dot(completer, complete_event):
"""
Suggest column names on table name and dot
:param completer:
@ -234,6 +255,7 @@ def test_table_names_after_from(completer, complete_event):
Document(text=text, cursor_position=position),
complete_event))
assert set(result) == set([
Completion(text='public', start_position=0),
Completion(text='users', start_position=0),
Completion(text='orders', start_position=0),
Completion(text='"select"', start_position=0),

View File

@ -1,143 +1,185 @@
from pgcli.packages.sqlcompletion import suggest_type
import pytest
def test_select_suggests_cols_with_table_scope():
suggestion = suggest_type('SELECT FROM tabl', 'SELECT ')
assert suggestion == ('columns-and-functions', ['tabl'])
def sorted_dicts(dicts):
"""input is a list of dicts"""
return sorted(tuple(x.items()) for x in dicts)
def test_select_suggests_cols_with_visible_table_scope():
suggestions = suggest_type('SELECT FROM tabl', 'SELECT ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'function'}])
def test_select_suggests_cols_with_qualified_table_scope():
suggestions = suggest_type('SELECT FROM sch.tabl', 'SELECT ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [('sch', 'tabl', None)]},
{'type': 'function'}])
def test_where_suggests_columns_functions():
suggestion = suggest_type('SELECT * FROM tabl WHERE ',
suggestions = suggest_type('SELECT * FROM tabl WHERE ',
'SELECT * FROM tabl WHERE ')
assert suggestion == ('columns-and-functions', ['tabl'])
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'function'}])
def test_lparen_suggests_cols():
suggestion = suggest_type('SELECT MAX( FROM tbl', 'SELECT MAX(')
assert suggestion == ('columns', ['tbl'])
assert suggestion == [
{'type': 'column', 'tables': [(None, 'tbl', None)]}]
def test_select_suggests_cols_and_funcs():
suggestion = suggest_type('SELECT ', 'SELECT ')
assert suggestion == ('columns-and-functions', [])
suggestions = suggest_type('SELECT ', 'SELECT ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': []},
{'type': 'function'}])
def test_from_suggests_tables():
suggestion = suggest_type('SELECT * FROM ', 'SELECT * FROM ')
assert suggestion == ('tables', [])
def test_from_suggests_tables_and_schemas():
suggestions = suggest_type('SELECT * FROM ', 'SELECT * FROM ')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'schema'}])
def test_distinct_suggests_cols():
suggestion = suggest_type('SELECT DISTINCT ', 'SELECT DISTINCT ')
assert suggestion == ('columns', [])
suggestions = suggest_type('SELECT DISTINCT ', 'SELECT DISTINCT ')
assert suggestions == [{'type': 'column', 'tables': []}]
def test_col_comma_suggests_cols():
suggestion = suggest_type('SELECT a, b, FROM tbl', 'SELECT a, b,')
assert suggestion == ('columns-and-functions', ['tbl'])
suggestions = suggest_type('SELECT a, b, FROM tbl', 'SELECT a, b,')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'tbl', None)]},
{'type': 'function'}])
def test_table_comma_suggests_tables():
suggestion = suggest_type('SELECT a, b FROM tbl1, ',
def test_table_comma_suggests_tables_and_schemas():
suggestions = suggest_type('SELECT a, b FROM tbl1, ',
'SELECT a, b FROM tbl1, ')
assert suggestion == ('tables', [])
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'schema'}])
def test_into_suggests_tables():
def test_into_suggests_tables_and_schemas():
suggestion = suggest_type('INSERT INTO ', 'INSERT INTO ')
assert suggestion == ('tables', [])
assert sorted_dicts(suggestion) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'schema'}])
def test_insert_into_lparen_suggests_cols():
suggestion = suggest_type('INSERT INTO abc (', 'INSERT INTO abc (')
assert suggestion == ('columns', ['abc'])
suggestions = suggest_type('INSERT INTO abc (', 'INSERT INTO abc (')
assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
def test_insert_into_lparen_partial_text_suggests_cols():
suggestion = suggest_type('INSERT INTO abc (i', 'INSERT INTO abc (i')
assert suggestion == ('columns', ['abc'])
suggestions = suggest_type('INSERT INTO abc (i', 'INSERT INTO abc (i')
assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
def test_insert_into_lparen_comma_suggests_cols():
suggestion = suggest_type('INSERT INTO abc (id,', 'INSERT INTO abc (id,')
assert suggestion == ('columns', ['abc'])
suggestions = suggest_type('INSERT INTO abc (id,', 'INSERT INTO abc (id,')
assert suggestions == [{'type': 'column', 'tables': [(None, 'abc', None)]}]
def test_partially_typed_col_name_suggests_col_names():
suggestion = suggest_type('SELECT * FROM tabl WHERE col_n',
suggestions = suggest_type('SELECT * FROM tabl WHERE col_n',
'SELECT * FROM tabl WHERE col_n')
assert suggestion == ('columns-and-functions', ['tabl'])
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'function'}])
def test_dot_suggests_cols_of_a_table():
suggestion = suggest_type('SELECT tabl. FROM tabl', 'SELECT tabl.')
assert suggestion == ('columns', ['tabl'])
def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
suggestions = suggest_type('SELECT tabl. FROM tabl', 'SELECT tabl.')
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'tabl', None)]},
{'type': 'table', 'schema': 'tabl'}])
def test_dot_suggests_cols_of_an_alias():
suggestion = suggest_type('SELECT t1. FROM tabl1 t1, tabl2 t2',
suggestions = suggest_type('SELECT t1. FROM tabl1 t1, tabl2 t2',
'SELECT t1.')
assert suggestion == ('columns', ['tabl1'])
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'table', 'schema': 't1'},
{'type': 'column', 'tables': [(None, 'tabl1', 't1')]}])
def test_dot_col_comma_suggests_cols():
suggestion = suggest_type('SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2',
def test_dot_col_comma_suggests_cols_or_schema_qualified_table():
suggestions = suggest_type('SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2',
'SELECT t1.a, t2.')
assert suggestion == ('columns', ['tabl2'])
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'tabl2', 't2')]},
{'type': 'table', 'schema': 't2'}])
def test_sub_select_suggests_keyword():
suggestion = suggest_type('SELECT * FROM (', 'SELECT * FROM (')
assert suggestion == ('keywords', [])
assert suggestion == [{'type': 'keyword'}]
def test_sub_select_partial_text_suggests_keyword():
suggestion = suggest_type('SELECT * FROM (S', 'SELECT * FROM (S')
assert suggestion == ('keywords', [])
assert suggestion == [{'type': 'keyword'}]
def test_sub_select_table_name_completion():
suggestion = suggest_type('SELECT * FROM (SELECT * FROM ',
'SELECT * FROM (SELECT * FROM ')
assert suggestion == ('tables', [])
assert sorted_dicts(suggestion) == sorted_dicts([
{'type': 'table', 'schema': []}, {'type': 'schema'}])
def test_sub_select_col_name_completion():
suggestion = suggest_type('SELECT * FROM (SELECT FROM abc',
suggestions = suggest_type('SELECT * FROM (SELECT FROM abc',
'SELECT * FROM (SELECT ')
assert suggestion == ('columns-and-functions', ['abc'])
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'abc', None)]},
{'type': 'function'}])
@pytest.mark.xfail
def test_sub_select_multiple_col_name_completion():
suggestion = suggest_type('SELECT * FROM (SELECT a, FROM abc',
suggestions = suggest_type('SELECT * FROM (SELECT a, FROM abc',
'SELECT * FROM (SELECT a, ')
assert suggestion == ('columns-and-functions', ['abc'])
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'abc', None)]},
{'type': 'function'}])
def test_sub_select_dot_col_name_completion():
suggestion = suggest_type('SELECT * FROM (SELECT t. FROM tabl t',
suggestions = suggest_type('SELECT * FROM (SELECT t. FROM tabl t',
'SELECT * FROM (SELECT t.')
assert suggestion == ('columns', ['tabl'])
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'tabl', 't')]},
{'type': 'table', 'schema': 't'}])
def test_join_suggests_tables():
def test_join_suggests_tables_and_schemas():
suggestion = suggest_type('SELECT * FROM abc a JOIN ',
'SELECT * FROM abc a JOIN ')
assert suggestion == ('tables', [])
assert sorted_dicts(suggestion) == sorted_dicts([
{'type': 'table', 'schema': []},
{'type': 'schema'}])
def test_join_alias_dot_suggests_cols1():
suggestion = suggest_type('SELECT * FROM abc a JOIN def d ON a.',
suggestions = suggest_type('SELECT * FROM abc a JOIN def d ON a.',
'SELECT * FROM abc a JOIN def d ON a.')
assert suggestion == ('columns', ['abc'])
assert sorted_dicts(suggestions) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'abc', 'a')]},
{'type': 'table', 'schema': 'a'}])
def test_join_alias_dot_suggests_cols2():
suggestion = suggest_type('SELECT * FROM abc a JOIN def d ON a.',
'SELECT * FROM abc a JOIN def d ON a.id = d.')
assert suggestion == ('columns', ['def'])
assert sorted_dicts(suggestion) == sorted_dicts([
{'type': 'column', 'tables': [(None, 'def', 'd')]},
{'type': 'table', 'schema': 'd'}])
def test_on_suggests_aliases():
category, scope = suggest_type(
suggestions = suggest_type(
'select a.x, b.y from abc a join bcd b on ',
'select a.x, b.y from abc a join bcd b on ')
assert category == 'tables-or-aliases'
assert set(scope) == set(['a', 'b'])
assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}]
def test_on_suggests_tables():
category, scope = suggest_type(
suggestions = suggest_type(
'select abc.x, bcd.y from abc join bcd on ',
'select abc.x, bcd.y from abc join bcd on ')
assert category == 'tables-or-aliases'
assert set(scope) == set(['abc', 'bcd'])
assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}]
def test_on_suggests_aliases_right_side():
category, scope = suggest_type(
suggestions = suggest_type(
'select a.x, b.y from abc a join bcd b on a.id = ',
'select a.x, b.y from abc a join bcd b on a.id = ')
assert category == 'tables-or-aliases'
assert set(scope) == set(['a', 'b'])
assert suggestions == [{'type': 'alias', 'aliases': ['a', 'b']}]
def test_on_suggests_tables_right_side():
category, scope = suggest_type(
suggestions = suggest_type(
'select abc.x, bcd.y from abc join bcd on ',
'select abc.x, bcd.y from abc join bcd on ')
assert category == 'tables-or-aliases'
assert set(scope) == set(['abc', 'bcd'])
assert suggestions == [{'type': 'alias', 'aliases': ['abc', 'bcd']}]

View File

@ -32,7 +32,11 @@ def create_db(dbname):
def drop_tables(conn):
with conn.cursor() as cur:
cur.execute('''DROP SCHEMA public CASCADE; CREATE SCHEMA public''')
cur.execute('''
DROP SCHEMA public CASCADE;
CREATE SCHEMA public;
DROP SCHEMA IF EXISTS schema1 CASCADE;
DROP SCHEMA IF EXISTS schema2 CASCADE''')
def run(executor, sql, join=False):