Add built-in CSRF protection using Sec-Fetch-Site header
This commit is contained in:
parent
2579ce9f18
commit
54aaa01eb5
7 changed files with 525 additions and 20 deletions
|
|
@ -3,6 +3,9 @@ Version 3.2.0
|
|||
|
||||
Unreleased
|
||||
|
||||
- Add built-in CSRF protection using ``Sec-Fetch-Site`` header with
|
||||
``Origin`` fallback. Controlled by ``CSRF_PROTECTION`` config and
|
||||
``csrf_protection`` route parameter. :issue:`5863`
|
||||
- Drop support for Python 3.9. :pr:`5730`
|
||||
- Remove previously deprecated code: ``__version__``. :pr:`5648`
|
||||
- ``RequestContext`` has merged with ``AppContext``. ``RequestContext`` is now
|
||||
|
|
|
|||
|
|
@ -394,6 +394,39 @@ The following configuration values are used internally by Flask:
|
|||
responses. This can be overridden per route by altering the
|
||||
``provide_automatic_options`` attribute.
|
||||
|
||||
.. py:data:: CSRF_PROTECTION
|
||||
|
||||
Enable CSRF protection globally for all routes. When enabled, requests
|
||||
using methods in :data:`CSRF_PROTECTED_METHODS` will be validated using
|
||||
the ``Sec-Fetch-Site`` header (with a fallback to ``Origin`` header
|
||||
validation). This can be overridden per route using the ``csrf_protection``
|
||||
parameter on ``@app.route()`` or ``add_url_rule()``.
|
||||
See :ref:`security-csrf`.
|
||||
|
||||
Default: ``False``
|
||||
|
||||
.. versionadded:: 3.2
|
||||
|
||||
.. py:data:: CSRF_TRUSTED_ORIGINS
|
||||
|
||||
A list of origins that are trusted to make cross-origin requests without
|
||||
CSRF validation. Each value should be a full origin including the scheme,
|
||||
such as ``"https://example.com"``.
|
||||
|
||||
Default: ``None``
|
||||
|
||||
.. versionadded:: 3.2
|
||||
|
||||
.. py:data:: CSRF_PROTECTED_METHODS
|
||||
|
||||
A set of HTTP methods that require CSRF validation when
|
||||
:data:`CSRF_PROTECTION` is ``True`` or ``csrf_protection=True`` is set on a
|
||||
route. Safe methods like GET, HEAD, and OPTIONS should not be included.
|
||||
|
||||
Default: ``frozenset({"POST", "PUT", "PATCH", "DELETE"})``
|
||||
|
||||
.. versionadded:: 3.2
|
||||
|
||||
.. versionadded:: 0.4
|
||||
``LOGGER_NAME``
|
||||
|
||||
|
|
|
|||
|
|
@ -100,40 +100,84 @@ which the browser will execute when clicked if not secured properly.
|
|||
|
||||
To prevent this, you'll need to set the :ref:`security-csp` response header.
|
||||
|
||||
.. _security-csrf:
|
||||
|
||||
Cross-Site Request Forgery (CSRF)
|
||||
---------------------------------
|
||||
|
||||
Another big problem is CSRF. This is a very complex topic and I won't
|
||||
outline it here in detail just mention what it is and how to theoretically
|
||||
prevent it.
|
||||
|
||||
If your authentication information is stored in cookies, you have implicit
|
||||
state management. The state of "being logged in" is controlled by a
|
||||
cookie, and that cookie is sent with each request to a page.
|
||||
Unfortunately that includes requests triggered by 3rd party sites. If you
|
||||
don't keep that in mind, some people might be able to trick your
|
||||
application's users with social engineering to do stupid things without
|
||||
them knowing.
|
||||
state management. The state of "being logged in" is controlled by a cookie,
|
||||
and that cookie is sent with each request to a page. Unfortunately that
|
||||
includes requests triggered by 3rd party sites. If you don't keep that in
|
||||
mind, some people might be able to trick your application's users with social
|
||||
engineering to do stupid things without them knowing.
|
||||
|
||||
Say you have a specific URL that, when you sent ``POST`` requests to will
|
||||
delete a user's profile (say ``http://example.com/user/delete``). If an
|
||||
Say you have a specific URL that, when you send ``POST`` requests to will
|
||||
delete a user's profile (say ``http://example.com/user/delete``). If an
|
||||
attacker now creates a page that sends a post request to that page with
|
||||
some JavaScript they just have to trick some users to load that page and
|
||||
their profiles will end up being deleted.
|
||||
|
||||
Imagine you were to run Facebook with millions of concurrent users and
|
||||
someone would send out links to images of little kittens. When users
|
||||
someone would send out links to images of little kittens. When users
|
||||
would go to that page, their profiles would get deleted while they are
|
||||
looking at images of fluffy cats.
|
||||
|
||||
How can you prevent that? Basically for each request that modifies
|
||||
content on the server you would have to either use a one-time token and
|
||||
store that in the cookie **and** also transmit it with the form data.
|
||||
After receiving the data on the server again, you would then have to
|
||||
compare the two tokens and ensure they are equal.
|
||||
Flask provides built-in CSRF protection that can be enabled for state-changing
|
||||
requests (POST, PUT, PATCH, DELETE) using the ``Sec-Fetch-Site`` header that
|
||||
modern browsers send automatically. This header tells the server whether a
|
||||
request is coming from the same origin, the same site, or a cross-site source.
|
||||
|
||||
Enabling CSRF protection
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
CSRF protection is disabled by default. To enable it globally::
|
||||
|
||||
app.config['CSRF_PROTECTION'] = True
|
||||
|
||||
Or enable it per-route::
|
||||
|
||||
@app.route('/delete', methods=['POST'], csrf_protection=True)
|
||||
def delete():
|
||||
...
|
||||
|
||||
How it works
|
||||
~~~~~~~~~~~~
|
||||
|
||||
When enabled, CSRF protection validates requests as follows:
|
||||
|
||||
1. Requests with ``Sec-Fetch-Site: same-origin`` or ``none`` are allowed.
|
||||
2. Requests with ``Sec-Fetch-Site: same-site`` or ``cross-site`` are rejected.
|
||||
3. Requests without browser headers (API clients, curl, etc.) are allowed,
|
||||
as CSRF is exclusively a browser attack vector.
|
||||
4. For browsers that don't send ``Sec-Fetch-Site``, the ``Origin`` header
|
||||
is checked against the request host.
|
||||
|
||||
Allowing cross-origin requests
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
For legitimate cross-origin requests (OAuth callbacks, third-party embeds),
|
||||
add the origin to :data:`CSRF_TRUSTED_ORIGINS`::
|
||||
|
||||
app.config['CSRF_TRUSTED_ORIGINS'] = [
|
||||
'https://accounts.example.com',
|
||||
]
|
||||
|
||||
Exempting specific routes
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
For routes that use other protection mechanisms (webhook signature
|
||||
verification, bearer token authentication), disable CSRF per-route::
|
||||
|
||||
@app.route('/webhooks/stripe', methods=['POST'], csrf_protection=False)
|
||||
def stripe_webhook():
|
||||
# Verify Stripe signature instead
|
||||
...
|
||||
|
||||
For more information on CSRF and the ``Sec-Fetch-Site`` header, see:
|
||||
|
||||
- `MDN: Sec-Fetch-Site <https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Sec-Fetch-Site>`_
|
||||
|
||||
Why does Flask not do that for you? The ideal place for this to happen is
|
||||
the form validation framework, which does not exist in Flask.
|
||||
|
||||
.. _security-json:
|
||||
|
||||
|
|
|
|||
|
|
@ -12,11 +12,13 @@ from inspect import iscoroutinefunction
|
|||
from itertools import chain
|
||||
from types import TracebackType
|
||||
from urllib.parse import quote as _url_quote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import click
|
||||
from werkzeug.datastructures import Headers
|
||||
from werkzeug.datastructures import ImmutableDict
|
||||
from werkzeug.exceptions import BadRequestKeyError
|
||||
from werkzeug.exceptions import Forbidden
|
||||
from werkzeug.exceptions import HTTPException
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
from werkzeug.routing import BuildError
|
||||
|
|
@ -233,6 +235,9 @@ class Flask(App):
|
|||
"TEMPLATES_AUTO_RELOAD": None,
|
||||
"MAX_COOKIE_SIZE": 4093,
|
||||
"PROVIDE_AUTOMATIC_OPTIONS": True,
|
||||
"CSRF_PROTECTION": False,
|
||||
"CSRF_TRUSTED_ORIGINS": None,
|
||||
"CSRF_PROTECTED_METHODS": frozenset({"POST", "PUT", "PATCH", "DELETE"}),
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -962,6 +967,64 @@ class Flask(App):
|
|||
f"Exception on {ctx.request.path} [{ctx.request.method}]", exc_info=exc_info
|
||||
)
|
||||
|
||||
def check_csrf(self, ctx: AppContext) -> None:
|
||||
"""Check CSRF protection for the current request. Raises
|
||||
:exc:`~werkzeug.exceptions.Forbidden` if the request fails
|
||||
CSRF validation.
|
||||
|
||||
This implements fetch metadata-based CSRF protection using the
|
||||
``Sec-Fetch-Site`` header, with a fallback to ``Origin`` header
|
||||
validation for browsers that don't support fetch metadata.
|
||||
|
||||
The check is only performed for methods listed in
|
||||
:data:`CSRF_PROTECTED_METHODS` (POST, PUT, PATCH, DELETE by default)
|
||||
and only when the route has ``csrf_protection=True``.
|
||||
|
||||
.. versionadded:: 3.2
|
||||
"""
|
||||
req = ctx.request
|
||||
|
||||
# Skip if no url_rule matched (routing error will be raised later)
|
||||
if req.url_rule is None:
|
||||
return
|
||||
|
||||
# Skip if csrf protection is not enabled for this route
|
||||
if not getattr(req.url_rule, "csrf_protection", False):
|
||||
return
|
||||
|
||||
# Skip safe methods
|
||||
if req.method not in self.config["CSRF_PROTECTED_METHODS"]:
|
||||
return
|
||||
|
||||
origin = req.headers.get("Origin")
|
||||
|
||||
# Check if origin is in trusted list
|
||||
trusted_origins = self.config["CSRF_TRUSTED_ORIGINS"]
|
||||
if trusted_origins and origin in trusted_origins:
|
||||
return
|
||||
|
||||
# Check Sec-Fetch-Site header (modern browsers)
|
||||
sec_fetch_site = req.headers.get("Sec-Fetch-Site")
|
||||
if sec_fetch_site is not None:
|
||||
# same-origin and none (e.g., direct navigation) are allowed
|
||||
if sec_fetch_site in ("same-origin", "none"):
|
||||
return
|
||||
# same-site and cross-site are rejected
|
||||
raise Forbidden("CSRF validation failed: cross-origin request detected")
|
||||
|
||||
# Fallback for browsers without Sec-Fetch-Site support:
|
||||
# If neither Sec-Fetch-Site nor Origin is present, allow the request
|
||||
# (likely a non-browser client like curl or API tool)
|
||||
if origin is None:
|
||||
return
|
||||
|
||||
# If Origin is present, verify it matches the Host
|
||||
origin_host = urlparse(origin).netloc
|
||||
if origin_host == req.host:
|
||||
return
|
||||
|
||||
raise Forbidden("CSRF validation failed: origin mismatch")
|
||||
|
||||
def dispatch_request(self, ctx: AppContext) -> ft.ResponseReturnValue:
|
||||
"""Does the request dispatching. Matches the URL and returns the
|
||||
return value of the view or error handler. This does not have to
|
||||
|
|
@ -999,6 +1062,7 @@ class Flask(App):
|
|||
|
||||
try:
|
||||
request_started.send(self, _async_wrapper=self.ensure_sync)
|
||||
self.check_csrf(ctx)
|
||||
rv = self.preprocess_request(ctx)
|
||||
if rv is None:
|
||||
rv = self.dispatch_request(ctx)
|
||||
|
|
|
|||
|
|
@ -605,6 +605,7 @@ class App(Scaffold):
|
|||
endpoint: str | None = None,
|
||||
view_func: ft.RouteCallable | None = None,
|
||||
provide_automatic_options: bool | None = None,
|
||||
csrf_protection: bool | None = None,
|
||||
**options: t.Any,
|
||||
) -> None:
|
||||
if endpoint is None:
|
||||
|
|
@ -641,11 +642,19 @@ class App(Scaffold):
|
|||
else:
|
||||
provide_automatic_options = False
|
||||
|
||||
# Handle csrf_protection: check view_func attribute, then config.
|
||||
if csrf_protection is None:
|
||||
csrf_protection = getattr(view_func, "csrf_protection", None)
|
||||
|
||||
if csrf_protection is None:
|
||||
csrf_protection = self.config["CSRF_PROTECTION"]
|
||||
|
||||
# Add the required methods now.
|
||||
methods |= required_methods
|
||||
|
||||
rule_obj = self.url_rule_class(rule, methods=methods, **options)
|
||||
rule_obj.provide_automatic_options = provide_automatic_options # type: ignore[attr-defined]
|
||||
rule_obj.csrf_protection = csrf_protection # type: ignore[attr-defined]
|
||||
|
||||
self.url_map.add(rule_obj)
|
||||
if view_func is not None:
|
||||
|
|
|
|||
|
|
@ -55,6 +55,13 @@ class View:
|
|||
#: ``add_url_rule`` by default.
|
||||
provide_automatic_options: t.ClassVar[bool | None] = None
|
||||
|
||||
#: Control whether CSRF protection is enabled for this view.
|
||||
#: Uses the same default (``CSRF_PROTECTION`` config) as ``route`` and
|
||||
#: ``add_url_rule`` by default.
|
||||
#:
|
||||
#: .. versionadded:: 3.2
|
||||
csrf_protection: t.ClassVar[bool | None] = None
|
||||
|
||||
#: A list of decorators to apply, in order, to the generated view
|
||||
#: function. Remember that ``@decorator`` syntax is applied bottom
|
||||
#: to top, so the first decorator in the list would be the bottom
|
||||
|
|
@ -132,6 +139,7 @@ class View:
|
|||
view.__module__ = cls.__module__
|
||||
view.methods = cls.methods # type: ignore
|
||||
view.provide_automatic_options = cls.provide_automatic_options # type: ignore
|
||||
view.csrf_protection = cls.csrf_protection # type: ignore
|
||||
return view
|
||||
|
||||
|
||||
|
|
|
|||
344
tests/test_csrf.py
Normal file
344
tests/test_csrf.py
Normal file
|
|
@ -0,0 +1,344 @@
|
|||
"""Tests for CSRF protection using Sec-Fetch-Site header."""
|
||||
|
||||
from flask.views import MethodView
|
||||
|
||||
|
||||
class TestCSRFProtection:
|
||||
"""Test CSRF protection functionality."""
|
||||
|
||||
def test_csrf_disabled_by_default(self, app, client):
|
||||
"""CSRF protection is disabled by default."""
|
||||
|
||||
@app.route("/", methods=["POST"])
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
# Cross-origin request should succeed when CSRF is disabled
|
||||
rv = client.post(
|
||||
"/",
|
||||
headers={
|
||||
"Origin": "https://evil.com",
|
||||
"Sec-Fetch-Site": "cross-site",
|
||||
},
|
||||
)
|
||||
assert rv.status_code == 200
|
||||
|
||||
def test_csrf_enabled_via_config(self, app, client):
|
||||
"""CSRF protection can be enabled via CSRF_PROTECTION config."""
|
||||
app.config["CSRF_PROTECTION"] = True
|
||||
|
||||
@app.route("/", methods=["POST"])
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
# Cross-origin request should be rejected
|
||||
rv = client.post(
|
||||
"/",
|
||||
headers={
|
||||
"Origin": "https://evil.com",
|
||||
"Sec-Fetch-Site": "cross-site",
|
||||
},
|
||||
)
|
||||
assert rv.status_code == 403
|
||||
|
||||
def test_csrf_enabled_via_route_param(self, app, client):
|
||||
"""CSRF protection can be enabled per-route."""
|
||||
|
||||
@app.route("/", methods=["POST"], csrf_protection=True)
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
# Cross-origin request should be rejected
|
||||
rv = client.post(
|
||||
"/",
|
||||
headers={
|
||||
"Origin": "https://evil.com",
|
||||
"Sec-Fetch-Site": "cross-site",
|
||||
},
|
||||
)
|
||||
assert rv.status_code == 403
|
||||
|
||||
def test_csrf_disabled_via_route_param_overrides_config(self, app, client):
|
||||
"""Route-level csrf_protection=False overrides CSRF_PROTECTION config."""
|
||||
app.config["CSRF_PROTECTION"] = True
|
||||
|
||||
@app.route("/", methods=["POST"], csrf_protection=False)
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
# Cross-origin request should succeed
|
||||
rv = client.post(
|
||||
"/",
|
||||
headers={
|
||||
"Origin": "https://evil.com",
|
||||
"Sec-Fetch-Site": "cross-site",
|
||||
},
|
||||
)
|
||||
assert rv.status_code == 200
|
||||
|
||||
def test_csrf_allows_same_origin(self, app, client):
|
||||
"""Same-origin requests are allowed."""
|
||||
|
||||
@app.route("/", methods=["POST"], csrf_protection=True)
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
rv = client.post("/", headers={"Sec-Fetch-Site": "same-origin"})
|
||||
assert rv.status_code == 200
|
||||
|
||||
def test_csrf_allows_none_fetch_site(self, app, client):
|
||||
"""Requests with Sec-Fetch-Site: none are allowed (direct navigation)."""
|
||||
|
||||
@app.route("/", methods=["POST"], csrf_protection=True)
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
rv = client.post("/", headers={"Sec-Fetch-Site": "none"})
|
||||
assert rv.status_code == 200
|
||||
|
||||
def test_csrf_rejects_same_site(self, app, client):
|
||||
"""Same-site cross-origin requests are rejected."""
|
||||
|
||||
@app.route("/", methods=["POST"], csrf_protection=True)
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
rv = client.post("/", headers={"Sec-Fetch-Site": "same-site"})
|
||||
assert rv.status_code == 403
|
||||
|
||||
def test_csrf_rejects_cross_site(self, app, client):
|
||||
"""Cross-site requests are rejected."""
|
||||
|
||||
@app.route("/", methods=["POST"], csrf_protection=True)
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
rv = client.post("/", headers={"Sec-Fetch-Site": "cross-site"})
|
||||
assert rv.status_code == 403
|
||||
|
||||
def test_csrf_allows_get_requests(self, app, client):
|
||||
"""GET requests are not protected by CSRF (safe method)."""
|
||||
|
||||
@app.route("/", methods=["GET"], csrf_protection=True)
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
rv = client.get("/", headers={"Sec-Fetch-Site": "cross-site"})
|
||||
assert rv.status_code == 200
|
||||
|
||||
def test_csrf_allows_head_requests(self, app, client):
|
||||
"""HEAD requests are not protected by CSRF (safe method)."""
|
||||
|
||||
@app.route("/", methods=["GET", "HEAD"], csrf_protection=True)
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
rv = client.head("/", headers={"Sec-Fetch-Site": "cross-site"})
|
||||
assert rv.status_code == 200
|
||||
|
||||
def test_csrf_allows_options_requests(self, app, client):
|
||||
"""OPTIONS requests are not protected by CSRF (safe method)."""
|
||||
|
||||
@app.route("/", methods=["POST", "OPTIONS"], csrf_protection=True)
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
rv = client.options("/", headers={"Sec-Fetch-Site": "cross-site"})
|
||||
assert rv.status_code == 200
|
||||
|
||||
def test_csrf_protects_post(self, app, client):
|
||||
"""POST requests are protected."""
|
||||
|
||||
@app.route("/", methods=["POST"], csrf_protection=True)
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
rv = client.post("/", headers={"Sec-Fetch-Site": "cross-site"})
|
||||
assert rv.status_code == 403
|
||||
|
||||
def test_csrf_protects_put(self, app, client):
|
||||
"""PUT requests are protected."""
|
||||
|
||||
@app.route("/", methods=["PUT"], csrf_protection=True)
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
rv = client.put("/", headers={"Sec-Fetch-Site": "cross-site"})
|
||||
assert rv.status_code == 403
|
||||
|
||||
def test_csrf_protects_patch(self, app, client):
|
||||
"""PATCH requests are protected."""
|
||||
|
||||
@app.route("/", methods=["PATCH"], csrf_protection=True)
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
rv = client.patch("/", headers={"Sec-Fetch-Site": "cross-site"})
|
||||
assert rv.status_code == 403
|
||||
|
||||
def test_csrf_protects_delete(self, app, client):
|
||||
"""DELETE requests are protected."""
|
||||
|
||||
@app.route("/", methods=["DELETE"], csrf_protection=True)
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
rv = client.delete("/", headers={"Sec-Fetch-Site": "cross-site"})
|
||||
assert rv.status_code == 403
|
||||
|
||||
|
||||
class TestCSRFTrustedOrigins:
|
||||
"""Test CSRF_TRUSTED_ORIGINS configuration."""
|
||||
|
||||
def test_trusted_origin_allowed(self, app, client):
|
||||
"""Requests from trusted origins are allowed."""
|
||||
app.config["CSRF_TRUSTED_ORIGINS"] = ["https://trusted.com"]
|
||||
|
||||
@app.route("/", methods=["POST"], csrf_protection=True)
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
rv = client.post(
|
||||
"/",
|
||||
headers={
|
||||
"Origin": "https://trusted.com",
|
||||
"Sec-Fetch-Site": "cross-site",
|
||||
},
|
||||
)
|
||||
assert rv.status_code == 200
|
||||
|
||||
def test_untrusted_origin_rejected(self, app, client):
|
||||
"""Requests from untrusted origins are rejected."""
|
||||
app.config["CSRF_TRUSTED_ORIGINS"] = ["https://trusted.com"]
|
||||
|
||||
@app.route("/", methods=["POST"], csrf_protection=True)
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
rv = client.post(
|
||||
"/",
|
||||
headers={
|
||||
"Origin": "https://evil.com",
|
||||
"Sec-Fetch-Site": "cross-site",
|
||||
},
|
||||
)
|
||||
assert rv.status_code == 403
|
||||
|
||||
|
||||
class TestCSRFOriginFallback:
|
||||
"""Test Origin header fallback for browsers without Sec-Fetch-Site."""
|
||||
|
||||
def test_no_headers_allowed(self, app, client):
|
||||
"""Requests without Sec-Fetch-Site or Origin are allowed (non-browser)."""
|
||||
|
||||
@app.route("/", methods=["POST"], csrf_protection=True)
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
rv = client.post("/")
|
||||
assert rv.status_code == 200
|
||||
|
||||
def test_matching_origin_allowed(self, app, client):
|
||||
"""Requests with matching Origin header are allowed."""
|
||||
|
||||
@app.route("/", methods=["POST"], csrf_protection=True)
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
rv = client.post("/", headers={"Origin": "http://localhost"})
|
||||
assert rv.status_code == 200
|
||||
|
||||
def test_mismatched_origin_rejected(self, app, client):
|
||||
"""Requests with mismatched Origin header are rejected."""
|
||||
|
||||
@app.route("/", methods=["POST"], csrf_protection=True)
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
rv = client.post("/", headers={"Origin": "https://evil.com"})
|
||||
assert rv.status_code == 403
|
||||
|
||||
|
||||
class TestCSRFViewFunctionAttribute:
|
||||
"""Test csrf_protection attribute on view functions."""
|
||||
|
||||
def test_view_func_attribute_enables_csrf(self, app, client):
|
||||
"""View function csrf_protection attribute enables CSRF protection."""
|
||||
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
index.csrf_protection = True
|
||||
app.add_url_rule("/", view_func=index, methods=["POST"])
|
||||
|
||||
rv = client.post("/", headers={"Sec-Fetch-Site": "cross-site"})
|
||||
assert rv.status_code == 403
|
||||
|
||||
def test_view_func_attribute_disables_csrf(self, app, client):
|
||||
"""View function csrf_protection attribute disables CSRF protection."""
|
||||
app.config["CSRF_PROTECTION"] = True
|
||||
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
index.csrf_protection = False
|
||||
app.add_url_rule("/", view_func=index, methods=["POST"])
|
||||
|
||||
rv = client.post("/", headers={"Sec-Fetch-Site": "cross-site"})
|
||||
assert rv.status_code == 200
|
||||
|
||||
|
||||
class TestCSRFClassBasedViews:
|
||||
"""Test CSRF protection with class-based views."""
|
||||
|
||||
def test_method_view_csrf_protection(self, app, client):
|
||||
"""MethodView with csrf_protection class attribute."""
|
||||
|
||||
class MyView(MethodView):
|
||||
csrf_protection = True
|
||||
|
||||
def post(self):
|
||||
return "ok"
|
||||
|
||||
app.add_url_rule("/", view_func=MyView.as_view("myview"))
|
||||
|
||||
rv = client.post("/", headers={"Sec-Fetch-Site": "cross-site"})
|
||||
assert rv.status_code == 403
|
||||
|
||||
rv = client.post("/", headers={"Sec-Fetch-Site": "same-origin"})
|
||||
assert rv.status_code == 200
|
||||
|
||||
def test_method_view_csrf_disabled(self, app, client):
|
||||
"""MethodView with csrf_protection=False overrides config."""
|
||||
app.config["CSRF_PROTECTION"] = True
|
||||
|
||||
class MyView(MethodView):
|
||||
csrf_protection = False
|
||||
|
||||
def post(self):
|
||||
return "ok"
|
||||
|
||||
app.add_url_rule("/", view_func=MyView.as_view("myview"))
|
||||
|
||||
rv = client.post("/", headers={"Sec-Fetch-Site": "cross-site"})
|
||||
assert rv.status_code == 200
|
||||
|
||||
|
||||
class TestCSRFProtectedMethods:
|
||||
"""Test CSRF_PROTECTED_METHODS configuration."""
|
||||
|
||||
def test_custom_protected_methods(self, app, client):
|
||||
"""Custom CSRF_PROTECTED_METHODS configuration."""
|
||||
app.config["CSRF_PROTECTED_METHODS"] = frozenset({"POST"})
|
||||
|
||||
@app.route("/", methods=["POST", "DELETE"], csrf_protection=True)
|
||||
def index():
|
||||
return "ok"
|
||||
|
||||
# POST should still be protected
|
||||
rv = client.post("/", headers={"Sec-Fetch-Site": "cross-site"})
|
||||
assert rv.status_code == 403
|
||||
|
||||
# DELETE should not be protected (not in CSRF_PROTECTED_METHODS)
|
||||
rv = client.delete("/", headers={"Sec-Fetch-Site": "cross-site"})
|
||||
assert rv.status_code == 200
|
||||
Loading…
Add table
Add a link
Reference in a new issue