diff --git a/scripts/flaskext_migrate.py b/scripts/flaskext_migrate.py index 023508b9..5dcfa664 100644 --- a/scripts/flaskext_migrate.py +++ b/scripts/flaskext_migrate.py @@ -70,18 +70,15 @@ def fix_standard_imports(red): """ Handles import modification in the form: import flask.ext.foo" --> import flask_foo - - Does not modify function calls elsewhere in the source outside of the - original import statement. """ imports = red.find_all("ImportNode") for x, node in enumerate(imports): try: - if (node.value[0].value == 'flask' and - node.value[1].value == 'ext'): - package = node.value[2].value - name = node.names()[0] - imports[x].replace("import flask_%s as %s" % (package, name)) + if (node.value[0].value[0].value == 'flask' and + node.value[0].value[1].value == 'ext'): + package = node.value[0].value[2] + name = node.names()[0].split('.')[-1] + node.replace("import flask_%s as %s" % (package, name)) except IndexError: pass @@ -90,7 +87,7 @@ def fix_standard_imports(red): def _get_modules(module): """ - Takes a list of modules and converts into a string + Takes a list of modules and converts into a string. The module list can include parens, this function checks each element in the list, if there is a paren then it does not add a comma before the next @@ -105,20 +102,46 @@ def _get_modules(module): return ''.join(modules_string) +def fix_function_calls(red): + """ + Modifies function calls in the source to reflect import changes. + + Searches the AST for AtomtrailerNodes and replaces them. + """ + atoms = red.find_all("Atomtrailers") + for x, node in enumerate(atoms): + try: + if (node.value[0].value == 'flask' and + node.value[1].value == 'ext'): + node.replace("flask_foo%s" % node.value[3]) + except IndexError: + pass + + return red + + def check_user_input(): """Exits and gives error message if no argument is passed in the shell.""" if len(sys.argv) < 2: sys.exit("No filename was included, please try again.") -def fix(ast): +def fix_tester(ast): """Wrapper which allows for testing when not running from shell.""" - return fix_imports(ast).dumps() + ast = fix_imports(ast) + ast = fix_function_calls(ast) + return ast.dumps() -if __name__ == "__main__": +def fix(): + """Wrapper for user argument checking and import fixing.""" check_user_input() input_file = sys.argv[1] ast = read_source(input_file) ast = fix_imports(ast) + ast = fix_function_calls(ast) write_source(ast, input_file) + + +if __name__ == "__main__": + fix() diff --git a/scripts/test_import_migration.py b/scripts/test_import_migration.py index 71956749..fa30ef69 100644 --- a/scripts/test_import_migration.py +++ b/scripts/test_import_migration.py @@ -8,19 +8,19 @@ import flaskext_migrate as migrate def test_simple_from_import(): red = RedBaron("from flask.ext import foo") - output = migrate.fix(red) + output = migrate.fix_tester(red) assert output == "import flask_foo as foo" def test_from_to_from_import(): red = RedBaron("from flask.ext.foo import bar") - output = migrate.fix(red) + output = migrate.fix_tester(red) assert output == "from flask_foo import bar as bar" def test_multiple_import(): red = RedBaron("from flask.ext.foo import bar, foobar, something") - output = migrate.fix(red) + output = migrate.fix_tester(red) assert output == "from flask_foo import bar, foobar, something" @@ -29,23 +29,35 @@ def test_multiline_import(): bar,\ foobar,\ something") - output = migrate.fix(red) + output = migrate.fix_tester(red) assert output == "from flask_foo import bar, foobar, something" def test_module_import(): red = RedBaron("import flask.ext.foo") - output = migrate.fix(red) - assert output == "import flask_foo" + output = migrate.fix_tester(red) + assert output == "import flask_foo as foo" -def test_module_import(): +def test_named_module_import(): + red = RedBaron("import flask.ext.foo as foobar") + output = migrate.fix_tester(red) + assert output == "import flask_foo as foobar" + + +def test__named_from_import(): red = RedBaron("from flask.ext.foo import bar as baz") - output = migrate.fix(red) + output = migrate.fix_tester(red) assert output == "from flask_foo import bar as baz" def test_parens_import(): red = RedBaron("from flask.ext.foo import (bar, foo, foobar)") - output = migrate.fix(red) + output = migrate.fix_tester(red) assert output == "from flask_foo import (bar, foo, foobar)" + + +def test_function_call_migration(): + red = RedBaron("flask.ext.foo(var)") + output = migrate.fix_tester(red) + assert output == "flask_foo(var)"