Merge pull request #3989 from pgjones/async

Async improvements
This commit is contained in:
David Lord 2021-05-03 06:23:00 -07:00 committed by GitHub
commit 47f0e799db
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 74 additions and 108 deletions

View file

@ -92,6 +92,21 @@ not work with async views because they will not await the function or be
awaitable. Other functions they provide will not be awaitable either and awaitable. Other functions they provide will not be awaitable either and
will probably be blocking if called within an async view. will probably be blocking if called within an async view.
Extension authors can support async functions by utilising the
:meth:`flask.Flask.ensure_sync` method. For example, if the extension
provides a view function decorator add ``ensure_sync`` before calling
the decorated function,
.. code-block:: python
def extension(func):
@wraps(func)
def wrapper(*args, **kwargs):
... # Extension logic
return current_app.ensure_sync(func)(*args, **kwargs)
return wrapper
Check the changelog of the extension you want to use to see if they've Check the changelog of the extension you want to use to see if they've
implemented async support, or make a feature request or PR to them. implemented async support, or make a feature request or PR to them.

View file

@ -16,6 +16,7 @@ from werkzeug.exceptions import BadRequest
from werkzeug.exceptions import BadRequestKeyError from werkzeug.exceptions import BadRequestKeyError
from werkzeug.exceptions import HTTPException from werkzeug.exceptions import HTTPException
from werkzeug.exceptions import InternalServerError from werkzeug.exceptions import InternalServerError
from werkzeug.local import ContextVar
from werkzeug.routing import BuildError from werkzeug.routing import BuildError
from werkzeug.routing import Map from werkzeug.routing import Map
from werkzeug.routing import MapAdapter from werkzeug.routing import MapAdapter
@ -40,7 +41,6 @@ from .helpers import get_env
from .helpers import get_flashed_messages from .helpers import get_flashed_messages
from .helpers import get_load_dotenv from .helpers import get_load_dotenv
from .helpers import locked_cached_property from .helpers import locked_cached_property
from .helpers import run_async
from .helpers import url_for from .helpers import url_for
from .json import jsonify from .json import jsonify
from .logging import create_logger from .logging import create_logger
@ -1080,14 +1080,12 @@ class Flask(Scaffold):
self.url_map.add(rule) self.url_map.add(rule)
if view_func is not None: if view_func is not None:
old_func = self.view_functions.get(endpoint) old_func = self.view_functions.get(endpoint)
if getattr(old_func, "_flask_sync_wrapper", False):
old_func = old_func.__wrapped__ # type: ignore
if old_func is not None and old_func != view_func: if old_func is not None and old_func != view_func:
raise AssertionError( raise AssertionError(
"View function mapping is overwriting an existing" "View function mapping is overwriting an existing"
f" endpoint function: {endpoint}" f" endpoint function: {endpoint}"
) )
self.view_functions[endpoint] = self.ensure_sync(view_func) self.view_functions[endpoint] = view_func
@setupmethod @setupmethod
def template_filter(self, name: t.Optional[str] = None) -> t.Callable: def template_filter(self, name: t.Optional[str] = None) -> t.Callable:
@ -1208,7 +1206,7 @@ class Flask(Scaffold):
.. versionadded:: 0.8 .. versionadded:: 0.8
""" """
self.before_first_request_funcs.append(self.ensure_sync(f)) self.before_first_request_funcs.append(f)
return f return f
@setupmethod @setupmethod
@ -1241,7 +1239,7 @@ class Flask(Scaffold):
.. versionadded:: 0.9 .. versionadded:: 0.9
""" """
self.teardown_appcontext_funcs.append(self.ensure_sync(f)) self.teardown_appcontext_funcs.append(f)
return f return f
@setupmethod @setupmethod
@ -1308,7 +1306,7 @@ class Flask(Scaffold):
handler = self._find_error_handler(e) handler = self._find_error_handler(e)
if handler is None: if handler is None:
return e return e
return handler(e) return self.ensure_sync(handler)(e)
def trap_http_exception(self, e: Exception) -> bool: def trap_http_exception(self, e: Exception) -> bool:
"""Checks if an HTTP exception should be trapped or not. By default """Checks if an HTTP exception should be trapped or not. By default
@ -1375,7 +1373,7 @@ class Flask(Scaffold):
if handler is None: if handler is None:
raise raise
return handler(e) return self.ensure_sync(handler)(e)
def handle_exception(self, e: Exception) -> Response: def handle_exception(self, e: Exception) -> Response:
"""Handle an exception that did not have an error handler """Handle an exception that did not have an error handler
@ -1422,7 +1420,7 @@ class Flask(Scaffold):
handler = self._find_error_handler(server_error) handler = self._find_error_handler(server_error)
if handler is not None: if handler is not None:
server_error = handler(server_error) server_error = self.ensure_sync(handler)(server_error)
return self.finalize_request(server_error, from_error_handler=True) return self.finalize_request(server_error, from_error_handler=True)
@ -1484,7 +1482,7 @@ class Flask(Scaffold):
): ):
return self.make_default_options_response() return self.make_default_options_response()
# otherwise dispatch to the handler for that endpoint # otherwise dispatch to the handler for that endpoint
return self.view_functions[rule.endpoint](**req.view_args) return self.ensure_sync(self.view_functions[rule.endpoint])(**req.view_args)
def full_dispatch_request(self) -> Response: def full_dispatch_request(self) -> Response:
"""Dispatches the request and on top of that performs request """Dispatches the request and on top of that performs request
@ -1545,7 +1543,7 @@ class Flask(Scaffold):
if self._got_first_request: if self._got_first_request:
return return
for func in self.before_first_request_funcs: for func in self.before_first_request_funcs:
func() self.ensure_sync(func)()
self._got_first_request = True self._got_first_request = True
def make_default_options_response(self) -> Response: def make_default_options_response(self) -> Response:
@ -1581,10 +1579,40 @@ class Flask(Scaffold):
.. versionadded:: 2.0 .. versionadded:: 2.0
""" """
if iscoroutinefunction(func): if iscoroutinefunction(func):
return run_async(func) return self.async_to_sync(func)
return func return func
def async_to_sync(
self, func: t.Callable[..., t.Coroutine]
) -> t.Callable[..., t.Any]:
"""Return a sync function that will run the coroutine function.
.. code-block:: python
result = app.async_to_sync(func)(*args, **kwargs)
Override this method to change how the app converts async code
to be synchronously callable.
.. versionadded:: 2.0
"""
try:
from asgiref.sync import async_to_sync as asgiref_async_to_sync
except ImportError:
raise RuntimeError(
"Install Flask with the 'async' extra in order to use async views."
)
# Check that Werkzeug isn't using its fallback ContextVar class.
if ContextVar.__module__ == "werkzeug.local":
raise RuntimeError(
"Async cannot be used with this combination of Python "
"and Greenlet versions."
)
return asgiref_async_to_sync(func)
def make_response(self, rv: ResponseReturnValue) -> Response: def make_response(self, rv: ResponseReturnValue) -> Response:
"""Convert the return value from a view function to an instance of """Convert the return value from a view function to an instance of
:attr:`response_class`. :attr:`response_class`.
@ -1807,7 +1835,7 @@ class Flask(Scaffold):
if bp in self.before_request_funcs: if bp in self.before_request_funcs:
funcs = chain(funcs, self.before_request_funcs[bp]) funcs = chain(funcs, self.before_request_funcs[bp])
for func in funcs: for func in funcs:
rv = func() rv = self.ensure_sync(func)()
if rv is not None: if rv is not None:
return rv return rv
@ -1834,7 +1862,7 @@ class Flask(Scaffold):
if None in self.after_request_funcs: if None in self.after_request_funcs:
funcs = chain(funcs, reversed(self.after_request_funcs[None])) funcs = chain(funcs, reversed(self.after_request_funcs[None]))
for handler in funcs: for handler in funcs:
response = handler(response) response = self.ensure_sync(handler)(response)
if not self.session_interface.is_null_session(ctx.session): if not self.session_interface.is_null_session(ctx.session):
self.session_interface.save_session(self, ctx.session, response) self.session_interface.save_session(self, ctx.session, response)
return response return response
@ -1871,7 +1899,7 @@ class Flask(Scaffold):
if bp in self.teardown_request_funcs: if bp in self.teardown_request_funcs:
funcs = chain(funcs, reversed(self.teardown_request_funcs[bp])) funcs = chain(funcs, reversed(self.teardown_request_funcs[bp]))
for func in funcs: for func in funcs:
func(exc) self.ensure_sync(func)(exc)
request_tearing_down.send(self, exc=exc) request_tearing_down.send(self, exc=exc)
def do_teardown_appcontext( def do_teardown_appcontext(
@ -1894,7 +1922,7 @@ class Flask(Scaffold):
if exc is _sentinel: if exc is _sentinel:
exc = sys.exc_info()[1] exc = sys.exc_info()[1]
for func in reversed(self.teardown_appcontext_funcs): for func in reversed(self.teardown_appcontext_funcs):
func(exc) self.ensure_sync(func)(exc)
appcontext_tearing_down.send(self, exc=exc) appcontext_tearing_down.send(self, exc=exc)
def app_context(self) -> AppContext: def app_context(self) -> AppContext:

View file

@ -292,13 +292,10 @@ class Blueprint(Scaffold):
# Merge blueprint data into parent. # Merge blueprint data into parent.
if first_registration: if first_registration:
def extend(bp_dict, parent_dict, ensure_sync=False): def extend(bp_dict, parent_dict):
for key, values in bp_dict.items(): for key, values in bp_dict.items():
key = self.name if key is None else f"{self.name}.{key}" key = self.name if key is None else f"{self.name}.{key}"
if ensure_sync:
values = [app.ensure_sync(func) for func in values]
parent_dict[key].extend(values) parent_dict[key].extend(values)
for key, value in self.error_handler_spec.items(): for key, value in self.error_handler_spec.items():
@ -307,8 +304,7 @@ class Blueprint(Scaffold):
dict, dict,
{ {
code: { code: {
exc_class: app.ensure_sync(func) exc_class: func for exc_class, func in code_values.items()
for exc_class, func in code_values.items()
} }
for code, code_values in value.items() for code, code_values in value.items()
}, },
@ -316,16 +312,13 @@ class Blueprint(Scaffold):
app.error_handler_spec[key] = value app.error_handler_spec[key] = value
for endpoint, func in self.view_functions.items(): for endpoint, func in self.view_functions.items():
app.view_functions[endpoint] = app.ensure_sync(func) app.view_functions[endpoint] = func
extend( extend(self.before_request_funcs, app.before_request_funcs)
self.before_request_funcs, app.before_request_funcs, ensure_sync=True extend(self.after_request_funcs, app.after_request_funcs)
)
extend(self.after_request_funcs, app.after_request_funcs, ensure_sync=True)
extend( extend(
self.teardown_request_funcs, self.teardown_request_funcs,
app.teardown_request_funcs, app.teardown_request_funcs,
ensure_sync=True,
) )
extend(self.url_default_functions, app.url_default_functions) extend(self.url_default_functions, app.url_default_functions)
extend(self.url_value_preprocessors, app.url_value_preprocessors) extend(self.url_value_preprocessors, app.url_value_preprocessors)
@ -478,9 +471,7 @@ class Blueprint(Scaffold):
before each request, even if outside of a blueprint. before each request, even if outside of a blueprint.
""" """
self.record_once( self.record_once(
lambda s: s.app.before_request_funcs.setdefault(None, []).append( lambda s: s.app.before_request_funcs.setdefault(None, []).append(f)
s.app.ensure_sync(f)
)
) )
return f return f
@ -490,9 +481,7 @@ class Blueprint(Scaffold):
"""Like :meth:`Flask.before_first_request`. Such a function is """Like :meth:`Flask.before_first_request`. Such a function is
executed before the first request to the application. executed before the first request to the application.
""" """
self.record_once( self.record_once(lambda s: s.app.before_first_request_funcs.append(f))
lambda s: s.app.before_first_request_funcs.append(s.app.ensure_sync(f))
)
return f return f
def after_app_request(self, f: AfterRequestCallable) -> AfterRequestCallable: def after_app_request(self, f: AfterRequestCallable) -> AfterRequestCallable:
@ -500,9 +489,7 @@ class Blueprint(Scaffold):
is executed after each request, even if outside of the blueprint. is executed after each request, even if outside of the blueprint.
""" """
self.record_once( self.record_once(
lambda s: s.app.after_request_funcs.setdefault(None, []).append( lambda s: s.app.after_request_funcs.setdefault(None, []).append(f)
s.app.ensure_sync(f)
)
) )
return f return f
@ -553,14 +540,3 @@ class Blueprint(Scaffold):
lambda s: s.app.url_default_functions.setdefault(None, []).append(f) lambda s: s.app.url_default_functions.setdefault(None, []).append(f)
) )
return f return f
def ensure_sync(self, f: t.Callable) -> t.Callable:
"""Ensure the function is synchronous.
Override if you would like custom async to sync behaviour in
this blueprint. Otherwise the app's
:meth:`~flask.Flask.ensure_sync` is used.
.. versionadded:: 2.0
"""
return f

