Alter ensure_sync implementation to support extensions

This allows extensions to override the Flask.ensure_sync method and
have the change apply to blueprints as well. Without this change it is
possible for differing blueprints to have differing ensure_sync
approaches depending on the extension used - which would likely result
in event-loop blocking issues.

This also allows blueprints to have a custom ensure_sync, although
this is a by product rather than an expected use case.
This commit is contained in:
pgjones 2021-03-24 20:47:55 +00:00 committed by David Lord
parent c6c6408c3f
commit 00f5a3e55c
No known key found for this signature in database
GPG key ID: 7A1C87E3F5BC42A8
5 changed files with 169 additions and 32 deletions

View file

@ -2,6 +2,7 @@ import os
import sys import sys
import weakref import weakref
from datetime import timedelta from datetime import timedelta
from inspect import iscoroutinefunction
from itertools import chain from itertools import chain
from threading import Lock from threading import Lock
@ -34,6 +35,7 @@ from .helpers import get_env
from .helpers import get_flashed_messages from .helpers import get_flashed_messages
from .helpers import get_load_dotenv from .helpers import get_load_dotenv
from .helpers import locked_cached_property from .helpers import locked_cached_property
from .helpers import run_async
from .helpers import url_for from .helpers import url_for
from .json import jsonify from .json import jsonify
from .logging import create_logger from .logging import create_logger
@ -1517,6 +1519,19 @@ class Flask(Scaffold):
""" """
return False return False
def ensure_sync(self, func):
"""Ensure that the returned function is sync and calls the async func.
.. versionadded:: 2.0
Override if you wish to change how asynchronous functions are
run.
"""
if iscoroutinefunction(func):
return run_async(func)
return func
def make_response(self, rv): def make_response(self, rv):
"""Convert the return value from a view function to an instance of """Convert the return value from a view function to an instance of
:attr:`response_class`. :attr:`response_class`.

View file

@ -1,3 +1,4 @@
from collections import defaultdict
from functools import update_wrapper from functools import update_wrapper
from .scaffold import _endpoint_from_view_func from .scaffold import _endpoint_from_view_func
@ -235,24 +236,44 @@ class Blueprint(Scaffold):
# Merge blueprint data into parent. # Merge blueprint data into parent.
if first_registration: if first_registration:
def extend(bp_dict, parent_dict): def extend(bp_dict, parent_dict, ensure_sync=False):
for key, values in bp_dict.items(): for key, values in bp_dict.items():
key = self.name if key is None else f"{self.name}.{key}" key = self.name if key is None else f"{self.name}.{key}"
if ensure_sync:
values = [app.ensure_sync(func) for func in values]
parent_dict[key].extend(values) parent_dict[key].extend(values)
def update(bp_dict, parent_dict): for key, value in self.error_handler_spec.items():
for key, value in bp_dict.items(): key = self.name if key is None else f"{self.name}.{key}"
key = self.name if key is None else f"{self.name}.{key}" value = defaultdict(
parent_dict[key] = value dict,
{
code: {
exc_class: app.ensure_sync(func)
for exc_class, func in code_values.items()
}
for code, code_values in value.items()
},
)
app.error_handler_spec[key] = value
app.view_functions.update(self.view_functions) for endpoint, func in self.view_functions.items():
extend(self.before_request_funcs, app.before_request_funcs) app.view_functions[endpoint] = app.ensure_sync(func)
extend(self.after_request_funcs, app.after_request_funcs)
extend(self.teardown_request_funcs, app.teardown_request_funcs) extend(
self.before_request_funcs, app.before_request_funcs, ensure_sync=True
)
extend(self.after_request_funcs, app.after_request_funcs, ensure_sync=True)
extend(
self.teardown_request_funcs,
app.teardown_request_funcs,
ensure_sync=True,
)
extend(self.url_default_functions, app.url_default_functions) extend(self.url_default_functions, app.url_default_functions)
extend(self.url_value_preprocessors, app.url_value_preprocessors) extend(self.url_value_preprocessors, app.url_value_preprocessors)
extend(self.template_context_processors, app.template_context_processors) extend(self.template_context_processors, app.template_context_processors)
update(self.error_handler_spec, app.error_handler_spec)
for deferred in self.deferred_functions: for deferred in self.deferred_functions:
deferred(state) deferred(state)
@ -380,7 +401,9 @@ class Blueprint(Scaffold):
before each request, even if outside of a blueprint. before each request, even if outside of a blueprint.
""" """
self.record_once( self.record_once(
lambda s: s.app.before_request_funcs.setdefault(None, []).append(f) lambda s: s.app.before_request_funcs.setdefault(None, []).append(
s.app.ensure_sync(f)
)
) )
return f return f
@ -388,7 +411,9 @@ class Blueprint(Scaffold):
"""Like :meth:`Flask.before_first_request`. Such a function is """Like :meth:`Flask.before_first_request`. Such a function is
executed before the first request to the application. executed before the first request to the application.
""" """
self.record_once(lambda s: s.app.before_first_request_funcs.append(f)) self.record_once(
lambda s: s.app.before_first_request_funcs.append(s.app.ensure_sync(f))
)
return f return f
def after_app_request(self, f): def after_app_request(self, f):
@ -396,7 +421,9 @@ class Blueprint(Scaffold):
is executed after each request, even if outside of the blueprint. is executed after each request, even if outside of the blueprint.
""" """
self.record_once( self.record_once(
lambda s: s.app.after_request_funcs.setdefault(None, []).append(f) lambda s: s.app.after_request_funcs.setdefault(None, []).append(
s.app.ensure_sync(f)
)
) )
return f return f
@ -443,3 +470,14 @@ class Blueprint(Scaffold):
lambda s: s.app.url_default_functions.setdefault(None, []).append(f) lambda s: s.app.url_default_functions.setdefault(None, []).append(f)
) )
return f return f
def ensure_sync(self, f):
"""Ensure the function is synchronous.
Override if you would like custom async to sync behaviour in
this blueprint. Otherwise :meth:`~flask.Flask..ensure_sync` is
used.
.. versionadded:: 2.0
"""
return f

