mirror of https://github.com/dbcli/pgcli
171 lines
5.3 KiB
Python
171 lines
5.3 KiB
Python
from collections import namedtuple
|
|
|
|
_ColumnMetadata = namedtuple(
|
|
"ColumnMetadata", ["name", "datatype", "foreignkeys", "default", "has_default"]
|
|
)
|
|
|
|
|
|
def ColumnMetadata(name, datatype, foreignkeys=None, default=None, has_default=False):
|
|
return _ColumnMetadata(name, datatype, foreignkeys or [], default, has_default)
|
|
|
|
|
|
ForeignKey = namedtuple(
|
|
"ForeignKey",
|
|
[
|
|
"parentschema",
|
|
"parenttable",
|
|
"parentcolumn",
|
|
"childschema",
|
|
"childtable",
|
|
"childcolumn",
|
|
],
|
|
)
|
|
TableMetadata = namedtuple("TableMetadata", "name columns")
|
|
|
|
|
|
def parse_defaults(defaults_string):
|
|
"""Yields default values for a function, given the string provided by
|
|
pg_get_expr(pg_catalog.pg_proc.proargdefaults, 0)"""
|
|
if not defaults_string:
|
|
return
|
|
current = ""
|
|
in_quote = None
|
|
for char in defaults_string:
|
|
if current == "" and char == " ":
|
|
# Skip space after comma separating default expressions
|
|
continue
|
|
if char == '"' or char == "'":
|
|
if in_quote and char == in_quote:
|
|
# End quote
|
|
in_quote = None
|
|
elif not in_quote:
|
|
# Begin quote
|
|
in_quote = char
|
|
elif char == "," and not in_quote:
|
|
# End of expression
|
|
yield current
|
|
current = ""
|
|
continue
|
|
current += char
|
|
yield current
|
|
|
|
|
|
class FunctionMetadata(object):
|
|
def __init__(
|
|
self,
|
|
schema_name,
|
|
func_name,
|
|
arg_names,
|
|
arg_types,
|
|
arg_modes,
|
|
return_type,
|
|
is_aggregate,
|
|
is_window,
|
|
is_set_returning,
|
|
is_extension,
|
|
arg_defaults,
|
|
):
|
|
"""Class for describing a postgresql function"""
|
|
|
|
self.schema_name = schema_name
|
|
self.func_name = func_name
|
|
|
|
self.arg_modes = tuple(arg_modes) if arg_modes else None
|
|
self.arg_names = tuple(arg_names) if arg_names else None
|
|
|
|
# Be flexible in not requiring arg_types -- use None as a placeholder
|
|
# for each arg. (Used for compatibility with old versions of postgresql
|
|
# where such info is hard to get.
|
|
if arg_types:
|
|
self.arg_types = tuple(arg_types)
|
|
elif arg_modes:
|
|
self.arg_types = tuple([None] * len(arg_modes))
|
|
elif arg_names:
|
|
self.arg_types = tuple([None] * len(arg_names))
|
|
else:
|
|
self.arg_types = None
|
|
|
|
self.arg_defaults = tuple(parse_defaults(arg_defaults))
|
|
|
|
self.return_type = return_type.strip()
|
|
self.is_aggregate = is_aggregate
|
|
self.is_window = is_window
|
|
self.is_set_returning = is_set_returning
|
|
self.is_extension = bool(is_extension)
|
|
self.is_public = self.schema_name and self.schema_name == "public"
|
|
|
|
def __eq__(self, other):
|
|
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
|
|
|
|
def __ne__(self, other):
|
|
return not self.__eq__(other)
|
|
|
|
def _signature(self):
|
|
return (
|
|
self.schema_name,
|
|
self.func_name,
|
|
self.arg_names,
|
|
self.arg_types,
|
|
self.arg_modes,
|
|
self.return_type,
|
|
self.is_aggregate,
|
|
self.is_window,
|
|
self.is_set_returning,
|
|
self.is_extension,
|
|
self.arg_defaults,
|
|
)
|
|
|
|
def __hash__(self):
|
|
return hash(self._signature())
|
|
|
|
def __repr__(self):
|
|
return (
|
|
"%s(schema_name=%r, func_name=%r, arg_names=%r, "
|
|
"arg_types=%r, arg_modes=%r, return_type=%r, is_aggregate=%r, "
|
|
"is_window=%r, is_set_returning=%r, is_extension=%r, arg_defaults=%r)"
|
|
) % ((self.__class__.__name__,) + self._signature())
|
|
|
|
def has_variadic(self):
|
|
return self.arg_modes and any(arg_mode == "v" for arg_mode in self.arg_modes)
|
|
|
|
def args(self):
|
|
"""Returns a list of input-parameter ColumnMetadata namedtuples."""
|
|
if not self.arg_names:
|
|
return []
|
|
modes = self.arg_modes or ["i"] * len(self.arg_names)
|
|
args = [
|
|
(name, typ)
|
|
for name, typ, mode in zip(self.arg_names, self.arg_types, modes)
|
|
if mode in ("i", "b", "v") # IN, INOUT, VARIADIC
|
|
]
|
|
|
|
def arg(name, typ, num):
|
|
num_args = len(args)
|
|
num_defaults = len(self.arg_defaults)
|
|
has_default = num + num_defaults >= num_args
|
|
default = (
|
|
self.arg_defaults[num - num_args + num_defaults]
|
|
if has_default
|
|
else None
|
|
)
|
|
return ColumnMetadata(name, typ, [], default, has_default)
|
|
|
|
return [arg(name, typ, num) for num, (name, typ) in enumerate(args)]
|
|
|
|
def fields(self):
|
|
"""Returns a list of output-field ColumnMetadata namedtuples"""
|
|
|
|
if self.return_type.lower() == "void":
|
|
return []
|
|
elif not self.arg_modes:
|
|
# For functions without output parameters, the function name
|
|
# is used as the name of the output column.
|
|
# E.g. 'SELECT unnest FROM unnest(...);'
|
|
return [ColumnMetadata(self.func_name, self.return_type, [])]
|
|
|
|
return [
|
|
ColumnMetadata(name, typ, [])
|
|
for name, typ, mode in zip(self.arg_names, self.arg_types, self.arg_modes)
|
|
if mode in ("o", "b", "t")
|
|
] # OUT, INOUT, TABLE
|