diff --git a/scripts/flaskext_migrate.py b/scripts/flaskext_migrate.py new file mode 100644 index 00000000..e24ab95c --- /dev/null +++ b/scripts/flaskext_migrate.py @@ -0,0 +1,160 @@ +# Script which modifies source code away from the deprecated "flask.ext" +# format. +# +# Run in the terminal by typing: `python flaskext_migrate.py ` +# +# Author: Keyan Pishdadian 2015 + +from redbaron import RedBaron +import sys + + +def read_source(input_file): + """Parses the input_file into a RedBaron FST.""" + with open(input_file, "r") as source_code: + red = RedBaron(source_code.read()) + return red + + +def write_source(red, input_file): + """Overwrites the input_file once the FST has been modified.""" + with open(input_file, "w") as source_code: + source_code.write(red.dumps()) + + +def fix_imports(red): + """Wrapper which fixes "from" style imports and then "import" style.""" + red = fix_standard_imports(red) + red = fix_from_imports(red) + return red + + +def fix_from_imports(red): + """ + Converts "from" style imports to not use "flask.ext". + + Handles: + Case 1: from flask.ext.foo import bam --> from flask_foo import bam + Case 2: from flask.ext import foo --> import flask_foo as foo + """ + from_imports = red.find_all("FromImport") + for x, node in enumerate(from_imports): + values = node.value + if (values[0].value == 'flask') and (values[1].value == 'ext'): + # Case 1 + if len(node.value) == 3: + package = values[2].value + modules = node.modules() + module_string = _get_modules(modules) + if len(modules) > 1: + node.replace("from flask_%s import %s" + % (package, module_string)) + else: + name = node.names()[0] + node.replace("from flask_%s import %s as %s" + % (package, module_string, name)) + # Case 2 + else: + module = node.modules()[0] + node.replace("import flask_%s as %s" + % (module, module)) + return red + + +def fix_standard_imports(red): + """ + Handles import modification in the form: + import flask.ext.foo" --> import flask_foo + """ + imports = red.find_all("ImportNode") + for x, node in enumerate(imports): + try: + if (node.value[0].value[0].value == 'flask' and + node.value[0].value[1].value == 'ext'): + package = node.value[0].value[2].value + name = node.names()[0].split('.')[-1] + if name == package: + node.replace("import flask_%s" % (package)) + else: + node.replace("import flask_%s as %s" % (package, name)) + except IndexError: + pass + + return red + + +def _get_modules(module): + """ + Takes a list of modules and converts into a string. + + The module list can include parens, this function checks each element in + the list, if there is a paren then it does not add a comma before the next + element. Otherwise a comma and space is added. This is to preserve module + imports which are multi-line and/or occur within parens. While also not + affecting imports which are not enclosed. + """ + modules_string = [cur + ', ' if cur.isalnum() and next.isalnum() + else cur + for (cur, next) in zip(module, module[1:]+[''])] + + return ''.join(modules_string) + + +def fix_function_calls(red): + """ + Modifies function calls in the source to reflect import changes. + + Searches the AST for AtomtrailerNodes and replaces them. + """ + atoms = red.find_all("Atomtrailers") + for x, node in enumerate(atoms): + try: + if (node.value[0].value == 'flask' and + node.value[1].value == 'ext'): + params = _form_function_call(node) + node.replace("flask_%s%s" % (node.value[2], params)) + except IndexError: + pass + + return red + + +def _form_function_call(node): + """ + Reconstructs function call strings when making attribute access calls. + """ + node_vals = node.value + output = "." + for x, param in enumerate(node_vals[3::]): + if param.dumps()[0] == "(": + output = output[0:-1] + param.dumps() + return output + else: + output += param.dumps() + "." + + +def check_user_input(): + """Exits and gives error message if no argument is passed in the shell.""" + if len(sys.argv) < 2: + sys.exit("No filename was included, please try again.") + + +def fix_tester(ast): + """Wrapper which allows for testing when not running from shell.""" + ast = fix_imports(ast) + ast = fix_function_calls(ast) + return ast.dumps() + + +def fix(): + """Wrapper for user argument checking and import fixing.""" + check_user_input() + input_file = sys.argv[1] + ast = read_source(input_file) + ast = fix_imports(ast) + ast = fix_function_calls(ast) + write_source(ast, input_file) + + +if __name__ == "__main__": + fix() diff --git a/scripts/test_import_migration.py b/scripts/test_import_migration.py new file mode 100644 index 00000000..0220e70a --- /dev/null +++ b/scripts/test_import_migration.py @@ -0,0 +1,71 @@ +# Tester for the flaskext_migrate.py module located in flask/scripts/ +# +# Author: Keyan Pishdadian +import pytest +from redbaron import RedBaron +import flaskext_migrate as migrate + + +def test_simple_from_import(): + red = RedBaron("from flask.ext import foo") + output = migrate.fix_tester(red) + assert output == "import flask_foo as foo" + + +def test_from_to_from_import(): + red = RedBaron("from flask.ext.foo import bar") + output = migrate.fix_tester(red) + assert output == "from flask_foo import bar as bar" + + +def test_multiple_import(): + red = RedBaron("from flask.ext.foo import bar, foobar, something") + output = migrate.fix_tester(red) + assert output == "from flask_foo import bar, foobar, something" + + +def test_multiline_import(): + red = RedBaron("from flask.ext.foo import \ + bar,\ + foobar,\ + something") + output = migrate.fix_tester(red) + assert output == "from flask_foo import bar, foobar, something" + + +def test_module_import(): + red = RedBaron("import flask.ext.foo") + output = migrate.fix_tester(red) + assert output == "import flask_foo" + + +def test_named_module_import(): + red = RedBaron("import flask.ext.foo as foobar") + output = migrate.fix_tester(red) + assert output == "import flask_foo as foobar" + + +def test__named_from_import(): + red = RedBaron("from flask.ext.foo import bar as baz") + output = migrate.fix_tester(red) + assert output == "from flask_foo import bar as baz" + + +def test_parens_import(): + red = RedBaron("from flask.ext.foo import (bar, foo, foobar)") + output = migrate.fix_tester(red) + assert output == "from flask_foo import (bar, foo, foobar)" + + +def test_function_call_migration(): + red = RedBaron("flask.ext.foo(var)") + output = migrate.fix_tester(red) + assert output == "flask_foo(var)" + + +def test_nested_function_call_migration(): + red = RedBaron("import flask.ext.foo\n\n" + "flask.ext.foo.bar(var)") + output = migrate.fix_tester(red) + assert output == ("import flask_foo\n\n" + "flask_foo.bar(var)") diff --git a/tox.ini b/tox.ini index ba2d2668..3e170d87 100644 --- a/tox.ini +++ b/tox.ini @@ -8,6 +8,7 @@ commands = deps= pytest greenlet + redbaron lowest: Werkzeug==0.7 lowest: Jinja2==2.4