Add rate limiting functionality to Flask app
- Introduced MemoryRateLimiter for managing request limits. - Added configuration options for enabling rate limiting and setting request limits and time windows. - Implemented methods to enforce rate limits and build rate limit keys based on request context. - Integrated rate limiting checks into the request handling process.
This commit is contained in:
parent
2579ce9f18
commit
20fd9130b7
3 changed files with 153 additions and 0 deletions
|
|
@ -19,6 +19,7 @@ from werkzeug.datastructures import ImmutableDict
|
||||||
from werkzeug.exceptions import BadRequestKeyError
|
from werkzeug.exceptions import BadRequestKeyError
|
||||||
from werkzeug.exceptions import HTTPException
|
from werkzeug.exceptions import HTTPException
|
||||||
from werkzeug.exceptions import InternalServerError
|
from werkzeug.exceptions import InternalServerError
|
||||||
|
from werkzeug.exceptions import TooManyRequests
|
||||||
from werkzeug.routing import BuildError
|
from werkzeug.routing import BuildError
|
||||||
from werkzeug.routing import MapAdapter
|
from werkzeug.routing import MapAdapter
|
||||||
from werkzeug.routing import RequestRedirect
|
from werkzeug.routing import RequestRedirect
|
||||||
|
|
@ -40,6 +41,7 @@ from .helpers import get_debug_flag
|
||||||
from .helpers import get_flashed_messages
|
from .helpers import get_flashed_messages
|
||||||
from .helpers import get_load_dotenv
|
from .helpers import get_load_dotenv
|
||||||
from .helpers import send_from_directory
|
from .helpers import send_from_directory
|
||||||
|
from .rate_limiter import MemoryRateLimiter
|
||||||
from .sansio.app import App
|
from .sansio.app import App
|
||||||
from .sessions import SecureCookieSessionInterface
|
from .sessions import SecureCookieSessionInterface
|
||||||
from .sessions import SessionInterface
|
from .sessions import SessionInterface
|
||||||
|
|
@ -233,6 +235,9 @@ class Flask(App):
|
||||||
"TEMPLATES_AUTO_RELOAD": None,
|
"TEMPLATES_AUTO_RELOAD": None,
|
||||||
"MAX_COOKIE_SIZE": 4093,
|
"MAX_COOKIE_SIZE": 4093,
|
||||||
"PROVIDE_AUTOMATIC_OPTIONS": True,
|
"PROVIDE_AUTOMATIC_OPTIONS": True,
|
||||||
|
"RATELIMIT_ENABLED": False,
|
||||||
|
"RATELIMIT_REQUESTS": 60,
|
||||||
|
"RATELIMIT_WINDOW": 60,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -342,6 +347,9 @@ class Flask(App):
|
||||||
# the app's commands to another CLI tool.
|
# the app's commands to another CLI tool.
|
||||||
self.cli.name = self.name
|
self.cli.name = self.name
|
||||||
|
|
||||||
|
self._rate_limiter: MemoryRateLimiter | None = None
|
||||||
|
self._rate_limiter_settings: tuple[int, float] | None = None
|
||||||
|
|
||||||
# Add a static route using the provided static_url_path, static_host,
|
# Add a static route using the provided static_url_path, static_host,
|
||||||
# and static_folder if there is a configured static_folder.
|
# and static_folder if there is a configured static_folder.
|
||||||
# Note we do this without checking if static_folder exists.
|
# Note we do this without checking if static_folder exists.
|
||||||
|
|
@ -410,6 +418,59 @@ class Flask(App):
|
||||||
t.cast(str, self.static_folder), filename, max_age=max_age
|
t.cast(str, self.static_folder), filename, max_age=max_age
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_rate_limiter(self) -> MemoryRateLimiter | None:
|
||||||
|
if not self.config.get("RATELIMIT_ENABLED"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
limit = self.config.get("RATELIMIT_REQUESTS")
|
||||||
|
window = self.config.get("RATELIMIT_WINDOW")
|
||||||
|
|
||||||
|
if not isinstance(limit, int) or limit <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not isinstance(window, (int, float)) or window <= 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
settings = (limit, float(window))
|
||||||
|
|
||||||
|
if self._rate_limiter is None or self._rate_limiter_settings != settings:
|
||||||
|
self._rate_limiter = MemoryRateLimiter(limit, float(window))
|
||||||
|
self._rate_limiter_settings = settings
|
||||||
|
|
||||||
|
return self._rate_limiter
|
||||||
|
|
||||||
|
def _build_rate_limit_key(self, ctx: AppContext) -> str | None:
|
||||||
|
req = ctx.request
|
||||||
|
|
||||||
|
if req.access_route:
|
||||||
|
return req.access_route[0]
|
||||||
|
|
||||||
|
if req.remote_addr:
|
||||||
|
return req.remote_addr
|
||||||
|
|
||||||
|
host = req.environ.get("REMOTE_ADDR")
|
||||||
|
|
||||||
|
if host is None:
|
||||||
|
host = get_host(req.environ)
|
||||||
|
|
||||||
|
return host
|
||||||
|
|
||||||
|
def _enforce_rate_limit(self, ctx: AppContext) -> None:
|
||||||
|
limiter = self._get_rate_limiter()
|
||||||
|
|
||||||
|
if limiter is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
key = self._build_rate_limit_key(ctx)
|
||||||
|
|
||||||
|
if key is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
allowed, _ = limiter.hit(key)
|
||||||
|
|
||||||
|
if not allowed:
|
||||||
|
raise TooManyRequests(description="Rate limit exceeded. Try again later.")
|
||||||
|
|
||||||
def open_resource(
|
def open_resource(
|
||||||
self, resource: str, mode: str = "rb", encoding: str | None = None
|
self, resource: str, mode: str = "rb", encoding: str | None = None
|
||||||
) -> t.IO[t.AnyStr]:
|
) -> t.IO[t.AnyStr]:
|
||||||
|
|
@ -998,6 +1059,7 @@ class Flask(App):
|
||||||
self._got_first_request = True
|
self._got_first_request = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
self._enforce_rate_limit(ctx)
|
||||||
request_started.send(self, _async_wrapper=self.ensure_sync)
|
request_started.send(self, _async_wrapper=self.ensure_sync)
|
||||||
rv = self.preprocess_request(ctx)
|
rv = self.preprocess_request(ctx)
|
||||||
if rv is None:
|
if rv is None:
|
||||||
|
|
|
||||||
54
src/flask/rate_limiter.py
Normal file
54
src/flask/rate_limiter.py
Normal file
|
|
@ -0,0 +1,54 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import defaultdict, deque
|
||||||
|
from threading import Lock
|
||||||
|
from time import monotonic
|
||||||
|
from typing import Deque, DefaultDict, Tuple
|
||||||
|
|
||||||
|
__all__ = ("MemoryRateLimiter",)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRateLimiter:
|
||||||
|
"""Simple in-memory, per-key fixed window rate limiter.
|
||||||
|
|
||||||
|
The limiter stores request timestamps in monotonic time buckets. It is
|
||||||
|
designed for small Flask applications and the built-in development server,
|
||||||
|
not as a production-grade distributed limiter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, limit: int, window: float) -> None:
|
||||||
|
if limit <= 0:
|
||||||
|
raise ValueError("limit must be greater than 0")
|
||||||
|
if window <= 0:
|
||||||
|
raise ValueError("window must be greater than 0 seconds")
|
||||||
|
|
||||||
|
self.limit = limit
|
||||||
|
self.window = window
|
||||||
|
self._requests: DefaultDict[str, Deque[float]] = defaultdict(deque)
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
def hit(self, key: str) -> Tuple[bool, float | None]:
|
||||||
|
"""Record a hit for *key*.
|
||||||
|
|
||||||
|
Returns a tuple ``(allowed, retry_after)`` where *allowed* indicates if
|
||||||
|
the request can proceed. When *allowed* is ``False`` the second value is
|
||||||
|
the remaining seconds until a new request would be accepted.
|
||||||
|
"""
|
||||||
|
|
||||||
|
now = monotonic()
|
||||||
|
window_start = now - self.window
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
bucket = self._requests[key]
|
||||||
|
|
||||||
|
while bucket and bucket[0] <= window_start:
|
||||||
|
bucket.popleft()
|
||||||
|
|
||||||
|
if len(bucket) >= self.limit:
|
||||||
|
retry_after = max(0.0, self.window - (now - bucket[0]))
|
||||||
|
return False, retry_after
|
||||||
|
|
||||||
|
bucket.append(now)
|
||||||
|
|
||||||
|
return True, None
|
||||||
|
|
||||||
37
tests/test_rate_limit.py
Normal file
37
tests/test_rate_limit.py
Normal file
|
|
@ -0,0 +1,37 @@
|
||||||
|
import flask
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limit_disabled_by_default():
|
||||||
|
app = flask.Flask(__name__)
|
||||||
|
|
||||||
|
@app.route("/")
|
||||||
|
def index():
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
client = app.test_client()
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
rv = client.get("/")
|
||||||
|
assert rv.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limit_blocks_after_threshold():
|
||||||
|
app = flask.Flask(__name__)
|
||||||
|
app.config.update(
|
||||||
|
RATELIMIT_ENABLED=True,
|
||||||
|
RATELIMIT_REQUESTS=2,
|
||||||
|
RATELIMIT_WINDOW=60,
|
||||||
|
)
|
||||||
|
|
||||||
|
@app.route("/")
|
||||||
|
def index():
|
||||||
|
return "ok"
|
||||||
|
|
||||||
|
client = app.test_client()
|
||||||
|
|
||||||
|
assert client.get("/").status_code == 200
|
||||||
|
assert client.get("/").status_code == 200
|
||||||
|
|
||||||
|
rv = client.get("/")
|
||||||
|
assert rv.status_code == 429
|
||||||
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue