1
0
mirror of https://github.com/dbcli/pgcli synced 2024-05-31 01:17:54 +00:00
pgcli/pgcli/pgexecute.py

588 lines
22 KiB
Python
Raw Normal View History

2015-08-18 07:15:43 +00:00
import traceback
import logging
import itertools
2015-02-08 22:54:57 +00:00
import psycopg2
import psycopg2.extras
import psycopg2.extensions as ext
import sqlparse
import pgspecial as special
from .packages.function_metadata import FunctionMetadata, ForeignKey
2016-05-14 04:25:15 +00:00
from .encodingutils import unicode2utf8, PY2, utf8tounicode
_logger = logging.getLogger(__name__)
# Cast all database input to unicode automatically.
# See http://initd.org/psycopg/docs/usage.html#unicode-handling for more info.
ext.register_type(ext.UNICODE)
ext.register_type(ext.UNICODEARRAY)
ext.register_type(ext.new_type((705,), "UNKNOWN", ext.UNICODE))
# See https://github.com/dbcli/pgcli/issues/426 for more details.
# This registers a unicode type caster for datatype 'RECORD'.
ext.register_type(ext.new_type((2249,), "RECORD", ext.UNICODE))
2015-01-22 03:50:24 +00:00
# Cast bytea fields to text. By default, this will render as hex strings with
# Postgres 9+ and as escaped binary in earlier versions.
ext.register_type(ext.new_type((17,), 'BYTEA_TEXT', psycopg2.STRING))
2015-01-22 03:50:24 +00:00
2015-01-09 09:56:22 +00:00
# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
ext.set_wait_callback(psycopg2.extras.wait_select)
2016-02-02 21:41:58 +00:00
def register_date_typecasters(connection):
"""
Casts date and timestamp values to string, resolves issues with out of
range dates (e.g. BC) which psycopg2 can't handle
"""
def cast_date(value, cursor):
return value
cursor = connection.cursor()
cursor.execute('SELECT NULL::date')
date_oid = cursor.description[0][1]
cursor.execute('SELECT NULL::timestamp')
timestamp_oid = cursor.description[0][1]
cursor.execute('SELECT NULL::timestamptz')
timestamptz_oid = cursor.description[0][1]
oids = (date_oid, timestamp_oid, timestamptz_oid)
2016-02-02 21:41:58 +00:00
new_type = psycopg2.extensions.new_type(oids, 'DATE', cast_date)
psycopg2.extensions.register_type(new_type)
def register_json_typecasters(conn, loads_fn):
"""Set the function for converting JSON data for a connection.
Use the supplied function to decode JSON data returned from the database
via the given connection. The function should accept a single argument of
the data as a string encoded in the database's character encoding.
psycopg2's default handler for JSON data is json.loads.
http://initd.org/psycopg/docs/extras.html#json-adaptation
This function attempts to register the typecaster for both JSON and JSONB
types.
Returns a set that is a subset of {'json', 'jsonb'} indicating which types
(if any) were successfully registered.
"""
available = set()
for name in ['json', 'jsonb']:
try:
psycopg2.extras.register_json(conn, loads=loads_fn, name=name)
available.add(name)
except psycopg2.ProgrammingError:
pass
return available
2015-03-04 07:14:04 +00:00
def register_hstore_typecaster(conn):
"""
Instead of using register_hstore() which converts hstore into a python
dict, we query the 'oid' of hstore which will be different for each
database and register a type caster that converts it to unicode.
http://initd.org/psycopg/docs/extras.html#psycopg2.extras.register_hstore
"""
with conn.cursor() as cur:
try:
cur.execute("SELECT 'hstore'::regtype::oid")
oid = cur.fetchone()[0]
ext.register_type(ext.new_type((oid,), "HSTORE", ext.UNICODE))
except Exception:
pass
2014-11-22 07:43:11 +00:00
class PGExecute(object):
2014-11-23 23:02:05 +00:00
# The boolean argument to the current_schemas function indicates whether
# implicit schemas, e.g. pg_catalog
search_path_query = '''
SELECT * FROM unnest(current_schemas(true))'''
schemata_query = '''
SELECT nspname
FROM pg_catalog.pg_namespace
ORDER BY 1 '''
2015-01-18 19:32:30 +00:00
tables_query = '''
SELECT n.nspname schema_name,
c.relname table_name
FROM pg_catalog.pg_class c
2015-01-18 19:32:30 +00:00
LEFT JOIN pg_catalog.pg_namespace n
ON n.oid = c.relnamespace
WHERE c.relkind = ANY(%s)
2015-01-18 19:32:30 +00:00
ORDER BY 1,2;'''
databases_query = '''
SELECT d.datname
FROM pg_catalog.pg_database d
ORDER BY 1'''
def __init__(self, database, user, password, host, port, dsn):
2015-01-24 06:31:37 +00:00
self.dbname = database
self.user = user
self.password = password
self.host = host
self.port = port
self.dsn = dsn
self.connect()
2015-01-09 09:56:22 +00:00
def connect(self, database=None, user=None, password=None, host=None,
port=None, dsn=None):
db = (database or self.dbname)
user = (user or self.user)
password = (password or self.password)
host = (host or self.host)
port = (port or self.port)
dsn = (dsn or self.dsn)
if dsn:
if password:
dsn = "{0} password={1}".format(dsn, password)
conn = psycopg2.connect(dsn=unicode2utf8(dsn))
cursor = conn.cursor()
# When we connect using a DSN, we don't really know what db,
# user, etc. we connected to. Let's read it.
db = self._select_one(cursor, 'select current_database()')
user = self._select_one(cursor, 'select current_user')
host = self._select_one(cursor, 'select inet_server_addr()')
port = self._select_one(cursor, 'select inet_server_port()')
else:
conn = psycopg2.connect(
database=unicode2utf8(db),
user=unicode2utf8(user),
password=unicode2utf8(password),
host=unicode2utf8(host),
port=unicode2utf8(port))
conn.set_client_encoding('utf8')
if hasattr(self, 'conn'):
self.conn.close()
self.conn = conn
self.conn.autocommit = True
2015-06-05 00:26:02 +00:00
self.dbname = db
self.user = user
self.password = password
self.host = host
self.port = port
2016-02-02 21:41:58 +00:00
register_date_typecasters(conn)
register_json_typecasters(self.conn, self._json_typecaster)
2015-03-04 07:14:04 +00:00
register_hstore_typecaster(self.conn)
def _select_one(self, cur, sql):
"""
Helper method to run a select and retrieve a single field value
:param cur: cursor
:param sql: string
:return: string
"""
cur.execute(sql)
2015-08-16 17:12:27 +00:00
return cur.fetchone()
def _json_typecaster(self, json_data):
"""Interpret incoming JSON data as a string.
The raw data is decoded using the connection's encoding, which defaults
to the database's encoding.
See http://initd.org/psycopg/docs/connection.html#connection.encoding
"""
if PY2:
return json_data.decode(self.conn.encoding)
else:
return json_data
2014-11-22 07:43:11 +00:00
def run(self, statement, pgspecial=None, exception_formatter=None,
on_error_resume=False):
"""Execute the sql in the database and return the results.
:param statement: A string containing one or more sql statements
:param pgspecial: PGSpecial object
:param exception_formatter: A callable that accepts an Exception and
returns a formatted (title, rows, headers, status) tuple that can
act as a query result. If an exception_formatter is not supplied,
psycopg2 exceptions are always raised.
:param on_error_resume: Bool. If true, queries following an exception
(assuming exception_formatter has been supplied) continue to
execute.
:return: Generator yielding tuples containing
(title, rows, headers, status, query, success)
"""
# Remove spaces and EOL
statement = statement.strip()
if not statement: # Empty string
yield (None, None, None, None, statement, False)
# Split the sql into separate queries and run each one.
for sql in sqlparse.split(statement):
# Remove spaces, eol and semi-colons.
sql = sql.rstrip(';')
try:
if pgspecial:
# First try to run each query as special
_logger.debug('Trying a pgspecial command. sql: %r', sql)
cur = self.conn.cursor()
try:
for result in pgspecial.execute(cur, sql):
# e.g. execute_from_file already appends these
if len(result) < 6:
yield result + (sql, True)
else:
yield result
continue
except special.CommandNotFound:
pass
# Not a special command, so execute as normal sql
yield self.execute_normal_sql(sql) + (sql, True)
except psycopg2.DatabaseError as e:
_logger.error("sql: %r, error: %r", sql, e)
_logger.error("traceback: %r", traceback.format_exc())
if (isinstance(e, psycopg2.OperationalError)
or not exception_formatter):
# Always raise operational errors, regardless of on_error
# specification
raise
yield None, None, None, exception_formatter(e), sql, False
if not on_error_resume:
break
def execute_normal_sql(self, split_sql):
"""Returns tuple (title, rows, headers, status)"""
_logger.debug('Regular sql statement. sql: %r', split_sql)
cur = self.conn.cursor()
cur.execute(split_sql)
2015-08-18 07:15:43 +00:00
2015-12-06 15:52:17 +00:00
# conn.notices persist between queies, we use pop to clear out the list
title = ''
while len(self.conn.notices) > 0:
title = utf8tounicode(self.conn.notices.pop()) + title
# cur.description will be None for operations that do not return
# rows.
if cur.description:
headers = [x[0] for x in cur.description]
return title, cur, headers, cur.statusmessage
else:
_logger.debug('No rows in result.')
return title, None, None, cur.statusmessage
2014-11-22 07:43:11 +00:00
def search_path(self):
"""Returns the current search path as a list of schema names"""
2015-01-18 19:32:30 +00:00
try:
with self.conn.cursor() as cur:
_logger.debug('Search path query. sql: %r', self.search_path_query)
cur.execute(self.search_path_query)
return [x[0] for x in cur.fetchall()]
except psycopg2.ProgrammingError:
fallback = 'SELECT * FROM current_schemas(true)'
with self.conn.cursor() as cur:
_logger.debug('Search path query. sql: %r', fallback)
cur.execute(fallback)
return cur.fetchone()[0]
2015-01-18 19:32:30 +00:00
def schemata(self):
"""Returns a list of schema names in the database"""
2015-01-18 19:32:30 +00:00
with self.conn.cursor() as cur:
_logger.debug('Schemata Query. sql: %r', self.schemata_query)
cur.execute(self.schemata_query)
return [x[0] for x in cur.fetchall()]
def _relations(self, kinds=('r', 'v', 'm')):
"""Get table or view name metadata
:param kinds: list of postgres relkind filters:
'r' - table
'v' - view
'm' - materialized view
:return: (schema_name, rel_name) tuples
"""
with self.conn.cursor() as cur:
sql = cur.mogrify(self.tables_query, [kinds])
_logger.debug('Tables Query. sql: %r', sql)
cur.execute(sql)
for row in cur:
yield row
def tables(self):
"""Yields (schema_name, table_name) tuples"""
for row in self._relations(kinds=['r']):
yield row
def views(self):
"""Yields (schema_name, view_name) tuples.
Includes both views and and materialized views
"""
for row in self._relations(kinds=['v', 'm']):
yield row
def _columns(self, kinds=('r', 'v', 'm')):
"""Get column metadata for tables and views
:param kinds: kinds: list of postgres relkind filters:
'r' - table
'v' - view
'm' - materialized view
:return: list of (schema_name, relation_name, column_name, column_type) tuples
"""
if self.conn.server_version >= 80400:
columns_query = '''
SELECT nsp.nspname schema_name,
cls.relname table_name,
att.attname column_name,
att.atttypid::regtype::text type_name
FROM pg_catalog.pg_attribute att
INNER JOIN pg_catalog.pg_class cls
ON att.attrelid = cls.oid
INNER JOIN pg_catalog.pg_namespace nsp
ON cls.relnamespace = nsp.oid
WHERE cls.relkind = ANY(%s)
AND NOT att.attisdropped
AND att.attnum > 0
ORDER BY 1, 2, att.attnum'''
else:
columns_query = '''
SELECT nsp.nspname schema_name,
cls.relname table_name,
att.attname column_name,
typ.typname type_name
FROM pg_catalog.pg_attribute att
INNER JOIN pg_catalog.pg_class cls
ON att.attrelid = cls.oid
INNER JOIN pg_catalog.pg_namespace nsp
ON cls.relnamespace = nsp.oid
INNER JOIN pg_catalog.pg_type typ
ON typ.oid = att.atttypid
WHERE cls.relkind = ANY(%s)
AND NOT att.attisdropped
AND att.attnum > 0
ORDER BY 1, 2, att.attnum'''
2015-01-18 19:32:30 +00:00
with self.conn.cursor() as cur:
sql = cur.mogrify(columns_query, [kinds])
_logger.debug('Columns Query. sql: %r', sql)
cur.execute(sql)
for row in cur:
yield row
def table_columns(self):
for row in self._columns(kinds=['r']):
yield row
def view_columns(self):
for row in self._columns(kinds=['v', 'm']):
yield row
def databases(self):
with self.conn.cursor() as cur:
_logger.debug('Databases Query. sql: %r', self.databases_query)
cur.execute(self.databases_query)
return [x[0] for x in cur.fetchall()]
2015-02-17 16:38:22 +00:00
def foreignkeys(self):
"""Yields ForeignKey named tuples"""
if self.conn.server_version < 90000:
return
with self.conn.cursor() as cur:
query = '''
SELECT s_p.nspname AS parentschema,
t_p.relname AS parenttable,
unnest((
select
array_agg(attname ORDER BY i)
from
(select unnest(confkey) as attnum, generate_subscripts(confkey, 1) as i) x
JOIN pg_catalog.pg_attribute c USING(attnum)
WHERE c.attrelid = fk.confrelid
)) AS parentcolumn,
s_c.nspname AS childschema,
t_c.relname AS childtable,
unnest((
select
array_agg(attname ORDER BY i)
from
(select unnest(conkey) as attnum, generate_subscripts(conkey, 1) as i) x
JOIN pg_catalog.pg_attribute c USING(attnum)
WHERE c.attrelid = fk.conrelid
)) AS childcolumn
FROM pg_catalog.pg_constraint fk
JOIN pg_catalog.pg_class t_p ON t_p.oid = fk.confrelid
JOIN pg_catalog.pg_namespace s_p ON s_p.oid = t_p.relnamespace
JOIN pg_catalog.pg_class t_c ON t_c.oid = fk.conrelid
JOIN pg_catalog.pg_namespace s_c ON s_c.oid = t_c.relnamespace
WHERE fk.contype = 'f';
'''
_logger.debug('Functions Query. sql: %r', query)
cur.execute(query)
for row in cur:
yield ForeignKey(*row)
2015-02-17 16:38:22 +00:00
def functions(self):
"""Yields FunctionMetadata named tuples"""
2015-02-17 16:38:22 +00:00
if self.conn.server_version > 90000:
query = '''
SELECT n.nspname schema_name,
p.proname func_name,
p.proargnames,
COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[],
p.proargmodes,
prorettype::regtype::text return_type,
p.proisagg is_aggregate,
p.proiswindow is_window,
p.proretset is_set_returning
FROM pg_catalog.pg_proc p
INNER JOIN pg_catalog.pg_namespace n
ON n.oid = p.pronamespace
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
'''
elif self.conn.server_version >= 80400:
query = '''
SELECT n.nspname schema_name,
p.proname func_name,
p.proargnames,
COALESCE(proallargtypes::regtype[], proargtypes::regtype[])::text[],
p.proargmodes,
prorettype::regtype::text,
p.proisagg is_aggregate,
false is_window,
p.proretset is_set_returning
FROM pg_catalog.pg_proc p
INNER JOIN pg_catalog.pg_namespace n
ON n.oid = p.pronamespace
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
'''
else:
query = '''
SELECT n.nspname schema_name,
p.proname func_name,
p.proargnames,
NULL,
p.proargmodes,
'',
p.proisagg is_aggregate,
false is_window,
p.proretset is_set_returning
FROM pg_catalog.pg_proc p
INNER JOIN pg_catalog.pg_namespace n
ON n.oid = p.pronamespace
WHERE p.prorettype::regtype != 'trigger'::regtype
ORDER BY 1, 2
'''
2015-02-17 16:38:22 +00:00
with self.conn.cursor() as cur:
_logger.debug('Functions Query. sql: %r', query)
cur.execute(query)
for row in cur:
yield FunctionMetadata(*row)
def datatypes(self):
"""Yields tuples of (schema_name, type_name)"""
with self.conn.cursor() as cur:
if self.conn.server_version > 90000:
query = '''
SELECT n.nspname schema_name,
t.typname type_name
FROM pg_catalog.pg_type t
INNER JOIN pg_catalog.pg_namespace n
ON n.oid = t.typnamespace
WHERE ( t.typrelid = 0 -- non-composite types
OR ( -- composite type, but not a table
SELECT c.relkind = 'c'
FROM pg_catalog.pg_class c
WHERE c.oid = t.typrelid
)
)
AND NOT EXISTS( -- ignore array types
SELECT 1
FROM pg_catalog.pg_type el
WHERE el.oid = t.typelem AND el.typarray = t.oid
)
AND n.nspname <> 'pg_catalog'
AND n.nspname <> 'information_schema'
ORDER BY 1, 2;
'''
else:
query = '''
SELECT n.nspname schema_name,
pg_catalog.format_type(t.oid, NULL) type_name
FROM pg_catalog.pg_type t
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
WHERE (t.typrelid = 0 OR (SELECT c.relkind = 'c' FROM pg_catalog.pg_class c WHERE c.oid = t.typrelid))
AND t.typname !~ '^_'
AND n.nspname <> 'pg_catalog'
AND n.nspname <> 'information_schema'
AND pg_catalog.pg_type_is_visible(t.oid)
ORDER BY 1, 2;
'''
_logger.debug('Datatypes Query. sql: %r', query)
cur.execute(query)
for row in cur:
yield row
def casing(self):
"""Yields the most common casing for names used in db functions"""
with self.conn.cursor() as cur:
query = '''
WITH Words AS (
SELECT regexp_split_to_table(prosrc, '\W+') AS Word, COUNT(1)
FROM pg_catalog.pg_proc P
JOIN pg_catalog.pg_namespace N ON N.oid = P.pronamespace
JOIN pg_catalog.pg_language L ON L.oid = P.prolang
WHERE L.lanname IN ('sql', 'plpgsql')
AND N.nspname NOT IN ('pg_catalog', 'information_schema')
GROUP BY Word
),
OrderWords AS (
SELECT Word,
ROW_NUMBER() OVER(PARTITION BY LOWER(Word) ORDER BY Count DESC)
FROM Words
WHERE Word ~* '.*[a-z].*'
),
Names AS (
--Column names
SELECT attname AS Name
FROM pg_catalog.pg_attribute
UNION -- Table/view names
SELECT relname
FROM pg_catalog.pg_class
UNION -- Function names
SELECT proname
FROM pg_catalog.pg_proc
UNION -- Type names
SELECT typname
FROM pg_catalog.pg_type
UNION -- Schema names
SELECT nspname
FROM pg_catalog.pg_namespace
)
SELECT Word
FROM OrderWords
WHERE LOWER(Word) IN (SELECT Name FROM Names)
AND Row_Number = 1;
'''
_logger.debug('Casing Query. sql: %r', query)
cur.execute(query)
for row in cur:
yield row[0]