diff --git a/CHANGES.rst b/CHANGES.rst index c0313eb7..f1806fa4 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,5 +1,14 @@ .. currentmodule:: flask +Version 2.0.0 +------------- + +Unreleased + +- Add :meth:`sessions.SessionInterface.get_cookie_name` to allow + setting the session cookie name dynamically. :pr:`3369` + + Version 1.1.2 ------------- diff --git a/src/flask/sessions.py b/src/flask/sessions.py index c57ba29c..fe2a20ee 100644 --- a/src/flask/sessions.py +++ b/src/flask/sessions.py @@ -173,6 +173,13 @@ class SessionInterface(object): """ return isinstance(obj, self.null_session_class) + def get_cookie_name(self, app): + """Returns the name of the session cookie. + + Uses ``app.session_cookie_name`` which is set to ``SESSION_COOKIE_NAME`` + """ + return app.session_cookie_name + def get_cookie_domain(self, app): """Returns the domain that should be set for the session cookie. @@ -340,7 +347,7 @@ class SecureCookieSessionInterface(SessionInterface): s = self.get_signing_serializer(app) if s is None: return None - val = request.cookies.get(app.session_cookie_name) + val = request.cookies.get(self.get_cookie_name(app)) if not val: return self.session_class() max_age = total_seconds(app.permanent_session_lifetime) @@ -351,6 +358,7 @@ class SecureCookieSessionInterface(SessionInterface): return self.session_class() def save_session(self, app, session, response): + name = self.get_cookie_name(app) domain = self.get_cookie_domain(app) path = self.get_cookie_path(app) @@ -358,9 +366,7 @@ class SecureCookieSessionInterface(SessionInterface): # If the session is empty, return without setting the cookie. if not session: if session.modified: - response.delete_cookie( - app.session_cookie_name, domain=domain, path=path - ) + response.delete_cookie(name, domain=domain, path=path) return @@ -377,7 +383,7 @@ class SecureCookieSessionInterface(SessionInterface): expires = self.get_expiration_time(app, session) val = self.get_signing_serializer(app).dumps(dict(session)) response.set_cookie( - app.session_cookie_name, + name, val, expires=expires, httponly=httponly, diff --git a/tests/test_reqctx.py b/tests/test_reqctx.py index 90eae9d8..8f00034b 100644 --- a/tests/test_reqctx.py +++ b/tests/test_reqctx.py @@ -11,6 +11,7 @@ import pytest import flask +from flask.sessions import SecureCookieSessionInterface from flask.sessions import SessionInterface try: @@ -229,6 +230,58 @@ def test_session_error_pops_context(): assert not flask.current_app +def test_session_dynamic_cookie_name(): + + # This session interface will use a cookie with a different name if the + # requested url ends with the string "dynamic_cookie" + class PathAwareSessionInterface(SecureCookieSessionInterface): + def get_cookie_name(self, app): + if flask.request.url.endswith("dynamic_cookie"): + return "dynamic_cookie_name" + else: + return super(PathAwareSessionInterface, self).get_cookie_name(app) + + class CustomFlask(flask.Flask): + session_interface = PathAwareSessionInterface() + + app = CustomFlask(__name__) + app.secret_key = "secret_key" + + @app.route("/set", methods=["POST"]) + def set(): + flask.session["value"] = flask.request.form["value"] + return "value set" + + @app.route("/get") + def get(): + v = flask.session.get("value", "None") + return v + + @app.route("/set_dynamic_cookie", methods=["POST"]) + def set_dynamic_cookie(): + flask.session["value"] = flask.request.form["value"] + return "value set" + + @app.route("/get_dynamic_cookie") + def get_dynamic_cookie(): + v = flask.session.get("value", "None") + return v + + test_client = app.test_client() + + # first set the cookie in both /set urls but each with a different value + assert test_client.post("/set", data={"value": "42"}).data == b"value set" + assert ( + test_client.post("/set_dynamic_cookie", data={"value": "616"}).data + == b"value set" + ) + + # now check that the relevant values come back - meaning that different + # cookies are being used for the urls that end with "dynamic cookie" + assert test_client.get("/get").data == b"42" + assert test_client.get("/get_dynamic_cookie").data == b"616" + + def test_bad_environ_raises_bad_request(): app = flask.Flask(__name__)