diff --git a/pgcli/main.py b/pgcli/main.py index f98bff3f..2170324a 100755 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -118,12 +118,11 @@ class PGCli(object): def connect(self, database='', host='', user='', port='', passwd=''): # Connect to the database. + if not user: + user = getuser() + if not database: - if user: - database = user - else: - # default to current OS username just like psql - database = user = getuser() + database = user # Prompt for a password immediately if requested via the -W flag. This # avoids wasting time trying to connect to the database and catching a diff --git a/pgcli/packages/pgspecial/dbcommands.py b/pgcli/packages/pgspecial/dbcommands.py index 8be3c203..6e7c8b16 100644 --- a/pgcli/packages/pgspecial/dbcommands.py +++ b/pgcli/packages/pgspecial/dbcommands.py @@ -8,16 +8,16 @@ TableInfo = namedtuple("TableInfo", ['checks', 'relkind', 'hasindex', log = logging.getLogger(__name__) -def change_db(cur, arg, verbose, db_obj): - if arg is None: +def change_db(cur, pattern, verbose, db_obj): + if pattern is None: db_obj.connect() else: - db_obj.connect(database=arg) + db_obj.connect(database=pattern) yield (None, None, None, 'You are now connected to database "%s" as ' 'user "%s"' % (db_obj.dbname, db_obj.user)) -def list_roles(cur, pattern, verbose): +def list_roles(cur, pattern, verbose, **kwargs): """ Returns (title, rows, headers, status) """ @@ -45,7 +45,7 @@ def list_roles(cur, pattern, verbose): headers = [x[0] for x in cur.description] return [(None, cur, headers, cur.statusmessage)] -def list_schemas(cur, pattern, verbose): +def list_schemas(cur, pattern, verbose, **kwargs): """ Returns (title, rows, headers, status) """ @@ -130,21 +130,21 @@ def list_objects(cur, pattern, verbose, relkinds): return [(None, cur, headers, cur.statusmessage)] -def list_tables(cur, pattern, verbose): +def list_tables(cur, pattern, verbose, **kwargs): return list_objects(cur, pattern, verbose, ['r', '']) -def list_views(cur, pattern, verbose): +def list_views(cur, pattern, verbose, **kwargs): return list_objects(cur, pattern, verbose, ['v', 's', '']) -def list_sequences(cur, pattern, verbose): +def list_sequences(cur, pattern, verbose, **kwargs): return list_objects(cur, pattern, verbose, ['S', 's', '']) -def list_indexes(cur, pattern, verbose): +def list_indexes(cur, pattern, verbose, **kwargs): return list_objects(cur, pattern, verbose, ['i', 's', '']) -def list_functions(cur, pattern, verbose): +def list_functions(cur, pattern, verbose, **kwargs): if verbose: verbose_columns = ''' @@ -209,7 +209,7 @@ def list_functions(cur, pattern, verbose): return [(None, cur, headers, cur.statusmessage)] -def list_datatypes(cur, pattern, verbose): +def list_datatypes(cur, pattern, verbose, **kwargs): assert True sql = '''SELECT n.nspname as "Schema", pg_catalog.format_type(t.oid, NULL) AS "Name", ''' @@ -276,7 +276,7 @@ def list_datatypes(cur, pattern, verbose): headers = [x[0] for x in cur.description] return [(None, cur, headers, cur.statusmessage)] -def describe_table_details(cur, pattern, verbose): +def describe_table_details(cur, pattern, verbose, **kwargs): """ Returns (title, rows, headers, status) """ diff --git a/pgcli/packages/pgspecial/iocommands.py b/pgcli/packages/pgspecial/iocommands.py index 1facb512..1bd325f5 100644 --- a/pgcli/packages/pgspecial/iocommands.py +++ b/pgcli/packages/pgspecial/iocommands.py @@ -16,14 +16,14 @@ use_expanded_output = False def is_expanded_output(): return use_expanded_output -def toggle_expanded_output(cur, arg, verbose): +def toggle_expanded_output(cur, arg, verbose, **kwargs): global use_expanded_output use_expanded_output = not use_expanded_output message = u"Expanded display is " message += u"on." if use_expanded_output else u"off." return [(None, None, None, message)] -def toggle_timing(cur, arg, verbose): +def toggle_timing(cur, arg, verbose, **kwargs): global TIMING_ENABLED TIMING_ENABLED = not TIMING_ENABLED message = "Timing is " @@ -89,7 +89,7 @@ def open_external_editor(filename=None, sql=''): return (query, message) -def execute_named_query(cur, arg, verbose): +def execute_named_query(cur, arg, verbose, **kwargs): """Returns (title, rows, headers, status)""" if arg == '': return list_named_queries(cur, arg, verbose) @@ -106,7 +106,7 @@ def execute_named_query(cur, arg, verbose): else: return [(title, None, None, cur.statusmessage)] -def list_named_queries(cur, arg, verbose): +def list_named_queries(cur, arg, verbose, **kwargs): """List of all named queries. Returns (title, rows, headers, status)""" if not verbose: @@ -117,7 +117,7 @@ def list_named_queries(cur, arg, verbose): rows = [[r, namedqueries.get(r)] for r in namedqueries.list()] return [('', rows, headers, "")] -def save_named_query(cur, arg, verbose): +def save_named_query(cur, arg, verbose, **kwargs): """Save a new named query. Returns (title, rows, headers, status)""" if ' ' not in arg: @@ -126,7 +126,7 @@ def save_named_query(cur, arg, verbose): namedqueries.save(name, query) return [(None, None, None, "Saved.")] -def delete_named_query(cur, arg, verbose): +def delete_named_query(cur, arg, verbose, **kwargs): """Delete an existing named query. """ if len(arg) == 0: diff --git a/pgcli/packages/pgspecial/main.py b/pgcli/packages/pgspecial/main.py index d60dd0c2..8ebaf4a0 100644 --- a/pgcli/packages/pgspecial/main.py +++ b/pgcli/packages/pgspecial/main.py @@ -2,7 +2,7 @@ from . import export from .iocommands import * from .dbcommands import * -def show_help(*args): # All the parameters are ignored. +def show_help(**kwargs): # All the parameters are ignored. headers = ['Command', 'Description'] result = [] @@ -27,7 +27,7 @@ def in_progress(*args): COMMANDS = { '\?': (show_help, ['\?', 'Help on pgcli commands.']), - '\c': (change_db, ['\c database_name', 'Connect to a new database.']), + '\c': (change_db, ['\c[onnect] database_name', 'Connect to a new database.']), '\l': ('''SELECT datname FROM pg_database;''', ['\l', 'List databases.']), '\d': (describe_table_details, ['\d [pattern]', 'List or describe tables, views and sequences.']), '\dn': (list_schemas, ['\dn[+] [pattern]', 'List schemas.']), @@ -53,6 +53,8 @@ COMMANDS = { # Commands not shown via help. HIDDEN_COMMANDS = { 'describe': (describe_table_details, ['DESCRIBE [pattern]', '']), + 'use': (change_db, ['\c database_name', 'Connect to a new database.']), + '\connect': (change_db, ['\c database_name', 'Connect to a new database.']), } @export @@ -80,7 +82,7 @@ def execute(cur=None, sql='', db_obj=None): # If the command executor is a function, then call the function with the # args. If it's a string, then assume it's an SQL command and run it. if callable(command_executor): - return command_executor(cur, arg, verbose) + return command_executor(cur=cur, pattern=arg, verbose=verbose, db_obj=db_obj) elif isinstance(command_executor, str): cur.execute(command_executor) if cur.description: diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index bec601fe..97d30015 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -162,6 +162,11 @@ class PGExecute(object): self.conn.close() self.conn = conn self.conn.autocommit = True + self.dbname = db + self.user = user + self.password = password + self.host = host + self.port = port register_json_typecasters(self.conn, self._json_typecaster) register_hstore_typecaster(self.conn) @@ -194,30 +199,13 @@ class PGExecute(object): # Remove spaces, eol and semi-colons. sql = sql.rstrip(';') - # Check if the command is a \c or 'use'. This is a special - # exception that cannot be offloaded to `pgspecial` lib. Because we - # have to change the database connection that we're connected to. - command = sql.split()[0] - if command == '\c' or command == '\connect' or command.lower() == 'use': - _logger.debug('Database change command detected.') - try: - dbname = sql.split()[1] - except: - _logger.debug('Database name missing.') - raise RuntimeError('Database name missing.') - self.connect(database=dbname) - self.dbname = dbname - _logger.debug('Successfully switched to DB: %r', dbname) - yield (None, None, None, 'You are now connected to database "%s" as ' - 'user "%s"' % (self.dbname, self.user)) - else: - try: # Special command - _logger.debug('Trying a pgspecial command. sql: %r', sql) - cur = self.conn.cursor() - for result in special.execute(cur, sql): - yield result - except KeyError: # Regular SQL - yield self.execute_normal_sql(sql) + try: # Special command + _logger.debug('Trying a pgspecial command. sql: %r', sql) + cur = self.conn.cursor() + for result in special.execute(cur, sql, self): + yield result + except KeyError: # Regular SQL + yield self.execute_normal_sql(sql) def execute_normal_sql(self, split_sql): _logger.debug('Regular sql statement. sql: %r', split_sql)