mirror of https://github.com/dbcli/pgcli
Add pg_service.conf handling (#1155)
* add parse_service_info * added tests * changelog + AUTHORS * py35
This commit is contained in:
parent
005fd2fcee
commit
f3ac559844
|
@ -7,6 +7,7 @@ Features:
|
|||
* Add `__main__.py` file to execute pgcli as a package directly (#1123).
|
||||
* Add support for ANSI escape sequences for coloring the prompt (#1122).
|
||||
* Add support for partitioned tables (relkind "p").
|
||||
* Add support for `pg_service.conf` files
|
||||
|
||||
Bug fixes:
|
||||
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
import platform
|
||||
import warnings
|
||||
from os.path import expanduser
|
||||
|
||||
from configobj import ConfigObj
|
||||
from pgspecial.namedqueries import NamedQueries
|
||||
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="psycopg2")
|
||||
|
@ -470,6 +473,21 @@ class PGCli(object):
|
|||
def connect_dsn(self, dsn, **kwargs):
|
||||
self.connect(dsn=dsn, **kwargs)
|
||||
|
||||
def connect_service(self, service, user):
|
||||
service_config, file = parse_service_info(service)
|
||||
if service_config is None:
|
||||
click.secho(
|
||||
"service '%s' was not found in %s" % (service, file), err=True, fg="red"
|
||||
)
|
||||
exit(1)
|
||||
self.connect(
|
||||
database=service_config.get("dbname"),
|
||||
host=service_config.get("host"),
|
||||
user=user or service_config.get("user"),
|
||||
port=service_config.get("port"),
|
||||
passwd=service_config.get("password"),
|
||||
)
|
||||
|
||||
def connect_uri(self, uri):
|
||||
kwargs = psycopg2.extensions.parse_dsn(uri)
|
||||
remap = {"dbname": "database", "password": "passwd"}
|
||||
|
@ -1248,7 +1266,11 @@ def cli(
|
|||
username = dbname
|
||||
database = dbname_opt or dbname or ""
|
||||
user = username_opt or username
|
||||
|
||||
service = None
|
||||
if database.startswith("service="):
|
||||
service = database[8:]
|
||||
elif os.getenv("PGSERVICE") is not None:
|
||||
service = os.getenv("PGSERVICE")
|
||||
# because option --list or -l are not supposed to have a db name
|
||||
if list_databases:
|
||||
database = "postgres"
|
||||
|
@ -1269,10 +1291,10 @@ def cli(
|
|||
pgcli.dsn_alias = dsn
|
||||
elif "://" in database:
|
||||
pgcli.connect_uri(database)
|
||||
elif "=" in database:
|
||||
elif "=" in database and service is None:
|
||||
pgcli.connect_dsn(database, user=user)
|
||||
elif os.environ.get("PGSERVICE", None):
|
||||
pgcli.connect_dsn("service={0}".format(os.environ["PGSERVICE"]))
|
||||
elif service is not None:
|
||||
pgcli.connect_service(service, user)
|
||||
else:
|
||||
pgcli.connect(database, host, user, port)
|
||||
|
||||
|
@ -1446,5 +1468,26 @@ def format_output(title, cur, headers, status, settings):
|
|||
return output
|
||||
|
||||
|
||||
def parse_service_info(service):
|
||||
service = service or os.getenv("PGSERVICE")
|
||||
service_file = os.getenv("PGSERVICEFILE")
|
||||
if not service_file:
|
||||
# try ~/.pg_service.conf (if that exists)
|
||||
if platform.system() == "Windows":
|
||||
service_file = os.getenv("PGSYSCONFDIR") + "\\pg_service.conf"
|
||||
elif os.getenv("PGSYSCONFDIR"):
|
||||
service_file = os.path.join(os.getenv("PGSYSCONFDIR"), ".pg_service.conf")
|
||||
else:
|
||||
service_file = expanduser("~/.pg_service.conf")
|
||||
if not service:
|
||||
# nothing to do
|
||||
return None, service_file
|
||||
service_file_config = ConfigObj(service_file)
|
||||
if service not in service_file_config:
|
||||
return None, service_file
|
||||
service_conf = service_file_config.get(service)
|
||||
return service_conf, service_file
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
|
|
|
@ -282,6 +282,54 @@ def test_quoted_db_uri(tmpdir):
|
|||
)
|
||||
|
||||
|
||||
def test_pg_service_file(tmpdir):
|
||||
|
||||
with mock.patch.object(PGCli, "connect") as mock_connect:
|
||||
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
|
||||
with open(tmpdir.join(".pg_service.conf").strpath, "w") as service_conf:
|
||||
service_conf.write(
|
||||
"""[myservice]
|
||||
host=a_host
|
||||
user=a_user
|
||||
port=5433
|
||||
password=much_secure
|
||||
dbname=a_dbname
|
||||
|
||||
[my_other_service]
|
||||
host=b_host
|
||||
user=b_user
|
||||
port=5435
|
||||
dbname=b_dbname
|
||||
"""
|
||||
)
|
||||
os.environ["PGSERVICEFILE"] = tmpdir.join(".pg_service.conf").strpath
|
||||
cli.connect_service("myservice", "another_user")
|
||||
mock_connect.assert_called_with(
|
||||
database="a_dbname",
|
||||
host="a_host",
|
||||
user="another_user",
|
||||
port="5433",
|
||||
passwd="much_secure",
|
||||
)
|
||||
|
||||
with mock.patch.object(PGExecute, "__init__") as mock_pgexecute:
|
||||
mock_pgexecute.return_value = None
|
||||
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
|
||||
os.environ["PGPASSWORD"] = "very_secure"
|
||||
cli.connect_service("my_other_service", None)
|
||||
mock_pgexecute.assert_called_with(
|
||||
"b_dbname",
|
||||
"b_user",
|
||||
"very_secure",
|
||||
"b_host",
|
||||
"5435",
|
||||
"",
|
||||
application_name="pgcli",
|
||||
)
|
||||
del os.environ["PGPASSWORD"]
|
||||
del os.environ["PGSERVICEFILE"]
|
||||
|
||||
|
||||
def test_ssl_db_uri(tmpdir):
|
||||
with mock.patch.object(PGCli, "connect") as mock_connect:
|
||||
cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile")))
|
||||
|
|
Loading…
Reference in New Issue