pass context through dispatch methods

This commit is contained in:
David Lord 2025-09-19 16:58:48 -07:00
parent adf363679d
commit 6a64969009
No known key found for this signature in database
GPG key ID: 43368A7AA8CC5926
6 changed files with 167 additions and 72 deletions

View file

@ -9,6 +9,11 @@ Unreleased
a deprecated alias. If an app context is already pushed, it is not reused a deprecated alias. If an app context is already pushed, it is not reused
when dispatching a request. This greatly simplifies the internal code for tracking when dispatching a request. This greatly simplifies the internal code for tracking
the active context. :issue:`5639` the active context. :issue:`5639`
- Many ``Flask`` methods involved in request dispatch now take the current
``AppContext`` as the first parameter, instead of using the proxy objects.
If subclasses were overriding these methods, the old signature is detected,
shows a deprecation warning, and will continue to work during the
deprecation period. :issue:`5815`
- ``template_filter``, ``template_test``, and ``template_global`` decorators - ``template_filter``, ``template_test``, and ``template_global`` decorators
can be used without parentheses. :issue:`5729` can be used without parentheses. :issue:`5729`

View file

@ -1,11 +1,13 @@
from __future__ import annotations from __future__ import annotations
import collections.abc as cabc import collections.abc as cabc
import inspect
import os import os
import sys import sys
import typing as t import typing as t
import weakref import weakref
from datetime import timedelta from datetime import timedelta
from functools import update_wrapper
from inspect import iscoroutinefunction from inspect import iscoroutinefunction
from itertools import chain from itertools import chain
from types import TracebackType from types import TracebackType
@ -30,6 +32,7 @@ from . import cli
from . import typing as ft from . import typing as ft
from .ctx import AppContext from .ctx import AppContext
from .globals import _cv_app from .globals import _cv_app
from .globals import app_ctx
from .globals import g from .globals import g
from .globals import request from .globals import request
from .globals import session from .globals import session
@ -73,6 +76,35 @@ def _make_timedelta(value: timedelta | int | None) -> timedelta | None:
return timedelta(seconds=value) return timedelta(seconds=value)
F = t.TypeVar("F", bound=t.Callable[..., t.Any])
# Other methods may call the overridden method with the new ctx arg. Remove it
# and call the method with the remaining args.
def remove_ctx(f: F) -> F:
def wrapper(self: Flask, *args: t.Any, **kwargs: t.Any) -> t.Any:
if args and isinstance(args[0], AppContext):
args = args[1:]
return f(self, *args, **kwargs)
return update_wrapper(wrapper, f) # type: ignore[return-value]
# The overridden method may call super().base_method without the new ctx arg.
# Add it to the args for the call.
def add_ctx(f: F) -> F:
def wrapper(self: Flask, *args: t.Any, **kwargs: t.Any) -> t.Any:
if not args:
args = (app_ctx._get_current_object(),)
elif not isinstance(args[0], AppContext):
args = (app_ctx._get_current_object(), *args)
return f(self, *args, **kwargs)
return update_wrapper(wrapper, f) # type: ignore[return-value]
class Flask(App): class Flask(App):
"""The flask object implements a WSGI application and acts as the central """The flask object implements a WSGI application and acts as the central
object. It is passed the name of the module or package of the object. It is passed the name of the module or package of the
@ -218,6 +250,62 @@ class Flask(App):
#: .. versionadded:: 0.8 #: .. versionadded:: 0.8
session_interface: SessionInterface = SecureCookieSessionInterface() session_interface: SessionInterface = SecureCookieSessionInterface()
def __init_subclass__(cls, **kwargs: t.Any) -> None:
import warnings
# These method signatures were updated to take a ctx param. Detect
# overridden methods in subclasses that still have the old signature.
# Show a deprecation warning and wrap to call with correct args.
for method in (
cls.handle_http_exception,
cls.handle_user_exception,
cls.handle_exception,
cls.log_exception,
cls.dispatch_request,
cls.full_dispatch_request,
cls.finalize_request,
cls.make_default_options_response,
cls.preprocess_request,
cls.process_response,
cls.do_teardown_request,
cls.do_teardown_appcontext,
):
base_method = getattr(Flask, method.__name__)
if method is base_method:
# not overridden
continue
# get the second parameter (first is self)
iter_params = iter(inspect.signature(method).parameters.values())
next(iter_params)
param = next(iter_params, None)
# must have second parameter named ctx or annotated AppContext
if param is None or not (
# no annotation, match name
(param.annotation is inspect.Parameter.empty and param.name == "ctx")
or (
# string annotation, access path ends with AppContext
isinstance(param.annotation, str)
and param.annotation.rpartition(".")[2] == "AppContext"
)
or (
# class annotation
inspect.isclass(param.annotation)
and issubclass(param.annotation, AppContext)
)
):
warnings.warn(
f"The '{method.__name__}' method now takes 'ctx: AppContext'"
" as the first parameter. The old signature is deprecated"
" and will not be supported in Flask 4.0.",
DeprecationWarning,
stacklevel=2,
)
setattr(cls, method.__name__, remove_ctx(method))
setattr(Flask, method.__name__, add_ctx(base_method))
def __init__( def __init__(
self, self,
import_name: str, import_name: str,
@ -498,7 +586,9 @@ class Flask(App):
raise FormDataRoutingRedirect(request) raise FormDataRoutingRedirect(request)
def update_template_context(self, context: dict[str, t.Any]) -> None: def update_template_context(
self, ctx: AppContext, context: dict[str, t.Any]
) -> None:
"""Update the template context with some commonly used variables. """Update the template context with some commonly used variables.
This injects request, session, config and g into the template This injects request, session, config and g into the template
context as well as everything template context processors want context as well as everything template context processors want
@ -512,7 +602,7 @@ class Flask(App):
names: t.Iterable[str | None] = (None,) names: t.Iterable[str | None] = (None,)
# A template may be rendered outside a request context. # A template may be rendered outside a request context.
if (ctx := _cv_app.get(None)) is not None and ctx.has_request: if ctx.has_request:
names = chain(names, reversed(ctx.request.blueprints)) names = chain(names, reversed(ctx.request.blueprints))
# The values passed to render_template take precedence. Keep a # The values passed to render_template take precedence. Keep a
@ -737,7 +827,7 @@ class Flask(App):
return cls(self, **kwargs) # type: ignore return cls(self, **kwargs) # type: ignore
def handle_http_exception( def handle_http_exception(
self, e: HTTPException self, ctx: AppContext, e: HTTPException
) -> HTTPException | ft.ResponseReturnValue: ) -> HTTPException | ft.ResponseReturnValue:
"""Handles an HTTP exception. By default this will invoke the """Handles an HTTP exception. By default this will invoke the
registered error handlers and fall back to returning the registered error handlers and fall back to returning the
@ -766,13 +856,13 @@ class Flask(App):
if isinstance(e, RoutingException): if isinstance(e, RoutingException):
return e return e
handler = self._find_error_handler(e, request.blueprints) handler = self._find_error_handler(e, ctx.request.blueprints)
if handler is None: if handler is None:
return e return e
return self.ensure_sync(handler)(e) # type: ignore[no-any-return] return self.ensure_sync(handler)(e) # type: ignore[no-any-return]
def handle_user_exception( def handle_user_exception(
self, e: Exception self, ctx: AppContext, e: Exception
) -> HTTPException | ft.ResponseReturnValue: ) -> HTTPException | ft.ResponseReturnValue:
"""This method is called whenever an exception occurs that """This method is called whenever an exception occurs that
should be handled. A special case is :class:`~werkzeug should be handled. A special case is :class:`~werkzeug
@ -794,16 +884,16 @@ class Flask(App):
e.show_exception = True e.show_exception = True
if isinstance(e, HTTPException) and not self.trap_http_exception(e): if isinstance(e, HTTPException) and not self.trap_http_exception(e):
return self.handle_http_exception(e) return self.handle_http_exception(ctx, e)
handler = self._find_error_handler(e, request.blueprints) handler = self._find_error_handler(e, ctx.request.blueprints)
if handler is None: if handler is None:
raise raise
return self.ensure_sync(handler)(e) # type: ignore[no-any-return] return self.ensure_sync(handler)(e) # type: ignore[no-any-return]
def handle_exception(self, e: Exception) -> Response: def handle_exception(self, ctx: AppContext, e: Exception) -> Response:
"""Handle an exception that did not have an error handler """Handle an exception that did not have an error handler
associated with it, or that was raised from an error handler. associated with it, or that was raised from an error handler.
This always causes a 500 ``InternalServerError``. This always causes a 500 ``InternalServerError``.
@ -846,19 +936,20 @@ class Flask(App):
raise e raise e
self.log_exception(exc_info) self.log_exception(ctx, exc_info)
server_error: InternalServerError | ft.ResponseReturnValue server_error: InternalServerError | ft.ResponseReturnValue
server_error = InternalServerError(original_exception=e) server_error = InternalServerError(original_exception=e)
handler = self._find_error_handler(server_error, request.blueprints) handler = self._find_error_handler(server_error, ctx.request.blueprints)
if handler is not None: if handler is not None:
server_error = self.ensure_sync(handler)(server_error) server_error = self.ensure_sync(handler)(server_error)
return self.finalize_request(server_error, from_error_handler=True) return self.finalize_request(ctx, server_error, from_error_handler=True)
def log_exception( def log_exception(
self, self,
exc_info: (tuple[type, BaseException, TracebackType] | tuple[None, None, None]), ctx: AppContext,
exc_info: tuple[type, BaseException, TracebackType] | tuple[None, None, None],
) -> None: ) -> None:
"""Logs an exception. This is called by :meth:`handle_exception` """Logs an exception. This is called by :meth:`handle_exception`
if debugging is disabled and right before the handler is called. if debugging is disabled and right before the handler is called.
@ -868,10 +959,10 @@ class Flask(App):
.. versionadded:: 0.8 .. versionadded:: 0.8
""" """
self.logger.error( self.logger.error(
f"Exception on {request.path} [{request.method}]", exc_info=exc_info f"Exception on {ctx.request.path} [{ctx.request.method}]", exc_info=exc_info
) )
def dispatch_request(self) -> ft.ResponseReturnValue: def dispatch_request(self, ctx: AppContext) -> ft.ResponseReturnValue:
"""Does the request dispatching. Matches the URL and returns the """Does the request dispatching. Matches the URL and returns the
return value of the view or error handler. This does not have to return value of the view or error handler. This does not have to
be a response object. In order to convert the return value to a be a response object. In order to convert the return value to a
@ -881,7 +972,7 @@ class Flask(App):
This no longer does the exception handling, this code was This no longer does the exception handling, this code was
moved to the new :meth:`full_dispatch_request`. moved to the new :meth:`full_dispatch_request`.
""" """
req = _cv_app.get().request req = ctx.request
if req.routing_exception is not None: if req.routing_exception is not None:
self.raise_routing_exception(req) self.raise_routing_exception(req)
@ -892,12 +983,12 @@ class Flask(App):
getattr(rule, "provide_automatic_options", False) getattr(rule, "provide_automatic_options", False)
and req.method == "OPTIONS" and req.method == "OPTIONS"
): ):
return self.make_default_options_response() return self.make_default_options_response(ctx)
# otherwise dispatch to the handler for that endpoint # otherwise dispatch to the handler for that endpoint
view_args: dict[str, t.Any] = req.view_args # type: ignore[assignment] view_args: dict[str, t.Any] = req.view_args # type: ignore[assignment]
return self.ensure_sync(self.view_functions[rule.endpoint])(**view_args) # type: ignore[no-any-return] return self.ensure_sync(self.view_functions[rule.endpoint])(**view_args) # type: ignore[no-any-return]
def full_dispatch_request(self) -> Response: def full_dispatch_request(self, ctx: AppContext) -> Response:
"""Dispatches the request and on top of that performs request """Dispatches the request and on top of that performs request
pre and postprocessing as well as HTTP exception catching and pre and postprocessing as well as HTTP exception catching and
error handling. error handling.
@ -908,15 +999,16 @@ class Flask(App):
try: try:
request_started.send(self, _async_wrapper=self.ensure_sync) request_started.send(self, _async_wrapper=self.ensure_sync)
rv = self.preprocess_request() rv = self.preprocess_request(ctx)
if rv is None: if rv is None:
rv = self.dispatch_request() rv = self.dispatch_request(ctx)
except Exception as e: except Exception as e:
rv = self.handle_user_exception(e) rv = self.handle_user_exception(ctx, e)
return self.finalize_request(rv) return self.finalize_request(ctx, rv)
def finalize_request( def finalize_request(
self, self,
ctx: AppContext,
rv: ft.ResponseReturnValue | HTTPException, rv: ft.ResponseReturnValue | HTTPException,
from_error_handler: bool = False, from_error_handler: bool = False,
) -> Response: ) -> Response:
@ -934,7 +1026,7 @@ class Flask(App):
""" """
response = self.make_response(rv) response = self.make_response(rv)
try: try:
response = self.process_response(response) response = self.process_response(ctx, response)
request_finished.send( request_finished.send(
self, _async_wrapper=self.ensure_sync, response=response self, _async_wrapper=self.ensure_sync, response=response
) )
@ -946,15 +1038,14 @@ class Flask(App):
) )
return response return response
def make_default_options_response(self) -> Response: def make_default_options_response(self, ctx: AppContext) -> Response:
"""This method is called to create the default ``OPTIONS`` response. """This method is called to create the default ``OPTIONS`` response.
This can be changed through subclassing to change the default This can be changed through subclassing to change the default
behavior of ``OPTIONS`` responses. behavior of ``OPTIONS`` responses.
.. versionadded:: 0.7 .. versionadded:: 0.7
""" """
adapter = _cv_app.get().url_adapter methods = ctx.url_adapter.allowed_methods() # type: ignore[union-attr]
methods = adapter.allowed_methods() # type: ignore[union-attr]
rv = self.response_class() rv = self.response_class()
rv.allow.update(methods) rv.allow.update(methods)
return rv return rv
@ -1260,7 +1351,7 @@ class Flask(App):
return rv return rv
def preprocess_request(self) -> ft.ResponseReturnValue | None: def preprocess_request(self, ctx: AppContext) -> ft.ResponseReturnValue | None:
"""Called before the request is dispatched. Calls """Called before the request is dispatched. Calls
:attr:`url_value_preprocessors` registered with the app and the :attr:`url_value_preprocessors` registered with the app and the
current blueprint (if any). Then calls :attr:`before_request_funcs` current blueprint (if any). Then calls :attr:`before_request_funcs`
@ -1270,7 +1361,7 @@ class Flask(App):
value is handled as if it was the return value from the view, and value is handled as if it was the return value from the view, and
further request handling is stopped. further request handling is stopped.
""" """
req = _cv_app.get().request req = ctx.request
names = (None, *reversed(req.blueprints)) names = (None, *reversed(req.blueprints))
for name in names: for name in names:
@ -1288,7 +1379,7 @@ class Flask(App):
return None return None
def process_response(self, response: Response) -> Response: def process_response(self, ctx: AppContext, response: Response) -> Response:
"""Can be overridden in order to modify the response object """Can be overridden in order to modify the response object
before it's sent to the WSGI server. By default this will before it's sent to the WSGI server. By default this will
call all the :meth:`after_request` decorated functions. call all the :meth:`after_request` decorated functions.
@ -1301,8 +1392,6 @@ class Flask(App):
:return: a new response object or the same, has to be an :return: a new response object or the same, has to be an
instance of :attr:`response_class`. instance of :attr:`response_class`.
""" """
ctx = _cv_app.get()
for func in ctx._after_request_functions: for func in ctx._after_request_functions:
response = self.ensure_sync(func)(response) response = self.ensure_sync(func)(response)
@ -1316,7 +1405,9 @@ class Flask(App):
return response return response
def do_teardown_request(self, exc: BaseException | None = None) -> None: def do_teardown_request(
self, ctx: AppContext, exc: BaseException | None = None
) -> None:
"""Called after the request is dispatched and the response is finalized, """Called after the request is dispatched and the response is finalized,
right before the request context is popped. Called by right before the request context is popped. Called by
:meth:`.AppContext.pop`. :meth:`.AppContext.pop`.
@ -1331,16 +1422,16 @@ class Flask(App):
.. versionchanged:: 0.9 .. versionchanged:: 0.9
Added the ``exc`` argument. Added the ``exc`` argument.
""" """
req = _cv_app.get().request for name in chain(ctx.request.blueprints, (None,)):
for name in chain(req.blueprints, (None,)):
if name in self.teardown_request_funcs: if name in self.teardown_request_funcs:
for func in reversed(self.teardown_request_funcs[name]): for func in reversed(self.teardown_request_funcs[name]):
self.ensure_sync(func)(exc) self.ensure_sync(func)(exc)
request_tearing_down.send(self, _async_wrapper=self.ensure_sync, exc=exc) request_tearing_down.send(self, _async_wrapper=self.ensure_sync, exc=exc)
def do_teardown_appcontext(self, exc: BaseException | None = None) -> None: def do_teardown_appcontext(
self, ctx: AppContext, exc: BaseException | None = None
) -> None:
"""Called right before the application context is popped. Called by """Called right before the application context is popped. Called by
:meth:`.AppContext.pop`. :meth:`.AppContext.pop`.
@ -1473,17 +1564,17 @@ class Flask(App):
try: try:
try: try:
ctx.push() ctx.push()
response = self.full_dispatch_request() response = self.full_dispatch_request(ctx)
except Exception as e: except Exception as e:
error = e error = e
response = self.handle_exception(e) response = self.handle_exception(ctx, e)
except: # noqa: B001 except: # noqa: B001
error = sys.exc_info()[1] error = sys.exc_info()[1]
raise raise
return response(environ, start_response) return response(environ, start_response)
finally: finally:
if "werkzeug.debug.preserve_context" in environ: if "werkzeug.debug.preserve_context" in environ:
environ["werkzeug.debug.preserve_context"](_cv_app.get()) environ["werkzeug.debug.preserve_context"](ctx)
if error is not None and self.should_ignore_error(error): if error is not None and self.should_ignore_error(error):
error = None error = None

