From 20fd9130b754a4143d0ef895de38aefee643f519 Mon Sep 17 00:00:00 2001 From: txjas99 Date: Tue, 25 Nov 2025 22:08:03 +0530 Subject: [PATCH] Add rate limiting functionality to Flask app - Introduced MemoryRateLimiter for managing request limits. - Added configuration options for enabling rate limiting and setting request limits and time windows. - Implemented methods to enforce rate limits and build rate limit keys based on request context. - Integrated rate limiting checks into the request handling process. --- src/flask/app.py | 62 +++++++++++++++++++++++++++++++++++++++ src/flask/rate_limiter.py | 54 ++++++++++++++++++++++++++++++++++ tests/test_rate_limit.py | 37 +++++++++++++++++++++++ 3 files changed, 153 insertions(+) create mode 100644 src/flask/rate_limiter.py create mode 100644 tests/test_rate_limit.py diff --git a/src/flask/app.py b/src/flask/app.py index e0c193dc..b620a0fa 100644 --- a/src/flask/app.py +++ b/src/flask/app.py @@ -19,6 +19,7 @@ from werkzeug.datastructures import ImmutableDict from werkzeug.exceptions import BadRequestKeyError from werkzeug.exceptions import HTTPException from werkzeug.exceptions import InternalServerError +from werkzeug.exceptions import TooManyRequests from werkzeug.routing import BuildError from werkzeug.routing import MapAdapter from werkzeug.routing import RequestRedirect @@ -40,6 +41,7 @@ from .helpers import get_debug_flag from .helpers import get_flashed_messages from .helpers import get_load_dotenv from .helpers import send_from_directory +from .rate_limiter import MemoryRateLimiter from .sansio.app import App from .sessions import SecureCookieSessionInterface from .sessions import SessionInterface @@ -233,6 +235,9 @@ class Flask(App): "TEMPLATES_AUTO_RELOAD": None, "MAX_COOKIE_SIZE": 4093, "PROVIDE_AUTOMATIC_OPTIONS": True, + "RATELIMIT_ENABLED": False, + "RATELIMIT_REQUESTS": 60, + "RATELIMIT_WINDOW": 60, } ) @@ -342,6 +347,9 @@ class Flask(App): # the app's commands to another CLI tool. self.cli.name = self.name + self._rate_limiter: MemoryRateLimiter | None = None + self._rate_limiter_settings: tuple[int, float] | None = None + # Add a static route using the provided static_url_path, static_host, # and static_folder if there is a configured static_folder. # Note we do this without checking if static_folder exists. @@ -410,6 +418,59 @@ class Flask(App): t.cast(str, self.static_folder), filename, max_age=max_age ) + def _get_rate_limiter(self) -> MemoryRateLimiter | None: + if not self.config.get("RATELIMIT_ENABLED"): + return None + + limit = self.config.get("RATELIMIT_REQUESTS") + window = self.config.get("RATELIMIT_WINDOW") + + if not isinstance(limit, int) or limit <= 0: + return None + + if not isinstance(window, (int, float)) or window <= 0: + return None + + settings = (limit, float(window)) + + if self._rate_limiter is None or self._rate_limiter_settings != settings: + self._rate_limiter = MemoryRateLimiter(limit, float(window)) + self._rate_limiter_settings = settings + + return self._rate_limiter + + def _build_rate_limit_key(self, ctx: AppContext) -> str | None: + req = ctx.request + + if req.access_route: + return req.access_route[0] + + if req.remote_addr: + return req.remote_addr + + host = req.environ.get("REMOTE_ADDR") + + if host is None: + host = get_host(req.environ) + + return host + + def _enforce_rate_limit(self, ctx: AppContext) -> None: + limiter = self._get_rate_limiter() + + if limiter is None: + return + + key = self._build_rate_limit_key(ctx) + + if key is None: + return + + allowed, _ = limiter.hit(key) + + if not allowed: + raise TooManyRequests(description="Rate limit exceeded. Try again later.") + def open_resource( self, resource: str, mode: str = "rb", encoding: str | None = None ) -> t.IO[t.AnyStr]: @@ -998,6 +1059,7 @@ class Flask(App): self._got_first_request = True try: + self._enforce_rate_limit(ctx) request_started.send(self, _async_wrapper=self.ensure_sync) rv = self.preprocess_request(ctx) if rv is None: diff --git a/src/flask/rate_limiter.py b/src/flask/rate_limiter.py new file mode 100644 index 00000000..fc95b1ad --- /dev/null +++ b/src/flask/rate_limiter.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from collections import defaultdict, deque +from threading import Lock +from time import monotonic +from typing import Deque, DefaultDict, Tuple + +__all__ = ("MemoryRateLimiter",) + + +class MemoryRateLimiter: + """Simple in-memory, per-key fixed window rate limiter. + + The limiter stores request timestamps in monotonic time buckets. It is + designed for small Flask applications and the built-in development server, + not as a production-grade distributed limiter. + """ + + def __init__(self, limit: int, window: float) -> None: + if limit <= 0: + raise ValueError("limit must be greater than 0") + if window <= 0: + raise ValueError("window must be greater than 0 seconds") + + self.limit = limit + self.window = window + self._requests: DefaultDict[str, Deque[float]] = defaultdict(deque) + self._lock = Lock() + + def hit(self, key: str) -> Tuple[bool, float | None]: + """Record a hit for *key*. + + Returns a tuple ``(allowed, retry_after)`` where *allowed* indicates if + the request can proceed. When *allowed* is ``False`` the second value is + the remaining seconds until a new request would be accepted. + """ + + now = monotonic() + window_start = now - self.window + + with self._lock: + bucket = self._requests[key] + + while bucket and bucket[0] <= window_start: + bucket.popleft() + + if len(bucket) >= self.limit: + retry_after = max(0.0, self.window - (now - bucket[0])) + return False, retry_after + + bucket.append(now) + + return True, None + diff --git a/tests/test_rate_limit.py b/tests/test_rate_limit.py new file mode 100644 index 00000000..dd9706b7 --- /dev/null +++ b/tests/test_rate_limit.py @@ -0,0 +1,37 @@ +import flask + + +def test_rate_limit_disabled_by_default(): + app = flask.Flask(__name__) + + @app.route("/") + def index(): + return "ok" + + client = app.test_client() + + for _ in range(3): + rv = client.get("/") + assert rv.status_code == 200 + + +def test_rate_limit_blocks_after_threshold(): + app = flask.Flask(__name__) + app.config.update( + RATELIMIT_ENABLED=True, + RATELIMIT_REQUESTS=2, + RATELIMIT_WINDOW=60, + ) + + @app.route("/") + def index(): + return "ok" + + client = app.test_client() + + assert client.get("/").status_code == 200 + assert client.get("/").status_code == 200 + + rv = client.get("/") + assert rv.status_code == 429 +