View file

@ -755,6 +755,7 @@ def run_async(func):
ctx module, except it has an async inner. ctx module, except it has an async inner.
""" """
ctx = None ctx = None
if _request_ctx_stack.top is not None: if _request_ctx_stack.top is not None:
ctx = _request_ctx_stack.top.copy() ctx = _request_ctx_stack.top.copy()

View file

@ -4,7 +4,6 @@ import pkgutil
import sys import sys
from collections import defaultdict from collections import defaultdict
from functools import update_wrapper from functools import update_wrapper
from inspect import iscoroutinefunction
from jinja2 import FileSystemLoader from jinja2 import FileSystemLoader
from werkzeug.exceptions import default_exceptions from werkzeug.exceptions import default_exceptions
@ -13,7 +12,6 @@ from werkzeug.exceptions import HTTPException
from .cli import AppGroup from .cli import AppGroup
from .globals import current_app from .globals import current_app
from .helpers import locked_cached_property from .helpers import locked_cached_property
from .helpers import run_async
from .helpers import send_from_directory from .helpers import send_from_directory
from .templating import _default_template_ctx_processor from .templating import _default_template_ctx_processor
@ -687,17 +685,7 @@ class Scaffold:
return exc_class, None return exc_class, None
def ensure_sync(self, func): def ensure_sync(self, func):
"""Ensure that the returned function is sync and calls the async func. raise NotImplementedError()
.. versionadded:: 2.0
Override if you wish to change how asynchronous functions are
run.
"""
if iscoroutinefunction(func):
return run_async(func)
else:
return func
def _endpoint_from_view_func(view_func): def _endpoint_from_view_func(view_func):

View file

@ -3,12 +3,20 @@ import sys
import pytest import pytest
from flask import abort from flask import Blueprint
from flask import Flask from flask import Flask
from flask import request from flask import request
from flask.helpers import run_async from flask.helpers import run_async
class AppError(Exception):
pass
class BlueprintError(Exception):
pass
@pytest.fixture(name="async_app") @pytest.fixture(name="async_app")
def _async_app(): def _async_app():
app = Flask(__name__) app = Flask(__name__)
@ -18,24 +26,111 @@ def _async_app():
await asyncio.sleep(0) await asyncio.sleep(0)
return request.method return request.method
@app.errorhandler(AppError)
async def handle(_):
return "", 412
@app.route("/error") @app.route("/error")
async def error(): async def error():
abort(412) raise AppError()
blueprint = Blueprint("bp", __name__)
@blueprint.route("/", methods=["GET", "POST"])
async def bp_index():
await asyncio.sleep(0)
return request.method
@blueprint.errorhandler(BlueprintError)
async def bp_handle(_):
return "", 412
@blueprint.route("/error")
async def bp_error():
raise BlueprintError()
app.register_blueprint(blueprint, url_prefix="/bp")
return app return app
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires Python >= 3.7") @pytest.mark.skipif(sys.version_info < (3, 7), reason="requires Python >= 3.7")
def test_async_request_context(async_app): @pytest.mark.parametrize("path", ["/", "/bp/"])
def test_async_route(path, async_app):
test_client = async_app.test_client() test_client = async_app.test_client()
response = test_client.get("/") response = test_client.get(path)
assert b"GET" in response.get_data() assert b"GET" in response.get_data()
response = test_client.post("/") response = test_client.post(path)
assert b"POST" in response.get_data() assert b"POST" in response.get_data()
response = test_client.get("/error")
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires Python >= 3.7")
@pytest.mark.parametrize("path", ["/error", "/bp/error"])
def test_async_error_handler(path, async_app):
test_client = async_app.test_client()
response = test_client.get(path)
assert response.status_code == 412 assert response.status_code == 412
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires Python >= 3.7")
def test_async_before_after_request():
app_first_called = False
app_before_called = False
app_after_called = False
bp_before_called = False
bp_after_called = False
app = Flask(__name__)
@app.route("/")
def index():
return ""
@app.before_first_request
async def before_first():
nonlocal app_first_called
app_first_called = True
@app.before_request
async def before():
nonlocal app_before_called
app_before_called = True
@app.after_request
async def after(response):
nonlocal app_after_called
app_after_called = True
return response
blueprint = Blueprint("bp", __name__)
@blueprint.route("/")
def bp_index():
return ""
@blueprint.before_request
async def bp_before():
nonlocal bp_before_called
bp_before_called = True
@blueprint.after_request
async def bp_after(response):
nonlocal bp_after_called
bp_after_called = True
return response
app.register_blueprint(blueprint, url_prefix="/bp")
test_client = app.test_client()
test_client.get("/")
assert app_first_called
assert app_before_called
assert app_after_called
test_client.get("/bp/")
assert bp_before_called
assert bp_after_called
@pytest.mark.skipif(sys.version_info >= (3, 7), reason="should only raise Python < 3.7") @pytest.mark.skipif(sys.version_info >= (3, 7), reason="should only raise Python < 3.7")
def test_async_runtime_error(): def test_async_runtime_error():
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):