1
0
Fork 0

Another attempt to fix pgbouncer error (1093.) (#1097)

* Another attempt to fix pgbouncer error (1093.)

* Fixes for various pgbouncer problems.

* different approach with custom cursor.

* Fix rebase.

* Missed this.

* Fix completion refresher test.

* Black.

* Unused import.

* Changelog.

* Fix race condition in test.

* Switch from is_pgbouncer to more generic is_virtual_database, and duck-type it. Add very dumb unit test for virtual cursor.

* Remove debugger code.
This commit is contained in:
Irina Truong 2021-05-21 15:32:34 -07:00 committed by GitHub
parent d8532df22e
commit e0a4c18c4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 163 additions and 56 deletions

View File

@ -19,6 +19,7 @@ Bug fixes:
* Fix pager not being used when output format is set to csv. (#1238)
* Add function literals random, generate_series, generate_subscripts
* Fix ANSI escape codes in first line make the cli choose expanded output incorrectly
* Fix pgcli crashing with virtual `pgbouncer` database. (#1093)
3.1.0
=====

View File

@ -3,7 +3,6 @@ import os
from collections import OrderedDict
from .pgcompleter import PGCompleter
from .pgexecute import PGExecute
class CompletionRefresher:
@ -27,6 +26,10 @@ class CompletionRefresher:
has completed the refresh. The newly created completion
object will be passed in as an argument to each callback.
"""
if executor.is_virtual_database():
# do nothing
return [(None, None, None, "Auto-completion refresh can't be started.")]
if self.is_refreshing():
self._restart_refresh.set()
return [(None, None, None, "Auto-completion refresh restarted.")]

View File

@ -988,16 +988,13 @@ class PGCli:
callback = functools.partial(
self._on_completions_refreshed, persist_priorities=persist_priorities
)
self.completion_refresher.refresh(
return self.completion_refresher.refresh(
self.pgexecute,
self.pgspecial,
callback,
history=history,
settings=self.settings,
)
return [
(None, None, None, "Auto-completion refresh started in the background.")
]
def _on_completions_refreshed(self, new_completer, persist_priorities):
self._swap_completer_objects(new_completer, persist_priorities)

View File

@ -1,13 +1,15 @@
import traceback
import logging
import select
import traceback
import pgspecial as special
import psycopg2
import psycopg2.extras
import psycopg2.errorcodes
import psycopg2.extensions as ext
import psycopg2.extras
import sqlparse
import pgspecial as special
import select
from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE, make_dsn
from .packages.parseutils.meta import FunctionMetadata, ForeignKey
_logger = logging.getLogger(__name__)
@ -27,6 +29,7 @@ ext.register_type(ext.new_type((17,), "BYTEA_TEXT", psycopg2.STRING))
# TODO: Get default timeout from pgclirc?
_WAIT_SELECT_TIMEOUT = 1
_wait_callback_is_set = False
def _wait_select(conn):
@ -34,31 +37,41 @@ def _wait_select(conn):
copy-pasted from psycopg2.extras.wait_select
the default implementation doesn't define a timeout in the select calls
"""
while 1:
try:
state = conn.poll()
if state == POLL_OK:
break
elif state == POLL_READ:
select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
elif state == POLL_WRITE:
select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
else:
raise conn.OperationalError("bad state from poll: %s" % state)
except KeyboardInterrupt:
conn.cancel()
# the loop will be broken by a server error
continue
except OSError as e:
errno = e.args[0]
if errno != 4:
raise
try:
while 1:
try:
state = conn.poll()
if state == POLL_OK:
break
elif state == POLL_READ:
select.select([conn.fileno()], [], [], _WAIT_SELECT_TIMEOUT)
elif state == POLL_WRITE:
select.select([], [conn.fileno()], [], _WAIT_SELECT_TIMEOUT)
else:
raise conn.OperationalError("bad state from poll: %s" % state)
except KeyboardInterrupt:
conn.cancel()
# the loop will be broken by a server error
continue
except OSError as e:
errno = e.args[0]
if errno != 4:
raise
except psycopg2.OperationalError:
pass
# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
# See also https://github.com/psycopg/psycopg2/issues/468
ext.set_wait_callback(_wait_select)
def _set_wait_callback(is_virtual_database):
global _wait_callback_is_set
if _wait_callback_is_set:
return
_wait_callback_is_set = True
if is_virtual_database:
return
# When running a query, make pressing CTRL+C raise a KeyboardInterrupt
# See http://initd.org/psycopg/articles/2014/07/20/cancelling-postgresql-statements-python/
# See also https://github.com/psycopg/psycopg2/issues/468
ext.set_wait_callback(_wait_select)
def register_date_typecasters(connection):
@ -72,6 +85,8 @@ def register_date_typecasters(connection):
cursor = connection.cursor()
cursor.execute("SELECT NULL::date")
if cursor.description is None:
return
date_oid = cursor.description[0][1]
cursor.execute("SELECT NULL::timestamp")
timestamp_oid = cursor.description[0][1]
@ -103,7 +118,7 @@ def register_json_typecasters(conn, loads_fn):
try:
psycopg2.extras.register_json(conn, loads=loads_fn, name=name)
available.add(name)
except psycopg2.ProgrammingError:
except (psycopg2.ProgrammingError, psycopg2.errors.ProtocolViolation):
pass
return available
@ -127,6 +142,38 @@ def register_hstore_typecaster(conn):
pass
class ProtocolSafeCursor(psycopg2.extensions.cursor):
def __init__(self, *args, **kwargs):
self.protocol_error = False
self.protocol_message = ""
super().__init__(*args, **kwargs)
def __iter__(self):
if self.protocol_error:
raise StopIteration
return super().__iter__()
def fetchall(self):
if self.protocol_error:
return [(self.protocol_message,)]
return super().fetchall()
def fetchone(self):
if self.protocol_error:
return (self.protocol_message,)
return super().fetchone()
def execute(self, sql, args=None):
try:
psycopg2.extensions.cursor.execute(self, sql, args)
self.protocol_error = False
self.protocol_message = ""
except psycopg2.errors.ProtocolViolation as ex:
self.protocol_error = True
self.protocol_message = ex.pgerror
_logger.debug("%s: %s" % (ex.__class__.__name__, ex))
class PGExecute:
# The boolean argument to the current_schemas function indicates whether
@ -190,8 +237,6 @@ class PGExecute:
SELECT pg_catalog.pg_get_functiondef(f.f_oid)
FROM f"""
version_query = "SELECT version();"
def __init__(
self,
database=None,
@ -203,6 +248,7 @@ class PGExecute:
**kwargs,
):
self._conn_params = {}
self._is_virtual_database = None
self.conn = None
self.dbname = None
self.user = None
@ -214,6 +260,11 @@ class PGExecute:
self.connect(database, user, password, host, port, dsn, **kwargs)
self.reset_expanded = None
def is_virtual_database(self):
if self._is_virtual_database is None:
self._is_virtual_database = self.is_protocol_error()
return self._is_virtual_database
def copy(self):
"""Returns a clone of the current executor."""
return self.__class__(**self._conn_params)
@ -250,9 +301,9 @@ class PGExecute:
)
conn_params.update({k: v for k, v in new_params.items() if v})
conn_params["cursor_factory"] = ProtocolSafeCursor
conn = psycopg2.connect(**conn_params)
cursor = conn.cursor()
conn.set_client_encoding("utf8")
self._conn_params = conn_params
@ -293,16 +344,22 @@ class PGExecute:
self.extra_args = kwargs
if not self.host:
self.host = self.get_socket_directory()
self.host = (
"pgbouncer"
if self.is_virtual_database()
else self.get_socket_directory()
)
pid = self._select_one(cursor, "select pg_backend_pid()")[0]
self.pid = pid
self.pid = conn.get_backend_pid()
self.superuser = conn.get_parameter_status("is_superuser") in ("on", "1")
self.server_version = conn.get_parameter_status("server_version")
self.server_version = conn.get_parameter_status("server_version") or ""
register_date_typecasters(conn)
register_json_typecasters(self.conn, self._json_typecaster)
register_hstore_typecaster(self.conn)
_set_wait_callback(self.is_virtual_database())
if not self.is_virtual_database():
register_date_typecasters(conn)
register_json_typecasters(self.conn, self._json_typecaster)
register_hstore_typecaster(self.conn)
@property
def short_host(self):
@ -395,7 +452,13 @@ class PGExecute:
# See https://github.com/dbcli/pgcli/issues/1014.
cur = None
try:
for result in pgspecial.execute(cur, sql):
response = pgspecial.execute(cur, sql)
if cur and cur.protocol_error:
yield None, None, None, cur.protocol_message, statement, False, False
# this would close connection. We should reconnect.
self.connect()
continue
for result in response:
# e.g. execute_from_file already appends these
if len(result) < 7:
yield result + (sql, True, True)
@ -453,6 +516,9 @@ class PGExecute:
if cur.description:
headers = [x[0] for x in cur.description]
return title, cur, headers, cur.statusmessage
elif cur.protocol_error:
_logger.debug("Protocol error, unsupported command.")
return title, None, None, cur.protocol_message
else:
_logger.debug("No rows in result.")
return title, None, None, cur.statusmessage
@ -617,6 +683,13 @@ class PGExecute:
headers = [x[0] for x in cur.description]
return cur.fetchall(), headers, cur.statusmessage
def is_protocol_error(self):
query = "SELECT 1"
with self.conn.cursor() as cur:
_logger.debug("Simple Query. sql: %r", query)
cur.execute(query)
return bool(cur.protocol_error)
def get_socket_directory(self):
with self.conn.cursor() as cur:
_logger.debug(

View File

@ -22,5 +22,10 @@ def step_see_refresh_started(context):
Wait to see refresh output.
"""
wrappers.expect_pager(
context, "Auto-completion refresh started in the background.\r\n", timeout=2
context,
[
"Auto-completion refresh started in the background.\r\n",
"Auto-completion refresh restarted.\r\n",
],
timeout=2,
)

View File

@ -39,9 +39,15 @@ def expect_exact(context, expected, timeout):
def expect_pager(context, expected, timeout):
formatted = expected if isinstance(expected, list) else [expected]
formatted = [
f"{context.conf['pager_boundary']}\r\n{t}{context.conf['pager_boundary']}\r\n"
for t in formatted
]
expect_exact(
context,
"{0}\r\n{1}{0}\r\n".format(context.conf["pager_boundary"], expected),
formatted,
timeout=timeout,
)

View File

@ -37,7 +37,7 @@ def test_refresh_called_once(refresher):
:return:
"""
callbacks = Mock()
pgexecute = Mock()
pgexecute = Mock(**{"is_virtual_database.return_value": False})
special = Mock()
with patch.object(refresher, "_bg_refresh") as bg_refresh:
@ -57,7 +57,7 @@ def test_refresh_called_twice(refresher):
"""
callbacks = Mock()
pgexecute = Mock()
pgexecute = Mock(**{"is_virtual_database.return_value": False})
special = Mock()
def dummy_bg_refresh(*args):
@ -84,14 +84,12 @@ def test_refresh_with_callbacks(refresher):
:param refresher:
"""
callbacks = [Mock()]
pgexecute_class = Mock()
pgexecute = Mock()
pgexecute = Mock(**{"is_virtual_database.return_value": False})
pgexecute.extra_args = {}
special = Mock()
with patch("pgcli.completion_refresher.PGExecute", pgexecute_class):
# Set refreshers to 0: we're not testing refresh logic here
refresher.refreshers = {}
refresher.refresh(pgexecute, special, callbacks)
time.sleep(1) # Wait for the thread to work.
assert callbacks[0].call_count == 1
# Set refreshers to 0: we're not testing refresh logic here
refresher.refreshers = {}
refresher.refresh(pgexecute, special, callbacks)
time.sleep(1) # Wait for the thread to work.
assert callbacks[0].call_count == 1

View File

@ -520,6 +520,21 @@ class BrokenConnection:
raise psycopg2.InterfaceError("I'm broken!")
class VirtualCursor:
"""Mock a cursor to virtual database like pgbouncer."""
def __init__(self):
self.protocol_error = False
self.protocol_message = ""
self.description = None
self.status = None
self.statusmessage = "Error"
def execute(self, *args, **kwargs):
self.protocol_error = True
self.protocol_message = "Command not supported"
@dbtest
def test_exit_without_active_connection(executor):
quit_handler = MagicMock()
@ -542,3 +557,12 @@ def test_exit_without_active_connection(executor):
# an exception should be raised when running a query without active connection
with pytest.raises(psycopg2.InterfaceError):
run(executor, "select 1", pgspecial=pgspecial)
@dbtest
def test_virtual_database(executor):
virtual_connection = MagicMock()
virtual_connection.cursor.return_value = VirtualCursor()
with patch.object(executor, "conn", virtual_connection):
result = run(executor, "select 1")
assert "Command not supported" in result