Implement user passwd command
This commit is contained in:
parent
3c7fc4b5f8
commit
6b5463b702
|
@ -5,7 +5,8 @@ User query interface
|
||||||
import re
|
import re
|
||||||
from typing import Optional, Sequence
|
from typing import Optional, Sequence
|
||||||
|
|
||||||
from sqlalchemy import select, func
|
from sqlalchemy import select, func, update
|
||||||
|
from werkzeug.security import generate_password_hash, check_password_hash
|
||||||
|
|
||||||
from amanuensis.db import DbContext, User
|
from amanuensis.db import DbContext, User
|
||||||
from amanuensis.errors import ArgumentError
|
from amanuensis.errors import ArgumentError
|
||||||
|
@ -81,3 +82,19 @@ def get_user_by_username(db: DbContext, username: str) -> Optional[User]:
|
||||||
"""
|
"""
|
||||||
user: User = db(select(User).where(User.username == username)).scalar_one_or_none()
|
user: User = db(select(User).where(User.username == username)).scalar_one_or_none()
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def password_set(db: DbContext, username: str, new_password: str) -> None:
|
||||||
|
"""Set a user's password."""
|
||||||
|
password_hash = generate_password_hash(new_password)
|
||||||
|
db(update(User).where(User.username == username).values(password=password_hash))
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def password_check(db: DbContext, username: str, password: str) -> bool:
|
||||||
|
"""Check if a password is correct."""
|
||||||
|
user_password_hash: str = db(
|
||||||
|
select(User.password).where(User.username == username)
|
||||||
|
).scalar_one()
|
||||||
|
return check_password_hash(user_password_hash, password)
|
||||||
|
|
||||||
|
|
|
@ -66,7 +66,7 @@ def add_subcommand(subparsers, module) -> None:
|
||||||
sc_name, help=sc_help, description=obj.__doc__
|
sc_name, help=sc_help, description=obj.__doc__
|
||||||
)
|
)
|
||||||
subcommand.set_defaults(func=obj)
|
subcommand.set_defaults(func=obj)
|
||||||
for args, kwargs in obj.__dict__.get("add_argument", []):
|
for args, kwargs in reversed(obj.__dict__.get("add_argument", [])):
|
||||||
subcommand.add_argument(*args, **kwargs)
|
subcommand.add_argument(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,8 @@ LOG = logging.getLogger(__name__)
|
||||||
def command_create(args) -> int:
|
def command_create(args) -> int:
|
||||||
"""Create a user."""
|
"""Create a user."""
|
||||||
db: DbContext = args.get_db()
|
db: DbContext = args.get_db()
|
||||||
userq.create(db, args.username, args.password, args.username, args.email, False)
|
userq.create(db, args.username, "password", args.username, args.email, False)
|
||||||
|
userq.password_set(db, args.username, args.password)
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
@ -71,8 +72,13 @@ def command_list(args):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
def command_passwd(args):
|
@add_argument("username")
|
||||||
|
@add_argument("password")
|
||||||
|
def command_passwd(args) -> int:
|
||||||
"""
|
"""
|
||||||
Set a user's password.
|
Set a user's password.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
db: DbContext = args.get_db()
|
||||||
|
userq.password_set(db, args.username, args.password)
|
||||||
|
LOG.info(f"Updated password for {args.username}")
|
||||||
|
return 0
|
||||||
|
|
Loading…
Reference in New Issue