Simplify the async handling code

Firstly `run_sync` was a misleading name as it didn't run anything,
instead I think `async_to_sync` is much clearer as it converts a
coroutine function to a function. (Name stolen from asgiref).

Secondly trying to run the ensure_sync during registration made the
code more complex and brittle, e.g. the _flask_async_wrapper
usage. This was done to pay any setup costs during registration rather
than runtime, however this only saved a iscoroutne check. It allows
the weirdness of the Blueprint and Scaffold ensure_sync methods to be
removed.

Switching to runtime ensure_sync usage provides a method for
extensions to also support async, as now documented.
This commit is contained in:
pgjones 2021-05-02 20:49:46 +01:00 committed by David Lord
parent cb13128cf0
commit 7f87f3dd93
No known key found for this signature in database
GPG key ID: 7A1C87E3F5BC42A8
6 changed files with 53 additions and 68 deletions

View file

@ -35,12 +35,12 @@ from .globals import _request_ctx_stack
from .globals import g
from .globals import request
from .globals import session
from .helpers import async_to_sync
from .helpers import get_debug_flag
from .helpers import get_env
from .helpers import get_flashed_messages
from .helpers import get_load_dotenv
from .helpers import locked_cached_property
from .helpers import run_async
from .helpers import url_for
from .json import jsonify
from .logging import create_logger
@ -1080,14 +1080,12 @@ class Flask(Scaffold):
self.url_map.add(rule)
if view_func is not None:
old_func = self.view_functions.get(endpoint)
if getattr(old_func, "_flask_sync_wrapper", False):
old_func = old_func.__wrapped__ # type: ignore
if old_func is not None and old_func != view_func:
raise AssertionError(
"View function mapping is overwriting an existing"
f" endpoint function: {endpoint}"
)
self.view_functions[endpoint] = self.ensure_sync(view_func)
self.view_functions[endpoint] = view_func
@setupmethod
def template_filter(self, name: t.Optional[str] = None) -> t.Callable:
@ -1208,7 +1206,7 @@ class Flask(Scaffold):
.. versionadded:: 0.8
"""
self.before_first_request_funcs.append(self.ensure_sync(f))
self.before_first_request_funcs.append(f)
return f
@setupmethod
@ -1241,7 +1239,7 @@ class Flask(Scaffold):
.. versionadded:: 0.9
"""
self.teardown_appcontext_funcs.append(self.ensure_sync(f))
self.teardown_appcontext_funcs.append(f)
return f
@setupmethod
@ -1308,7 +1306,7 @@ class Flask(Scaffold):
handler = self._find_error_handler(e)
if handler is None:
return e
return handler(e)
return self.ensure_sync(handler)(e)
def trap_http_exception(self, e: Exception) -> bool:
"""Checks if an HTTP exception should be trapped or not. By default
@ -1375,7 +1373,7 @@ class Flask(Scaffold):
if handler is None:
raise
return handler(e)
return self.ensure_sync(handler)(e)
def handle_exception(self, e: Exception) -> Response:
"""Handle an exception that did not have an error handler
@ -1422,7 +1420,7 @@ class Flask(Scaffold):
handler = self._find_error_handler(server_error)
if handler is not None:
server_error = handler(server_error)
server_error = self.ensure_sync(handler)(server_error)
return self.finalize_request(server_error, from_error_handler=True)
@ -1484,7 +1482,7 @@ class Flask(Scaffold):
):
return self.make_default_options_response()
# otherwise dispatch to the handler for that endpoint
return self.view_functions[rule.endpoint](**req.view_args)
return self.ensure_sync(self.view_functions[rule.endpoint])(**req.view_args)
def full_dispatch_request(self) -> Response:
"""Dispatches the request and on top of that performs request
@ -1545,7 +1543,7 @@ class Flask(Scaffold):
if self._got_first_request:
return
for func in self.before_first_request_funcs:
func()
self.ensure_sync(func)()
self._got_first_request = True
def make_default_options_response(self) -> Response:
@ -1581,7 +1579,7 @@ class Flask(Scaffold):
.. versionadded:: 2.0
"""
if iscoroutinefunction(func):
return run_async(func)
return async_to_sync(func)
return func
@ -1807,7 +1805,7 @@ class Flask(Scaffold):
if bp in self.before_request_funcs:
funcs = chain(funcs, self.before_request_funcs[bp])
for func in funcs:
rv = func()
rv = self.ensure_sync(func)()
if rv is not None:
return rv
@ -1834,7 +1832,7 @@ class Flask(Scaffold):
if None in self.after_request_funcs:
funcs = chain(funcs, reversed(self.after_request_funcs[None]))
for handler in funcs:
response = handler(response)
response = self.ensure_sync(handler)(response)
if not self.session_interface.is_null_session(ctx.session):
self.session_interface.save_session(self, ctx.session, response)
return response
@ -1871,7 +1869,7 @@ class Flask(Scaffold):
if bp in self.teardown_request_funcs:
funcs = chain(funcs, reversed(self.teardown_request_funcs[bp]))
for func in funcs:
func(exc)
self.ensure_sync(func)(exc)
request_tearing_down.send(self, exc=exc)
def do_teardown_appcontext(
@ -1894,7 +1892,7 @@ class Flask(Scaffold):
if exc is _sentinel:
exc = sys.exc_info()[1]
for func in reversed(self.teardown_appcontext_funcs):
func(exc)
self.ensure_sync(func)(exc)
appcontext_tearing_down.send(self, exc=exc)
def app_context(self) -> AppContext:

View file

@ -292,13 +292,10 @@ class Blueprint(Scaffold):
# Merge blueprint data into parent.
if first_registration:
def extend(bp_dict, parent_dict, ensure_sync=False):
def extend(bp_dict, parent_dict):
for key, values in bp_dict.items():
key = self.name if key is None else f"{self.name}.{key}"
if ensure_sync:
values = [app.ensure_sync(func) for func in values]
parent_dict[key].extend(values)
for key, value in self.error_handler_spec.items():
@ -307,8 +304,7 @@ class Blueprint(Scaffold):
dict,
{
code: {
exc_class: app.ensure_sync(func)
for exc_class, func in code_values.items()
exc_class: func for exc_class, func in code_values.items()
}
for code, code_values in value.items()
},
@ -316,16 +312,13 @@ class Blueprint(Scaffold):
app.error_handler_spec[key] = value
for endpoint, func in self.view_functions.items():
app.view_functions[endpoint] = app.ensure_sync(func)
app.view_functions[endpoint] = func
extend(
self.before_request_funcs, app.before_request_funcs, ensure_sync=True
)
extend(self.after_request_funcs, app.after_request_funcs, ensure_sync=True)
extend(self.before_request_funcs, app.before_request_funcs)
extend(self.after_request_funcs, app.after_request_funcs)
extend(
self.teardown_request_funcs,
app.teardown_request_funcs,
ensure_sync=True,
)
extend(self.url_default_functions, app.url_default_functions)
extend(self.url_value_preprocessors, app.url_value_preprocessors)
@ -478,9 +471,7 @@ class Blueprint(Scaffold):
before each request, even if outside of a blueprint.
"""
self.record_once(
lambda s: s.app.before_request_funcs.setdefault(None, []).append(
s.app.ensure_sync(f)
)
lambda s: s.app.before_request_funcs.setdefault(None, []).append(f)
)
return f
@ -490,9 +481,7 @@ class Blueprint(Scaffold):
"""Like :meth:`Flask.before_first_request`. Such a function is
executed before the first request to the application.
"""
self.record_once(
lambda s: s.app.before_first_request_funcs.append(s.app.ensure_sync(f))
)
self.record_once(lambda s: s.app.before_first_request_funcs.append(f))
return f
def after_app_request(self, f: AfterRequestCallable) -> AfterRequestCallable:
@ -500,9 +489,7 @@ class Blueprint(Scaffold):
is executed after each request, even if outside of the blueprint.
"""
self.record_once(
lambda s: s.app.after_request_funcs.setdefault(None, []).append(
s.app.ensure_sync(f)
)
lambda s: s.app.after_request_funcs.setdefault(None, []).append(f)
)
return f
@ -553,14 +540,3 @@ class Blueprint(Scaffold):
lambda s: s.app.url_default_functions.setdefault(None, []).append(f)
)
return f
def ensure_sync(self, f: t.Callable) -> t.Callable:
"""Ensure the function is synchronous.
Override if you would like custom async to sync behaviour in
this blueprint. Otherwise the app's
:meth:`~flask.Flask.ensure_sync` is used.
.. versionadded:: 2.0
"""
return f

