diff --git a/docs/async-await.rst b/docs/async-await.rst index 23b418de..34751d47 100644 --- a/docs/async-await.rst +++ b/docs/async-await.rst @@ -92,6 +92,21 @@ not work with async views because they will not await the function or be awaitable. Other functions they provide will not be awaitable either and will probably be blocking if called within an async view. +Extension authors can support async functions by utilising the +:meth:`flask.Flask.ensure_sync` method. For example, if the extension +provides a view function decorator add ``ensure_sync`` before calling +the decorated function, + +.. code-block:: python + + def extension(func): + @wraps(func) + def wrapper(*args, **kwargs): + ... # Extension logic + return current_app.ensure_sync(func)(*args, **kwargs) + + return wrapper + Check the changelog of the extension you want to use to see if they've implemented async support, or make a feature request or PR to them. diff --git a/src/flask/app.py b/src/flask/app.py index 7afb0a1e..85306d7c 100644 --- a/src/flask/app.py +++ b/src/flask/app.py @@ -16,6 +16,7 @@ from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequestKeyError from werkzeug.exceptions import HTTPException from werkzeug.exceptions import InternalServerError +from werkzeug.local import ContextVar from werkzeug.routing import BuildError from werkzeug.routing import Map from werkzeug.routing import MapAdapter @@ -40,7 +41,6 @@ from .helpers import get_env from .helpers import get_flashed_messages from .helpers import get_load_dotenv from .helpers import locked_cached_property -from .helpers import run_async from .helpers import url_for from .json import jsonify from .logging import create_logger @@ -1080,14 +1080,12 @@ class Flask(Scaffold): self.url_map.add(rule) if view_func is not None: old_func = self.view_functions.get(endpoint) - if getattr(old_func, "_flask_sync_wrapper", False): - old_func = old_func.__wrapped__ # type: ignore if old_func is not None and old_func != view_func: raise AssertionError( "View function mapping is overwriting an existing" f" endpoint function: {endpoint}" ) - self.view_functions[endpoint] = self.ensure_sync(view_func) + self.view_functions[endpoint] = view_func @setupmethod def template_filter(self, name: t.Optional[str] = None) -> t.Callable: @@ -1208,7 +1206,7 @@ class Flask(Scaffold): .. versionadded:: 0.8 """ - self.before_first_request_funcs.append(self.ensure_sync(f)) + self.before_first_request_funcs.append(f) return f @setupmethod @@ -1241,7 +1239,7 @@ class Flask(Scaffold): .. versionadded:: 0.9 """ - self.teardown_appcontext_funcs.append(self.ensure_sync(f)) + self.teardown_appcontext_funcs.append(f) return f @setupmethod @@ -1308,7 +1306,7 @@ class Flask(Scaffold): handler = self._find_error_handler(e) if handler is None: return e - return handler(e) + return self.ensure_sync(handler)(e) def trap_http_exception(self, e: Exception) -> bool: """Checks if an HTTP exception should be trapped or not. By default @@ -1375,7 +1373,7 @@ class Flask(Scaffold): if handler is None: raise - return handler(e) + return self.ensure_sync(handler)(e) def handle_exception(self, e: Exception) -> Response: """Handle an exception that did not have an error handler @@ -1422,7 +1420,7 @@ class Flask(Scaffold): handler = self._find_error_handler(server_error) if handler is not None: - server_error = handler(server_error) + server_error = self.ensure_sync(handler)(server_error) return self.finalize_request(server_error, from_error_handler=True) @@ -1484,7 +1482,7 @@ class Flask(Scaffold): ): return self.make_default_options_response() # otherwise dispatch to the handler for that endpoint - return self.view_functions[rule.endpoint](**req.view_args) + return self.ensure_sync(self.view_functions[rule.endpoint])(**req.view_args) def full_dispatch_request(self) -> Response: """Dispatches the request and on top of that performs request @@ -1545,7 +1543,7 @@ class Flask(Scaffold): if self._got_first_request: return for func in self.before_first_request_funcs: - func() + self.ensure_sync(func)() self._got_first_request = True def make_default_options_response(self) -> Response: @@ -1581,10 +1579,40 @@ class Flask(Scaffold): .. versionadded:: 2.0 """ if iscoroutinefunction(func): - return run_async(func) + return self.async_to_sync(func) return func + def async_to_sync( + self, func: t.Callable[..., t.Coroutine] + ) -> t.Callable[..., t.Any]: + """Return a sync function that will run the coroutine function. + + .. code-block:: python + + result = app.async_to_sync(func)(*args, **kwargs) + + Override this method to change how the app converts async code + to be synchronously callable. + + .. versionadded:: 2.0 + """ + try: + from asgiref.sync import async_to_sync as asgiref_async_to_sync + except ImportError: + raise RuntimeError( + "Install Flask with the 'async' extra in order to use async views." + ) + + # Check that Werkzeug isn't using its fallback ContextVar class. + if ContextVar.__module__ == "werkzeug.local": + raise RuntimeError( + "Async cannot be used with this combination of Python " + "and Greenlet versions." + ) + + return asgiref_async_to_sync(func) + def make_response(self, rv: ResponseReturnValue) -> Response: """Convert the return value from a view function to an instance of :attr:`response_class`. @@ -1807,7 +1835,7 @@ class Flask(Scaffold): if bp in self.before_request_funcs: funcs = chain(funcs, self.before_request_funcs[bp]) for func in funcs: - rv = func() + rv = self.ensure_sync(func)() if rv is not None: return rv @@ -1834,7 +1862,7 @@ class Flask(Scaffold): if None in self.after_request_funcs: funcs = chain(funcs, reversed(self.after_request_funcs[None])) for handler in funcs: - response = handler(response) + response = self.ensure_sync(handler)(response) if not self.session_interface.is_null_session(ctx.session): self.session_interface.save_session(self, ctx.session, response) return response @@ -1871,7 +1899,7 @@ class Flask(Scaffold): if bp in self.teardown_request_funcs: funcs = chain(funcs, reversed(self.teardown_request_funcs[bp])) for func in funcs: - func(exc) + self.ensure_sync(func)(exc) request_tearing_down.send(self, exc=exc) def do_teardown_appcontext( @@ -1894,7 +1922,7 @@ class Flask(Scaffold): if exc is _sentinel: exc = sys.exc_info()[1] for func in reversed(self.teardown_appcontext_funcs): - func(exc) + self.ensure_sync(func)(exc) appcontext_tearing_down.send(self, exc=exc) def app_context(self) -> AppContext: diff --git a/src/flask/blueprints.py b/src/flask/blueprints.py index a2b6c0f5..5fb84d86 100644 --- a/src/flask/blueprints.py +++ b/src/flask/blueprints.py @@ -292,13 +292,10 @@ class Blueprint(Scaffold): # Merge blueprint data into parent. if first_registration: - def extend(bp_dict, parent_dict, ensure_sync=False): + def extend(bp_dict, parent_dict): for key, values in bp_dict.items(): key = self.name if key is None else f"{self.name}.{key}" - if ensure_sync: - values = [app.ensure_sync(func) for func in values] - parent_dict[key].extend(values) for key, value in self.error_handler_spec.items(): @@ -307,8 +304,7 @@ class Blueprint(Scaffold): dict, { code: { - exc_class: app.ensure_sync(func) - for exc_class, func in code_values.items() + exc_class: func for exc_class, func in code_values.items() } for code, code_values in value.items() }, @@ -316,16 +312,13 @@ class Blueprint(Scaffold): app.error_handler_spec[key] = value for endpoint, func in self.view_functions.items(): - app.view_functions[endpoint] = app.ensure_sync(func) + app.view_functions[endpoint] = func - extend( - self.before_request_funcs, app.before_request_funcs, ensure_sync=True - ) - extend(self.after_request_funcs, app.after_request_funcs, ensure_sync=True) + extend(self.before_request_funcs, app.before_request_funcs) + extend(self.after_request_funcs, app.after_request_funcs) extend( self.teardown_request_funcs, app.teardown_request_funcs, - ensure_sync=True, ) extend(self.url_default_functions, app.url_default_functions) extend(self.url_value_preprocessors, app.url_value_preprocessors) @@ -478,9 +471,7 @@ class Blueprint(Scaffold): before each request, even if outside of a blueprint. """ self.record_once( - lambda s: s.app.before_request_funcs.setdefault(None, []).append( - s.app.ensure_sync(f) - ) + lambda s: s.app.before_request_funcs.setdefault(None, []).append(f) ) return f @@ -490,9 +481,7 @@ class Blueprint(Scaffold): """Like :meth:`Flask.before_first_request`. Such a function is executed before the first request to the application. """ - self.record_once( - lambda s: s.app.before_first_request_funcs.append(s.app.ensure_sync(f)) - ) + self.record_once(lambda s: s.app.before_first_request_funcs.append(f)) return f def after_app_request(self, f: AfterRequestCallable) -> AfterRequestCallable: @@ -500,9 +489,7 @@ class Blueprint(Scaffold): is executed after each request, even if outside of the blueprint. """ self.record_once( - lambda s: s.app.after_request_funcs.setdefault(None, []).append( - s.app.ensure_sync(f) - ) + lambda s: s.app.after_request_funcs.setdefault(None, []).append(f) ) return f @@ -553,14 +540,3 @@ class Blueprint(Scaffold): lambda s: s.app.url_default_functions.setdefault(None, []).append(f) ) return f - - def ensure_sync(self, f: t.Callable) -> t.Callable: - """Ensure the function is synchronous. - - Override if you would like custom async to sync behaviour in - this blueprint. Otherwise the app's - :meth:`~flask.Flask.ensure_sync` is used. - - .. versionadded:: 2.0 - """ - return f diff --git a/src/flask/helpers.py b/src/flask/helpers.py index f7d37ed7..109f544f 100644 --- a/src/flask/helpers.py +++ b/src/flask/helpers.py @@ -6,12 +6,10 @@ import typing as t import warnings from datetime import timedelta from functools import update_wrapper -from functools import wraps from threading import RLock import werkzeug.utils from werkzeug.exceptions import NotFound -from werkzeug.local import ContextVar from werkzeug.routing import BuildError from werkzeug.urls import url_quote @@ -801,51 +799,3 @@ def is_ip(value: str) -> bool: return True return False - - -def run_async(func: t.Callable[..., t.Coroutine]) -> t.Callable[..., t.Any]: - """Return a sync function that will run the coroutine function *func*.""" - try: - from asgiref.sync import async_to_sync - except ImportError: - raise RuntimeError( - "Install Flask with the 'async' extra in order to use async views." - ) - - # Check that Werkzeug isn't using its fallback ContextVar class. - if ContextVar.__module__ == "werkzeug.local": - raise RuntimeError( - "Async cannot be used with this combination of Python & Greenlet versions." - ) - - @wraps(func) - def outer(*args: t.Any, **kwargs: t.Any) -> t.Any: - """This function grabs the current context for the inner function. - - This is similar to the copy_current_xxx_context functions in the - ctx module, except it has an async inner. - """ - ctx = None - - if _request_ctx_stack.top is not None: - ctx = _request_ctx_stack.top.copy() - - @wraps(func) - async def inner(*a: t.Any, **k: t.Any) -> t.Any: - """This restores the context before awaiting the func. - - This is required as the function must be awaited within the - context. Only calling ``func`` (as per the - ``copy_current_xxx_context`` functions) doesn't work as the - with block will close before the coroutine is awaited. - """ - if ctx is not None: - with ctx: - return await func(*a, **k) - else: - return await func(*a, **k) - - return async_to_sync(inner)(*args, **kwargs) - - outer._flask_sync_wrapper = True # type: ignore - return outer diff --git a/src/flask/scaffold.py b/src/flask/scaffold.py index d40dfdd8..56d37ddd 100644 --- a/src/flask/scaffold.py +++ b/src/flask/scaffold.py @@ -521,7 +521,7 @@ class Scaffold: """ def decorator(f): - self.view_functions[endpoint] = self.ensure_sync(f) + self.view_functions[endpoint] = f return f return decorator @@ -545,7 +545,7 @@ class Scaffold: return value from the view, and further request handling is stopped. """ - self.before_request_funcs.setdefault(None, []).append(self.ensure_sync(f)) + self.before_request_funcs.setdefault(None, []).append(f) return f @setupmethod @@ -561,7 +561,7 @@ class Scaffold: should not be used for actions that must execute, such as to close resources. Use :meth:`teardown_request` for that. """ - self.after_request_funcs.setdefault(None, []).append(self.ensure_sync(f)) + self.after_request_funcs.setdefault(None, []).append(f) return f @setupmethod @@ -600,7 +600,7 @@ class Scaffold: debugger can still access it. This behavior can be controlled by the ``PRESERVE_CONTEXT_ON_EXCEPTION`` configuration variable. """ - self.teardown_request_funcs.setdefault(None, []).append(self.ensure_sync(f)) + self.teardown_request_funcs.setdefault(None, []).append(f) return f @setupmethod @@ -706,7 +706,7 @@ class Scaffold: " instead." ) - self.error_handler_spec[None][code][exc_class] = self.ensure_sync(f) + self.error_handler_spec[None][code][exc_class] = f @staticmethod def _get_exc_class_and_code( @@ -734,9 +734,6 @@ class Scaffold: else: return exc_class, None - def ensure_sync(self, func: t.Callable) -> t.Callable: - raise NotImplementedError() - def _endpoint_from_view_func(view_func: t.Callable) -> str: """Internal helper that returns the default endpoint for a given diff --git a/tests/test_async.py b/tests/test_async.py index 8c096f69..26a91118 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -6,7 +6,6 @@ import pytest from flask import Blueprint from flask import Flask from flask import request -from flask.helpers import run_async pytest.importorskip("asgiref") @@ -136,5 +135,6 @@ def test_async_before_after_request(): @pytest.mark.skipif(sys.version_info >= (3, 7), reason="should only raise Python < 3.7") def test_async_runtime_error(): + app = Flask(__name__) with pytest.raises(RuntimeError): - run_async(None) + app.async_to_sync(None)