diff --git a/amanuensis/server/__init__.py b/amanuensis/server/__init__.py index 31f0e50..bb0e0b4 100644 --- a/amanuensis/server/__init__.py +++ b/amanuensis/server/__init__.py @@ -9,7 +9,7 @@ from amanuensis.config import AmanuensisConfig, CommandLineConfig from amanuensis.db import DbContext from amanuensis.parser import filesafe_title import amanuensis.server.auth as auth -from amanuensis.server.helpers import UuidConverter +from amanuensis.server.helpers import UuidConverter, current_lexicon, current_membership import amanuensis.server.home as home import amanuensis.server.lexicon as lexicon @@ -68,13 +68,21 @@ def get_app( app.teardown_appcontext(db_teardown) # Configure jinja options - def include_backend(): - return {"db": db, "lexiq": lexiq, "userq": userq, "memq": memq, "charq": charq} + def add_jinja_context(): + return { + "db": db, + "lexiq": lexiq, + "userq": userq, + "memq": memq, + "charq": charq, + "current_lexicon": current_lexicon, + "current_membership": current_membership + } app.jinja_options.update(trim_blocks=True, lstrip_blocks=True) app.template_filter("date")(date_format) app.template_filter("articlelink")(article_link) - app.context_processor(include_backend) + app.context_processor(add_jinja_context) # Set up uuid route converter app.url_map.converters["uuid"] = UuidConverter diff --git a/amanuensis/server/helpers.py b/amanuensis/server/helpers.py index 8328ade..b622a00 100644 --- a/amanuensis/server/helpers.py +++ b/amanuensis/server/helpers.py @@ -2,8 +2,17 @@ from functools import wraps from typing import Optional, Any from uuid import UUID -from flask import g, flash, redirect, url_for +from flask import ( + _request_ctx_stack, + flash, + g, + has_request_context, + redirect, + request, + url_for, +) from flask_login import current_user +from werkzeug.local import LocalProxy from werkzeug.routing import BaseConverter, ValidationError from amanuensis.backend import lexiq, memq @@ -26,6 +35,45 @@ class UuidConverter(BaseConverter): return str(value) +def get_current_lexicon(): + # Check if the request context is for a lexicon page + if not has_request_context(): + return None + lexicon_name = request.view_args.get("lexicon_name") + if not lexicon_name: + return None + # Pull up the lexicon if it exists and cache it in the request context + if not hasattr(_request_ctx_stack.top, "lexicon"): + db: DbContext = g.db + lexicon: Optional[Lexicon] = lexiq.try_from_name(db, lexicon_name) + setattr(_request_ctx_stack.top, "lexicon", lexicon) + # Return the cached lexicon + return getattr(_request_ctx_stack.top, "lexicon", None) + + +current_lexicon = LocalProxy(get_current_lexicon) + + +def get_current_membership(): + # Base the current membership on the current user and the current lexicon + user: User = current_user + if not user or not user.is_authenticated: + return None + lexicon: Lexicon = current_lexicon + if not lexicon: + return None + # Pull up the membership and cache it in the request context + if not hasattr(_request_ctx_stack.top, "membership"): + db: DbContext = g.db + mem: Membership = memq.try_from_ids(db, user.id, lexicon.id) + setattr(_request_ctx_stack.top, "membership", mem) + # Return cached membership + return getattr(_request_ctx_stack.top, "membership", None) + + +current_membership = LocalProxy(get_current_membership) + + def lexicon_param(route): """ Wrapper for loading a route's lexicon to `g`.