diff --git a/examples/blueprintexample/blueprintexample_test.py b/examples/blueprintexample/blueprintexample_test.py deleted file mode 100644 index b8f93414..00000000 --- a/examples/blueprintexample/blueprintexample_test.py +++ /dev/null @@ -1,36 +0,0 @@ -# -*- coding: utf-8 -*- -""" - Blueprint Example Tests - ~~~~~~~~~~~~~~ - - Tests the Blueprint example app -""" -import blueprintexample -import unittest - - -class BlueprintExampleTestCase(unittest.TestCase): - - def setUp(self): - self.app = blueprintexample.app.test_client() - - def test_urls(self): - r = self.app.get('/') - self.assertEquals(r.status_code, 200) - - r = self.app.get('/hello') - self.assertEquals(r.status_code, 200) - - r = self.app.get('/world') - self.assertEquals(r.status_code, 200) - - #second blueprint instance - r = self.app.get('/pages/hello') - self.assertEquals(r.status_code, 200) - - r = self.app.get('/pages/world') - self.assertEquals(r.status_code, 200) - - -if __name__ == '__main__': - unittest.main() diff --git a/examples/blueprintexample/test_blueprintexample.py b/examples/blueprintexample/test_blueprintexample.py new file mode 100644 index 00000000..2f3dd93f --- /dev/null +++ b/examples/blueprintexample/test_blueprintexample.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +""" + Blueprint Example Tests + ~~~~~~~~~~~~~~ + + Tests the Blueprint example app +""" +import pytest + +import blueprintexample + + +@pytest.fixture +def client(): + return blueprintexample.app.test_client() + + +def test_urls(client): + r = client.get('/') + assert r.status_code == 200 + + r = client.get('/hello') + assert r.status_code == 200 + + r = client.get('/world') + assert r.status_code == 200 + + # second blueprint instance + r = client.get('/pages/hello') + assert r.status_code == 200 + + r = client.get('/pages/world') + assert r.status_code == 200 diff --git a/examples/flaskr/flaskr_tests.py b/examples/flaskr/flaskr_tests.py deleted file mode 100644 index b90a7be7..00000000 --- a/examples/flaskr/flaskr_tests.py +++ /dev/null @@ -1,76 +0,0 @@ -# -*- coding: utf-8 -*- -""" - Flaskr Tests - ~~~~~~~~~~~~ - - Tests the Flaskr application. - - :copyright: (c) 2014 by Armin Ronacher. - :license: BSD, see LICENSE for more details. -""" -import os -import flaskr -import unittest -import tempfile - - -class FlaskrTestCase(unittest.TestCase): - - def setUp(self): - """Before each test, set up a blank database""" - self.db_fd, flaskr.app.config['DATABASE'] = tempfile.mkstemp() - flaskr.app.config['TESTING'] = True - self.app = flaskr.app.test_client() - with flaskr.app.app_context(): - flaskr.init_db() - - def tearDown(self): - """Get rid of the database again after each test.""" - os.close(self.db_fd) - os.unlink(flaskr.app.config['DATABASE']) - - def login(self, username, password): - return self.app.post('/login', data=dict( - username=username, - password=password - ), follow_redirects=True) - - def logout(self): - return self.app.get('/logout', follow_redirects=True) - - # testing functions - - def test_empty_db(self): - """Start with a blank database.""" - rv = self.app.get('/') - assert b'No entries here so far' in rv.data - - def test_login_logout(self): - """Make sure login and logout works""" - rv = self.login(flaskr.app.config['USERNAME'], - flaskr.app.config['PASSWORD']) - assert b'You were logged in' in rv.data - rv = self.logout() - assert b'You were logged out' in rv.data - rv = self.login(flaskr.app.config['USERNAME'] + 'x', - flaskr.app.config['PASSWORD']) - assert b'Invalid username' in rv.data - rv = self.login(flaskr.app.config['USERNAME'], - flaskr.app.config['PASSWORD'] + 'x') - assert b'Invalid password' in rv.data - - def test_messages(self): - """Test that messages work""" - self.login(flaskr.app.config['USERNAME'], - flaskr.app.config['PASSWORD']) - rv = self.app.post('/add', data=dict( - title='', - text='HTML allowed here' - ), follow_redirects=True) - assert b'No entries here so far' not in rv.data - assert b'<Hello>' in rv.data - assert b'HTML allowed here' in rv.data - - -if __name__ == '__main__': - unittest.main() diff --git a/examples/flaskr/test_flaskr.py b/examples/flaskr/test_flaskr.py new file mode 100644 index 00000000..07116f81 --- /dev/null +++ b/examples/flaskr/test_flaskr.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +""" + Flaskr Tests + ~~~~~~~~~~~~ + + Tests the Flaskr application. + + :copyright: (c) 2014 by Armin Ronacher. + :license: BSD, see LICENSE for more details. +""" + +import pytest + +import os +import flaskr +import tempfile + + +@pytest.fixture +def client(request): + db_fd, flaskr.app.config['DATABASE'] = tempfile.mkstemp() + flaskr.app.config['TESTING'] = True + client = flaskr.app.test_client() + with flaskr.app.app_context(): + flaskr.init_db() + + def teardown(): + os.close(db_fd) + os.unlink(flaskr.app.config['DATABASE']) + request.addfinalizer(teardown) + + return client + + +def login(client, username, password): + return client.post('/login', data=dict( + username=username, + password=password + ), follow_redirects=True) + + +def logout(client): + return client.get('/logout', follow_redirects=True) + + +def test_empty_db(client): + """Start with a blank database.""" + rv = client.get('/') + assert b'No entries here so far' in rv.data + + +def test_login_logout(client): + """Make sure login and logout works""" + rv = login(client, flaskr.app.config['USERNAME'], + flaskr.app.config['PASSWORD']) + assert b'You were logged in' in rv.data + rv = logout(client) + assert b'You were logged out' in rv.data + rv = login(client, flaskr.app.config['USERNAME'] + 'x', + flaskr.app.config['PASSWORD']) + assert b'Invalid username' in rv.data + rv = login(client, flaskr.app.config['USERNAME'], + flaskr.app.config['PASSWORD'] + 'x') + assert b'Invalid password' in rv.data + + +def test_messages(client): + """Test that messages work""" + login(client, flaskr.app.config['USERNAME'], + flaskr.app.config['PASSWORD']) + rv = client.post('/add', data=dict( + title='', + text='HTML allowed here' + ), follow_redirects=True) + assert b'No entries here so far' not in rv.data + assert b'<Hello>' in rv.data + assert b'HTML allowed here' in rv.data diff --git a/examples/minitwit/minitwit_tests.py b/examples/minitwit/minitwit_tests.py deleted file mode 100644 index 0a1a3f67..00000000 --- a/examples/minitwit/minitwit_tests.py +++ /dev/null @@ -1,150 +0,0 @@ -# -*- coding: utf-8 -*- -""" - MiniTwit Tests - ~~~~~~~~~~~~~~ - - Tests the MiniTwit application. - - :copyright: (c) 2014 by Armin Ronacher. - :license: BSD, see LICENSE for more details. -""" -import os -import minitwit -import unittest -import tempfile - - -class MiniTwitTestCase(unittest.TestCase): - - def setUp(self): - """Before each test, set up a blank database""" - self.db_fd, minitwit.app.config['DATABASE'] = tempfile.mkstemp() - self.app = minitwit.app.test_client() - with minitwit.app.app_context(): - minitwit.init_db() - - def tearDown(self): - """Get rid of the database again after each test.""" - os.close(self.db_fd) - os.unlink(minitwit.app.config['DATABASE']) - - # helper functions - - def register(self, username, password, password2=None, email=None): - """Helper function to register a user""" - if password2 is None: - password2 = password - if email is None: - email = username + '@example.com' - return self.app.post('/register', data={ - 'username': username, - 'password': password, - 'password2': password2, - 'email': email, - }, follow_redirects=True) - - def login(self, username, password): - """Helper function to login""" - return self.app.post('/login', data={ - 'username': username, - 'password': password - }, follow_redirects=True) - - def register_and_login(self, username, password): - """Registers and logs in in one go""" - self.register(username, password) - return self.login(username, password) - - def logout(self): - """Helper function to logout""" - return self.app.get('/logout', follow_redirects=True) - - def add_message(self, text): - """Records a message""" - rv = self.app.post('/add_message', data={'text': text}, - follow_redirects=True) - if text: - assert b'Your message was recorded' in rv.data - return rv - - # testing functions - - def test_register(self): - """Make sure registering works""" - rv = self.register('user1', 'default') - assert b'You were successfully registered ' \ - b'and can login now' in rv.data - rv = self.register('user1', 'default') - assert b'The username is already taken' in rv.data - rv = self.register('', 'default') - assert b'You have to enter a username' in rv.data - rv = self.register('meh', '') - assert b'You have to enter a password' in rv.data - rv = self.register('meh', 'x', 'y') - assert b'The two passwords do not match' in rv.data - rv = self.register('meh', 'foo', email='broken') - assert b'You have to enter a valid email address' in rv.data - - def test_login_logout(self): - """Make sure logging in and logging out works""" - rv = self.register_and_login('user1', 'default') - assert b'You were logged in' in rv.data - rv = self.logout() - assert b'You were logged out' in rv.data - rv = self.login('user1', 'wrongpassword') - assert b'Invalid password' in rv.data - rv = self.login('user2', 'wrongpassword') - assert b'Invalid username' in rv.data - - def test_message_recording(self): - """Check if adding messages works""" - self.register_and_login('foo', 'default') - self.add_message('test message 1') - self.add_message('') - rv = self.app.get('/') - assert b'test message 1' in rv.data - assert b'<test message 2>' in rv.data - - def test_timelines(self): - """Make sure that timelines work""" - self.register_and_login('foo', 'default') - self.add_message('the message by foo') - self.logout() - self.register_and_login('bar', 'default') - self.add_message('the message by bar') - rv = self.app.get('/public') - assert b'the message by foo' in rv.data - assert b'the message by bar' in rv.data - - # bar's timeline should just show bar's message - rv = self.app.get('/') - assert b'the message by foo' not in rv.data - assert b'the message by bar' in rv.data - - # now let's follow foo - rv = self.app.get('/foo/follow', follow_redirects=True) - assert b'You are now following "foo"' in rv.data - - # we should now see foo's message - rv = self.app.get('/') - assert b'the message by foo' in rv.data - assert b'the message by bar' in rv.data - - # but on the user's page we only want the user's message - rv = self.app.get('/bar') - assert b'the message by foo' not in rv.data - assert b'the message by bar' in rv.data - rv = self.app.get('/foo') - assert b'the message by foo' in rv.data - assert b'the message by bar' not in rv.data - - # now unfollow and check if that worked - rv = self.app.get('/foo/unfollow', follow_redirects=True) - assert b'You are no longer following "foo"' in rv.data - rv = self.app.get('/') - assert b'the message by foo' not in rv.data - assert b'the message by bar' in rv.data - - -if __name__ == '__main__': - unittest.main() diff --git a/examples/minitwit/test_minitwit.py b/examples/minitwit/test_minitwit.py new file mode 100644 index 00000000..c9345e9d --- /dev/null +++ b/examples/minitwit/test_minitwit.py @@ -0,0 +1,151 @@ +# -*- coding: utf-8 -*- +""" + MiniTwit Tests + ~~~~~~~~~~~~~~ + + Tests the MiniTwit application. + + :copyright: (c) 2014 by Armin Ronacher. + :license: BSD, see LICENSE for more details. +""" +import os +import minitwit +import tempfile +import pytest + + +@pytest.fixture +def client(request): + db_fd, minitwit.app.config['DATABASE'] = tempfile.mkstemp() + client = minitwit.app.test_client() + with minitwit.app.app_context(): + minitwit.init_db() + + def teardown(): + """Get rid of the database again after each test.""" + os.close(db_fd) + os.unlink(minitwit.app.config['DATABASE']) + request.addfinalizer(teardown) + return client + + +def register(client, username, password, password2=None, email=None): + """Helper function to register a user""" + if password2 is None: + password2 = password + if email is None: + email = username + '@example.com' + return client.post('/register', data={ + 'username': username, + 'password': password, + 'password2': password2, + 'email': email, + }, follow_redirects=True) + + +def login(client, username, password): + """Helper function to login""" + return client.post('/login', data={ + 'username': username, + 'password': password + }, follow_redirects=True) + + +def register_and_login(client, username, password): + """Registers and logs in in one go""" + register(client, username, password) + return login(client, username, password) + + +def logout(client): + """Helper function to logout""" + return client.get('/logout', follow_redirects=True) + + +def add_message(client, text): + """Records a message""" + rv = client.post('/add_message', data={'text': text}, + follow_redirects=True) + if text: + assert b'Your message was recorded' in rv.data + return rv + + +def test_register(client): + """Make sure registering works""" + rv = register(client, 'user1', 'default') + assert b'You were successfully registered ' \ + b'and can login now' in rv.data + rv = register(client, 'user1', 'default') + assert b'The username is already taken' in rv.data + rv = register(client, '', 'default') + assert b'You have to enter a username' in rv.data + rv = register(client, 'meh', '') + assert b'You have to enter a password' in rv.data + rv = register(client, 'meh', 'x', 'y') + assert b'The two passwords do not match' in rv.data + rv = register(client, 'meh', 'foo', email='broken') + assert b'You have to enter a valid email address' in rv.data + + +def test_login_logout(client): + """Make sure logging in and logging out works""" + rv = register_and_login(client, 'user1', 'default') + assert b'You were logged in' in rv.data + rv = logout(client) + assert b'You were logged out' in rv.data + rv = login(client, 'user1', 'wrongpassword') + assert b'Invalid password' in rv.data + rv = login(client, 'user2', 'wrongpassword') + assert b'Invalid username' in rv.data + + +def test_message_recording(client): + """Check if adding messages works""" + register_and_login(client, 'foo', 'default') + add_message(client, 'test message 1') + add_message(client, '') + rv = client.get('/') + assert b'test message 1' in rv.data + assert b'<test message 2>' in rv.data + + +def test_timelines(client): + """Make sure that timelines work""" + register_and_login(client, 'foo', 'default') + add_message(client, 'the message by foo') + logout(client) + register_and_login(client, 'bar', 'default') + add_message(client, 'the message by bar') + rv = client.get('/public') + assert b'the message by foo' in rv.data + assert b'the message by bar' in rv.data + + # bar's timeline should just show bar's message + rv = client.get('/') + assert b'the message by foo' not in rv.data + assert b'the message by bar' in rv.data + + # now let's follow foo + rv = client.get('/foo/follow', follow_redirects=True) + assert b'You are now following "foo"' in rv.data + + # we should now see foo's message + rv = client.get('/') + assert b'the message by foo' in rv.data + assert b'the message by bar' in rv.data + + # but on the user's page we only want the user's message + rv = client.get('/bar') + assert b'the message by foo' not in rv.data + assert b'the message by bar' in rv.data + rv = client.get('/foo') + assert b'the message by foo' in rv.data + assert b'the message by bar' not in rv.data + + # now unfollow and check if that worked + rv = client.get('/foo/unfollow', follow_redirects=True) + assert b'You are no longer following "foo"' in rv.data + rv = client.get('/') + assert b'the message by foo' not in rv.data + assert b'the message by bar' in rv.data diff --git a/setup.cfg b/setup.cfg index dbd1f97a..c3b4abda 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,3 +1,6 @@ +[pytest] +norecursedirs= scripts docs + [aliases] release = egg_info -RDb '' diff --git a/tests/__init__.py b/tests/__init__.py index 2d1b979d..9c349854 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -11,16 +11,15 @@ """ from __future__ import print_function +import pytest import os import sys import flask import warnings -import unittest from functools import update_wrapper from contextlib import contextmanager -from werkzeug.utils import import_string, find_modules -from flask._compat import reraise, StringIO +from flask._compat import StringIO def add_to_path(path): @@ -43,29 +42,6 @@ def add_to_path(path): sys.path.insert(0, path) -def iter_suites(): - """Yields all testsuites.""" - for module in find_modules(__name__): - mod = import_string(module) - if hasattr(mod, 'suite'): - yield mod.suite() - - -def find_all_tests(suite): - """Yields all the tests and their names from a given suite.""" - suites = [suite] - while suites: - s = suites.pop() - try: - suites.extend(s) - except TypeError: - yield s, '%s.%s.%s' % ( - s.__class__.__module__, - s.__class__.__name__, - s._testMethodName - ) - - @contextmanager def catch_warnings(): """Catch warnings in a with block in a list""" @@ -76,6 +52,7 @@ def catch_warnings(): warnings.filters = filters[:] old_showwarning = warnings.showwarning log = [] + def showwarning(message, category, filename, lineno, file=None, line=None): log.append(locals()) try: @@ -107,12 +84,23 @@ def emits_module_deprecation_warning(f): return update_wrapper(new_f, f) -class FlaskTestCase(unittest.TestCase): +class TestFlask(object): """Baseclass for all the tests that Flask uses. Use these methods for testing instead of the camelcased ones in the baseclass for consistency. """ + @pytest.fixture(autouse=True) + def setup_path(self, monkeypatch): + monkeypatch.syspath_prepend( + os.path.abspath(os.path.join( + os.path.dirname(__file__), 'test_apps')) + ) + + @pytest.fixture(autouse=True) + def leak_detector(self, request): + request.addfinalizer(self.ensure_clean_request_context) + def ensure_clean_request_context(self): # make sure we're not leaking a request context since we are # testing flask internally in debug mode in a few cases @@ -121,133 +109,42 @@ class FlaskTestCase(unittest.TestCase): leaks.append(flask._request_ctx_stack.pop()) self.assert_equal(leaks, []) + def setup_method(self, method): + self.setup() + + def teardown_method(self, method): + self.teardown() + def setup(self): pass def teardown(self): pass - def setUp(self): - self.setup() - - def tearDown(self): - unittest.TestCase.tearDown(self) - self.ensure_clean_request_context() - self.teardown() - def assert_equal(self, x, y): - return self.assertEqual(x, y) + assert x == y def assert_raises(self, exc_type, callable=None, *args, **kwargs): - catcher = _ExceptionCatcher(self, exc_type) - if callable is None: - return catcher - with catcher: - callable(*args, **kwargs) + if callable: + return pytest.raises(exc_type, callable, *args, **kwargs) + else: + return pytest.raises(exc_type) def assert_true(self, x, msg=None): - self.assertTrue(x, msg) + assert x assert_ = assert_true def assert_false(self, x, msg=None): - self.assertFalse(x, msg) + assert not x def assert_in(self, x, y): - self.assertIn(x, y) + assert x in y def assert_not_in(self, x, y): - self.assertNotIn(x, y) + assert x not in y def assert_isinstance(self, obj, cls): - self.assertIsInstance(obj, cls) + assert isinstance(obj, cls) - if sys.version_info[:2] == (2, 6): - def assertIn(self, x, y): - assert x in y, "%r unexpectedly not in %r" % (x, y) - - def assertNotIn(self, x, y): - assert x not in y, "%r unexpectedly in %r" % (x, y) - - def assertIsInstance(self, x, y): - assert isinstance(x, y), "not isinstance(%r, %r)" % (x, y) - - -class _ExceptionCatcher(object): - - def __init__(self, test_case, exc_type): - self.test_case = test_case - self.exc_type = exc_type - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, tb): - exception_name = self.exc_type.__name__ - if exc_type is None: - self.test_case.fail('Expected exception of type %r' % - exception_name) - elif not issubclass(exc_type, self.exc_type): - reraise(exc_type, exc_value, tb) - return True - - -class BetterLoader(unittest.TestLoader): - """A nicer loader that solves two problems. First of all we are setting - up tests from different sources and we're doing this programmatically - which breaks the default loading logic so this is required anyways. - Secondly this loader has a nicer interpolation for test names than the - default one so you can just do ``run-tests.py ViewTestCase`` and it - will work. - """ - - def getRootSuite(self): - return suite() - - def loadTestsFromName(self, name, module=None): - root = self.getRootSuite() - if name == 'suite': - return root - - all_tests = [] - for testcase, testname in find_all_tests(root): - if testname == name or \ - testname.endswith('.' + name) or \ - ('.' + name + '.') in testname or \ - testname.startswith(name + '.'): - all_tests.append(testcase) - - if not all_tests: - raise LookupError('could not find test case for "%s"' % name) - - if len(all_tests) == 1: - return all_tests[0] - rv = unittest.TestSuite() - for test in all_tests: - rv.addTest(test) - return rv - - -def setup_path(): - add_to_path(os.path.abspath(os.path.join( - os.path.dirname(__file__), 'test_apps'))) - - -def suite(): - """A testsuite that has all the Flask tests. You can use this - function to integrate the Flask tests into your own testsuite - in case you want to test that monkeypatches to Flask do not - break it. - """ - setup_path() - suite = unittest.TestSuite() - for other_suite in iter_suites(): - suite.addTest(other_suite) - return suite - - -def main(): - """Runs the testsuite as command line application.""" - try: - unittest.main(testLoader=BetterLoader(), defaultTest='suite') - except Exception as e: - print('Error: %s' % e) + def fail(self, msg): + raise AssertionError(msg) diff --git a/tests/test_appctx.py b/tests/test_appctx.py index 65e9e86d..2430f577 100644 --- a/tests/test_appctx.py +++ b/tests/test_appctx.py @@ -11,10 +11,10 @@ import flask import unittest -from tests import FlaskTestCase +from tests import TestFlask -class AppContextTestCase(FlaskTestCase): +class TestAppContext(TestFlask): def test_basic_url_generation(self): app = flask.Flask(__name__) @@ -109,10 +109,10 @@ class AppContextTestCase(FlaskTestCase): return u'' c = app.test_client() c.get('/') - self.assertEqual(called, ['request', 'app']) + self.assert_equal(called, ['request', 'app']) def suite(): suite = unittest.TestSuite() - suite.addTest(unittest.makeSuite(AppContextTestCase)) + suite.addTest(unittest.makeSuite(TestAppContext)) return suite diff --git a/tests/test_apps/importerror.py b/tests/test_apps/importerror.py index eb298b9b..3970e3e7 100644 --- a/tests/test_apps/importerror.py +++ b/tests/test_apps/importerror.py @@ -1,2 +1,2 @@ -# NoImportsTestCase +# TestNoImports raise NotImplementedError diff --git a/tests/test_basic.py b/tests/test_basic.py index 1cc20ee0..c06907e6 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -17,14 +17,14 @@ import pickle import unittest from datetime import datetime from threading import Thread -from tests import FlaskTestCase, emits_module_deprecation_warning +from tests import TestFlask, emits_module_deprecation_warning from flask._compat import text_type from werkzeug.exceptions import BadRequest, NotFound, Forbidden from werkzeug.http import parse_date from werkzeug.routing import BuildError -class BasicFunctionalityTestCase(FlaskTestCase): +class TestBasicFunctionality(TestFlask): def test_options_work(self): app = flask.Flask(__name__) @@ -522,8 +522,8 @@ class BasicFunctionalityTestCase(FlaskTestCase): return 'Test' c = app.test_client() resp = c.get('/') - self.assertEqual(resp.status_code, 200) - self.assertEqual(resp.headers['X-Foo'], 'a header') + self.assert_equal(resp.status_code, 200) + self.assert_equal(resp.headers['X-Foo'], 'a header') def test_teardown_request_handler(self): called = [] @@ -840,22 +840,22 @@ class BasicFunctionalityTestCase(FlaskTestCase): with app.test_request_context(): rv = flask.make_response( flask.jsonify({'msg': 'W00t'}), 400) - self.assertEqual(rv.status_code, 400) - self.assertEqual(rv.data, b'{\n "msg": "W00t"\n}') - self.assertEqual(rv.mimetype, 'application/json') + self.assert_equal(rv.status_code, 400) + self.assert_equal(rv.data, b'{\n "msg": "W00t"\n}') + self.assert_equal(rv.mimetype, 'application/json') rv = flask.make_response( flask.Response(''), 400) - self.assertEqual(rv.status_code, 400) - self.assertEqual(rv.data, b'') - self.assertEqual(rv.mimetype, 'text/html') + self.assert_equal(rv.status_code, 400) + self.assert_equal(rv.data, b'') + self.assert_equal(rv.mimetype, 'text/html') rv = flask.make_response( flask.Response('', headers={'Content-Type': 'text/html'}), 400, [('X-Foo', 'bar')]) - self.assertEqual(rv.status_code, 400) - self.assertEqual(rv.headers['Content-Type'], 'text/html') - self.assertEqual(rv.headers['X-Foo'], 'bar') + self.assert_equal(rv.status_code, 400) + self.assert_equal(rv.headers['Content-Type'], 'text/html') + self.assert_equal(rv.headers['X-Foo'], 'bar') def test_url_generation(self): app = flask.Flask(__name__) @@ -872,7 +872,7 @@ class BasicFunctionalityTestCase(FlaskTestCase): # Test base case, a URL which results in a BuildError. with app.test_request_context(): - self.assertRaises(BuildError, flask.url_for, 'spam') + self.assert_raises(BuildError, flask.url_for, 'spam') # Verify the error is re-raised if not the current exception. try: @@ -883,7 +883,7 @@ class BasicFunctionalityTestCase(FlaskTestCase): try: raise RuntimeError('Test case where BuildError is not current.') except RuntimeError: - self.assertRaises(BuildError, app.handle_url_build_error, error, 'spam', {}) + self.assert_raises(BuildError, app.handle_url_build_error, error, 'spam', {}) # Test a custom handler. def handler(error, endpoint, values): @@ -936,7 +936,7 @@ class BasicFunctionalityTestCase(FlaskTestCase): def test_request_locals(self): self.assert_equal(repr(flask.g), '') - self.assertFalse(flask.g) + self.assert_false(flask.g) def test_test_app_proper_environ(self): app = flask.Flask(__name__) @@ -1205,9 +1205,9 @@ class BasicFunctionalityTestCase(FlaskTestCase): assert flask.url_for('123') == '/bar/123' c = app.test_client() - self.assertEqual(c.get('/foo/').data, b'foo') - self.assertEqual(c.get('/bar/').data, b'bar') - self.assertEqual(c.get('/bar/123').data, b'123') + self.assert_equal(c.get('/foo/').data, b'foo') + self.assert_equal(c.get('/bar/').data, b'bar') + self.assert_equal(c.get('/bar/123').data, b'123') def test_preserve_only_once(self): app = flask.Flask(__name__) @@ -1286,7 +1286,7 @@ class BasicFunctionalityTestCase(FlaskTestCase): self.assert_equal(sorted(flask.g), ['bar', 'foo']) -class SubdomainTestCase(FlaskTestCase): +class TestSubdomain(TestFlask): def test_basic_support(self): app = flask.Flask(__name__) @@ -1355,10 +1355,3 @@ class SubdomainTestCase(FlaskTestCase): self.assert_equal(rv.data, b'a') rv = app.test_client().open('/b/') self.assert_equal(rv.data, b'b') - - -def suite(): - suite = unittest.TestSuite() - suite.addTest(unittest.makeSuite(BasicFunctionalityTestCase)) - suite.addTest(unittest.makeSuite(SubdomainTestCase)) - return suite diff --git a/tests/test_blueprints.py b/tests/test_blueprints.py index a5488a57..c8e85b4a 100644 --- a/tests/test_blueprints.py +++ b/tests/test_blueprints.py @@ -11,13 +11,13 @@ import flask import unittest -from tests import FlaskTestCase +from tests import TestFlask from flask._compat import text_type from werkzeug.http import parse_cache_control_header from jinja2 import TemplateNotFound -class BlueprintTestCase(FlaskTestCase): +class TestBlueprint(TestFlask): def test_blueprint_specific_error_handling(self): frontend = flask.Blueprint('frontend', __name__) @@ -303,11 +303,11 @@ class BlueprintTestCase(FlaskTestCase): return flask.request.endpoint c = app.test_client() - self.assertEqual(c.get('/').data, b'index') - self.assertEqual(c.get('/py/foo').data, b'bp.foo') - self.assertEqual(c.get('/py/bar').data, b'bp.bar') - self.assertEqual(c.get('/py/bar/123').data, b'bp.123') - self.assertEqual(c.get('/py/bar/foo').data, b'bp.bar_foo') + self.assert_equal(c.get('/').data, b'index') + self.assert_equal(c.get('/py/foo').data, b'bp.foo') + self.assert_equal(c.get('/py/bar').data, b'bp.bar') + self.assert_equal(c.get('/py/bar/123').data, b'bp.123') + self.assert_equal(c.get('/py/bar/foo').data, b'bp.bar_foo') def test_route_decorator_custom_endpoint_with_dots(self): bp = flask.Blueprint('bp', __name__) @@ -337,14 +337,14 @@ class BlueprintTestCase(FlaskTestCase): def foo_foo_foo(): pass - self.assertRaises( + self.assert_raises( AssertionError, lambda: bp.add_url_rule( '/bar/123', endpoint='bar.123', view_func=foo_foo_foo ) ) - self.assertRaises( + self.assert_raises( AssertionError, bp.route('/bar/123', endpoint='bar.123'), lambda: None @@ -354,7 +354,7 @@ class BlueprintTestCase(FlaskTestCase): app.register_blueprint(bp, url_prefix='/py') c = app.test_client() - self.assertEqual(c.get('/py/foo').data, b'bp.foo') + self.assert_equal(c.get('/py/foo').data, b'bp.foo') # The rule's didn't actually made it through rv = c.get('/py/bar') assert rv.status_code == 404 @@ -581,5 +581,5 @@ class BlueprintTestCase(FlaskTestCase): def suite(): suite = unittest.TestSuite() - suite.addTest(unittest.makeSuite(BlueprintTestCase)) + suite.addTest(unittest.makeSuite(TestBlueprint)) return suite diff --git a/tests/test_config.py b/tests/test_config.py index 4f9e693f..99e2f45e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -15,21 +15,21 @@ import flask import pkgutil import unittest from contextlib import contextmanager -from tests import FlaskTestCase +from tests import TestFlask from flask._compat import PY2 -# config keys used for the ConfigTestCase +# config keys used for the TestConfig TEST_KEY = 'foo' SECRET_KEY = 'devkey' -class ConfigTestCase(FlaskTestCase): +class TestConfig(TestFlask): def common_object_test(self, app): self.assert_equal(app.secret_key, 'devkey') self.assert_equal(app.config['TEST_KEY'], 'foo') - self.assert_not_in('ConfigTestCase', app.config) + self.assert_not_in('TestConfig', app.config) def test_config_from_file(self): app = flask.Flask(__name__) @@ -117,7 +117,7 @@ class ConfigTestCase(FlaskTestCase): self.assert_true(msg.endswith("missing.cfg'")) else: self.fail('expected IOError') - self.assertFalse(app.config.from_envvar('FOO_SETTINGS', silent=True)) + self.assert_false(app.config.from_envvar('FOO_SETTINGS', silent=True)) finally: os.environ = env @@ -207,7 +207,7 @@ def patch_pkgutil_get_loader(wrapper_class=LimitedLoaderMockWrapper): pkgutil.get_loader = old_get_loader -class InstanceTestCase(FlaskTestCase): +class TestInstance(TestFlask): def test_explicit_instance_paths(self): here = os.path.abspath(os.path.dirname(__file__)) @@ -379,6 +379,6 @@ class InstanceTestCase(FlaskTestCase): def suite(): suite = unittest.TestSuite() - suite.addTest(unittest.makeSuite(ConfigTestCase)) - suite.addTest(unittest.makeSuite(InstanceTestCase)) + suite.addTest(unittest.makeSuite(TestConfig)) + suite.addTest(unittest.makeSuite(TestInstance)) return suite diff --git a/tests/test_deprecations.py b/tests/test_deprecations.py index 7353af5a..81fba63b 100644 --- a/tests/test_deprecations.py +++ b/tests/test_deprecations.py @@ -11,14 +11,14 @@ import flask import unittest -from tests import FlaskTestCase, catch_warnings +from tests import TestFlask, catch_warnings -class DeprecationsTestCase(FlaskTestCase): +class TestDeprecations(TestFlask): """not used currently""" def suite(): suite = unittest.TestSuite() - suite.addTest(unittest.makeSuite(DeprecationsTestCase)) + suite.addTest(unittest.makeSuite(TestDeprecations)) return suite diff --git a/tests/test_examples.py b/tests/test_examples.py index c321a141..e4f66325 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -24,15 +24,15 @@ def suite(): setup_path() suite = unittest.TestSuite() try: - from minitwit_tests import MiniTwitTestCase + from minitwit_tests import TestMiniTwit except ImportError: pass else: - suite.addTest(unittest.makeSuite(MiniTwitTestCase)) + suite.addTest(unittest.makeSuite(TestMiniTwit)) try: - from flaskr_tests import FlaskrTestCase + from flaskr_tests import TestFlaskr except ImportError: pass else: - suite.addTest(unittest.makeSuite(FlaskrTestCase)) + suite.addTest(unittest.makeSuite(TestFlaskr)) return suite diff --git a/tests/test_ext.py b/tests/test_ext.py index 110657d8..e48563bf 100644 --- a/tests/test_ext.py +++ b/tests/test_ext.py @@ -15,10 +15,10 @@ try: from imp import reload as reload_module except ImportError: reload_module = reload -from tests import FlaskTestCase +from tests import TestFlask from flask._compat import PY2 -class ExtImportHookTestCase(FlaskTestCase): +class TestExtImportHook(TestFlask): def setup(self): # we clear this out for various reasons. The most important one is @@ -132,5 +132,5 @@ class ExtImportHookTestCase(FlaskTestCase): def suite(): suite = unittest.TestSuite() - suite.addTest(unittest.makeSuite(ExtImportHookTestCase)) + suite.addTest(unittest.makeSuite(TestExtImportHook)) return suite diff --git a/tests/test_helpers.py b/tests/test_helpers.py index c923cae1..e5055fa6 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -13,7 +13,7 @@ import os import flask import unittest from logging import StreamHandler -from tests import FlaskTestCase, catch_warnings, catch_stderr +from tests import TestFlask, catch_warnings, catch_stderr from werkzeug.http import parse_cache_control_header, parse_options_header from flask._compat import StringIO, text_type @@ -27,7 +27,7 @@ def has_encoding(name): return False -class JSONTestCase(FlaskTestCase): +class TestJSON(TestFlask): def test_json_bad_requests(self): app = flask.Flask(__name__) @@ -148,7 +148,7 @@ class JSONTestCase(FlaskTestCase): rv = c.post('/', data=flask.json.dumps({ 'x': {'_foo': 42} }), content_type='application/json') - self.assertEqual(rv.data, b'"<42>"') + self.assert_equal(rv.data, b'"<42>"') def test_modified_url_encoding(self): class ModifiedRequest(flask.Request): @@ -240,7 +240,7 @@ class JSONTestCase(FlaskTestCase): except AssertionError: self.assert_equal(lines, sorted_by_str) -class SendfileTestCase(FlaskTestCase): +class TestSendfile(TestFlask): def test_send_file_regular(self): app = flask.Flask(__name__) @@ -422,7 +422,7 @@ class SendfileTestCase(FlaskTestCase): rv.close() -class LoggingTestCase(FlaskTestCase): +class TestLogging(TestFlask): def test_logger_cache(self): app = flask.Flask(__name__) @@ -450,7 +450,7 @@ class LoggingTestCase(FlaskTestCase): with catch_stderr() as err: c.get('/') out = err.getvalue() - self.assert_in('WARNING in helpers [', out) + self.assert_in('WARNING in test_helpers [', out) self.assert_in(os.path.basename(__file__.rsplit('.', 1)[0] + '.py'), out) self.assert_in('the standard library is dead', out) self.assert_in('this is a debug statement', out) @@ -572,7 +572,7 @@ class LoggingTestCase(FlaskTestCase): '/myview/create') -class NoImportsTestCase(FlaskTestCase): +class TestNoImports(TestFlask): """Test Flasks are created without import. Avoiding ``__import__`` helps create Flask instances where there are errors @@ -590,7 +590,7 @@ class NoImportsTestCase(FlaskTestCase): self.fail('Flask(import_name) is importing import_name.') -class StreamingTestCase(FlaskTestCase): +class TestStreaming(TestFlask): def test_streaming_with_context(self): app = flask.Flask(__name__) @@ -604,7 +604,7 @@ class StreamingTestCase(FlaskTestCase): return flask.Response(flask.stream_with_context(generate())) c = app.test_client() rv = c.get('/?name=World') - self.assertEqual(rv.data, b'Hello World!') + self.assert_equal(rv.data, b'Hello World!') def test_streaming_with_context_as_decorator(self): app = flask.Flask(__name__) @@ -619,7 +619,7 @@ class StreamingTestCase(FlaskTestCase): return flask.Response(generate()) c = app.test_client() rv = c.get('/?name=World') - self.assertEqual(rv.data, b'Hello World!') + self.assert_equal(rv.data, b'Hello World!') def test_streaming_with_context_and_custom_close(self): app = flask.Flask(__name__) @@ -645,16 +645,16 @@ class StreamingTestCase(FlaskTestCase): Wrapper(generate()))) c = app.test_client() rv = c.get('/?name=World') - self.assertEqual(rv.data, b'Hello World!') - self.assertEqual(called, [42]) + self.assert_equal(rv.data, b'Hello World!') + self.assert_equal(called, [42]) def suite(): suite = unittest.TestSuite() if flask.json_available: - suite.addTest(unittest.makeSuite(JSONTestCase)) - suite.addTest(unittest.makeSuite(SendfileTestCase)) - suite.addTest(unittest.makeSuite(LoggingTestCase)) - suite.addTest(unittest.makeSuite(NoImportsTestCase)) - suite.addTest(unittest.makeSuite(StreamingTestCase)) + suite.addTest(unittest.makeSuite(TestJSON)) + suite.addTest(unittest.makeSuite(TestSendfile)) + suite.addTest(unittest.makeSuite(TestLogging)) + suite.addTest(unittest.makeSuite(TestNoImports)) + suite.addTest(unittest.makeSuite(TestStreaming)) return suite diff --git a/tests/test_regression.py b/tests/test_regression.py index 8bdbbf6d..df766246 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -9,14 +9,15 @@ :license: BSD, see LICENSE for more details. """ +import pytest + import os import gc import sys import flask import threading -import unittest from werkzeug.exceptions import NotFound -from tests import FlaskTestCase +from tests import TestFlask _gc_lock = threading.Lock() @@ -51,13 +52,16 @@ class _NoLeakAsserter(object): gc.enable() -class MemoryTestCase(FlaskTestCase): +@pytest.mark.skipif(os.environ.get('RUN_FLASK_MEMORY_TESTS') != '1', + reason='Turned off due to envvar.') +class TestMemory(TestFlask): def assert_no_leak(self): return _NoLeakAsserter(self) def test_memory_consumption(self): app = flask.Flask(__name__) + @app.route('/') def index(): return flask.render_template('simple_template.html', whiskey=42) @@ -84,33 +88,28 @@ class MemoryTestCase(FlaskTestCase): safe_join('/foo', '..') -class ExceptionTestCase(FlaskTestCase): +class TestException(TestFlask): def test_aborting(self): class Foo(Exception): whatever = 42 app = flask.Flask(__name__) app.testing = True + @app.errorhandler(Foo) def handle_foo(e): return str(e.whatever) + @app.route('/') def index(): raise flask.abort(flask.redirect(flask.url_for('test'))) + @app.route('/test') def test(): raise Foo() with app.test_client() as c: rv = c.get('/') - self.assertEqual(rv.headers['Location'], 'http://localhost/test') + self.assert_equal(rv.headers['Location'], 'http://localhost/test') rv = c.get('/test') - self.assertEqual(rv.data, b'42') - - -def suite(): - suite = unittest.TestSuite() - if os.environ.get('RUN_FLASK_MEMORY_TESTS') == '1': - suite.addTest(unittest.makeSuite(MemoryTestCase)) - suite.addTest(unittest.makeSuite(ExceptionTestCase)) - return suite + self.assert_equal(rv.data, b'42') diff --git a/tests/test_reqctx.py b/tests/test_reqctx.py index 4ea04098..6787d90a 100644 --- a/tests/test_reqctx.py +++ b/tests/test_reqctx.py @@ -15,10 +15,10 @@ try: from greenlet import greenlet except ImportError: greenlet = None -from tests import FlaskTestCase +from tests import TestFlask -class RequestContextTestCase(FlaskTestCase): +class TestRequestContext(TestFlask): def test_teardown_on_pop(self): buffer = [] @@ -197,5 +197,5 @@ class RequestContextTestCase(FlaskTestCase): def suite(): suite = unittest.TestSuite() - suite.addTest(unittest.makeSuite(RequestContextTestCase)) + suite.addTest(unittest.makeSuite(TestRequestContext)) return suite diff --git a/tests/test_signals.py b/tests/test_signals.py index d05b5b27..db3287e5 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -11,10 +11,10 @@ import flask import unittest -from tests import FlaskTestCase +from tests import TestFlask -class SignalsTestCase(FlaskTestCase): +class TestSignals(TestFlask): def test_template_rendered(self): app = flask.Flask(__name__) @@ -149,5 +149,5 @@ class SignalsTestCase(FlaskTestCase): def suite(): suite = unittest.TestSuite() if flask.signals_available: - suite.addTest(unittest.makeSuite(SignalsTestCase)) + suite.addTest(unittest.makeSuite(TestSignals)) return suite diff --git a/tests/test_subclassing.py b/tests/test_subclassing.py index 41e587e7..b5c85dfd 100644 --- a/tests/test_subclassing.py +++ b/tests/test_subclassing.py @@ -12,11 +12,11 @@ import flask import unittest from logging import StreamHandler -from tests import FlaskTestCase +from tests import TestFlask from flask._compat import StringIO -class FlaskSubclassingTestCase(FlaskTestCase): +class TestFlaskSubclassing(TestFlask): def test_suppressed_exception_logging(self): class SuppressedFlask(flask.Flask): @@ -42,5 +42,5 @@ class FlaskSubclassingTestCase(FlaskTestCase): def suite(): suite = unittest.TestSuite() - suite.addTest(unittest.makeSuite(FlaskSubclassingTestCase)) + suite.addTest(unittest.makeSuite(TestFlaskSubclassing)) return suite diff --git a/tests/test_templating.py b/tests/test_templating.py index 5ce67d97..9727d21d 100644 --- a/tests/test_templating.py +++ b/tests/test_templating.py @@ -14,10 +14,10 @@ import unittest import logging from jinja2 import TemplateNotFound -from tests import FlaskTestCase +from tests import TestFlask -class TemplatingTestCase(FlaskTestCase): +class TestTemplating(TestFlask): def test_context_processing(self): app = flask.Flask(__name__) @@ -348,5 +348,5 @@ class TemplatingTestCase(FlaskTestCase): def suite(): suite = unittest.TestSuite() - suite.addTest(unittest.makeSuite(TemplatingTestCase)) + suite.addTest(unittest.makeSuite(TestTemplating)) return suite diff --git a/tests/test_testing.py b/tests/test_testing.py index 3a51c59f..257cf7d0 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -8,14 +8,15 @@ :copyright: (c) 2014 by Armin Ronacher. :license: BSD, see LICENSE for more details. """ +import pytest import flask import unittest -from tests import FlaskTestCase +from tests import TestFlask from flask._compat import text_type -class TestToolsTestCase(FlaskTestCase): +class TestTestTools(TestFlask): def test_environ_defaults_from_config(self): app = flask.Flask(__name__) @@ -212,46 +213,45 @@ class TestToolsTestCase(FlaskTestCase): self.assert_true('vodka' in flask.request.args) -class SubdomainTestCase(FlaskTestCase): +class TestSubdomain(TestFlask): - def setUp(self): - self.app = flask.Flask(__name__) - self.app.config['SERVER_NAME'] = 'example.com' - self.client = self.app.test_client() + @pytest.fixture + def app(self, request): + app = flask.Flask(__name__) + app.config['SERVER_NAME'] = 'example.com' - self._ctx = self.app.test_request_context() - self._ctx.push() + ctx = app.test_request_context() + ctx.push() - def tearDown(self): - if self._ctx is not None: - self._ctx.pop() + def teardown(): + if ctx is not None: + ctx.pop() + request.addfinalizer(teardown) + return app - def test_subdomain(self): - @self.app.route('/', subdomain='') + @pytest.fixture + def client(self, app): + return app.test_client() + + def test_subdomain(self, app, client): + @app.route('/', subdomain='') def view(company_id): return company_id url = flask.url_for('view', company_id='xxx') - response = self.client.get(url) + response = client.get(url) self.assert_equal(200, response.status_code) self.assert_equal(b'xxx', response.data) - def test_nosubdomain(self): - @self.app.route('/') + def test_nosubdomain(self, app, client): + @app.route('/') def view(company_id): return company_id url = flask.url_for('view', company_id='xxx') - response = self.client.get(url) + response = client.get(url) self.assert_equal(200, response.status_code) self.assert_equal(b'xxx', response.data) - - -def suite(): - suite = unittest.TestSuite() - suite.addTest(unittest.makeSuite(TestToolsTestCase)) - suite.addTest(unittest.makeSuite(SubdomainTestCase)) - return suite diff --git a/tests/test_views.py b/tests/test_views.py index f98f920b..a98bc2f0 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -12,10 +12,10 @@ import flask import flask.views import unittest -from tests import FlaskTestCase +from tests import TestFlask from werkzeug.http import parse_set_header -class ViewTestCase(FlaskTestCase): +class TestView(TestFlask): def common_test(self, app): c = app.test_client() @@ -165,5 +165,5 @@ class ViewTestCase(FlaskTestCase): def suite(): suite = unittest.TestSuite() - suite.addTest(unittest.makeSuite(ViewTestCase)) + suite.addTest(unittest.makeSuite(TestView)) return suite