View file

@ -6,7 +6,6 @@ import typing as t
import warnings
from datetime import timedelta
from functools import update_wrapper
from functools import wraps
from threading import RLock
import werkzeug.utils
@ -803,10 +802,15 @@ def is_ip(value: str) -> bool:
return False
def run_async(func: t.Callable[..., t.Coroutine]) -> t.Callable[..., t.Any]:
"""Return a sync function that will run the coroutine function *func*."""
def async_to_sync(func: t.Callable[..., t.Coroutine]) -> t.Callable[..., t.Any]:
"""Return a sync function that will run the coroutine function *func*.
This can be used as so
result = async_to_async(func)(*args, **kwargs)
"""
try:
from asgiref.sync import async_to_sync
from asgiref.sync import async_to_sync as asgiref_async_to_sync
except ImportError:
raise RuntimeError(
"Install Flask with the 'async' extra in order to use async views."
@ -818,9 +822,4 @@ def run_async(func: t.Callable[..., t.Coroutine]) -> t.Callable[..., t.Any]:
"Async cannot be used with this combination of Python & Greenlet versions."
)
@wraps(func)
def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any:
return async_to_sync(func)(*args, **kwargs)
wrapper._flask_sync_wrapper = True # type: ignore
return wrapper
return asgiref_async_to_sync(func)

