Add initial type hints

This should make it easier for users to correctly use Flask. The hints
are from Quart.
This commit is contained in:
pgjones 2021-04-24 12:22:26 +01:00
parent f405c6f19e
commit 77237093da
20 changed files with 820 additions and 461 deletions

View file

@ -1,6 +1,8 @@
import os
import socket
import typing as t
import warnings
from datetime import timedelta
from functools import update_wrapper
from functools import wraps
from threading import RLock
@ -18,8 +20,11 @@ from .globals import request
from .globals import session
from .signals import message_flashed
if t.TYPE_CHECKING:
from .wrappers import Response
def get_env():
def get_env() -> str:
"""Get the environment the app is running in, indicated by the
:envvar:`FLASK_ENV` environment variable. The default is
``'production'``.
@ -27,7 +32,7 @@ def get_env():
return os.environ.get("FLASK_ENV") or "production"
def get_debug_flag():
def get_debug_flag() -> bool:
"""Get whether debug mode should be enabled for the app, indicated
by the :envvar:`FLASK_DEBUG` environment variable. The default is
``True`` if :func:`.get_env` returns ``'development'``, or ``False``
@ -41,7 +46,7 @@ def get_debug_flag():
return val.lower() not in ("0", "false", "no")
def get_load_dotenv(default=True):
def get_load_dotenv(default: bool = True) -> bool:
"""Get whether the user has disabled loading dotenv files by setting
:envvar:`FLASK_SKIP_DOTENV`. The default is ``True``, load the
files.
@ -56,7 +61,9 @@ def get_load_dotenv(default=True):
return val.lower() in ("0", "false", "no")
def stream_with_context(generator_or_function):
def stream_with_context(
generator_or_function: t.Union[t.Generator, t.Callable]
) -> t.Generator:
"""Request contexts disappear when the response is started on the server.
This is done for efficiency reasons and to make it less likely to encounter
memory leaks with badly written WSGI middlewares. The downside is that if
@ -91,16 +98,16 @@ def stream_with_context(generator_or_function):
.. versionadded:: 0.9
"""
try:
gen = iter(generator_or_function)
gen = iter(generator_or_function) # type: ignore
except TypeError:
def decorator(*args, **kwargs):
gen = generator_or_function(*args, **kwargs)
def decorator(*args: t.Any, **kwargs: t.Any) -> t.Any:
gen = generator_or_function(*args, **kwargs) # type: ignore
return stream_with_context(gen)
return update_wrapper(decorator, generator_or_function)
return update_wrapper(decorator, generator_or_function) # type: ignore
def generator():
def generator() -> t.Generator:
ctx = _request_ctx_stack.top
if ctx is None:
raise RuntimeError(
@ -120,7 +127,7 @@ def stream_with_context(generator_or_function):
yield from gen
finally:
if hasattr(gen, "close"):
gen.close()
gen.close() # type: ignore
# The trick is to start the generator. Then the code execution runs until
# the first dummy None is yielded at which point the context was already
@ -131,7 +138,7 @@ def stream_with_context(generator_or_function):
return wrapped_g
def make_response(*args):
def make_response(*args: t.Any) -> "Response":
"""Sometimes it is necessary to set additional headers in a view. Because
views do not have to return response objects but can return a value that
is converted into a response object by Flask itself, it becomes tricky to
@ -180,7 +187,7 @@ def make_response(*args):
return current_app.make_response(args)
def url_for(endpoint, **values):
def url_for(endpoint: str, **values: t.Any) -> str:
"""Generates a URL to the given endpoint with the method provided.
Variable arguments that are unknown to the target endpoint are appended
@ -331,7 +338,7 @@ def url_for(endpoint, **values):
return rv
def get_template_attribute(template_name, attribute):
def get_template_attribute(template_name: str, attribute: str) -> t.Any:
"""Loads a macro (or variable) a template exports. This can be used to
invoke a macro from within Python code. If you for example have a
template named :file:`_cider.html` with the following contents:
@ -353,7 +360,7 @@ def get_template_attribute(template_name, attribute):
return getattr(current_app.jinja_env.get_template(template_name).module, attribute)
def flash(message, category="message"):
def flash(message: str, category: str = "message") -> None:
"""Flashes a message to the next request. In order to remove the
flashed message from the session and to display it to the user,
the template has to call :func:`get_flashed_messages`.
@ -379,11 +386,15 @@ def flash(message, category="message"):
flashes.append((category, message))
session["_flashes"] = flashes
message_flashed.send(
current_app._get_current_object(), message=message, category=category
current_app._get_current_object(), # type: ignore
message=message,
category=category,
)
def get_flashed_messages(with_categories=False, category_filter=()):
def get_flashed_messages(
with_categories: bool = False, category_filter: t.Iterable[str] = ()
) -> t.Union[t.List[str], t.List[t.Tuple[str, str]]]:
"""Pulls all flashed messages from the session and returns them.
Further calls in the same request to the function will return
the same messages. By default just the messages are returned,
@ -608,7 +619,7 @@ def send_file(
)
def safe_join(directory, *pathnames):
def safe_join(directory: str, *pathnames: str) -> str:
"""Safely join zero or more untrusted path components to a base
directory to avoid escaping the base directory.
@ -631,7 +642,7 @@ def safe_join(directory, *pathnames):
return path
def send_from_directory(directory, path, **kwargs):
def send_from_directory(directory: str, path: str, **kwargs: t.Any) -> "Response":
"""Send a file from within a directory using :func:`send_file`.
.. code-block:: python
@ -661,7 +672,7 @@ def send_from_directory(directory, path, **kwargs):
.. versionadded:: 0.5
"""
return werkzeug.utils.send_from_directory(
return werkzeug.utils.send_from_directory( # type: ignore
directory, path, **_prepare_send_file_kwargs(**kwargs)
)
@ -675,27 +686,32 @@ class locked_cached_property(werkzeug.utils.cached_property):
Inherits from Werkzeug's ``cached_property`` (and ``property``).
"""
def __init__(self, fget, name=None, doc=None):
def __init__(
self,
fget: t.Callable[[t.Any], t.Any],
name: t.Optional[str] = None,
doc: t.Optional[str] = None,
) -> None:
super().__init__(fget, name=name, doc=doc)
self.lock = RLock()
def __get__(self, obj, type=None):
def __get__(self, obj: object, type: type = None) -> t.Any: # type: ignore
if obj is None:
return self
with self.lock:
return super().__get__(obj, type=type)
def __set__(self, obj, value):
def __set__(self, obj: object, value: t.Any) -> None:
with self.lock:
super().__set__(obj, value)
def __delete__(self, obj):
def __delete__(self, obj: object) -> None:
with self.lock:
super().__delete__(obj)
def total_seconds(td):
def total_seconds(td: timedelta) -> int:
"""Returns the total seconds from a timedelta object.
:param timedelta td: the timedelta to be converted in seconds
@ -716,7 +732,7 @@ def total_seconds(td):
return td.days * 60 * 60 * 24 + td.seconds
def is_ip(value):
def is_ip(value: str) -> bool:
"""Determine if the given string is an IP address.
:param value: value to check
@ -736,7 +752,7 @@ def is_ip(value):
return False
def run_async(func):
def run_async(func: t.Callable[..., t.Coroutine]) -> t.Callable[..., t.Any]:
"""Return a sync function that will run the coroutine function *func*."""
try:
from asgiref.sync import async_to_sync
@ -752,7 +768,7 @@ def run_async(func):
)
@wraps(func)
def outer(*args, **kwargs):
def outer(*args: t.Any, **kwargs: t.Any) -> t.Any:
"""This function grabs the current context for the inner function.
This is similar to the copy_current_xxx_context functions in the
@ -764,7 +780,7 @@ def run_async(func):
ctx = _request_ctx_stack.top.copy()
@wraps(func)
async def inner(*a, **k):
async def inner(*a: t.Any, **k: t.Any) -> t.Any:
"""This restores the context before awaiting the func.
This is required as the function must be awaited within the
@ -780,5 +796,5 @@ def run_async(func):
return async_to_sync(inner)(*args, **kwargs)
outer._flask_sync_wrapper = True
outer._flask_sync_wrapper = True # type: ignore
return outer