From b759aa2b955cccf7050ece74b84c32089443e47e Mon Sep 17 00:00:00 2001 From: Keyan Pishdadian Date: Thu, 12 Feb 2015 11:58:38 -0500 Subject: [PATCH] Add test for naming module and fix logic to cover --- scripts/flaskext_migrate.py | 39 ++++++++++++++++++++-------------- tests/test_import_migration.py | 11 ++++++++-- 2 files changed, 32 insertions(+), 18 deletions(-) diff --git a/scripts/flaskext_migrate.py b/scripts/flaskext_migrate.py index e23b732b..3d78808f 100644 --- a/scripts/flaskext_migrate.py +++ b/scripts/flaskext_migrate.py @@ -43,21 +43,26 @@ def fix_from_imports(red): Case 2: from flask.ext import foo --> import flask_foo as foo """ from_imports = red.find_all("FromImport") - for x in range(len(from_imports)): - values = from_imports[x].value + for x, node in enumerate(from_imports): + values = node.value if (values[0].value == 'flask') and (values[1].value == 'ext'): # Case 1 - if len(from_imports[x].value) == 3: + if len(node.value) == 3: package = values[2].value - modules = from_imports[x].modules() - r = "{}," * len(modules) - from_imports[x].replace("from flask_%s import %s" - % (package, r.format(*modules)[:-1])) + modules = node.modules() + if len(modules) > 1: + r = "{}," * len(modules) + node.replace("from flask_%s import %s" + % (package, r.format(*modules)[:-1])) + else: + name = node.names()[0] + node.replace("from flask_%s import %s as %s" + % (package, modules.pop(), name)) # Case 2 else: - module = from_imports[x].modules()[0] - from_imports[x].replace("import flask_%s as %s" - % (module, module)) + module = node.modules()[0] + node.replace("import flask_%s as %s" + % (module, module)) return red @@ -70,13 +75,13 @@ def fix_standard_imports(red): original import statement. """ imports = red.find_all("ImportNode") - for x in range(len(imports)): - values = imports[x].value + for x, node in enumerate(imports): try: - if (values[x].value[0].value == 'flask' and - values[x].value[1].value == 'ext'): - package = values[x].value[2].value - imports[x].replace("import flask_%s" % package) + if (node.value[0].value == 'flask' and + node.value[1].value == 'ext'): + package = node.value[2].value + name = node.names()[0] + imports[x].replace("import flask_%s as %s" % (package, name)) except IndexError: pass @@ -88,6 +93,8 @@ def fix(ast): return fix_imports(ast).dumps() if __name__ == "__main__": + if len(sys.argv) < 2: + sys.exit("No filename was included, please try again.") input_file = sys.argv[1] ast = read_source(input_file) ast = fix_imports(ast) diff --git a/tests/test_import_migration.py b/tests/test_import_migration.py index dc662aa7..ddd49142 100644 --- a/tests/test_import_migration.py +++ b/tests/test_import_migration.py @@ -1,7 +1,8 @@ # Tester for the flaskext_migrate.py module located in flask/scripts/ # # Author: Keyan Pishdadian - +import sys +sys.path.append('scripts') import pytest from redbaron import RedBaron import flaskext_migrate as migrate @@ -16,7 +17,7 @@ def test_simple_from_import(): def test_from_to_from_import(): red = RedBaron("from flask.ext.foo import bar") output = migrate.fix(red) - assert output == "from flask_foo import bar" + assert output == "from flask_foo import bar as bar" def test_multiple_import(): @@ -38,3 +39,9 @@ def test_module_import(): red = RedBaron("import flask.ext.foo") output = migrate.fix(red) assert output == "import flask_foo" + + +def test_module_import(): + red = RedBaron("from flask.ext.foo import bar as baz") + output = migrate.fix(red) + assert output == "from flask_foo import bar as baz"