wrap ContextVar more directly instead of Werkzeug LocalStack

This commit is contained in:
David Lord 2022-06-22 12:06:22 -07:00
parent ab36542260
commit af0feaa116
No known key found for this signature in database
GPG key ID: 7A1C87E3F5BC42A8
2 changed files with 59 additions and 30 deletions

View file

@ -1,12 +1,13 @@
import typing as t
from functools import partial
from contextvars import ContextVar
from werkzeug.local import LocalProxy
from werkzeug.local import LocalStack
if t.TYPE_CHECKING: # pragma: no cover
from .app import Flask
from .ctx import _AppCtxGlobals
from .ctx import AppContext
from .ctx import RequestContext
from .sessions import SessionMixin
from .wrappers import Request
@ -26,34 +27,61 @@ this, set up an application context with app.app_context(). See the
documentation for more information.\
"""
def _lookup_req_object(name):
top = _request_ctx_stack.top
if top is None:
raise RuntimeError(_request_ctx_err_msg)
return getattr(top, name)
_T = t.TypeVar("_T")
def _lookup_app_object(name):
top = _app_ctx_stack.top
if top is None:
raise RuntimeError(_app_ctx_err_msg)
return getattr(top, name)
class CtxStack(t.Generic[_T]):
def __init__(self, var: ContextVar[t.List[_T]], error: str) -> None:
self.var = var
self.error = error
def push(self, ctx: _T) -> t.List[_T]:
stack = self.var.get(None)
if stack is None:
stack = []
self.var.set(stack)
stack.append(ctx)
return stack
def pop(self) -> t.Optional[_T]:
stack = self.var.get(None)
if stack is None or len(stack) == 0:
return None
return stack.pop()
@property
def top(self) -> _T:
stack = self.var.get(None)
if stack is None or len(stack) == 0:
return None
return stack[-1]
def require(self) -> _T:
top = self.top
if top is None:
raise RuntimeError(self.error)
return top
def _find_app():
top = _app_ctx_stack.top
if top is None:
raise RuntimeError(_app_ctx_err_msg)
return top.app
# context locals
_request_ctx_stack = LocalStack()
_app_ctx_stack = LocalStack()
current_app: "Flask" = LocalProxy(_find_app) # type: ignore
request: "Request" = LocalProxy(partial(_lookup_req_object, "request")) # type: ignore
session: "SessionMixin" = LocalProxy( # type: ignore
partial(_lookup_req_object, "session")
_app_var: ContextVar[t.List["AppContext"]] = ContextVar("_app_var")
_app_ctx_stack: CtxStack["AppContext"] = CtxStack(_app_var, _app_ctx_err_msg)
current_app: "Flask" = LocalProxy(lambda: _app_ctx_stack.require().app) # type: ignore
g: "_AppCtxGlobals" = LocalProxy(lambda: _app_ctx_stack.require().g) # type: ignore
_req_var: ContextVar[t.List["RequestContext"]] = ContextVar("_req_var")
_request_ctx_stack: CtxStack["RequestContext"] = CtxStack(
_req_var, _request_ctx_err_msg
)
request: "Request" = LocalProxy( # type: ignore
lambda: _request_ctx_stack.require().request
)
session: "SessionMixin" = LocalProxy( # type: ignore
lambda: _request_ctx_stack.require().session
)
g: "_AppCtxGlobals" = LocalProxy(partial(_lookup_app_object, "g")) # type: ignore

View file

@ -323,12 +323,13 @@ def test_app_cli_has_app_context(app, runner):
# the loaded app should be the same as current_app
same_app = current_app._get_current_object() is app
# only one app context should be pushed
stack_size = len(_app_ctx_stack._local.stack)
stack = _app_ctx_stack.var.get(None)
stack_size = stack is not None and len(stack) == 1
return same_app, stack_size, value
cli = FlaskGroup(create_app=lambda: app)
result = runner.invoke(cli, ["check", "x"], standalone_mode=False)
assert result.return_value == (True, 1, True)
assert result.return_value == (True, True, True)
def test_with_appcontext(runner):