1
0
Fork 0

Modernize code to Python 3.6+ (#1229)

1. `class A(object)` can be written as `class A:`
2. replace `dict([…])` and `set([…])` with `{…}`
3. use f-strings or compact `.format`
4. use `yield from` instead of `yield` in a `for` loop
5. import `mock` from `unittest`
6. expect `OSError` instead of `IOError` or `select` error
7. use Python3 defaults for file reading or `super()`
8. remove redundant parenthesis (keep those in tuples though)
9. shorten set intersection instead of creating lists
10. backslashes in strings do not have to be escaped if prepended with `r`
This commit is contained in:
Miroslav Šedivý 2021-02-12 20:34:56 +01:00 committed by GitHub
parent 87ffae295e
commit 762fb4b8da
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
31 changed files with 367 additions and 416 deletions

View File

@ -115,6 +115,7 @@ Contributors:
* Jan Brun Rasmussen (janbrunrasmussen)
* Kevin Marsh (kevinmarsh)
* Eero Ruohola (ruohola)
* Miroslav Šedivý (eumiro)
Creator:
--------

View File

@ -6,7 +6,7 @@ from .pgcompleter import PGCompleter
from .pgexecute import PGExecute
class CompletionRefresher(object):
class CompletionRefresher:
refreshers = OrderedDict()
@ -141,7 +141,7 @@ def refresh_casing(completer, executor):
with open(casing_file, "w") as f:
f.write(casing_prefs)
if os.path.isfile(casing_file):
with open(casing_file, "r") as f:
with open(casing_file) as f:
completer.extend_casing([line.strip() for line in f])

View File

@ -43,7 +43,7 @@ def pgcli_line_magic(line):
conn._pgcli = pgcli
# For convenience, print the connection alias
print("Connected: {}".format(conn.name))
print(f"Connected: {conn.name}")
try:
pgcli.run_cli()

View File

@ -122,7 +122,7 @@ class PgCliQuitError(Exception):
pass
class PGCli(object):
class PGCli:
default_prompt = "\\u@\\h:\\d> "
max_len_prompt = 30
@ -325,11 +325,11 @@ class PGCli(object):
if pattern not in TabularOutputFormatter().supported_formats:
raise ValueError()
self.table_format = pattern
yield (None, None, None, "Changed table format to {}".format(pattern))
yield (None, None, None, f"Changed table format to {pattern}")
except ValueError:
msg = "Table format {} not recognized. Allowed formats:".format(pattern)
msg = f"Table format {pattern} not recognized. Allowed formats:"
for table_type in TabularOutputFormatter().supported_formats:
msg += "\n\t{}".format(table_type)
msg += f"\n\t{table_type}"
msg += "\nCurrently set to: %s" % self.table_format
yield (None, None, None, msg)
@ -386,7 +386,7 @@ class PGCli(object):
try:
with open(os.path.expanduser(pattern), encoding="utf-8") as f:
query = f.read()
except IOError as e:
except OSError as e:
return [(None, None, None, str(e), "", False, True)]
if self.destructive_warning and confirm_destructive_query(query) is False:
@ -407,7 +407,7 @@ class PGCli(object):
if not os.path.isfile(filename):
try:
open(filename, "w").close()
except IOError as e:
except OSError as e:
self.output_file = None
message = str(e) + "\nFile output disabled"
return [(None, None, None, message, "", False, True)]
@ -479,7 +479,7 @@ class PGCli(object):
service_config, file = parse_service_info(service)
if service_config is None:
click.secho(
"service '%s' was not found in %s" % (service, file), err=True, fg="red"
f"service '{service}' was not found in {file}", err=True, fg="red"
)
exit(1)
self.connect(
@ -515,7 +515,7 @@ class PGCli(object):
passwd = os.environ.get("PGPASSWORD", "")
# Find password from store
key = "%s@%s" % (user, host)
key = f"{user}@{host}"
keyring_error_message = dedent(
"""\
{}
@ -677,7 +677,7 @@ class PGCli(object):
click.echo(text, file=f)
click.echo("\n".join(output), file=f)
click.echo("", file=f) # extra newline
except IOError as e:
except OSError as e:
click.secho(str(e), err=True, fg="red")
else:
if output:
@ -753,11 +753,7 @@ class PGCli(object):
while self.watch_command:
try:
query = self.execute_command(self.watch_command)
click.echo(
"Waiting for {0} seconds before repeating".format(
timing
)
)
click.echo(f"Waiting for {timing} seconds before repeating")
sleep(timing)
except KeyboardInterrupt:
self.watch_command = None
@ -1049,7 +1045,7 @@ class PGCli(object):
str(self.pgexecute.port) if self.pgexecute.port is not None else "5432",
)
string = string.replace("\\i", str(self.pgexecute.pid) or "(none)")
string = string.replace("\\#", "#" if (self.pgexecute.superuser) else ">")
string = string.replace("\\#", "#" if self.pgexecute.superuser else ">")
string = string.replace("\\n", "\n")
return string
@ -1384,7 +1380,7 @@ def is_mutating(status):
if not status:
return False
mutating = set(["insert", "update", "delete"])
mutating = {"insert", "update", "delete"}
return status.split(None, 1)[0].lower() in mutating

View File

@ -50,7 +50,7 @@ def parse_defaults(defaults_string):
yield current
class FunctionMetadata(object):
class FunctionMetadata:
def __init__(
self,
schema_name,

View File

@ -42,8 +42,7 @@ def extract_from_part(parsed, stop_at_punctuation=True):
for item in parsed.tokens:
if tbl_prefix_seen:
if is_subselect(item):
for x in extract_from_part(item, stop_at_punctuation):
yield x
yield from extract_from_part(item, stop_at_punctuation)
elif stop_at_punctuation and item.ttype is Punctuation:
return
# An incomplete nested select won't be recognized correctly as a

View File

@ -16,10 +16,10 @@ def _compile_regex(keyword):
keywords = get_literals("keywords")
keyword_regexs = dict((kw, _compile_regex(kw)) for kw in keywords)
keyword_regexs = {kw: _compile_regex(kw) for kw in keywords}
class PrevalenceCounter(object):
class PrevalenceCounter:
def __init__(self):
self.keyword_counts = defaultdict(int)
self.name_counts = defaultdict(int)

View File

@ -47,7 +47,7 @@ Alias = namedtuple("Alias", ["aliases"])
Path = namedtuple("Path", [])
class SqlStatement(object):
class SqlStatement:
def __init__(self, full_text, text_before_cursor):
self.identifier = None
self.word_before_cursor = word_before_cursor = last_word(

View File

@ -83,7 +83,7 @@ class PGCompleter(Completer):
reserved_words = set(get_literals("reserved"))
def __init__(self, smart_completion=True, pgspecial=None, settings=None):
super(PGCompleter, self).__init__()
super().__init__()
self.smart_completion = smart_completion
self.pgspecial = pgspecial
self.prioritizer = PrevalenceCounter()
@ -177,7 +177,7 @@ class PGCompleter(Completer):
:return:
"""
# casing should be a dict {lowercasename:PreferredCasingName}
self.casing = dict((word.lower(), word) for word in words)
self.casing = {word.lower(): word for word in words}
def extend_relations(self, data, kind):
"""extend metadata for tables or views.
@ -279,8 +279,8 @@ class PGCompleter(Completer):
fk = ForeignKey(
parentschema, parenttable, parcol, childschema, childtable, childcol
)
childcolmeta.foreignkeys.append((fk))
parcolmeta.foreignkeys.append((fk))
childcolmeta.foreignkeys.append(fk)
parcolmeta.foreignkeys.append(fk)
def extend_datatypes(self, type_data):
@ -424,7 +424,7 @@ class PGCompleter(Completer):
# the same priority as unquoted names.
lexical_priority = (
tuple(
0 if c in (" _") else -ord(c)
0 if c in " _" else -ord(c)
for c in self.unescape_name(item.lower())
)
+ (1,)
@ -517,9 +517,9 @@ class PGCompleter(Completer):
# require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should
# suggest only columns that appear in the last table and one more
ltbl = tables[-1].ref
other_tbl_cols = set(
other_tbl_cols = {
c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs
)
}
scoped_cols = {
t: [col for col in cols if col.name in other_tbl_cols]
for t, cols in scoped_cols.items()
@ -574,7 +574,7 @@ class PGCompleter(Completer):
tbls - TableReference iterable of tables already in query
"""
tbl = self.case(tbl)
tbls = set(normalize_ref(t.ref) for t in tbls)
tbls = {normalize_ref(t.ref) for t in tbls}
if self.generate_aliases:
tbl = generate_alias(self.unescape_name(tbl))
if normalize_ref(tbl) not in tbls:
@ -589,10 +589,10 @@ class PGCompleter(Completer):
tbls = suggestion.table_refs
cols = self.populate_scoped_cols(tbls)
# Set up some data structures for efficient access
qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls)
ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls))
refs = set(normalize_ref(t.ref) for t in tbls)
other_tbls = set((t.schema, t.name) for t in list(cols)[:-1])
qualified = {normalize_ref(t.ref): t.schema for t in tbls}
ref_prio = {normalize_ref(t.ref): n for n, t in enumerate(tbls)}
refs = {normalize_ref(t.ref) for t in tbls}
other_tbls = {(t.schema, t.name) for t in list(cols)[:-1]}
joins = []
# Iterate over FKs in existing tables to find potential joins
fks = (
@ -667,7 +667,7 @@ class PGCompleter(Completer):
return d
# Tables that are closer to the cursor get higher prio
ref_prio = dict((tbl.ref, num) for num, tbl in enumerate(suggestion.table_refs))
ref_prio = {tbl.ref: num for num, tbl in enumerate(suggestion.table_refs)}
# Map (schema, table, col) to tables
coldict = list_dict(
((t.schema, t.name, c.name), t) for t, c in cols if t.ref != lref
@ -721,9 +721,7 @@ class PGCompleter(Completer):
# Function overloading means we way have multiple functions of the same
# name at this point, so keep unique names only
all_functions = self.populate_functions(suggestion.schema, filt)
funcs = set(
self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions
)
funcs = {self._make_cand(f, alias, suggestion, arg_mode) for f in all_functions}
matches = self.find_matches(word_before_cursor, funcs, meta="function")
@ -953,7 +951,7 @@ class PGCompleter(Completer):
:return: {TableReference:{colname:ColumnMetaData}}
"""
ctes = dict((normalize_ref(t.name), t.columns) for t in local_tbls)
ctes = {normalize_ref(t.name): t.columns for t in local_tbls}
columns = OrderedDict()
meta = self.dbmetadata

View File

@ -49,7 +49,7 @@ def _wait_select(conn):
conn.cancel()
# the loop will be broken by a server error
continue
except select.error as e:
except OSError as e:
errno = e.args[0]
if errno != 4:
raise
@ -127,7 +127,7 @@ def register_hstore_typecaster(conn):
pass
class PGExecute(object):
class PGExecute:
# The boolean argument to the current_schemas function indicates whether
# implicit schemas, e.g. pg_catalog
@ -485,7 +485,7 @@ class PGExecute(object):
try:
cur.execute(sql, (spec,))
except psycopg2.ProgrammingError:
raise RuntimeError("View {} does not exist.".format(spec))
raise RuntimeError(f"View {spec} does not exist.")
result = cur.fetchone()
view_type = "MATERIALIZED" if result[2] == "m" else ""
return template.format(*result + (view_type,))
@ -501,7 +501,7 @@ class PGExecute(object):
result = cur.fetchone()
return result[0]
except psycopg2.ProgrammingError:
raise RuntimeError("Function {} does not exist.".format(spec))
raise RuntimeError(f"Function {spec} does not exist.")
def schemata(self):
"""Returns a list of schema names in the database"""
@ -527,21 +527,18 @@ class PGExecute(object):
sql = cur.mogrify(self.tables_query, [kinds])
_logger.debug("Tables Query. sql: %r", sql)
cur.execute(sql)
for row in cur:
yield row
yield from cur
def tables(self):
"""Yields (schema_name, table_name) tuples"""
for row in self._relations(kinds=["r", "p", "f"]):
yield row
yield from self._relations(kinds=["r", "p", "f"])
def views(self):
"""Yields (schema_name, view_name) tuples.
Includes both views and and materialized views
"""
for row in self._relations(kinds=["v", "m"]):
yield row
yield from self._relations(kinds=["v", "m"])
def _columns(self, kinds=("r", "p", "f", "v", "m")):
"""Get column metadata for tables and views
@ -599,16 +596,13 @@ class PGExecute(object):
sql = cur.mogrify(columns_query, [kinds])
_logger.debug("Columns Query. sql: %r", sql)
cur.execute(sql)
for row in cur:
yield row
yield from cur
def table_columns(self):
for row in self._columns(kinds=["r", "p", "f"]):
yield row
yield from self._columns(kinds=["r", "p", "f"])
def view_columns(self):
for row in self._columns(kinds=["v", "m"]):
yield row
yield from self._columns(kinds=["v", "m"])
def databases(self):
with self.conn.cursor() as cur:
@ -804,8 +798,7 @@ class PGExecute(object):
"""
_logger.debug("Datatypes Query. sql: %r", query)
cur.execute(query)
for row in cur:
yield row
yield from cur
def casing(self):
"""Yields the most common casing for names used in db functions"""

View File

@ -1,5 +1,4 @@
pytest>=2.7.0
mock>=1.0.1
tox>=1.9.2
behave>=1.2.4
pexpect==3.3

View File

@ -44,7 +44,7 @@ def create_cn(hostname, password, username, dbname, port):
host=hostname, user=username, database=dbname, password=password, port=port
)
print("Created connection: {0}.".format(cn.dsn))
print(f"Created connection: {cn.dsn}.")
return cn
@ -75,4 +75,4 @@ def close_cn(cn=None):
"""
if cn:
cn.close()
print("Closed connection: {0}.".format(cn.dsn))
print(f"Closed connection: {cn.dsn}.")

View File

@ -38,7 +38,7 @@ def before_all(context):
vi = "_".join([str(x) for x in sys.version_info[:3]])
db_name = context.config.userdata.get("pg_test_db", "pgcli_behave_tests")
db_name_full = "{0}_{1}".format(db_name, vi)
db_name_full = f"{db_name}_{vi}"
# Store get params from config.
context.conf = {
@ -122,12 +122,12 @@ def before_all(context):
def show_env_changes(env_old, env_new):
"""Print out all test-specific env values."""
print("--- os.environ changed values: ---")
all_keys = set(list(env_old.keys()) + list(env_new.keys()))
all_keys = env_old.keys() | env_new.keys()
for k in sorted(all_keys):
old_value = env_old.get(k, "")
new_value = env_new.get(k, "")
if new_value and old_value != new_value:
print('{}="{}"'.format(k, new_value))
print(f'{k}="{new_value}"')
print("-" * 20)
@ -173,13 +173,13 @@ def after_scenario(context, scenario):
# Quit nicely.
if not context.atprompt:
dbname = context.currentdb
context.cli.expect_exact("{0}> ".format(dbname), timeout=15)
context.cli.expect_exact(f"{dbname}> ", timeout=15)
context.cli.sendcontrol("c")
context.cli.sendcontrol("d")
try:
context.cli.expect_exact(pexpect.EOF, timeout=15)
except pexpect.TIMEOUT:
print("--- after_scenario {}: kill cli".format(scenario.name))
print(f"--- after_scenario {scenario.name}: kill cli")
context.cli.kill(signal.SIGKILL)
if hasattr(context, "tmpfile_sql_help") and context.tmpfile_sql_help:
context.tmpfile_sql_help.close()

View File

@ -18,7 +18,7 @@ def read_fixture_files():
"""Read all files inside fixture_data directory."""
current_dir = os.path.dirname(__file__)
fixture_dir = os.path.join(current_dir, "fixture_data/")
print("reading fixture data: {}".format(fixture_dir))
print(f"reading fixture data: {fixture_dir}")
fixture_dict = {}
for filename in os.listdir(fixture_dir):
if filename not in [".", ".."]:

View File

@ -66,19 +66,19 @@ def step_ctrl_d(context):
"""
# turn off pager before exiting
context.cli.sendcontrol("c")
context.cli.sendline("\pset pager off")
context.cli.sendline(r"\pset pager off")
wrappers.wait_prompt(context)
context.cli.sendcontrol("d")
context.cli.expect(pexpect.EOF, timeout=15)
context.exit_sent = True
@when('we send "\?" command')
@when(r'we send "\?" command')
def step_send_help(context):
"""
r"""
Send \? to see help.
"""
context.cli.sendline("\?")
context.cli.sendline(r"\?")
@when("we send partial select command")
@ -97,9 +97,9 @@ def step_see_error_message(context):
@when("we send source command")
def step_send_source_command(context):
context.tmpfile_sql_help = tempfile.NamedTemporaryFile(prefix="pgcli_")
context.tmpfile_sql_help.write(b"\?")
context.tmpfile_sql_help.write(br"\?")
context.tmpfile_sql_help.flush()
context.cli.sendline("\i {0}".format(context.tmpfile_sql_help.name))
context.cli.sendline(fr"\i {context.tmpfile_sql_help.name}")
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)

View File

@ -14,7 +14,7 @@ def step_db_create(context):
"""
Send create database.
"""
context.cli.sendline("create database {0};".format(context.conf["dbname_tmp"]))
context.cli.sendline("create database {};".format(context.conf["dbname_tmp"]))
context.response = {"database_name": context.conf["dbname_tmp"]}
@ -24,7 +24,7 @@ def step_db_drop(context):
"""
Send drop database.
"""
context.cli.sendline("drop database {0};".format(context.conf["dbname_tmp"]))
context.cli.sendline("drop database {};".format(context.conf["dbname_tmp"]))
@when("we connect to test database")
@ -33,7 +33,7 @@ def step_db_connect_test(context):
Send connect to database.
"""
db_name = context.conf["dbname"]
context.cli.sendline("\\connect {0}".format(db_name))
context.cli.sendline(f"\\connect {db_name}")
@when("we connect to dbserver")
@ -59,7 +59,7 @@ def step_see_prompt(context):
Wait to see the prompt.
"""
db_name = getattr(context, "currentdb", context.conf["dbname"])
wrappers.expect_exact(context, "{0}> ".format(db_name), timeout=5)
wrappers.expect_exact(context, f"{db_name}> ", timeout=5)
context.atprompt = True

View File

@ -31,7 +31,7 @@ def step_prepare_data(context):
@when("we set expanded {mode}")
def step_set_expanded(context, mode):
"""Set expanded to mode."""
context.cli.sendline("\\" + "x {}".format(mode))
context.cli.sendline("\\" + f"x {mode}")
wrappers.expect_exact(context, "Expanded display is", timeout=2)
wrappers.wait_prompt(context)

View File

@ -13,7 +13,7 @@ def step_edit_file(context):
)
if os.path.exists(context.editor_file_name):
os.remove(context.editor_file_name)
context.cli.sendline("\e {0}".format(os.path.basename(context.editor_file_name)))
context.cli.sendline(r"\e {}".format(os.path.basename(context.editor_file_name)))
wrappers.expect_exact(
context, 'Entering Ex mode. Type "visual" to go to Normal mode.', timeout=2
)
@ -53,7 +53,7 @@ def step_tee_ouptut(context):
)
if os.path.exists(context.tee_file_name):
os.remove(context.tee_file_name)
context.cli.sendline("\o {0}".format(os.path.basename(context.tee_file_name)))
context.cli.sendline(r"\o {}".format(os.path.basename(context.tee_file_name)))
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
wrappers.expect_exact(context, "Writing to file", timeout=5)
wrappers.expect_exact(context, context.conf["pager_boundary"] + "\r\n", timeout=5)
@ -67,7 +67,7 @@ def step_query_select_123456(context):
@when("we stop teeing output")
def step_notee_output(context):
context.cli.sendline("\o")
context.cli.sendline(r"\o")
wrappers.expect_exact(context, "Time", timeout=5)

