diff --git a/amanuensis/database.py b/amanuensis/database.py index 98afcaa..74ee911 100644 --- a/amanuensis/database.py +++ b/amanuensis/database.py @@ -6,43 +6,31 @@ import sqlite3 import uuid -# Register GUID as a known type with sqlite -sqlite3.register_converter('GUID', lambda h: uuid.UUID(hex=h)) -sqlite3.register_adapter(uuid.UUID, lambda u: u.hex) - class Uuid(TypeDecorator): - """ - A uuid backed by a char(32) field in sqlite. - """ - impl = CHAR(32) + """ + A uuid backed by a char(32) field in sqlite. + """ + impl = CHAR(32) - def process_bind_param(self, value, dialect): - if value is None: - return value - elif not isinstance(value, uuid.UUID): - return f'{uuid.UUID(value).int:32x}' - else: - return f'{value.int:32x}' + def process_bind_param(self, value, dialect): + if value is None: + return value + elif not isinstance(value, uuid.UUID): + return f'{uuid.UUID(value).int:32x}' + else: + return f'{value.int:32x}' - def process_result_value(self, value, dialect): - if value is None: - return value - elif not isinstance(value, uuid.UUID): - return uuid.UUID(value) - else: - return value + def process_result_value(self, value, dialect): + if value is None: + return value + elif not isinstance(value, uuid.UUID): + return uuid.UUID(value) + else: + return value -engine = create_engine('sqlite:///:memory:', connect_args={'detect_types': sqlite3.PARSE_DECLTYPES}) - -# Enable foreign key constraints -@event.listens_for(engine, "connect") -def set_sqlite_pragma(dbapi_connection, connection_record): - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA foreign_keys=ON") - cursor.close() # Define naming conventions for generated constraints -metadata = MetaData(bind=engine, naming_convention={ +metadata = MetaData(naming_convention={ "ix": "ix_%(column_0_label)s", "uq": "uq_%(table_name)s_%(column_0_name)s", "ck": "ck_%(table_name)s_%(constraint_name)s", @@ -50,8 +38,22 @@ metadata = MetaData(bind=engine, naming_convention={ "pk": "pk_%(table_name)s" }) -# Thread-safe db session -session = scoped_session(sessionmaker(bind=engine)) - # Base class for ORM models ModelBase = declarative_base(metadata=metadata) + + +class DbContext(): + def __init__(self, db_uri, debug=False): + # Create an engine and enable foreign key constraints in sqlite + self.engine = create_engine(db_uri, echo=debug) + @event.listens_for(self.engine, "connect") + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + # Create a thread-safe session factory + self.session = scoped_session(sessionmaker(bind=self.engine)) + + def create_all(self): + ModelBase.metadata.create_all(self.engine)