From 60e3d34d195592fea612ff6e8d688f0200dd3d61 Mon Sep 17 00:00:00 2001 From: Tim Pansino Date: Thu, 3 Jun 2021 14:10:31 -0700 Subject: [PATCH] Move async to sync conversion to helpers Co-authored-by: Uma Annamalai Co-authored-by: Kevin Yang Co-authored-by: Katherine Kelly --- src/flask/app.py | 77 ++++++-------------------------------------- src/flask/helpers.py | 54 +++++++++++++++++++++++++++++++ 2 files changed, 64 insertions(+), 67 deletions(-) diff --git a/src/flask/app.py b/src/flask/app.py index cacb40a5..e4e0a0e7 100644 --- a/src/flask/app.py +++ b/src/flask/app.py @@ -16,7 +16,6 @@ from werkzeug.exceptions import BadRequest from werkzeug.exceptions import BadRequestKeyError from werkzeug.exceptions import HTTPException from werkzeug.exceptions import InternalServerError -from werkzeug.local import ContextVar from werkzeug.routing import BuildError from werkzeug.routing import Map from werkzeug.routing import MapAdapter @@ -37,6 +36,7 @@ from .globals import g from .globals import request from .globals import session from .helpers import _split_blueprint_path +from .helpers import ensure_sync from .helpers import get_debug_flag from .helpers import get_env from .helpers import get_flashed_messages @@ -79,19 +79,6 @@ if t.TYPE_CHECKING: from .testing import FlaskClient from .testing import FlaskCliRunner -if sys.version_info >= (3, 8): - iscoroutinefunction = inspect.iscoroutinefunction -else: - - def iscoroutinefunction(func: t.Any) -> bool: - while inspect.ismethod(func): - func = func.__func__ - - while isinstance(func, functools.partial): - func = func.func - - return inspect.iscoroutinefunction(func) - def _make_timedelta(value: t.Optional[timedelta]) -> t.Optional[timedelta]: if value is None or isinstance(value, timedelta): @@ -1323,7 +1310,7 @@ class Flask(Scaffold): handler = self._find_error_handler(e) if handler is None: return e - return self.ensure_sync(handler)(e) + return ensure_sync(handler)(e) def trap_http_exception(self, e: Exception) -> bool: """Checks if an HTTP exception should be trapped or not. By default @@ -1390,7 +1377,7 @@ class Flask(Scaffold): if handler is None: raise - return self.ensure_sync(handler)(e) + return ensure_sync(handler)(e) def handle_exception(self, e: Exception) -> Response: """Handle an exception that did not have an error handler @@ -1437,7 +1424,7 @@ class Flask(Scaffold): handler = self._find_error_handler(server_error) if handler is not None: - server_error = self.ensure_sync(handler)(server_error) + server_error = ensure_sync(handler)(server_error) return self.finalize_request(server_error, from_error_handler=True) @@ -1499,7 +1486,7 @@ class Flask(Scaffold): ): return self.make_default_options_response() # otherwise dispatch to the handler for that endpoint - return self.ensure_sync(self.view_functions[rule.endpoint])(**req.view_args) + return ensure_sync(self.view_functions[rule.endpoint])(**req.view_args) def full_dispatch_request(self) -> Response: """Dispatches the request and on top of that performs request @@ -1560,7 +1547,7 @@ class Flask(Scaffold): if self._got_first_request: return for func in self.before_first_request_funcs: - self.ensure_sync(func)() + ensure_sync(func)() self._got_first_request = True def make_default_options_response(self) -> Response: @@ -1586,50 +1573,6 @@ class Flask(Scaffold): """ return False - def ensure_sync(self, func: t.Callable) -> t.Callable: - """Ensure that the function is synchronous for WSGI workers. - Plain ``def`` functions are returned as-is. ``async def`` - functions are wrapped to run and wait for the response. - - Override this method to change how the app runs async views. - - .. versionadded:: 2.0 - """ - if iscoroutinefunction(func): - return self.async_to_sync(func) - - return func - - def async_to_sync( - self, func: t.Callable[..., t.Coroutine] - ) -> t.Callable[..., t.Any]: - """Return a sync function that will run the coroutine function. - - .. code-block:: python - - result = app.async_to_sync(func)(*args, **kwargs) - - Override this method to change how the app converts async code - to be synchronously callable. - - .. versionadded:: 2.0 - """ - try: - from asgiref.sync import async_to_sync as asgiref_async_to_sync - except ImportError: - raise RuntimeError( - "Install Flask with the 'async' extra in order to use async views." - ) - - # Check that Werkzeug isn't using its fallback ContextVar class. - if ContextVar.__module__ == "werkzeug.local": - raise RuntimeError( - "Async cannot be used with this combination of Python " - "and Greenlet versions." - ) - - return asgiref_async_to_sync(func) - def make_response(self, rv: ResponseReturnValue) -> Response: """Convert the return value from a view function to an instance of :attr:`response_class`. @@ -1857,7 +1800,7 @@ class Flask(Scaffold): if bp in self.before_request_funcs: funcs = chain(funcs, self.before_request_funcs[bp]) for func in funcs: - rv = self.ensure_sync(func)() + rv = ensure_sync(func)() if rv is not None: return rv @@ -1884,7 +1827,7 @@ class Flask(Scaffold): if None in self.after_request_funcs: funcs = chain(funcs, reversed(self.after_request_funcs[None])) for handler in funcs: - response = self.ensure_sync(handler)(response) + response = ensure_sync(handler)(response) if not self.session_interface.is_null_session(ctx.session): self.session_interface.save_session(self, ctx.session, response) return response @@ -1921,7 +1864,7 @@ class Flask(Scaffold): if bp in self.teardown_request_funcs: funcs = chain(funcs, reversed(self.teardown_request_funcs[bp])) for func in funcs: - self.ensure_sync(func)(exc) + ensure_sync(func)(exc) request_tearing_down.send(self, exc=exc) def do_teardown_appcontext( @@ -1944,7 +1887,7 @@ class Flask(Scaffold): if exc is _sentinel: exc = sys.exc_info()[1] for func in reversed(self.teardown_appcontext_funcs): - self.ensure_sync(func)(exc) + ensure_sync(func)(exc) appcontext_tearing_down.send(self, exc=exc) def app_context(self) -> AppContext: diff --git a/src/flask/helpers.py b/src/flask/helpers.py index 7b8b0870..a8782cac 100644 --- a/src/flask/helpers.py +++ b/src/flask/helpers.py @@ -1,3 +1,4 @@ +import inspect import os import pkgutil import socket @@ -12,6 +13,7 @@ from threading import RLock import werkzeug.utils from werkzeug.exceptions import NotFound +from werkzeug.local import ContextVar from werkzeug.routing import BuildError from werkzeug.urls import url_quote @@ -25,6 +27,19 @@ from .signals import message_flashed if t.TYPE_CHECKING: from .wrappers import Response +if sys.version_info >= (3, 8): + iscoroutinefunction = inspect.iscoroutinefunction +else: + + def iscoroutinefunction(func: t.Any) -> bool: + while inspect.ismethod(func): + func = func.__func__ + + while isinstance(func, functools.partial): + func = func.func + + return inspect.iscoroutinefunction(func) + def get_env() -> str: """Get the environment the app is running in, indicated by the @@ -834,3 +849,42 @@ def _split_blueprint_path(name: str) -> t.List[str]: out.extend(_split_blueprint_path(name.rpartition(".")[0])) return out + + +def ensure_sync(func: t.Callable) -> t.Callable: + """Ensure that the function is synchronous for WSGI workers. + Plain ``def`` functions are returned as-is. ``async def`` + functions are wrapped to run and wait for the response. + + .. versionadded:: 2.0 + """ + if iscoroutinefunction(func): + return async_to_sync(func) + + return func + + +def async_to_sync(func: t.Callable[..., t.Coroutine]) -> t.Callable[..., t.Any]: + """Return a sync function that will run the coroutine function. + + .. code-block:: python + + result = async_to_sync(func)(*args, **kwargs) + + .. versionadded:: 2.0 + """ + try: + from asgiref.sync import async_to_sync as asgiref_async_to_sync + except ImportError: + raise RuntimeError( + "Install Flask with the 'async' extra in order to use async views." + ) + + # Check that Werkzeug isn't using its fallback ContextVar class. + if ContextVar.__module__ == "werkzeug.local": + raise RuntimeError( + "Async cannot be used with this combination of Python " + "and Greenlet versions." + ) + + return asgiref_async_to_sync(func)