flask/tests/test_csrf.py

344 lines
11 KiB
Python

"""Tests for CSRF protection using Sec-Fetch-Site header."""
from flask.views import MethodView
class TestCSRFProtection:
"""Test CSRF protection functionality."""
def test_csrf_disabled_by_default(self, app, client):
"""CSRF protection is disabled by default."""
@app.route("/", methods=["POST"])
def index():
return "ok"
# Cross-origin request should succeed when CSRF is disabled
rv = client.post(
"/",
headers={
"Origin": "https://evil.com",
"Sec-Fetch-Site": "cross-site",
},
)
assert rv.status_code == 200
def test_csrf_enabled_via_config(self, app, client):
"""CSRF protection can be enabled via CSRF_PROTECTION config."""
app.config["CSRF_PROTECTION"] = True
@app.route("/", methods=["POST"])
def index():
return "ok"
# Cross-origin request should be rejected
rv = client.post(
"/",
headers={
"Origin": "https://evil.com",
"Sec-Fetch-Site": "cross-site",
},
)
assert rv.status_code == 403
def test_csrf_enabled_via_route_param(self, app, client):
"""CSRF protection can be enabled per-route."""
@app.route("/", methods=["POST"], csrf_protection=True)
def index():
return "ok"
# Cross-origin request should be rejected
rv = client.post(
"/",
headers={
"Origin": "https://evil.com",
"Sec-Fetch-Site": "cross-site",
},
)
assert rv.status_code == 403
def test_csrf_disabled_via_route_param_overrides_config(self, app, client):
"""Route-level csrf_protection=False overrides CSRF_PROTECTION config."""
app.config["CSRF_PROTECTION"] = True
@app.route("/", methods=["POST"], csrf_protection=False)
def index():
return "ok"
# Cross-origin request should succeed
rv = client.post(
"/",
headers={
"Origin": "https://evil.com",
"Sec-Fetch-Site": "cross-site",
},
)
assert rv.status_code == 200
def test_csrf_allows_same_origin(self, app, client):
"""Same-origin requests are allowed."""
@app.route("/", methods=["POST"], csrf_protection=True)
def index():
return "ok"
rv = client.post("/", headers={"Sec-Fetch-Site": "same-origin"})
assert rv.status_code == 200
def test_csrf_allows_none_fetch_site(self, app, client):
"""Requests with Sec-Fetch-Site: none are allowed (direct navigation)."""
@app.route("/", methods=["POST"], csrf_protection=True)
def index():
return "ok"
rv = client.post("/", headers={"Sec-Fetch-Site": "none"})
assert rv.status_code == 200
def test_csrf_rejects_same_site(self, app, client):
"""Same-site cross-origin requests are rejected."""
@app.route("/", methods=["POST"], csrf_protection=True)
def index():
return "ok"
rv = client.post("/", headers={"Sec-Fetch-Site": "same-site"})
assert rv.status_code == 403
def test_csrf_rejects_cross_site(self, app, client):
"""Cross-site requests are rejected."""
@app.route("/", methods=["POST"], csrf_protection=True)
def index():
return "ok"
rv = client.post("/", headers={"Sec-Fetch-Site": "cross-site"})
assert rv.status_code == 403
def test_csrf_allows_get_requests(self, app, client):
"""GET requests are not protected by CSRF (safe method)."""
@app.route("/", methods=["GET"], csrf_protection=True)
def index():
return "ok"
rv = client.get("/", headers={"Sec-Fetch-Site": "cross-site"})
assert rv.status_code == 200
def test_csrf_allows_head_requests(self, app, client):
"""HEAD requests are not protected by CSRF (safe method)."""
@app.route("/", methods=["GET", "HEAD"], csrf_protection=True)
def index():
return "ok"
rv = client.head("/", headers={"Sec-Fetch-Site": "cross-site"})
assert rv.status_code == 200
def test_csrf_allows_options_requests(self, app, client):
"""OPTIONS requests are not protected by CSRF (safe method)."""
@app.route("/", methods=["POST", "OPTIONS"], csrf_protection=True)
def index():
return "ok"
rv = client.options("/", headers={"Sec-Fetch-Site": "cross-site"})
assert rv.status_code == 200
def test_csrf_protects_post(self, app, client):
"""POST requests are protected."""
@app.route("/", methods=["POST"], csrf_protection=True)
def index():
return "ok"
rv = client.post("/", headers={"Sec-Fetch-Site": "cross-site"})
assert rv.status_code == 403
def test_csrf_protects_put(self, app, client):
"""PUT requests are protected."""
@app.route("/", methods=["PUT"], csrf_protection=True)
def index():
return "ok"
rv = client.put("/", headers={"Sec-Fetch-Site": "cross-site"})
assert rv.status_code == 403
def test_csrf_protects_patch(self, app, client):
"""PATCH requests are protected."""
@app.route("/", methods=["PATCH"], csrf_protection=True)
def index():
return "ok"
rv = client.patch("/", headers={"Sec-Fetch-Site": "cross-site"})
assert rv.status_code == 403
def test_csrf_protects_delete(self, app, client):
"""DELETE requests are protected."""
@app.route("/", methods=["DELETE"], csrf_protection=True)
def index():
return "ok"
rv = client.delete("/", headers={"Sec-Fetch-Site": "cross-site"})
assert rv.status_code == 403
class TestCSRFTrustedOrigins:
"""Test CSRF_TRUSTED_ORIGINS configuration."""
def test_trusted_origin_allowed(self, app, client):
"""Requests from trusted origins are allowed."""
app.config["CSRF_TRUSTED_ORIGINS"] = ["https://trusted.com"]
@app.route("/", methods=["POST"], csrf_protection=True)
def index():
return "ok"
rv = client.post(
"/",
headers={
"Origin": "https://trusted.com",
"Sec-Fetch-Site": "cross-site",
},
)
assert rv.status_code == 200
def test_untrusted_origin_rejected(self, app, client):
"""Requests from untrusted origins are rejected."""
app.config["CSRF_TRUSTED_ORIGINS"] = ["https://trusted.com"]
@app.route("/", methods=["POST"], csrf_protection=True)
def index():
return "ok"
rv = client.post(
"/",
headers={
"Origin": "https://evil.com",
"Sec-Fetch-Site": "cross-site",
},
)
assert rv.status_code == 403
class TestCSRFOriginFallback:
"""Test Origin header fallback for browsers without Sec-Fetch-Site."""
def test_no_headers_allowed(self, app, client):
"""Requests without Sec-Fetch-Site or Origin are allowed (non-browser)."""
@app.route("/", methods=["POST"], csrf_protection=True)
def index():
return "ok"
rv = client.post("/")
assert rv.status_code == 200
def test_matching_origin_allowed(self, app, client):
"""Requests with matching Origin header are allowed."""
@app.route("/", methods=["POST"], csrf_protection=True)
def index():
return "ok"
rv = client.post("/", headers={"Origin": "http://localhost"})
assert rv.status_code == 200
def test_mismatched_origin_rejected(self, app, client):
"""Requests with mismatched Origin header are rejected."""
@app.route("/", methods=["POST"], csrf_protection=True)
def index():
return "ok"
rv = client.post("/", headers={"Origin": "https://evil.com"})
assert rv.status_code == 403
class TestCSRFViewFunctionAttribute:
"""Test csrf_protection attribute on view functions."""
def test_view_func_attribute_enables_csrf(self, app, client):
"""View function csrf_protection attribute enables CSRF protection."""
def index():
return "ok"
index.csrf_protection = True
app.add_url_rule("/", view_func=index, methods=["POST"])
rv = client.post("/", headers={"Sec-Fetch-Site": "cross-site"})
assert rv.status_code == 403
def test_view_func_attribute_disables_csrf(self, app, client):
"""View function csrf_protection attribute disables CSRF protection."""
app.config["CSRF_PROTECTION"] = True
def index():
return "ok"
index.csrf_protection = False
app.add_url_rule("/", view_func=index, methods=["POST"])
rv = client.post("/", headers={"Sec-Fetch-Site": "cross-site"})
assert rv.status_code == 200
class TestCSRFClassBasedViews:
"""Test CSRF protection with class-based views."""
def test_method_view_csrf_protection(self, app, client):
"""MethodView with csrf_protection class attribute."""
class MyView(MethodView):
csrf_protection = True
def post(self):
return "ok"
app.add_url_rule("/", view_func=MyView.as_view("myview"))
rv = client.post("/", headers={"Sec-Fetch-Site": "cross-site"})
assert rv.status_code == 403
rv = client.post("/", headers={"Sec-Fetch-Site": "same-origin"})
assert rv.status_code == 200
def test_method_view_csrf_disabled(self, app, client):
"""MethodView with csrf_protection=False overrides config."""
app.config["CSRF_PROTECTION"] = True
class MyView(MethodView):
csrf_protection = False
def post(self):
return "ok"
app.add_url_rule("/", view_func=MyView.as_view("myview"))
rv = client.post("/", headers={"Sec-Fetch-Site": "cross-site"})
assert rv.status_code == 200
class TestCSRFProtectedMethods:
"""Test CSRF_PROTECTED_METHODS configuration."""
def test_custom_protected_methods(self, app, client):
"""Custom CSRF_PROTECTED_METHODS configuration."""
app.config["CSRF_PROTECTED_METHODS"] = frozenset({"POST"})
@app.route("/", methods=["POST", "DELETE"], csrf_protection=True)
def index():
return "ok"
# POST should still be protected
rv = client.post("/", headers={"Sec-Fetch-Site": "cross-site"})
assert rv.status_code == 403
# DELETE should not be protected (not in CSRF_PROTECTED_METHODS)
rv = client.delete("/", headers={"Sec-Fetch-Site": "cross-site"})
assert rv.status_code == 200