diff --git a/scripts/flaskext_migrate.py b/scripts/flaskext_migrate.py index 3d78808f..023508b9 100644 --- a/scripts/flaskext_migrate.py +++ b/scripts/flaskext_migrate.py @@ -50,14 +50,14 @@ def fix_from_imports(red): if len(node.value) == 3: package = values[2].value modules = node.modules() + module_string = _get_modules(modules) if len(modules) > 1: - r = "{}," * len(modules) node.replace("from flask_%s import %s" - % (package, r.format(*modules)[:-1])) + % (package, module_string)) else: name = node.names()[0] node.replace("from flask_%s import %s as %s" - % (package, modules.pop(), name)) + % (package, module_string, name)) # Case 2 else: module = node.modules()[0] @@ -88,13 +88,36 @@ def fix_standard_imports(red): return red -def fix(ast): - """Wrapper which allows for testing when not running from shell""" - return fix_imports(ast).dumps() +def _get_modules(module): + """ + Takes a list of modules and converts into a string -if __name__ == "__main__": + The module list can include parens, this function checks each element in + the list, if there is a paren then it does not add a comma before the next + element. Otherwise a comma and space is added. This is to preserve module + imports which are multi-line and/or occur within parens. While also not + affecting imports which are not enclosed. + """ + modules_string = [cur + ', ' if cur.isalnum() and next.isalnum() + else cur + for (cur, next) in zip(module, module[1:]+[''])] + + return ''.join(modules_string) + + +def check_user_input(): + """Exits and gives error message if no argument is passed in the shell.""" if len(sys.argv) < 2: sys.exit("No filename was included, please try again.") + + +def fix(ast): + """Wrapper which allows for testing when not running from shell.""" + return fix_imports(ast).dumps() + + +if __name__ == "__main__": + check_user_input() input_file = sys.argv[1] ast = read_source(input_file) ast = fix_imports(ast) diff --git a/scripts/test_import_migration.py b/scripts/test_import_migration.py new file mode 100644 index 00000000..71956749 --- /dev/null +++ b/scripts/test_import_migration.py @@ -0,0 +1,51 @@ +# Tester for the flaskext_migrate.py module located in flask/scripts/ +# +# Author: Keyan Pishdadian +import pytest +from redbaron import RedBaron +import flaskext_migrate as migrate + + +def test_simple_from_import(): + red = RedBaron("from flask.ext import foo") + output = migrate.fix(red) + assert output == "import flask_foo as foo" + + +def test_from_to_from_import(): + red = RedBaron("from flask.ext.foo import bar") + output = migrate.fix(red) + assert output == "from flask_foo import bar as bar" + + +def test_multiple_import(): + red = RedBaron("from flask.ext.foo import bar, foobar, something") + output = migrate.fix(red) + assert output == "from flask_foo import bar, foobar, something" + + +def test_multiline_import(): + red = RedBaron("from flask.ext.foo import \ + bar,\ + foobar,\ + something") + output = migrate.fix(red) + assert output == "from flask_foo import bar, foobar, something" + + +def test_module_import(): + red = RedBaron("import flask.ext.foo") + output = migrate.fix(red) + assert output == "import flask_foo" + + +def test_module_import(): + red = RedBaron("from flask.ext.foo import bar as baz") + output = migrate.fix(red) + assert output == "from flask_foo import bar as baz" + + +def test_parens_import(): + red = RedBaron("from flask.ext.foo import (bar, foo, foobar)") + output = migrate.fix(red) + assert output == "from flask_foo import (bar, foo, foobar)"