Move async to sync conversion to helpers

Co-authored-by: Uma Annamalai <umaannamalai@users.noreply.github.com>
Co-authored-by: Kevin Yang <kkaiyang94@users.noreply.github.com>
Co-authored-by: Katherine Kelly <kat-star@users.noreply.github.com>
This commit is contained in:
Tim Pansino 2021-06-03 14:10:31 -07:00
parent aac67289e5
commit 60e3d34d19
2 changed files with 64 additions and 67 deletions

View file

@ -16,7 +16,6 @@ from werkzeug.exceptions import BadRequest
from werkzeug.exceptions import BadRequestKeyError
from werkzeug.exceptions import HTTPException
from werkzeug.exceptions import InternalServerError
from werkzeug.local import ContextVar
from werkzeug.routing import BuildError
from werkzeug.routing import Map
from werkzeug.routing import MapAdapter
@ -37,6 +36,7 @@ from .globals import g
from .globals import request
from .globals import session
from .helpers import _split_blueprint_path
from .helpers import ensure_sync
from .helpers import get_debug_flag
from .helpers import get_env
from .helpers import get_flashed_messages
@ -79,19 +79,6 @@ if t.TYPE_CHECKING:
from .testing import FlaskClient
from .testing import FlaskCliRunner
if sys.version_info >= (3, 8):
iscoroutinefunction = inspect.iscoroutinefunction
else:
def iscoroutinefunction(func: t.Any) -> bool:
while inspect.ismethod(func):
func = func.__func__
while isinstance(func, functools.partial):
func = func.func
return inspect.iscoroutinefunction(func)
def _make_timedelta(value: t.Optional[timedelta]) -> t.Optional[timedelta]:
if value is None or isinstance(value, timedelta):
@ -1323,7 +1310,7 @@ class Flask(Scaffold):
handler = self._find_error_handler(e)
if handler is None:
return e
return self.ensure_sync(handler)(e)
return ensure_sync(handler)(e)
def trap_http_exception(self, e: Exception) -> bool:
"""Checks if an HTTP exception should be trapped or not. By default
@ -1390,7 +1377,7 @@ class Flask(Scaffold):
if handler is None:
raise
return self.ensure_sync(handler)(e)
return ensure_sync(handler)(e)
def handle_exception(self, e: Exception) -> Response:
"""Handle an exception that did not have an error handler
@ -1437,7 +1424,7 @@ class Flask(Scaffold):
handler = self._find_error_handler(server_error)
if handler is not None:
server_error = self.ensure_sync(handler)(server_error)
server_error = ensure_sync(handler)(server_error)
return self.finalize_request(server_error, from_error_handler=True)
@ -1499,7 +1486,7 @@ class Flask(Scaffold):
):
return self.make_default_options_response()
# otherwise dispatch to the handler for that endpoint
return self.ensure_sync(self.view_functions[rule.endpoint])(**req.view_args)
return ensure_sync(self.view_functions[rule.endpoint])(**req.view_args)
def full_dispatch_request(self) -> Response:
"""Dispatches the request and on top of that performs request
@ -1560,7 +1547,7 @@ class Flask(Scaffold):
if self._got_first_request:
return
for func in self.before_first_request_funcs:
self.ensure_sync(func)()
ensure_sync(func)()
self._got_first_request = True
def make_default_options_response(self) -> Response:
@ -1586,50 +1573,6 @@ class Flask(Scaffold):
"""
return False
def ensure_sync(self, func: t.Callable) -> t.Callable:
"""Ensure that the function is synchronous for WSGI workers.
Plain ``def`` functions are returned as-is. ``async def``
functions are wrapped to run and wait for the response.
Override this method to change how the app runs async views.
.. versionadded:: 2.0
"""
if iscoroutinefunction(func):
return self.async_to_sync(func)
return func
def async_to_sync(
self, func: t.Callable[..., t.Coroutine]
) -> t.Callable[..., t.Any]:
"""Return a sync function that will run the coroutine function.
.. code-block:: python
result = app.async_to_sync(func)(*args, **kwargs)
Override this method to change how the app converts async code
to be synchronously callable.
.. versionadded:: 2.0
"""
try:
from asgiref.sync import async_to_sync as asgiref_async_to_sync
except ImportError:
raise RuntimeError(
"Install Flask with the 'async' extra in order to use async views."
)
# Check that Werkzeug isn't using its fallback ContextVar class.
if ContextVar.__module__ == "werkzeug.local":
raise RuntimeError(
"Async cannot be used with this combination of Python "
"and Greenlet versions."
)
return asgiref_async_to_sync(func)
def make_response(self, rv: ResponseReturnValue) -> Response:
"""Convert the return value from a view function to an instance of
:attr:`response_class`.
@ -1857,7 +1800,7 @@ class Flask(Scaffold):
if bp in self.before_request_funcs:
funcs = chain(funcs, self.before_request_funcs[bp])
for func in funcs:
rv = self.ensure_sync(func)()
rv = ensure_sync(func)()
if rv is not None:
return rv
@ -1884,7 +1827,7 @@ class Flask(Scaffold):
if None in self.after_request_funcs:
funcs = chain(funcs, reversed(self.after_request_funcs[None]))
for handler in funcs:
response = self.ensure_sync(handler)(response)
response = ensure_sync(handler)(response)
if not self.session_interface.is_null_session(ctx.session):
self.session_interface.save_session(self, ctx.session, response)
return response
@ -1921,7 +1864,7 @@ class Flask(Scaffold):
if bp in self.teardown_request_funcs:
funcs = chain(funcs, reversed(self.teardown_request_funcs[bp]))
for func in funcs:
self.ensure_sync(func)(exc)
ensure_sync(func)(exc)
request_tearing_down.send(self, exc=exc)
def do_teardown_appcontext(
@ -1944,7 +1887,7 @@ class Flask(Scaffold):
if exc is _sentinel:
exc = sys.exc_info()[1]
for func in reversed(self.teardown_appcontext_funcs):
self.ensure_sync(func)(exc)
ensure_sync(func)(exc)
appcontext_tearing_down.send(self, exc=exc)
def app_context(self) -> AppContext:

View file

@ -1,3 +1,4 @@
import inspect
import os
import pkgutil
import socket
@ -12,6 +13,7 @@ from threading import RLock
import werkzeug.utils
from werkzeug.exceptions import NotFound
from werkzeug.local import ContextVar
from werkzeug.routing import BuildError
from werkzeug.urls import url_quote
@ -25,6 +27,19 @@ from .signals import message_flashed
if t.TYPE_CHECKING:
from .wrappers import Response
if sys.version_info >= (3, 8):
iscoroutinefunction = inspect.iscoroutinefunction
else:
def iscoroutinefunction(func: t.Any) -> bool:
while inspect.ismethod(func):
func = func.__func__
while isinstance(func, functools.partial):
func = func.func
return inspect.iscoroutinefunction(func)
def get_env() -> str:
"""Get the environment the app is running in, indicated by the
@ -834,3 +849,42 @@ def _split_blueprint_path(name: str) -> t.List[str]:
out.extend(_split_blueprint_path(name.rpartition(".")[0]))
return out
def ensure_sync(func: t.Callable) -> t.Callable:
"""Ensure that the function is synchronous for WSGI workers.
Plain ``def`` functions are returned as-is. ``async def``
functions are wrapped to run and wait for the response.
.. versionadded:: 2.0
"""
if iscoroutinefunction(func):
return async_to_sync(func)
return func
def async_to_sync(func: t.Callable[..., t.Coroutine]) -> t.Callable[..., t.Any]:
"""Return a sync function that will run the coroutine function.
.. code-block:: python
result = async_to_sync(func)(*args, **kwargs)
.. versionadded:: 2.0
"""
try:
from asgiref.sync import async_to_sync as asgiref_async_to_sync
except ImportError:
raise RuntimeError(
"Install Flask with the 'async' extra in order to use async views."
)
# Check that Werkzeug isn't using its fallback ContextVar class.
if ContextVar.__module__ == "werkzeug.local":
raise RuntimeError(
"Async cannot be used with this combination of Python "
"and Greenlet versions."
)
return asgiref_async_to_sync(func)