View File

@ -57,7 +57,7 @@ def run_cli(context, run_args=None, prompt_check=True, currentdb=None):
context.cli.logfile = context.logfile
context.exit_sent = False
context.currentdb = currentdb or context.conf["dbname"]
context.cli.sendline("\pset pager always")
context.cli.sendline(r"\pset pager always")
if prompt_check:
wait_prompt(context)

View File

@ -3,7 +3,7 @@ from itertools import product
from pgcli.packages.parseutils.meta import FunctionMetadata, ForeignKey
from prompt_toolkit.completion import Completion
from prompt_toolkit.document import Document
from mock import Mock
from unittest.mock import Mock
import pytest
parametrize = pytest.mark.parametrize
@ -59,7 +59,7 @@ def wildcard_expansion(cols, pos=-1):
return Completion(cols, start_position=pos, display_meta="columns", display="*")
class MetaData(object):
class MetaData:
def __init__(self, metadata):
self.metadata = metadata
@ -128,7 +128,7 @@ class MetaData(object):
]
def schemas(self, pos=0):
schemas = set(sch for schs in self.metadata.values() for sch in schs)
schemas = {sch for schs in self.metadata.values() for sch in schs}
return [schema(escape(s), pos=pos) for s in schemas]
def functions_and_keywords(self, parent="public", pos=0):

