diff --git a/amanuensis/cli/admin.py b/amanuensis/cli/admin.py index c7e7f30..dfc92d0 100644 --- a/amanuensis/cli/admin.py +++ b/amanuensis/cli/admin.py @@ -28,9 +28,8 @@ def command_init_db(args) -> int: args.parser.error(f"{args.path} already exists and --force was not specified") # Initialize the database - db_uri = f"sqlite:///{os.path.abspath(args.path)}" - LOG.info(f"Creating database at {db_uri}") - db = DbContext(db_uri, debug=args.verbose) + LOG.info(f"Creating database at {args.path}") + db = DbContext(path=args.path, echo=args.verbose) db.create_all() LOG.info("Done") diff --git a/amanuensis/db/database.py b/amanuensis/db/database.py index 0fb68f3..90eaa49 100644 --- a/amanuensis/db/database.py +++ b/amanuensis/db/database.py @@ -1,6 +1,8 @@ """ Database connection setup """ +import os + from sqlalchemy import create_engine, MetaData, event from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import scoped_session, sessionmaker @@ -27,9 +29,20 @@ ModelBase = declarative_base(metadata=metadata) class DbContext: - def __init__(self, db_uri, debug=False): + """Class encapsulating connections to the database.""" + + def __init__(self, path=None, uri=None, echo=False): + """ + Create a database context. + Exactly one of `path` and `uri` should be specified. + """ + + if path and uri: + raise ValueError("Only one of path and uri may be specified") + db_uri = uri if uri else f"sqlite:///{os.path.abspath(path)}" + # Create an engine and enable foreign key constraints in sqlite - self.engine = create_engine(db_uri, echo=debug) + self.engine = create_engine(db_uri, echo=echo) @event.listens_for(self.engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): diff --git a/amanuensis/server/__init__.py b/amanuensis/server/__init__.py index e144471..eeb6a29 100644 --- a/amanuensis/server/__init__.py +++ b/amanuensis/server/__init__.py @@ -30,7 +30,7 @@ def get_app( # Create the database context, if one wasn't already given if db is None: - db = DbContext(app.config["DATABASE_URI"]) + db = DbContext(uri=app.config["DATABASE_URI"]) # Make the database connection available to requests via g def db_setup(): diff --git a/tests/conftest.py b/tests/conftest.py index 6328261..b5bc8d3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,7 @@ from amanuensis.server import get_app @pytest.fixture def db() -> DbContext: """Provides an initialized database in memory.""" - db = DbContext("sqlite:///:memory:", debug=False) + db = DbContext(uri="sqlite:///:memory:", echo=False) db.create_all() return db