From 097353695e3178a38403b204ae4889c8a32bf997 Mon Sep 17 00:00:00 2001 From: Armin Ronacher Date: Tue, 14 May 2013 11:00:04 +0100 Subject: [PATCH] Added flask.copy_current_request_context which simplies working with greenlets --- CHANGES | 2 + docs/api.rst | 2 + flask/__init__.py | 2 +- flask/app.py | 6 ++ flask/ctx.py | 77 +++++++++++++++- flask/testsuite/basic.py | 110 +--------------------- flask/testsuite/reqctx.py | 187 ++++++++++++++++++++++++++++++++++++++ 7 files changed, 272 insertions(+), 114 deletions(-) create mode 100644 flask/testsuite/reqctx.py diff --git a/CHANGES b/CHANGES index ea7be9a7..e383b56f 100644 --- a/CHANGES +++ b/CHANGES @@ -52,6 +52,8 @@ Release date to be decided. - Changed logic for picking defaults for cookie values from sessions to work better with Google Chrome. - Added `message_flashed` signal that simplifies flashing testing. +- Added support for copying of request contexts for better working with + greenlets. Version 0.9 ----------- diff --git a/docs/api.rst b/docs/api.rst index 6e08241e..291bfabb 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -280,6 +280,8 @@ Useful Functions and Classes .. autofunction:: has_request_context +.. autofunction:: copy_current_request_context + .. autofunction:: has_app_context .. autofunction:: url_for diff --git a/flask/__init__.py b/flask/__init__.py index 52eec667..978a4a4c 100644 --- a/flask/__init__.py +++ b/flask/__init__.py @@ -26,7 +26,7 @@ from .helpers import url_for, flash, send_file, send_from_directory, \ from .globals import current_app, g, request, session, _request_ctx_stack, \ _app_ctx_stack from .ctx import has_request_context, has_app_context, \ - after_this_request + after_this_request, copy_current_request_context from .module import Module from .blueprints import Blueprint from .templating import render_template, render_template_string diff --git a/flask/app.py b/flask/app.py index 373479f5..fba200e8 100644 --- a/flask/app.py +++ b/flask/app.py @@ -1821,3 +1821,9 @@ class Flask(_PackageBoundObject): def __call__(self, environ, start_response): """Shortcut for :attr:`wsgi_app`.""" return self.wsgi_app(environ, start_response) + + def __repr__(self): + return '<%s %r>' % ( + self.__class__.__name__, + self.name, + ) diff --git a/flask/ctx.py b/flask/ctx.py index a1d9af62..6b271687 100644 --- a/flask/ctx.py +++ b/flask/ctx.py @@ -10,6 +10,7 @@ """ import sys +from functools import update_wrapper from werkzeug.exceptions import HTTPException @@ -19,7 +20,24 @@ from .module import blueprint_is_module class _AppCtxGlobals(object): """A plain object.""" - pass + + def __getitem__(self, name): + try: + return getattr(self, name) + except AttributeError: + return None + + def __setitem__(self, name, value): + setattr(self, name, value) + + def __delitem__(self, name, value): + delattr(self, name, value) + + def __repr__(self): + top = _app_ctx_stack.top + if top is not None: + return '' % top.app.name + return object.__repr__(self) def after_this_request(f): @@ -47,6 +65,41 @@ def after_this_request(f): return f +def copy_current_request_context(f): + """A helper function that decorates a function to retain the current + request context. This is useful when working with greenlets. The moment + the function is decorated a copy of the request context is created and + then pushed when the function is called. + + Example:: + + import gevent + from flask import copy_current_request_context + + @app.route('/') + def index(): + @copy_current_request_context + def do_some_work(): + # do some work here, it can access flask.request like you + # would otherwise in the view function. + ... + gevent.spawn(do_some_work) + return 'Regular response' + + .. versionadded:: 0.10 + """ + top = _request_ctx_stack.top + if top is None: + raise RuntimeError('This decorator can only be used at local scopes ' + 'when a request context is on the stack. For instance within ' + 'view functions.') + reqctx = top.copy() + def wrapper(*args, **kwargs): + with reqctx: + return f(*args, **kwargs) + return update_wrapper(wrapper, f) + + def has_request_context(): """If you have code that wants to test if a request context is there or not this function can be used. For instance, you may want to take advantage @@ -161,9 +214,11 @@ class RequestContext(object): that situation, otherwise your unittests will leak memory. """ - def __init__(self, app, environ): + def __init__(self, app, environ, request=None): self.app = app - self.request = app.request_class(environ) + if request is None: + request = app.request_class(environ) + self.request = request self.url_adapter = app.create_url_adapter(self.request) self.flashes = None self.session = None @@ -202,6 +257,20 @@ class RequestContext(object): g = property(_get_g, _set_g) del _get_g, _set_g + def copy(self): + """Creates a copy of this request context with the same request object. + This can be used to move a request context to a different greenlet. + Because the actual request object is the same this cannot be used to + move a request context to a different thread unless access to the + request object is locked. + + .. versionadded:: 0.10 + """ + return self.__class__(self.app, + environ=self.request.environ, + request=self.request + ) + def match_request(self): """Can be overridden by a subclass to hook into the matching of the request. @@ -299,5 +368,5 @@ class RequestContext(object): self.__class__.__name__, self.request.url, self.request.method, - self.app.name + self.app.name, ) diff --git a/flask/testsuite/basic.py b/flask/testsuite/basic.py index aaf02fce..445b6b41 100644 --- a/flask/testsuite/basic.py +++ b/flask/testsuite/basic.py @@ -666,19 +666,6 @@ class BasicFunctionalityTestCase(FlaskTestCase): else: self.fail('Expected exception') - def test_teardown_on_pop(self): - buffer = [] - app = flask.Flask(__name__) - @app.teardown_request - def end_of_request(exception): - buffer.append(exception) - - ctx = app.test_request_context() - ctx.push() - self.assert_equal(buffer, []) - ctx.pop() - self.assert_equal(buffer, [None]) - def test_response_creation(self): app = flask.Flask(__name__) @app.route('/unicode') @@ -821,53 +808,6 @@ class BasicFunctionalityTestCase(FlaskTestCase): self.assert_equal(repr(flask.g), '') self.assertFalse(flask.g) - def test_proper_test_request_context(self): - app = flask.Flask(__name__) - app.config.update( - SERVER_NAME='localhost.localdomain:5000' - ) - - @app.route('/') - def index(): - return None - - @app.route('/', subdomain='foo') - def sub(): - return None - - with app.test_request_context('/'): - self.assert_equal(flask.url_for('index', _external=True), 'http://localhost.localdomain:5000/') - - with app.test_request_context('/'): - self.assert_equal(flask.url_for('sub', _external=True), 'http://foo.localhost.localdomain:5000/') - - try: - with app.test_request_context('/', environ_overrides={'HTTP_HOST': 'localhost'}): - pass - except Exception, e: - self.assert_(isinstance(e, ValueError)) - self.assert_equal(str(e), "the server name provided " + - "('localhost.localdomain:5000') does not match the " + \ - "server name from the WSGI environment ('localhost')") - - try: - app.config.update(SERVER_NAME='localhost') - with app.test_request_context('/', environ_overrides={'SERVER_NAME': 'localhost'}): - pass - except ValueError, e: - raise ValueError( - "No ValueError exception should have been raised \"%s\"" % e - ) - - try: - app.config.update(SERVER_NAME='localhost:80') - with app.test_request_context('/', environ_overrides={'SERVER_NAME': 'localhost:80'}): - pass - except ValueError, e: - raise ValueError( - "No ValueError exception should have been raised \"%s\"" % e - ) - def test_test_app_proper_environ(self): app = flask.Flask(__name__) app.config.update( @@ -1012,7 +952,7 @@ class BasicFunctionalityTestCase(FlaskTestCase): values = dict() app.inject_url_defaults('foo.bar.baz.view', values) expected = dict(page='login') - self.assert_equal(values, expected) + self.assert_equal(values, expected) with app.test_request_context('/somepage'): url = flask.url_for('foo.bar.baz.view') @@ -1127,53 +1067,6 @@ class BasicFunctionalityTestCase(FlaskTestCase): self.assert_(flask._app_ctx_stack.top is None) -class ContextTestCase(FlaskTestCase): - - def test_context_binding(self): - app = flask.Flask(__name__) - @app.route('/') - def index(): - return 'Hello %s!' % flask.request.args['name'] - @app.route('/meh') - def meh(): - return flask.request.url - - with app.test_request_context('/?name=World'): - self.assert_equal(index(), 'Hello World!') - with app.test_request_context('/meh'): - self.assert_equal(meh(), 'http://localhost/meh') - self.assert_(flask._request_ctx_stack.top is None) - - def test_context_test(self): - app = flask.Flask(__name__) - self.assert_(not flask.request) - self.assert_(not flask.has_request_context()) - ctx = app.test_request_context() - ctx.push() - try: - self.assert_(flask.request) - self.assert_(flask.has_request_context()) - finally: - ctx.pop() - - def test_manual_context_binding(self): - app = flask.Flask(__name__) - @app.route('/') - def index(): - return 'Hello %s!' % flask.request.args['name'] - - ctx = app.test_request_context('/?name=World') - ctx.push() - self.assert_equal(index(), 'Hello World!') - ctx.pop() - try: - index() - except RuntimeError: - pass - else: - self.assert_(0, 'expected runtime error') - - class SubdomainTestCase(FlaskTestCase): def test_basic_support(self): @@ -1251,6 +1144,5 @@ class SubdomainTestCase(FlaskTestCase): def suite(): suite = unittest.TestSuite() suite.addTest(unittest.makeSuite(BasicFunctionalityTestCase)) - suite.addTest(unittest.makeSuite(ContextTestCase)) suite.addTest(unittest.makeSuite(SubdomainTestCase)) return suite diff --git a/flask/testsuite/reqctx.py b/flask/testsuite/reqctx.py new file mode 100644 index 00000000..a93523e7 --- /dev/null +++ b/flask/testsuite/reqctx.py @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- +""" + flask.testsuite.reqctx + ~~~~~~~~~~~~~~~~~~~~~~ + + Tests the request context. + + :copyright: (c) 2012 by Armin Ronacher. + :license: BSD, see LICENSE for more details. +""" + +from __future__ import with_statement + +import flask +import unittest +try: + from greenlet import greenlet +except ImportError: + greenlet = None +from flask.testsuite import FlaskTestCase + + +class RequestContextTestCase(FlaskTestCase): + + def test_teardown_on_pop(self): + buffer = [] + app = flask.Flask(__name__) + @app.teardown_request + def end_of_request(exception): + buffer.append(exception) + + ctx = app.test_request_context() + ctx.push() + self.assert_equal(buffer, []) + ctx.pop() + self.assert_equal(buffer, [None]) + + def test_proper_test_request_context(self): + app = flask.Flask(__name__) + app.config.update( + SERVER_NAME='localhost.localdomain:5000' + ) + + @app.route('/') + def index(): + return None + + @app.route('/', subdomain='foo') + def sub(): + return None + + with app.test_request_context('/'): + self.assert_equal(flask.url_for('index', _external=True), 'http://localhost.localdomain:5000/') + + with app.test_request_context('/'): + self.assert_equal(flask.url_for('sub', _external=True), 'http://foo.localhost.localdomain:5000/') + + try: + with app.test_request_context('/', environ_overrides={'HTTP_HOST': 'localhost'}): + pass + except Exception, e: + self.assert_(isinstance(e, ValueError)) + self.assert_equal(str(e), "the server name provided " + + "('localhost.localdomain:5000') does not match the " + \ + "server name from the WSGI environment ('localhost')") + + try: + app.config.update(SERVER_NAME='localhost') + with app.test_request_context('/', environ_overrides={'SERVER_NAME': 'localhost'}): + pass + except ValueError, e: + raise ValueError( + "No ValueError exception should have been raised \"%s\"" % e + ) + + try: + app.config.update(SERVER_NAME='localhost:80') + with app.test_request_context('/', environ_overrides={'SERVER_NAME': 'localhost:80'}): + pass + except ValueError, e: + raise ValueError( + "No ValueError exception should have been raised \"%s\"" % e + ) + + def test_context_binding(self): + app = flask.Flask(__name__) + @app.route('/') + def index(): + return 'Hello %s!' % flask.request.args['name'] + @app.route('/meh') + def meh(): + return flask.request.url + + with app.test_request_context('/?name=World'): + self.assert_equal(index(), 'Hello World!') + with app.test_request_context('/meh'): + self.assert_equal(meh(), 'http://localhost/meh') + self.assert_(flask._request_ctx_stack.top is None) + + def test_context_test(self): + app = flask.Flask(__name__) + self.assert_(not flask.request) + self.assert_(not flask.has_request_context()) + ctx = app.test_request_context() + ctx.push() + try: + self.assert_(flask.request) + self.assert_(flask.has_request_context()) + finally: + ctx.pop() + + def test_manual_context_binding(self): + app = flask.Flask(__name__) + @app.route('/') + def index(): + return 'Hello %s!' % flask.request.args['name'] + + ctx = app.test_request_context('/?name=World') + ctx.push() + self.assert_equal(index(), 'Hello World!') + ctx.pop() + try: + index() + except RuntimeError: + pass + else: + self.assert_(0, 'expected runtime error') + + def test_greenlet_context_copying(self): + app = flask.Flask(__name__) + greenlets = [] + + @app.route('/') + def index(): + reqctx = flask._request_ctx_stack.top.copy() + def g(): + self.assert_(not flask.request) + self.assert_(not flask.current_app) + with reqctx: + self.assert_(flask.request) + self.assert_equal(flask.current_app, app) + self.assert_equal(flask.request.path, '/') + self.assert_equal(flask.request.args['foo'], 'bar') + self.assert_(not flask.request) + return 42 + greenlets.append(greenlet(g)) + return 'Hello World!' + + rv = app.test_client().get('/?foo=bar') + self.assert_equal(rv.data, 'Hello World!') + + result = greenlets[0].run() + self.assert_equal(result, 42) + + def test_greenlet_context_copying_api(self): + app = flask.Flask(__name__) + greenlets = [] + + @app.route('/') + def index(): + reqctx = flask._request_ctx_stack.top.copy() + @flask.copy_current_request_context + def g(): + self.assert_(flask.request) + self.assert_equal(flask.current_app, app) + self.assert_equal(flask.request.path, '/') + self.assert_equal(flask.request.args['foo'], 'bar') + return 42 + greenlets.append(greenlet(g)) + return 'Hello World!' + + rv = app.test_client().get('/?foo=bar') + self.assert_equal(rv.data, 'Hello World!') + + result = greenlets[0].run() + self.assert_equal(result, 42) + + # Disable test if we don't have greenlets available + if greenlet is None: + test_greenlet_context_copying = None + test_greenlet_context_copying_api = None + + +def suite(): + suite = unittest.TestSuite() + suite.addTest(unittest.makeSuite(RequestContextTestCase)) + return suite