Move async to sync conversion to helpers
Co-authored-by: Uma Annamalai <umaannamalai@users.noreply.github.com> Co-authored-by: Kevin Yang <kkaiyang94@users.noreply.github.com> Co-authored-by: Katherine Kelly <kat-star@users.noreply.github.com>
This commit is contained in:
parent
aac67289e5
commit
60e3d34d19
2 changed files with 64 additions and 67 deletions
|
|
@ -16,7 +16,6 @@ from werkzeug.exceptions import BadRequest
|
|||
from werkzeug.exceptions import BadRequestKeyError
|
||||
from werkzeug.exceptions import HTTPException
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
from werkzeug.local import ContextVar
|
||||
from werkzeug.routing import BuildError
|
||||
from werkzeug.routing import Map
|
||||
from werkzeug.routing import MapAdapter
|
||||
|
|
@ -37,6 +36,7 @@ from .globals import g
|
|||
from .globals import request
|
||||
from .globals import session
|
||||
from .helpers import _split_blueprint_path
|
||||
from .helpers import ensure_sync
|
||||
from .helpers import get_debug_flag
|
||||
from .helpers import get_env
|
||||
from .helpers import get_flashed_messages
|
||||
|
|
@ -79,19 +79,6 @@ if t.TYPE_CHECKING:
|
|||
from .testing import FlaskClient
|
||||
from .testing import FlaskCliRunner
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
iscoroutinefunction = inspect.iscoroutinefunction
|
||||
else:
|
||||
|
||||
def iscoroutinefunction(func: t.Any) -> bool:
|
||||
while inspect.ismethod(func):
|
||||
func = func.__func__
|
||||
|
||||
while isinstance(func, functools.partial):
|
||||
func = func.func
|
||||
|
||||
return inspect.iscoroutinefunction(func)
|
||||
|
||||
|
||||
def _make_timedelta(value: t.Optional[timedelta]) -> t.Optional[timedelta]:
|
||||
if value is None or isinstance(value, timedelta):
|
||||
|
|
@ -1323,7 +1310,7 @@ class Flask(Scaffold):
|
|||
handler = self._find_error_handler(e)
|
||||
if handler is None:
|
||||
return e
|
||||
return self.ensure_sync(handler)(e)
|
||||
return ensure_sync(handler)(e)
|
||||
|
||||
def trap_http_exception(self, e: Exception) -> bool:
|
||||
"""Checks if an HTTP exception should be trapped or not. By default
|
||||
|
|
@ -1390,7 +1377,7 @@ class Flask(Scaffold):
|
|||
if handler is None:
|
||||
raise
|
||||
|
||||
return self.ensure_sync(handler)(e)
|
||||
return ensure_sync(handler)(e)
|
||||
|
||||
def handle_exception(self, e: Exception) -> Response:
|
||||
"""Handle an exception that did not have an error handler
|
||||
|
|
@ -1437,7 +1424,7 @@ class Flask(Scaffold):
|
|||
handler = self._find_error_handler(server_error)
|
||||
|
||||
if handler is not None:
|
||||
server_error = self.ensure_sync(handler)(server_error)
|
||||
server_error = ensure_sync(handler)(server_error)
|
||||
|
||||
return self.finalize_request(server_error, from_error_handler=True)
|
||||
|
||||
|
|
@ -1499,7 +1486,7 @@ class Flask(Scaffold):
|
|||
):
|
||||
return self.make_default_options_response()
|
||||
# otherwise dispatch to the handler for that endpoint
|
||||
return self.ensure_sync(self.view_functions[rule.endpoint])(**req.view_args)
|
||||
return ensure_sync(self.view_functions[rule.endpoint])(**req.view_args)
|
||||
|
||||
def full_dispatch_request(self) -> Response:
|
||||
"""Dispatches the request and on top of that performs request
|
||||
|
|
@ -1560,7 +1547,7 @@ class Flask(Scaffold):
|
|||
if self._got_first_request:
|
||||
return
|
||||
for func in self.before_first_request_funcs:
|
||||
self.ensure_sync(func)()
|
||||
ensure_sync(func)()
|
||||
self._got_first_request = True
|
||||
|
||||
def make_default_options_response(self) -> Response:
|
||||
|
|
@ -1586,50 +1573,6 @@ class Flask(Scaffold):
|
|||
"""
|
||||
return False
|
||||
|
||||
def ensure_sync(self, func: t.Callable) -> t.Callable:
|
||||
"""Ensure that the function is synchronous for WSGI workers.
|
||||
Plain ``def`` functions are returned as-is. ``async def``
|
||||
functions are wrapped to run and wait for the response.
|
||||
|
||||
Override this method to change how the app runs async views.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
if iscoroutinefunction(func):
|
||||
return self.async_to_sync(func)
|
||||
|
||||
return func
|
||||
|
||||
def async_to_sync(
|
||||
self, func: t.Callable[..., t.Coroutine]
|
||||
) -> t.Callable[..., t.Any]:
|
||||
"""Return a sync function that will run the coroutine function.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
result = app.async_to_sync(func)(*args, **kwargs)
|
||||
|
||||
Override this method to change how the app converts async code
|
||||
to be synchronously callable.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
try:
|
||||
from asgiref.sync import async_to_sync as asgiref_async_to_sync
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
"Install Flask with the 'async' extra in order to use async views."
|
||||
)
|
||||
|
||||
# Check that Werkzeug isn't using its fallback ContextVar class.
|
||||
if ContextVar.__module__ == "werkzeug.local":
|
||||
raise RuntimeError(
|
||||
"Async cannot be used with this combination of Python "
|
||||
"and Greenlet versions."
|
||||
)
|
||||
|
||||
return asgiref_async_to_sync(func)
|
||||
|
||||
def make_response(self, rv: ResponseReturnValue) -> Response:
|
||||
"""Convert the return value from a view function to an instance of
|
||||
:attr:`response_class`.
|
||||
|
|
@ -1857,7 +1800,7 @@ class Flask(Scaffold):
|
|||
if bp in self.before_request_funcs:
|
||||
funcs = chain(funcs, self.before_request_funcs[bp])
|
||||
for func in funcs:
|
||||
rv = self.ensure_sync(func)()
|
||||
rv = ensure_sync(func)()
|
||||
if rv is not None:
|
||||
return rv
|
||||
|
||||
|
|
@ -1884,7 +1827,7 @@ class Flask(Scaffold):
|
|||
if None in self.after_request_funcs:
|
||||
funcs = chain(funcs, reversed(self.after_request_funcs[None]))
|
||||
for handler in funcs:
|
||||
response = self.ensure_sync(handler)(response)
|
||||
response = ensure_sync(handler)(response)
|
||||
if not self.session_interface.is_null_session(ctx.session):
|
||||
self.session_interface.save_session(self, ctx.session, response)
|
||||
return response
|
||||
|
|
@ -1921,7 +1864,7 @@ class Flask(Scaffold):
|
|||
if bp in self.teardown_request_funcs:
|
||||
funcs = chain(funcs, reversed(self.teardown_request_funcs[bp]))
|
||||
for func in funcs:
|
||||
self.ensure_sync(func)(exc)
|
||||
ensure_sync(func)(exc)
|
||||
request_tearing_down.send(self, exc=exc)
|
||||
|
||||
def do_teardown_appcontext(
|
||||
|
|
@ -1944,7 +1887,7 @@ class Flask(Scaffold):
|
|||
if exc is _sentinel:
|
||||
exc = sys.exc_info()[1]
|
||||
for func in reversed(self.teardown_appcontext_funcs):
|
||||
self.ensure_sync(func)(exc)
|
||||
ensure_sync(func)(exc)
|
||||
appcontext_tearing_down.send(self, exc=exc)
|
||||
|
||||
def app_context(self) -> AppContext:
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import inspect
|
||||
import os
|
||||
import pkgutil
|
||||
import socket
|
||||
|
|
@ -12,6 +13,7 @@ from threading import RLock
|
|||
|
||||
import werkzeug.utils
|
||||
from werkzeug.exceptions import NotFound
|
||||
from werkzeug.local import ContextVar
|
||||
from werkzeug.routing import BuildError
|
||||
from werkzeug.urls import url_quote
|
||||
|
||||
|
|
@ -25,6 +27,19 @@ from .signals import message_flashed
|
|||
if t.TYPE_CHECKING:
|
||||
from .wrappers import Response
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
iscoroutinefunction = inspect.iscoroutinefunction
|
||||
else:
|
||||
|
||||
def iscoroutinefunction(func: t.Any) -> bool:
|
||||
while inspect.ismethod(func):
|
||||
func = func.__func__
|
||||
|
||||
while isinstance(func, functools.partial):
|
||||
func = func.func
|
||||
|
||||
return inspect.iscoroutinefunction(func)
|
||||
|
||||
|
||||
def get_env() -> str:
|
||||
"""Get the environment the app is running in, indicated by the
|
||||
|
|
@ -834,3 +849,42 @@ def _split_blueprint_path(name: str) -> t.List[str]:
|
|||
out.extend(_split_blueprint_path(name.rpartition(".")[0]))
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def ensure_sync(func: t.Callable) -> t.Callable:
|
||||
"""Ensure that the function is synchronous for WSGI workers.
|
||||
Plain ``def`` functions are returned as-is. ``async def``
|
||||
functions are wrapped to run and wait for the response.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
if iscoroutinefunction(func):
|
||||
return async_to_sync(func)
|
||||
|
||||
return func
|
||||
|
||||
|
||||
def async_to_sync(func: t.Callable[..., t.Coroutine]) -> t.Callable[..., t.Any]:
|
||||
"""Return a sync function that will run the coroutine function.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
result = async_to_sync(func)(*args, **kwargs)
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
try:
|
||||
from asgiref.sync import async_to_sync as asgiref_async_to_sync
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
"Install Flask with the 'async' extra in order to use async views."
|
||||
)
|
||||
|
||||
# Check that Werkzeug isn't using its fallback ContextVar class.
|
||||
if ContextVar.__module__ == "werkzeug.local":
|
||||
raise RuntimeError(
|
||||
"Async cannot be used with this combination of Python "
|
||||
"and Greenlet versions."
|
||||
)
|
||||
|
||||
return asgiref_async_to_sync(func)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue