Refactor db to lazy-load at the top level
This commit is contained in:
parent
c4f133434d
commit
651ab1d72f
|
@ -1,10 +1,13 @@
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser, Namespace
|
||||||
import logging
|
import logging
|
||||||
import logging.config
|
import logging.config
|
||||||
|
import os
|
||||||
|
from typing import Callable
|
||||||
|
|
||||||
import amanuensis.cli.admin
|
import amanuensis.cli.admin
|
||||||
import amanuensis.cli.lexicon
|
import amanuensis.cli.lexicon
|
||||||
import amanuensis.cli.user
|
import amanuensis.cli.user
|
||||||
|
from amanuensis.db import DbContext
|
||||||
|
|
||||||
|
|
||||||
LOGGING_CONFIG = {
|
LOGGING_CONFIG = {
|
||||||
|
@ -76,6 +79,18 @@ def init_logger(args):
|
||||||
logging.config.dictConfig(LOGGING_CONFIG)
|
logging.config.dictConfig(LOGGING_CONFIG)
|
||||||
|
|
||||||
|
|
||||||
|
def get_db_factory(parser: ArgumentParser, args: Namespace) -> Callable[[], DbContext]:
|
||||||
|
"""Factory function for lazy-loading the database in subcommands."""
|
||||||
|
|
||||||
|
def get_db() -> DbContext:
|
||||||
|
"""Lazy loader for the database connection."""
|
||||||
|
if not os.path.exists(args.db_path):
|
||||||
|
parser.error(f"No database found at {args.db_path}")
|
||||||
|
return DbContext(path=args.db_path, echo=args.verbose)
|
||||||
|
|
||||||
|
return get_db
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""CLI entry point"""
|
"""CLI entry point"""
|
||||||
# Set up the top-level parser
|
# Set up the top-level parser
|
||||||
|
@ -83,8 +98,12 @@ def main():
|
||||||
parser.set_defaults(
|
parser.set_defaults(
|
||||||
parser=parser,
|
parser=parser,
|
||||||
func=lambda args: parser.print_usage(),
|
func=lambda args: parser.print_usage(),
|
||||||
|
get_db=None,
|
||||||
)
|
)
|
||||||
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
|
parser.add_argument("--verbose", "-v", action="store_true", help="Verbose output")
|
||||||
|
parser.add_argument(
|
||||||
|
"--db", dest="db_path", default="db.sqlite", help="Path to Amanuensis database"
|
||||||
|
)
|
||||||
|
|
||||||
# Add commands from cli submodules
|
# Add commands from cli submodules
|
||||||
subparsers = parser.add_subparsers(metavar="COMMAND")
|
subparsers = parser.add_subparsers(metavar="COMMAND")
|
||||||
|
@ -92,7 +111,10 @@ def main():
|
||||||
add_subcommand(subparsers, amanuensis.cli.lexicon)
|
add_subcommand(subparsers, amanuensis.cli.lexicon)
|
||||||
add_subcommand(subparsers, amanuensis.cli.user)
|
add_subcommand(subparsers, amanuensis.cli.user)
|
||||||
|
|
||||||
# Parse args and execute the desired action
|
# Parse args and perform top-level arg processing
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
init_logger(args)
|
init_logger(args)
|
||||||
|
args.get_db = get_db_factory(parser, args)
|
||||||
|
|
||||||
|
# Execute the desired action
|
||||||
args.func(args)
|
args.func(args)
|
||||||
|
|
|
@ -14,23 +14,17 @@ COMMAND_HELP = "Interact with Amanuensis."
|
||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@add_argument(
|
@add_argument("--drop", "-d", action="store_true", help="Overwrite existing database")
|
||||||
"path", metavar="DB_PATH", help="Path to where the database should be created"
|
|
||||||
)
|
|
||||||
@add_argument("--force", "-f", action="store_true", help="Overwrite existing database")
|
|
||||||
@add_argument("--verbose", "-v", action="store_true", help="Enable db echo")
|
|
||||||
def command_init_db(args) -> int:
|
def command_init_db(args) -> int:
|
||||||
"""
|
"""
|
||||||
Initialize the Amanuensis database.
|
Initialize the Amanuensis database.
|
||||||
"""
|
"""
|
||||||
# Check if force is required
|
if args.drop:
|
||||||
if not args.force and os.path.exists(args.path):
|
open(args.db_path, mode="w").close()
|
||||||
args.parser.error(f"{args.path} already exists and --force was not specified")
|
|
||||||
|
|
||||||
# Initialize the database
|
# Initialize the database
|
||||||
LOG.info(f"Creating database at {args.path}")
|
LOG.info(f"Creating database at {args.db_path}")
|
||||||
db = DbContext(path=args.path, echo=args.verbose)
|
args.get_db().create_all()
|
||||||
db.create_all()
|
|
||||||
|
|
||||||
LOG.info("Done")
|
LOG.info("Done")
|
||||||
return 0
|
return 0
|
||||||
|
|
Loading…
Reference in New Issue