From 651ab1d72f1856c4e16ade4b9155e5444193e1c4 Mon Sep 17 00:00:00 2001 From: Tim Van Baak Date: Tue, 15 Jun 2021 23:02:51 -0700 Subject: [PATCH] Refactor db to lazy-load at the top level --- amanuensis/cli/__init__.py | 26 ++++++++++++++++++++++++-- amanuensis/cli/admin.py | 16 +++++----------- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/amanuensis/cli/__init__.py b/amanuensis/cli/__init__.py index 7f50868..df98ead 100644 --- a/amanuensis/cli/__init__.py +++ b/amanuensis/cli/__init__.py @@ -1,10 +1,13 @@ -from argparse import ArgumentParser +from argparse import ArgumentParser, Namespace import logging import logging.config +import os +from typing import Callable import amanuensis.cli.admin import amanuensis.cli.lexicon import amanuensis.cli.user +from amanuensis.db import DbContext LOGGING_CONFIG = { @@ -76,6 +79,18 @@ def init_logger(args): logging.config.dictConfig(LOGGING_CONFIG) +def get_db_factory(parser: ArgumentParser, args: Namespace) -> Callable[[], DbContext]: + """Factory function for lazy-loading the database in subcommands.""" + + def get_db() -> DbContext: + """Lazy loader for the database connection.""" + if not os.path.exists(args.db_path): + parser.error(f"No database found at {args.db_path}") + return DbContext(path=args.db_path, echo=args.verbose) + + return get_db + + def main(): """CLI entry point""" # Set up the top-level parser @@ -83,8 +98,12 @@ def main(): parser.set_defaults( parser=parser, func=lambda args: parser.print_usage(), + get_db=None, ) parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output") + parser.add_argument( + "--db", dest="db_path", default="db.sqlite", help="Path to Amanuensis database" + ) # Add commands from cli submodules subparsers = parser.add_subparsers(metavar="COMMAND") @@ -92,7 +111,10 @@ def main(): add_subcommand(subparsers, amanuensis.cli.lexicon) add_subcommand(subparsers, amanuensis.cli.user) - # Parse args and execute the desired action + # Parse args and perform top-level arg processing args = parser.parse_args() init_logger(args) + args.get_db = get_db_factory(parser, args) + + # Execute the desired action args.func(args) diff --git a/amanuensis/cli/admin.py b/amanuensis/cli/admin.py index dfc92d0..7eb1d99 100644 --- a/amanuensis/cli/admin.py +++ b/amanuensis/cli/admin.py @@ -14,23 +14,17 @@ COMMAND_HELP = "Interact with Amanuensis." LOG = logging.getLogger(__name__) -@add_argument( - "path", metavar="DB_PATH", help="Path to where the database should be created" -) -@add_argument("--force", "-f", action="store_true", help="Overwrite existing database") -@add_argument("--verbose", "-v", action="store_true", help="Enable db echo") +@add_argument("--drop", "-d", action="store_true", help="Overwrite existing database") def command_init_db(args) -> int: """ Initialize the Amanuensis database. """ - # Check if force is required - if not args.force and os.path.exists(args.path): - args.parser.error(f"{args.path} already exists and --force was not specified") + if args.drop: + open(args.db_path, mode="w").close() # Initialize the database - LOG.info(f"Creating database at {args.path}") - db = DbContext(path=args.path, echo=args.verbose) - db.create_all() + LOG.info(f"Creating database at {args.db_path}") + args.get_db().create_all() LOG.info("Done") return 0