converters have access to session

This commit is contained in:
David Lord 2021-05-14 08:11:09 -07:00
parent 8648750997
commit a7b02b3a07
No known key found for this signature in database
GPG key ID: 7A1C87E3F5BC42A8
4 changed files with 28 additions and 16 deletions

View file

@ -20,6 +20,9 @@ Unreleased
the endpoint name. :issue:`4041` the endpoint name. :issue:`4041`
- Combine URL prefixes when nesting blueprints that were created with - Combine URL prefixes when nesting blueprints that were created with
a ``url_prefix`` value. :issue:`4037` a ``url_prefix`` value. :issue:`4037`
- Roll back a change to the order that URL matching was done. The
URL is again matched after the session is loaded, so the session is
available in custom URL converters. :issue:`4053`
Version 2.0.0 Version 2.0.0

View file

@ -395,9 +395,6 @@ class RequestContext:
_request_ctx_stack.push(self) _request_ctx_stack.push(self)
if self.url_adapter is not None:
self.match_request()
# Open the session at the moment that the request context is available. # Open the session at the moment that the request context is available.
# This allows a custom open_session method to use the request context. # This allows a custom open_session method to use the request context.
# Only open a new session if this is the first time the request was # Only open a new session if this is the first time the request was
@ -409,6 +406,11 @@ class RequestContext:
if self.session is None: if self.session is None:
self.session = session_interface.make_null_session(self.app) self.session = session_interface.make_null_session(self.app)
# Match the request URL after loading the session, so that the
# session is available in custom URL converters.
if self.url_adapter is not None:
self.match_request()
def pop(self, exc: t.Optional[BaseException] = _sentinel) -> None: # type: ignore def pop(self, exc: t.Optional[BaseException] = _sentinel) -> None: # type: ignore
"""Pops the request context and unbinds it by doing that. This will """Pops the request context and unbinds it by doing that. This will
also trigger the execution of functions registered by the also trigger the execution of functions registered by the

View file

@ -1,6 +1,7 @@
from werkzeug.routing import BaseConverter from werkzeug.routing import BaseConverter
from flask import has_request_context from flask import request
from flask import session
from flask import url_for from flask import url_for
@ -28,12 +29,13 @@ def test_custom_converters(app, client):
def test_context_available(app, client): def test_context_available(app, client):
class ContextConverter(BaseConverter): class ContextConverter(BaseConverter):
def to_python(self, value): def to_python(self, value):
assert has_request_context() assert request is not None
assert session is not None
return value return value
app.url_map.converters["ctx"] = ContextConverter app.url_map.converters["ctx"] = ContextConverter
@app.route("/<ctx:name>") @app.get("/<ctx:name>")
def index(name): def index(name):
return name return name

View file

@ -2,21 +2,26 @@ import flask
from flask.sessions import SessionInterface from flask.sessions import SessionInterface
def test_open_session_endpoint_not_none(): def test_open_session_with_endpoint():
# Define a session interface that breaks if request.endpoint is None """If request.endpoint (or other URL matching behavior) is needed
while loading the session, RequestContext.match_request() can be
called manually.
"""
class MySessionInterface(SessionInterface): class MySessionInterface(SessionInterface):
def save_session(self): def save_session(self, app, session, response):
pass pass
def open_session(self, _, request): def open_session(self, app, request):
flask._request_ctx_stack.top.match_request()
assert request.endpoint is not None assert request.endpoint is not None
def index():
return "Hello World!"
# Confirm a 200 response, indicating that request.endpoint was NOT None
app = flask.Flask(__name__) app = flask.Flask(__name__)
app.route("/")(index)
app.session_interface = MySessionInterface() app.session_interface = MySessionInterface()
response = app.test_client().open("/")
@app.get("/")
def index():
return "Hello, World!"
response = app.test_client().get("/")
assert response.status_code == 200 assert response.status_code == 200