Add initial type hints

This should make it easier for users to correctly use Flask. The hints
are from Quart.
This commit is contained in:
pgjones 2021-04-24 12:22:26 +01:00
parent 943009c8ec
commit fa5c99ef4f
20 changed files with 820 additions and 461 deletions

View file

@ -2,8 +2,11 @@ import importlib.util
import os
import pkgutil
import sys
import typing as t
from collections import defaultdict
from functools import update_wrapper
from json import JSONDecoder
from json import JSONEncoder
from jinja2 import FileSystemLoader
from werkzeug.exceptions import default_exceptions
@ -14,17 +17,28 @@ from .globals import current_app
from .helpers import locked_cached_property
from .helpers import send_from_directory
from .templating import _default_template_ctx_processor
from .typing import AfterRequestCallable
from .typing import AppOrBlueprintKey
from .typing import BeforeRequestCallable
from .typing import ErrorHandlerCallable
from .typing import TeardownCallable
from .typing import TemplateContextProcessorCallable
from .typing import URLDefaultCallable
from .typing import URLValuePreprocessorCallable
if t.TYPE_CHECKING:
from .wrappers import Response
# a singleton sentinel value for parameter defaults
_sentinel = object()
def setupmethod(f):
def setupmethod(f: t.Callable) -> t.Callable:
"""Wraps a method so that it performs a check in debug mode if the
first request was already handled.
"""
def wrapper_func(self, *args, **kwargs):
def wrapper_func(self, *args: t.Any, **kwargs: t.Any) -> t.Any:
if self._is_setup_finished():
raise AssertionError(
"A setup function was called after the first request "
@ -60,24 +74,24 @@ class Scaffold:
"""
name: str
_static_folder = None
_static_url_path = None
_static_folder: t.Optional[str] = None
_static_url_path: t.Optional[str] = None
#: JSON encoder class used by :func:`flask.json.dumps`. If a
#: blueprint sets this, it will be used instead of the app's value.
json_encoder = None
json_encoder: t.Optional[t.Type[JSONEncoder]] = None
#: JSON decoder class used by :func:`flask.json.loads`. If a
#: blueprint sets this, it will be used instead of the app's value.
json_decoder = None
json_decoder: t.Optional[t.Type[JSONDecoder]] = None
def __init__(
self,
import_name,
static_folder=None,
static_url_path=None,
template_folder=None,
root_path=None,
import_name: str,
static_folder: t.Optional[str] = None,
static_url_path: t.Optional[str] = None,
template_folder: t.Optional[str] = None,
root_path: t.Optional[str] = None,
):
#: The name of the package or module that this object belongs
#: to. Do not change this once it is set by the constructor.
@ -110,7 +124,7 @@ class Scaffold:
#:
#: This data structure is internal. It should not be modified
#: directly and its format may change at any time.
self.view_functions = {}
self.view_functions: t.Dict[str, t.Callable] = {}
#: A data structure of registered error handlers, in the format
#: ``{scope: {code: {class: handler}}}```. The ``scope`` key is
@ -125,7 +139,10 @@ class Scaffold:
#:
#: This data structure is internal. It should not be modified
#: directly and its format may change at any time.
self.error_handler_spec = defaultdict(lambda: defaultdict(dict))
self.error_handler_spec: t.Dict[
AppOrBlueprintKey,
t.Dict[t.Optional[int], t.Dict[t.Type[Exception], ErrorHandlerCallable]],
] = defaultdict(lambda: defaultdict(dict))
#: A data structure of functions to call at the beginning of
#: each request, in the format ``{scope: [functions]}``. The
@ -137,7 +154,9 @@ class Scaffold:
#:
#: This data structure is internal. It should not be modified
#: directly and its format may change at any time.
self.before_request_funcs = defaultdict(list)
self.before_request_funcs: t.Dict[
AppOrBlueprintKey, t.List[BeforeRequestCallable]
] = defaultdict(list)
#: A data structure of functions to call at the end of each
#: request, in the format ``{scope: [functions]}``. The
@ -149,7 +168,9 @@ class Scaffold:
#:
#: This data structure is internal. It should not be modified
#: directly and its format may change at any time.
self.after_request_funcs = defaultdict(list)
self.after_request_funcs: t.Dict[
AppOrBlueprintKey, t.List[AfterRequestCallable]
] = defaultdict(list)
#: A data structure of functions to call at the end of each
#: request even if an exception is raised, in the format
@ -162,7 +183,9 @@ class Scaffold:
#:
#: This data structure is internal. It should not be modified
#: directly and its format may change at any time.
self.teardown_request_funcs = defaultdict(list)
self.teardown_request_funcs: t.Dict[
AppOrBlueprintKey, t.List[TeardownCallable]
] = defaultdict(list)
#: A data structure of functions to call to pass extra context
#: values when rendering templates, in the format
@ -175,9 +198,9 @@ class Scaffold:
#:
#: This data structure is internal. It should not be modified
#: directly and its format may change at any time.
self.template_context_processors = defaultdict(
list, {None: [_default_template_ctx_processor]}
)
self.template_context_processors: t.Dict[
AppOrBlueprintKey, t.List[TemplateContextProcessorCallable]
] = defaultdict(list, {None: [_default_template_ctx_processor]})
#: A data structure of functions to call to modify the keyword
#: arguments passed to the view function, in the format
@ -190,7 +213,10 @@ class Scaffold:
#:
#: This data structure is internal. It should not be modified
#: directly and its format may change at any time.
self.url_value_preprocessors = defaultdict(list)
self.url_value_preprocessors: t.Dict[
AppOrBlueprintKey,
t.List[URLValuePreprocessorCallable],
] = defaultdict(list)
#: A data structure of functions to call to modify the keyword
#: arguments when generating URLs, in the format
@ -203,31 +229,35 @@ class Scaffold:
#:
#: This data structure is internal. It should not be modified
#: directly and its format may change at any time.
self.url_default_functions = defaultdict(list)
self.url_default_functions: t.Dict[
AppOrBlueprintKey, t.List[URLDefaultCallable]
] = defaultdict(list)
def __repr__(self):
def __repr__(self) -> str:
return f"<{type(self).__name__} {self.name!r}>"
def _is_setup_finished(self):
def _is_setup_finished(self) -> bool:
raise NotImplementedError
@property
def static_folder(self):
def static_folder(self) -> t.Optional[str]:
"""The absolute path to the configured static folder. ``None``
if no static folder is set.
"""
if self._static_folder is not None:
return os.path.join(self.root_path, self._static_folder)
else:
return None
@static_folder.setter
def static_folder(self, value):
def static_folder(self, value: t.Optional[str]) -> None:
if value is not None:
value = os.fspath(value).rstrip(r"\/")
self._static_folder = value
@property
def has_static_folder(self):
def has_static_folder(self) -> bool:
"""``True`` if :attr:`static_folder` is set.
.. versionadded:: 0.5
@ -235,7 +265,7 @@ class Scaffold:
return self.static_folder is not None
@property
def static_url_path(self):
def static_url_path(self) -> t.Optional[str]:
"""The URL prefix that the static route will be accessible from.
If it was not configured during init, it is derived from
@ -248,14 +278,16 @@ class Scaffold:
basename = os.path.basename(self.static_folder)
return f"/{basename}".rstrip("/")
return None
@static_url_path.setter
def static_url_path(self, value):
def static_url_path(self, value: t.Optional[str]) -> None:
if value is not None:
value = value.rstrip("/")
self._static_url_path = value
def get_send_file_max_age(self, filename):
def get_send_file_max_age(self, filename: str) -> t.Optional[int]:
"""Used by :func:`send_file` to determine the ``max_age`` cache
value for a given file path if it wasn't passed.
@ -276,7 +308,7 @@ class Scaffold:
return int(value.total_seconds())
def send_static_file(self, filename):
def send_static_file(self, filename: str) -> "Response":
"""The view function used to serve files from
:attr:`static_folder`. A route is automatically registered for
this view at :attr:`static_url_path` if :attr:`static_folder` is
@ -290,10 +322,12 @@ class Scaffold:
# send_file only knows to call get_send_file_max_age on the app,
# call it here so it works for blueprints too.
max_age = self.get_send_file_max_age(filename)
return send_from_directory(self.static_folder, filename, max_age=max_age)
return send_from_directory(
t.cast(str, self.static_folder), filename, max_age=max_age
)
@locked_cached_property
def jinja_loader(self):
def jinja_loader(self) -> t.Optional[FileSystemLoader]:
"""The Jinja loader for this object's templates. By default this
is a class :class:`jinja2.loaders.FileSystemLoader` to
:attr:`template_folder` if it is set.
@ -302,8 +336,10 @@ class Scaffold:
"""
if self.template_folder is not None:
return FileSystemLoader(os.path.join(self.root_path, self.template_folder))
else:
return None
def open_resource(self, resource, mode="rb"):
def open_resource(self, resource: str, mode: str = "rb") -> t.IO[t.AnyStr]:
"""Open a resource file relative to :attr:`root_path` for
reading.
@ -326,48 +362,48 @@ class Scaffold:
return open(os.path.join(self.root_path, resource), mode)
def _method_route(self, method, rule, options):
def _method_route(self, method: str, rule: str, options: dict) -> t.Callable:
if "methods" in options:
raise TypeError("Use the 'route' decorator to use the 'methods' argument.")
return self.route(rule, methods=[method], **options)
def get(self, rule, **options):
def get(self, rule: str, **options: t.Any) -> t.Callable:
"""Shortcut for :meth:`route` with ``methods=["GET"]``.
.. versionadded:: 2.0
"""
return self._method_route("GET", rule, options)
def post(self, rule, **options):
def post(self, rule: str, **options: t.Any) -> t.Callable:
"""Shortcut for :meth:`route` with ``methods=["POST"]``.
.. versionadded:: 2.0
"""
return self._method_route("POST", rule, options)
def put(self, rule, **options):
def put(self, rule: str, **options: t.Any) -> t.Callable:
"""Shortcut for :meth:`route` with ``methods=["PUT"]``.
.. versionadded:: 2.0
"""
return self._method_route("PUT", rule, options)
def delete(self, rule, **options):
def delete(self, rule: str, **options: t.Any) -> t.Callable:
"""Shortcut for :meth:`route` with ``methods=["DELETE"]``.
.. versionadded:: 2.0
"""
return self._method_route("DELETE", rule, options)
def patch(self, rule, **options):
def patch(self, rule: str, **options: t.Any) -> t.Callable:
"""Shortcut for :meth:`route` with ``methods=["PATCH"]``.
.. versionadded:: 2.0
"""
return self._method_route("PATCH", rule, options)
def route(self, rule, **options):
def route(self, rule: str, **options: t.Any) -> t.Callable:
"""Decorate a view function to register it with the given URL
rule and options. Calls :meth:`add_url_rule`, which has more
details about the implementation.
@ -391,7 +427,7 @@ class Scaffold:
:class:`~werkzeug.routing.Rule` object.
"""
def decorator(f):
def decorator(f: t.Callable) -> t.Callable:
endpoint = options.pop("endpoint", None)
self.add_url_rule(rule, endpoint, f, **options)
return f
@ -401,12 +437,12 @@ class Scaffold:
@setupmethod
def add_url_rule(
self,
rule,
endpoint=None,
view_func=None,
provide_automatic_options=None,
**options,
):
rule: str,
endpoint: t.Optional[str] = None,
view_func: t.Optional[t.Callable] = None,
provide_automatic_options: t.Optional[bool] = None,
**options: t.Any,
) -> t.Callable:
"""Register a rule for routing incoming requests and building
URLs. The :meth:`route` decorator is a shortcut to call this
with the ``view_func`` argument. These are equivalent:
@ -466,7 +502,7 @@ class Scaffold:
"""
raise NotImplementedError
def endpoint(self, endpoint):
def endpoint(self, endpoint: str) -> t.Callable:
"""Decorate a view function to register it for the given
endpoint. Used if a rule is added without a ``view_func`` with
:meth:`add_url_rule`.
@ -490,7 +526,7 @@ class Scaffold:
return decorator
@setupmethod
def before_request(self, f):
def before_request(self, f: BeforeRequestCallable) -> BeforeRequestCallable:
"""Register a function to run before each request.
For example, this can be used to open a database connection, or
@ -512,7 +548,7 @@ class Scaffold:
return f
@setupmethod
def after_request(self, f):
def after_request(self, f: AfterRequestCallable) -> AfterRequestCallable:
"""Register a function to run after each request to this object.
The function is called with the response object, and must return
@ -528,7 +564,7 @@ class Scaffold:
return f
@setupmethod
def teardown_request(self, f):
def teardown_request(self, f: TeardownCallable) -> TeardownCallable:
"""Register a function to be run at the end of each request,
regardless of whether there was an exception or not. These functions
are executed when the request context is popped, even if not an
@ -567,13 +603,17 @@ class Scaffold:
return f
@setupmethod
def context_processor(self, f):
def context_processor(
self, f: TemplateContextProcessorCallable
) -> TemplateContextProcessorCallable:
"""Registers a template context processor function."""
self.template_context_processors[None].append(f)
return f
@setupmethod
def url_value_preprocessor(self, f):
def url_value_preprocessor(
self, f: URLValuePreprocessorCallable
) -> URLValuePreprocessorCallable:
"""Register a URL value preprocessor function for all view
functions in the application. These functions will be called before the
:meth:`before_request` functions.
@ -590,7 +630,7 @@ class Scaffold:
return f
@setupmethod
def url_defaults(self, f):
def url_defaults(self, f: URLDefaultCallable) -> URLDefaultCallable:
"""Callback function for URL defaults for all view functions of the
application. It's called with the endpoint and values and should
update the values passed in place.
@ -599,7 +639,9 @@ class Scaffold:
return f
@setupmethod
def errorhandler(self, code_or_exception):
def errorhandler(
self, code_or_exception: t.Union[t.Type[Exception], int]
) -> t.Callable:
"""Register a function to handle errors by code or exception class.
A decorator that is used to register a function given an
@ -629,14 +671,18 @@ class Scaffold:
an arbitrary exception
"""
def decorator(f):
def decorator(f: ErrorHandlerCallable) -> ErrorHandlerCallable:
self.register_error_handler(code_or_exception, f)
return f
return decorator
@setupmethod
def register_error_handler(self, code_or_exception, f):
def register_error_handler(
self,
code_or_exception: t.Union[t.Type[Exception], int],
f: ErrorHandlerCallable,
) -> None:
"""Alternative error attach function to the :meth:`errorhandler`
decorator that is more straightforward to use for non decorator
usage.
@ -662,7 +708,9 @@ class Scaffold:
self.error_handler_spec[None][code][exc_class] = self.ensure_sync(f)
@staticmethod
def _get_exc_class_and_code(exc_class_or_code):
def _get_exc_class_and_code(
exc_class_or_code: t.Union[t.Type[Exception], int]
) -> t.Tuple[t.Type[Exception], t.Optional[int]]:
"""Get the exception class being handled. For HTTP status codes
or ``HTTPException`` subclasses, return both the exception and
status code.
@ -670,6 +718,7 @@ class Scaffold:
:param exc_class_or_code: Any exception class, or an HTTP status
code as an integer.
"""
exc_class: t.Type[Exception]
if isinstance(exc_class_or_code, int):
exc_class = default_exceptions[exc_class_or_code]
else:
@ -684,11 +733,11 @@ class Scaffold:
else:
return exc_class, None
def ensure_sync(self, func):
def ensure_sync(self, func: t.Callable) -> t.Callable:
raise NotImplementedError()
def _endpoint_from_view_func(view_func):
def _endpoint_from_view_func(view_func: t.Callable) -> str:
"""Internal helper that returns the default endpoint for a given
function. This always is the function name.
"""
@ -696,7 +745,7 @@ def _endpoint_from_view_func(view_func):
return view_func.__name__
def get_root_path(import_name):
def get_root_path(import_name: str) -> str:
"""Find the root path of a package, or the path that contains a
module. If it cannot be found, returns the current working
directory.
@ -721,7 +770,7 @@ def get_root_path(import_name):
return os.getcwd()
if hasattr(loader, "get_filename"):
filepath = loader.get_filename(import_name)
filepath = loader.get_filename(import_name) # type: ignore
else:
# Fall back to imports.
__import__(import_name)
@ -822,7 +871,7 @@ def _find_package_path(root_mod_name):
return package_path
def find_package(import_name):
def find_package(import_name: str):
"""Find the prefix that a package is installed under, and the path
that it would be imported from.