diff --git a/amanuensis/db/database.py b/amanuensis/db/database.py index a2ec806..94da947 100644 --- a/amanuensis/db/database.py +++ b/amanuensis/db/database.py @@ -44,16 +44,26 @@ class DbContext: # Create an engine and enable foreign key constraints in sqlite self.engine = create_engine(self.db_uri, echo=echo) - @event.listens_for(self.engine, "connect") def set_sqlite_pragma(dbapi_connection, connection_record): cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.close() + event.listens_for(self.engine, "connect")(set_sqlite_pragma) + # Create a thread-safe session factory - self.session = scoped_session( - sessionmaker(bind=self.engine), scopefunc=get_ident - ) + sm = sessionmaker(bind=self.engine) + + def add_lifecycle_hook(sm, from_state, to_state): + def object_lifecycle_hook(_, obj): + print(f"object moved from {from_state} to {to_state}: {obj}") + + event.listens_for(sm, f"{from_state}_to_{to_state}")(object_lifecycle_hook) + + if echo: + add_lifecycle_hook(sm, "persistent", "detached") + + self.session = scoped_session(sm, scopefunc=get_ident) def __call__(self, *args, **kwargs): """Provides shortcut access to session.execute.""" diff --git a/tests/conftest.py b/tests/conftest.py index f8b2a29..2dccf33 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,10 @@ pytest test fixtures import os import pytest import tempfile +from typing import Optional +from bs4 import BeautifulSoup +from flask.testing import FlaskClient from sqlalchemy.orm.session import close_all_sessions import amanuensis.backend.character as charq @@ -12,7 +15,7 @@ import amanuensis.backend.lexicon as lexiq import amanuensis.backend.membership as memq import amanuensis.backend.user as userq from amanuensis.config import AmanuensisConfig -from amanuensis.db import DbContext +from amanuensis.db import DbContext, User, Lexicon, Membership, Character from amanuensis.server import get_app @@ -33,12 +36,53 @@ def db(request) -> DbContext: return db -@pytest.fixture -def make_user(db: DbContext): - """Provides a factory function for creating users, with valid default values.""" +class UserClient: + """Class encapsulating user web operations.""" - def user_factory(state={"nonce": 1}, **kwargs): - default_kwargs = { + def __init__(self, db: DbContext, user_id: int): + self.db = db + self.user_id = user_id + + def login(self, client: FlaskClient): + """Log the user in.""" + user: Optional[User] = userq.from_id(self.db, self.user_id) + assert user is not None + + # Set the user's password so we know what it is later + password = os.urandom(8).hex() + userq.password_set(self.db, user.username, password) + + # Log in + response = client.get("/auth/login/") + assert response.status_code == 200 + soup = BeautifulSoup(response.data, features="html.parser") + csrf_token = soup.find(id="csrf_token") + assert csrf_token is not None + response = client.post( + "/auth/login/", + data={ + "username": user.username, + "password": password, + "csrf_token": csrf_token["value"], + }, + ) + assert 300 <= response.status_code <= 399 + + def logout(self, client: FlaskClient): + """Log the user out.""" + response = client.get("/auth/logout/") + assert 300 <= response.status_code <= 399 + + +class ObjectFactory: + """Factory class.""" + + def __init__(self, db): + self.db = db + + def user(self, state={"nonce": 1}, **kwargs) -> User: + """Factory function for creating users, with valid default values.""" + default_kwargs: dict = { "username": f'test_user_{state["nonce"]}', "password": "password", "display_name": None, @@ -46,87 +90,54 @@ def make_user(db: DbContext): "is_site_admin": False, } state["nonce"] += 1 - updated_kwargs = {**default_kwargs, **kwargs} - return userq.create(db, **updated_kwargs) + updated_kwargs: dict = {**default_kwargs, **kwargs} + return userq.create(self.db, **updated_kwargs) - return user_factory - - -@pytest.fixture -def make_lexicon(db: DbContext): - """Provides a factory function for creating lexicons, with valid default values.""" - - def lexicon_factory(state={"nonce": 1}, **kwargs): - default_kwargs = { + def lexicon(self, state={"nonce": 1}, **kwargs) -> Lexicon: + """Factory function for creating lexicons, with valid default values.""" + default_kwargs: dict = { "name": f'Test_{state["nonce"]}', "title": None, "prompt": f'Test Lexicon game {state["nonce"]}', } state["nonce"] += 1 - updated_kwargs = {**default_kwargs, **kwargs} - lex = lexiq.create(db, **updated_kwargs) + updated_kwargs: dict = {**default_kwargs, **kwargs} + lex = lexiq.create(self.db, **updated_kwargs) lex.joinable = True - db.session.commit() + self.db.session.commit() return lex - return lexicon_factory - - -@pytest.fixture -def make_membership(db: DbContext): - """Provides a factory function for creating memberships, with valid default values.""" - - def membership_factory(**kwargs): - default_kwargs = { + def membership(self, **kwargs) -> Membership: + """Factory function for creating memberships, with valid default values.""" + default_kwargs: dict = { "is_editor": False, } - updated_kwargs = {**default_kwargs, **kwargs} - return memq.create(db, **updated_kwargs) + updated_kwargs: dict = {**default_kwargs, **kwargs} + return memq.create(self.db, **updated_kwargs) - return membership_factory - - -@pytest.fixture -def make_character(db: DbContext): - """Provides a factory function for creating characters, with valid default values.""" - - def character_factory(state={"nonce": 1}, **kwargs): - default_kwargs = { + def character(self, state={"nonce": 1}, **kwargs) -> Character: + """Factory function for creating characters, with valid default values.""" + default_kwargs: dict = { "name": f'Character {state["nonce"]}', "signature": None, } state["nonce"] += 1 - updated_kwargs = {**default_kwargs, **kwargs} - return charq.create(db, **updated_kwargs) + updated_kwargs: dict = {**default_kwargs, **kwargs} + return charq.create(self.db, **updated_kwargs) - return character_factory - - -class TestFactory: - def __init__(self, db, **factories): - self.db = db - self.factories = factories - - def __getattr__(self, name): - return self.factories[name] + def client(self, user_id: int) -> UserClient: + """Factory function for user test clients.""" + return UserClient(self.db, user_id) @pytest.fixture -def make( - db: DbContext, make_user, make_lexicon, make_membership, make_character -) -> TestFactory: - """Fixture that groups all factory fixtures together.""" - return TestFactory( - db, - user=make_user, - lexicon=make_lexicon, - membership=make_membership, - character=make_character, - ) +def make(db: DbContext) -> ObjectFactory: + """Fixture that provides a factory class.""" + return ObjectFactory(db) @pytest.fixture -def lexicon_with_editor(make): +def lexicon_with_editor(make: ObjectFactory): """Shortcut setup for a lexicon game with an editor.""" editor = make.user() assert editor diff --git a/tests/test_auth.py b/tests/test_auth.py index dc9a392..4cecfd1 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -11,7 +11,7 @@ def test_auth_circuit(app: Flask, make): """Test the user login/logout path.""" username: str = f"user_{os.urandom(8).hex()}" ub: bytes = username.encode("utf8") - user: User = make.user(username=username, password=username) + assert make.user(username=username, password=username) with app.test_client() as client: # User should not be logged in diff --git a/tests/test_home.py b/tests/test_home.py new file mode 100644 index 0000000..cb52d86 --- /dev/null +++ b/tests/test_home.py @@ -0,0 +1,45 @@ +import os +from urllib.parse import urlsplit + +from flask import Flask + +from amanuensis.db import DbContext, User, Lexicon + +from .conftest import ObjectFactory, UserClient + + +def test_game_visibility(db: DbContext, app: Flask, make: ObjectFactory): + """Test lexicon visibility settings.""" + user: User = make.user() + auth: UserClient = make.client(user.id) + + public_joined: Lexicon = make.lexicon() + public_joined.public = True + make.membership(user_id=auth.user_id, lexicon_id=public_joined.id) + public_joined_title = public_joined.full_title + + private_joined: Lexicon = make.lexicon() + private_joined.public = False + make.membership(user_id=auth.user_id, lexicon_id=private_joined.id) + private_joined_title = private_joined.full_title + + public_open: Lexicon = make.lexicon() + public_open.public = True + db.session.commit() + public_open_title = public_open.full_title + + private_open: Lexicon = make.lexicon() + private_open.public = False + db.session.commit() + private_open_title = private_open.full_title + + with app.test_client() as client: + auth.login(client) + + # Check that lexicons appear if they should + response = client.get("/home/") + assert response.status_code == 200 + assert public_joined_title.encode("utf8") in response.data + assert private_joined_title.encode("utf8") in response.data + assert public_open_title.encode("utf8") in response.data + assert private_open_title.encode("utf8") not in response.data