Merge pull request #4230 from matipau/blueprint-fix

Fix callback order for nested blueprints
This commit is contained in:
David Lord 2021-10-03 20:38:28 -07:00 committed by GitHub
commit 6d637f0fdb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 140 additions and 60 deletions

View file

@ -21,6 +21,9 @@ Unreleased
:issue:`4096` :issue:`4096`
- The CLI loader handles ``**kwargs`` in a ``create_app`` function. - The CLI loader handles ``**kwargs`` in a ``create_app`` function.
:issue:`4170` :issue:`4170`
- Fix the order of ``before_request`` and other callbacks that trigger
before the view returns. They are called from the app down to the
closest nested blueprint. :issue:`4229`
Version 2.0.1 Version 2.0.1

View file

@ -58,17 +58,12 @@ from .signals import request_started
from .signals import request_tearing_down from .signals import request_tearing_down
from .templating import DispatchingJinjaLoader from .templating import DispatchingJinjaLoader
from .templating import Environment from .templating import Environment
from .typing import AfterRequestCallable
from .typing import BeforeFirstRequestCallable from .typing import BeforeFirstRequestCallable
from .typing import BeforeRequestCallable
from .typing import ResponseReturnValue from .typing import ResponseReturnValue
from .typing import TeardownCallable from .typing import TeardownCallable
from .typing import TemplateContextProcessorCallable
from .typing import TemplateFilterCallable from .typing import TemplateFilterCallable
from .typing import TemplateGlobalCallable from .typing import TemplateGlobalCallable
from .typing import TemplateTestCallable from .typing import TemplateTestCallable
from .typing import URLDefaultCallable
from .typing import URLValuePreprocessorCallable
from .wrappers import Request from .wrappers import Request
from .wrappers import Response from .wrappers import Response
@ -745,20 +740,21 @@ class Flask(Scaffold):
:param context: the context as a dictionary that is updated in place :param context: the context as a dictionary that is updated in place
to add extra variables. to add extra variables.
""" """
funcs: t.Iterable[ names: t.Iterable[t.Optional[str]] = (None,)
TemplateContextProcessorCallable
] = self.template_context_processors[None] # A template may be rendered outside a request context.
reqctx = _request_ctx_stack.top if request:
if reqctx is not None: names = chain(names, reversed(request.blueprints))
for bp in request.blueprints:
if bp in self.template_context_processors: # The values passed to render_template take precedence. Keep a
funcs = chain(funcs, self.template_context_processors[bp]) # copy to re-apply after all context functions.
orig_ctx = context.copy() orig_ctx = context.copy()
for func in funcs:
context.update(func()) for name in names:
# make sure the original values win. This makes it possible to if name in self.template_context_processors:
# easier add new variables in context processors without breaking for func in self.template_context_processors[name]:
# existing views. context.update(func())
context.update(orig_ctx) context.update(orig_ctx)
def make_shell_context(self) -> dict: def make_shell_context(self) -> dict:
@ -1278,9 +1274,10 @@ class Flask(Scaffold):
class, or ``None`` if a suitable handler is not found. class, or ``None`` if a suitable handler is not found.
""" """
exc_class, code = self._get_exc_class_and_code(type(e)) exc_class, code = self._get_exc_class_and_code(type(e))
names = (*request.blueprints, None)
for c in [code, None] if code is not None else [None]: for c in (code, None) if code is not None else (None,):
for name in chain(request.blueprints, [None]): for name in names:
handler_map = self.error_handler_spec[name][c] handler_map = self.error_handler_spec[name][c]
if not handler_map: if not handler_map:
@ -1800,17 +1797,19 @@ class Flask(Scaffold):
.. versionadded:: 0.7 .. versionadded:: 0.7
""" """
funcs: t.Iterable[URLDefaultCallable] = self.url_default_functions[None] names: t.Iterable[t.Optional[str]] = (None,)
# url_for may be called outside a request context, parse the
# passed endpoint instead of using request.blueprints.
if "." in endpoint: if "." in endpoint:
# This is called by url_for, which can be called outside a names = chain(
# request, can't use request.blueprints. names, reversed(_split_blueprint_path(endpoint.rpartition(".")[0]))
bps = _split_blueprint_path(endpoint.rpartition(".")[0]) )
bp_funcs = chain.from_iterable(self.url_default_functions[bp] for bp in bps)
funcs = chain(funcs, bp_funcs)
for func in funcs: for name in names:
func(endpoint, values) if name in self.url_default_functions:
for func in self.url_default_functions[name]:
func(endpoint, values)
def handle_url_build_error( def handle_url_build_error(
self, error: Exception, endpoint: str, values: dict self, error: Exception, endpoint: str, values: dict
@ -1845,24 +1844,20 @@ class Flask(Scaffold):
value is handled as if it was the return value from the view, and value is handled as if it was the return value from the view, and
further request handling is stopped. further request handling is stopped.
""" """
names = (None, *reversed(request.blueprints))
funcs: t.Iterable[URLValuePreprocessorCallable] = self.url_value_preprocessors[ for name in names:
None if name in self.url_value_preprocessors:
] for url_func in self.url_value_preprocessors[name]:
for bp in request.blueprints: url_func(request.endpoint, request.view_args)
if bp in self.url_value_preprocessors:
funcs = chain(funcs, self.url_value_preprocessors[bp])
for func in funcs:
func(request.endpoint, request.view_args)
funcs: t.Iterable[BeforeRequestCallable] = self.before_request_funcs[None] for name in names:
for bp in request.blueprints: if name in self.before_request_funcs:
if bp in self.before_request_funcs: for before_func in self.before_request_funcs[name]:
funcs = chain(funcs, self.before_request_funcs[bp]) rv = self.ensure_sync(before_func)()
for func in funcs:
rv = self.ensure_sync(func)() if rv is not None:
if rv is not None: return rv
return rv
return None return None
@ -1880,16 +1875,18 @@ class Flask(Scaffold):
instance of :attr:`response_class`. instance of :attr:`response_class`.
""" """
ctx = _request_ctx_stack.top ctx = _request_ctx_stack.top
funcs: t.Iterable[AfterRequestCallable] = ctx._after_request_functions
for bp in request.blueprints: for func in ctx._after_request_functions:
if bp in self.after_request_funcs: response = self.ensure_sync(func)(response)
funcs = chain(funcs, reversed(self.after_request_funcs[bp]))
if None in self.after_request_funcs: for name in chain(request.blueprints, (None,)):
funcs = chain(funcs, reversed(self.after_request_funcs[None])) if name in self.after_request_funcs:
for handler in funcs: for func in reversed(self.after_request_funcs[name]):
response = self.ensure_sync(handler)(response) response = self.ensure_sync(func)(response)
if not self.session_interface.is_null_session(ctx.session): if not self.session_interface.is_null_session(ctx.session):
self.session_interface.save_session(self, ctx.session, response) self.session_interface.save_session(self, ctx.session, response)
return response return response
def do_teardown_request( def do_teardown_request(
@ -1917,14 +1914,12 @@ class Flask(Scaffold):
""" """
if exc is _sentinel: if exc is _sentinel:
exc = sys.exc_info()[1] exc = sys.exc_info()[1]
funcs: t.Iterable[TeardownCallable] = reversed(
self.teardown_request_funcs[None] for name in chain(request.blueprints, (None,)):
) if name in self.teardown_request_funcs:
for bp in request.blueprints: for func in reversed(self.teardown_request_funcs[name]):
if bp in self.teardown_request_funcs: self.ensure_sync(func)(exc)
funcs = chain(funcs, reversed(self.teardown_request_funcs[bp]))
for func in funcs:
self.ensure_sync(func)(exc)
request_tearing_down.send(self, exc=exc) request_tearing_down.send(self, exc=exc)
def do_teardown_appcontext( def do_teardown_appcontext(
@ -1946,8 +1941,10 @@ class Flask(Scaffold):
""" """
if exc is _sentinel: if exc is _sentinel:
exc = sys.exc_info()[1] exc = sys.exc_info()[1]
for func in reversed(self.teardown_appcontext_funcs): for func in reversed(self.teardown_appcontext_funcs):
self.ensure_sync(func)(exc) self.ensure_sync(func)(exc)
appcontext_tearing_down.send(self, exc=exc) appcontext_tearing_down.send(self, exc=exc)
def app_context(self) -> AppContext: def app_context(self) -> AppContext:

View file

@ -837,6 +837,86 @@ def test_nested_blueprint(app, client):
assert client.get("/parent/child/grandchild/no").data == b"Grandchild no" assert client.get("/parent/child/grandchild/no").data == b"Grandchild no"
def test_nested_callback_order(app, client):
parent = flask.Blueprint("parent", __name__)
child = flask.Blueprint("child", __name__)
@app.before_request
def app_before1():
flask.g.setdefault("seen", []).append("app_1")
@app.teardown_request
def app_teardown1(e=None):
assert flask.g.seen.pop() == "app_1"
@app.before_request
def app_before2():
flask.g.setdefault("seen", []).append("app_2")
@app.teardown_request
def app_teardown2(e=None):
assert flask.g.seen.pop() == "app_2"
@app.context_processor
def app_ctx():
return dict(key="app")
@parent.before_request
def parent_before1():
flask.g.setdefault("seen", []).append("parent_1")
@parent.teardown_request
def parent_teardown1(e=None):
assert flask.g.seen.pop() == "parent_1"
@parent.before_request
def parent_before2():
flask.g.setdefault("seen", []).append("parent_2")
@parent.teardown_request
def parent_teardown2(e=None):
assert flask.g.seen.pop() == "parent_2"
@parent.context_processor
def parent_ctx():
return dict(key="parent")
@child.before_request
def child_before1():
flask.g.setdefault("seen", []).append("child_1")
@child.teardown_request
def child_teardown1(e=None):
assert flask.g.seen.pop() == "child_1"
@child.before_request
def child_before2():
flask.g.setdefault("seen", []).append("child_2")
@child.teardown_request
def child_teardown2(e=None):
assert flask.g.seen.pop() == "child_2"
@child.context_processor
def child_ctx():
return dict(key="child")
@child.route("/a")
def a():
return ", ".join(flask.g.seen)
@child.route("/b")
def b():
return flask.render_template_string("{{ key }}")
parent.register_blueprint(child)
app.register_blueprint(parent)
assert (
client.get("/a").data == b"app_1, app_2, parent_1, parent_2, child_1, child_2"
)
assert client.get("/b").data == b"child"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"parent_init, child_init, parent_registration, child_registration", "parent_init, child_init, parent_registration, child_registration",
[ [