Improve decorator typing (#4676)

* Add a missing setupmethod decorator

* Improve the decorator typing

This will allow type checkers to understand that the decorators return
the same function signature as passed as an argument. This follows the
guidelines from
https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators.

I've chosen to keep a TypeVar per module and usage as I think
encouraged by PEP 695, which I hope is accepted as the syntax is much
nicer.
This commit is contained in:
Phil Jones 2022-07-06 22:05:20 +01:00 committed by GitHub
parent d7482cd765
commit 9b44bf2818
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 123 additions and 76 deletions

View file

@ -74,6 +74,17 @@ if t.TYPE_CHECKING: # pragma: no cover
from .testing import FlaskClient
from .testing import FlaskCliRunner
T_before_first_request = t.TypeVar(
"T_before_first_request", bound=ft.BeforeFirstRequestCallable
)
T_shell_context_processor = t.TypeVar(
"T_shell_context_processor", bound=ft.ShellContextProcessorCallable
)
T_teardown = t.TypeVar("T_teardown", bound=ft.TeardownCallable)
T_template_filter = t.TypeVar("T_template_filter", bound=ft.TemplateFilterCallable)
T_template_global = t.TypeVar("T_template_global", bound=ft.TemplateGlobalCallable)
T_template_test = t.TypeVar("T_template_test", bound=ft.TemplateTestCallable)
if sys.version_info >= (3, 8):
iscoroutinefunction = inspect.iscoroutinefunction
else:
@ -472,7 +483,7 @@ class Flask(Scaffold):
#: when a shell context is created.
#:
#: .. versionadded:: 0.11
self.shell_context_processors: t.List[t.Callable[[], t.Dict[str, t.Any]]] = []
self.shell_context_processors: t.List[ft.ShellContextProcessorCallable] = []
#: Maps registered blueprint names to blueprint objects. The
#: dict retains the order the blueprints were registered in.
@ -1075,7 +1086,7 @@ class Flask(Scaffold):
self,
rule: str,
endpoint: t.Optional[str] = None,
view_func: t.Optional[ft.ViewCallable] = None,
view_func: t.Optional[ft.RouteCallable] = None,
provide_automatic_options: t.Optional[bool] = None,
**options: t.Any,
) -> None:
@ -1132,7 +1143,7 @@ class Flask(Scaffold):
@setupmethod
def template_filter(
self, name: t.Optional[str] = None
) -> t.Callable[[ft.TemplateFilterCallable], ft.TemplateFilterCallable]:
) -> t.Callable[[T_template_filter], T_template_filter]:
"""A decorator that is used to register custom template filter.
You can specify a name for the filter, otherwise the function
name will be used. Example::
@ -1145,7 +1156,7 @@ class Flask(Scaffold):
function name will be used.
"""
def decorator(f: ft.TemplateFilterCallable) -> ft.TemplateFilterCallable:
def decorator(f: T_template_filter) -> T_template_filter:
self.add_template_filter(f, name=name)
return f
@ -1166,7 +1177,7 @@ class Flask(Scaffold):
@setupmethod
def template_test(
self, name: t.Optional[str] = None
) -> t.Callable[[ft.TemplateTestCallable], ft.TemplateTestCallable]:
) -> t.Callable[[T_template_test], T_template_test]:
"""A decorator that is used to register custom template test.
You can specify a name for the test, otherwise the function
name will be used. Example::
@ -1186,7 +1197,7 @@ class Flask(Scaffold):
function name will be used.
"""
def decorator(f: ft.TemplateTestCallable) -> ft.TemplateTestCallable:
def decorator(f: T_template_test) -> T_template_test:
self.add_template_test(f, name=name)
return f
@ -1209,7 +1220,7 @@ class Flask(Scaffold):
@setupmethod
def template_global(
self, name: t.Optional[str] = None
) -> t.Callable[[ft.TemplateGlobalCallable], ft.TemplateGlobalCallable]:
) -> t.Callable[[T_template_global], T_template_global]:
"""A decorator that is used to register a custom template global function.
You can specify a name for the global function, otherwise the function
name will be used. Example::
@ -1224,7 +1235,7 @@ class Flask(Scaffold):
function name will be used.
"""
def decorator(f: ft.TemplateGlobalCallable) -> ft.TemplateGlobalCallable:
def decorator(f: T_template_global) -> T_template_global:
self.add_template_global(f, name=name)
return f
@ -1245,9 +1256,7 @@ class Flask(Scaffold):
self.jinja_env.globals[name or f.__name__] = f
@setupmethod
def before_first_request(
self, f: ft.BeforeFirstRequestCallable
) -> ft.BeforeFirstRequestCallable:
def before_first_request(self, f: T_before_first_request) -> T_before_first_request:
"""Registers a function to be run before the first request to this
instance of the application.
@ -1273,7 +1282,7 @@ class Flask(Scaffold):
return f
@setupmethod
def teardown_appcontext(self, f: ft.TeardownCallable) -> ft.TeardownCallable:
def teardown_appcontext(self, f: T_teardown) -> T_teardown:
"""Registers a function to be called when the application context
ends. These functions are typically also called when the request
context is popped.
@ -1306,7 +1315,9 @@ class Flask(Scaffold):
return f
@setupmethod
def shell_context_processor(self, f: t.Callable) -> t.Callable:
def shell_context_processor(
self, f: T_shell_context_processor
) -> T_shell_context_processor:
"""Registers a shell context processor function.
.. versionadded:: 0.11

View file

@ -13,6 +13,23 @@ if t.TYPE_CHECKING: # pragma: no cover
from .app import Flask
DeferredSetupFunction = t.Callable[["BlueprintSetupState"], t.Callable]
T_after_request = t.TypeVar("T_after_request", bound=ft.AfterRequestCallable)
T_before_first_request = t.TypeVar(
"T_before_first_request", bound=ft.BeforeFirstRequestCallable
)
T_before_request = t.TypeVar("T_before_request", bound=ft.BeforeRequestCallable)
T_error_handler = t.TypeVar("T_error_handler", bound=ft.ErrorHandlerCallable)
T_teardown = t.TypeVar("T_teardown", bound=ft.TeardownCallable)
T_template_context_processor = t.TypeVar(
"T_template_context_processor", bound=ft.TemplateContextProcessorCallable
)
T_template_filter = t.TypeVar("T_template_filter", bound=ft.TemplateFilterCallable)
T_template_global = t.TypeVar("T_template_global", bound=ft.TemplateGlobalCallable)
T_template_test = t.TypeVar("T_template_test", bound=ft.TemplateTestCallable)
T_url_defaults = t.TypeVar("T_url_defaults", bound=ft.URLDefaultCallable)
T_url_value_preprocessor = t.TypeVar(
"T_url_value_preprocessor", bound=ft.URLValuePreprocessorCallable
)
class BlueprintSetupState:
@ -236,7 +253,7 @@ class Blueprint(Scaffold):
if state.first_registration:
func(state)
return self.record(update_wrapper(wrapper, func))
self.record(update_wrapper(wrapper, func))
def make_setup_state(
self, app: "Flask", options: dict, first_registration: bool = False
@ -392,7 +409,7 @@ class Blueprint(Scaffold):
self,
rule: str,
endpoint: t.Optional[str] = None,
view_func: t.Optional[ft.ViewCallable] = None,
view_func: t.Optional[ft.RouteCallable] = None,
provide_automatic_options: t.Optional[bool] = None,
**options: t.Any,
) -> None:
@ -418,7 +435,7 @@ class Blueprint(Scaffold):
@setupmethod
def app_template_filter(
self, name: t.Optional[str] = None
) -> t.Callable[[ft.TemplateFilterCallable], ft.TemplateFilterCallable]:
) -> t.Callable[[T_template_filter], T_template_filter]:
"""Register a custom template filter, available application wide. Like
:meth:`Flask.template_filter` but for a blueprint.
@ -426,7 +443,7 @@ class Blueprint(Scaffold):
function name will be used.
"""
def decorator(f: ft.TemplateFilterCallable) -> ft.TemplateFilterCallable:
def decorator(f: T_template_filter) -> T_template_filter:
self.add_app_template_filter(f, name=name)
return f
@ -452,7 +469,7 @@ class Blueprint(Scaffold):
@setupmethod
def app_template_test(
self, name: t.Optional[str] = None
) -> t.Callable[[ft.TemplateTestCallable], ft.TemplateTestCallable]:
) -> t.Callable[[T_template_test], T_template_test]:
"""Register a custom template test, available application wide. Like
:meth:`Flask.template_test` but for a blueprint.
@ -462,7 +479,7 @@ class Blueprint(Scaffold):
function name will be used.
"""
def decorator(f: ft.TemplateTestCallable) -> ft.TemplateTestCallable:
def decorator(f: T_template_test) -> T_template_test:
self.add_app_template_test(f, name=name)
return f
@ -490,7 +507,7 @@ class Blueprint(Scaffold):
@setupmethod
def app_template_global(
self, name: t.Optional[str] = None
) -> t.Callable[[ft.TemplateGlobalCallable], ft.TemplateGlobalCallable]:
) -> t.Callable[[T_template_global], T_template_global]:
"""Register a custom template global, available application wide. Like
:meth:`Flask.template_global` but for a blueprint.
@ -500,7 +517,7 @@ class Blueprint(Scaffold):
function name will be used.
"""
def decorator(f: ft.TemplateGlobalCallable) -> ft.TemplateGlobalCallable:
def decorator(f: T_template_global) -> T_template_global:
self.add_app_template_global(f, name=name)
return f
@ -526,9 +543,7 @@ class Blueprint(Scaffold):
self.record_once(register_template)
@setupmethod
def before_app_request(
self, f: ft.BeforeRequestCallable
) -> ft.BeforeRequestCallable:
def before_app_request(self, f: T_before_request) -> T_before_request:
"""Like :meth:`Flask.before_request`. Such a function is executed
before each request, even if outside of a blueprint.
"""
@ -539,8 +554,8 @@ class Blueprint(Scaffold):
@setupmethod
def before_app_first_request(
self, f: ft.BeforeFirstRequestCallable
) -> ft.BeforeFirstRequestCallable:
self, f: T_before_first_request
) -> T_before_first_request:
"""Like :meth:`Flask.before_first_request`. Such a function is
executed before the first request to the application.
@ -560,7 +575,8 @@ class Blueprint(Scaffold):
self.record_once(lambda s: s.app.before_first_request_funcs.append(f))
return f
def after_app_request(self, f: ft.AfterRequestCallable) -> ft.AfterRequestCallable:
@setupmethod
def after_app_request(self, f: T_after_request) -> T_after_request:
"""Like :meth:`Flask.after_request` but for a blueprint. Such a function
is executed after each request, even if outside of the blueprint.
"""
@ -570,7 +586,7 @@ class Blueprint(Scaffold):
return f
@setupmethod
def teardown_app_request(self, f: ft.TeardownCallable) -> ft.TeardownCallable:
def teardown_app_request(self, f: T_teardown) -> T_teardown:
"""Like :meth:`Flask.teardown_request` but for a blueprint. Such a
function is executed when tearing down each request, even if outside of
the blueprint.
@ -582,8 +598,8 @@ class Blueprint(Scaffold):
@setupmethod
def app_context_processor(
self, f: ft.TemplateContextProcessorCallable
) -> ft.TemplateContextProcessorCallable:
self, f: T_template_context_processor
) -> T_template_context_processor:
"""Like :meth:`Flask.context_processor` but for a blueprint. Such a
function is executed each request, even if outside of the blueprint.
"""
@ -595,12 +611,12 @@ class Blueprint(Scaffold):
@setupmethod
def app_errorhandler(
self, code: t.Union[t.Type[Exception], int]
) -> t.Callable[[ft.ErrorHandlerDecorator], ft.ErrorHandlerDecorator]:
) -> t.Callable[[T_error_handler], T_error_handler]:
"""Like :meth:`Flask.errorhandler` but for a blueprint. This
handler is used for all requests, even if outside of the blueprint.
"""
def decorator(f: ft.ErrorHandlerDecorator) -> ft.ErrorHandlerDecorator:
def decorator(f: T_error_handler) -> T_error_handler:
self.record_once(lambda s: s.app.errorhandler(code)(f))
return f
@ -608,8 +624,8 @@ class Blueprint(Scaffold):
@setupmethod
def app_url_value_preprocessor(
self, f: ft.URLValuePreprocessorCallable
) -> ft.URLValuePreprocessorCallable:
self, f: T_url_value_preprocessor
) -> T_url_value_preprocessor:
"""Same as :meth:`url_value_preprocessor` but application wide."""
self.record_once(
lambda s: s.app.url_value_preprocessors.setdefault(None, []).append(f)
@ -617,7 +633,7 @@ class Blueprint(Scaffold):
return f
@setupmethod
def app_url_defaults(self, f: ft.URLDefaultCallable) -> ft.URLDefaultCallable:
def app_url_defaults(self, f: T_url_defaults) -> T_url_defaults:
"""Same as :meth:`url_defaults` but application wide."""
self.record_once(
lambda s: s.app.url_default_functions.setdefault(None, []).append(f)

View file

@ -28,6 +28,18 @@ if t.TYPE_CHECKING: # pragma: no cover
_sentinel = object()
F = t.TypeVar("F", bound=t.Callable[..., t.Any])
T_after_request = t.TypeVar("T_after_request", bound=ft.AfterRequestCallable)
T_before_request = t.TypeVar("T_before_request", bound=ft.BeforeRequestCallable)
T_error_handler = t.TypeVar("T_error_handler", bound=ft.ErrorHandlerCallable)
T_teardown = t.TypeVar("T_teardown", bound=ft.TeardownCallable)
T_template_context_processor = t.TypeVar(
"T_template_context_processor", bound=ft.TemplateContextProcessorCallable
)
T_url_defaults = t.TypeVar("T_url_defaults", bound=ft.URLDefaultCallable)
T_url_value_preprocessor = t.TypeVar(
"T_url_value_preprocessor", bound=ft.URLValuePreprocessorCallable
)
T_route = t.TypeVar("T_route", bound=ft.RouteCallable)
def setupmethod(f: F) -> F:
@ -352,16 +364,14 @@ class Scaffold:
method: str,
rule: str,
options: dict,
) -> t.Callable[[ft.RouteDecorator], ft.RouteDecorator]:
) -> t.Callable[[T_route], T_route]:
if "methods" in options:
raise TypeError("Use the 'route' decorator to use the 'methods' argument.")
return self.route(rule, methods=[method], **options)
@setupmethod
def get(
self, rule: str, **options: t.Any
) -> t.Callable[[ft.RouteDecorator], ft.RouteDecorator]:
def get(self, rule: str, **options: t.Any) -> t.Callable[[T_route], T_route]:
"""Shortcut for :meth:`route` with ``methods=["GET"]``.
.. versionadded:: 2.0
@ -369,9 +379,7 @@ class Scaffold:
return self._method_route("GET", rule, options)
@setupmethod
def post(
self, rule: str, **options: t.Any
) -> t.Callable[[ft.RouteDecorator], ft.RouteDecorator]:
def post(self, rule: str, **options: t.Any) -> t.Callable[[T_route], T_route]:
"""Shortcut for :meth:`route` with ``methods=["POST"]``.
.. versionadded:: 2.0
@ -379,9 +387,7 @@ class Scaffold:
return self._method_route("POST", rule, options)
@setupmethod
def put(
self, rule: str, **options: t.Any
) -> t.Callable[[ft.RouteDecorator], ft.RouteDecorator]:
def put(self, rule: str, **options: t.Any) -> t.Callable[[T_route], T_route]:
"""Shortcut for :meth:`route` with ``methods=["PUT"]``.
.. versionadded:: 2.0
@ -389,9 +395,7 @@ class Scaffold:
return self._method_route("PUT", rule, options)
@setupmethod
def delete(
self, rule: str, **options: t.Any
) -> t.Callable[[ft.RouteDecorator], ft.RouteDecorator]:
def delete(self, rule: str, **options: t.Any) -> t.Callable[[T_route], T_route]:
"""Shortcut for :meth:`route` with ``methods=["DELETE"]``.
.. versionadded:: 2.0
@ -399,9 +403,7 @@ class Scaffold:
return self._method_route("DELETE", rule, options)
@setupmethod
def patch(
self, rule: str, **options: t.Any
) -> t.Callable[[ft.RouteDecorator], ft.RouteDecorator]:
def patch(self, rule: str, **options: t.Any) -> t.Callable[[T_route], T_route]:
"""Shortcut for :meth:`route` with ``methods=["PATCH"]``.
.. versionadded:: 2.0
@ -409,9 +411,7 @@ class Scaffold:
return self._method_route("PATCH", rule, options)
@setupmethod
def route(
self, rule: str, **options: t.Any
) -> t.Callable[[ft.RouteDecorator], ft.RouteDecorator]:
def route(self, rule: str, **options: t.Any) -> t.Callable[[T_route], T_route]:
"""Decorate a view function to register it with the given URL
rule and options. Calls :meth:`add_url_rule`, which has more
details about the implementation.
@ -435,7 +435,7 @@ class Scaffold:
:class:`~werkzeug.routing.Rule` object.
"""
def decorator(f: ft.RouteDecorator) -> ft.RouteDecorator:
def decorator(f: T_route) -> T_route:
endpoint = options.pop("endpoint", None)
self.add_url_rule(rule, endpoint, f, **options)
return f
@ -447,7 +447,7 @@ class Scaffold:
self,
rule: str,
endpoint: t.Optional[str] = None,
view_func: t.Optional[ft.ViewCallable] = None,
view_func: t.Optional[ft.RouteCallable] = None,
provide_automatic_options: t.Optional[bool] = None,
**options: t.Any,
) -> None:
@ -511,7 +511,7 @@ class Scaffold:
raise NotImplementedError
@setupmethod
def endpoint(self, endpoint: str) -> t.Callable:
def endpoint(self, endpoint: str) -> t.Callable[[F], F]:
"""Decorate a view function to register it for the given
endpoint. Used if a rule is added without a ``view_func`` with
:meth:`add_url_rule`.
@ -528,14 +528,14 @@ class Scaffold:
function.
"""
def decorator(f):
def decorator(f: F) -> F:
self.view_functions[endpoint] = f
return f
return decorator
@setupmethod
def before_request(self, f: ft.BeforeRequestCallable) -> ft.BeforeRequestCallable:
def before_request(self, f: T_before_request) -> T_before_request:
"""Register a function to run before each request.
For example, this can be used to open a database connection, or
@ -557,7 +557,7 @@ class Scaffold:
return f
@setupmethod
def after_request(self, f: ft.AfterRequestCallable) -> ft.AfterRequestCallable:
def after_request(self, f: T_after_request) -> T_after_request:
"""Register a function to run after each request to this object.
The function is called with the response object, and must return
@ -573,7 +573,7 @@ class Scaffold:
return f
@setupmethod
def teardown_request(self, f: ft.TeardownCallable) -> ft.TeardownCallable:
def teardown_request(self, f: T_teardown) -> T_teardown:
"""Register a function to be run at the end of each request,
regardless of whether there was an exception or not. These functions
are executed when the request context is popped, even if not an
@ -606,16 +606,18 @@ class Scaffold:
@setupmethod
def context_processor(
self, f: ft.TemplateContextProcessorCallable
) -> ft.TemplateContextProcessorCallable:
self,
f: T_template_context_processor,
) -> T_template_context_processor:
"""Registers a template context processor function."""
self.template_context_processors[None].append(f)
return f
@setupmethod
def url_value_preprocessor(
self, f: ft.URLValuePreprocessorCallable
) -> ft.URLValuePreprocessorCallable:
self,
f: T_url_value_preprocessor,
) -> T_url_value_preprocessor:
"""Register a URL value preprocessor function for all view
functions in the application. These functions will be called before the
:meth:`before_request` functions.
@ -632,7 +634,7 @@ class Scaffold:
return f
@setupmethod
def url_defaults(self, f: ft.URLDefaultCallable) -> ft.URLDefaultCallable:
def url_defaults(self, f: T_url_defaults) -> T_url_defaults:
"""Callback function for URL defaults for all view functions of the
application. It's called with the endpoint and values and should
update the values passed in place.
@ -643,7 +645,7 @@ class Scaffold:
@setupmethod
def errorhandler(
self, code_or_exception: t.Union[t.Type[Exception], int]
) -> t.Callable[[ft.ErrorHandlerDecorator], ft.ErrorHandlerDecorator]:
) -> t.Callable[[T_error_handler], T_error_handler]:
"""Register a function to handle errors by code or exception class.
A decorator that is used to register a function given an
@ -673,7 +675,7 @@ class Scaffold:
an arbitrary exception
"""
def decorator(f: ft.ErrorHandlerDecorator) -> ft.ErrorHandlerDecorator:
def decorator(f: T_error_handler) -> T_error_handler:
self.register_error_handler(code_or_exception, f)
return f

View file

@ -42,10 +42,22 @@ ResponseReturnValue = t.Union[
ResponseClass = t.TypeVar("ResponseClass", bound="Response")
AppOrBlueprintKey = t.Optional[str] # The App key is None, whereas blueprints are named
AfterRequestCallable = t.Callable[[ResponseClass], ResponseClass]
BeforeFirstRequestCallable = t.Callable[[], None]
BeforeRequestCallable = t.Callable[[], t.Optional[ResponseReturnValue]]
TeardownCallable = t.Callable[[t.Optional[BaseException]], None]
AfterRequestCallable = t.Union[
t.Callable[[ResponseClass], ResponseClass],
t.Callable[[ResponseClass], t.Awaitable[ResponseClass]],
]
BeforeFirstRequestCallable = t.Union[
t.Callable[[], None], t.Callable[[], t.Awaitable[None]]
]
BeforeRequestCallable = t.Union[
t.Callable[[], t.Optional[ResponseReturnValue]],
t.Callable[[], t.Awaitable[t.Optional[ResponseReturnValue]]],
]
ShellContextProcessorCallable = t.Callable[[], t.Dict[str, t.Any]]
TeardownCallable = t.Union[
t.Callable[[t.Optional[BaseException]], None],
t.Callable[[t.Optional[BaseException]], t.Awaitable[None]],
]
TemplateContextProcessorCallable = t.Callable[[], t.Dict[str, t.Any]]
TemplateFilterCallable = t.Callable[..., t.Any]
TemplateGlobalCallable = t.Callable[..., t.Any]
@ -60,7 +72,8 @@ URLValuePreprocessorCallable = t.Callable[[t.Optional[str], t.Optional[dict]], N
# https://github.com/pallets/flask/issues/4295
# https://github.com/pallets/flask/issues/4297
ErrorHandlerCallable = t.Callable[[t.Any], ResponseReturnValue]
ErrorHandlerDecorator = t.TypeVar("ErrorHandlerDecorator", bound=ErrorHandlerCallable)
ViewCallable = t.Callable[..., ResponseReturnValue]
RouteDecorator = t.TypeVar("RouteDecorator", bound=ViewCallable)
RouteCallable = t.Union[
t.Callable[..., ResponseReturnValue],
t.Callable[..., t.Awaitable[ResponseReturnValue]],
]

View file

@ -82,7 +82,7 @@ class View:
@classmethod
def as_view(
cls, name: str, *class_args: t.Any, **class_kwargs: t.Any
) -> ft.ViewCallable:
) -> ft.RouteCallable:
"""Convert the class into a view function that can be registered
for a route.

View file

@ -84,6 +84,11 @@ def return_template_stream() -> t.Iterator[str]:
return stream_template("index.html", name="Hello")
@app.route("/async")
async def async_route() -> str:
return "Hello"
class RenderTemplateView(View):
def __init__(self: RenderTemplateView, template_name: str) -> None:
self.template_name = template_name