View file

@ -471,10 +471,10 @@ class AppContext:
try: try:
if self._request is not None: if self._request is not None:
self.app.do_teardown_request(exc) self.app.do_teardown_request(self, exc)
self._request.close() self._request.close()
finally: finally:
self.app.do_teardown_appcontext(exc) self.app.do_teardown_appcontext(self, exc)
_cv_app.reset(self._cv_token) _cv_app.reset(self._cv_token)
self._cv_token = None self._cv_token = None
appcontext_popped.send(self.app, _async_wrapper=self.app.ensure_sync) appcontext_popped.send(self.app, _async_wrapper=self.app.ensure_sync)

View file

@ -7,14 +7,13 @@ from jinja2 import Environment as BaseEnvironment
from jinja2 import Template from jinja2 import Template
from jinja2 import TemplateNotFound from jinja2 import TemplateNotFound
from .globals import _cv_app from .ctx import AppContext
from .globals import current_app from .globals import app_ctx
from .helpers import stream_with_context from .helpers import stream_with_context
from .signals import before_render_template from .signals import before_render_template
from .signals import template_rendered from .signals import template_rendered
if t.TYPE_CHECKING: # pragma: no cover if t.TYPE_CHECKING: # pragma: no cover
from .app import Flask
from .sansio.app import App from .sansio.app import App
from .sansio.scaffold import Scaffold from .sansio.scaffold import Scaffold
@ -23,15 +22,12 @@ def _default_template_ctx_processor() -> dict[str, t.Any]:
"""Default template context processor. Injects `request`, """Default template context processor. Injects `request`,
`session` and `g`. `session` and `g`.
""" """
ctx = _cv_app.get(None) ctx = app_ctx._get_current_object()
rv: dict[str, t.Any] = {} rv: dict[str, t.Any] = {"g": ctx.g}
if ctx is not None: if ctx.has_request:
rv["g"] = ctx.g rv["request"] = ctx.request
rv["session"] = ctx.session
if ctx.has_request:
rv["request"] = ctx.request
rv["session"] = ctx.session
return rv return rv
@ -123,8 +119,9 @@ class DispatchingJinjaLoader(BaseLoader):
return list(result) return list(result)
def _render(app: Flask, template: Template, context: dict[str, t.Any]) -> str: def _render(ctx: AppContext, template: Template, context: dict[str, t.Any]) -> str:
app.update_template_context(context) app = ctx.app
app.update_template_context(ctx, context)
before_render_template.send( before_render_template.send(
app, _async_wrapper=app.ensure_sync, template=template, context=context app, _async_wrapper=app.ensure_sync, template=template, context=context
) )
@ -145,9 +142,9 @@ def render_template(
a list is given, the first name to exist will be rendered. a list is given, the first name to exist will be rendered.
:param context: The variables to make available in the template. :param context: The variables to make available in the template.
""" """
app = current_app._get_current_object() ctx = app_ctx._get_current_object()
template = app.jinja_env.get_or_select_template(template_name_or_list) template = ctx.app.jinja_env.get_or_select_template(template_name_or_list)
return _render(app, template, context) return _render(ctx, template, context)
def render_template_string(source: str, **context: t.Any) -> str: def render_template_string(source: str, **context: t.Any) -> str:
@ -157,15 +154,16 @@ def render_template_string(source: str, **context: t.Any) -> str:
:param source: The source code of the template to render. :param source: The source code of the template to render.
:param context: The variables to make available in the template. :param context: The variables to make available in the template.
""" """
app = current_app._get_current_object() ctx = app_ctx._get_current_object()
template = app.jinja_env.from_string(source) template = ctx.app.jinja_env.from_string(source)
return _render(app, template, context) return _render(ctx, template, context)
def _stream( def _stream(
app: Flask, template: Template, context: dict[str, t.Any] ctx: AppContext, template: Template, context: dict[str, t.Any]
) -> t.Iterator[str]: ) -> t.Iterator[str]:
app.update_template_context(context) app = ctx.app
app.update_template_context(ctx, context)
before_render_template.send( before_render_template.send(
app, _async_wrapper=app.ensure_sync, template=template, context=context app, _async_wrapper=app.ensure_sync, template=template, context=context
) )
@ -193,9 +191,9 @@ def stream_template(
.. versionadded:: 2.2 .. versionadded:: 2.2
""" """
app = current_app._get_current_object() ctx = app_ctx._get_current_object()
template = app.jinja_env.get_or_select_template(template_name_or_list) template = ctx.app.jinja_env.get_or_select_template(template_name_or_list)
return _stream(app, template, context) return _stream(ctx, template, context)
def stream_template_string(source: str, **context: t.Any) -> t.Iterator[str]: def stream_template_string(source: str, **context: t.Any) -> t.Iterator[str]:
@ -208,6 +206,6 @@ def stream_template_string(source: str, **context: t.Any) -> t.Iterator[str]:
.. versionadded:: 2.2 .. versionadded:: 2.2
""" """
app = current_app._get_current_object() ctx = app_ctx._get_current_object()
template = app.jinja_env.from_string(source) template = ctx.app.jinja_env.from_string(source)
return _stream(app, template, context) return _stream(ctx, template, context)

View file

@ -288,8 +288,9 @@ def test_bad_environ_raises_bad_request():
# use a non-printable character in the Host - this is key to this test # use a non-printable character in the Host - this is key to this test
environ["HTTP_HOST"] = "\x8a" environ["HTTP_HOST"] = "\x8a"
with app.request_context(environ): with app.request_context(environ) as ctx:
response = app.full_dispatch_request() response = app.full_dispatch_request(ctx)
assert response.status_code == 400 assert response.status_code == 400
@ -308,8 +309,8 @@ def test_environ_for_valid_idna_completes():
# these characters are all IDNA-compatible # these characters are all IDNA-compatible
environ["HTTP_HOST"] = "ąśźäüжŠßя.com" environ["HTTP_HOST"] = "ąśźäüжŠßя.com"
with app.request_context(environ): with app.request_context(environ) as ctx:
response = app.full_dispatch_request() response = app.full_dispatch_request(ctx)
assert response.status_code == 200 assert response.status_code == 200

View file

@ -5,7 +5,7 @@ import flask
def test_suppressed_exception_logging(): def test_suppressed_exception_logging():
class SuppressedFlask(flask.Flask): class SuppressedFlask(flask.Flask):
def log_exception(self, exc_info): def log_exception(self, ctx, exc_info):
pass pass
out = StringIO() out = StringIO()