forked from orbit-oss/flask
Alter ensure_sync implementation to support extensions
This allows extensions to override the Flask.ensure_sync method and have the change apply to blueprints as well. Without this change it is possible for differing blueprints to have differing ensure_sync approaches depending on the extension used - which would likely result in event-loop blocking issues. This also allows blueprints to have a custom ensure_sync, although this is a by product rather than an expected use case.
This commit is contained in:
parent
c6c6408c3f
commit
00f5a3e55c
5 changed files with 169 additions and 32 deletions
|
|
@ -2,6 +2,7 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import weakref
|
import weakref
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from inspect import iscoroutinefunction
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
|
|
@ -34,6 +35,7 @@ from .helpers import get_env
|
||||||
from .helpers import get_flashed_messages
|
from .helpers import get_flashed_messages
|
||||||
from .helpers import get_load_dotenv
|
from .helpers import get_load_dotenv
|
||||||
from .helpers import locked_cached_property
|
from .helpers import locked_cached_property
|
||||||
|
from .helpers import run_async
|
||||||
from .helpers import url_for
|
from .helpers import url_for
|
||||||
from .json import jsonify
|
from .json import jsonify
|
||||||
from .logging import create_logger
|
from .logging import create_logger
|
||||||
|
|
@ -1517,6 +1519,19 @@ class Flask(Scaffold):
|
||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def ensure_sync(self, func):
|
||||||
|
"""Ensure that the returned function is sync and calls the async func.
|
||||||
|
|
||||||
|
.. versionadded:: 2.0
|
||||||
|
|
||||||
|
Override if you wish to change how asynchronous functions are
|
||||||
|
run.
|
||||||
|
"""
|
||||||
|
if iscoroutinefunction(func):
|
||||||
|
return run_async(func)
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
def make_response(self, rv):
|
def make_response(self, rv):
|
||||||
"""Convert the return value from a view function to an instance of
|
"""Convert the return value from a view function to an instance of
|
||||||
:attr:`response_class`.
|
:attr:`response_class`.
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from collections import defaultdict
|
||||||
from functools import update_wrapper
|
from functools import update_wrapper
|
||||||
|
|
||||||
from .scaffold import _endpoint_from_view_func
|
from .scaffold import _endpoint_from_view_func
|
||||||
|
|
@ -235,24 +236,44 @@ class Blueprint(Scaffold):
|
||||||
# Merge blueprint data into parent.
|
# Merge blueprint data into parent.
|
||||||
if first_registration:
|
if first_registration:
|
||||||
|
|
||||||
def extend(bp_dict, parent_dict):
|
def extend(bp_dict, parent_dict, ensure_sync=False):
|
||||||
for key, values in bp_dict.items():
|
for key, values in bp_dict.items():
|
||||||
key = self.name if key is None else f"{self.name}.{key}"
|
key = self.name if key is None else f"{self.name}.{key}"
|
||||||
|
|
||||||
|
if ensure_sync:
|
||||||
|
values = [app.ensure_sync(func) for func in values]
|
||||||
|
|
||||||
parent_dict[key].extend(values)
|
parent_dict[key].extend(values)
|
||||||
|
|
||||||
def update(bp_dict, parent_dict):
|
for key, value in self.error_handler_spec.items():
|
||||||
for key, value in bp_dict.items():
|
key = self.name if key is None else f"{self.name}.{key}"
|
||||||
key = self.name if key is None else f"{self.name}.{key}"
|
value = defaultdict(
|
||||||
parent_dict[key] = value
|
dict,
|
||||||
|
{
|
||||||
|
code: {
|
||||||
|
exc_class: app.ensure_sync(func)
|
||||||
|
for exc_class, func in code_values.items()
|
||||||
|
}
|
||||||
|
for code, code_values in value.items()
|
||||||
|
},
|
||||||
|
)
|
||||||
|
app.error_handler_spec[key] = value
|
||||||
|
|
||||||
app.view_functions.update(self.view_functions)
|
for endpoint, func in self.view_functions.items():
|
||||||
extend(self.before_request_funcs, app.before_request_funcs)
|
app.view_functions[endpoint] = app.ensure_sync(func)
|
||||||
extend(self.after_request_funcs, app.after_request_funcs)
|
|
||||||
extend(self.teardown_request_funcs, app.teardown_request_funcs)
|
extend(
|
||||||
|
self.before_request_funcs, app.before_request_funcs, ensure_sync=True
|
||||||
|
)
|
||||||
|
extend(self.after_request_funcs, app.after_request_funcs, ensure_sync=True)
|
||||||
|
extend(
|
||||||
|
self.teardown_request_funcs,
|
||||||
|
app.teardown_request_funcs,
|
||||||
|
ensure_sync=True,
|
||||||
|
)
|
||||||
extend(self.url_default_functions, app.url_default_functions)
|
extend(self.url_default_functions, app.url_default_functions)
|
||||||
extend(self.url_value_preprocessors, app.url_value_preprocessors)
|
extend(self.url_value_preprocessors, app.url_value_preprocessors)
|
||||||
extend(self.template_context_processors, app.template_context_processors)
|
extend(self.template_context_processors, app.template_context_processors)
|
||||||
update(self.error_handler_spec, app.error_handler_spec)
|
|
||||||
|
|
||||||
for deferred in self.deferred_functions:
|
for deferred in self.deferred_functions:
|
||||||
deferred(state)
|
deferred(state)
|
||||||
|
|
@ -380,7 +401,9 @@ class Blueprint(Scaffold):
|
||||||
before each request, even if outside of a blueprint.
|
before each request, even if outside of a blueprint.
|
||||||
"""
|
"""
|
||||||
self.record_once(
|
self.record_once(
|
||||||
lambda s: s.app.before_request_funcs.setdefault(None, []).append(f)
|
lambda s: s.app.before_request_funcs.setdefault(None, []).append(
|
||||||
|
s.app.ensure_sync(f)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
@ -388,7 +411,9 @@ class Blueprint(Scaffold):
|
||||||
"""Like :meth:`Flask.before_first_request`. Such a function is
|
"""Like :meth:`Flask.before_first_request`. Such a function is
|
||||||
executed before the first request to the application.
|
executed before the first request to the application.
|
||||||
"""
|
"""
|
||||||
self.record_once(lambda s: s.app.before_first_request_funcs.append(f))
|
self.record_once(
|
||||||
|
lambda s: s.app.before_first_request_funcs.append(s.app.ensure_sync(f))
|
||||||
|
)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
def after_app_request(self, f):
|
def after_app_request(self, f):
|
||||||
|
|
@ -396,7 +421,9 @@ class Blueprint(Scaffold):
|
||||||
is executed after each request, even if outside of the blueprint.
|
is executed after each request, even if outside of the blueprint.
|
||||||
"""
|
"""
|
||||||
self.record_once(
|
self.record_once(
|
||||||
lambda s: s.app.after_request_funcs.setdefault(None, []).append(f)
|
lambda s: s.app.after_request_funcs.setdefault(None, []).append(
|
||||||
|
s.app.ensure_sync(f)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
@ -443,3 +470,14 @@ class Blueprint(Scaffold):
|
||||||
lambda s: s.app.url_default_functions.setdefault(None, []).append(f)
|
lambda s: s.app.url_default_functions.setdefault(None, []).append(f)
|
||||||
)
|
)
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
def ensure_sync(self, f):
|
||||||
|
"""Ensure the function is synchronous.
|
||||||
|
|
||||||
|
Override if you would like custom async to sync behaviour in
|
||||||
|
this blueprint. Otherwise :meth:`~flask.Flask..ensure_sync` is
|
||||||
|
used.
|
||||||
|
|
||||||
|
.. versionadded:: 2.0
|
||||||
|
"""
|
||||||
|
return f
|
||||||
|
|
|
||||||
|
|
@ -755,6 +755,7 @@ def run_async(func):
|
||||||
ctx module, except it has an async inner.
|
ctx module, except it has an async inner.
|
||||||
"""
|
"""
|
||||||
ctx = None
|
ctx = None
|
||||||
|
|
||||||
if _request_ctx_stack.top is not None:
|
if _request_ctx_stack.top is not None:
|
||||||
ctx = _request_ctx_stack.top.copy()
|
ctx = _request_ctx_stack.top.copy()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ import pkgutil
|
||||||
import sys
|
import sys
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from functools import update_wrapper
|
from functools import update_wrapper
|
||||||
from inspect import iscoroutinefunction
|
|
||||||
|
|
||||||
from jinja2 import FileSystemLoader
|
from jinja2 import FileSystemLoader
|
||||||
from werkzeug.exceptions import default_exceptions
|
from werkzeug.exceptions import default_exceptions
|
||||||
|
|
@ -13,7 +12,6 @@ from werkzeug.exceptions import HTTPException
|
||||||
from .cli import AppGroup
|
from .cli import AppGroup
|
||||||
from .globals import current_app
|
from .globals import current_app
|
||||||
from .helpers import locked_cached_property
|
from .helpers import locked_cached_property
|
||||||
from .helpers import run_async
|
|
||||||
from .helpers import send_from_directory
|
from .helpers import send_from_directory
|
||||||
from .templating import _default_template_ctx_processor
|
from .templating import _default_template_ctx_processor
|
||||||
|
|
||||||
|
|
@ -687,17 +685,7 @@ class Scaffold:
|
||||||
return exc_class, None
|
return exc_class, None
|
||||||
|
|
||||||
def ensure_sync(self, func):
|
def ensure_sync(self, func):
|
||||||
"""Ensure that the returned function is sync and calls the async func.
|
raise NotImplementedError()
|
||||||
|
|
||||||
.. versionadded:: 2.0
|
|
||||||
|
|
||||||
Override if you wish to change how asynchronous functions are
|
|
||||||
run.
|
|
||||||
"""
|
|
||||||
if iscoroutinefunction(func):
|
|
||||||
return run_async(func)
|
|
||||||
else:
|
|
||||||
return func
|
|
||||||
|
|
||||||
|
|
||||||
def _endpoint_from_view_func(view_func):
|
def _endpoint_from_view_func(view_func):
|
||||||
|
|
|
||||||
|
|
@ -3,12 +3,20 @@ import sys
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from flask import abort
|
from flask import Blueprint
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
from flask import request
|
from flask import request
|
||||||
from flask.helpers import run_async
|
from flask.helpers import run_async
|
||||||
|
|
||||||
|
|
||||||
|
class AppError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class BlueprintError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="async_app")
|
@pytest.fixture(name="async_app")
|
||||||
def _async_app():
|
def _async_app():
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
|
|
@ -18,24 +26,111 @@ def _async_app():
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
return request.method
|
return request.method
|
||||||
|
|
||||||
|
@app.errorhandler(AppError)
|
||||||
|
async def handle(_):
|
||||||
|
return "", 412
|
||||||
|
|
||||||
@app.route("/error")
|
@app.route("/error")
|
||||||
async def error():
|
async def error():
|
||||||
abort(412)
|
raise AppError()
|
||||||
|
|
||||||
|
blueprint = Blueprint("bp", __name__)
|
||||||
|
|
||||||
|
@blueprint.route("/", methods=["GET", "POST"])
|
||||||
|
async def bp_index():
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
return request.method
|
||||||
|
|
||||||
|
@blueprint.errorhandler(BlueprintError)
|
||||||
|
async def bp_handle(_):
|
||||||
|
return "", 412
|
||||||
|
|
||||||
|
@blueprint.route("/error")
|
||||||
|
async def bp_error():
|
||||||
|
raise BlueprintError()
|
||||||
|
|
||||||
|
app.register_blueprint(blueprint, url_prefix="/bp")
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires Python >= 3.7")
|
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires Python >= 3.7")
|
||||||
def test_async_request_context(async_app):
|
@pytest.mark.parametrize("path", ["/", "/bp/"])
|
||||||
|
def test_async_route(path, async_app):
|
||||||
test_client = async_app.test_client()
|
test_client = async_app.test_client()
|
||||||
response = test_client.get("/")
|
response = test_client.get(path)
|
||||||
assert b"GET" in response.get_data()
|
assert b"GET" in response.get_data()
|
||||||
response = test_client.post("/")
|
response = test_client.post(path)
|
||||||
assert b"POST" in response.get_data()
|
assert b"POST" in response.get_data()
|
||||||
response = test_client.get("/error")
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires Python >= 3.7")
|
||||||
|
@pytest.mark.parametrize("path", ["/error", "/bp/error"])
|
||||||
|
def test_async_error_handler(path, async_app):
|
||||||
|
test_client = async_app.test_client()
|
||||||
|
response = test_client.get(path)
|
||||||
assert response.status_code == 412
|
assert response.status_code == 412
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires Python >= 3.7")
|
||||||
|
def test_async_before_after_request():
|
||||||
|
app_first_called = False
|
||||||
|
app_before_called = False
|
||||||
|
app_after_called = False
|
||||||
|
bp_before_called = False
|
||||||
|
bp_after_called = False
|
||||||
|
|
||||||
|
app = Flask(__name__)
|
||||||
|
|
||||||
|
@app.route("/")
|
||||||
|
def index():
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@app.before_first_request
|
||||||
|
async def before_first():
|
||||||
|
nonlocal app_first_called
|
||||||
|
app_first_called = True
|
||||||
|
|
||||||
|
@app.before_request
|
||||||
|
async def before():
|
||||||
|
nonlocal app_before_called
|
||||||
|
app_before_called = True
|
||||||
|
|
||||||
|
@app.after_request
|
||||||
|
async def after(response):
|
||||||
|
nonlocal app_after_called
|
||||||
|
app_after_called = True
|
||||||
|
return response
|
||||||
|
|
||||||
|
blueprint = Blueprint("bp", __name__)
|
||||||
|
|
||||||
|
@blueprint.route("/")
|
||||||
|
def bp_index():
|
||||||
|
return ""
|
||||||
|
|
||||||
|
@blueprint.before_request
|
||||||
|
async def bp_before():
|
||||||
|
nonlocal bp_before_called
|
||||||
|
bp_before_called = True
|
||||||
|
|
||||||
|
@blueprint.after_request
|
||||||
|
async def bp_after(response):
|
||||||
|
nonlocal bp_after_called
|
||||||
|
bp_after_called = True
|
||||||
|
return response
|
||||||
|
|
||||||
|
app.register_blueprint(blueprint, url_prefix="/bp")
|
||||||
|
|
||||||
|
test_client = app.test_client()
|
||||||
|
test_client.get("/")
|
||||||
|
assert app_first_called
|
||||||
|
assert app_before_called
|
||||||
|
assert app_after_called
|
||||||
|
test_client.get("/bp/")
|
||||||
|
assert bp_before_called
|
||||||
|
assert bp_after_called
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(sys.version_info >= (3, 7), reason="should only raise Python < 3.7")
|
@pytest.mark.skipif(sys.version_info >= (3, 7), reason="should only raise Python < 3.7")
|
||||||
def test_async_runtime_error():
|
def test_async_runtime_error():
|
||||||
with pytest.raises(RuntimeError):
|
with pytest.raises(RuntimeError):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue