Add rate limiting functionality to Flask app
- Introduced MemoryRateLimiter for managing request limits. - Added configuration options for enabling rate limiting and setting request limits and time windows. - Implemented methods to enforce rate limits and build rate limit keys based on request context. - Integrated rate limiting checks into the request handling process.
This commit is contained in:
parent
2579ce9f18
commit
20fd9130b7
3 changed files with 153 additions and 0 deletions
|
|
@ -19,6 +19,7 @@ from werkzeug.datastructures import ImmutableDict
|
|||
from werkzeug.exceptions import BadRequestKeyError
|
||||
from werkzeug.exceptions import HTTPException
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
from werkzeug.exceptions import TooManyRequests
|
||||
from werkzeug.routing import BuildError
|
||||
from werkzeug.routing import MapAdapter
|
||||
from werkzeug.routing import RequestRedirect
|
||||
|
|
@ -40,6 +41,7 @@ from .helpers import get_debug_flag
|
|||
from .helpers import get_flashed_messages
|
||||
from .helpers import get_load_dotenv
|
||||
from .helpers import send_from_directory
|
||||
from .rate_limiter import MemoryRateLimiter
|
||||
from .sansio.app import App
|
||||
from .sessions import SecureCookieSessionInterface
|
||||
from .sessions import SessionInterface
|
||||
|
|
@ -233,6 +235,9 @@ class Flask(App):
|
|||
"TEMPLATES_AUTO_RELOAD": None,
|
||||
"MAX_COOKIE_SIZE": 4093,
|
||||
"PROVIDE_AUTOMATIC_OPTIONS": True,
|
||||
"RATELIMIT_ENABLED": False,
|
||||
"RATELIMIT_REQUESTS": 60,
|
||||
"RATELIMIT_WINDOW": 60,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -342,6 +347,9 @@ class Flask(App):
|
|||
# the app's commands to another CLI tool.
|
||||
self.cli.name = self.name
|
||||
|
||||
self._rate_limiter: MemoryRateLimiter | None = None
|
||||
self._rate_limiter_settings: tuple[int, float] | None = None
|
||||
|
||||
# Add a static route using the provided static_url_path, static_host,
|
||||
# and static_folder if there is a configured static_folder.
|
||||
# Note we do this without checking if static_folder exists.
|
||||
|
|
@ -410,6 +418,59 @@ class Flask(App):
|
|||
t.cast(str, self.static_folder), filename, max_age=max_age
|
||||
)
|
||||
|
||||
def _get_rate_limiter(self) -> MemoryRateLimiter | None:
|
||||
if not self.config.get("RATELIMIT_ENABLED"):
|
||||
return None
|
||||
|
||||
limit = self.config.get("RATELIMIT_REQUESTS")
|
||||
window = self.config.get("RATELIMIT_WINDOW")
|
||||
|
||||
if not isinstance(limit, int) or limit <= 0:
|
||||
return None
|
||||
|
||||
if not isinstance(window, (int, float)) or window <= 0:
|
||||
return None
|
||||
|
||||
settings = (limit, float(window))
|
||||
|
||||
if self._rate_limiter is None or self._rate_limiter_settings != settings:
|
||||
self._rate_limiter = MemoryRateLimiter(limit, float(window))
|
||||
self._rate_limiter_settings = settings
|
||||
|
||||
return self._rate_limiter
|
||||
|
||||
def _build_rate_limit_key(self, ctx: AppContext) -> str | None:
|
||||
req = ctx.request
|
||||
|
||||
if req.access_route:
|
||||
return req.access_route[0]
|
||||
|
||||
if req.remote_addr:
|
||||
return req.remote_addr
|
||||
|
||||
host = req.environ.get("REMOTE_ADDR")
|
||||
|
||||
if host is None:
|
||||
host = get_host(req.environ)
|
||||
|
||||
return host
|
||||
|
||||
def _enforce_rate_limit(self, ctx: AppContext) -> None:
|
||||
limiter = self._get_rate_limiter()
|
||||
|
||||
if limiter is None:
|
||||
return
|
||||
|
||||
key = self._build_rate_limit_key(ctx)
|
||||
|
||||
if key is None:
|
||||
return
|
||||
|
||||
allowed, _ = limiter.hit(key)
|
||||
|
||||
if not allowed:
|
||||
raise TooManyRequests(description="Rate limit exceeded. Try again later.")
|
||||
|
||||
def open_resource(
|
||||
self, resource: str, mode: str = "rb", encoding: str | None = None
|
||||
) -> t.IO[t.AnyStr]:
|
||||
|
|
@ -998,6 +1059,7 @@ class Flask(App):
|
|||
self._got_first_request = True
|
||||
|
||||
try:
|
||||
self._enforce_rate_limit(ctx)
|
||||
request_started.send(self, _async_wrapper=self.ensure_sync)
|
||||
rv = self.preprocess_request(ctx)
|
||||
if rv is None:
|
||||
|
|
|
|||
54
src/flask/rate_limiter.py
Normal file
54
src/flask/rate_limiter.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict, deque
|
||||
from threading import Lock
|
||||
from time import monotonic
|
||||
from typing import Deque, DefaultDict, Tuple
|
||||
|
||||
__all__ = ("MemoryRateLimiter",)
|
||||
|
||||
|
||||
class MemoryRateLimiter:
|
||||
"""Simple in-memory, per-key fixed window rate limiter.
|
||||
|
||||
The limiter stores request timestamps in monotonic time buckets. It is
|
||||
designed for small Flask applications and the built-in development server,
|
||||
not as a production-grade distributed limiter.
|
||||
"""
|
||||
|
||||
def __init__(self, limit: int, window: float) -> None:
|
||||
if limit <= 0:
|
||||
raise ValueError("limit must be greater than 0")
|
||||
if window <= 0:
|
||||
raise ValueError("window must be greater than 0 seconds")
|
||||
|
||||
self.limit = limit
|
||||
self.window = window
|
||||
self._requests: DefaultDict[str, Deque[float]] = defaultdict(deque)
|
||||
self._lock = Lock()
|
||||
|
||||
def hit(self, key: str) -> Tuple[bool, float | None]:
|
||||
"""Record a hit for *key*.
|
||||
|
||||
Returns a tuple ``(allowed, retry_after)`` where *allowed* indicates if
|
||||
the request can proceed. When *allowed* is ``False`` the second value is
|
||||
the remaining seconds until a new request would be accepted.
|
||||
"""
|
||||
|
||||
now = monotonic()
|
||||
window_start = now - self.window
|
||||
|
||||
with self._lock:
|
||||
bucket = self._requests[key]
|
||||
|
||||
while bucket and bucket[0] <= window_start:
|
||||
bucket.popleft()
|
||||
|
||||
if len(bucket) >= self.limit:
|
||||
retry_after = max(0.0, self.window - (now - bucket[0]))
|
||||
return False, retry_after
|
||||
|
||||
bucket.append(now)
|
||||
|
||||
return True, None
|
||||
|
||||
37
tests/test_rate_limit.py
Normal file
37
tests/test_rate_limit.py
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import flask
|
||||
|
||||
|
||||
def test_rate_limit_disabled_by_default():
|
||||
app = flask.Flask(__name__)
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
client = app.test_client()
|
||||
|
||||
for _ in range(3):
|
||||
rv = client.get("/")
|
||||
assert rv.status_code == 200
|
||||
|
||||
|
||||
def test_rate_limit_blocks_after_threshold():
|
||||
app = flask.Flask(__name__)
|
||||
app.config.update(
|
||||
RATELIMIT_ENABLED=True,
|
||||
RATELIMIT_REQUESTS=2,
|
||||
RATELIMIT_WINDOW=60,
|
||||
)
|
||||
|
||||
@app.route("/")
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
client = app.test_client()
|
||||
|
||||
assert client.get("/").status_code == 200
|
||||
assert client.get("/").status_code == 200
|
||||
|
||||
rv = client.get("/")
|
||||
assert rv.status_code == 429
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue