improve typing for app.errorhandler decorator

This commit is contained in:
Pascal Corpet 2021-05-31 12:28:44 +02:00 committed by pgjones
parent 5205cd4ea9
commit 6a4e7e948d
5 changed files with 38 additions and 13 deletions

View file

@ -11,6 +11,7 @@ Unreleased
- Fixed the issue where typing requires template global - Fixed the issue where typing requires template global
decorators to accept functions with no arguments. :issue:`4098` decorators to accept functions with no arguments. :issue:`4098`
- Support View and MethodView instances with async handlers. :issue:`4112` - Support View and MethodView instances with async handlers. :issue:`4112`
- Enhance typing of ``app.errorhandler`` decorator. :issue:`4095`
Version 2.0.1 Version 2.0.1

View file

@ -61,7 +61,6 @@ from .templating import Environment
from .typing import AfterRequestCallable from .typing import AfterRequestCallable
from .typing import BeforeFirstRequestCallable from .typing import BeforeFirstRequestCallable
from .typing import BeforeRequestCallable from .typing import BeforeRequestCallable
from .typing import ErrorHandlerCallable
from .typing import ResponseReturnValue from .typing import ResponseReturnValue
from .typing import TeardownCallable from .typing import TeardownCallable
from .typing import TemplateContextProcessorCallable from .typing import TemplateContextProcessorCallable
@ -78,6 +77,7 @@ if t.TYPE_CHECKING:
from .blueprints import Blueprint from .blueprints import Blueprint
from .testing import FlaskClient from .testing import FlaskClient
from .testing import FlaskCliRunner from .testing import FlaskCliRunner
from .typing import ErrorHandlerCallable
if sys.version_info >= (3, 8): if sys.version_info >= (3, 8):
iscoroutinefunction = inspect.iscoroutinefunction iscoroutinefunction = inspect.iscoroutinefunction
@ -1268,7 +1268,9 @@ class Flask(Scaffold):
self.shell_context_processors.append(f) self.shell_context_processors.append(f)
return f return f
def _find_error_handler(self, e: Exception) -> t.Optional[ErrorHandlerCallable]: def _find_error_handler(
self, e: Exception
) -> t.Optional["ErrorHandlerCallable[Exception]"]:
"""Return a registered error handler for an exception in this order: """Return a registered error handler for an exception in this order:
blueprint handler for a specific code, app handler for a specific code, blueprint handler for a specific code, app handler for a specific code,
blueprint handler for an exception class, app handler for an exception blueprint handler for an exception class, app handler for an exception

View file

@ -8,7 +8,6 @@ from .scaffold import Scaffold
from .typing import AfterRequestCallable from .typing import AfterRequestCallable
from .typing import BeforeFirstRequestCallable from .typing import BeforeFirstRequestCallable
from .typing import BeforeRequestCallable from .typing import BeforeRequestCallable
from .typing import ErrorHandlerCallable
from .typing import TeardownCallable from .typing import TeardownCallable
from .typing import TemplateContextProcessorCallable from .typing import TemplateContextProcessorCallable
from .typing import TemplateFilterCallable from .typing import TemplateFilterCallable
@ -19,6 +18,7 @@ from .typing import URLValuePreprocessorCallable
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from .app import Flask from .app import Flask
from .typing import ErrorHandlerCallable
DeferredSetupFunction = t.Callable[["BlueprintSetupState"], t.Callable] DeferredSetupFunction = t.Callable[["BlueprintSetupState"], t.Callable]
@ -581,7 +581,9 @@ class Blueprint(Scaffold):
handler is used for all requests, even if outside of the blueprint. handler is used for all requests, even if outside of the blueprint.
""" """
def decorator(f: ErrorHandlerCallable) -> ErrorHandlerCallable: def decorator(
f: "ErrorHandlerCallable[Exception]",
) -> "ErrorHandlerCallable[Exception]":
self.record_once(lambda s: s.app.errorhandler(code)(f)) self.record_once(lambda s: s.app.errorhandler(code)(f))
return f return f

View file

