forked from orbit-oss/flask
Addresses #1135, some code cleanup and refactoring. Changes wrapper function which handles testing, further modularized code, added test to cover function call fixing, and fixed duplicate test function name.
147 lines
4.5 KiB
Python
147 lines
4.5 KiB
Python
# Script which modifies source code away from the deprecated "flask.ext"
|
|
# format. Does not yet fully support imports in the style:
|
|
#
|
|
# "import flask.ext.foo"
|
|
#
|
|
# these are converted to "import flask_foo" in the
|
|
# main import statement, but does not handle function calls in the source.
|
|
#
|
|
# Run in the terminal by typing: `python flaskext_migrate.py <source_file.py>`
|
|
#
|
|
# Author: Keyan Pishdadian 2015
|
|
|
|
from redbaron import RedBaron
|
|
import sys
|
|
|
|
|
|
def read_source(input_file):
|
|
"""Parses the input_file into a RedBaron FST."""
|
|
with open(input_file, "r") as source_code:
|
|
red = RedBaron(source_code.read())
|
|
return red
|
|
|
|
|
|
def write_source(red, input_file):
|
|
"""Overwrites the input_file once the FST has been modified."""
|
|
with open(input_file, "w") as source_code:
|
|
source_code.write(red.dumps())
|
|
|
|
|
|
def fix_imports(red):
|
|
"""Wrapper which fixes "from" style imports and then "import" style."""
|
|
red = fix_standard_imports(red)
|
|
red = fix_from_imports(red)
|
|
return red
|
|
|
|
|
|
def fix_from_imports(red):
|
|
"""
|
|
Converts "from" style imports to not use "flask.ext".
|
|
|
|
Handles:
|
|
Case 1: from flask.ext.foo import bam --> from flask_foo import bam
|
|
Case 2: from flask.ext import foo --> import flask_foo as foo
|
|
"""
|
|
from_imports = red.find_all("FromImport")
|
|
for x, node in enumerate(from_imports):
|
|
values = node.value
|
|
if (values[0].value == 'flask') and (values[1].value == 'ext'):
|
|
# Case 1
|
|
if len(node.value) == 3:
|
|
package = values[2].value
|
|
modules = node.modules()
|
|
module_string = _get_modules(modules)
|
|
if len(modules) > 1:
|
|
node.replace("from flask_%s import %s"
|
|
% (package, module_string))
|
|
else:
|
|
name = node.names()[0]
|
|
node.replace("from flask_%s import %s as %s"
|
|
% (package, module_string, name))
|
|
# Case 2
|
|
else:
|
|
module = node.modules()[0]
|
|
node.replace("import flask_%s as %s"
|
|
% (module, module))
|
|
return red
|
|
|
|
|
|
def fix_standard_imports(red):
|
|
"""
|
|
Handles import modification in the form:
|
|
import flask.ext.foo" --> import flask_foo
|
|
"""
|
|
imports = red.find_all("ImportNode")
|
|
for x, node in enumerate(imports):
|
|
try:
|
|
if (node.value[0].value[0].value == 'flask' and
|
|
node.value[0].value[1].value == 'ext'):
|
|
package = node.value[0].value[2]
|
|
name = node.names()[0].split('.')[-1]
|
|
node.replace("import flask_%s as %s" % (package, name))
|
|
except IndexError:
|
|
pass
|
|
|
|
return red
|
|
|
|
|
|
def _get_modules(module):
|
|
"""
|
|
Takes a list of modules and converts into a string.
|
|
|
|
The module list can include parens, this function checks each element in
|
|
the list, if there is a paren then it does not add a comma before the next
|
|
element. Otherwise a comma and space is added. This is to preserve module
|
|
imports which are multi-line and/or occur within parens. While also not
|
|
affecting imports which are not enclosed.
|
|
"""
|
|
modules_string = [cur + ', ' if cur.isalnum() and next.isalnum()
|
|
else cur
|
|
for (cur, next) in zip(module, module[1:]+[''])]
|
|
|
|
return ''.join(modules_string)
|
|
|
|
|
|
def fix_function_calls(red):
|
|
"""
|
|
Modifies function calls in the source to reflect import changes.
|
|
|
|
Searches the AST for AtomtrailerNodes and replaces them.
|
|
"""
|
|
atoms = red.find_all("Atomtrailers")
|
|
for x, node in enumerate(atoms):
|
|
try:
|
|
if (node.value[0].value == 'flask' and
|
|
node.value[1].value == 'ext'):
|
|
node.replace("flask_foo%s" % node.value[3])
|
|
except IndexError:
|
|
pass
|
|
|
|
return red
|
|
|
|
|
|
def check_user_input():
|
|
"""Exits and gives error message if no argument is passed in the shell."""
|
|
if len(sys.argv) < 2:
|
|
sys.exit("No filename was included, please try again.")
|
|
|
|
|
|
def fix_tester(ast):
|
|
"""Wrapper which allows for testing when not running from shell."""
|
|
ast = fix_imports(ast)
|
|
ast = fix_function_calls(ast)
|
|
return ast.dumps()
|
|
|
|
|
|
def fix():
|
|
"""Wrapper for user argument checking and import fixing."""
|
|
check_user_input()
|
|
input_file = sys.argv[1]
|
|
ast = read_source(input_file)
|
|
ast = fix_imports(ast)
|
|
ast = fix_function_calls(ast)
|
|
write_source(ast, input_file)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
fix()
|