1
0
Fork 0
pgcli/pgcli/main.py

1529 lines
53 KiB
Python

import platform
import warnings
from os.path import expanduser
from configobj import ConfigObj
from pgspecial.namedqueries import NamedQueries
warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2")
import os
import re
import sys
import traceback
import logging
import threading
import shutil
import functools
import pendulum
import datetime as dt
import itertools
import platform
from time import time, sleep
from codecs import open
keyring = None # keyring will be loaded later
from cli_helpers.tabular_output import TabularOutputFormatter
from cli_helpers.tabular_output.preprocessors import align_decimals, format_numbers
import click
try:
import setproctitle
except ImportError:
setproctitle = None
from prompt_toolkit.completion import DynamicCompleter, ThreadedCompleter
from prompt_toolkit.enums import DEFAULT_BUFFER, EditingMode
from prompt_toolkit.shortcuts import PromptSession, CompleteStyle
from prompt_toolkit.document import Document
from prompt_toolkit.filters import HasFocus, IsDone
from prompt_toolkit.formatted_text import ANSI
from prompt_toolkit.lexers import PygmentsLexer
from prompt_toolkit.layout.processors import (
ConditionalProcessor,
HighlightMatchingBracketProcessor,
TabsProcessor,
)
from prompt_toolkit.history import FileHistory
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
from pygments.lexers.sql import PostgresLexer
from pgspecial.main import PGSpecial, NO_QUERY, PAGER_OFF, PAGER_LONG_OUTPUT
import pgspecial as special
from .pgcompleter import PGCompleter
from .pgtoolbar import create_toolbar_tokens_func
from .pgstyle import style_factory, style_factory_output
from .pgexecute import PGExecute
from .completion_refresher import CompletionRefresher
from .config import (
get_casing_file,
load_config,
config_location,
ensure_dir_exists,
get_config,
)
from .key_bindings import pgcli_bindings
from .packages.prompt_utils import confirm_destructive_query
from .__init__ import __version__
click.disable_unicode_literals_warning = True
try:
from urlparse import urlparse, unquote, parse_qs
except ImportError:
from urllib.parse import urlparse, unquote, parse_qs
from getpass import getuser
from psycopg2 import OperationalError, InterfaceError
import psycopg2
from collections import namedtuple
from textwrap import dedent
# Ref: https://stackoverflow.com/questions/30425105/filter-special-chars-such-as-color-codes-from-shell-output
COLOR_CODE_REGEX = re.compile(r"\x1b(\[.*?[@-~]|\].*?(\x07|\x1b\\))")
# Query tuples are used for maintaining history
MetaQuery = namedtuple(
"Query",
[
"query", # The entire text of the command
"successful", # True If all subqueries were successful
"total_time", # Time elapsed executing the query and formatting results
"execution_time", # Time elapsed executing the query
"meta_changed", # True if any subquery executed create/alter/drop
"db_changed", # True if any subquery changed the database
"path_changed", # True if any subquery changed the search path
"mutated", # True if any subquery executed insert/update/delete
"is_special", # True if the query is a special command
],
)
MetaQuery.__new__.__defaults__ = ("", False, 0, 0, False, False, False, False)
OutputSettings = namedtuple(
"OutputSettings",
"table_format dcmlfmt floatfmt missingval expanded max_width case_function style_output",
)
OutputSettings.__new__.__defaults__ = (
None,
None,
None,
"<null>",
False,
None,
lambda x: x,
None,
)
class PgCliQuitError(Exception):
pass
class PGCli(object):
default_prompt = "\\u@\\h:\\d> "
max_len_prompt = 30
def set_default_pager(self, config):
configured_pager = config["main"].get("pager")
os_environ_pager = os.environ.get("PAGER")
if configured_pager:
self.logger.info(
'Default pager found in config file: "%s"', configured_pager
)
os.environ["PAGER"] = configured_pager
elif os_environ_pager:
self.logger.info(
'Default pager found in PAGER environment variable: "%s"',
os_environ_pager,
)
os.environ["PAGER"] = os_environ_pager
else:
self.logger.info(
"No default pager found in environment. Using os default pager"
)
# Set default set of less recommended options, if they are not already set.
# They are ignored if pager is different than less.
if not os.environ.get("LESS"):
os.environ["LESS"] = "-SRXF"
def __init__(
self,
force_passwd_prompt=False,
never_passwd_prompt=False,
pgexecute=None,
pgclirc_file=None,
row_limit=None,
single_connection=False,
less_chatty=None,
prompt=None,
prompt_dsn=None,
auto_vertical_output=False,
warn=None,
histfile=None,
alias_dsn=None,
):
self.force_passwd_prompt = force_passwd_prompt
self.never_passwd_prompt = never_passwd_prompt
self.pgexecute = pgexecute
self.dsn_alias = None
self.watch_command = None
# Load config.
c = self.config = get_config(pgclirc_file)
NamedQueries.instance = NamedQueries.from_config(self.config)
self.logger = logging.getLogger(__name__)
self.initialize_logging()
self.set_default_pager(c)
self.output_file = None
self.pgspecial = PGSpecial()
self.multi_line = c["main"].as_bool("multi_line")
self.multiline_mode = c["main"].get("multi_line_mode", "psql")
self.vi_mode = c["main"].as_bool("vi")
self.auto_expand = auto_vertical_output or c["main"].as_bool("auto_expand")
self.expanded_output = c["main"].as_bool("expand")
self.pgspecial.timing_enabled = c["main"].as_bool("timing")
if row_limit is not None:
self.row_limit = row_limit
else:
self.row_limit = c["main"].as_int("row_limit")
self.min_num_menu_lines = c["main"].as_int("min_num_menu_lines")
self.multiline_continuation_char = c["main"]["multiline_continuation_char"]
self.table_format = c["main"]["table_format"]
self.syntax_style = c["main"]["syntax_style"]
self.cli_style = c["colors"]
self.wider_completion_menu = c["main"].as_bool("wider_completion_menu")
c_dest_warning = c["main"].as_bool("destructive_warning")
self.destructive_warning = c_dest_warning if warn is None else warn
self.less_chatty = bool(less_chatty) or c["main"].as_bool("less_chatty")
self.null_string = c["main"].get("null_string", "<null>")
self.prompt_format = (
prompt
if prompt is not None
else c["main"].get("prompt", self.default_prompt)
)
self.prompt_dsn_format = prompt_dsn
self.on_error = c["main"]["on_error"].upper()
self.decimal_format = c["data_formats"]["decimal"]
self.float_format = c["data_formats"]["float"]
self.initialize_keyring()
self.show_bottom_toolbar = c["main"].as_bool("show_bottom_toolbar")
self.pgspecial.pset_pager(
self.config["main"].as_bool("enable_pager") and "on" or "off"
)
self.style_output = style_factory_output(self.syntax_style, c["colors"])
self.now = dt.datetime.today()
self.completion_refresher = CompletionRefresher()
# history file location: --hisfile > pgclirc:history
if histfile:
self.history_file = histfile
else:
self.history_file = self.config["main"]["history_file"]
if self.history_file == "default":
self.history_file = config_location() + "history"
if alias_dsn:
self.dsn_alias = alias_dsn
self.query_history = []
# Initialize completer
smart_completion = c["main"].as_bool("smart_completion")
keyword_casing = c["main"]["keyword_casing"]
self.settings = {
"casing_file": get_casing_file(c),
"generate_casing_file": c["main"].as_bool("generate_casing_file"),
"generate_aliases": c["main"].as_bool("generate_aliases"),
"asterisk_column_order": c["main"]["asterisk_column_order"],
"qualify_columns": c["main"]["qualify_columns"],
"case_column_headers": c["main"].as_bool("case_column_headers"),
"search_path_filter": c["main"].as_bool("search_path_filter"),
"single_connection": single_connection,
"less_chatty": less_chatty,
"keyword_casing": keyword_casing,
}
completer = PGCompleter(
smart_completion, pgspecial=self.pgspecial, settings=self.settings
)
self.completer = completer
self._completer_lock = threading.Lock()
self.register_special_commands()
self.prompt_app = None
def quit(self):
raise PgCliQuitError
def register_special_commands(self):
self.pgspecial.register(
self.change_db,
"\\c",
"\\c[onnect] database_name",
"Change to a new database.",
aliases=("use", "\\connect", "USE"),
)
refresh_callback = lambda: self.refresh_completions(persist_priorities="all")
self.pgspecial.register(
self.quit,
"\\q",
"\\q",
"Quit pgcli.",
arg_type=NO_QUERY,
case_sensitive=True,
aliases=(":q",),
)
self.pgspecial.register(
self.quit,
"quit",
"quit",
"Quit pgcli.",
arg_type=NO_QUERY,
case_sensitive=False,
aliases=("exit",),
)
self.pgspecial.register(
refresh_callback,
"\\#",
"\\#",
"Refresh auto-completions.",
arg_type=NO_QUERY,
)
self.pgspecial.register(
refresh_callback,
"\\refresh",
"\\refresh",
"Refresh auto-completions.",
arg_type=NO_QUERY,
)
self.pgspecial.register(
self.execute_from_file, "\\i", "\\i filename", "Execute commands from file."
)
self.pgspecial.register(
self.write_to_file,
"\\o",
"\\o [filename]",
"Send all query results to file.",
)
self.pgspecial.register(
self.info_connection, "\\conninfo", "\\conninfo", "Get connection details"
)
self.pgspecial.register(
self.change_table_format,
"\\T",
"\\T [format]",
"Change the table format used to output results",
)
def change_table_format(self, pattern, **_):
try:
if pattern not in TabularOutputFormatter().supported_formats:
raise ValueError()
self.table_format = pattern
yield (None, None, None, "Changed table format to {}".format(pattern))
except ValueError:
msg = "Table format {} not recognized. Allowed formats:".format(pattern)
for table_type in TabularOutputFormatter().supported_formats:
msg += "\n\t{}".format(table_type)
msg += "\nCurrently set to: %s" % self.table_format
yield (None, None, None, msg)
def info_connection(self, **_):
if self.pgexecute.host.startswith("/"):
host = 'socket "%s"' % self.pgexecute.host
else:
host = 'host "%s"' % self.pgexecute.host
yield (
None,
None,
None,
'You are connected to database "%s" as user '
'"%s" on %s at port "%s".'
% (self.pgexecute.dbname, self.pgexecute.user, host, self.pgexecute.port),
)
def change_db(self, pattern, **_):
if pattern:
# Get all the parameters in pattern, handling double quotes if any.
infos = re.findall(r'"[^"]*"|[^"\'\s]+', pattern)
# Now removing quotes.
list(map(lambda s: s.strip('"'), infos))
infos.extend([None] * (4 - len(infos)))
db, user, host, port = infos
try:
self.pgexecute.connect(
database=db,
user=user,
host=host,
port=port,
**self.pgexecute.extra_args,
)
except OperationalError as e:
click.secho(str(e), err=True, fg="red")
click.echo("Previous connection kept")
else:
self.pgexecute.connect()
yield (
None,
None,
None,
'You are now connected to database "%s" as '
'user "%s"' % (self.pgexecute.dbname, self.pgexecute.user),
)
def execute_from_file(self, pattern, **_):
if not pattern:
message = "\\i: missing required argument"
return [(None, None, None, message, "", False, True)]
try:
with open(os.path.expanduser(pattern), encoding="utf-8") as f:
query = f.read()
except IOError as e:
return [(None, None, None, str(e), "", False, True)]
if self.destructive_warning and confirm_destructive_query(query) is False:
message = "Wise choice. Command execution stopped."
return [(None, None, None, message)]
on_error_resume = self.on_error == "RESUME"
return self.pgexecute.run(
query, self.pgspecial, on_error_resume=on_error_resume
)
def write_to_file(self, pattern, **_):
if not pattern:
self.output_file = None
message = "File output disabled"
return [(None, None, None, message, "", True, True)]
filename = os.path.abspath(os.path.expanduser(pattern))
if not os.path.isfile(filename):
try:
open(filename, "w").close()
except IOError as e:
self.output_file = None
message = str(e) + "\nFile output disabled"
return [(None, None, None, message, "", False, True)]
self.output_file = filename
message = 'Writing to file "%s"' % self.output_file
return [(None, None, None, message, "", True, True)]
def initialize_logging(self):
log_file = self.config["main"]["log_file"]
if log_file == "default":
log_file = config_location() + "log"
ensure_dir_exists(log_file)
log_level = self.config["main"]["log_level"]
# Disable logging if value is NONE by switching to a no-op handler.
# Set log level to a high value so it doesn't even waste cycles getting called.
if log_level.upper() == "NONE":
handler = logging.NullHandler()
else:
handler = logging.FileHandler(os.path.expanduser(log_file))
level_map = {
"CRITICAL": logging.CRITICAL,
"ERROR": logging.ERROR,
"WARNING": logging.WARNING,
"INFO": logging.INFO,
"DEBUG": logging.DEBUG,
"NONE": logging.CRITICAL,
}
log_level = level_map[log_level.upper()]
formatter = logging.Formatter(
"%(asctime)s (%(process)d/%(threadName)s) "
"%(name)s %(levelname)s - %(message)s"
)
handler.setFormatter(formatter)
root_logger = logging.getLogger("pgcli")
root_logger.addHandler(handler)
root_logger.setLevel(log_level)
root_logger.debug("Initializing pgcli logging.")
root_logger.debug("Log file %r.", log_file)
pgspecial_logger = logging.getLogger("pgspecial")
pgspecial_logger.addHandler(handler)
pgspecial_logger.setLevel(log_level)
def initialize_keyring(self):
global keyring
keyring_enabled = self.config["main"].as_bool("keyring")
if keyring_enabled:
# Try best to load keyring (issue #1041).
import importlib
try:
keyring = importlib.import_module("keyring")
except Exception as e: # ImportError for Python 2, ModuleNotFoundError for Python 3
self.logger.warning("import keyring failed: %r.", e)
def connect_dsn(self, dsn, **kwargs):
self.connect(dsn=dsn, **kwargs)
def connect_service(self, service, user):
service_config, file = parse_service_info(service)
if service_config is None:
click.secho(
"service '%s' was not found in %s" % (service, file), err=True, fg="red"
)
exit(1)
self.connect(
database=service_config.get("dbname"),
host=service_config.get("host"),
user=user or service_config.get("user"),
port=service_config.get("port"),
passwd=service_config.get("password"),
)
def connect_uri(self, uri):
kwargs = psycopg2.extensions.parse_dsn(uri)
remap = {"dbname": "database", "password": "passwd"}
kwargs = {remap.get(k, k): v for k, v in kwargs.items()}
self.connect(**kwargs)
def connect(
self, database="", host="", user="", port="", passwd="", dsn="", **kwargs
):
# Connect to the database.
if not user:
user = getuser()
if not database:
database = user
kwargs.setdefault("application_name", "pgcli")
# If password prompt is not forced but no password is provided, try
# getting it from environment variable.
if not self.force_passwd_prompt and not passwd:
passwd = os.environ.get("PGPASSWORD", "")
# Find password from store
key = "%s@%s" % (user, host)
keyring_error_message = dedent(
"""\
{}
{}
To remove this message do one of the following:
- prepare keyring as described at: https://keyring.readthedocs.io/en/stable/
- uninstall keyring: pip uninstall keyring
- disable keyring in our configuration: add keyring = False to [main]"""
)
if not passwd and keyring:
try:
passwd = keyring.get_password("pgcli", key)
except (RuntimeError, keyring.errors.InitError) as e:
click.secho(
keyring_error_message.format(
"Load your password from keyring returned:", str(e)
),
err=True,
fg="red",
)
# Prompt for a password immediately if requested via the -W flag. This
# avoids wasting time trying to connect to the database and catching a
# no-password exception.
# If we successfully parsed a password from a URI, there's no need to
# prompt for it, even with the -W flag
if self.force_passwd_prompt and not passwd:
passwd = click.prompt(
"Password for %s" % user, hide_input=True, show_default=False, type=str
)
def should_ask_for_password(exc):
# Prompt for a password after 1st attempt to connect
# fails. Don't prompt if the -w flag is supplied
if self.never_passwd_prompt:
return False
error_msg = exc.args[0]
if "no password supplied" in error_msg:
return True
if "password authentication failed" in error_msg:
return True
return False
# Attempt to connect to the database.
# Note that passwd may be empty on the first attempt. If connection
# fails because of a missing or incorrect password, but we're allowed to
# prompt for a password (no -w flag), prompt for a passwd and try again.
try:
try:
pgexecute = PGExecute(database, user, passwd, host, port, dsn, **kwargs)
except (OperationalError, InterfaceError) as e:
if should_ask_for_password(e):
passwd = click.prompt(
"Password for %s" % user,
hide_input=True,
show_default=False,
type=str,
)
pgexecute = PGExecute(
database, user, passwd, host, port, dsn, **kwargs
)
else:
raise e
if passwd and keyring:
try:
keyring.set_password("pgcli", key, passwd)
except (RuntimeError, keyring.errors.KeyringError) as e:
click.secho(
keyring_error_message.format(
"Set password in keyring returned:", str(e)
),
err=True,
fg="red",
)
except Exception as e: # Connecting to a database could fail.
self.logger.debug("Database connection failed: %r.", e)
self.logger.error("traceback: %r", traceback.format_exc())
click.secho(str(e), err=True, fg="red")
exit(1)
self.pgexecute = pgexecute
def handle_editor_command(self, text):
r"""
Editor command is any query that is prefixed or suffixed
by a '\e'. The reason for a while loop is because a user
might edit a query multiple times.
For eg:
"select * from \e"<enter> to edit it in vim, then come
back to the prompt with the edited query "select * from
blah where q = 'abc'\e" to edit it again.
:param text: Document
:return: Document
"""
editor_command = special.editor_command(text)
while editor_command:
if editor_command == "\\e":
filename = special.get_filename(text)
query = special.get_editor_query(text) or self.get_last_query()
else: # \ev or \ef
filename = None
spec = text.split()[1]
if editor_command == "\\ev":
query = self.pgexecute.view_definition(spec)
elif editor_command == "\\ef":
query = self.pgexecute.function_definition(spec)
sql, message = special.open_external_editor(filename, sql=query)
if message:
# Something went wrong. Raise an exception and bail.
raise RuntimeError(message)
while True:
try:
text = self.prompt_app.prompt(default=sql)
break
except KeyboardInterrupt:
sql = ""
editor_command = special.editor_command(text)
return text
def execute_command(self, text):
logger = self.logger
query = MetaQuery(query=text, successful=False)
try:
if self.destructive_warning:
destroy = confirm = confirm_destructive_query(text)
if destroy is False:
click.secho("Wise choice!")
raise KeyboardInterrupt
elif destroy:
click.secho("Your call!")
output, query = self._evaluate_command(text)
except KeyboardInterrupt:
# Restart connection to the database
self.pgexecute.connect()
logger.debug("cancelled query, sql: %r", text)
click.secho("cancelled query", err=True, fg="red")
except NotImplementedError:
click.secho("Not Yet Implemented.", fg="yellow")
except OperationalError as e:
logger.error("sql: %r, error: %r", text, e)
logger.error("traceback: %r", traceback.format_exc())
self._handle_server_closed_connection(text)
except (PgCliQuitError, EOFError) as e:
raise
except Exception as e:
logger.error("sql: %r, error: %r", text, e)
logger.error("traceback: %r", traceback.format_exc())
click.secho(str(e), err=True, fg="red")
else:
try:
if self.output_file and not text.startswith(("\\o ", "\\? ")):
try:
with open(self.output_file, "a", encoding="utf-8") as f:
click.echo(text, file=f)
click.echo("\n".join(output), file=f)
click.echo("", file=f) # extra newline
except IOError as e:
click.secho(str(e), err=True, fg="red")
else:
if output:
self.echo_via_pager("\n".join(output))
except KeyboardInterrupt:
pass
if self.pgspecial.timing_enabled:
# Only add humanized time display if > 1 second
if query.total_time > 1:
print(
"Time: %0.03fs (%s), executed in: %0.03fs (%s)"
% (
query.total_time,
pendulum.Duration(seconds=query.total_time).in_words(),
query.execution_time,
pendulum.Duration(seconds=query.execution_time).in_words(),
)
)
else:
print("Time: %0.03fs" % query.total_time)
# Check if we need to update completions, in order of most
# to least drastic changes
if query.db_changed:
with self._completer_lock:
self.completer.reset_completions()
self.refresh_completions(persist_priorities="keywords")
elif query.meta_changed:
self.refresh_completions(persist_priorities="all")
elif query.path_changed:
logger.debug("Refreshing search path")
with self._completer_lock:
self.completer.set_search_path(self.pgexecute.search_path())
logger.debug("Search path: %r", self.completer.search_path)
return query
def run_cli(self):
logger = self.logger
history = FileHistory(os.path.expanduser(self.history_file))
self.refresh_completions(history=history, persist_priorities="none")
self.prompt_app = self._build_cli(history)
if not self.less_chatty:
print("Server: PostgreSQL", self.pgexecute.server_version)
print("Version:", __version__)
print("Chat: https://gitter.im/dbcli/pgcli")
print("Home: http://pgcli.com")
try:
while True:
try:
text = self.prompt_app.prompt()
except KeyboardInterrupt:
continue
try:
text = self.handle_editor_command(text)
except RuntimeError as e:
logger.error("sql: %r, error: %r", text, e)
logger.error("traceback: %r", traceback.format_exc())
click.secho(str(e), err=True, fg="red")
continue
# Initialize default metaquery in case execution fails
self.watch_command, timing = special.get_watch_command(text)
if self.watch_command:
while self.watch_command:
try:
query = self.execute_command(self.watch_command)
click.echo(
"Waiting for {0} seconds before repeating".format(
timing
)
)
sleep(timing)
except KeyboardInterrupt:
self.watch_command = None
else:
query = self.execute_command(text)
self.now = dt.datetime.today()
# Allow PGCompleter to learn user's preferred keywords, etc.
with self._completer_lock:
self.completer.extend_query_history(text)
self.query_history.append(query)
except (PgCliQuitError, EOFError):
if not self.less_chatty:
print("Goodbye!")
def _build_cli(self, history):
key_bindings = pgcli_bindings(self)
def get_message():
if self.dsn_alias and self.prompt_dsn_format is not None:
prompt_format = self.prompt_dsn_format
else:
prompt_format = self.prompt_format
prompt = self.get_prompt(prompt_format)
if (
prompt_format == self.default_prompt
and len(prompt) > self.max_len_prompt
):
prompt = self.get_prompt("\\d> ")
prompt = prompt.replace("\\x1b", "\x1b")
return ANSI(prompt)
def get_continuation(width, line_number, is_soft_wrap):
continuation = self.multiline_continuation_char * (width - 1) + " "
return [("class:continuation", continuation)]
get_toolbar_tokens = create_toolbar_tokens_func(self)
if self.wider_completion_menu:
complete_style = CompleteStyle.MULTI_COLUMN
else:
complete_style = CompleteStyle.COLUMN
with self._completer_lock:
prompt_app = PromptSession(
lexer=PygmentsLexer(PostgresLexer),
reserve_space_for_menu=self.min_num_menu_lines,
message=get_message,
prompt_continuation=get_continuation,
bottom_toolbar=get_toolbar_tokens if self.show_bottom_toolbar else None,
complete_style=complete_style,
input_processors=[
# Highlight matching brackets while editing.
ConditionalProcessor(
processor=HighlightMatchingBracketProcessor(chars="[](){}"),
filter=HasFocus(DEFAULT_BUFFER) & ~IsDone(),
),
# Render \t as 4 spaces instead of "^I"
TabsProcessor(char1=" ", char2=" "),
],
auto_suggest=AutoSuggestFromHistory(),
tempfile_suffix=".sql",
# N.b. pgcli's multi-line mode controls submit-on-Enter (which
# overrides the default behaviour of prompt_toolkit) and is
# distinct from prompt_toolkit's multiline mode here, which
# controls layout/display of the prompt/buffer
multiline=True,
history=history,
completer=ThreadedCompleter(DynamicCompleter(lambda: self.completer)),
complete_while_typing=True,
style=style_factory(self.syntax_style, self.cli_style),
include_default_pygments_style=False,
key_bindings=key_bindings,
enable_open_in_editor=True,
enable_system_prompt=True,
enable_suspend=True,
editing_mode=EditingMode.VI if self.vi_mode else EditingMode.EMACS,
search_ignore_case=True,
)
return prompt_app
def _should_limit_output(self, sql, cur):
"""returns True if the output should be truncated, False otherwise."""
if not is_select(sql):
return False
return (
not self._has_limit(sql)
and self.row_limit != 0
and cur
and cur.rowcount > self.row_limit
)
def _has_limit(self, sql):
if not sql:
return False
return "limit " in sql.lower()
def _limit_output(self, cur):
limit = min(self.row_limit, cur.rowcount)
new_cur = itertools.islice(cur, limit)
new_status = "SELECT " + str(limit)
click.secho("The result was limited to %s rows" % limit, fg="red")
return new_cur, new_status
def _evaluate_command(self, text):
"""Used to run a command entered by the user during CLI operation
(Puts the E in REPL)
returns (results, MetaQuery)
"""
logger = self.logger
logger.debug("sql: %r", text)
all_success = True
meta_changed = False # CREATE, ALTER, DROP, etc
mutated = False # INSERT, DELETE, etc
db_changed = False
path_changed = False
output = []
total = 0
execution = 0
# Run the query.
start = time()
on_error_resume = self.on_error == "RESUME"
res = self.pgexecute.run(
text, self.pgspecial, exception_formatter, on_error_resume
)
is_special = None
for title, cur, headers, status, sql, success, is_special in res:
logger.debug("headers: %r", headers)
logger.debug("rows: %r", cur)
logger.debug("status: %r", status)
if self._should_limit_output(sql, cur):
cur, status = self._limit_output(cur)
if self.pgspecial.auto_expand or self.auto_expand:
max_width = self.prompt_app.output.get_size().columns
else:
max_width = None
expanded = self.pgspecial.expanded_output or self.expanded_output
settings = OutputSettings(
table_format=self.table_format,
dcmlfmt=self.decimal_format,
floatfmt=self.float_format,
missingval=self.null_string,
expanded=expanded,
max_width=max_width,
case_function=(
self.completer.case
if self.settings["case_column_headers"]
else lambda x: x
),
style_output=self.style_output,
)
execution = time() - start
formatted = format_output(title, cur, headers, status, settings)
output.extend(formatted)
total = time() - start
# Keep track of whether any of the queries are mutating or changing
# the database
if success:
mutated = mutated or is_mutating(status)
db_changed = db_changed or has_change_db_cmd(sql)
meta_changed = meta_changed or has_meta_cmd(sql)
path_changed = path_changed or has_change_path_cmd(sql)
else:
all_success = False
meta_query = MetaQuery(
text,
all_success,
total,
execution,
meta_changed,
db_changed,
path_changed,
mutated,
is_special,
)
return output, meta_query
def _handle_server_closed_connection(self, text):
"""Used during CLI execution."""
try:
click.secho("Reconnecting...", fg="green")
self.pgexecute.connect()
click.secho("Reconnected!", fg="green")
self.execute_command(text)
except OperationalError as e:
click.secho("Reconnect Failed", fg="red")
click.secho(str(e), err=True, fg="red")
def refresh_completions(self, history=None, persist_priorities="all"):
"""Refresh outdated completions
:param history: A prompt_toolkit.history.FileHistory object. Used to
load keyword and identifier preferences
:param persist_priorities: 'all' or 'keywords'
"""
callback = functools.partial(
self._on_completions_refreshed, persist_priorities=persist_priorities
)
self.completion_refresher.refresh(
self.pgexecute,
self.pgspecial,
callback,
history=history,
settings=self.settings,
)
return [
(None, None, None, "Auto-completion refresh started in the background.")
]
def _on_completions_refreshed(self, new_completer, persist_priorities):
self._swap_completer_objects(new_completer, persist_priorities)
if self.prompt_app:
# After refreshing, redraw the CLI to clear the statusbar
# "Refreshing completions..." indicator
self.prompt_app.app.invalidate()
def _swap_completer_objects(self, new_completer, persist_priorities):
"""Swap the completer object with the newly created completer.
persist_priorities is a string specifying how the old completer's
learned prioritizer should be transferred to the new completer.
'none' - The new prioritizer is left in a new/clean state
'all' - The new prioritizer is updated to exactly reflect
the old one
'keywords' - The new prioritizer is updated with old keyword
priorities, but not any other.
"""
with self._completer_lock:
old_completer = self.completer
self.completer = new_completer
if persist_priorities == "all":
# Just swap over the entire prioritizer
new_completer.prioritizer = old_completer.prioritizer
elif persist_priorities == "keywords":
# Swap over the entire prioritizer, but clear name priorities,
# leaving learned keyword priorities alone
new_completer.prioritizer = old_completer.prioritizer
new_completer.prioritizer.clear_names()
elif persist_priorities == "none":
# Leave the new prioritizer as is
pass
self.completer = new_completer
def get_completions(self, text, cursor_positition):
with self._completer_lock:
return self.completer.get_completions(
Document(text=text, cursor_position=cursor_positition), None
)
def get_prompt(self, string):
# should be before replacing \\d
string = string.replace("\\dsn_alias", self.dsn_alias or "")
string = string.replace("\\t", self.now.strftime("%x %X"))
string = string.replace("\\u", self.pgexecute.user or "(none)")
string = string.replace("\\H", self.pgexecute.host or "(none)")
string = string.replace("\\h", self.pgexecute.short_host or "(none)")
string = string.replace("\\d", self.pgexecute.dbname or "(none)")
string = string.replace(
"\\p",
str(self.pgexecute.port) if self.pgexecute.port is not None else "5432",
)
string = string.replace("\\i", str(self.pgexecute.pid) or "(none)")
string = string.replace("\\#", "#" if (self.pgexecute.superuser) else ">")
string = string.replace("\\n", "\n")
return string
def get_last_query(self):
"""Get the last query executed or None."""
return self.query_history[-1][0] if self.query_history else None
def is_too_wide(self, line):
"""Will this line be too wide to fit into terminal?"""
if not self.prompt_app:
return False
return (
len(COLOR_CODE_REGEX.sub("", line))
> self.prompt_app.output.get_size().columns
)
def is_too_tall(self, lines):
"""Are there too many lines to fit into terminal?"""
if not self.prompt_app:
return False
return len(lines) >= (self.prompt_app.output.get_size().rows - 4)
def echo_via_pager(self, text, color=None):
if self.pgspecial.pager_config == PAGER_OFF or self.watch_command:
click.echo(text, color=color)
elif "pspg" in os.environ.get("PAGER", "") and self.table_format == "csv":
click.echo_via_pager(text, color)
elif self.pgspecial.pager_config == PAGER_LONG_OUTPUT:
lines = text.split("\n")
# The last 4 lines are reserved for the pgcli menu and padding
if self.is_too_tall(lines) or any(self.is_too_wide(l) for l in lines):
click.echo_via_pager(text, color=color)
else:
click.echo(text, color=color)
else:
click.echo_via_pager(text, color)
@click.command()
# Default host is '' so psycopg2 can default to either localhost or unix socket
@click.option(
"-h",
"--host",
default="",
envvar="PGHOST",
help="Host address of the postgres database.",
)
@click.option(
"-p",
"--port",
default=5432,
help="Port number at which the " "postgres instance is listening.",
envvar="PGPORT",
type=click.INT,
)
@click.option(
"-U",
"--username",
"username_opt",
help="Username to connect to the postgres database.",
)
@click.option(
"-u", "--user", "username_opt", help="Username to connect to the postgres database."
)
@click.option(
"-W",
"--password",
"prompt_passwd",
is_flag=True,
default=False,
help="Force password prompt.",
)
@click.option(
"-w",
"--no-password",
"never_prompt",
is_flag=True,
default=False,
help="Never prompt for password.",
)
@click.option(
"--single-connection",
"single_connection",
is_flag=True,
default=False,
help="Do not use a separate connection for completions.",
)
@click.option("-v", "--version", is_flag=True, help="Version of pgcli.")
@click.option("-d", "--dbname", "dbname_opt", help="database name to connect to.")
@click.option(
"--pgclirc",
default=config_location() + "config",
envvar="PGCLIRC",
help="Location of pgclirc file.",
type=click.Path(dir_okay=False),
)
@click.option(
"-D",
"--dsn",
default="",
envvar="DSN",
help="Use DSN configured into the [alias_dsn] section of pgclirc file.",
)
@click.option(
"--list-dsn",
"list_dsn",
is_flag=True,
help="list of DSN configured into the [alias_dsn] section of pgclirc file.",
)
@click.option(
"--row-limit",
default=None,
envvar="PGROWLIMIT",
type=click.INT,
help="Set threshold for row limit prompt. Use 0 to disable prompt.",
)
@click.option(
"--less-chatty",
"less_chatty",
is_flag=True,
default=False,
help="Skip intro on startup and goodbye on exit.",
)
@click.option("--prompt", help='Prompt format (Default: "\\u@\\h:\\d> ").')
@click.option(
"--prompt-dsn",
help='Prompt format for connections using DSN aliases (Default: "\\u@\\h:\\d> ").',
)
@click.option(
"-l",
"--list",
"list_databases",
is_flag=True,
help="list " "available databases, then exit.",
)
@click.option(
"--auto-vertical-output",
is_flag=True,
help="Automatically switch to vertical output mode if the result is wider than the terminal width.",
)
@click.option(
"--warn/--no-warn", default=None, help="Warn before running a destructive query."
)
@click.option("--histfile", default=None, help="Specify history file location.")
@click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1)
@click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1)
def cli(
dbname,
username_opt,
host,
port,
prompt_passwd,
never_prompt,
single_connection,
dbname_opt,
username,
version,
pgclirc,
dsn,
row_limit,
less_chatty,
prompt,
prompt_dsn,
list_databases,
auto_vertical_output,
list_dsn,
warn,
histfile,
):
if version:
print("Version:", __version__)
sys.exit(0)
config_dir = os.path.dirname(config_location())
if not os.path.exists(config_dir):
os.makedirs(config_dir)
# Migrate the config file from old location.
config_full_path = config_location() + "config"
if os.path.exists(os.path.expanduser("~/.pgclirc")):
if not os.path.exists(config_full_path):
shutil.move(os.path.expanduser("~/.pgclirc"), config_full_path)
print("Config file (~/.pgclirc) moved to new location", config_full_path)
else:
print("Config file is now located at", config_full_path)
print(
"Please move the existing config file ~/.pgclirc to",
config_full_path,
)
if list_dsn:
try:
cfg = load_config(pgclirc, config_full_path)
for alias in cfg["alias_dsn"]:
click.secho(alias + " : " + cfg["alias_dsn"][alias])
sys.exit(0)
except Exception as err:
click.secho(
"Invalid DSNs found in the config file. "
'Please check the "[alias_dsn]" section in pgclirc.',
err=True,
fg="red",
)
exit(1)
pgcli = PGCli(
prompt_passwd,
never_prompt,
pgclirc_file=pgclirc,
row_limit=row_limit,
single_connection=single_connection,
less_chatty=less_chatty,
prompt=prompt,
prompt_dsn=prompt_dsn,
auto_vertical_output=auto_vertical_output,
warn=warn,
histfile=histfile,
alias_dsn=dsn,
)
# Choose which ever one has a valid value.
if dbname_opt and dbname:
# work as psql: when database is given as option and argument use the argument as user
username = dbname
database = dbname_opt or dbname or ""
user = username_opt or username
service = None
if database.startswith("service="):
service = database[8:]
elif os.getenv("PGSERVICE") is not None:
service = os.getenv("PGSERVICE")
# because option --list or -l are not supposed to have a db name
if list_databases:
database = "postgres"
if dsn != "":
try:
cfg = load_config(pgclirc, config_full_path)
dsn_config = cfg["alias_dsn"][dsn]
except KeyError:
click.secho(
f"Could not find a DSN with alias {dsn}. "
'Please check the "[alias_dsn]" section in pgclirc.',
err=True,
fg="red",
)
exit(1)
except Exception:
click.secho(
"Invalid DSNs found in the config file. "
'Please check the "[alias_dsn]" section in pgclirc.',
err=True,
fg="red",
)
exit(1)
pgcli.connect_uri(dsn_config)
elif "://" in database:
pgcli.connect_uri(database)
elif "=" in database and service is None:
pgcli.connect_dsn(database, user=user)
elif service is not None:
pgcli.connect_service(service, user)
else:
pgcli.connect(database, host, user, port)
if list_databases:
cur, headers, status = pgcli.pgexecute.full_databases()
title = "List of databases"
settings = OutputSettings(table_format="ascii", missingval="<null>")
formatted = format_output(title, cur, headers, status, settings)
pgcli.echo_via_pager("\n".join(formatted))
sys.exit(0)
pgcli.logger.debug(
"Launch Params: \n" "\tdatabase: %r" "\tuser: %r" "\thost: %r" "\tport: %r",
database,
user,
host,
port,
)
if setproctitle:
obfuscate_process_password()
pgcli.run_cli()
def obfuscate_process_password():
process_title = setproctitle.getproctitle()
if "://" in process_title:
process_title = re.sub(r":(.*):(.*)@", r":\1:xxxx@", process_title)
elif "=" in process_title:
process_title = re.sub(
r"password=(.+?)((\s[a-zA-Z]+=)|$)", r"password=xxxx\2", process_title
)
setproctitle.setproctitle(process_title)
def has_meta_cmd(query):
"""Determines if the completion needs a refresh by checking if the sql
statement is an alter, create, drop, commit or rollback."""
try:
first_token = query.split()[0]
if first_token.lower() in ("alter", "create", "drop", "commit", "rollback"):
return True
except Exception:
return False
return False
def has_change_db_cmd(query):
"""Determines if the statement is a database switch such as 'use' or '\\c'"""
try:
first_token = query.split()[0]
if first_token.lower() in ("use", "\\c", "\\connect"):
return True
except Exception:
return False
return False
def has_change_path_cmd(sql):
"""Determines if the search_path should be refreshed by checking if the
sql has 'set search_path'."""
return "set search_path" in sql.lower()
def is_mutating(status):
"""Determines if the statement is mutating based on the status."""
if not status:
return False
mutating = set(["insert", "update", "delete"])
return status.split(None, 1)[0].lower() in mutating
def is_select(status):
"""Returns true if the first word in status is 'select'."""
if not status:
return False
return status.split(None, 1)[0].lower() == "select"
def exception_formatter(e):
return click.style(str(e), fg="red")
def format_output(title, cur, headers, status, settings):
output = []
expanded = settings.expanded or settings.table_format == "vertical"
table_format = "vertical" if settings.expanded else settings.table_format
max_width = settings.max_width
case_function = settings.case_function
formatter = TabularOutputFormatter(format_name=table_format)
def format_array(val):
if val is None:
return settings.missingval
if not isinstance(val, list):
return val
return "{" + ",".join(str(format_array(e)) for e in val) + "}"
def format_arrays(data, headers, **_):
data = list(data)
for row in data:
row[:] = [
format_array(val) if isinstance(val, list) else val for val in row
]
return data, headers
output_kwargs = {
"sep_title": "RECORD {n}",
"sep_character": "-",
"sep_length": (1, 25),
"missing_value": settings.missingval,
"integer_format": settings.dcmlfmt,
"float_format": settings.floatfmt,
"preprocessors": (format_numbers, format_arrays),
"disable_numparse": True,
"preserve_whitespace": True,
"style": settings.style_output,
}
if not settings.floatfmt:
output_kwargs["preprocessors"] = (align_decimals,)
if table_format == "csv":
# The default CSV dialect is "excel" which is not handling newline values correctly
# Nevertheless, we want to keep on using "excel" on Windows since it uses '\r\n'
# as the line terminator
# https://github.com/dbcli/pgcli/issues/1102
dialect = "excel" if platform.system() == "Windows" else "unix"
output_kwargs["dialect"] = dialect
if title: # Only print the title if it's not None.
output.append(title)
if cur:
headers = [case_function(x) for x in headers]
if max_width is not None:
cur = list(cur)
column_types = None
if hasattr(cur, "description"):
column_types = []
for d in cur.description:
if (
d[1] in psycopg2.extensions.DECIMAL.values
or d[1] in psycopg2.extensions.FLOAT.values
):
column_types.append(float)
if (
d[1] == psycopg2.extensions.INTEGER.values
or d[1] in psycopg2.extensions.LONGINTEGER.values
):
column_types.append(int)
else:
column_types.append(str)
formatted = formatter.format_output(cur, headers, **output_kwargs)
if isinstance(formatted, str):
formatted = iter(formatted.splitlines())
first_line = next(formatted)
formatted = itertools.chain([first_line], formatted)
if not expanded and max_width and len(first_line) > max_width and headers:
formatted = formatter.format_output(
cur, headers, format_name="vertical", column_types=None, **output_kwargs
)
if isinstance(formatted, str):
formatted = iter(formatted.splitlines())
output = itertools.chain(output, formatted)
# Only print the status if it's not None and we are not producing CSV
if status and table_format != "csv":
output = itertools.chain(output, [status])
return output
def parse_service_info(service):
service = service or os.getenv("PGSERVICE")
service_file = os.getenv("PGSERVICEFILE")
if not service_file:
# try ~/.pg_service.conf (if that exists)
if platform.system() == "Windows":
service_file = os.getenv("PGSYSCONFDIR") + "\\pg_service.conf"
elif os.getenv("PGSYSCONFDIR"):
service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf")
else:
service_file = expanduser("~/.pg_service.conf")
if not service:
# nothing to do
return None, service_file
service_file_config = ConfigObj(service_file)
if service not in service_file_config:
return None, service_file
service_conf = service_file_config.get(service)
return service_conf, service_file
if __name__ == "__main__":
cli()