View file

@ -6,12 +6,10 @@ import typing as t
import warnings import warnings
from datetime import timedelta from datetime import timedelta
from functools import update_wrapper from functools import update_wrapper
from functools import wraps
from threading import RLock from threading import RLock
import werkzeug.utils import werkzeug.utils
from werkzeug.exceptions import NotFound from werkzeug.exceptions import NotFound
from werkzeug.local import ContextVar
from werkzeug.routing import BuildError from werkzeug.routing import BuildError
from werkzeug.urls import url_quote from werkzeug.urls import url_quote
@ -801,51 +799,3 @@ def is_ip(value: str) -> bool:
return True return True
return False return False
def run_async(func: t.Callable[..., t.Coroutine]) -> t.Callable[..., t.Any]:
"""Return a sync function that will run the coroutine function *func*."""
try:
from asgiref.sync import async_to_sync
except ImportError:
raise RuntimeError(
"Install Flask with the 'async' extra in order to use async views."
)
# Check that Werkzeug isn't using its fallback ContextVar class.
if ContextVar.__module__ == "werkzeug.local":
raise RuntimeError(
"Async cannot be used with this combination of Python & Greenlet versions."
)
@wraps(func)
def outer(*args: t.Any, **kwargs: t.Any) -> t.Any:
"""This function grabs the current context for the inner function.
This is similar to the copy_current_xxx_context functions in the
ctx module, except it has an async inner.
"""
ctx = None
if _request_ctx_stack.top is not None:
ctx = _request_ctx_stack.top.copy()
@wraps(func)
async def inner(*a: t.Any, **k: t.Any) -> t.Any:
"""This restores the context before awaiting the func.
This is required as the function must be awaited within the
context. Only calling ``func`` (as per the
``copy_current_xxx_context`` functions) doesn't work as the
with block will close before the coroutine is awaited.
"""
if ctx is not None:
with ctx:
return await func(*a, **k)
else:
return await func(*a, **k)
return async_to_sync(inner)(*args, **kwargs)
outer._flask_sync_wrapper = True # type: ignore
return outer

