Add support for function call fixing, add tests

Addresses #1135, some code cleanup and refactoring. Changes wrapper function which handles testing, further modularized code, added test to cover function call fixing, and fixed duplicate test function name.
This commit is contained in:
Keyan Pishdadian 2015-02-13 15:27:29 -05:00
parent 9cbe83ef0d
commit 4cb311b945
2 changed files with 56 additions and 21 deletions

View file

@ -70,18 +70,15 @@ def fix_standard_imports(red):
"""
Handles import modification in the form:
import flask.ext.foo" --> import flask_foo
Does not modify function calls elsewhere in the source outside of the
original import statement.
"""
imports = red.find_all("ImportNode")
for x, node in enumerate(imports):
try:
if (node.value[0].value == 'flask' and
node.value[1].value == 'ext'):
package = node.value[2].value
name = node.names()[0]
imports[x].replace("import flask_%s as %s" % (package, name))
if (node.value[0].value[0].value == 'flask' and
node.value[0].value[1].value == 'ext'):
package = node.value[0].value[2]
name = node.names()[0].split('.')[-1]
node.replace("import flask_%s as %s" % (package, name))
except IndexError:
pass
@ -90,7 +87,7 @@ def fix_standard_imports(red):
def _get_modules(module):
"""
Takes a list of modules and converts into a string
Takes a list of modules and converts into a string.
The module list can include parens, this function checks each element in
the list, if there is a paren then it does not add a comma before the next
@ -105,20 +102,46 @@ def _get_modules(module):
return ''.join(modules_string)
def fix_function_calls(red):
"""
Modifies function calls in the source to reflect import changes.
Searches the AST for AtomtrailerNodes and replaces them.
"""
atoms = red.find_all("Atomtrailers")
for x, node in enumerate(atoms):
try:
if (node.value[0].value == 'flask' and
node.value[1].value == 'ext'):
node.replace("flask_foo%s" % node.value[3])
except IndexError:
pass
return red
def check_user_input():
"""Exits and gives error message if no argument is passed in the shell."""
if len(sys.argv) < 2:
sys.exit("No filename was included, please try again.")
def fix(ast):
def fix_tester(ast):
"""Wrapper which allows for testing when not running from shell."""
return fix_imports(ast).dumps()
ast = fix_imports(ast)
ast = fix_function_calls(ast)
return ast.dumps()
if __name__ == "__main__":
def fix():
"""Wrapper for user argument checking and import fixing."""
check_user_input()
input_file = sys.argv[1]
ast = read_source(input_file)
ast = fix_imports(ast)
ast = fix_function_calls(ast)
write_source(ast, input_file)
if __name__ == "__main__":
fix()