View File

@ -34,12 +34,12 @@ def test_simple_select_single_table_double_quoted():
def test_simple_select_multiple_tables():
tables = extract_tables("select * from abc, def")
assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)}
def test_simple_select_multiple_tables_double_quoted():
tables = extract_tables('select * from "Abc", "Def"')
assert set(tables) == set([(None, "Abc", None, False), (None, "Def", None, False)])
assert set(tables) == {(None, "Abc", None, False), (None, "Def", None, False)}
def test_simple_select_single_table_deouble_quoted_aliased():
@ -49,14 +49,12 @@ def test_simple_select_single_table_deouble_quoted_aliased():
def test_simple_select_multiple_tables_deouble_quoted_aliased():
tables = extract_tables('select * from "Abc" a, "Def" d')
assert set(tables) == set([(None, "Abc", "a", False), (None, "Def", "d", False)])
assert set(tables) == {(None, "Abc", "a", False), (None, "Def", "d", False)}
def test_simple_select_multiple_tables_schema_qualified():
tables = extract_tables("select * from abc.def, ghi.jkl")
assert set(tables) == set(
[("abc", "def", None, False), ("ghi", "jkl", None, False)]
)
assert set(tables) == {("abc", "def", None, False), ("ghi", "jkl", None, False)}
def test_simple_select_with_cols_single_table():
@ -71,14 +69,12 @@ def test_simple_select_with_cols_single_table_schema_qualified():
def test_simple_select_with_cols_multiple_tables():
tables = extract_tables("select a,b from abc, def")
assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)}
def test_simple_select_with_cols_multiple_qualified_tables():
tables = extract_tables("select a,b from abc.def, def.ghi")
assert set(tables) == set(
[("abc", "def", None, False), ("def", "ghi", None, False)]
)
assert set(tables) == {("abc", "def", None, False), ("def", "ghi", None, False)}
def test_select_with_hanging_comma_single_table():
@ -88,14 +84,12 @@ def test_select_with_hanging_comma_single_table():
def test_select_with_hanging_comma_multiple_tables():
tables = extract_tables("select a, from abc, def")
assert set(tables) == set([(None, "abc", None, False), (None, "def", None, False)])
assert set(tables) == {(None, "abc", None, False), (None, "def", None, False)}
def test_select_with_hanging_period_multiple_tables():
tables = extract_tables("SELECT t1. FROM tabl1 t1, tabl2 t2")
assert set(tables) == set(
[(None, "tabl1", "t1", False), (None, "tabl2", "t2", False)]
)
assert set(tables) == {(None, "tabl1", "t1", False), (None, "tabl2", "t2", False)}
def test_simple_insert_single_table():
@ -126,14 +120,14 @@ def test_simple_update_table_with_schema():
@pytest.mark.parametrize("join_type", ["", "INNER", "LEFT", "RIGHT OUTER"])
def test_join_table(join_type):
sql = "SELECT * FROM abc a {0} JOIN def d ON a.id = d.num".format(join_type)
sql = f"SELECT * FROM abc a {join_type} JOIN def d ON a.id = d.num"
tables = extract_tables(sql)
assert set(tables) == set([(None, "abc", "a", False), (None, "def", "d", False)])
assert set(tables) == {(None, "abc", "a", False), (None, "def", "d", False)}
def test_join_table_schema_qualified():
tables = extract_tables("SELECT * FROM abc.def x JOIN ghi.jkl y ON x.id = y.num")
assert set(tables) == set([("abc", "def", "x", False), ("ghi", "jkl", "y", False)])
assert set(tables) == {("abc", "def", "x", False), ("ghi", "jkl", "y", False)}
def test_incomplete_join_clause():
@ -177,25 +171,25 @@ def test_extract_no_tables(text):
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
def test_simple_function_as_table(arg_list):
tables = extract_tables("SELECT * FROM foo({0})".format(arg_list))
tables = extract_tables(f"SELECT * FROM foo({arg_list})")
assert tables == ((None, "foo", None, True),)
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
def test_simple_schema_qualified_function_as_table(arg_list):
tables = extract_tables("SELECT * FROM foo.bar({0})".format(arg_list))
tables = extract_tables(f"SELECT * FROM foo.bar({arg_list})")
assert tables == (("foo", "bar", None, True),)
@pytest.mark.parametrize("arg_list", ["", "arg1", "arg1, arg2, arg3"])
def test_simple_aliased_function_as_table(arg_list):
tables = extract_tables("SELECT * FROM foo({0}) bar".format(arg_list))
tables = extract_tables(f"SELECT * FROM foo({arg_list}) bar")
assert tables == ((None, "foo", "bar", True),)
def test_simple_table_and_function():
tables = extract_tables("SELECT * FROM foo JOIN bar()")
assert set(tables) == set([(None, "foo", None, False), (None, "bar", None, True)])
assert set(tables) == {(None, "foo", None, False), (None, "bar", None, True)}
def test_complex_table_and_function():
@ -203,9 +197,7 @@ def test_complex_table_and_function():
"""SELECT * FROM foo.bar baz
JOIN bar.qux(x, y, z) quux"""
)
assert set(tables) == set(
[("foo", "bar", "baz", False), ("bar", "qux", "quux", True)]
)
assert set(tables) == {("foo", "bar", "baz", False), ("bar", "qux", "quux", True)}
def test_find_prev_keyword_using():

View File

@ -1,6 +1,6 @@
import time
import pytest
from mock import Mock, patch
from unittest.mock import Mock, patch
@pytest.fixture

View File

@ -1,6 +1,6 @@
import os
import platform
import mock
from unittest import mock
import pytest

View File

@ -13,7 +13,7 @@ def completer():
@pytest.fixture
def complete_event():
from mock import Mock
from unittest.mock import Mock
return Mock()

View File

@ -2,7 +2,7 @@ from textwrap import dedent
import psycopg2
import pytest
from mock import patch, MagicMock
from unittest.mock import patch, MagicMock
from pgspecial.main import PGSpecial, NO_QUERY
from utils import run, dbtest, requires_json, requires_jsonb
@ -89,7 +89,7 @@ def test_expanded_slash_G(executor, pgspecial):
# Tests whether we reset the expanded output after a \G.
run(executor, """create table test(a boolean)""")
run(executor, """insert into test values(True)""")
results = run(executor, """select * from test \G""", pgspecial=pgspecial)
results = run(executor, r"""select * from test \G""", pgspecial=pgspecial)
assert pgspecial.expanded_output == False
@ -105,31 +105,35 @@ def test_schemata_table_views_and_columns_query(executor):
# schemata
# don't enforce all members of the schemas since they may include postgres
# temporary schemas
assert set(executor.schemata()) >= set(
["public", "pg_catalog", "information_schema", "schema1", "schema2"]
)
assert set(executor.schemata()) >= {
"public",
"pg_catalog",
"information_schema",
"schema1",
"schema2",
}
assert executor.search_path() == ["pg_catalog", "public"]
# tables
assert set(executor.tables()) >= set(
[("public", "a"), ("public", "b"), ("schema1", "c")]
)
assert set(executor.tables()) >= {
("public", "a"),
("public", "b"),
("schema1", "c"),
}
assert set(executor.table_columns()) >= set(
[
("public", "a", "x", "text", False, None),
("public", "a", "y", "text", False, None),
("public", "b", "z", "text", False, None),
("schema1", "c", "w", "text", True, "'meow'::text"),
]
)
assert set(executor.table_columns()) >= {
("public", "a", "x", "text", False, None),
("public", "a", "y", "text", False, None),
("public", "b", "z", "text", False, None),
("schema1", "c", "w", "text", True, "'meow'::text"),
}
# views
assert set(executor.views()) >= set([("public", "d")])
assert set(executor.views()) >= {("public", "d")}
assert set(executor.view_columns()) >= set(
[("public", "d", "e", "integer", False, None)]
)
assert set(executor.view_columns()) >= {
("public", "d", "e", "integer", False, None)
}
@dbtest
@ -142,9 +146,9 @@ def test_foreign_key_query(executor):
"create table schema2.child(childid int PRIMARY KEY, motherid int REFERENCES schema1.parent)",
)
assert set(executor.foreignkeys()) >= set(
[("schema1", "parent", "parentid", "schema2", "child", "motherid")]
)
assert set(executor.foreignkeys()) >= {
("schema1", "parent", "parentid", "schema2", "child", "motherid")
}
@dbtest
@ -175,30 +179,28 @@ def test_functions_query(executor):
)
funcs = set(executor.functions())
assert funcs >= set(
[
function_meta_data(func_name="func1", return_type="integer"),
function_meta_data(
func_name="func3",
arg_names=["x", "y"],
arg_types=["integer", "integer"],
arg_modes=["t", "t"],
return_type="record",
is_set_returning=True,
),
function_meta_data(
schema_name="public",
func_name="func4",
arg_names=("x",),
arg_types=("integer",),
return_type="integer",
is_set_returning=True,
),
function_meta_data(
schema_name="schema1", func_name="func2", return_type="integer"
),
]
)
assert funcs >= {
function_meta_data(func_name="func1", return_type="integer"),
function_meta_data(
func_name="func3",
arg_names=["x", "y"],
arg_types=["integer", "integer"],
arg_modes=["t", "t"],
return_type="record",
is_set_returning=True,
),
function_meta_data(
schema_name="public",
func_name="func4",
arg_names=("x",),
arg_types=("integer",),
return_type="integer",
is_set_returning=True,
),
function_meta_data(
schema_name="schema1", func_name="func2", return_type="integer"
),
}
@dbtest
@ -257,8 +259,8 @@ def test_not_is_special(executor, pgspecial):
@dbtest
def test_execute_from_file_no_arg(executor, pgspecial):
"""\i without a filename returns an error."""
result = list(executor.run("\i", pgspecial=pgspecial))
r"""\i without a filename returns an error."""
result = list(executor.run(r"\i", pgspecial=pgspecial))
status, sql, success, is_special = result[0][3:]
assert "missing required argument" in status
assert success == False
@ -268,12 +270,12 @@ def test_execute_from_file_no_arg(executor, pgspecial):
@dbtest
@patch("pgcli.main.os")
def test_execute_from_file_io_error(os, executor, pgspecial):
"""\i with an io_error returns an error."""
# Inject an IOError.
os.path.expanduser.side_effect = IOError("test")
r"""\i with an os_error returns an error."""
# Inject an OSError.
os.path.expanduser.side_effect = OSError("test")
# Check the result.
result = list(executor.run("\i test", pgspecial=pgspecial))
result = list(executor.run(r"\i test", pgspecial=pgspecial))
status, sql, success, is_special = result[0][3:]
assert status == "test"
assert success == False
@ -290,7 +292,7 @@ def test_multiple_queries_same_line(executor):
@dbtest
def test_multiple_queries_with_special_command_same_line(executor, pgspecial):
result = run(executor, "select 'foo'; \d", pgspecial=pgspecial)
result = run(executor, r"select 'foo'; \d", pgspecial=pgspecial)
assert len(result) == 11 # 2 * (output+status) * 3 lines
assert "foo" in result[3]
# This is a lame check. :(
@ -408,7 +410,7 @@ def test_date_time_types(executor):
@pytest.mark.parametrize("value", ["10000000", "10000000.0", "10000000000000"])
def test_large_numbers_render_directly(executor, value):
run(executor, "create table numbertest(a numeric)")
run(executor, "insert into numbertest (a) values ({0})".format(value))
run(executor, f"insert into numbertest (a) values ({value})")
assert value in run(executor, "select * from numbertest", join=True)
@ -511,7 +513,7 @@ def test_short_host(executor):
assert executor.short_host == "localhost1"
class BrokenConnection(object):
class BrokenConnection:
"""Mock a connection that failed."""
def cursor(self):

