1
0
mirror of https://github.com/dbcli/pgcli synced 2024-06-16 01:42:23 +00:00
pgcli/pgcli/pgexecute.py

192 lines
7.3 KiB
Python
Raw Normal View History

import logging
2014-11-22 07:43:11 +00:00
import psycopg2
2015-01-09 09:56:22 +00:00
import psycopg2.extras
import psycopg2.extensions
import sqlparse
2015-01-18 19:32:30 +00:00
from pandas import DataFrame
from .packages import pgspecial
2014-11-22 07:43:11 +00:00
_logger = logging.getLogger(__name__)
# Cast all database input to unicode automatically.
# See http://initd.org/psycopg/docs/usage.html#unicode-handling for more info.
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE)
psycopg2.extensions.register_type(psycopg2.extensions.UNICODEARRAY)
2015-01-09 09:56:22 +00:00
# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
2015-01-09 09:56:22 +00:00
psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select)
2014-12-13 02:28:55 +00:00
def _parse_dsn(dsn, default_user, default_password, default_host,
default_port):
"""
2014-12-14 22:18:26 +00:00
This function parses a postgres url to get the different components.
2014-12-13 02:28:55 +00:00
"""
user = password = host = port = dbname = None
2014-12-14 22:18:26 +00:00
if dsn.startswith('postgres://'): # Check if the string is a database url.
dsn = dsn[len('postgres://'):]
elif dsn.startswith('postgresql://'):
dsn = dsn[len('postgresql://'):]
2014-12-13 02:28:55 +00:00
2014-12-14 22:18:26 +00:00
if '/' in dsn:
2014-12-13 02:28:55 +00:00
host, dbname = dsn.split('/', 1)
if '@' in host:
user, _, host = host.partition('@')
if ':' in host:
host, _, port = host.partition(':')
2014-12-14 22:18:26 +00:00
if user and ':' in user:
2014-12-13 02:28:55 +00:00
user, _, password = user.partition(':')
user = user or default_user
password = password or default_password
host = host or default_host
port = port or default_port
dbname = dbname or dsn
_logger.debug('Parsed connection params:'
'dbname: %r, user: %r, password: %r, host: %r, port: %r',
dbname, user, password, host, port)
2014-12-13 02:28:55 +00:00
return (dbname, user, password, host, port)
2014-11-22 07:43:11 +00:00
class PGExecute(object):
2014-11-23 23:02:05 +00:00
2015-01-18 19:32:30 +00:00
tables_query = '''
SELECT n.nspname schema_name,
c.relname table_name,
pg_catalog.pg_table_is_visible(c.oid) is_visible
FROM pg_catalog.pg_class c
LEFT JOIN pg_catalog.pg_namespace n
ON n.oid = c.relnamespace
WHERE c.relkind IN ('r','v', 'm') -- table, view, materialized view
AND n.nspname !~ '^pg_toast'
AND n.nspname NOT IN ('information_schema', 'pg_catalog')
ORDER BY 1,2;'''
columns_query = '''
SELECT nsp.nspname schema_name,
cls.relname table_name,
att.attname column_name
FROM pg_catalog.pg_attribute att
INNER JOIN pg_catalog.pg_class cls
ON att.attrelid = cls.oid
INNER JOIN pg_catalog.pg_namespace nsp
ON cls.relnamespace = nsp.oid
WHERE cls.relkind IN ('r', 'v', 'm')
AND nsp.nspname !~ '^pg_'
AND nsp.nspname <> 'information_schema'
AND NOT att.attisdropped
AND att.attnum > 0
ORDER BY 1, 2, 3'''
2014-11-23 23:02:05 +00:00
2015-01-05 00:04:21 +00:00
databases_query = """SELECT d.datname as "Name",
pg_catalog.pg_get_userbyid(d.datdba) as "Owner",
pg_catalog.pg_encoding_to_char(d.encoding) as "Encoding",
d.datcollate as "Collate",
d.datctype as "Ctype",
pg_catalog.array_to_string(d.datacl, E'\n') AS "Access privileges"
FROM pg_catalog.pg_database d
ORDER BY 1;"""
def __init__(self, database, user, password, host, port):
2014-12-13 02:28:55 +00:00
(self.dbname, self.user, self.password, self.host, self.port) = \
_parse_dsn(database, default_user=user,
default_password=password, default_host=host,
default_port=port)
self.connect()
2015-01-09 09:56:22 +00:00
def connect(self, database=None, user=None, password=None, host=None,
port=None):
conn = psycopg2.connect(database=database or self.dbname, user=user or
self.user, password=password or self.password, host=host or
self.host, port=port or self.port)
if hasattr(self, 'conn'):
self.conn.close()
self.conn = conn
self.conn.autocommit = True
2014-11-22 07:43:11 +00:00
def run(self, sql):
"""Execute the sql in the database and return the results. The results
are a list of tuples. Each tuple has 3 values (rows, headers, status).
"""
# Remove spaces and EOL
sql = sql.strip()
if not sql: # Empty string
return [(None, None, None)]
# Remove spaces, eol and semi-colons.
sql = sql.rstrip(';')
# Check if the command is a \c or 'use'. This is a special exception
# that cannot be offloaded to `pgspecial` lib. Because we have to
# change the database connection that we're connected to.
if sql.startswith('\c') or sql.lower().startswith('use'):
2015-01-05 00:29:15 +00:00
_logger.debug('Database change command detected.')
2014-12-12 04:07:54 +00:00
try:
dbname = sql.split()[1]
except:
_logger.debug('Database name missing.')
2014-12-12 04:07:54 +00:00
raise RuntimeError('Database name missing.')
self.connect(database=dbname)
self.dbname = dbname
2015-01-05 00:29:15 +00:00
_logger.debug('Successfully switched to DB: %r', dbname)
return [(None, None, 'You are now connected to database "%s" as '
'user "%s"' % (self.dbname, self.user))]
try: # Special command
_logger.debug('Trying a pgspecial command. sql: %r', sql)
cur = self.conn.cursor()
return pgspecial.execute(cur, sql)
except KeyError: # Regular SQL
# Split the sql into separate queries and run each one. If any
# single query fails, the rest of them are not run and no results
# are shown.
queries = sqlparse.split(sql)
return [self.execute_normal_sql(query) for query in queries]
def execute_normal_sql(self, split_sql):
_logger.debug('Regular sql statement. sql: %r', split_sql)
cur = self.conn.cursor()
cur.execute(split_sql)
# cur.description will be None for operations that do not return
# rows.
if cur.description:
headers = [x[0] for x in cur.description]
return (cur, headers, cur.statusmessage)
else:
_logger.debug('No rows in result.')
return (None, None, cur.statusmessage)
2014-11-22 07:43:11 +00:00
2015-01-18 19:32:30 +00:00
def get_metadata(self):
""" Returns a tuple [tables, columns] of DataFrames
tables: DataFrame with columns [schema, table, is_visible]
columns: DataFrame with columns [schema, table, column]
"""
with self.conn.cursor() as cur:
_logger.debug('Tables Query. sql: %r', self.tables_query)
cur.execute(self.tables_query)
2015-01-18 19:32:30 +00:00
tables = DataFrame.from_records(cur,
columns=['schema', 'table', 'is_visible'])
2015-01-18 19:32:30 +00:00
with self.conn.cursor() as cur:
_logger.debug('Columns Query. sql: %r', self.columns_query)
cur.execute(self.columns_query)
2015-01-18 19:32:30 +00:00
columns = DataFrame.from_records(cur,
columns=['schema', 'table', 'column'])
return [tables, columns]
def databases(self):
with self.conn.cursor() as cur:
_logger.debug('Databases Query. sql: %r', self.databases_query)
cur.execute(self.databases_query)
return [x[0] for x in cur.fetchall()]