Add a test and cover edge case with parens
This commit is contained in:
parent
1479cf80f6
commit
9cbe83ef0d
2 changed files with 81 additions and 7 deletions
|
|
@ -50,14 +50,14 @@ def fix_from_imports(red):
|
||||||
if len(node.value) == 3:
|
if len(node.value) == 3:
|
||||||
package = values[2].value
|
package = values[2].value
|
||||||
modules = node.modules()
|
modules = node.modules()
|
||||||
|
module_string = _get_modules(modules)
|
||||||
if len(modules) > 1:
|
if len(modules) > 1:
|
||||||
r = "{}," * len(modules)
|
|
||||||
node.replace("from flask_%s import %s"
|
node.replace("from flask_%s import %s"
|
||||||
% (package, r.format(*modules)[:-1]))
|
% (package, module_string))
|
||||||
else:
|
else:
|
||||||
name = node.names()[0]
|
name = node.names()[0]
|
||||||
node.replace("from flask_%s import %s as %s"
|
node.replace("from flask_%s import %s as %s"
|
||||||
% (package, modules.pop(), name))
|
% (package, module_string, name))
|
||||||
# Case 2
|
# Case 2
|
||||||
else:
|
else:
|
||||||
module = node.modules()[0]
|
module = node.modules()[0]
|
||||||
|
|
@ -88,13 +88,36 @@ def fix_standard_imports(red):
|
||||||
return red
|
return red
|
||||||
|
|
||||||
|
|
||||||
def fix(ast):
|
def _get_modules(module):
|
||||||
"""Wrapper which allows for testing when not running from shell"""
|
"""
|
||||||
return fix_imports(ast).dumps()
|
Takes a list of modules and converts into a string
|
||||||
|
|
||||||
if __name__ == "__main__":
|
The module list can include parens, this function checks each element in
|
||||||
|
the list, if there is a paren then it does not add a comma before the next
|
||||||
|
element. Otherwise a comma and space is added. This is to preserve module
|
||||||
|
imports which are multi-line and/or occur within parens. While also not
|
||||||
|
affecting imports which are not enclosed.
|
||||||
|
"""
|
||||||
|
modules_string = [cur + ', ' if cur.isalnum() and next.isalnum()
|
||||||
|
else cur
|
||||||
|
for (cur, next) in zip(module, module[1:]+[''])]
|
||||||
|
|
||||||
|
return ''.join(modules_string)
|
||||||
|
|
||||||
|
|
||||||
|
def check_user_input():
|
||||||
|
"""Exits and gives error message if no argument is passed in the shell."""
|
||||||
if len(sys.argv) < 2:
|
if len(sys.argv) < 2:
|
||||||
sys.exit("No filename was included, please try again.")
|
sys.exit("No filename was included, please try again.")
|
||||||
|
|
||||||
|
|
||||||
|
def fix(ast):
|
||||||
|
"""Wrapper which allows for testing when not running from shell."""
|
||||||
|
return fix_imports(ast).dumps()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
check_user_input()
|
||||||
input_file = sys.argv[1]
|
input_file = sys.argv[1]
|
||||||
ast = read_source(input_file)
|
ast = read_source(input_file)
|
||||||
ast = fix_imports(ast)
|
ast = fix_imports(ast)
|
||||||
|
|
|
||||||
51
scripts/test_import_migration.py
Normal file
51
scripts/test_import_migration.py
Normal file
|
|
@ -0,0 +1,51 @@
|
||||||
|
# Tester for the flaskext_migrate.py module located in flask/scripts/
|
||||||
|
#
|
||||||
|
# Author: Keyan Pishdadian
|
||||||
|
import pytest
|
||||||
|
from redbaron import RedBaron
|
||||||
|
import flaskext_migrate as migrate
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_from_import():
|
||||||
|
red = RedBaron("from flask.ext import foo")
|
||||||
|
output = migrate.fix(red)
|
||||||
|
assert output == "import flask_foo as foo"
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_to_from_import():
|
||||||
|
red = RedBaron("from flask.ext.foo import bar")
|
||||||
|
output = migrate.fix(red)
|
||||||
|
assert output == "from flask_foo import bar as bar"
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_import():
|
||||||
|
red = RedBaron("from flask.ext.foo import bar, foobar, something")
|
||||||
|
output = migrate.fix(red)
|
||||||
|
assert output == "from flask_foo import bar, foobar, something"
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiline_import():
|
||||||
|
red = RedBaron("from flask.ext.foo import \
|
||||||
|
bar,\
|
||||||
|
foobar,\
|
||||||
|
something")
|
||||||
|
output = migrate.fix(red)
|
||||||
|
assert output == "from flask_foo import bar, foobar, something"
|
||||||
|
|
||||||
|
|
||||||
|
def test_module_import():
|
||||||
|
red = RedBaron("import flask.ext.foo")
|
||||||
|
output = migrate.fix(red)
|
||||||
|
assert output == "import flask_foo"
|
||||||
|
|
||||||
|
|
||||||
|
def test_module_import():
|
||||||
|
red = RedBaron("from flask.ext.foo import bar as baz")
|
||||||
|
output = migrate.fix(red)
|
||||||
|
assert output == "from flask_foo import bar as baz"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parens_import():
|
||||||
|
red = RedBaron("from flask.ext.foo import (bar, foo, foobar)")
|
||||||
|
output = migrate.fix(red)
|
||||||
|
assert output == "from flask_foo import (bar, foo, foobar)"
|
||||||
Loading…
Add table
Add a link
Reference in a new issue