Add rate limiting functionality to Flask app

- Introduced MemoryRateLimiter for managing request limits.
- Added configuration options for enabling rate limiting and setting request limits and time windows.
- Implemented methods to enforce rate limits and build rate limit keys based on request context.
- Integrated rate limiting checks into the request handling process.
This commit is contained in:
txjas99 2025-11-25 22:08:03 +05:30
parent 2579ce9f18
commit 20fd9130b7
3 changed files with 153 additions and 0 deletions

View file

@ -19,6 +19,7 @@ from werkzeug.datastructures import ImmutableDict
from werkzeug.exceptions import BadRequestKeyError
from werkzeug.exceptions import HTTPException
from werkzeug.exceptions import InternalServerError
from werkzeug.exceptions import TooManyRequests
from werkzeug.routing import BuildError
from werkzeug.routing import MapAdapter
from werkzeug.routing import RequestRedirect
@ -40,6 +41,7 @@ from .helpers import get_debug_flag
from .helpers import get_flashed_messages
from .helpers import get_load_dotenv
from .helpers import send_from_directory
from .rate_limiter import MemoryRateLimiter
from .sansio.app import App
from .sessions import SecureCookieSessionInterface
from .sessions import SessionInterface
@ -233,6 +235,9 @@ class Flask(App):
"TEMPLATES_AUTO_RELOAD": None,
"MAX_COOKIE_SIZE": 4093,
"PROVIDE_AUTOMATIC_OPTIONS": True,
"RATELIMIT_ENABLED": False,
"RATELIMIT_REQUESTS": 60,
"RATELIMIT_WINDOW": 60,
}
)
@ -342,6 +347,9 @@ class Flask(App):
# the app's commands to another CLI tool.
self.cli.name = self.name
self._rate_limiter: MemoryRateLimiter | None = None
self._rate_limiter_settings: tuple[int, float] | None = None
# Add a static route using the provided static_url_path, static_host,
# and static_folder if there is a configured static_folder.
# Note we do this without checking if static_folder exists.
@ -410,6 +418,59 @@ class Flask(App):
t.cast(str, self.static_folder), filename, max_age=max_age
)
def _get_rate_limiter(self) -> MemoryRateLimiter | None:
if not self.config.get("RATELIMIT_ENABLED"):
return None
limit = self.config.get("RATELIMIT_REQUESTS")
window = self.config.get("RATELIMIT_WINDOW")
if not isinstance(limit, int) or limit <= 0:
return None
if not isinstance(window, (int, float)) or window <= 0:
return None
settings = (limit, float(window))
if self._rate_limiter is None or self._rate_limiter_settings != settings:
self._rate_limiter = MemoryRateLimiter(limit, float(window))
self._rate_limiter_settings = settings
return self._rate_limiter
def _build_rate_limit_key(self, ctx: AppContext) -> str | None:
req = ctx.request
if req.access_route:
return req.access_route[0]
if req.remote_addr:
return req.remote_addr
host = req.environ.get("REMOTE_ADDR")
if host is None:
host = get_host(req.environ)
return host
def _enforce_rate_limit(self, ctx: AppContext) -> None:
limiter = self._get_rate_limiter()
if limiter is None:
return
key = self._build_rate_limit_key(ctx)
if key is None:
return
allowed, _ = limiter.hit(key)
if not allowed:
raise TooManyRequests(description="Rate limit exceeded. Try again later.")
def open_resource(
self, resource: str, mode: str = "rb", encoding: str | None = None
) -> t.IO[t.AnyStr]:
@ -998,6 +1059,7 @@ class Flask(App):
self._got_first_request = True
try:
self._enforce_rate_limit(ctx)
request_started.send(self, _async_wrapper=self.ensure_sync)
rv = self.preprocess_request(ctx)
if rv is None:

54
src/flask/rate_limiter.py Normal file
View file

@ -0,0 +1,54 @@
from __future__ import annotations
from collections import defaultdict, deque
from threading import Lock
from time import monotonic
from typing import Deque, DefaultDict, Tuple
__all__ = ("MemoryRateLimiter",)
class MemoryRateLimiter:
"""Simple in-memory, per-key fixed window rate limiter.
The limiter stores request timestamps in monotonic time buckets. It is
designed for small Flask applications and the built-in development server,
not as a production-grade distributed limiter.
"""
def __init__(self, limit: int, window: float) -> None:
if limit <= 0:
raise ValueError("limit must be greater than 0")
if window <= 0:
raise ValueError("window must be greater than 0 seconds")
self.limit = limit
self.window = window
self._requests: DefaultDict[str, Deque[float]] = defaultdict(deque)
self._lock = Lock()
def hit(self, key: str) -> Tuple[bool, float | None]:
"""Record a hit for *key*.
Returns a tuple ``(allowed, retry_after)`` where *allowed* indicates if
the request can proceed. When *allowed* is ``False`` the second value is
the remaining seconds until a new request would be accepted.
"""
now = monotonic()
window_start = now - self.window
with self._lock:
bucket = self._requests[key]
while bucket and bucket[0] <= window_start:
bucket.popleft()
if len(bucket) >= self.limit:
retry_after = max(0.0, self.window - (now - bucket[0]))
return False, retry_after
bucket.append(now)
return True, None

37
tests/test_rate_limit.py Normal file
View file

@ -0,0 +1,37 @@
import flask
def test_rate_limit_disabled_by_default():
app = flask.Flask(__name__)
@app.route("/")
def index():
return "ok"
client = app.test_client()
for _ in range(3):
rv = client.get("/")
assert rv.status_code == 200
def test_rate_limit_blocks_after_threshold():
app = flask.Flask(__name__)
app.config.update(
RATELIMIT_ENABLED=True,
RATELIMIT_REQUESTS=2,
RATELIMIT_WINDOW=60,
)
@app.route("/")
def index():
return "ok"
client = app.test_client()
assert client.get("/").status_code == 200
assert client.get("/").status_code == 200
rv = client.get("/")
assert rv.status_code == 429