mirror of https://github.com/dbcli/pgcli
Another attempt to fix pgbouncer error (1093.) (#1097)
* Another attempt to fix pgbouncer error (1093.) * Fixes for various pgbouncer problems. * different approach with custom cursor. * Fix rebase. * Missed this. * Fix completion refresher test. * Black. * Unused import. * Changelog. * Fix race condition in test. * Switch from is_pgbouncer to more generic is_virtual_database, and duck-type it. Add very dumb unit test for virtual cursor. * Remove debugger code.
This commit is contained in:
parent
d8532df22e
commit
e0a4c18c4a
|
@ -19,6 +19,7 @@ Bug fixes:
|
|||
* Fix pager not being used when output format is set to csv. (#1238)
|
||||
* Add function literals random, generate_series, generate_subscripts
|
||||
* Fix ANSI escape codes in first line make the cli choose expanded output incorrectly
|
||||
* Fix pgcli crashing with virtual `pgbouncer` database. (#1093)
|
||||
|
||||
3.1.0
|
||||
=====
|
||||
|
|
|
@ -3,7 +3,6 @@ import os
|
|||
from collections import OrderedDict
|
||||
|
||||
from .pgcompleter import PGCompleter
|
||||
from .pgexecute import PGExecute
|
||||
|
||||
|
||||
class CompletionRefresher:
|
||||
|
@ -27,6 +26,10 @@ class CompletionRefresher:
|
|||
has completed the refresh. The newly created completion
|
||||
object will be passed in as an argument to each callback.
|
||||
"""
|
||||
if executor.is_virtual_database():
|
||||
# do nothing
|
||||
return [(None, None, None, "Auto-completion refresh can't be started.")]
|
||||
|
||||
if self.is_refreshing():
|
||||
self._restart_refresh.set()
|
||||
return [(None, None, None, "Auto-completion refresh restarted.")]
|
||||
|
|
|
@ -988,16 +988,13 @@ class PGCli:
|
|||
callback = functools.partial(
|
||||
self._on_completions_refreshed, persist_priorities=persist_priorities
|
||||
)
|
||||
self.completion_refresher.refresh(
|
||||
return self.completion_refresher.refresh(
|
||||
self.pgexecute,
|
||||
self.pgspecial,
|
||||
callback,
|
||||
history=history,
|
||||
settings=self.settings,
|
||||
)
|
||||
return [
|
||||
(None, None, None, "Auto-completion refresh started in the background.")
|
||||
]
|
||||
|
||||
def _on_completions_refreshed(self, new_completer, persist_priorities):
|
||||
self._swap_completer_objects(new_completer, persist_priorities)
|
||||
|
|
|
@ -1,13 +1,15 @@
|
|||
import traceback
|
||||
import logging
|
||||
import select
|
||||
import traceback
|
||||
|
||||
import pgspecial as special
|
||||
import psycopg2
|
||||
import psycopg2.extras
|
||||
import psycopg2.errorcodes
|
||||
import psycopg2.extensions as ext
|
||||
import psycopg2.extras
|
||||
import sqlparse
|
||||
import pgspecial as special
|
||||
import select
|
||||
from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn
|
||||
|
||||
from .packages.parseutils.meta import FunctionMetadata, ForeignKey
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
@ -27,6 +29,7 @@ ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING))
|
|||
|
||||
# TODO: Get default timeout from pgclirc?
|
||||
_WAIT_SELECT_TIMEOUT = 1
|
||||
_wait_callback_is_set = False
|
||||
|
||||
|
||||
def _wait_select(conn):
|
||||
|
@ -34,31 +37,41 @@ def _wait_select(conn):
|
|||
copy-pasted from psycopg2.extras.wait_select
|
||||
the default implementation doesn't define a timeout in the select calls
|
||||
"""
|
||||
while 1:
|
||||
try:
|
||||
state = conn.poll()
|
||||
if state == POLL_OK:
|
||||
break
|
||||
elif state == POLL_READ:
|
||||
select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
|
||||
elif state == POLL_WRITE:
|
||||
select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
|
||||
else:
|
||||
raise conn.OperationalError("bad state from poll: %s" % state)
|
||||
except KeyboardInterrupt:
|
||||
conn.cancel()
|
||||
# the loop will be broken by a server error
|
||||
continue
|
||||
except OSError as e:
|
||||
errno = e.args[0]
|
||||
if errno != 4:
|
||||
raise
|
||||
try:
|
||||
while 1:
|
||||
try:
|
||||
state = conn.poll()
|
||||
if state == POLL_OK:
|
||||
break
|
||||
elif state == POLL_READ:
|
||||
select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
|
||||
elif state == POLL_WRITE:
|
||||
select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
|
||||
else:
|
||||
raise conn.OperationalError("bad state from poll: %s" % state)
|
||||
except KeyboardInterrupt:
|
||||
conn.cancel()
|
||||
# the loop will be broken by a server error
|
||||
continue
|
||||
except OSError as e:
|
||||
errno = e.args[0]
|
||||
if errno != 4:
|
||||
raise
|
||||
except psycopg2.OperationalError:
|
||||
pass
|
||||
|
||||
|
||||
# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
|
||||
# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
|
||||
# See also https://github.com/psycopg/psycopg2/issues/468
|
||||
ext.set_wait_callback(_wait_select)
|
||||
def _set_wait_callback(is_virtual_database):
|
||||
global _wait_callback_is_set
|
||||
if _wait_callback_is_set:
|
||||
return
|
||||
_wait_callback_is_set = True
|
||||
if is_virtual_database:
|
||||
return
|
||||
# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
|
||||
# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
|
||||
# See also https://github.com/psycopg/psycopg2/issues/468
|
||||
ext.set_wait_callback(_wait_select)
|
||||
|
||||
|
||||
def register_date_typecasters(connection):
|
||||
|
@ -72,6 +85,8 @@ def register_date_typecasters(connection):
|
|||
|
||||
cursor = connection.cursor()
|
||||
cursor.execute("SELECT NULL::date")
|
||||
if cursor.description is None:
|
||||
return
|
||||
date_oid = cursor.description[0][1]
|
||||
cursor.execute("SELECT NULL::timestamp")
|
||||
timestamp_oid = cursor.description[0][1]
|
||||
|
@ -103,7 +118,7 @@ def register_json_typecasters(conn, loads_fn):
|
|||
try:
|
||||
psycopg2.extras.register_json(conn, loads=loads_fn, name=name)
|
||||
available.add(name)
|
||||
except psycopg2.ProgrammingError:
|
||||
except (psycopg2.ProgrammingError, psycopg2.errors.ProtocolViolation):
|
||||
pass
|
||||
|
||||
return available
|
||||
|
@ -127,6 +142,38 @@ def register_hstore_typecaster(conn):
|
|||
pass
|
||||
|
||||
|
||||
class ProtocolSafeCursor(psycopg2.extensions.cursor):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.protocol_error = False
|
||||
self.protocol_message = ""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __iter__(self):
|
||||
if self.protocol_error:
|
||||
raise StopIteration
|
||||
return super().__iter__()
|
||||
|
||||
def fetchall(self):
|
||||
if self.protocol_error:
|
||||
return [(self.protocol_message,)]
|
||||
return super().fetchall()
|
||||
|
||||
def fetchone(self):
|
||||
if self.protocol_error:
|
||||
return (self.protocol_message,)
|
||||
return super().fetchone()
|
||||
|
||||
def execute(self, sql, args=None):
|
||||
try:
|
||||
psycopg2.extensions.cursor.execute(self, sql, args)
|
||||
self.protocol_error = False
|
||||
self.protocol_message = ""
|
||||
except psycopg2.errors.ProtocolViolation as ex:
|
||||
self.protocol_error = True
|
||||
self.protocol_message = ex.pgerror
|
||||
_logger.debug("%s: %s" % (ex.__class__.__name__, ex))
|
||||
|
||||
|
||||
class PGExecute:
|
||||
|
||||
# The boolean argument to the current_schemas function indicates whether
|
||||
|
@ -190,8 +237,6 @@ class PGExecute:
|
|||
SELECT pg_catalog.pg_get_functiondef(f.f_oid)
|
||||
FROM f"""
|
||||
|
||||
version_query = "SELECT version();"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
database=None,
|
||||
|
@ -203,6 +248,7 @@ class PGExecute:
|
|||
**kwargs,
|
||||
):
|
||||
self._conn_params = {}
|
||||
self._is_virtual_database = None
|
||||
self.conn = None
|
||||
self.dbname = None
|
||||
self.user = None
|
||||
|
@ -214,6 +260,11 @@ class PGExecute:
|
|||
self.connect(database, user, password, host, port, dsn, **kwargs)
|
||||
self.reset_expanded = None
|
||||
|
||||
def is_virtual_database(self):
|
||||
if self._is_virtual_database is None:
|
||||
self._is_virtual_database = self.is_protocol_error()
|
||||
return self._is_virtual_database
|
||||
|
||||
def copy(self):
|
||||
"""Returns a clone of the current executor."""
|
||||
return self.__class__(**self._conn_params)
|
||||
|
@ -250,9 +301,9 @@ class PGExecute:
|
|||
)
|
||||
|
||||
conn_params.update({k: v for k, v in new_params.items() if v})
|
||||
conn_params["cursor_factory"] = ProtocolSafeCursor
|
||||
|
||||
conn = psycopg2.connect(**conn_params)
|
||||
cursor = conn.cursor()
|
||||
conn.set_client_encoding("utf8")
|
||||
|
||||
self._conn_params = conn_params
|
||||
|
@ -293,16 +344,22 @@ class PGExecute:
|
|||
self.extra_args = kwargs
|
||||
|
||||
if not self.host:
|
||||
self.host = self.get_socket_directory()
|
||||
self.host = (
|
||||
"pgbouncer"
|
||||
if self.is_virtual_database()
|
||||
else self.get_socket_directory()
|
||||
)
|
||||
|
||||
pid = self._select_one(cursor, "select pg_backend_pid()")[0]
|
||||
self.pid = pid
|
||||
self.pid = conn.get_backend_pid()
|
||||
self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1")
|
||||
self.server_version = conn.get_parameter_status("server_version")
|
||||
self.server_version = conn.get_parameter_status("server_version") or ""
|
||||
|
||||
register_date_typecasters(conn)
|
||||
register_json_typecasters(self.conn, self._json_typecaster)
|
||||
register_hstore_typecaster(self.conn)
|
||||
_set_wait_callback(self.is_virtual_database())
|
||||
|
||||
if not self.is_virtual_database():
|
||||
register_date_typecasters(conn)
|
||||
register_json_typecasters(self.conn, self._json_typecaster)
|
||||
register_hstore_typecaster(self.conn)
|
||||
|
||||
@property
|
||||
def short_host(self):
|
||||
|
@ -395,7 +452,13 @@ class PGExecute:
|
|||
# See https://github.com/dbcli/pgcli/issues/1014.
|
||||
cur = None
|
||||
try:
|
||||
for result in pgspecial.execute(cur, sql):
|
||||
response = pgspecial.execute(cur, sql)
|
||||
if cur and cur.protocol_error:
|
||||
yield None, None, None, cur.protocol_message, statement, False, False
|
||||
# this would close connection. We should reconnect.
|
||||
self.connect()
|
||||
continue
|
||||
for result in response:
|
||||
# e.g. execute_from_file already appends these
|
||||
if len(result) < 7:
|
||||
yield result + (sql, True, True)
|
||||
|
@ -453,6 +516,9 @@ class PGExecute:
|
|||
if cur.description:
|
||||
headers = [x[0] for x in cur.description]
|
||||
return title, cur, headers, cur.statusmessage
|
||||
elif cur.protocol_error:
|
||||
_logger.debug("Protocol error, unsupported command.")
|
||||
return title, None, None, cur.protocol_message
|
||||
else:
|
||||
_logger.debug("No rows in result.")
|
||||
return title, None, None, cur.statusmessage
|
||||
|
@ -617,6 +683,13 @@ class PGExecute:
|
|||
headers = [x[0] for x in cur.description]
|
||||
return cur.fetchall(), headers, cur.statusmessage
|
||||
|
||||
def is_protocol_error(self):
|
||||
query = "SELECT 1"
|
||||
with self.conn.cursor() as cur:
|
||||
_logger.debug("Simple Query. sql: %r", query)
|
||||
cur.execute(query)
|
||||
return bool(cur.protocol_error)
|
||||
|
||||
def get_socket_directory(self):
|
||||
with self.conn.cursor() as cur:
|
||||
_logger.debug(
|
||||
|
|
|
@ -22,5 +22,10 @@ def step_see_refresh_started(context):
|
|||
Wait to see refresh output.
|
||||
"""
|
||||
wrappers.expect_pager(
|
||||
context, "Auto-completion refresh started in the background.\r\n", timeout=2
|
||||
context,
|
||||
[
|
||||
"Auto-completion refresh started in the background.\r\n",
|
||||
"Auto-completion refresh restarted.\r\n",
|
||||
],
|
||||
timeout=2,
|
||||
)
|
||||
|
|
|
@ -39,9 +39,15 @@ def expect_exact(context, expected, timeout):
|
|||
|
||||
|
||||
def expect_pager(context, expected, timeout):
|
||||
formatted = expected if isinstance(expected, list) else [expected]
|
||||
formatted = [
|
||||
f"{context.conf['pager_boundary']}\r\n{t}{context.conf['pager_boundary']}\r\n"
|
||||
for t in formatted
|
||||
]
|
||||
|
||||
expect_exact(
|
||||
context,
|
||||
"{0}\r\n{1}{0}\r\n".format(context.conf["pager_boundary"], expected),
|
||||
formatted,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
||||
|
|
|
@ -37,7 +37,7 @@ def test_refresh_called_once(refresher):
|
|||
:return:
|
||||
"""
|
||||
callbacks = Mock()
|
||||
pgexecute = Mock()
|
||||
pgexecute = Mock(**{"is_virtual_database.return_value": False})
|
||||
special = Mock()
|
||||
|
||||
with patch.object(refresher, "_bg_refresh") as bg_refresh:
|
||||
|
@ -57,7 +57,7 @@ def test_refresh_called_twice(refresher):
|
|||
"""
|
||||
callbacks = Mock()
|
||||
|
||||
pgexecute = Mock()
|
||||
pgexecute = Mock(**{"is_virtual_database.return_value": False})
|
||||
special = Mock()
|
||||
|
||||
def dummy_bg_refresh(*args):
|
||||
|
@ -84,14 +84,12 @@ def test_refresh_with_callbacks(refresher):
|
|||
:param refresher:
|
||||
"""
|
||||
callbacks = [Mock()]
|
||||
pgexecute_class = Mock()
|
||||
pgexecute = Mock()
|
||||
pgexecute = Mock(**{"is_virtual_database.return_value": False})
|
||||
pgexecute.extra_args = {}
|
||||
special = Mock()
|
||||
|
||||
with patch("pgcli.completion_refresher.PGExecute", pgexecute_class):
|
||||
# Set refreshers to 0: we're not testing refresh logic here
|
||||
refresher.refreshers = {}
|
||||
refresher.refresh(pgexecute, special, callbacks)
|
||||
time.sleep(1) # Wait for the thread to work.
|
||||
assert callbacks[0].call_count == 1
|
||||
# Set refreshers to 0: we're not testing refresh logic here
|
||||
refresher.refreshers = {}
|
||||
refresher.refresh(pgexecute, special, callbacks)
|
||||
time.sleep(1) # Wait for the thread to work.
|
||||
assert callbacks[0].call_count == 1
|
||||
|
|
|
@ -520,6 +520,21 @@ class BrokenConnection:
|
|||
raise psycopg2.InterfaceError("I'm broken!")
|
||||
|
||||
|
||||
class VirtualCursor:
|
||||
"""Mock a cursor to virtual database like pgbouncer."""
|
||||
|
||||
def __init__(self):
|
||||
self.protocol_error = False
|
||||
self.protocol_message = ""
|
||||
self.description = None
|
||||
self.status = None
|
||||
self.statusmessage = "Error"
|
||||
|
||||
def execute(self, *args, **kwargs):
|
||||
self.protocol_error = True
|
||||
self.protocol_message = "Command not supported"
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_exit_without_active_connection(executor):
|
||||
quit_handler = MagicMock()
|
||||
|
@ -542,3 +557,12 @@ def test_exit_without_active_connection(executor):
|
|||
# an exception should be raised when running a query without active connection
|
||||
with pytest.raises(psycopg2.InterfaceError):
|
||||
run(executor, "select 1", pgspecial=pgspecial)
|
||||
|
||||
|
||||
@dbtest
|
||||
def test_virtual_database(executor):
|
||||
virtual_connection = MagicMock()
|
||||
virtual_connection.cursor.return_value = VirtualCursor()
|
||||
with patch.object(executor, "conn", virtual_connection):
|
||||
result = run(executor, "select 1")
|
||||
assert "Command not supported" in result
|
||||
|
|
Loading…
Reference in New Issue