Refactor path-to-uri calculation into DbContext

This commit is contained in:
Tim Van Baak 2021-06-15 22:51:23 -07:00
parent aaa7982f67
commit c4f133434d
4 changed files with 19 additions and 7 deletions

View File

@ -28,9 +28,8 @@ def command_init_db(args) -> int:
args.parser.error(f"{args.path} already exists and --force was not specified") args.parser.error(f"{args.path} already exists and --force was not specified")
# Initialize the database # Initialize the database
db_uri = f"sqlite:///{os.path.abspath(args.path)}" LOG.info(f"Creating database at {args.path}")
LOG.info(f"Creating database at {db_uri}") db = DbContext(path=args.path, echo=args.verbose)
db = DbContext(db_uri, debug=args.verbose)
db.create_all() db.create_all()
LOG.info("Done") LOG.info("Done")

View File

@ -1,6 +1,8 @@
""" """
Database connection setup Database connection setup
""" """
import os
from sqlalchemy import create_engine, MetaData, event from sqlalchemy import create_engine, MetaData, event
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.orm import scoped_session, sessionmaker
@ -27,9 +29,20 @@ ModelBase = declarative_base(metadata=metadata)
class DbContext: class DbContext:
def __init__(self, db_uri, debug=False): """Class encapsulating connections to the database."""
def __init__(self, path=None, uri=None, echo=False):
"""
Create a database context.
Exactly one of `path` and `uri` should be specified.
"""
if path and uri:
raise ValueError("Only one of path and uri may be specified")
db_uri = uri if uri else f"sqlite:///{os.path.abspath(path)}"
# Create an engine and enable foreign key constraints in sqlite # Create an engine and enable foreign key constraints in sqlite
self.engine = create_engine(db_uri, echo=debug) self.engine = create_engine(db_uri, echo=echo)
@event.listens_for(self.engine, "connect") @event.listens_for(self.engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record): def set_sqlite_pragma(dbapi_connection, connection_record):

View File

@ -30,7 +30,7 @@ def get_app(
# Create the database context, if one wasn't already given # Create the database context, if one wasn't already given
if db is None: if db is None:
db = DbContext(app.config["DATABASE_URI"]) db = DbContext(uri=app.config["DATABASE_URI"])
# Make the database connection available to requests via g # Make the database connection available to requests via g
def db_setup(): def db_setup():

View File

@ -15,7 +15,7 @@ from amanuensis.server import get_app
@pytest.fixture @pytest.fixture
def db() -> DbContext: def db() -> DbContext:
"""Provides an initialized database in memory.""" """Provides an initialized database in memory."""
db = DbContext("sqlite:///:memory:", debug=False) db = DbContext(uri="sqlite:///:memory:", echo=False)
db.create_all() db.create_all()
return db return db