diff --git a/amanuensis/backend/user.py b/amanuensis/backend/user.py index 4ff2264..6c9d235 100644 --- a/amanuensis/backend/user.py +++ b/amanuensis/backend/user.py @@ -3,7 +3,7 @@ User query interface """ import re -from typing import Sequence +from typing import Optional, Sequence from sqlalchemy import select, func @@ -72,3 +72,12 @@ def create( def get_all_users(db: DbContext) -> Sequence[User]: """Get all users.""" return db(select(User)).scalars() + + +def get_user_by_username(db: DbContext, username: str) -> Optional[User]: + """ + Get a user by the user's username. + Returns None if no user was found. + """ + user: User = db(select(User).where(User.username == username)).scalar_one_or_none() + return user diff --git a/amanuensis/cli/__init__.py b/amanuensis/cli/__init__.py index df98ead..ae72d5f 100644 --- a/amanuensis/cli/__init__.py +++ b/amanuensis/cli/__init__.py @@ -79,13 +79,13 @@ def init_logger(args): logging.config.dictConfig(LOGGING_CONFIG) -def get_db_factory(parser: ArgumentParser, args: Namespace) -> Callable[[], DbContext]: +def get_db_factory(args: Namespace) -> Callable[[], DbContext]: """Factory function for lazy-loading the database in subcommands.""" def get_db() -> DbContext: """Lazy loader for the database connection.""" if not os.path.exists(args.db_path): - parser.error(f"No database found at {args.db_path}") + args.parser.error(f"No database found at {args.db_path}") return DbContext(path=args.db_path, echo=args.verbose) return get_db @@ -114,7 +114,7 @@ def main(): # Parse args and perform top-level arg processing args = parser.parse_args() init_logger(args) - args.get_db = get_db_factory(parser, args) + args.get_db = get_db_factory(args) # Execute the desired action args.func(args) diff --git a/amanuensis/cli/user.py b/amanuensis/cli/user.py index 91d16ce..34a72ab 100644 --- a/amanuensis/cli/user.py +++ b/amanuensis/cli/user.py @@ -1,4 +1,8 @@ import logging +from typing import Optional + +import amanuensis.backend.user as userq +from amanuensis.db import DbContext, User from .helpers import add_argument @@ -9,11 +13,48 @@ COMMAND_HELP = "Interact with users." LOG = logging.getLogger(__name__) -def command_create(args): - """ - Create a user. - """ - raise NotImplementedError() +@add_argument("username") +@add_argument("--password", default="password") +@add_argument("--email", default="") +def command_create(args) -> int: + """Create a user.""" + db: DbContext = args.get_db() + userq.create(db, args.username, args.password, args.username, args.email, False) + return 0 + + +@add_argument("username") +def command_promote(args) -> int: + """Make a user a site admin.""" + db: DbContext = args.get_db() + user: Optional[User] = userq.get_user_by_username(db, args.username) + if user is None: + args.parser.error("User not found") + return -1 + if user.is_site_admin: + LOG.info(f"{user.username} is already a site admin.") + else: + user.is_site_admin = True + LOG.info(f"Promoting {user.username} to site admin.") + db.session.commit() + return 0 + + +@add_argument("username") +def command_demote(args): + """Revoke a user's site admin status.""" + db: DbContext = args.get_db() + user: Optional[User] = userq.get_user_by_username(db, args.username) + if user is None: + args.parser.error("User not found") + return -1 + if not user.is_site_admin: + LOG.info(f"{user.username} is not a site admin.") + else: + user.is_site_admin = False + LOG.info(f"Revoking site admin status for {user.username}.") + db.session.commit() + return 0 def command_delete(args):