Get home page and login working #14
|
@ -5,7 +5,8 @@ User query interface
|
|||
import re
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from sqlalchemy import select, func, update
|
||||
from werkzeug.security import generate_password_hash, check_password_hash
|
||||
|
||||
from amanuensis.db import DbContext, User
|
||||
from amanuensis.errors import ArgumentError
|
||||
|
@ -81,3 +82,19 @@ def get_user_by_username(db: DbContext, username: str) -> Optional[User]:
|
|||
"""
|
||||
user: User = db(select(User).where(User.username == username)).scalar_one_or_none()
|
||||
return user
|
||||
|
||||
|
||||
def password_set(db: DbContext, username: str, new_password: str) -> None:
|
||||
"""Set a user's password."""
|
||||
password_hash = generate_password_hash(new_password)
|
||||
db(update(User).where(User.username == username).values(password=password_hash))
|
||||
db.session.commit()
|
||||
|
||||
|
||||
def password_check(db: DbContext, username: str, password: str) -> bool:
|
||||
"""Check if a password is correct."""
|
||||
user_password_hash: str = db(
|
||||
select(User.password).where(User.username == username)
|
||||
).scalar_one()
|
||||
return check_password_hash(user_password_hash, password)
|
||||
|
||||
|
|
|
@ -66,7 +66,7 @@ def add_subcommand(subparsers, module) -> None:
|
|||
sc_name, help=sc_help, description=obj.__doc__
|
||||
)
|
||||
subcommand.set_defaults(func=obj)
|
||||
for args, kwargs in obj.__dict__.get("add_argument", []):
|
||||
for args, kwargs in reversed(obj.__dict__.get("add_argument", [])):
|
||||
subcommand.add_argument(*args, **kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -19,7 +19,8 @@ LOG = logging.getLogger(__name__)
|
|||
def command_create(args) -> int:
|
||||
"""Create a user."""
|
||||
db: DbContext = args.get_db()
|
||||
userq.create(db, args.username, args.password, args.username, args.email, False)
|
||||
userq.create(db, args.username, "password", args.username, args.email, False)
|
||||
userq.password_set(db, args.username, args.password)
|
||||
return 0
|
||||
|
||||
|
||||
|
@ -71,8 +72,13 @@ def command_list(args):
|
|||
raise NotImplementedError()
|
||||
|
||||
|
||||
def command_passwd(args):
|
||||
@add_argument("username")
|
||||
@add_argument("password")
|
||||
def command_passwd(args) -> int:
|
||||
"""
|
||||
Set a user's password.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
db: DbContext = args.get_db()
|
||||
userq.password_set(db, args.username, args.password)
|
||||
LOG.info(f"Updated password for {args.username}")
|
||||
return 0
|
||||
|
|
Loading…
Reference in New Issue