diff --git a/flask/testing.py b/flask/testing.py index 782b40f6..bdd3860f 100644 --- a/flask/testing.py +++ b/flask/testing.py @@ -15,6 +15,7 @@ from __future__ import with_statement from contextlib import contextmanager from werkzeug.test import Client, EnvironBuilder from flask import _request_ctx_stack +from urlparse import urlparse def make_test_environ_builder(app, path='/', base_url=None, *args, **kwargs): @@ -22,9 +23,12 @@ def make_test_environ_builder(app, path='/', base_url=None, *args, **kwargs): http_host = app.config.get('SERVER_NAME') app_root = app.config.get('APPLICATION_ROOT') if base_url is None: - base_url = 'http://%s/' % (http_host or 'localhost') + url = urlparse(path) + base_url = 'http://%s/' % (url.netloc or http_host or 'localhost') if app_root: base_url += app_root.lstrip('/') + if url.netloc: + path = url.path return EnvironBuilder(path, base_url, *args, **kwargs) diff --git a/flask/testsuite/testing.py b/flask/testsuite/testing.py index 0e6feb60..92e3f267 100644 --- a/flask/testsuite/testing.py +++ b/flask/testsuite/testing.py @@ -198,7 +198,46 @@ class TestToolsTestCase(FlaskTestCase): self.assert_equal(called, [None, None]) +class SubdomainTestCase(FlaskTestCase): + + def setUp(self): + self.app = flask.Flask(__name__) + self.app.config['SERVER_NAME'] = 'example.com' + self.client = self.app.test_client() + + self._ctx = self.app.test_request_context() + self._ctx.push() + + def tearDown(self): + if self._ctx is not None: + self._ctx.pop() + + def test_subdomain(self): + @self.app.route('/', subdomain='') + def view(company_id): + return company_id + + url = flask.url_for('view', company_id='xxx') + response = self.client.get(url) + + self.assertEquals(200, response.status_code) + self.assertEquals('xxx', response.data) + + + def test_nosubdomain(self): + @self.app.route('/') + def view(company_id): + return company_id + + url = flask.url_for('view', company_id='xxx') + response = self.client.get(url) + + self.assertEquals(200, response.status_code) + self.assertEquals('xxx', response.data) + + def suite(): suite = unittest.TestSuite() suite.addTest(unittest.makeSuite(TestToolsTestCase)) + suite.addTest(unittest.makeSuite(SubdomainTestCase)) return suite