diff --git a/amanuensis/backend/user.py b/amanuensis/backend/user.py index 6c9d235..1fc5e17 100644 --- a/amanuensis/backend/user.py +++ b/amanuensis/backend/user.py @@ -5,7 +5,8 @@ User query interface import re from typing import Optional, Sequence -from sqlalchemy import select, func +from sqlalchemy import select, func, update +from werkzeug.security import generate_password_hash, check_password_hash from amanuensis.db import DbContext, User from amanuensis.errors import ArgumentError @@ -81,3 +82,19 @@ def get_user_by_username(db: DbContext, username: str) -> Optional[User]: """ user: User = db(select(User).where(User.username == username)).scalar_one_or_none() return user + + +def password_set(db: DbContext, username: str, new_password: str) -> None: + """Set a user's password.""" + password_hash = generate_password_hash(new_password) + db(update(User).where(User.username == username).values(password=password_hash)) + db.session.commit() + + +def password_check(db: DbContext, username: str, password: str) -> bool: + """Check if a password is correct.""" + user_password_hash: str = db( + select(User.password).where(User.username == username) + ).scalar_one() + return check_password_hash(user_password_hash, password) + diff --git a/amanuensis/cli/__init__.py b/amanuensis/cli/__init__.py index ae72d5f..eb9e111 100644 --- a/amanuensis/cli/__init__.py +++ b/amanuensis/cli/__init__.py @@ -66,7 +66,7 @@ def add_subcommand(subparsers, module) -> None: sc_name, help=sc_help, description=obj.__doc__ ) subcommand.set_defaults(func=obj) - for args, kwargs in obj.__dict__.get("add_argument", []): + for args, kwargs in reversed(obj.__dict__.get("add_argument", [])): subcommand.add_argument(*args, **kwargs) diff --git a/amanuensis/cli/user.py b/amanuensis/cli/user.py index 34a72ab..ccd9f9d 100644 --- a/amanuensis/cli/user.py +++ b/amanuensis/cli/user.py @@ -19,7 +19,8 @@ LOG = logging.getLogger(__name__) def command_create(args) -> int: """Create a user.""" db: DbContext = args.get_db() - userq.create(db, args.username, args.password, args.username, args.email, False) + userq.create(db, args.username, "password", args.username, args.email, False) + userq.password_set(db, args.username, args.password) return 0 @@ -71,8 +72,13 @@ def command_list(args): raise NotImplementedError() -def command_passwd(args): +@add_argument("username") +@add_argument("password") +def command_passwd(args) -> int: """ Set a user's password. """ - raise NotImplementedError() + db: DbContext = args.get_db() + userq.password_set(db, args.username, args.password) + LOG.info(f"Updated password for {args.username}") + return 0