View file

@ -521,7 +521,7 @@ class Scaffold:
""" """
def decorator(f): def decorator(f):
self.view_functions[endpoint] = self.ensure_sync(f) self.view_functions[endpoint] = f
return f return f
return decorator return decorator
@ -545,7 +545,7 @@ class Scaffold:
return value from the view, and further request handling is return value from the view, and further request handling is
stopped. stopped.
""" """
self.before_request_funcs.setdefault(None, []).append(self.ensure_sync(f)) self.before_request_funcs.setdefault(None, []).append(f)
return f return f
@setupmethod @setupmethod
@ -561,7 +561,7 @@ class Scaffold:
should not be used for actions that must execute, such as to should not be used for actions that must execute, such as to
close resources. Use :meth:`teardown_request` for that. close resources. Use :meth:`teardown_request` for that.
""" """
self.after_request_funcs.setdefault(None, []).append(self.ensure_sync(f)) self.after_request_funcs.setdefault(None, []).append(f)
return f return f
@setupmethod @setupmethod
@ -600,7 +600,7 @@ class Scaffold:
debugger can still access it. This behavior can be controlled debugger can still access it. This behavior can be controlled
by the ``PRESERVE_CONTEXT_ON_EXCEPTION`` configuration variable. by the ``PRESERVE_CONTEXT_ON_EXCEPTION`` configuration variable.
""" """
self.teardown_request_funcs.setdefault(None, []).append(self.ensure_sync(f)) self.teardown_request_funcs.setdefault(None, []).append(f)
return f return f
@setupmethod @setupmethod
@ -706,7 +706,7 @@ class Scaffold:
" instead." " instead."
) )
self.error_handler_spec[None][code][exc_class] = self.ensure_sync(f) self.error_handler_spec[None][code][exc_class] = f
@staticmethod @staticmethod
def _get_exc_class_and_code( def _get_exc_class_and_code(
@ -734,9 +734,6 @@ class Scaffold:
else: else:
return exc_class, None return exc_class, None
def ensure_sync(self, func: t.Callable) -> t.Callable:
raise NotImplementedError()
def _endpoint_from_view_func(view_func: t.Callable) -> str: def _endpoint_from_view_func(view_func: t.Callable) -> str:
"""Internal helper that returns the default endpoint for a given """Internal helper that returns the default endpoint for a given

View file

@ -6,7 +6,6 @@ import pytest
from flask import Blueprint from flask import Blueprint
from flask import Flask from flask import Flask
from flask import request from flask import request
from flask.helpers import run_async
pytest.importorskip("asgiref") pytest.importorskip("asgiref")
@ -136,5 +135,6 @@ def test_async_before_after_request():
@pytest.mark.skipif(sys.version_info >= (3, 7), reason="should only raise Python < 3.7") @pytest.mark.skipif(sys.version_info >= (3, 7), reason="should only raise Python < 3.7")
def test_async_runtime_error(): def test_async_runtime_error():
app = Flask(__name__)
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
run_async(None) app.async_to_sync(None)