View File

@ -13,12 +13,12 @@ from pgcli.packages.sqlcompletion import (
def test_slash_suggests_special():
suggestions = suggest_type("\\", "\\")
assert set(suggestions) == set([Special()])
assert set(suggestions) == {Special()}
def test_slash_d_suggests_special():
suggestions = suggest_type("\\d", "\\d")
assert set(suggestions) == set([Special()])
assert set(suggestions) == {Special()}
def test_dn_suggests_schemata():
@ -30,24 +30,24 @@ def test_dn_suggests_schemata():
def test_d_suggests_tables_views_and_schemas():
suggestions = suggest_type("\d ", "\d ")
assert set(suggestions) == set([Schema(), Table(schema=None), View(schema=None)])
suggestions = suggest_type(r"\d ", r"\d ")
assert set(suggestions) == {Schema(), Table(schema=None), View(schema=None)}
suggestions = suggest_type("\d xxx", "\d xxx")
assert set(suggestions) == set([Schema(), Table(schema=None), View(schema=None)])
suggestions = suggest_type(r"\d xxx", r"\d xxx")
assert set(suggestions) == {Schema(), Table(schema=None), View(schema=None)}
def test_d_dot_suggests_schema_qualified_tables_or_views():
suggestions = suggest_type("\d myschema.", "\d myschema.")
assert set(suggestions) == set([Table(schema="myschema"), View(schema="myschema")])
suggestions = suggest_type(r"\d myschema.", r"\d myschema.")
assert set(suggestions) == {Table(schema="myschema"), View(schema="myschema")}
suggestions = suggest_type("\d myschema.xxx", "\d myschema.xxx")
assert set(suggestions) == set([Table(schema="myschema"), View(schema="myschema")])
suggestions = suggest_type(r"\d myschema.xxx", r"\d myschema.xxx")
assert set(suggestions) == {Table(schema="myschema"), View(schema="myschema")}
def test_df_suggests_schema_or_function():
suggestions = suggest_type("\\df xxx", "\\df xxx")
assert set(suggestions) == set([Function(schema=None, usage="special"), Schema()])
assert set(suggestions) == {Function(schema=None, usage="special"), Schema()}
suggestions = suggest_type("\\df myschema.xxx", "\\df myschema.xxx")
assert suggestions == (Function(schema="myschema", usage="special"),)
@ -63,7 +63,7 @@ def test_leading_whitespace_ok():
def test_dT_suggests_schema_or_datatypes():
text = "\\dT "
suggestions = suggest_type(text, text)
assert set(suggestions) == set([Schema(), Datatype(schema=None)])
assert set(suggestions) == {Schema(), Datatype(schema=None)}
def test_schema_qualified_dT_suggests_datatypes():

View File

@ -1,5 +1,5 @@
import pytest
from mock import Mock
from unittest.mock import Mock
from pgcli.main import PGCli

View File

@ -193,7 +193,7 @@ def test_suggested_joins(completer, query, tbl):
result = get_result(completer, query.format(tbl))
assert completions_to_set(result) == completions_to_set(
testdata.schemas_and_from_clause_items()
+ [join("custom.shipments ON shipments.user_id = {0}.id".format(tbl))]
+ [join(f"custom.shipments ON shipments.user_id = {tbl}.id")]
)

View File

@ -53,7 +53,7 @@ metadata = {
],
}
metadata = dict((k, {"public": v}) for k, v in metadata.items())
metadata = {k: {"public": v} for k, v in metadata.items()}
testdata = MetaData(metadata)
@ -296,7 +296,7 @@ def test_suggested_cased_always_qualified_column_names(completer):
def test_suggested_column_names_in_function(completer):
result = get_result(completer, "SELECT MAX( from users", len("SELECT MAX("))
assert completions_to_set(result) == completions_to_set(
(testdata.columns_functions_and_keywords("users"))
testdata.columns_functions_and_keywords("users")
)
@ -316,7 +316,7 @@ def test_suggested_column_names_with_alias(completer):
def test_suggested_multiple_column_names(completer):
result = get_result(completer, "SELECT id, from users u", len("SELECT id, "))
assert completions_to_set(result) == completions_to_set(
(testdata.columns_functions_and_keywords("users"))
testdata.columns_functions_and_keywords("users")
)

View File

