Move async to sync conversion to helpers
Co-authored-by: Uma Annamalai <umaannamalai@users.noreply.github.com> Co-authored-by: Kevin Yang <kkaiyang94@users.noreply.github.com> Co-authored-by: Katherine Kelly <kat-star@users.noreply.github.com>
This commit is contained in:
parent
aac67289e5
commit
60e3d34d19
2 changed files with 64 additions and 67 deletions
|
|
@ -16,7 +16,6 @@ from werkzeug.exceptions import BadRequest
|
||||||
from werkzeug.exceptions import BadRequestKeyError
|
from werkzeug.exceptions import BadRequestKeyError
|
||||||
from werkzeug.exceptions import HTTPException
|
from werkzeug.exceptions import HTTPException
|
||||||
from werkzeug.exceptions import InternalServerError
|
from werkzeug.exceptions import InternalServerError
|
||||||
from werkzeug.local import ContextVar
|
|
||||||
from werkzeug.routing import BuildError
|
from werkzeug.routing import BuildError
|
||||||
from werkzeug.routing import Map
|
from werkzeug.routing import Map
|
||||||
from werkzeug.routing import MapAdapter
|
from werkzeug.routing import MapAdapter
|
||||||
|
|
@ -37,6 +36,7 @@ from .globals import g
|
||||||
from .globals import request
|
from .globals import request
|
||||||
from .globals import session
|
from .globals import session
|
||||||
from .helpers import _split_blueprint_path
|
from .helpers import _split_blueprint_path
|
||||||
|
from .helpers import ensure_sync
|
||||||
from .helpers import get_debug_flag
|
from .helpers import get_debug_flag
|
||||||
from .helpers import get_env
|
from .helpers import get_env
|
||||||
from .helpers import get_flashed_messages
|
from .helpers import get_flashed_messages
|
||||||
|
|
@ -79,19 +79,6 @@ if t.TYPE_CHECKING:
|
||||||
from .testing import FlaskClient
|
from .testing import FlaskClient
|
||||||
from .testing import FlaskCliRunner
|
from .testing import FlaskCliRunner
|
||||||
|
|
||||||
if sys.version_info >= (3, 8):
|
|
||||||
iscoroutinefunction = inspect.iscoroutinefunction
|
|
||||||
else:
|
|
||||||
|
|
||||||
def iscoroutinefunction(func: t.Any) -> bool:
|
|
||||||
while inspect.ismethod(func):
|
|
||||||
func = func.__func__
|
|
||||||
|
|
||||||
while isinstance(func, functools.partial):
|
|
||||||
func = func.func
|
|
||||||
|
|
||||||
return inspect.iscoroutinefunction(func)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_timedelta(value: t.Optional[timedelta]) -> t.Optional[timedelta]:
|
def _make_timedelta(value: t.Optional[timedelta]) -> t.Optional[timedelta]:
|
||||||
if value is None or isinstance(value, timedelta):
|
if value is None or isinstance(value, timedelta):
|
||||||
|
|
@ -1323,7 +1310,7 @@ class Flask(Scaffold):
|
||||||
handler = self._find_error_handler(e)
|
handler = self._find_error_handler(e)
|
||||||
if handler is None:
|
if handler is None:
|
||||||
return e
|
return e
|
||||||
return self.ensure_sync(handler)(e)
|
return ensure_sync(handler)(e)
|
||||||
|
|
||||||
def trap_http_exception(self, e: Exception) -> bool:
|
def trap_http_exception(self, e: Exception) -> bool:
|
||||||
"""Checks if an HTTP exception should be trapped or not. By default
|
"""Checks if an HTTP exception should be trapped or not. By default
|
||||||
|
|
@ -1390,7 +1377,7 @@ class Flask(Scaffold):
|
||||||
if handler is None:
|
if handler is None:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return self.ensure_sync(handler)(e)
|
return ensure_sync(handler)(e)
|
||||||
|
|
||||||
def handle_exception(self, e: Exception) -> Response:
|
def handle_exception(self, e: Exception) -> Response:
|
||||||
"""Handle an exception that did not have an error handler
|
"""Handle an exception that did not have an error handler
|
||||||
|
|
@ -1437,7 +1424,7 @@ class Flask(Scaffold):
|
||||||
handler = self._find_error_handler(server_error)
|
handler = self._find_error_handler(server_error)
|
||||||
|
|
||||||
if handler is not None:
|
if handler is not None:
|
||||||
server_error = self.ensure_sync(handler)(server_error)
|
server_error = ensure_sync(handler)(server_error)
|
||||||
|
|
||||||
return self.finalize_request(server_error, from_error_handler=True)
|
return self.finalize_request(server_error, from_error_handler=True)
|
||||||
|
|
||||||
|
|
@ -1499,7 +1486,7 @@ class Flask(Scaffold):
|
||||||
):
|
):
|
||||||
return self.make_default_options_response()
|
return self.make_default_options_response()
|
||||||
# otherwise dispatch to the handler for that endpoint
|
# otherwise dispatch to the handler for that endpoint
|
||||||
return self.ensure_sync(self.view_functions[rule.endpoint])(**req.view_args)
|
return ensure_sync(self.view_functions[rule.endpoint])(**req.view_args)
|
||||||
|
|
||||||
def full_dispatch_request(self) -> Response:
|
def full_dispatch_request(self) -> Response:
|
||||||
"""Dispatches the request and on top of that performs request
|
"""Dispatches the request and on top of that performs request
|
||||||
|
|
@ -1560,7 +1547,7 @@ class Flask(Scaffold):
|
||||||
if self._got_first_request:
|
if self._got_first_request:
|
||||||
return
|
return
|
||||||
for func in self.before_first_request_funcs:
|
for func in self.before_first_request_funcs:
|
||||||
self.ensure_sync(func)()
|
ensure_sync(func)()
|
||||||
self._got_first_request = True
|
self._got_first_request = True
|
||||||
|
|
||||||
def make_default_options_response(self) -> Response:
|
def make_default_options_response(self) -> Response:
|
||||||
|
|
@ -1586,50 +1573,6 @@ class Flask(Scaffold):
|
||||||
"""
|
"""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def ensure_sync(self, func: t.Callable) -> t.Callable:
|
|
||||||
"""Ensure that the function is synchronous for WSGI workers.
|
|
||||||
Plain ``def`` functions are returned as-is. ``async def``
|
|
||||||
functions are wrapped to run and wait for the response.
|
|
||||||
|
|
||||||
Override this method to change how the app runs async views.
|
|
||||||
|
|
||||||
.. versionadded:: 2.0
|
|
||||||
"""
|
|
||||||
if iscoroutinefunction(func):
|
|
||||||
return self.async_to_sync(func)
|
|
||||||
|
|
||||||
return func
|
|
||||||
|
|
||||||
def async_to_sync(
|
|
||||||
self, func: t.Callable[..., t.Coroutine]
|
|
||||||
) -> t.Callable[..., t.Any]:
|
|
||||||
"""Return a sync function that will run the coroutine function.
|
|
||||||
|
|
||||||
.. code-block:: python
|
|
||||||
|
|
||||||
result = app.async_to_sync(func)(*args, **kwargs)
|
|
||||||
|
|
||||||
Override this method to change how the app converts async code
|
|
||||||
to be synchronously callable.
|
|
||||||
|
|
||||||
.. versionadded:: 2.0
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from asgiref.sync import async_to_sync as asgiref_async_to_sync
|
|
||||||
except ImportError:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Install Flask with the 'async' extra in order to use async views."
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check that Werkzeug isn't using its fallback ContextVar class.
|
|
||||||
if ContextVar.__module__ == "werkzeug.local":
|
|
||||||
raise RuntimeError(
|
|
||||||
"Async cannot be used with this combination of Python "
|
|
||||||
"and Greenlet versions."
|
|
||||||
)
|
|
||||||
|
|
||||||
return asgiref_async_to_sync(func)
|
|
||||||
|
|
||||||
def make_response(self, rv: ResponseReturnValue) -> Response:
|
def make_response(self, rv: ResponseReturnValue) -> Response:
|
||||||
"""Convert the return value from a view function to an instance of
|
"""Convert the return value from a view function to an instance of
|
||||||
:attr:`response_class`.
|
:attr:`response_class`.
|
||||||
|
|
@ -1857,7 +1800,7 @@ class Flask(Scaffold):
|
||||||
if bp in self.before_request_funcs:
|
if bp in self.before_request_funcs:
|
||||||
funcs = chain(funcs, self.before_request_funcs[bp])
|
funcs = chain(funcs, self.before_request_funcs[bp])
|
||||||
for func in funcs:
|
for func in funcs:
|
||||||
rv = self.ensure_sync(func)()
|
rv = ensure_sync(func)()
|
||||||
if rv is not None:
|
if rv is not None:
|
||||||
return rv
|
return rv
|
||||||
|
|
||||||
|
|
@ -1884,7 +1827,7 @@ class Flask(Scaffold):
|
||||||
if None in self.after_request_funcs:
|
if None in self.after_request_funcs:
|
||||||
funcs = chain(funcs, reversed(self.after_request_funcs[None]))
|
funcs = chain(funcs, reversed(self.after_request_funcs[None]))
|
||||||
for handler in funcs:
|
for handler in funcs:
|
||||||
response = self.ensure_sync(handler)(response)
|
response = ensure_sync(handler)(response)
|
||||||
if not self.session_interface.is_null_session(ctx.session):
|
if not self.session_interface.is_null_session(ctx.session):
|
||||||
self.session_interface.save_session(self, ctx.session, response)
|
self.session_interface.save_session(self, ctx.session, response)
|
||||||
return response
|
return response
|
||||||
|
|
@ -1921,7 +1864,7 @@ class Flask(Scaffold):
|
||||||
if bp in self.teardown_request_funcs:
|
if bp in self.teardown_request_funcs:
|
||||||
funcs = chain(funcs, reversed(self.teardown_request_funcs[bp]))
|
funcs = chain(funcs, reversed(self.teardown_request_funcs[bp]))
|
||||||
for func in funcs:
|
for func in funcs:
|
||||||
self.ensure_sync(func)(exc)
|
ensure_sync(func)(exc)
|
||||||
request_tearing_down.send(self, exc=exc)
|
request_tearing_down.send(self, exc=exc)
|
||||||
|
|
||||||
def do_teardown_appcontext(
|
def do_teardown_appcontext(
|
||||||
|
|
@ -1944,7 +1887,7 @@ class Flask(Scaffold):
|
||||||
if exc is _sentinel:
|
if exc is _sentinel:
|
||||||
exc = sys.exc_info()[1]
|
exc = sys.exc_info()[1]
|
||||||
for func in reversed(self.teardown_appcontext_funcs):
|
for func in reversed(self.teardown_appcontext_funcs):
|
||||||
self.ensure_sync(func)(exc)
|
ensure_sync(func)(exc)
|
||||||
appcontext_tearing_down.send(self, exc=exc)
|
appcontext_tearing_down.send(self, exc=exc)
|
||||||
|
|
||||||
def app_context(self) -> AppContext:
|
def app_context(self) -> AppContext:
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
import inspect
|
||||||
import os
|
import os
|
||||||
import pkgutil
|
import pkgutil
|
||||||
import socket
|
import socket
|
||||||
|
|
@ -12,6 +13,7 @@ from threading import RLock
|
||||||
|
|
||||||
import werkzeug.utils
|
import werkzeug.utils
|
||||||
from werkzeug.exceptions import NotFound
|
from werkzeug.exceptions import NotFound
|
||||||
|
from werkzeug.local import ContextVar
|
||||||
from werkzeug.routing import BuildError
|
from werkzeug.routing import BuildError
|
||||||
from werkzeug.urls import url_quote
|
from werkzeug.urls import url_quote
|
||||||
|
|
||||||
|
|
@ -25,6 +27,19 @@ from .signals import message_flashed
|
||||||
if t.TYPE_CHECKING:
|
if t.TYPE_CHECKING:
|
||||||
from .wrappers import Response
|
from .wrappers import Response
|
||||||
|
|
||||||
|
if sys.version_info >= (3, 8):
|
||||||
|
iscoroutinefunction = inspect.iscoroutinefunction
|
||||||
|
else:
|
||||||
|
|
||||||
|
def iscoroutinefunction(func: t.Any) -> bool:
|
||||||
|
while inspect.ismethod(func):
|
||||||
|
func = func.__func__
|
||||||
|
|
||||||
|
while isinstance(func, functools.partial):
|
||||||
|
func = func.func
|
||||||
|
|
||||||
|
return inspect.iscoroutinefunction(func)
|
||||||
|
|
||||||
|
|
||||||
def get_env() -> str:
|
def get_env() -> str:
|
||||||
"""Get the environment the app is running in, indicated by the
|
"""Get the environment the app is running in, indicated by the
|
||||||
|
|
@ -834,3 +849,42 @@ def _split_blueprint_path(name: str) -> t.List[str]:
|
||||||
out.extend(_split_blueprint_path(name.rpartition(".")[0]))
|
out.extend(_split_blueprint_path(name.rpartition(".")[0]))
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_sync(func: t.Callable) -> t.Callable:
|
||||||
|
"""Ensure that the function is synchronous for WSGI workers.
|
||||||
|
Plain ``def`` functions are returned as-is. ``async def``
|
||||||
|
functions are wrapped to run and wait for the response.
|
||||||
|
|
||||||
|
.. versionadded:: 2.0
|
||||||
|
"""
|
||||||
|
if iscoroutinefunction(func):
|
||||||
|
return async_to_sync(func)
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
def async_to_sync(func: t.Callable[..., t.Coroutine]) -> t.Callable[..., t.Any]:
|
||||||
|
"""Return a sync function that will run the coroutine function.
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
result = async_to_sync(func)(*args, **kwargs)
|
||||||
|
|
||||||
|
.. versionadded:: 2.0
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
from asgiref.sync import async_to_sync as asgiref_async_to_sync
|
||||||
|
except ImportError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Install Flask with the 'async' extra in order to use async views."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that Werkzeug isn't using its fallback ContextVar class.
|
||||||
|
if ContextVar.__module__ == "werkzeug.local":
|
||||||
|
raise RuntimeError(
|
||||||
|
"Async cannot be used with this combination of Python "
|
||||||
|
"and Greenlet versions."
|
||||||
|
)
|
||||||
|
|
||||||
|
return asgiref_async_to_sync(func)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue