diff --git a/.gitignore b/.gitignore index 170585df..b993cb9a 100644 --- a/.gitignore +++ b/.gitignore @@ -69,3 +69,6 @@ target/ .vscode/ venv/ + +.ropeproject/ + diff --git a/AUTHORS b/AUTHORS index 2bf68e87..c573369b 100644 --- a/AUTHORS +++ b/AUTHORS @@ -122,6 +122,7 @@ Contributors: * Daniele Varrazzo * Daniel Kukula (dkuku) * Kian-Meng Ang (kianmeng) + * Liu Zhao (astroshot) Creator: -------- diff --git a/changelog.rst b/changelog.rst index 43a4c043..7e466bcb 100644 --- a/changelog.rst +++ b/changelog.rst @@ -1,6 +1,11 @@ Upcoming: ========= +Features: +--------- + +* New formatter is added to export query result to sql format (such as sql-insert, sql-update) like mycli. + Bug fixes: ---------- diff --git a/pgcli/main.py b/pgcli/main.py index 3d42dca8..0fa264f4 100644 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -62,6 +62,7 @@ from .config import ( get_config_filename, ) from .key_bindings import pgcli_bindings +from .packages.formatter.sqlformatter import register_new_formatter from .packages.prompt_utils import confirm_destructive_query from .__init__ import __version__ @@ -283,6 +284,10 @@ class PGCli: self.ssh_tunnel_url = ssh_tunnel_url self.ssh_tunnel = None + # formatter setup + self.formatter = TabularOutputFormatter(format_name=c["main"]["table_format"]) + register_new_formatter(self.formatter) + def quit(self): raise PgCliQuitError @@ -940,6 +945,8 @@ class PGCli: logger = self.logger logger.debug("sql: %r", text) + # set query to formatter in order to parse table name + self.formatter.query = text all_success = True meta_changed = False # CREATE, ALTER, DROP, etc mutated = False # INSERT, DELETE, etc diff --git a/pgcli/packages/formatter/__init__.py b/pgcli/packages/formatter/__init__.py new file mode 100644 index 00000000..9bad5790 --- /dev/null +++ b/pgcli/packages/formatter/__init__.py @@ -0,0 +1 @@ +# coding=utf-8 diff --git a/pgcli/packages/formatter/sqlformatter.py b/pgcli/packages/formatter/sqlformatter.py new file mode 100644 index 00000000..5bf25fec --- /dev/null +++ b/pgcli/packages/formatter/sqlformatter.py @@ -0,0 +1,71 @@ +# coding=utf-8 + +from pgcli.packages.parseutils.tables import extract_tables + + +supported_formats = ( + "sql-insert", + "sql-update", + "sql-update-1", + "sql-update-2", +) + +preprocessors = () + + +def escape_for_sql_statement(value): + if isinstance(value, bytes): + return f"X'{value.hex()}'" + else: + return "'{}'".format(value) + + +def adapter(data, headers, table_format=None, **kwargs): + tables = extract_tables(formatter.query) + if len(tables) > 0: + table = tables[0] + if table[0]: + table_name = "{}.{}".format(*table[:2]) + else: + table_name = table[1] + else: + table_name = '"DUAL"' + if table_format == "sql-insert": + h = '", "'.join(headers) + yield 'INSERT INTO "{}" ("{}") VALUES'.format(table_name, h) + prefix = " " + for d in data: + values = ", ".join(escape_for_sql_statement(v) for i, v in enumerate(d)) + yield "{}({})".format(prefix, values) + if prefix == " ": + prefix = ", " + yield ";" + if table_format.startswith("sql-update"): + s = table_format.split("-") + keys = 1 + if len(s) > 2: + keys = int(s[-1]) + for d in data: + yield 'UPDATE "{}" SET'.format(table_name) + prefix = " " + for i, v in enumerate(d[keys:], keys): + yield '{}"{}" = {}'.format( + prefix, headers[i], escape_for_sql_statement(v) + ) + if prefix == " ": + prefix = ", " + f = '"{}" = {}' + where = ( + f.format(headers[i], escape_for_sql_statement(d[i])) + for i in range(keys) + ) + yield "WHERE {};".format(" AND ".join(where)) + + +def register_new_formatter(TabularOutputFormatter): + global formatter + formatter = TabularOutputFormatter + for sql_format in supported_formats: + TabularOutputFormatter.register_new_formatter( + sql_format, adapter, preprocessors, {"table_format": sql_format} + ) diff --git a/pgcli/pgclirc b/pgcli/pgclirc index 6654ce92..dcff63d2 100644 --- a/pgcli/pgclirc +++ b/pgcli/pgclirc @@ -95,7 +95,9 @@ show_bottom_toolbar = True # Table format. Possible values: psql, plain, simple, grid, fancy_grid, pipe, # ascii, double, github, orgtbl, rst, mediawiki, html, latex, latex_booktabs, -# textile, moinmoin, jira, vertical, tsv, csv. +# textile, moinmoin, jira, vertical, tsv, csv, sql-insert, sql-update, +# sql-update-1, sql-update-2 (formatter with sql-* prefix can format query +# output to executable insertion or updating sql). # Recommended: psql, fancy_grid and grid. table_format = psql diff --git a/tests/formatter/__init__.py b/tests/formatter/__init__.py new file mode 100644 index 00000000..9bad5790 --- /dev/null +++ b/tests/formatter/__init__.py @@ -0,0 +1 @@ +# coding=utf-8 diff --git a/tests/formatter/test_sqlformatter.py b/tests/formatter/test_sqlformatter.py new file mode 100644 index 00000000..b8cd9c2b --- /dev/null +++ b/tests/formatter/test_sqlformatter.py @@ -0,0 +1,111 @@ +# coding=utf-8 + +from pgcli.packages.formatter.sqlformatter import escape_for_sql_statement + +from cli_helpers.tabular_output import TabularOutputFormatter +from pgcli.packages.formatter.sqlformatter import adapter, register_new_formatter + + +def test_escape_for_sql_statement_bytes(): + bts = b"837124ab3e8dc0f" + escaped_bytes = escape_for_sql_statement(bts) + assert escaped_bytes == "X'383337313234616233653864633066'" + + +def test_escape_for_sql_statement_number(): + num = 2981 + escaped_bytes = escape_for_sql_statement(num) + assert escaped_bytes == "'2981'" + + +def test_escape_for_sql_statement_str(): + example_str = "example str" + escaped_bytes = escape_for_sql_statement(example_str) + assert escaped_bytes == "'example str'" + + +def test_output_sql_insert(): + global formatter + formatter = TabularOutputFormatter + register_new_formatter(formatter) + data = [ + [ + 1, + "Jackson", + "jackson_test@gmail.com", + "132454789", + "", + "2022-09-09 19:44:32.712343+08", + "2022-09-09 19:44:32.712343+08", + ] + ] + header = ["id", "name", "email", "phone", "description", "created_at", "updated_at"] + table_format = "sql-insert" + kwargs = { + "column_types": [int, str, str, str, str, str, str], + "sep_title": "RECORD {n}", + "sep_character": "-", + "sep_length": (1, 25), + "missing_value": "", + "integer_format": "", + "float_format": "", + "disable_numparse": True, + "preserve_whitespace": True, + "max_field_width": 500, + } + formatter.query = 'SELECT * FROM "user";' + output = adapter(data, header, table_format=table_format, **kwargs) + output_list = [l for l in output] + expected = [ + 'INSERT INTO "user" ("id", "name", "email", "phone", "description", "created_at", "updated_at") VALUES', + " ('1', 'Jackson', 'jackson_test@gmail.com', '132454789', '', " + + "'2022-09-09 19:44:32.712343+08', '2022-09-09 19:44:32.712343+08')", + ";", + ] + assert expected == output_list + + +def test_output_sql_update(): + global formatter + formatter = TabularOutputFormatter + register_new_formatter(formatter) + data = [ + [ + 1, + "Jackson", + "jackson_test@gmail.com", + "132454789", + "", + "2022-09-09 19:44:32.712343+08", + "2022-09-09 19:44:32.712343+08", + ] + ] + header = ["id", "name", "email", "phone", "description", "created_at", "updated_at"] + table_format = "sql-update" + kwargs = { + "column_types": [int, str, str, str, str, str, str], + "sep_title": "RECORD {n}", + "sep_character": "-", + "sep_length": (1, 25), + "missing_value": "", + "integer_format": "", + "float_format": "", + "disable_numparse": True, + "preserve_whitespace": True, + "max_field_width": 500, + } + formatter.query = 'SELECT * FROM "user";' + output = adapter(data, header, table_format=table_format, **kwargs) + output_list = [l for l in output] + print(output_list) + expected = [ + 'UPDATE "user" SET', + " \"name\" = 'Jackson'", + ", \"email\" = 'jackson_test@gmail.com'", + ", \"phone\" = '132454789'", + ", \"description\" = ''", + ", \"created_at\" = '2022-09-09 19:44:32.712343+08'", + ", \"updated_at\" = '2022-09-09 19:44:32.712343+08'", + "WHERE \"id\" = '1';", + ] + assert expected == output_list