Add a test and cover edge case with parens

This commit is contained in:
Keyan Pishdadian 2015-02-12 16:41:57 -05:00
parent 1479cf80f6
commit 9cbe83ef0d
2 changed files with 81 additions and 7 deletions

View file

@ -50,14 +50,14 @@ def fix_from_imports(red):
if len(node.value) == 3:
package = values[2].value
modules = node.modules()
module_string = _get_modules(modules)
if len(modules) > 1:
r = "{}," * len(modules)
node.replace("from flask_%s import %s"
% (package, r.format(*modules)[:-1]))
% (package, module_string))
else:
name = node.names()[0]
node.replace("from flask_%s import %s as %s"
% (package, modules.pop(), name))
% (package, module_string, name))
# Case 2
else:
module = node.modules()[0]
@ -88,13 +88,36 @@ def fix_standard_imports(red):
return red
def fix(ast):
"""Wrapper which allows for testing when not running from shell"""
return fix_imports(ast).dumps()
def _get_modules(module):
"""
Takes a list of modules and converts into a string
if __name__ == "__main__":
The module list can include parens, this function checks each element in
the list, if there is a paren then it does not add a comma before the next
element. Otherwise a comma and space is added. This is to preserve module
imports which are multi-line and/or occur within parens. While also not
affecting imports which are not enclosed.
"""
modules_string = [cur + ', ' if cur.isalnum() and next.isalnum()
else cur
for (cur, next) in zip(module, module[1:]+[''])]
return ''.join(modules_string)
def check_user_input():
"""Exits and gives error message if no argument is passed in the shell."""
if len(sys.argv) < 2:
sys.exit("No filename was included, please try again.")
def fix(ast):
"""Wrapper which allows for testing when not running from shell."""
return fix_imports(ast).dumps()
if __name__ == "__main__":
check_user_input()
input_file = sys.argv[1]
ast = read_source(input_file)
ast = fix_imports(ast)