all teardown callbacks are called despite errors

This commit is contained in:
David Lord 2026-02-19 19:41:50 -08:00
parent 7b0088693e
commit fbb6f0bc4c
No known key found for this signature in database
GPG key ID: 43368A7AA8CC5926
10 changed files with 195 additions and 81 deletions

View file

@ -36,6 +36,7 @@ from .globals import app_ctx
from .globals import g
from .globals import request
from .globals import session
from .helpers import _CollectErrors
from .helpers import get_debug_flag
from .helpers import get_flashed_messages
from .helpers import get_load_dotenv
@ -1430,15 +1431,24 @@ class Flask(App):
:param exc: An unhandled exception raised while dispatching the request.
Passed to each teardown function.
.. versionchanged:: 3.2
All callbacks are called rather than stopping on the first error.
.. versionchanged:: 0.9
Added the ``exc`` argument.
"""
collect_errors = _CollectErrors()
for name in chain(ctx.request.blueprints, (None,)):
if name in self.teardown_request_funcs:
for func in reversed(self.teardown_request_funcs[name]):
self.ensure_sync(func)(exc)
with collect_errors:
self.ensure_sync(func)(exc)
request_tearing_down.send(self, _async_wrapper=self.ensure_sync, exc=exc)
with collect_errors:
request_tearing_down.send(self, _async_wrapper=self.ensure_sync, exc=exc)
collect_errors.raise_any("Errors during request teardown")
def do_teardown_appcontext(
self, ctx: AppContext, exc: BaseException | None = None
@ -1452,12 +1462,21 @@ class Flask(App):
:param exc: An unhandled exception raised while the context was active.
Passed to each teardown function.
.. versionchanged:: 3.2
All callbacks are called rather than stopping on the first error.
.. versionadded:: 0.9
"""
for func in reversed(self.teardown_appcontext_funcs):
self.ensure_sync(func)(exc)
collect_errors = _CollectErrors()
appcontext_tearing_down.send(self, _async_wrapper=self.ensure_sync, exc=exc)
for func in reversed(self.teardown_appcontext_funcs):
with collect_errors:
self.ensure_sync(func)(exc)
with collect_errors:
appcontext_tearing_down.send(self, _async_wrapper=self.ensure_sync, exc=exc)
collect_errors.raise_any("Errors during app teardown")
def app_context(self) -> AppContext:
"""Create an :class:`.AppContext`. When the context is pushed,

View file

@ -10,6 +10,7 @@ from werkzeug.routing import MapAdapter
from . import typing as ft
from .globals import _cv_app
from .helpers import _CollectErrors
from .signals import appcontext_popped
from .signals import appcontext_pushed
@ -482,16 +483,26 @@ class AppContext:
if self._push_count > 0:
return
try:
if self._request is not None:
collect_errors = _CollectErrors()
if self._request is not None:
with collect_errors:
self.app.do_teardown_request(self, exc)
with collect_errors:
self._request.close()
finally:
with collect_errors:
self.app.do_teardown_appcontext(self, exc)
_cv_app.reset(self._cv_token)
self._cv_token = None
_cv_app.reset(self._cv_token)
self._cv_token = None
with collect_errors:
appcontext_popped.send(self.app, _async_wrapper=self.app.ensure_sync)
collect_errors.raise_any("Errors during context teardown")
def __enter__(self) -> te.Self:
self.push()
return self

View file

@ -7,6 +7,7 @@ import typing as t
from datetime import datetime
from functools import cache
from functools import update_wrapper
from types import TracebackType
import werkzeug.utils
from werkzeug.exceptions import abort as _wz_abort
@ -636,3 +637,34 @@ def _split_blueprint_path(name: str) -> list[str]:
out.extend(_split_blueprint_path(name.rpartition(".")[0]))
return out
class _CollectErrors:
"""A context manager that records and silences an error raised within it.
Used to run all teardown functions, then raise any errors afterward.
"""
def __init__(self) -> None:
self.errors: list[BaseException] = []
def __enter__(self) -> None:
pass
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool:
if exc_val is not None:
self.errors.append(exc_val)
return True
def raise_any(self, message: str) -> None:
"""Raise if any errors were collected."""
if self.errors:
if sys.version_info >= (3, 11):
raise BaseExceptionGroup(message, self.errors) # noqa: F821
else:
raise self.errors[0]