@ -21,7 +21,7 @@ from .templating import _default_template_ctx_processor
from .typing import AfterRequestCallable from .typing import AfterRequestCallable
from .typing import AppOrBlueprintKey from .typing import AppOrBlueprintKey
from .typing import BeforeRequestCallable from .typing import BeforeRequestCallable
from .typing import ErrorHandlerCallable from .typing import GenericException
from .typing import TeardownCallable from .typing import TeardownCallable
from .typing import TemplateContextProcessorCallable from .typing import TemplateContextProcessorCallable
from .typing import URLDefaultCallable from .typing import URLDefaultCallable
@ -29,6 +29,7 @@ from .typing import URLValuePreprocessorCallable
if t.TYPE_CHECKING: if t.TYPE_CHECKING:
from .wrappers import Response from .wrappers import Response
from .typing import ErrorHandlerCallable
# a singleton sentinel value for parameter defaults # a singleton sentinel value for parameter defaults
_sentinel = object() _sentinel = object()
@ -144,7 +145,10 @@ class Scaffold:
#: directly and its format may change at any time. #: directly and its format may change at any time.
self.error_handler_spec: t.Dict[ self.error_handler_spec: t.Dict[
AppOrBlueprintKey, AppOrBlueprintKey,
t.Dict[t.Optional[int], t.Dict[t.Type[Exception], ErrorHandlerCallable]], t.Dict[
t.Optional[int],
t.Dict[t.Type[Exception], "ErrorHandlerCallable[Exception]"],
],
] = defaultdict(lambda: defaultdict(dict)) ] = defaultdict(lambda: defaultdict(dict))
#: A data structure of functions to call at the beginning of #: A data structure of functions to call at the beginning of
@ -643,8 +647,11 @@ class Scaffold:
@setupmethod @setupmethod
def errorhandler( def errorhandler(
self, code_or_exception: t.Union[t.Type[Exception], int] self, code_or_exception: t.Union[t.Type[GenericException], int]
) -> t.Callable[[ErrorHandlerCallable], ErrorHandlerCallable]: ) -> t.Callable[
["ErrorHandlerCallable[GenericException]"],
"ErrorHandlerCallable[GenericException]",
]:
"""Register a function to handle errors by code or exception class. """Register a function to handle errors by code or exception class.
A decorator that is used to register a function given an A decorator that is used to register a function given an
@ -674,7 +681,9 @@ class Scaffold:
an arbitrary exception an arbitrary exception
""" """
def decorator(f: ErrorHandlerCallable) -> ErrorHandlerCallable: def decorator(
f: "ErrorHandlerCallable[GenericException]",
) -> "ErrorHandlerCallable[GenericException]":
self.register_error_handler(code_or_exception, f) self.register_error_handler(code_or_exception, f)
return f return f
@ -683,8 +692,8 @@ class Scaffold:
@setupmethod @setupmethod
def register_error_handler( def register_error_handler(
self, self,
code_or_exception: t.Union[t.Type[Exception], int], code_or_exception: t.Union[t.Type[GenericException], int],
f: ErrorHandlerCallable, f: "ErrorHandlerCallable[GenericException]",
) -> None: ) -> None:
"""Alternative error attach function to the :meth:`errorhandler` """Alternative error attach function to the :meth:`errorhandler`
decorator that is more straightforward to use for non decorator decorator that is more straightforward to use for non decorator
@ -708,7 +717,9 @@ class Scaffold:
" instead." " instead."
) )
self.error_handler_spec[None][code][exc_class] = f self.error_handler_spec[None][code][exc_class] = t.cast(
"ErrorHandlerCallable[Exception]", f
)
@staticmethod @staticmethod
def _get_exc_class_and_code( def _get_exc_class_and_code(

View file

@ -33,11 +33,12 @@ ResponseReturnValue = t.Union[
"WSGIApplication", "WSGIApplication",
] ]
GenericException = t.TypeVar("GenericException", bound=Exception, contravariant=True)
AppOrBlueprintKey = t.Optional[str] # The App key is None, whereas blueprints are named AppOrBlueprintKey = t.Optional[str] # The App key is None, whereas blueprints are named
AfterRequestCallable = t.Callable[["Response"], "Response"] AfterRequestCallable = t.Callable[["Response"], "Response"]
BeforeFirstRequestCallable = t.Callable[[], None] BeforeFirstRequestCallable = t.Callable[[], None]
BeforeRequestCallable = t.Callable[[], t.Optional[ResponseReturnValue]] BeforeRequestCallable = t.Callable[[], t.Optional[ResponseReturnValue]]
ErrorHandlerCallable = t.Callable[[Exception], ResponseReturnValue]
TeardownCallable = t.Callable[[t.Optional[BaseException]], None] TeardownCallable = t.Callable[[t.Optional[BaseException]], None]
TemplateContextProcessorCallable = t.Callable[[], t.Dict[str, t.Any]] TemplateContextProcessorCallable = t.Callable[[], t.Dict[str, t.Any]]
TemplateFilterCallable = t.Callable[..., t.Any] TemplateFilterCallable = t.Callable[..., t.Any]
@ -45,3 +46,11 @@ TemplateGlobalCallable = t.Callable[..., t.Any]
TemplateTestCallable = t.Callable[..., bool] TemplateTestCallable = t.Callable[..., bool]
URLDefaultCallable = t.Callable[[str, dict], None] URLDefaultCallable = t.Callable[[str, dict], None]
URLValuePreprocessorCallable = t.Callable[[t.Optional[str], t.Optional[dict]], None] URLValuePreprocessorCallable = t.Callable[[t.Optional[str], t.Optional[dict]], None]
if t.TYPE_CHECKING:
import typing_extensions as te
class ErrorHandlerCallable(te.Protocol[GenericException]):
def __call__(self, error: GenericException) -> ResponseReturnValue:
...