1
0
mirror of https://github.com/dbcli/pgcli synced 2024-06-16 01:42:23 +00:00
pgcli/pgcli/pgcompleter.py
2017-10-11 12:58:32 -07:00

974 lines
40 KiB
Python

from __future__ import print_function, unicode_literals
import logging
import re
from itertools import count, repeat, chain
import operator
from collections import namedtuple, defaultdict, OrderedDict
from pgspecial.namedqueries import NamedQueries
from prompt_toolkit.completion import Completer, Completion
from prompt_toolkit.contrib.completers import PathCompleter
from prompt_toolkit.document import Document
from .packages.sqlcompletion import (FromClauseItem,
suggest_type, Special, Database, Schema, Table, Function, Column, View,
Keyword, NamedQuery, Datatype, Alias, Path, JoinCondition, Join)
from .packages.parseutils.meta import ColumnMetadata, ForeignKey
from .packages.parseutils.utils import last_word
from .packages.parseutils.tables import TableReference
from .packages.pgliterals.main import get_literals
from .packages.prioritization import PrevalenceCounter
from .config import load_config, config_location
_logger = logging.getLogger(__name__)
NamedQueries.instance = NamedQueries.from_config(
load_config(config_location() + 'config'))
Match = namedtuple('Match', ['completion', 'priority'])
_SchemaObject = namedtuple('SchemaObject', 'name schema meta')
def SchemaObject(name, schema=None, meta=None):
return _SchemaObject(name, schema, meta)
_Candidate = namedtuple(
'Candidate', 'completion prio meta synonyms prio2 display'
)
def Candidate(
completion, prio=None, meta=None, synonyms=None, prio2=None,
display=None
):
return _Candidate(
completion, prio, meta, synonyms or [completion], prio2,
display or completion
)
# Used to strip trailing '::some_type' from default-value expressions
arg_default_type_strip_regex = re.compile(r'::[\w\.]+(\[\])?$')
normalize_ref = lambda ref: ref if ref[0] == '"' else '"' + ref.lower() + '"'
def generate_alias(tbl):
""" Generate a table alias, consisting of all upper-case letters in
the table name, or, if there are no upper-case letters, the first letter +
all letters preceded by _
param tbl - unescaped name of the table to alias
"""
return ''.join([l for l in tbl if l.isupper()] or
[l for l, prev in zip(tbl, '_' + tbl) if prev == '_' and l != '_'])
class PGCompleter(Completer):
# keywords_tree: A dict mapping keywords to well known following keywords.
# e.g. 'CREATE': ['TABLE', 'USER', ...],
keywords_tree = get_literals('keywords', type_=dict)
keywords = tuple(set(chain(keywords_tree.keys(), *keywords_tree.values())))
functions = get_literals('functions')
datatypes = get_literals('datatypes')
reserved_words = set(get_literals('reserved'))
def __init__(self, smart_completion=True, pgspecial=None, settings=None):
super(PGCompleter, self).__init__()
self.smart_completion = smart_completion
self.pgspecial = pgspecial
self.prioritizer = PrevalenceCounter()
settings = settings or {}
self.signature_arg_style = settings.get(
'signature_arg_style', '{arg_name} {arg_type}'
)
self.call_arg_style = settings.get(
'call_arg_style', '{arg_name: <{max_arg_len}} := {arg_default}'
)
self.call_arg_display_style = settings.get(
'call_arg_display_style', '{arg_name}'
)
self.call_arg_oneliner_max = settings.get('call_arg_oneliner_max', 2)
self.search_path_filter = settings.get('search_path_filter')
self.generate_aliases = settings.get('generate_aliases')
self.casing_file = settings.get('casing_file')
self.insert_col_skip_patterns = [
re.compile(pattern) for pattern in settings.get(
'insert_col_skip_patterns',
[r'^now\(\)$', r'^nextval\(']
)
]
self.generate_casing_file = settings.get('generate_casing_file')
self.qualify_columns = settings.get(
'qualify_columns', 'if_more_than_one_table')
self.asterisk_column_order = settings.get(
'asterisk_column_order', 'table_order')
keyword_casing = settings.get('keyword_casing', 'upper').lower()
if keyword_casing not in ('upper', 'lower', 'auto'):
keyword_casing = 'upper'
self.keyword_casing = keyword_casing
self.name_pattern = re.compile(r"^[_a-z][_a-z0-9\$]*$")
self.databases = []
self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {},
'datatypes': {}}
self.search_path = []
self.casing = {}
self.all_completions = set(self.keywords + self.functions)
def escape_name(self, name):
if name and ((not self.name_pattern.match(name))
or (name.upper() in self.reserved_words)
or (name.upper() in self.functions)):
name = '"%s"' % name
return name
def escape_schema(self, name):
return "'{}'".format(self.unescape_name(name))
def unescape_name(self, name):
""" Unquote a string."""
if name and name[0] == '"' and name[-1] == '"':
name = name[1:-1]
return name
def escaped_names(self, names):
return [self.escape_name(name) for name in names]
def extend_database_names(self, databases):
self.databases.extend(databases)
def extend_keywords(self, additional_keywords):
self.keywords.extend(additional_keywords)
self.all_completions.update(additional_keywords)
def extend_schemata(self, schemata):
# schemata is a list of schema names
schemata = self.escaped_names(schemata)
metadata = self.dbmetadata['tables']
for schema in schemata:
metadata[schema] = {}
# dbmetadata.values() are the 'tables' and 'functions' dicts
for metadata in self.dbmetadata.values():
for schema in schemata:
metadata[schema] = {}
self.all_completions.update(schemata)
def extend_casing(self, words):
""" extend casing data
:return:
"""
# casing should be a dict {lowercasename:PreferredCasingName}
self.casing = dict((word.lower(), word) for word in words)
def extend_relations(self, data, kind):
"""extend metadata for tables or views.
:param data: list of (schema_name, rel_name) tuples
:param kind: either 'tables' or 'views'
:return:
"""
data = [self.escaped_names(d) for d in data]
# dbmetadata['tables']['schema_name']['table_name'] should be an
# OrderedDict {column_name:ColumnMetaData}.
metadata = self.dbmetadata[kind]
for schema, relname in data:
try:
metadata[schema][relname] = OrderedDict()
except KeyError:
_logger.error('%r %r listed in unrecognized schema %r',
kind, relname, schema)
self.all_completions.add(relname)
def extend_columns(self, column_data, kind):
"""extend column metadata.
:param column_data: list of (schema_name, rel_name, column_name,
column_type, has_default, default) tuples
:param kind: either 'tables' or 'views'
:return:
"""
metadata = self.dbmetadata[kind]
for schema, relname, colname, datatype, has_default, default in column_data:
(schema, relname, colname) = self.escaped_names(
[schema, relname, colname])
column = ColumnMetadata(
name=colname,
datatype=datatype,
has_default=has_default,
default=default
)
metadata[schema][relname][colname] = column
self.all_completions.add(colname)
def extend_functions(self, func_data):
# func_data is a list of function metadata namedtuples
# dbmetadata['schema_name']['functions']['function_name'] should return
# the function metadata namedtuple for the corresponding function
metadata = self.dbmetadata['functions']
for f in func_data:
schema, func = self.escaped_names([f.schema_name, f.func_name])
if func in metadata[schema]:
metadata[schema][func].append(f)
else:
metadata[schema][func] = [f]
self.all_completions.add(func)
self._refresh_arg_list_cache()
def _refresh_arg_list_cache(self):
# We keep a cache of {function_usage:{function_metadata: function_arg_list_string}}
# This is used when suggesting functions, to avoid the latency that would result
# if we'd recalculate the arg lists each time we suggest functions (in large DBs)
self._arg_list_cache = {
usage: {
meta: self._arg_list(meta, usage)
for sch, funcs in self.dbmetadata['functions'].items()
for func, metas in funcs.items()
for meta in metas
}
for usage in ('call', 'call_display', 'signature')
}
def extend_foreignkeys(self, fk_data):
# fk_data is a list of ForeignKey namedtuples, with fields
# parentschema, childschema, parenttable, childtable,
# parentcolumns, childcolumns
# These are added as a list of ForeignKey namedtuples to the
# ColumnMetadata namedtuple for both the child and parent
meta = self.dbmetadata['tables']
for fk in fk_data:
e = self.escaped_names
parentschema, childschema = e([fk.parentschema, fk.childschema])
parenttable, childtable = e([fk.parenttable, fk.childtable])
childcol, parcol = e([fk.childcolumn, fk.parentcolumn])
childcolmeta = meta[childschema][childtable][childcol]
parcolmeta = meta[parentschema][parenttable][parcol]
fk = ForeignKey(parentschema, parenttable, parcol,
childschema, childtable, childcol)
childcolmeta.foreignkeys.append((fk))
parcolmeta.foreignkeys.append((fk))
def extend_datatypes(self, type_data):
# dbmetadata['datatypes'][schema_name][type_name] should store type
# metadata, such as composite type field names. Currently, we're not
# storing any metadata beyond typename, so just store None
meta = self.dbmetadata['datatypes']
for t in type_data:
schema, type_name = self.escaped_names(t)
meta[schema][type_name] = None
self.all_completions.add(type_name)
def extend_query_history(self, text, is_init=False):
if is_init:
# During completer initialization, only load keyword preferences,
# not names
self.prioritizer.update_keywords(text)
else:
self.prioritizer.update(text)
def set_search_path(self, search_path):
self.search_path = self.escaped_names(search_path)
def reset_completions(self):
self.databases = []
self.special_commands = []
self.search_path = []
self.dbmetadata = {'tables': {}, 'views': {}, 'functions': {},
'datatypes': {}}
self.all_completions = set(self.keywords + self.functions)
def find_matches(self, text, collection, mode='fuzzy', meta=None):
"""Find completion matches for the given text.
Given the user's input text and a collection of available
completions, find completions matching the last word of the
text.
`collection` can be either a list of strings or a list of Candidate
namedtuples.
`mode` can be either 'fuzzy', or 'strict'
'fuzzy': fuzzy matching, ties broken by name prevalance
`keyword`: start only matching, ties broken by keyword prevalance
yields prompt_toolkit Completion instances for any matches found
in the collection of available completions.
"""
if not collection:
return []
prio_order = [
'keyword', 'function', 'view', 'table', 'datatype', 'database',
'schema', 'column', 'table alias', 'join', 'name join', 'fk join'
]
type_priority = prio_order.index(meta) if meta in prio_order else -1
text = last_word(text, include='most_punctuations').lower()
text_len = len(text)
if text and text[0] == '"':
# text starts with double quote; user is manually escaping a name
# Match on everything that follows the double-quote. Note that
# text_len is calculated before removing the quote, so the
# Completion.position value is correct
text = text[1:]
if mode == 'fuzzy':
fuzzy = True
priority_func = self.prioritizer.name_count
else:
fuzzy = False
priority_func = self.prioritizer.keyword_count
# Construct a `_match` function for either fuzzy or non-fuzzy matching
# The match function returns a 2-tuple used for sorting the matches,
# or None if the item doesn't match
# Note: higher priority values mean more important, so use negative
# signs to flip the direction of the tuple
if fuzzy:
regex = '.*?'.join(map(re.escape, text))
pat = re.compile('(%s)' % regex)
def _match(item):
if item.lower()[:len(text) + 1] in (text, text + ' '):
# Exact match of first word in suggestion
# This is to get exact alias matches to the top
# E.g. for input `e`, 'Entries E' should be on top
# (before e.g. `EndUsers EU`)
return float('Infinity'), -1
r = pat.search(self.unescape_name(item.lower()))
if r:
return -len(r.group()), -r.start()
else:
match_end_limit = len(text)
def _match(item):
match_point = item.lower().find(text, 0, match_end_limit)
if match_point >= 0:
# Use negative infinity to force keywords to sort after all
# fuzzy matches
return -float('Infinity'), -match_point
matches = []
for cand in collection:
if isinstance(cand, _Candidate):
item, prio, display_meta, synonyms, prio2, display = cand
if display_meta is None:
display_meta = meta
syn_matches = (_match(x) for x in synonyms)
# Nones need to be removed to avoid max() crashing in Python 3
syn_matches = [m for m in syn_matches if m]
sort_key = max(syn_matches) if syn_matches else None
else:
item, display_meta, prio, prio2, display = cand, meta, 0, 0, cand
sort_key = _match(cand)
if sort_key:
if display_meta and len(display_meta) > 50:
# Truncate meta-text to 50 characters, if necessary
display_meta = display_meta[:47] + u'...'
# Lexical order of items in the collection, used for
# tiebreaking items with the same match group length and start
# position. Since we use *higher* priority to mean "more
# important," we use -ord(c) to prioritize "aa" > "ab" and end
# with 1 to prioritize shorter strings (ie "user" > "users").
# We first do a case-insensitive sort and then a
# case-sensitive one as a tie breaker.
# We also use the unescape_name to make sure quoted names have
# the same priority as unquoted names.
lexical_priority = (tuple(0 if c in(' _') else -ord(c)
for c in self.unescape_name(item.lower())) + (1,)
+ tuple(c for c in item))
item = self.case(item)
display = self.case(display)
priority = (
sort_key, type_priority, prio, priority_func(item),
prio2, lexical_priority
)
matches.append(
Match(
completion=Completion(
text=item,
start_position=-text_len,
display_meta=display_meta,
display=display
),
priority=priority
)
)
return matches
def case(self, word):
return self.casing.get(word, word)
def get_completions(self, document, complete_event, smart_completion=None):
word_before_cursor = document.get_word_before_cursor(WORD=True)
if smart_completion is None:
smart_completion = self.smart_completion
# If smart_completion is off then match any word that starts with
# 'word_before_cursor'.
if not smart_completion:
matches = self.find_matches(word_before_cursor, self.all_completions,
mode='strict')
completions = [m.completion for m in matches]
return sorted(completions, key=operator.attrgetter('text'))
matches = []
suggestions = suggest_type(document.text, document.text_before_cursor)
for suggestion in suggestions:
suggestion_type = type(suggestion)
_logger.debug('Suggestion type: %r', suggestion_type)
# Map suggestion type to method
# e.g. 'table' -> self.get_table_matches
matcher = self.suggestion_matchers[suggestion_type]
matches.extend(matcher(self, suggestion, word_before_cursor))
# Sort matches so highest priorities are first
matches = sorted(matches, key=operator.attrgetter('priority'),
reverse=True)
return [m.completion for m in matches]
def get_column_matches(self, suggestion, word_before_cursor):
tables = suggestion.table_refs
do_qualify = suggestion.qualifiable and {'always': True, 'never': False,
'if_more_than_one_table': len(tables) > 1}[self.qualify_columns]
qualify = lambda col, tbl: (
(tbl + '.' + self.case(col)) if do_qualify else self.case(col))
_logger.debug("Completion column scope: %r", tables)
scoped_cols = self.populate_scoped_cols(tables, suggestion.local_tables)
def make_cand(name, ref):
synonyms = (name, generate_alias(self.case(name)))
return Candidate(qualify(name, ref), 0, 'column', synonyms)
def flat_cols():
return [make_cand(c.name, t.ref) for t, cols in scoped_cols.items() for c in cols]
if suggestion.require_last_table:
# require_last_table is used for 'tb11 JOIN tbl2 USING (...' which should
# suggest only columns that appear in the last table and one more
ltbl = tables[-1].ref
other_tbl_cols = set(
c.name for t, cs in scoped_cols.items() if t.ref != ltbl for c in cs)
scoped_cols = {
t: [col for col in cols if col.name in other_tbl_cols]
for t, cols in scoped_cols.items()
if t.ref == ltbl
}
lastword = last_word(word_before_cursor, include='most_punctuations')
if lastword == '*':
if suggestion.context == 'insert':
def filter(col):
if not col.has_default:
return True
return not any(
p.match(col.default)
for p in self.insert_col_skip_patterns
)
scoped_cols = {
t: [col for col in cols if filter(col)] for t, cols in scoped_cols.items()
}
if self.asterisk_column_order == 'alphabetic':
for cols in scoped_cols.values():
cols.sort(key=operator.attrgetter('name'))
if (lastword != word_before_cursor and len(tables) == 1
and word_before_cursor[-len(lastword) - 1] == '.'):
# User typed x.*; replicate "x." for all columns except the
# first, which gets the original (as we only replace the "*"")
sep = ', ' + word_before_cursor[:-1]
collist = sep.join(self.case(c.completion)
for c in flat_cols())
else:
collist = ', '.join(qualify(c.name, t.ref)
for t, cs in scoped_cols.items() for c in cs)
return [Match(
completion=Completion(
collist,
-1,
display_meta='columns',
display='*'
),
priority=(1, 1, 1)
)]
return self.find_matches(word_before_cursor, flat_cols(),
meta='column')
def alias(self, tbl, tbls):
""" Generate a unique table alias
tbl - name of the table to alias, quoted if it needs to be
tbls - TableReference iterable of tables already in query
"""
tbl = self.case(tbl)
tbls = set(normalize_ref(t.ref) for t in tbls)
if self.generate_aliases:
tbl = generate_alias(self.unescape_name(tbl))
if normalize_ref(tbl) not in tbls:
return tbl
elif tbl[0] == '"':
aliases = ('"' + tbl[1:-1] + str(i) + '"' for i in count(2))
else:
aliases = (tbl + str(i) for i in count(2))
return next(a for a in aliases if normalize_ref(a) not in tbls)
def get_join_matches(self, suggestion, word_before_cursor):
tbls = suggestion.table_refs
cols = self.populate_scoped_cols(tbls)
# Set up some data structures for efficient access
qualified = dict((normalize_ref(t.ref), t.schema) for t in tbls)
ref_prio = dict((normalize_ref(t.ref), n) for n, t in enumerate(tbls))
refs = set(normalize_ref(t.ref) for t in tbls)
other_tbls = set((t.schema, t.name)
for t in list(cols)[:-1])
joins = []
# Iterate over FKs in existing tables to find potential joins
fks = ((fk, rtbl, rcol) for rtbl, rcols in cols.items()
for rcol in rcols for fk in rcol.foreignkeys)
col = namedtuple('col', 'schema tbl col')
for fk, rtbl, rcol in fks:
right = col(rtbl.schema, rtbl.name, rcol.name)
child = col(fk.childschema, fk.childtable, fk.childcolumn)
parent = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
left = child if parent == right else parent
if suggestion.schema and left.schema != suggestion.schema:
continue
c = self.case
if self.generate_aliases or normalize_ref(left.tbl) in refs:
lref = self.alias(left.tbl, suggestion.table_refs)
join = '{0} {4} ON {4}.{1} = {2}.{3}'.format(
c(left.tbl), c(left.col), rtbl.ref, c(right.col), lref)
else:
join = '{0} ON {0}.{1} = {2}.{3}'.format(
c(left.tbl), c(left.col), rtbl.ref, c(right.col))
alias = generate_alias(self.case(left.tbl))
synonyms = [join, '{0} ON {0}.{1} = {2}.{3}'.format(
alias, c(left.col), rtbl.ref, c(right.col))]
# Schema-qualify if (1) new table in same schema as old, and old
# is schema-qualified, or (2) new in other schema, except public
if not suggestion.schema and (qualified[normalize_ref(rtbl.ref)]
and left.schema == right.schema
or left.schema not in(right.schema, 'public')):
join = left.schema + '.' + join
prio = ref_prio[normalize_ref(rtbl.ref)] * 2 + (
0 if (left.schema, left.tbl) in other_tbls else 1)
joins.append(Candidate(join, prio, 'join', synonyms=synonyms))
return self.find_matches(word_before_cursor, joins, meta='join')
def get_join_condition_matches(self, suggestion, word_before_cursor):
col = namedtuple('col', 'schema tbl col')
tbls = self.populate_scoped_cols(suggestion.table_refs).items
cols = [(t, c) for t, cs in tbls() for c in cs]
try:
lref = (suggestion.parent or suggestion.table_refs[-1]).ref
ltbl, lcols = [(t, cs) for (t, cs) in tbls() if t.ref == lref][-1]
except IndexError: # The user typed an incorrect table qualifier
return []
conds, found_conds = [], set()
def add_cond(lcol, rcol, rref, prio, meta):
prefix = '' if suggestion.parent else ltbl.ref + '.'
case = self.case
cond = prefix + case(lcol) + ' = ' + rref + '.' + case(rcol)
if cond not in found_conds:
found_conds.add(cond)
conds.append(Candidate(cond, prio + ref_prio[rref], meta))
def list_dict(pairs): # Turns [(a, b), (a, c)] into {a: [b, c]}
d = defaultdict(list)
for pair in pairs:
d[pair[0]].append(pair[1])
return d
# Tables that are closer to the cursor get higher prio
ref_prio = dict((tbl.ref, num) for num, tbl
in enumerate(suggestion.table_refs))
# Map (schema, table, col) to tables
coldict = list_dict(((t.schema, t.name, c.name), t)
for t, c in cols if t.ref != lref)
# For each fk from the left table, generate a join condition if
# the other table is also in the scope
fks = ((fk, lcol.name) for lcol in lcols for fk in lcol.foreignkeys)
for fk, lcol in fks:
left = col(ltbl.schema, ltbl.name, lcol)
child = col(fk.childschema, fk.childtable, fk.childcolumn)
par = col(fk.parentschema, fk.parenttable, fk.parentcolumn)
left, right = (child, par) if left == child else (par, child)
for rtbl in coldict[right]:
add_cond(left.col, right.col, rtbl.ref, 2000, 'fk join')
# For name matching, use a {(colname, coltype): TableReference} dict
coltyp = namedtuple('coltyp', 'name datatype')
col_table = list_dict((coltyp(c.name, c.datatype), t) for t, c in cols)
# Find all name-match join conditions
for c in (coltyp(c.name, c.datatype) for c in lcols):
for rtbl in (t for t in col_table[c] if t.ref != ltbl.ref):
prio = 1000 if c.datatype in (
'integer', 'bigint', 'smallint') else 0
add_cond(c.name, c.name, rtbl.ref, prio, 'name join')
return self.find_matches(word_before_cursor, conds, meta='join')
def get_function_matches(self, suggestion, word_before_cursor, alias=False):
if suggestion.usage == 'from':
# Only suggest functions allowed in FROM clause
def filt(f): return not f.is_aggregate and not f.is_window
else:
alias = False
def filt(f): return True
arg_mode = {
'signature': 'signature',
'special': None,
}.get(suggestion.usage, 'call')
# Function overloading means we way have multiple functions of the same
# name at this point, so keep unique names only
funcs = set(
self._make_cand(f, alias, suggestion, arg_mode)
for f in self.populate_functions(suggestion.schema, filt)
)
matches = self.find_matches(word_before_cursor, funcs, meta='function')
if not suggestion.schema and not suggestion.usage:
# also suggest hardcoded functions using startswith matching
predefined_funcs = self.find_matches(
word_before_cursor, self.functions, mode='strict',
meta='function')
matches.extend(predefined_funcs)
return matches
def get_schema_matches(self, suggestion, word_before_cursor):
schema_names = self.dbmetadata['tables'].keys()
# Unless we're sure the user really wants them, hide schema names
# starting with pg_, which are mostly temporary schemas
if not word_before_cursor.startswith('pg_'):
schema_names = [s
for s in schema_names
if not s.startswith('pg_')]
if suggestion.quoted:
schema_names = [self.escape_schema(s) for s in schema_names]
return self.find_matches(word_before_cursor, schema_names, meta='schema')
def get_from_clause_item_matches(self, suggestion, word_before_cursor):
alias = self.generate_aliases
s = suggestion
t_sug = Table(s.schema, s.table_refs, s.local_tables)
v_sug = View(s.schema, s.table_refs)
f_sug = Function(s.schema, s.table_refs, usage='from')
return (
self.get_table_matches(t_sug, word_before_cursor, alias)
+ self.get_view_matches(v_sug, word_before_cursor, alias)
+ self.get_function_matches(f_sug, word_before_cursor, alias)
)
def _arg_list(self, func, usage):
"""Returns a an arg list string, e.g. `(_foo:=23)` for a func.
:param func is a FunctionMetadata object
:param usage is 'call', 'call_display' or 'signature'
"""
template = {
'call': self.call_arg_style,
'call_display': self.call_arg_display_style,
'signature': self.signature_arg_style
}[usage]
args = func.args()
if not template:
return '()'
elif usage == 'call' and len(args) < 2:
return '()'
elif usage == 'call' and func.has_variadic():
return '()'
multiline = usage == 'call' and len(args) > self.call_arg_oneliner_max
max_arg_len = max(len(a.name) for a in args) if multiline else 0
args = (
self._format_arg(template, arg, arg_num + 1, max_arg_len)
for arg_num, arg in enumerate(args)
)
if multiline:
return '(' + ','.join('\n ' + a for a in args if a) + '\n)'
else:
return '(' + ', '.join(a for a in args if a) + ')'
def _format_arg(self, template, arg, arg_num, max_arg_len):
if not template:
return None
if arg.has_default:
arg_default = 'NULL' if arg.default is None else arg.default
# Remove trailing ::(schema.)type
arg_default = arg_default_type_strip_regex.sub('', arg_default)
else:
arg_default = ''
return template.format(
max_arg_len=max_arg_len,
arg_name=self.case(arg.name),
arg_num=arg_num,
arg_type=arg.datatype,
arg_default=arg_default
)
def _make_cand(self, tbl, do_alias, suggestion, arg_mode=None):
"""Returns a Candidate namedtuple.
:param tbl is a SchemaObject
:param arg_mode determines what type of arg list to suffix for functions.
Possible values: call, signature
"""
cased_tbl = self.case(tbl.name)
if do_alias:
alias = self.alias(cased_tbl, suggestion.table_refs)
synonyms = (cased_tbl, generate_alias(cased_tbl))
maybe_alias = (' ' + alias) if do_alias else ''
maybe_schema = (self.case(tbl.schema) + '.') if tbl.schema else ''
suffix = self._arg_list_cache[arg_mode][tbl.meta] if arg_mode else ''
if arg_mode == 'call':
display_suffix = self._arg_list_cache['call_display'][tbl.meta]
elif arg_mode == 'signature':
display_suffix = self._arg_list_cache['signature'][tbl.meta]
else:
display_suffix = ''
item = maybe_schema + cased_tbl + suffix + maybe_alias
display = maybe_schema + cased_tbl + display_suffix + maybe_alias
prio2 = 0 if tbl.schema else 1
return Candidate(item, synonyms=synonyms, prio2=prio2, display=display)
def get_table_matches(self, suggestion, word_before_cursor, alias=False):
tables = self.populate_schema_objects(suggestion.schema, 'tables')
tables.extend(SchemaObject(tbl.name) for tbl in suggestion.local_tables)
# Unless we're sure the user really wants them, don't suggest the
# pg_catalog tables that are implicitly on the search path
if not suggestion.schema and (
not word_before_cursor.startswith('pg_')):
tables = [t for t in tables if not t.name.startswith('pg_')]
tables = [self._make_cand(t, alias, suggestion) for t in tables]
return self.find_matches(word_before_cursor, tables, meta='table')
def get_view_matches(self, suggestion, word_before_cursor, alias=False):
views = self.populate_schema_objects(suggestion.schema, 'views')
if not suggestion.schema and (
not word_before_cursor.startswith('pg_')):
views = [v for v in views if not v.name.startswith('pg_')]
views = [self._make_cand(v, alias, suggestion) for v in views]
return self.find_matches(word_before_cursor, views, meta='view')
def get_alias_matches(self, suggestion, word_before_cursor):
aliases = suggestion.aliases
return self.find_matches(word_before_cursor, aliases,
meta='table alias')
def get_database_matches(self, _, word_before_cursor):
return self.find_matches(word_before_cursor, self.databases,
meta='database')
def get_keyword_matches(self, suggestion, word_before_cursor):
keywords = self.keywords_tree.keys()
# Get well known following keywords for the last token. If any, narrow
# candidates to this list.
next_keywords = self.keywords_tree.get(suggestion.last_token, [])
if next_keywords:
keywords = next_keywords
casing = self.keyword_casing
if casing == 'auto':
if word_before_cursor and word_before_cursor[-1].islower():
casing = 'lower'
else:
casing = 'upper'
if casing == 'upper':
keywords = [k.upper() for k in keywords]
else:
keywords = [k.lower() for k in keywords]
return self.find_matches(word_before_cursor, keywords,
mode='strict', meta='keyword')
def get_path_matches(self, _, word_before_cursor):
completer = PathCompleter(expanduser=True)
document = Document(text=word_before_cursor,
cursor_position=len(word_before_cursor))
for c in completer.get_completions(document, None):
yield Match(completion=c, priority=(0,))
def get_special_matches(self, _, word_before_cursor):
if not self.pgspecial:
return []
commands = self.pgspecial.commands
cmds = commands.keys()
cmds = [Candidate(cmd, 0, commands[cmd].description) for cmd in cmds]
return self.find_matches(word_before_cursor, cmds, mode='strict')
def get_datatype_matches(self, suggestion, word_before_cursor):
# suggest custom datatypes
types = self.populate_schema_objects(suggestion.schema, 'datatypes')
types = [self._make_cand(t, False, suggestion) for t in types]
matches = self.find_matches(word_before_cursor, types, meta='datatype')
if not suggestion.schema:
# Also suggest hardcoded types
matches.extend(self.find_matches(word_before_cursor, self.datatypes,
mode='strict', meta='datatype'))
return matches
def get_namedquery_matches(self, _, word_before_cursor):
return self.find_matches(
word_before_cursor, NamedQueries.instance.list(), meta='named query')
suggestion_matchers = {
FromClauseItem: get_from_clause_item_matches,
JoinCondition: get_join_condition_matches,
Join: get_join_matches,
Column: get_column_matches,
Function: get_function_matches,
Schema: get_schema_matches,
Table: get_table_matches,
View: get_view_matches,
Alias: get_alias_matches,
Database: get_database_matches,
Keyword: get_keyword_matches,
Special: get_special_matches,
Datatype: get_datatype_matches,
NamedQuery: get_namedquery_matches,
Path: get_path_matches,
}
def populate_scoped_cols(self, scoped_tbls, local_tbls=()):
"""Find all columns in a set of scoped_tables.
:param scoped_tbls: list of TableReference namedtuples
:param local_tbls: tuple(TableMetadata)
:return: {TableReference:{colname:ColumnMetaData}}
"""
ctes = dict((normalize_ref(t.name), t.columns) for t in local_tbls)
columns = OrderedDict()
meta = self.dbmetadata
def addcols(schema, rel, alias, reltype, cols):
tbl = TableReference(schema, rel, alias, reltype == 'functions')
if tbl not in columns:
columns[tbl] = []
columns[tbl].extend(cols)
for tbl in scoped_tbls:
# Local tables should shadow database tables
if tbl.schema is None and normalize_ref(tbl.name) in ctes:
cols = ctes[normalize_ref(tbl.name)]
addcols(None, tbl.name, 'CTE', tbl.alias, cols)
continue
schemas = [tbl.schema] if tbl.schema else self.search_path
for schema in schemas:
relname = self.escape_name(tbl.name)
schema = self.escape_name(schema)
if tbl.is_function:
# Return column names from a set-returning function
# Get an array of FunctionMetadata objects
functions = meta['functions'].get(schema, {}).get(relname)
for func in (functions or []):
# func is a FunctionMetadata object
cols = func.fields()
addcols(schema, relname, tbl.alias, 'functions', cols)
else:
for reltype in ('tables', 'views'):
cols = meta[reltype].get(schema, {}).get(relname)
if cols:
cols = cols.values()
addcols(schema, relname, tbl.alias, reltype, cols)
break
return columns
def _get_schemas(self, obj_typ, schema):
"""Returns a list of schemas from which to suggest objects.
:param schema is the schema qualification input by the user (if any)
"""
metadata = self.dbmetadata[obj_typ]
if schema:
schema = self.escape_name(schema)
return [schema] if schema in metadata else []
return self.search_path if self.search_path_filter else metadata.keys()
def _maybe_schema(self, schema, parent):
return None if parent or schema in self.search_path else schema
def populate_schema_objects(self, schema, obj_type):
"""Returns a list of SchemaObjects representing tables or views.
:param schema is the schema qualification input by the user (if any)
"""
return [
SchemaObject(
name=obj,
schema=(self._maybe_schema(schema=sch, parent=schema))
)
for sch in self._get_schemas(obj_type, schema)
for obj in self.dbmetadata[obj_type][sch].keys()
]
def populate_functions(self, schema, filter_func):
"""Returns a list of function SchemaObjects.
:param filter_func is a function that accepts a FunctionMetadata
namedtuple and returns a boolean indicating whether that
function should be kept or discarded
"""
# Because of multiple dispatch, we can have multiple functions
# with the same name, which is why `for meta in metas` is necessary
# in the comprehensions below
return [
SchemaObject(
name=func,
schema=(self._maybe_schema(schema=sch, parent=schema)),
meta=meta
)
for sch in self._get_schemas('functions', schema)
for (func, metas) in self.dbmetadata['functions'][sch].items()
for meta in metas
if filter_func(meta)
]