View file

@ -521,7 +521,7 @@ class Scaffold:
"""
def decorator(f):
self.view_functions[endpoint] = self.ensure_sync(f)
self.view_functions[endpoint] = f
return f
return decorator
@ -545,7 +545,7 @@ class Scaffold:
return value from the view, and further request handling is
stopped.
"""
self.before_request_funcs.setdefault(None, []).append(self.ensure_sync(f))
self.before_request_funcs.setdefault(None, []).append(f)
return f
@setupmethod
@ -561,7 +561,7 @@ class Scaffold:
should not be used for actions that must execute, such as to
close resources. Use :meth:`teardown_request` for that.
"""
self.after_request_funcs.setdefault(None, []).append(self.ensure_sync(f))
self.after_request_funcs.setdefault(None, []).append(f)
return f
@setupmethod
@ -600,7 +600,7 @@ class Scaffold:
debugger can still access it. This behavior can be controlled
by the ``PRESERVE_CONTEXT_ON_EXCEPTION`` configuration variable.
"""
self.teardown_request_funcs.setdefault(None, []).append(self.ensure_sync(f))
self.teardown_request_funcs.setdefault(None, []).append(f)
return f
@setupmethod
@ -706,7 +706,7 @@ class Scaffold:
" instead."
)
self.error_handler_spec[None][code][exc_class] = self.ensure_sync(f)
self.error_handler_spec[None][code][exc_class] = f
@staticmethod
def _get_exc_class_and_code(
@ -734,9 +734,6 @@ class Scaffold:
else:
return exc_class, None
def ensure_sync(self, func: t.Callable) -> t.Callable:
raise NotImplementedError()
def _endpoint_from_view_func(view_func: t.Callable) -> str:
"""Internal helper that returns the default endpoint for a given