@ -23,16 +23,14 @@ def cols_etc(
):
"""Returns the expected select-clause suggestions for a single-table
select."""
return set(
[
Column(
table_refs=(TableReference(schema, table, alias, is_function),),
qualifiable=True,
),
Function(schema=parent),
Keyword(last_keyword),
]
)
return {
Column(
table_refs=(TableReference(schema, table, alias, is_function),),
qualifiable=True,
),
Function(schema=parent),
Keyword(last_keyword),
}
def test_select_suggests_cols_with_visible_table_scope():
@ -103,24 +101,20 @@ def test_where_equals_any_suggests_columns_or_keywords():
def test_lparen_suggests_cols_and_funcs():
suggestion = suggest_type("SELECT MAX( FROM tbl", "SELECT MAX(")
assert set(suggestion) == set(
[
Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
Function(schema=None),
Keyword("("),
]
)
assert set(suggestion) == {
Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
Function(schema=None),
Keyword("("),
}
def test_select_suggests_cols_and_funcs():
suggestions = suggest_type("SELECT ", "SELECT ")
assert set(suggestions) == set(
[
Column(table_refs=(), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
]
)
assert set(suggestions) == {
Column(table_refs=(), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
}
@pytest.mark.parametrize(
@ -128,13 +122,13 @@ def test_select_suggests_cols_and_funcs():
)
def test_suggests_tables_views_and_schemas(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([Table(schema=None), View(schema=None), Schema()])
assert set(suggestions) == {Table(schema=None), View(schema=None), Schema()}
@pytest.mark.parametrize("expression", ["SELECT * FROM "])
def test_suggest_tables_views_schemas_and_functions(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
@pytest.mark.parametrize(
@ -147,9 +141,11 @@ def test_suggest_tables_views_schemas_and_functions(expression):
def test_suggest_after_join_with_two_tables(expression):
suggestions = suggest_type(expression, expression)
tables = tuple([(None, "foo", None, False), (None, "bar", None, False)])
assert set(suggestions) == set(
[FromClauseItem(schema=None, table_refs=tables), Join(tables, None), Schema()]
)
assert set(suggestions) == {
FromClauseItem(schema=None, table_refs=tables),
Join(tables, None),
Schema(),
}
@pytest.mark.parametrize(
@ -158,13 +154,11 @@ def test_suggest_after_join_with_two_tables(expression):
def test_suggest_after_join_with_one_table(expression):
suggestions = suggest_type(expression, expression)
tables = ((None, "foo", None, False),)
assert set(suggestions) == set(
[
FromClauseItem(schema=None, table_refs=tables),
Join(((None, "foo", None, False),), None),
Schema(),
]
)
assert set(suggestions) == {
FromClauseItem(schema=None, table_refs=tables),
Join(((None, "foo", None, False),), None),
Schema(),
}
@pytest.mark.parametrize(
@ -172,13 +166,13 @@ def test_suggest_after_join_with_one_table(expression):
)
def test_suggest_qualified_tables_and_views(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([Table(schema="sch"), View(schema="sch")])
assert set(suggestions) == {Table(schema="sch"), View(schema="sch")}
@pytest.mark.parametrize("expression", ["UPDATE sch."])
def test_suggest_qualified_aliasable_tables_and_views(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([Table(schema="sch"), View(schema="sch")])
assert set(suggestions) == {Table(schema="sch"), View(schema="sch")}
@pytest.mark.parametrize(
@ -193,26 +187,27 @@ def test_suggest_qualified_aliasable_tables_and_views(expression):
)
def test_suggest_qualified_tables_views_and_functions(expression):
suggestions = suggest_type(expression, expression)
assert set(suggestions) == set([FromClauseItem(schema="sch")])
assert set(suggestions) == {FromClauseItem(schema="sch")}
@pytest.mark.parametrize("expression", ["SELECT * FROM foo JOIN sch."])
def test_suggest_qualified_tables_views_functions_and_joins(expression):
suggestions = suggest_type(expression, expression)
tbls = tuple([(None, "foo", None, False)])
assert set(suggestions) == set(
[FromClauseItem(schema="sch", table_refs=tbls), Join(tbls, "sch")]
)
assert set(suggestions) == {
FromClauseItem(schema="sch", table_refs=tbls),
Join(tbls, "sch"),
}
def test_truncate_suggests_tables_and_schemas():
suggestions = suggest_type("TRUNCATE ", "TRUNCATE ")
assert set(suggestions) == set([Table(schema=None), Schema()])
assert set(suggestions) == {Table(schema=None), Schema()}
def test_truncate_suggests_qualified_tables():
suggestions = suggest_type("TRUNCATE sch.", "TRUNCATE sch.")
assert set(suggestions) == set([Table(schema="sch")])
assert set(suggestions) == {Table(schema="sch")}
@pytest.mark.parametrize(
@ -220,13 +215,11 @@ def test_truncate_suggests_qualified_tables():
)
def test_distinct_suggests_cols(text):
suggestions = suggest_type(text, text)
assert set(suggestions) == set(
[
Column(table_refs=(), local_tables=(), qualifiable=True),
Function(schema=None),
Keyword("DISTINCT"),
]
)
assert set(suggestions) == {
Column(table_refs=(), local_tables=(), qualifiable=True),
Function(schema=None),
Keyword("DISTINCT"),
}
@pytest.mark.parametrize(
@ -244,20 +237,18 @@ def test_distinct_and_order_by_suggestions_with_aliases(
text, text_before, last_keyword
):
suggestions = suggest_type(text, text_before)
assert set(suggestions) == set(
[
Column(
table_refs=(
TableReference(None, "tbl", "x", False),
TableReference(None, "tbl1", "y", False),
),
local_tables=(),
qualifiable=True,
assert set(suggestions) == {
Column(
table_refs=(
TableReference(None, "tbl", "x", False),
TableReference(None, "tbl1", "y", False),
),
Function(schema=None),
Keyword(last_keyword),
]
)
local_tables=(),
qualifiable=True,
),
Function(schema=None),
Keyword(last_keyword),
}
@pytest.mark.parametrize(
@ -272,56 +263,50 @@ def test_distinct_and_order_by_suggestions_with_aliases(
)
def test_distinct_and_order_by_suggestions_with_alias_given(text, text_before):
suggestions = suggest_type(text, text_before)
assert set(suggestions) == set(
[
Column(
table_refs=(TableReference(None, "tbl", "x", False),),
local_tables=(),
qualifiable=False,
),
Table(schema="x"),
View(schema="x"),
Function(schema="x"),
]
)
assert set(suggestions) == {
Column(
table_refs=(TableReference(None, "tbl", "x", False),),
local_tables=(),
qualifiable=False,
),
Table(schema="x"),
View(schema="x"),
Function(schema="x"),
}
def test_function_arguments_with_alias_given():
suggestions = suggest_type("SELECT avg(x. FROM tbl x, tbl2 y", "SELECT avg(x.")
assert set(suggestions) == set(
[
Column(
table_refs=(TableReference(None, "tbl", "x", False),),
local_tables=(),
qualifiable=False,
),
Table(schema="x"),
View(schema="x"),
Function(schema="x"),
]
)
assert set(suggestions) == {
Column(
table_refs=(TableReference(None, "tbl", "x", False),),
local_tables=(),
qualifiable=False,
),
Table(schema="x"),
View(schema="x"),
Function(schema="x"),
}
def test_col_comma_suggests_cols():
suggestions = suggest_type("SELECT a, b, FROM tbl", "SELECT a, b,")
assert set(suggestions) == set(
[
Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
]
)
assert set(suggestions) == {
Column(table_refs=((None, "tbl", None, False),), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
}
def test_table_comma_suggests_tables_and_schemas():
suggestions = suggest_type("SELECT a, b FROM tbl1, ", "SELECT a, b FROM tbl1, ")
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
def test_into_suggests_tables_and_schemas():
suggestion = suggest_type("INSERT INTO ", "INSERT INTO ")
assert set(suggestion) == set([Table(schema=None), View(schema=None), Schema()])
assert set(suggestion) == {Table(schema=None), View(schema=None), Schema()}
@pytest.mark.parametrize(
@ -357,14 +342,12 @@ def test_partially_typed_col_name_suggests_col_names():
def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
suggestions = suggest_type("SELECT tabl. FROM tabl", "SELECT tabl.")
assert set(suggestions) == set(
[
Column(table_refs=((None, "tabl", None, False),)),
Table(schema="tabl"),
View(schema="tabl"),
Function(schema="tabl"),
]
)
assert set(suggestions) == {
Column(table_refs=((None, "tabl", None, False),)),
Table(schema="tabl"),
View(schema="tabl"),
Function(schema="tabl"),
}
@pytest.mark.parametrize(
@ -378,14 +361,12 @@ def test_dot_suggests_cols_of_a_table_or_schema_qualified_table():
)
def test_dot_suggests_cols_of_an_alias(sql):
suggestions = suggest_type(sql, "SELECT t1.")
assert set(suggestions) == set(
[
Table(schema="t1"),
View(schema="t1"),
Column(table_refs=((None, "tabl1", "t1", False),)),
Function(schema="t1"),
]
)
assert set(suggestions) == {
Table(schema="t1"),
View(schema="t1"),
Column(table_refs=((None, "tabl1", "t1", False),)),
Function(schema="t1"),
}
@pytest.mark.parametrize(
@ -399,28 +380,24 @@ def test_dot_suggests_cols_of_an_alias(sql):
)
def test_dot_suggests_cols_of_an_alias_where(sql):
suggestions = suggest_type(sql, sql)
assert set(suggestions) == set(
[
Table(schema="t1"),
View(schema="t1"),
Column(table_refs=((None, "tabl1", "t1", False),)),
Function(schema="t1"),
]
)
assert set(suggestions) == {
Table(schema="t1"),
View(schema="t1"),
Column(table_refs=((None, "tabl1", "t1", False),)),
Function(schema="t1"),
}
def test_dot_col_comma_suggests_cols_or_schema_qualified_table():
suggestions = suggest_type(
"SELECT t1.a, t2. FROM tabl1 t1, tabl2 t2", "SELECT t1.a, t2."
)
assert set(suggestions) == set(
[
Column(table_refs=((None, "tabl2", "t2", False),)),
Table(schema="t2"),
View(schema="t2"),
Function(schema="t2"),
]
)
assert set(suggestions) == {
Column(table_refs=((None, "tabl2", "t2", False),)),
Table(schema="t2"),
View(schema="t2"),
Function(schema="t2"),
}
@pytest.mark.parametrize(
@ -452,20 +429,18 @@ def test_sub_select_partial_text_suggests_keyword(expression):
def test_outer_table_reference_in_exists_subquery_suggests_columns():
q = "SELECT * FROM foo f WHERE EXISTS (SELECT 1 FROM bar WHERE f."
suggestions = suggest_type(q, q)
assert set(suggestions) == set(
[
Column(table_refs=((None, "foo", "f", False),)),
Table(schema="f"),
View(schema="f"),
Function(schema="f"),
]
)
assert set(suggestions) == {
Column(table_refs=((None, "foo", "f", False),)),
Table(schema="f"),
View(schema="f"),
Function(schema="f"),
}
@pytest.mark.parametrize("expression", ["SELECT * FROM (SELECT * FROM "])
def test_sub_select_table_name_completion(expression):
suggestion = suggest_type(expression, expression)
assert set(suggestion) == set([FromClauseItem(schema=None), Schema()])
assert set(suggestion) == {FromClauseItem(schema=None), Schema()}
@pytest.mark.parametrize(
@ -478,22 +453,18 @@ def test_sub_select_table_name_completion(expression):
def test_sub_select_table_name_completion_with_outer_table(expression):
suggestion = suggest_type(expression, expression)
tbls = tuple([(None, "foo", None, False)])
assert set(suggestion) == set(
[FromClauseItem(schema=None, table_refs=tbls), Schema()]
)
assert set(suggestion) == {FromClauseItem(schema=None, table_refs=tbls), Schema()}
def test_sub_select_col_name_completion():
suggestions = suggest_type(
"SELECT * FROM (SELECT FROM abc", "SELECT * FROM (SELECT "
)
assert set(suggestions) == set(
[
Column(table_refs=((None, "abc", None, False),), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
]
)
assert set(suggestions) == {
Column(table_refs=((None, "abc", None, False),), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
}
@pytest.mark.xfail
@ -508,25 +479,25 @@ def test_sub_select_dot_col_name_completion():
suggestions = suggest_type(
"SELECT * FROM (SELECT t. FROM tabl t", "SELECT * FROM (SELECT t."
)
assert set(suggestions) == set(
[
Column(table_refs=((None, "tabl", "t", False),)),
Table(schema="t"),
View(schema="t"),
Function(schema="t"),
]
)
assert set(suggestions) == {
Column(table_refs=((None, "tabl", "t", False),)),
Table(schema="t"),
View(schema="t"),
Function(schema="t"),
}
@pytest.mark.parametrize("join_type", ("", "INNER", "LEFT", "RIGHT OUTER"))
@pytest.mark.parametrize("tbl_alias", ("", "foo"))
def test_join_suggests_tables_and_schemas(tbl_alias, join_type):
text = "SELECT * FROM abc {0} {1} JOIN ".format(tbl_alias, join_type)
text = f"SELECT * FROM abc {tbl_alias} {join_type} JOIN "
suggestion = suggest_type(text, text)
tbls = tuple([(None, "abc", tbl_alias or None, False)])
assert set(suggestion) == set(
[FromClauseItem(schema=None, table_refs=tbls), Schema(), Join(tbls, None)]
)
assert set(suggestion) == {
FromClauseItem(schema=None, table_refs=tbls),
Schema(),
Join(tbls, None),
}
def test_left_join_with_comma():
@ -535,9 +506,7 @@ def test_left_join_with_comma():
# tbls should also include (None, 'bar', 'b', False)
# but there's a bug with commas
tbls = tuple([(None, "foo", "f", False)])
assert set(suggestions) == set(
[FromClauseItem(schema=None, table_refs=tbls), Schema()]
)
assert set(suggestions) == {FromClauseItem(schema=None, table_refs=tbls), Schema()}
@pytest.mark.parametrize(
@ -550,15 +519,13 @@ def test_left_join_with_comma():
def test_join_alias_dot_suggests_cols1(sql):
suggestions = suggest_type(sql, sql)
tables = ((None, "abc", "a", False), (None, "def", "d", False))
assert set(suggestions) == set(
[
Column(table_refs=((None, "abc", "a", False),)),
Table(schema="a"),
View(schema="a"),
Function(schema="a"),
JoinCondition(table_refs=tables, parent=(None, "abc", "a", False)),
]
)
assert set(suggestions) == {
Column(table_refs=((None, "abc", "a", False),)),
Table(schema="a"),
View(schema="a"),
Function(schema="a"),
JoinCondition(table_refs=tables, parent=(None, "abc", "a", False)),
}
@pytest.mark.parametrize(
@ -570,14 +537,12 @@ def test_join_alias_dot_suggests_cols1(sql):
)
def test_join_alias_dot_suggests_cols2(sql):
suggestion = suggest_type(sql, sql)
assert set(suggestion) == set(
[
Column(table_refs=((None, "def", "d", False),)),
Table(schema="d"),
View(schema="d"),
Function(schema="d"),
]
)
assert set(suggestion) == {
Column(table_refs=((None, "def", "d", False),)),
Table(schema="d"),
View(schema="d"),
Function(schema="d"),
}
@pytest.mark.parametrize(
@ -598,9 +563,10 @@ on """,
def test_on_suggests_aliases_and_join_conditions(sql):
suggestions = suggest_type(sql, sql)
tables = ((None, "abc", "a", False), (None, "bcd", "b", False))
assert set(suggestions) == set(
(JoinCondition(table_refs=tables, parent=None), Alias(aliases=("a", "b")))
)
assert set(suggestions) == {
JoinCondition(table_refs=tables, parent=None),
Alias(aliases=("a", "b")),
}
@pytest.mark.parametrize(
@ -613,9 +579,10 @@ def test_on_suggests_aliases_and_join_conditions(sql):
def test_on_suggests_tables_and_join_conditions(sql):
suggestions = suggest_type(sql, sql)
tables = ((None, "abc", None, False), (None, "bcd", None, False))
assert set(suggestions) == set(
(JoinCondition(table_refs=tables, parent=None), Alias(aliases=("abc", "bcd")))
)
assert set(suggestions) == {
JoinCondition(table_refs=tables, parent=None),
Alias(aliases=("abc", "bcd")),
}
@pytest.mark.parametrize(
@ -640,9 +607,10 @@ def test_on_suggests_aliases_right_side(sql):
def test_on_suggests_tables_and_join_conditions_right_side(sql):
suggestions = suggest_type(sql, sql)
tables = ((None, "abc", None, False), (None, "bcd", None, False))
assert set(suggestions) == set(
(JoinCondition(table_refs=tables, parent=None), Alias(aliases=("abc", "bcd")))
)
assert set(suggestions) == {
JoinCondition(table_refs=tables, parent=None),
Alias(aliases=("abc", "bcd")),
}
@pytest.mark.parametrize(
@ -659,9 +627,9 @@ def test_on_suggests_tables_and_join_conditions_right_side(sql):
)
def test_join_using_suggests_common_columns(text):
tables = ((None, "abc", None, False), (None, "def", None, False))
assert set(suggest_type(text, text)) == set(
[Column(table_refs=tables, require_last_table=True)]
)
assert set(suggest_type(text, text)) == {
Column(table_refs=tables, require_last_table=True)
}
def test_suggest_columns_after_multiple_joins():
@ -678,29 +646,27 @@ def test_2_statements_2nd_current():
suggestions = suggest_type(
"select * from a; select * from ", "select * from a; select * from "
)
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
suggestions = suggest_type(
"select * from a; select from b", "select * from a; select "
)
assert set(suggestions) == set(
[
Column(table_refs=((None, "b", None, False),), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
]
)
assert set(suggestions) == {
Column(table_refs=((None, "b", None, False),), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
}
# Should work even if first statement is invalid
suggestions = suggest_type(
"select * from; select * from ", "select * from; select * from "
)
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
def test_2_statements_1st_current():
suggestions = suggest_type("select * from ; select * from b", "select * from ")
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
suggestions = suggest_type("select from a; select * from b", "select ")
assert set(suggestions) == cols_etc("a", last_keyword="SELECT")
@ -711,7 +677,7 @@ def test_3_statements_2nd_current():
"select * from a; select * from ; select * from c",
"select * from a; select * from ",
)
assert set(suggestions) == set([FromClauseItem(schema=None), Schema()])
assert set(suggestions) == {FromClauseItem(schema=None), Schema()}
suggestions = suggest_type(
"select * from a; select from b; select * from c", "select * from a; select "
@ -768,13 +734,11 @@ SELECT * FROM qux;
)
def test_statements_in_function_body(text):
suggestions = suggest_type(text, text[: text.find(" ") + 1])
assert set(suggestions) == set(
[
Column(table_refs=((None, "foo", None, False),), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
]
)
assert set(suggestions) == {
Column(table_refs=((None, "foo", None, False),), qualifiable=True),
Function(schema=None),
Keyword("SELECT"),
}
functions = [
@ -799,13 +763,13 @@ SELECT 1 FROM foo;
@pytest.mark.parametrize("text", functions)
def test_statements_with_cursor_after_function_body(text):
suggestions = suggest_type(text, text[: text.find("; ") + 1])
assert set(suggestions) == set([Keyword(), Special()])
assert set(suggestions) == {Keyword(), Special()}
@pytest.mark.parametrize("text", functions)
def test_statements_with_cursor_before_function_body(text):
suggestions = suggest_type(text, "")
assert set(suggestions) == set([Keyword(), Special()])
assert set(suggestions) == {Keyword(), Special()}
def test_create_db_with_template():
@ -813,14 +777,14 @@ def test_create_db_with_template():
"create database foo with template ", "create database foo with template "
)
assert set(suggestions) == set((Database(),))
assert set(suggestions) == {Database()}
@pytest.mark.parametrize("initial_text", ("", " ", "\t \t", "\n"))
def test_specials_included_for_initial_completion(initial_text):
suggestions = suggest_type(initial_text, initial_text)
assert set(suggestions) == set([Keyword(), Special()])
assert set(suggestions) == {Keyword(), Special()}
def test_drop_schema_qualified_table_suggests_only_tables():
@ -843,25 +807,30 @@ def test_drop_schema_suggests_schemas():
@pytest.mark.parametrize("text", ["SELECT x::", "SELECT x::y", "SELECT (x + y)::"])
def test_cast_operator_suggests_types(text):
assert set(suggest_type(text, text)) == set(
[Datatype(schema=None), Table(schema=None), Schema()]
)
assert set(suggest_type(text, text)) == {
Datatype(schema=None),
Table(schema=None),
Schema(),
}
@pytest.mark.parametrize(
"text", ["SELECT foo::bar.", "SELECT foo::bar.baz", "SELECT (x + y)::bar."]
)
def test_cast_operator_suggests_schema_qualified_types(text):
assert set(suggest_type(text, text)) == set(
[Datatype(schema="bar"), Table(schema="bar")]
)
assert set(suggest_type(text, text)) == {
Datatype(schema="bar"),
Table(schema="bar"),
}
def test_alter_column_type_suggests_types():
q = "ALTER TABLE foo ALTER COLUMN bar TYPE "
assert set(suggest_type(q, q)) == set(
[Datatype(schema=None), Table(schema=None), Schema()]
)
assert set(suggest_type(q, q)) == {
Datatype(schema=None),
Table(schema=None),
Schema(),
}
@pytest.mark.parametrize(
@ -880,9 +849,11 @@ def test_alter_column_type_suggests_types():
],
)
def test_identifier_suggests_types_in_parentheses(text):
assert set(suggest_type(text, text)) == set(
[Datatype(schema=None), Table(schema=None), Schema()]
)
assert set(suggest_type(text, text)) == {
Datatype(schema=None),
Table(schema=None),
Schema(),
}
@pytest.mark.parametrize(
@ -977,7 +948,7 @@ def test_ignore_leading_double_quotes(sql):
)
def test_column_keyword_suggests_columns(sql):
suggestions = suggest_type(sql, sql)
assert set(suggestions) == set([Column(table_refs=((None, "foo", None, False),))])
assert set(suggestions) == {Column(table_refs=((None, "foo", None, False),))}
def test_handle_unrecognized_kw_generously():

View File

@ -89,7 +89,7 @@ def run(
def completions_to_set(completions):
return set(
return {
(completion.display_text, completion.display_meta_text)
for completion in completions
)
}