Cleanup in the test finder
This commit is contained in:
parent
5a49688554
commit
a082a5e0ba
1 changed files with 27 additions and 24 deletions
|
|
@ -20,9 +20,6 @@ from contextlib import contextmanager
|
|||
from werkzeug.utils import import_string, find_modules
|
||||
|
||||
|
||||
common_prefix = __name__ + '.'
|
||||
|
||||
|
||||
def add_to_path(path):
|
||||
def _samefile(x, y):
|
||||
try:
|
||||
|
|
@ -50,23 +47,18 @@ def iter_suites():
|
|||
yield mod.suite()
|
||||
|
||||
|
||||
def find_all_tests():
|
||||
suites = [suite()]
|
||||
def find_all_tests(suite):
|
||||
suites = [suite]
|
||||
while suites:
|
||||
s = suites.pop()
|
||||
try:
|
||||
suites.extend(s)
|
||||
except TypeError:
|
||||
yield s
|
||||
|
||||
|
||||
def find_all_tests_with_name():
|
||||
for testcase in find_all_tests():
|
||||
yield testcase, '%s.%s.%s' % (
|
||||
testcase.__class__.__module__,
|
||||
testcase.__class__.__name__,
|
||||
testcase._testMethodName
|
||||
)
|
||||
yield s, '%s.%s.%s' % (
|
||||
s.__class__.__module__,
|
||||
s.__class__.__name__,
|
||||
s._testMethodName
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
|
@ -111,6 +103,10 @@ def emits_module_deprecation_warning(f):
|
|||
|
||||
|
||||
class FlaskTestCase(unittest.TestCase):
|
||||
"""Baseclass for all the tests that Flask uses. Use these methods
|
||||
for testing instead of the camelcased ones in the baseclass for
|
||||
consistency.
|
||||
"""
|
||||
|
||||
def ensure_clean_request_context(self):
|
||||
# make sure we're not leaking a request context since we are
|
||||
|
|
@ -136,20 +132,27 @@ class FlaskTestCase(unittest.TestCase):
|
|||
|
||||
|
||||
class BetterLoader(unittest.TestLoader):
|
||||
"""A nicer loader that solves two problems. First of all we are setting
|
||||
up tests from different sources and we're doing this programmatically
|
||||
which breaks the default loading logic so this is required anyways.
|
||||
Secondly this loader has a nicer interpolation for test names than the
|
||||
default one so you can just do ``run-tests.py ViewTestCase`` and it
|
||||
will work.
|
||||
"""
|
||||
|
||||
def getRootSuite(self):
|
||||
return suite()
|
||||
|
||||
def loadTestsFromName(self, name, module=None):
|
||||
root = self.getRootSuite()
|
||||
if name == 'suite':
|
||||
return suite()
|
||||
for testcase, testname in find_all_tests_with_name():
|
||||
if testname == name:
|
||||
return testcase
|
||||
if testname.startswith(common_prefix):
|
||||
if testname[len(common_prefix):] == name:
|
||||
return testcase
|
||||
return root
|
||||
|
||||
all_tests = []
|
||||
for testcase, testname in find_all_tests_with_name():
|
||||
if testname.endswith('.' + name) or ('.' + name + '.') in testname or \
|
||||
for testcase, testname in find_all_tests(root):
|
||||
if testname == name or \
|
||||
testname.endswith('.' + name) or \
|
||||
('.' + name + '.') in testname or \
|
||||
testname.startswith(name + '.'):
|
||||
all_tests.append(testcase)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue