wrap ContextVar more directly instead of Werkzeug LocalStack
This commit is contained in:
parent
ab36542260
commit
af0feaa116
2 changed files with 59 additions and 30 deletions
|
|
@ -1,12 +1,13 @@
|
|||
import typing as t
|
||||
from functools import partial
|
||||
from contextvars import ContextVar
|
||||
|
||||
from werkzeug.local import LocalProxy
|
||||
from werkzeug.local import LocalStack
|
||||
|
||||
if t.TYPE_CHECKING: # pragma: no cover
|
||||
from .app import Flask
|
||||
from .ctx import _AppCtxGlobals
|
||||
from .ctx import AppContext
|
||||
from .ctx import RequestContext
|
||||
from .sessions import SessionMixin
|
||||
from .wrappers import Request
|
||||
|
||||
|
|
@ -26,34 +27,61 @@ this, set up an application context with app.app_context(). See the
|
|||
documentation for more information.\
|
||||
"""
|
||||
|
||||
|
||||
def _lookup_req_object(name):
|
||||
top = _request_ctx_stack.top
|
||||
if top is None:
|
||||
raise RuntimeError(_request_ctx_err_msg)
|
||||
return getattr(top, name)
|
||||
_T = t.TypeVar("_T")
|
||||
|
||||
|
||||
def _lookup_app_object(name):
|
||||
top = _app_ctx_stack.top
|
||||
if top is None:
|
||||
raise RuntimeError(_app_ctx_err_msg)
|
||||
return getattr(top, name)
|
||||
class CtxStack(t.Generic[_T]):
|
||||
def __init__(self, var: ContextVar[t.List[_T]], error: str) -> None:
|
||||
self.var = var
|
||||
self.error = error
|
||||
|
||||
def push(self, ctx: _T) -> t.List[_T]:
|
||||
stack = self.var.get(None)
|
||||
|
||||
if stack is None:
|
||||
stack = []
|
||||
self.var.set(stack)
|
||||
|
||||
stack.append(ctx)
|
||||
return stack
|
||||
|
||||
def pop(self) -> t.Optional[_T]:
|
||||
stack = self.var.get(None)
|
||||
|
||||
if stack is None or len(stack) == 0:
|
||||
return None
|
||||
|
||||
return stack.pop()
|
||||
|
||||
@property
|
||||
def top(self) -> _T:
|
||||
stack = self.var.get(None)
|
||||
|
||||
if stack is None or len(stack) == 0:
|
||||
return None
|
||||
|
||||
return stack[-1]
|
||||
|
||||
def require(self) -> _T:
|
||||
top = self.top
|
||||
|
||||
if top is None:
|
||||
raise RuntimeError(self.error)
|
||||
|
||||
return top
|
||||
|
||||
|
||||
def _find_app():
|
||||
top = _app_ctx_stack.top
|
||||
if top is None:
|
||||
raise RuntimeError(_app_ctx_err_msg)
|
||||
return top.app
|
||||
|
||||
|
||||
# context locals
|
||||
_request_ctx_stack = LocalStack()
|
||||
_app_ctx_stack = LocalStack()
|
||||
current_app: "Flask" = LocalProxy(_find_app) # type: ignore
|
||||
request: "Request" = LocalProxy(partial(_lookup_req_object, "request")) # type: ignore
|
||||
session: "SessionMixin" = LocalProxy( # type: ignore
|
||||
partial(_lookup_req_object, "session")
|
||||
_app_var: ContextVar[t.List["AppContext"]] = ContextVar("_app_var")
|
||||
_app_ctx_stack: CtxStack["AppContext"] = CtxStack(_app_var, _app_ctx_err_msg)
|
||||
current_app: "Flask" = LocalProxy(lambda: _app_ctx_stack.require().app) # type: ignore
|
||||
g: "_AppCtxGlobals" = LocalProxy(lambda: _app_ctx_stack.require().g) # type: ignore
|
||||
_req_var: ContextVar[t.List["RequestContext"]] = ContextVar("_req_var")
|
||||
_request_ctx_stack: CtxStack["RequestContext"] = CtxStack(
|
||||
_req_var, _request_ctx_err_msg
|
||||
)
|
||||
request: "Request" = LocalProxy( # type: ignore
|
||||
lambda: _request_ctx_stack.require().request
|
||||
)
|
||||
session: "SessionMixin" = LocalProxy( # type: ignore
|
||||
lambda: _request_ctx_stack.require().session
|
||||
)
|
||||
g: "_AppCtxGlobals" = LocalProxy(partial(_lookup_app_object, "g")) # type: ignore
|
||||
|
|
|
|||
|
|
@ -323,12 +323,13 @@ def test_app_cli_has_app_context(app, runner):
|
|||
# the loaded app should be the same as current_app
|
||||
same_app = current_app._get_current_object() is app
|
||||
# only one app context should be pushed
|
||||
stack_size = len(_app_ctx_stack._local.stack)
|
||||
stack = _app_ctx_stack.var.get(None)
|
||||
stack_size = stack is not None and len(stack) == 1
|
||||
return same_app, stack_size, value
|
||||
|
||||
cli = FlaskGroup(create_app=lambda: app)
|
||||
result = runner.invoke(cli, ["check", "x"], standalone_mode=False)
|
||||
assert result.return_value == (True, 1, True)
|
||||
assert result.return_value == (True, True, True)
|
||||
|
||||
|
||||
def test_with_appcontext(runner):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue