1
0
Fork 0

Preserve comments when writing to config file.

This commit is contained in:
Irina Truong 2021-02-26 12:20:45 -08:00
parent c9fd72449e
commit d67272214c
2 changed files with 19 additions and 6 deletions

View File

@ -18,11 +18,15 @@ def config_location():
def load_config(usr_cfg, def_cfg=None):
cfg = ConfigObj()
cfg.merge(ConfigObj(def_cfg, interpolation=False))
cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8"))
# avoid config merges when possible. For writing, we need an umerged config instance.
# see https://github.com/dbcli/pgcli/issues/1240 and https://github.com/DiffSK/configobj/issues/171
if def_cfg:
cfg = ConfigObj()
cfg.merge(ConfigObj(def_cfg, interpolation=False))
cfg.merge(ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8"))
else:
cfg = ConfigObj(expanduser(usr_cfg), interpolation=False, encoding="utf-8")
cfg.filename = expanduser(usr_cfg)
return cfg
@ -46,12 +50,16 @@ def upgrade_config(config, def_config):
cfg.write()
def get_config_filename(pgclirc_file=None):
return pgclirc_file or "%sconfig" % config_location()
def get_config(pgclirc_file=None):
from pgcli import __file__ as package_root
package_root = os.path.dirname(package_root)
pgclirc_file = pgclirc_file or "%sconfig" % config_location()
pgclirc_file = get_config_filename(pgclirc_file)
default_config = os.path.join(package_root, "pgclirc")
write_default_config(default_config, pgclirc_file)

View File

@ -63,6 +63,7 @@ from .config import (
config_location,
ensure_dir_exists,
get_config,
get_config_filename
)
from .key_bindings import pgcli_bindings
from .packages.prompt_utils import confirm_destructive_query
@ -176,7 +177,11 @@ class PGCli:
# Load config.
c = self.config = get_config(pgclirc_file)
NamedQueries.instance = NamedQueries.from_config(self.config)
# at this point, config should be written to pgclirc_file if it did not exist. Read it.
self.config_writer = load_config(get_config_filename(pgclirc_file))
# make sure to use self.config_writer, not self.config
NamedQueries.instance = NamedQueries.from_config(self.config_writer)
self.logger = logging.getLogger(__name__)
self.initialize_logging()