diff --git a/src/flask/app.py b/src/flask/app.py index 65ec5046..c1743e6e 100644 --- a/src/flask/app.py +++ b/src/flask/app.py @@ -2,6 +2,7 @@ import os import sys import weakref from datetime import timedelta +from inspect import iscoroutinefunction from itertools import chain from threading import Lock @@ -34,6 +35,7 @@ from .helpers import get_env from .helpers import get_flashed_messages from .helpers import get_load_dotenv from .helpers import locked_cached_property +from .helpers import run_async from .helpers import url_for from .json import jsonify from .logging import create_logger @@ -1517,6 +1519,19 @@ class Flask(Scaffold): """ return False + def ensure_sync(self, func): + """Ensure that the returned function is sync and calls the async func. + + .. versionadded:: 2.0 + + Override if you wish to change how asynchronous functions are + run. + """ + if iscoroutinefunction(func): + return run_async(func) + + return func + def make_response(self, rv): """Convert the return value from a view function to an instance of :attr:`response_class`. diff --git a/src/flask/blueprints.py b/src/flask/blueprints.py index 7c314203..c4f94a06 100644 --- a/src/flask/blueprints.py +++ b/src/flask/blueprints.py @@ -1,3 +1,4 @@ +from collections import defaultdict from functools import update_wrapper from .scaffold import _endpoint_from_view_func @@ -235,24 +236,44 @@ class Blueprint(Scaffold): # Merge blueprint data into parent. if first_registration: - def extend(bp_dict, parent_dict): + def extend(bp_dict, parent_dict, ensure_sync=False): for key, values in bp_dict.items(): key = self.name if key is None else f"{self.name}.{key}" + + if ensure_sync: + values = [app.ensure_sync(func) for func in values] + parent_dict[key].extend(values) - def update(bp_dict, parent_dict): - for key, value in bp_dict.items(): - key = self.name if key is None else f"{self.name}.{key}" - parent_dict[key] = value + for key, value in self.error_handler_spec.items(): + key = self.name if key is None else f"{self.name}.{key}" + value = defaultdict( + dict, + { + code: { + exc_class: app.ensure_sync(func) + for exc_class, func in code_values.items() + } + for code, code_values in value.items() + }, + ) + app.error_handler_spec[key] = value - app.view_functions.update(self.view_functions) - extend(self.before_request_funcs, app.before_request_funcs) - extend(self.after_request_funcs, app.after_request_funcs) - extend(self.teardown_request_funcs, app.teardown_request_funcs) + for endpoint, func in self.view_functions.items(): + app.view_functions[endpoint] = app.ensure_sync(func) + + extend( + self.before_request_funcs, app.before_request_funcs, ensure_sync=True + ) + extend(self.after_request_funcs, app.after_request_funcs, ensure_sync=True) + extend( + self.teardown_request_funcs, + app.teardown_request_funcs, + ensure_sync=True, + ) extend(self.url_default_functions, app.url_default_functions) extend(self.url_value_preprocessors, app.url_value_preprocessors) extend(self.template_context_processors, app.template_context_processors) - update(self.error_handler_spec, app.error_handler_spec) for deferred in self.deferred_functions: deferred(state) @@ -380,7 +401,9 @@ class Blueprint(Scaffold): before each request, even if outside of a blueprint. """ self.record_once( - lambda s: s.app.before_request_funcs.setdefault(None, []).append(f) + lambda s: s.app.before_request_funcs.setdefault(None, []).append( + s.app.ensure_sync(f) + ) ) return f @@ -388,7 +411,9 @@ class Blueprint(Scaffold): """Like :meth:`Flask.before_first_request`. Such a function is executed before the first request to the application. """ - self.record_once(lambda s: s.app.before_first_request_funcs.append(f)) + self.record_once( + lambda s: s.app.before_first_request_funcs.append(s.app.ensure_sync(f)) + ) return f def after_app_request(self, f): @@ -396,7 +421,9 @@ class Blueprint(Scaffold): is executed after each request, even if outside of the blueprint. """ self.record_once( - lambda s: s.app.after_request_funcs.setdefault(None, []).append(f) + lambda s: s.app.after_request_funcs.setdefault(None, []).append( + s.app.ensure_sync(f) + ) ) return f @@ -443,3 +470,14 @@ class Blueprint(Scaffold): lambda s: s.app.url_default_functions.setdefault(None, []).append(f) ) return f + + def ensure_sync(self, f): + """Ensure the function is synchronous. + + Override if you would like custom async to sync behaviour in + this blueprint. Otherwise :meth:`~flask.Flask..ensure_sync` is + used. + + .. versionadded:: 2.0 + """ + return f diff --git a/src/flask/helpers.py b/src/flask/helpers.py index 5933f42a..3b4377c3 100644 --- a/src/flask/helpers.py +++ b/src/flask/helpers.py @@ -755,6 +755,7 @@ def run_async(func): ctx module, except it has an async inner. """ ctx = None + if _request_ctx_stack.top is not None: ctx = _request_ctx_stack.top.copy() diff --git a/src/flask/scaffold.py b/src/flask/scaffold.py index 7911bc71..bfa53068 100644 --- a/src/flask/scaffold.py +++ b/src/flask/scaffold.py @@ -4,7 +4,6 @@ import pkgutil import sys from collections import defaultdict from functools import update_wrapper -from inspect import iscoroutinefunction from jinja2 import FileSystemLoader from werkzeug.exceptions import default_exceptions @@ -13,7 +12,6 @@ from werkzeug.exceptions import HTTPException from .cli import AppGroup from .globals import current_app from .helpers import locked_cached_property -from .helpers import run_async from .helpers import send_from_directory from .templating import _default_template_ctx_processor @@ -687,17 +685,7 @@ class Scaffold: return exc_class, None def ensure_sync(self, func): - """Ensure that the returned function is sync and calls the async func. - - .. versionadded:: 2.0 - - Override if you wish to change how asynchronous functions are - run. - """ - if iscoroutinefunction(func): - return run_async(func) - else: - return func + raise NotImplementedError() def _endpoint_from_view_func(view_func): diff --git a/tests/test_async.py b/tests/test_async.py index 12784c34..de7d89b6 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -3,12 +3,20 @@ import sys import pytest -from flask import abort +from flask import Blueprint from flask import Flask from flask import request from flask.helpers import run_async +class AppError(Exception): + pass + + +class BlueprintError(Exception): + pass + + @pytest.fixture(name="async_app") def _async_app(): app = Flask(__name__) @@ -18,24 +26,111 @@ def _async_app(): await asyncio.sleep(0) return request.method + @app.errorhandler(AppError) + async def handle(_): + return "", 412 + @app.route("/error") async def error(): - abort(412) + raise AppError() + + blueprint = Blueprint("bp", __name__) + + @blueprint.route("/", methods=["GET", "POST"]) + async def bp_index(): + await asyncio.sleep(0) + return request.method + + @blueprint.errorhandler(BlueprintError) + async def bp_handle(_): + return "", 412 + + @blueprint.route("/error") + async def bp_error(): + raise BlueprintError() + + app.register_blueprint(blueprint, url_prefix="/bp") return app @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires Python >= 3.7") -def test_async_request_context(async_app): +@pytest.mark.parametrize("path", ["/", "/bp/"]) +def test_async_route(path, async_app): test_client = async_app.test_client() - response = test_client.get("/") + response = test_client.get(path) assert b"GET" in response.get_data() - response = test_client.post("/") + response = test_client.post(path) assert b"POST" in response.get_data() - response = test_client.get("/error") + + +@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires Python >= 3.7") +@pytest.mark.parametrize("path", ["/error", "/bp/error"]) +def test_async_error_handler(path, async_app): + test_client = async_app.test_client() + response = test_client.get(path) assert response.status_code == 412 +@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires Python >= 3.7") +def test_async_before_after_request(): + app_first_called = False + app_before_called = False + app_after_called = False + bp_before_called = False + bp_after_called = False + + app = Flask(__name__) + + @app.route("/") + def index(): + return "" + + @app.before_first_request + async def before_first(): + nonlocal app_first_called + app_first_called = True + + @app.before_request + async def before(): + nonlocal app_before_called + app_before_called = True + + @app.after_request + async def after(response): + nonlocal app_after_called + app_after_called = True + return response + + blueprint = Blueprint("bp", __name__) + + @blueprint.route("/") + def bp_index(): + return "" + + @blueprint.before_request + async def bp_before(): + nonlocal bp_before_called + bp_before_called = True + + @blueprint.after_request + async def bp_after(response): + nonlocal bp_after_called + bp_after_called = True + return response + + app.register_blueprint(blueprint, url_prefix="/bp") + + test_client = app.test_client() + test_client.get("/") + assert app_first_called + assert app_before_called + assert app_after_called + test_client.get("/bp/") + assert bp_before_called + assert bp_after_called + + @pytest.mark.skipif(sys.version_info >= (3, 7), reason="should only raise Python < 3.7") def test_async_runtime_error(): with pytest.raises(RuntimeError):