From 9b44bf2818d8e3cde422ad7f43fb33dfc6737289 Mon Sep 17 00:00:00 2001 From: Phil Jones Date: Wed, 6 Jul 2022 22:05:20 +0100 Subject: [PATCH] Improve decorator typing (#4676) * Add a missing setupmethod decorator * Improve the decorator typing This will allow type checkers to understand that the decorators return the same function signature as passed as an argument. This follows the guidelines from https://mypy.readthedocs.io/en/stable/generics.html#declaring-decorators. I've chosen to keep a TypeVar per module and usage as I think encouraged by PEP 695, which I hope is accepted as the syntax is much nicer. --- src/flask/app.py | 37 +++++++++++++------- src/flask/blueprints.py | 60 +++++++++++++++++++------------ src/flask/scaffold.py | 68 +++++++++++++++++++----------------- src/flask/typing.py | 27 ++++++++++---- src/flask/views.py | 2 +- tests/typing/typing_route.py | 5 +++ 6 files changed, 123 insertions(+), 76 deletions(-) diff --git a/src/flask/app.py b/src/flask/app.py index 8726a6e8..736061b9 100644 --- a/src/flask/app.py +++ b/src/flask/app.py @@ -74,6 +74,17 @@ if t.TYPE_CHECKING: # pragma: no cover from .testing import FlaskClient from .testing import FlaskCliRunner +T_before_first_request = t.TypeVar( + "T_before_first_request", bound=ft.BeforeFirstRequestCallable +) +T_shell_context_processor = t.TypeVar( + "T_shell_context_processor", bound=ft.ShellContextProcessorCallable +) +T_teardown = t.TypeVar("T_teardown", bound=ft.TeardownCallable) +T_template_filter = t.TypeVar("T_template_filter", bound=ft.TemplateFilterCallable) +T_template_global = t.TypeVar("T_template_global", bound=ft.TemplateGlobalCallable) +T_template_test = t.TypeVar("T_template_test", bound=ft.TemplateTestCallable) + if sys.version_info >= (3, 8): iscoroutinefunction = inspect.iscoroutinefunction else: @@ -472,7 +483,7 @@ class Flask(Scaffold): #: when a shell context is created. #: #: .. versionadded:: 0.11 - self.shell_context_processors: t.List[t.Callable[[], t.Dict[str, t.Any]]] = [] + self.shell_context_processors: t.List[ft.ShellContextProcessorCallable] = [] #: Maps registered blueprint names to blueprint objects. The #: dict retains the order the blueprints were registered in. @@ -1075,7 +1086,7 @@ class Flask(Scaffold): self, rule: str, endpoint: t.Optional[str] = None, - view_func: t.Optional[ft.ViewCallable] = None, + view_func: t.Optional[ft.RouteCallable] = None, provide_automatic_options: t.Optional[bool] = None, **options: t.Any, ) -> None: @@ -1132,7 +1143,7 @@ class Flask(Scaffold): @setupmethod def template_filter( self, name: t.Optional[str] = None - ) -> t.Callable[[ft.TemplateFilterCallable], ft.TemplateFilterCallable]: + ) -> t.Callable[[T_template_filter], T_template_filter]: """A decorator that is used to register custom template filter. You can specify a name for the filter, otherwise the function name will be used. Example:: @@ -1145,7 +1156,7 @@ class Flask(Scaffold): function name will be used. """ - def decorator(f: ft.TemplateFilterCallable) -> ft.TemplateFilterCallable: + def decorator(f: T_template_filter) -> T_template_filter: self.add_template_filter(f, name=name) return f @@ -1166,7 +1177,7 @@ class Flask(Scaffold): @setupmethod def template_test( self, name: t.Optional[str] = None - ) -> t.Callable[[ft.TemplateTestCallable], ft.TemplateTestCallable]: + ) -> t.Callable[[T_template_test], T_template_test]: """A decorator that is used to register custom template test. You can specify a name for the test, otherwise the function name will be used. Example:: @@ -1186,7 +1197,7 @@ class Flask(Scaffold): function name will be used. """ - def decorator(f: ft.TemplateTestCallable) -> ft.TemplateTestCallable: + def decorator(f: T_template_test) -> T_template_test: self.add_template_test(f, name=name) return f @@ -1209,7 +1220,7 @@ class Flask(Scaffold): @setupmethod def template_global( self, name: t.Optional[str] = None - ) -> t.Callable[[ft.TemplateGlobalCallable], ft.TemplateGlobalCallable]: + ) -> t.Callable[[T_template_global], T_template_global]: """A decorator that is used to register a custom template global function. You can specify a name for the global function, otherwise the function name will be used. Example:: @@ -1224,7 +1235,7 @@ class Flask(Scaffold): function name will be used. """ - def decorator(f: ft.TemplateGlobalCallable) -> ft.TemplateGlobalCallable: + def decorator(f: T_template_global) -> T_template_global: self.add_template_global(f, name=name) return f @@ -1245,9 +1256,7 @@ class Flask(Scaffold): self.jinja_env.globals[name or f.__name__] = f @setupmethod - def before_first_request( - self, f: ft.BeforeFirstRequestCallable - ) -> ft.BeforeFirstRequestCallable: + def before_first_request(self, f: T_before_first_request) -> T_before_first_request: """Registers a function to be run before the first request to this instance of the application. @@ -1273,7 +1282,7 @@ class Flask(Scaffold): return f @setupmethod - def teardown_appcontext(self, f: ft.TeardownCallable) -> ft.TeardownCallable: + def teardown_appcontext(self, f: T_teardown) -> T_teardown: """Registers a function to be called when the application context ends. These functions are typically also called when the request context is popped. @@ -1306,7 +1315,9 @@ class Flask(Scaffold): return f @setupmethod - def shell_context_processor(self, f: t.Callable) -> t.Callable: + def shell_context_processor( + self, f: T_shell_context_processor + ) -> T_shell_context_processor: """Registers a shell context processor function. .. versionadded:: 0.11 diff --git a/src/flask/blueprints.py b/src/flask/blueprints.py index 76b36067..6deda47e 100644 --- a/src/flask/blueprints.py +++ b/src/flask/blueprints.py @@ -13,6 +13,23 @@ if t.TYPE_CHECKING: # pragma: no cover from .app import Flask DeferredSetupFunction = t.Callable[["BlueprintSetupState"], t.Callable] +T_after_request = t.TypeVar("T_after_request", bound=ft.AfterRequestCallable) +T_before_first_request = t.TypeVar( + "T_before_first_request", bound=ft.BeforeFirstRequestCallable +) +T_before_request = t.TypeVar("T_before_request", bound=ft.BeforeRequestCallable) +T_error_handler = t.TypeVar("T_error_handler", bound=ft.ErrorHandlerCallable) +T_teardown = t.TypeVar("T_teardown", bound=ft.TeardownCallable) +T_template_context_processor = t.TypeVar( + "T_template_context_processor", bound=ft.TemplateContextProcessorCallable +) +T_template_filter = t.TypeVar("T_template_filter", bound=ft.TemplateFilterCallable) +T_template_global = t.TypeVar("T_template_global", bound=ft.TemplateGlobalCallable) +T_template_test = t.TypeVar("T_template_test", bound=ft.TemplateTestCallable) +T_url_defaults = t.TypeVar("T_url_defaults", bound=ft.URLDefaultCallable) +T_url_value_preprocessor = t.TypeVar( + "T_url_value_preprocessor", bound=ft.URLValuePreprocessorCallable +) class BlueprintSetupState: @@ -236,7 +253,7 @@ class Blueprint(Scaffold): if state.first_registration: func(state) - return self.record(update_wrapper(wrapper, func)) + self.record(update_wrapper(wrapper, func)) def make_setup_state( self, app: "Flask", options: dict, first_registration: bool = False @@ -392,7 +409,7 @@ class Blueprint(Scaffold): self, rule: str, endpoint: t.Optional[str] = None, - view_func: t.Optional[ft.ViewCallable] = None, + view_func: t.Optional[ft.RouteCallable] = None, provide_automatic_options: t.Optional[bool] = None, **options: t.Any, ) -> None: @@ -418,7 +435,7 @@ class Blueprint(Scaffold): @setupmethod def app_template_filter( self, name: t.Optional[str] = None - ) -> t.Callable[[ft.TemplateFilterCallable], ft.TemplateFilterCallable]: + ) -> t.Callable[[T_template_filter], T_template_filter]: """Register a custom template filter, available application wide. Like :meth:`Flask.template_filter` but for a blueprint. @@ -426,7 +443,7 @@ class Blueprint(Scaffold): function name will be used. """ - def decorator(f: ft.TemplateFilterCallable) -> ft.TemplateFilterCallable: + def decorator(f: T_template_filter) -> T_template_filter: self.add_app_template_filter(f, name=name) return f @@ -452,7 +469,7 @@ class Blueprint(Scaffold): @setupmethod def app_template_test( self, name: t.Optional[str] = None - ) -> t.Callable[[ft.TemplateTestCallable], ft.TemplateTestCallable]: + ) -> t.Callable[[T_template_test], T_template_test]: """Register a custom template test, available application wide. Like :meth:`Flask.template_test` but for a blueprint. @@ -462,7 +479,7 @@ class Blueprint(Scaffold): function name will be used. """ - def decorator(f: ft.TemplateTestCallable) -> ft.TemplateTestCallable: + def decorator(f: T_template_test) -> T_template_test: self.add_app_template_test(f, name=name) return f @@ -490,7 +507,7 @@ class Blueprint(Scaffold): @setupmethod def app_template_global( self, name: t.Optional[str] = None - ) -> t.Callable[[ft.TemplateGlobalCallable], ft.TemplateGlobalCallable]: + ) -> t.Callable[[T_template_global], T_template_global]: """Register a custom template global, available application wide. Like :meth:`Flask.template_global` but for a blueprint. @@ -500,7 +517,7 @@ class Blueprint(Scaffold): function name will be used. """ - def decorator(f: ft.TemplateGlobalCallable) -> ft.TemplateGlobalCallable: + def decorator(f: T_template_global) -> T_template_global: self.add_app_template_global(f, name=name) return f @@ -526,9 +543,7 @@ class Blueprint(Scaffold): self.record_once(register_template) @setupmethod - def before_app_request( - self, f: ft.BeforeRequestCallable - ) -> ft.BeforeRequestCallable: + def before_app_request(self, f: T_before_request) -> T_before_request: """Like :meth:`Flask.before_request`. Such a function is executed before each request, even if outside of a blueprint. """ @@ -539,8 +554,8 @@ class Blueprint(Scaffold): @setupmethod def before_app_first_request( - self, f: ft.BeforeFirstRequestCallable - ) -> ft.BeforeFirstRequestCallable: + self, f: T_before_first_request + ) -> T_before_first_request: """Like :meth:`Flask.before_first_request`. Such a function is executed before the first request to the application. @@ -560,7 +575,8 @@ class Blueprint(Scaffold): self.record_once(lambda s: s.app.before_first_request_funcs.append(f)) return f - def after_app_request(self, f: ft.AfterRequestCallable) -> ft.AfterRequestCallable: + @setupmethod + def after_app_request(self, f: T_after_request) -> T_after_request: """Like :meth:`Flask.after_request` but for a blueprint. Such a function is executed after each request, even if outside of the blueprint. """ @@ -570,7 +586,7 @@ class Blueprint(Scaffold): return f @setupmethod - def teardown_app_request(self, f: ft.TeardownCallable) -> ft.TeardownCallable: + def teardown_app_request(self, f: T_teardown) -> T_teardown: """Like :meth:`Flask.teardown_request` but for a blueprint. Such a function is executed when tearing down each request, even if outside of the blueprint. @@ -582,8 +598,8 @@ class Blueprint(Scaffold): @setupmethod def app_context_processor( - self, f: ft.TemplateContextProcessorCallable - ) -> ft.TemplateContextProcessorCallable: + self, f: T_template_context_processor + ) -> T_template_context_processor: """Like :meth:`Flask.context_processor` but for a blueprint. Such a function is executed each request, even if outside of the blueprint. """ @@ -595,12 +611,12 @@ class Blueprint(Scaffold): @setupmethod def app_errorhandler( self, code: t.Union[t.Type[Exception], int] - ) -> t.Callable[[ft.ErrorHandlerDecorator], ft.ErrorHandlerDecorator]: + ) -> t.Callable[[T_error_handler], T_error_handler]: """Like :meth:`Flask.errorhandler` but for a blueprint. This handler is used for all requests, even if outside of the blueprint. """ - def decorator(f: ft.ErrorHandlerDecorator) -> ft.ErrorHandlerDecorator: + def decorator(f: T_error_handler) -> T_error_handler: self.record_once(lambda s: s.app.errorhandler(code)(f)) return f @@ -608,8 +624,8 @@ class Blueprint(Scaffold): @setupmethod def app_url_value_preprocessor( - self, f: ft.URLValuePreprocessorCallable - ) -> ft.URLValuePreprocessorCallable: + self, f: T_url_value_preprocessor + ) -> T_url_value_preprocessor: """Same as :meth:`url_value_preprocessor` but application wide.""" self.record_once( lambda s: s.app.url_value_preprocessors.setdefault(None, []).append(f) @@ -617,7 +633,7 @@ class Blueprint(Scaffold): return f @setupmethod - def app_url_defaults(self, f: ft.URLDefaultCallable) -> ft.URLDefaultCallable: + def app_url_defaults(self, f: T_url_defaults) -> T_url_defaults: """Same as :meth:`url_defaults` but application wide.""" self.record_once( lambda s: s.app.url_default_functions.setdefault(None, []).append(f) diff --git a/src/flask/scaffold.py b/src/flask/scaffold.py index 418b24ae..7f099f40 100644 --- a/src/flask/scaffold.py +++ b/src/flask/scaffold.py @@ -28,6 +28,18 @@ if t.TYPE_CHECKING: # pragma: no cover _sentinel = object() F = t.TypeVar("F", bound=t.Callable[..., t.Any]) +T_after_request = t.TypeVar("T_after_request", bound=ft.AfterRequestCallable) +T_before_request = t.TypeVar("T_before_request", bound=ft.BeforeRequestCallable) +T_error_handler = t.TypeVar("T_error_handler", bound=ft.ErrorHandlerCallable) +T_teardown = t.TypeVar("T_teardown", bound=ft.TeardownCallable) +T_template_context_processor = t.TypeVar( + "T_template_context_processor", bound=ft.TemplateContextProcessorCallable +) +T_url_defaults = t.TypeVar("T_url_defaults", bound=ft.URLDefaultCallable) +T_url_value_preprocessor = t.TypeVar( + "T_url_value_preprocessor", bound=ft.URLValuePreprocessorCallable +) +T_route = t.TypeVar("T_route", bound=ft.RouteCallable) def setupmethod(f: F) -> F: @@ -352,16 +364,14 @@ class Scaffold: method: str, rule: str, options: dict, - ) -> t.Callable[[ft.RouteDecorator], ft.RouteDecorator]: + ) -> t.Callable[[T_route], T_route]: if "methods" in options: raise TypeError("Use the 'route' decorator to use the 'methods' argument.") return self.route(rule, methods=[method], **options) @setupmethod - def get( - self, rule: str, **options: t.Any - ) -> t.Callable[[ft.RouteDecorator], ft.RouteDecorator]: + def get(self, rule: str, **options: t.Any) -> t.Callable[[T_route], T_route]: """Shortcut for :meth:`route` with ``methods=["GET"]``. .. versionadded:: 2.0 @@ -369,9 +379,7 @@ class Scaffold: return self._method_route("GET", rule, options) @setupmethod - def post( - self, rule: str, **options: t.Any - ) -> t.Callable[[ft.RouteDecorator], ft.RouteDecorator]: + def post(self, rule: str, **options: t.Any) -> t.Callable[[T_route], T_route]: """Shortcut for :meth:`route` with ``methods=["POST"]``. .. versionadded:: 2.0 @@ -379,9 +387,7 @@ class Scaffold: return self._method_route("POST", rule, options) @setupmethod - def put( - self, rule: str, **options: t.Any - ) -> t.Callable[[ft.RouteDecorator], ft.RouteDecorator]: + def put(self, rule: str, **options: t.Any) -> t.Callable[[T_route], T_route]: """Shortcut for :meth:`route` with ``methods=["PUT"]``. .. versionadded:: 2.0 @@ -389,9 +395,7 @@ class Scaffold: return self._method_route("PUT", rule, options) @setupmethod - def delete( - self, rule: str, **options: t.Any - ) -> t.Callable[[ft.RouteDecorator], ft.RouteDecorator]: + def delete(self, rule: str, **options: t.Any) -> t.Callable[[T_route], T_route]: """Shortcut for :meth:`route` with ``methods=["DELETE"]``. .. versionadded:: 2.0 @@ -399,9 +403,7 @@ class Scaffold: return self._method_route("DELETE", rule, options) @setupmethod - def patch( - self, rule: str, **options: t.Any - ) -> t.Callable[[ft.RouteDecorator], ft.RouteDecorator]: + def patch(self, rule: str, **options: t.Any) -> t.Callable[[T_route], T_route]: """Shortcut for :meth:`route` with ``methods=["PATCH"]``. .. versionadded:: 2.0 @@ -409,9 +411,7 @@ class Scaffold: return self._method_route("PATCH", rule, options) @setupmethod - def route( - self, rule: str, **options: t.Any - ) -> t.Callable[[ft.RouteDecorator], ft.RouteDecorator]: + def route(self, rule: str, **options: t.Any) -> t.Callable[[T_route], T_route]: """Decorate a view function to register it with the given URL rule and options. Calls :meth:`add_url_rule`, which has more details about the implementation. @@ -435,7 +435,7 @@ class Scaffold: :class:`~werkzeug.routing.Rule` object. """ - def decorator(f: ft.RouteDecorator) -> ft.RouteDecorator: + def decorator(f: T_route) -> T_route: endpoint = options.pop("endpoint", None) self.add_url_rule(rule, endpoint, f, **options) return f @@ -447,7 +447,7 @@ class Scaffold: self, rule: str, endpoint: t.Optional[str] = None, - view_func: t.Optional[ft.ViewCallable] = None, + view_func: t.Optional[ft.RouteCallable] = None, provide_automatic_options: t.Optional[bool] = None, **options: t.Any, ) -> None: @@ -511,7 +511,7 @@ class Scaffold: raise NotImplementedError @setupmethod - def endpoint(self, endpoint: str) -> t.Callable: + def endpoint(self, endpoint: str) -> t.Callable[[F], F]: """Decorate a view function to register it for the given endpoint. Used if a rule is added without a ``view_func`` with :meth:`add_url_rule`. @@ -528,14 +528,14 @@ class Scaffold: function. """ - def decorator(f): + def decorator(f: F) -> F: self.view_functions[endpoint] = f return f return decorator @setupmethod - def before_request(self, f: ft.BeforeRequestCallable) -> ft.BeforeRequestCallable: + def before_request(self, f: T_before_request) -> T_before_request: """Register a function to run before each request. For example, this can be used to open a database connection, or @@ -557,7 +557,7 @@ class Scaffold: return f @setupmethod - def after_request(self, f: ft.AfterRequestCallable) -> ft.AfterRequestCallable: + def after_request(self, f: T_after_request) -> T_after_request: """Register a function to run after each request to this object. The function is called with the response object, and must return @@ -573,7 +573,7 @@ class Scaffold: return f @setupmethod - def teardown_request(self, f: ft.TeardownCallable) -> ft.TeardownCallable: + def teardown_request(self, f: T_teardown) -> T_teardown: """Register a function to be run at the end of each request, regardless of whether there was an exception or not. These functions are executed when the request context is popped, even if not an @@ -606,16 +606,18 @@ class Scaffold: @setupmethod def context_processor( - self, f: ft.TemplateContextProcessorCallable - ) -> ft.TemplateContextProcessorCallable: + self, + f: T_template_context_processor, + ) -> T_template_context_processor: """Registers a template context processor function.""" self.template_context_processors[None].append(f) return f @setupmethod def url_value_preprocessor( - self, f: ft.URLValuePreprocessorCallable - ) -> ft.URLValuePreprocessorCallable: + self, + f: T_url_value_preprocessor, + ) -> T_url_value_preprocessor: """Register a URL value preprocessor function for all view functions in the application. These functions will be called before the :meth:`before_request` functions. @@ -632,7 +634,7 @@ class Scaffold: return f @setupmethod - def url_defaults(self, f: ft.URLDefaultCallable) -> ft.URLDefaultCallable: + def url_defaults(self, f: T_url_defaults) -> T_url_defaults: """Callback function for URL defaults for all view functions of the application. It's called with the endpoint and values and should update the values passed in place. @@ -643,7 +645,7 @@ class Scaffold: @setupmethod def errorhandler( self, code_or_exception: t.Union[t.Type[Exception], int] - ) -> t.Callable[[ft.ErrorHandlerDecorator], ft.ErrorHandlerDecorator]: + ) -> t.Callable[[T_error_handler], T_error_handler]: """Register a function to handle errors by code or exception class. A decorator that is used to register a function given an @@ -673,7 +675,7 @@ class Scaffold: an arbitrary exception """ - def decorator(f: ft.ErrorHandlerDecorator) -> ft.ErrorHandlerDecorator: + def decorator(f: T_error_handler) -> T_error_handler: self.register_error_handler(code_or_exception, f) return f diff --git a/src/flask/typing.py b/src/flask/typing.py index 6bbdb1dd..89bdc71e 100644 --- a/src/flask/typing.py +++ b/src/flask/typing.py @@ -42,10 +42,22 @@ ResponseReturnValue = t.Union[ ResponseClass = t.TypeVar("ResponseClass", bound="Response") AppOrBlueprintKey = t.Optional[str] # The App key is None, whereas blueprints are named -AfterRequestCallable = t.Callable[[ResponseClass], ResponseClass] -BeforeFirstRequestCallable = t.Callable[[], None] -BeforeRequestCallable = t.Callable[[], t.Optional[ResponseReturnValue]] -TeardownCallable = t.Callable[[t.Optional[BaseException]], None] +AfterRequestCallable = t.Union[ + t.Callable[[ResponseClass], ResponseClass], + t.Callable[[ResponseClass], t.Awaitable[ResponseClass]], +] +BeforeFirstRequestCallable = t.Union[ + t.Callable[[], None], t.Callable[[], t.Awaitable[None]] +] +BeforeRequestCallable = t.Union[ + t.Callable[[], t.Optional[ResponseReturnValue]], + t.Callable[[], t.Awaitable[t.Optional[ResponseReturnValue]]], +] +ShellContextProcessorCallable = t.Callable[[], t.Dict[str, t.Any]] +TeardownCallable = t.Union[ + t.Callable[[t.Optional[BaseException]], None], + t.Callable[[t.Optional[BaseException]], t.Awaitable[None]], +] TemplateContextProcessorCallable = t.Callable[[], t.Dict[str, t.Any]] TemplateFilterCallable = t.Callable[..., t.Any] TemplateGlobalCallable = t.Callable[..., t.Any] @@ -60,7 +72,8 @@ URLValuePreprocessorCallable = t.Callable[[t.Optional[str], t.Optional[dict]], N # https://github.com/pallets/flask/issues/4295 # https://github.com/pallets/flask/issues/4297 ErrorHandlerCallable = t.Callable[[t.Any], ResponseReturnValue] -ErrorHandlerDecorator = t.TypeVar("ErrorHandlerDecorator", bound=ErrorHandlerCallable) -ViewCallable = t.Callable[..., ResponseReturnValue] -RouteDecorator = t.TypeVar("RouteDecorator", bound=ViewCallable) +RouteCallable = t.Union[ + t.Callable[..., ResponseReturnValue], + t.Callable[..., t.Awaitable[ResponseReturnValue]], +] diff --git a/src/flask/views.py b/src/flask/views.py index 7aac3dd5..a82f1912 100644 --- a/src/flask/views.py +++ b/src/flask/views.py @@ -82,7 +82,7 @@ class View: @classmethod def as_view( cls, name: str, *class_args: t.Any, **class_kwargs: t.Any - ) -> ft.ViewCallable: + ) -> ft.RouteCallable: """Convert the class into a view function that can be registered for a route. diff --git a/tests/typing/typing_route.py b/tests/typing/typing_route.py index 41973c26..5f2ddbfd 100644 --- a/tests/typing/typing_route.py +++ b/tests/typing/typing_route.py @@ -84,6 +84,11 @@ def return_template_stream() -> t.Iterator[str]: return stream_template("index.html", name="Hello") +@app.route("/async") +async def async_route() -> str: + return "Hello" + + class RenderTemplateView(View): def __init__(self: RenderTemplateView, template_name: str) -> None: self.template_name = template_name