diff --git a/amanuensis/backend/lexicon.py b/amanuensis/backend/lexicon.py index 5efd406..073a4cf 100644 --- a/amanuensis/backend/lexicon.py +++ b/amanuensis/backend/lexicon.py @@ -55,11 +55,6 @@ def create( return new_lexicon -def from_name(db: DbContext, name: str) -> Lexicon: - """Get a lexicon by its name.""" - return db(select(Lexicon).where(Lexicon.name == name)).scalar_one() - - def get_all(db: DbContext) -> Sequence[Lexicon]: """Get all lexicons.""" return db(select(Lexicon)).scalars() @@ -75,3 +70,8 @@ def get_joined(db: DbContext, user_id: int) -> Sequence[Lexicon]: def get_public(db: DbContext) -> Sequence[Lexicon]: """Get all publicly visible lexicons.""" return db(select(Lexicon).where(Lexicon.public == True)).scalars() + + +def try_from_name(db: DbContext, name: str) -> Optional[Lexicon]: + """Get a lexicon by its name, or None if no such lexicon was found.""" + return db(select(Lexicon).where(Lexicon.name == name)).scalar_one_or_none() diff --git a/amanuensis/backend/user.py b/amanuensis/backend/user.py index 1283fb2..5542c8c 100644 --- a/amanuensis/backend/user.py +++ b/amanuensis/backend/user.py @@ -71,24 +71,6 @@ def create( return new_user -def from_id(db: DbContext, user_id: int) -> Optional[User]: - """ - Get a user by the user's id. - Returns None if no user was found. - """ - user: User = db(select(User).where(User.id == user_id)).scalar_one_or_none() - return user - - -def from_username(db: DbContext, username: str) -> Optional[User]: - """ - Get a user by the user's username. - Returns None if no user was found. - """ - user: User = db(select(User).where(User.username == username)).scalar_one_or_none() - return user - - def get_all(db: DbContext) -> Sequence[User]: """Get all users.""" return db(select(User)).scalars() @@ -109,6 +91,16 @@ def password_check(db: DbContext, username: str, password: str) -> bool: return check_password_hash(user_password_hash, password) +def try_from_id(db: DbContext, user_id: int) -> Optional[User]: + """Get a user by the user's id, or None is no such user was found.""" + return db(select(User).where(User.id == user_id)).scalar_one_or_none() + + +def try_from_username(db: DbContext, username: str) -> Optional[User]: + """Get a user by the user's username, or None is no such user was found.""" + return db(select(User).where(User.username == username)).scalar_one_or_none() + + def update_logged_in(db: DbContext, username: str) -> None: """Bump the value of the last_login column for a user.""" db( diff --git a/amanuensis/cli/lexicon.py b/amanuensis/cli/lexicon.py index 2d580d2..4419593 100644 --- a/amanuensis/cli/lexicon.py +++ b/amanuensis/cli/lexicon.py @@ -24,9 +24,12 @@ def command_add(args) -> int: Add a user to a lexicon. """ db: DbContext = args.get_db() - lexicon = lexiq.from_name(db, args.lexicon) - user = userq.from_username(db, args.user) - assert user is not None + lexicon = lexiq.try_from_name(db, args.lexicon) + if not lexicon: + raise ValueError("Lexicon does not exist") + user = userq.try_from_username(db, args.user) + if not user: + raise ValueError("User does not exist") memq.create(db, user.id, lexicon.id, args.editor) LOG.info(f"Added {args.user} to lexicon {args.lexicon}") return 0 diff --git a/amanuensis/cli/user.py b/amanuensis/cli/user.py index 79518eb..e28455c 100644 --- a/amanuensis/cli/user.py +++ b/amanuensis/cli/user.py @@ -29,7 +29,7 @@ def command_create(args) -> int: def command_promote(args) -> int: """Make a user a site admin.""" db: DbContext = args.get_db() - user: Optional[User] = userq.from_username(db, args.username) + user: Optional[User] = userq.try_from_username(db, args.username) if user is None: args.parser.error("User not found") return -1 @@ -46,7 +46,7 @@ def command_promote(args) -> int: def command_demote(args): """Revoke a user's site admin status.""" db: DbContext = args.get_db() - user: Optional[User] = userq.from_username(db, args.username) + user: Optional[User] = userq.try_from_username(db, args.username) if user is None: args.parser.error("User not found") return -1 diff --git a/amanuensis/server/auth/__init__.py b/amanuensis/server/auth/__init__.py index f8fc748..b88466f 100644 --- a/amanuensis/server/auth/__init__.py +++ b/amanuensis/server/auth/__init__.py @@ -39,7 +39,7 @@ def get_login_manager() -> LoginManager: user_id = int(user_id_str) except: return None - return userq.from_id(g.db, user_id) + return userq.try_from_id(g.db, user_id) login_manager.user_loader(load_user) @@ -58,7 +58,7 @@ def login(): # POST with valid data username: str = form.username.data password: str = form.password.data - user: User = userq.from_username(g.db, username) + user: User = userq.try_from_username(g.db, username) if not user or not userq.password_check(g.db, username, password): # Bad creds flash("Login not recognized") diff --git a/tests/backend/test_lexicon.py b/tests/backend/test_lexicon.py index b2f07c9..3c73d9e 100644 --- a/tests/backend/test_lexicon.py +++ b/tests/backend/test_lexicon.py @@ -58,8 +58,8 @@ def test_lexicon_from(db: DbContext, make: ObjectFactory): """Test lexiq.from_*.""" lexicon1: Lexicon = make.lexicon() lexicon2: Lexicon = make.lexicon() - assert lexiq.from_name(db, lexicon1.name) == lexicon1 - assert lexiq.from_name(db, lexicon2.name) == lexicon2 + assert lexiq.try_from_name(db, lexicon1.name) == lexicon1 + assert lexiq.try_from_name(db, lexicon2.name) == lexicon2 def test_get_lexicon(db: DbContext, make: ObjectFactory): diff --git a/tests/backend/test_user.py b/tests/backend/test_user.py index e5fc571..f657a33 100644 --- a/tests/backend/test_user.py +++ b/tests/backend/test_user.py @@ -57,10 +57,10 @@ def test_user_from(db: DbContext, make): """Test userq.from_*.""" user1: User = make.user() user2: User = make.user() - assert userq.from_id(db, user1.id) == user1 - assert userq.from_username(db, user1.username) == user1 - assert userq.from_id(db, user2.id) == user2 - assert userq.from_username(db, user2.username) == user2 + assert userq.try_from_id(db, user1.id) == user1 + assert userq.try_from_username(db, user1.username) == user1 + assert userq.try_from_id(db, user2.id) == user2 + assert userq.try_from_username(db, user2.username) == user2 def test_user_password(db: DbContext, make): diff --git a/tests/conftest.py b/tests/conftest.py index 2dccf33..ff601bb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -45,7 +45,7 @@ class UserClient: def login(self, client: FlaskClient): """Log the user in.""" - user: Optional[User] = userq.from_id(self.db, self.user_id) + user: Optional[User] = userq.try_from_id(self.db, self.user_id) assert user is not None # Set the user's password so we know what it is later