Compare commits
2 Commits
fc28718846
...
8a4d666248
Author | SHA1 | Date |
---|---|---|
Tim Van Baak | 8a4d666248 | |
Tim Van Baak | 9f2c9d14d3 |
|
@ -6,43 +6,31 @@ import sqlite3
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
|
||||||
# Register GUID as a known type with sqlite
|
|
||||||
sqlite3.register_converter('GUID', lambda h: uuid.UUID(hex=h))
|
|
||||||
sqlite3.register_adapter(uuid.UUID, lambda u: u.hex)
|
|
||||||
|
|
||||||
class Uuid(TypeDecorator):
|
class Uuid(TypeDecorator):
|
||||||
"""
|
"""
|
||||||
A uuid backed by a char(32) field in sqlite.
|
A uuid backed by a char(32) field in sqlite.
|
||||||
"""
|
"""
|
||||||
impl = CHAR(32)
|
impl = CHAR(32)
|
||||||
|
|
||||||
def process_bind_param(self, value, dialect):
|
def process_bind_param(self, value, dialect):
|
||||||
if value is None:
|
if value is None:
|
||||||
return value
|
return value
|
||||||
elif not isinstance(value, uuid.UUID):
|
elif not isinstance(value, uuid.UUID):
|
||||||
return f'{uuid.UUID(value).int:32x}'
|
return f'{uuid.UUID(value).int:32x}'
|
||||||
else:
|
else:
|
||||||
return f'{value.int:32x}'
|
return f'{value.int:32x}'
|
||||||
|
|
||||||
def process_result_value(self, value, dialect):
|
def process_result_value(self, value, dialect):
|
||||||
if value is None:
|
if value is None:
|
||||||
return value
|
return value
|
||||||
elif not isinstance(value, uuid.UUID):
|
elif not isinstance(value, uuid.UUID):
|
||||||
return uuid.UUID(value)
|
return uuid.UUID(value)
|
||||||
else:
|
else:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
engine = create_engine('sqlite:///:memory:', connect_args={'detect_types': sqlite3.PARSE_DECLTYPES})
|
|
||||||
|
|
||||||
# Enable foreign key constraints
|
|
||||||
@event.listens_for(engine, "connect")
|
|
||||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
|
||||||
cursor = dbapi_connection.cursor()
|
|
||||||
cursor.execute("PRAGMA foreign_keys=ON")
|
|
||||||
cursor.close()
|
|
||||||
|
|
||||||
# Define naming conventions for generated constraints
|
# Define naming conventions for generated constraints
|
||||||
metadata = MetaData(bind=engine, naming_convention={
|
metadata = MetaData(naming_convention={
|
||||||
"ix": "ix_%(column_0_label)s",
|
"ix": "ix_%(column_0_label)s",
|
||||||
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
"uq": "uq_%(table_name)s_%(column_0_name)s",
|
||||||
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||||
|
@ -50,8 +38,22 @@ metadata = MetaData(bind=engine, naming_convention={
|
||||||
"pk": "pk_%(table_name)s"
|
"pk": "pk_%(table_name)s"
|
||||||
})
|
})
|
||||||
|
|
||||||
# Thread-safe db session
|
|
||||||
session = scoped_session(sessionmaker(bind=engine))
|
|
||||||
|
|
||||||
# Base class for ORM models
|
# Base class for ORM models
|
||||||
ModelBase = declarative_base(metadata=metadata)
|
ModelBase = declarative_base(metadata=metadata)
|
||||||
|
|
||||||
|
|
||||||
|
class DbContext():
|
||||||
|
def __init__(self, db_uri, debug=False):
|
||||||
|
# Create an engine and enable foreign key constraints in sqlite
|
||||||
|
self.engine = create_engine(db_uri, echo=debug)
|
||||||
|
@event.listens_for(self.engine, "connect")
|
||||||
|
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
# Create a thread-safe session factory
|
||||||
|
self.session = scoped_session(sessionmaker(bind=self.engine))
|
||||||
|
|
||||||
|
def create_all(self):
|
||||||
|
ModelBase.metadata.create_all(self.engine)
|
||||||
|
|
|
@ -1,5 +0,0 @@
|
||||||
from amanuensis import __version__
|
|
||||||
|
|
||||||
|
|
||||||
def test_version():
|
|
||||||
assert __version__ == '0.1.0'
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import func
|
||||||
|
|
||||||
|
from amanuensis.database import DbContext
|
||||||
|
from amanuensis.models import *
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def session():
|
||||||
|
db = DbContext('sqlite:///:memory:', debug=True)
|
||||||
|
db.create_all()
|
||||||
|
return db.session
|
||||||
|
|
||||||
|
|
||||||
|
def test_create(session):
|
||||||
|
"""Simple test that the database creates fine from scratch."""
|
||||||
|
assert session.query(func.count(User.id)).scalar() == 0
|
||||||
|
assert session.query(func.count(Lexicon.id)).scalar() == 0
|
||||||
|
assert session.query(func.count(Membership.id)).scalar() == 0
|
||||||
|
assert session.query(func.count(Character.id)).scalar() == 0
|
||||||
|
assert session.query(func.count(Article.id)).scalar() == 0
|
||||||
|
assert session.query(func.count(ArticleIndex.id)).scalar() == 0
|
||||||
|
assert session.query(func.count(ArticleIndexRule.id)).scalar() == 0
|
||||||
|
assert session.query(func.count(ArticleContentRule.id)).scalar() == 0
|
||||||
|
assert session.query(func.count(Post.id)).scalar() == 0
|
Loading…
Reference in New Issue