diff --git a/scripts/flaskext_migrate.py b/scripts/flaskext_migrate.py index 9b378a09..53bc80e4 100644 --- a/scripts/flaskext_migrate.py +++ b/scripts/flaskext_migrate.py @@ -1,22 +1,47 @@ -# CASE 1 - from flask.ext.foo import bam --> from flask_foo import bam -# CASE 2 - from flask.ext import foo --> import flask_foo as foo +# Script which modifies source code away from the deprecated "flask.ext" +# format. Does not yet fully support imports in the style: +# +# "import flask.ext.foo" +# +# these are converted to "import flask_foo" in the +# main import statement, but does not handle function calls in the source. +# +# Run in the terminal by typing: `python flaskext_migrate.py ` +# +# Author: Keyan Pishdadian 2015 from redbaron import RedBaron import sys def read_source(input_file): + """Parses the input_file into a RedBaron FST.""" with open(input_file, "r") as source_code: red = RedBaron(source_code.read()) return red def write_source(red, input_file): + """Overwrites the input_file once the FST has been modified.""" with open(input_file, "w") as source_code: source_code.write(red.dumps()) def fix_imports(red): + """Wrapper which fixes "from" style imports and then "import" style.""" + red = fix_standard_imports(red) + red = fix_from_imports(red) + return red + + +def fix_from_imports(red): + """ + Converts "from" style imports to not use "flask.ext". + + Handles: + Case 1: from flask.ext.foo import bam --> from flask_foo import bam + Case 2: from flask.ext import foo --> import flask_foo as foo + """ from_imports = red.find_all("FromImport") for x in range(len(from_imports)): values = from_imports[x].value @@ -33,6 +58,27 @@ def fix_imports(red): module = from_imports[x].modules()[0] from_imports[x].replace("import flask_%s as %s" % (module, module)) + return red + + +def fix_standard_imports(red): + """ + Handles import modification in the form: + import flask.ext.foo" --> import flask_foo + + Does not modify function calls elsewhere in the source outside of the + original import statement. + """ + imports = red.find_all("ImportNode") + for x in range(len(imports)): + values = imports[x].value + try: + if (values[x].value[0].value == 'flask' and + values[x].value[1].value == 'ext'): + package = values[x].value[2].value + imports[x].replace("import flask_%s" % package) + except IndexError: + pass return red