Merge pull request #3183 from singingwolfboy/pre-commit

Pre-commit: Black
This commit is contained in:
David Baumgold 2019-05-06 16:34:46 -04:00 committed by GitHub
commit 1dda032b00
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
65 changed files with 3794 additions and 3459 deletions

5
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,5 @@
repos:
- repo: https://github.com/python/black
rev: 19.3b0
hooks:
- id: black

View file

@ -3,14 +3,14 @@ from flask import jsonify, render_template, request
from js_example import app from js_example import app
@app.route('/', defaults={'js': 'plain'}) @app.route("/", defaults={"js": "plain"})
@app.route('/<any(plain, jquery, fetch):js>') @app.route("/<any(plain, jquery, fetch):js>")
def index(js): def index(js):
return render_template('{0}.html'.format(js), js=js) return render_template("{0}.html".format(js), js=js)
@app.route('/add', methods=['POST']) @app.route("/add", methods=["POST"])
def add(): def add():
a = request.form.get('a', 0, type=float) a = request.form.get("a", 0, type=float)
b = request.form.get('b', 0, type=float) b = request.form.get("b", 0, type=float)
return jsonify(result=a + b) return jsonify(result=a + b)

View file

@ -2,29 +2,21 @@ import io
from setuptools import find_packages, setup from setuptools import find_packages, setup
with io.open('README.rst', 'rt', encoding='utf8') as f: with io.open("README.rst", "rt", encoding="utf8") as f:
readme = f.read() readme = f.read()
setup( setup(
name='js_example', name="js_example",
version='1.0.0', version="1.0.0",
url='http://flask.pocoo.org/docs/patterns/jquery/', url="http://flask.pocoo.org/docs/patterns/jquery/",
license='BSD', license="BSD",
maintainer='Pallets team', maintainer="Pallets team",
maintainer_email='contact@palletsprojects.com', maintainer_email="contact@palletsprojects.com",
description='Demonstrates making Ajax requests to Flask.', description="Demonstrates making Ajax requests to Flask.",
long_description=readme, long_description=readme,
packages=find_packages(), packages=find_packages(),
include_package_data=True, include_package_data=True,
zip_safe=False, zip_safe=False,
install_requires=[ install_requires=["flask"],
'flask', extras_require={"test": ["pytest", "coverage", "blinker"]},
],
extras_require={
'test': [
'pytest',
'coverage',
'blinker',
],
},
) )

View file

@ -3,7 +3,7 @@ import pytest
from js_example import app from js_example import app
@pytest.fixture(name='app') @pytest.fixture(name="app")
def fixture_app(): def fixture_app():
app.testing = True app.testing = True
yield app yield app

View file

@ -3,12 +3,15 @@ import pytest
from flask import template_rendered from flask import template_rendered
@pytest.mark.parametrize(('path', 'template_name'), ( @pytest.mark.parametrize(
('/', 'plain.html'), ("path", "template_name"),
('/plain', 'plain.html'), (
('/fetch', 'fetch.html'), ("/", "plain.html"),
('/jquery', 'jquery.html'), ("/plain", "plain.html"),
)) ("/fetch", "fetch.html"),
("/jquery", "jquery.html"),
),
)
def test_index(app, client, path, template_name): def test_index(app, client, path, template_name):
def check(sender, template, context): def check(sender, template, context):
assert template.name == template_name assert template.name == template_name
@ -17,12 +20,9 @@ def test_index(app, client, path, template_name):
client.get(path) client.get(path)
@pytest.mark.parametrize(('a', 'b', 'result'), ( @pytest.mark.parametrize(
(2, 3, 5), ("a", "b", "result"), ((2, 3, 5), (2.5, 3, 5.5), (2, None, 2), (2, "b", 2))
(2.5, 3, 5.5), )
(2, None, 2),
(2, 'b', 2),
))
def test_add(client, a, b, result): def test_add(client, a, b, result):
response = client.post('/add', data={'a': a, 'b': b}) response = client.post("/add", data={"a": a, "b": b})
assert response.get_json()['result'] == result assert response.get_json()["result"] == result

View file

@ -8,14 +8,14 @@ def create_app(test_config=None):
app = Flask(__name__, instance_relative_config=True) app = Flask(__name__, instance_relative_config=True)
app.config.from_mapping( app.config.from_mapping(
# a default secret that should be overridden by instance config # a default secret that should be overridden by instance config
SECRET_KEY='dev', SECRET_KEY="dev",
# store the database in the instance folder # store the database in the instance folder
DATABASE=os.path.join(app.instance_path, 'flaskr.sqlite'), DATABASE=os.path.join(app.instance_path, "flaskr.sqlite"),
) )
if test_config is None: if test_config is None:
# load the instance config, if it exists, when not testing # load the instance config, if it exists, when not testing
app.config.from_pyfile('config.py', silent=True) app.config.from_pyfile("config.py", silent=True)
else: else:
# load the test config if passed in # load the test config if passed in
app.config.update(test_config) app.config.update(test_config)
@ -26,16 +26,18 @@ def create_app(test_config=None):
except OSError: except OSError:
pass pass
@app.route('/hello') @app.route("/hello")
def hello(): def hello():
return 'Hello, World!' return "Hello, World!"
# register the database commands # register the database commands
from flaskr import db from flaskr import db
db.init_app(app) db.init_app(app)
# apply the blueprints to the app # apply the blueprints to the app
from flaskr import auth, blog from flaskr import auth, blog
app.register_blueprint(auth.bp) app.register_blueprint(auth.bp)
app.register_blueprint(blog.bp) app.register_blueprint(blog.bp)
@ -43,6 +45,6 @@ def create_app(test_config=None):
# in another app, you might define a separate main index here with # in another app, you might define a separate main index here with
# app.route, while giving the blog blueprint a url_prefix, but for # app.route, while giving the blog blueprint a url_prefix, but for
# the tutorial the blog will be the main index # the tutorial the blog will be the main index
app.add_url_rule('/', endpoint='index') app.add_url_rule("/", endpoint="index")
return app return app

View file

@ -1,21 +1,29 @@
import functools import functools
from flask import ( from flask import (
Blueprint, flash, g, redirect, render_template, request, session, url_for Blueprint,
flash,
g,
redirect,
render_template,
request,
session,
url_for,
) )
from werkzeug.security import check_password_hash, generate_password_hash from werkzeug.security import check_password_hash, generate_password_hash
from flaskr.db import get_db from flaskr.db import get_db
bp = Blueprint('auth', __name__, url_prefix='/auth') bp = Blueprint("auth", __name__, url_prefix="/auth")
def login_required(view): def login_required(view):
"""View decorator that redirects anonymous users to the login page.""" """View decorator that redirects anonymous users to the login page."""
@functools.wraps(view) @functools.wraps(view)
def wrapped_view(**kwargs): def wrapped_view(**kwargs):
if g.user is None: if g.user is None:
return redirect(url_for('auth.login')) return redirect(url_for("auth.login"))
return view(**kwargs) return view(**kwargs)
@ -26,83 +34,84 @@ def login_required(view):
def load_logged_in_user(): def load_logged_in_user():
"""If a user id is stored in the session, load the user object from """If a user id is stored in the session, load the user object from
the database into ``g.user``.""" the database into ``g.user``."""
user_id = session.get('user_id') user_id = session.get("user_id")
if user_id is None: if user_id is None:
g.user = None g.user = None
else: else:
g.user = get_db().execute( g.user = (
'SELECT * FROM user WHERE id = ?', (user_id,) get_db().execute("SELECT * FROM user WHERE id = ?", (user_id,)).fetchone()
).fetchone() )
@bp.route('/register', methods=('GET', 'POST')) @bp.route("/register", methods=("GET", "POST"))
def register(): def register():
"""Register a new user. """Register a new user.
Validates that the username is not already taken. Hashes the Validates that the username is not already taken. Hashes the
password for security. password for security.
""" """
if request.method == 'POST': if request.method == "POST":
username = request.form['username'] username = request.form["username"]
password = request.form['password'] password = request.form["password"]
db = get_db() db = get_db()
error = None error = None
if not username: if not username:
error = 'Username is required.' error = "Username is required."
elif not password: elif not password:
error = 'Password is required.' error = "Password is required."
elif db.execute( elif (
'SELECT id FROM user WHERE username = ?', (username,) db.execute("SELECT id FROM user WHERE username = ?", (username,)).fetchone()
).fetchone() is not None: is not None
error = 'User {0} is already registered.'.format(username) ):
error = "User {0} is already registered.".format(username)
if error is None: if error is None:
# the name is available, store it in the database and go to # the name is available, store it in the database and go to
# the login page # the login page
db.execute( db.execute(
'INSERT INTO user (username, password) VALUES (?, ?)', "INSERT INTO user (username, password) VALUES (?, ?)",
(username, generate_password_hash(password)) (username, generate_password_hash(password)),
) )
db.commit() db.commit()
return redirect(url_for('auth.login')) return redirect(url_for("auth.login"))
flash(error) flash(error)
return render_template('auth/register.html') return render_template("auth/register.html")
@bp.route('/login', methods=('GET', 'POST')) @bp.route("/login", methods=("GET", "POST"))
def login(): def login():
"""Log in a registered user by adding the user id to the session.""" """Log in a registered user by adding the user id to the session."""
if request.method == 'POST': if request.method == "POST":
username = request.form['username'] username = request.form["username"]
password = request.form['password'] password = request.form["password"]
db = get_db() db = get_db()
error = None error = None
user = db.execute( user = db.execute(
'SELECT * FROM user WHERE username = ?', (username,) "SELECT * FROM user WHERE username = ?", (username,)
).fetchone() ).fetchone()
if user is None: if user is None:
error = 'Incorrect username.' error = "Incorrect username."
elif not check_password_hash(user['password'], password): elif not check_password_hash(user["password"], password):
error = 'Incorrect password.' error = "Incorrect password."
if error is None: if error is None:
# store the user id in a new session and return to the index # store the user id in a new session and return to the index
session.clear() session.clear()
session['user_id'] = user['id'] session["user_id"] = user["id"]
return redirect(url_for('index')) return redirect(url_for("index"))
flash(error) flash(error)
return render_template('auth/login.html') return render_template("auth/login.html")
@bp.route('/logout') @bp.route("/logout")
def logout(): def logout():
"""Clear the current session, including the stored user id.""" """Clear the current session, including the stored user id."""
session.clear() session.clear()
return redirect(url_for('index')) return redirect(url_for("index"))

View file

@ -1,24 +1,22 @@
from flask import ( from flask import Blueprint, flash, g, redirect, render_template, request, url_for
Blueprint, flash, g, redirect, render_template, request, url_for
)
from werkzeug.exceptions import abort from werkzeug.exceptions import abort
from flaskr.auth import login_required from flaskr.auth import login_required
from flaskr.db import get_db from flaskr.db import get_db
bp = Blueprint('blog', __name__) bp = Blueprint("blog", __name__)
@bp.route('/') @bp.route("/")
def index(): def index():
"""Show all the posts, most recent first.""" """Show all the posts, most recent first."""
db = get_db() db = get_db()
posts = db.execute( posts = db.execute(
'SELECT p.id, title, body, created, author_id, username' "SELECT p.id, title, body, created, author_id, username"
' FROM post p JOIN user u ON p.author_id = u.id' " FROM post p JOIN user u ON p.author_id = u.id"
' ORDER BY created DESC' " ORDER BY created DESC"
).fetchall() ).fetchall()
return render_template('blog/index.html', posts=posts) return render_template("blog/index.html", posts=posts)
def get_post(id, check_author=True): def get_post(id, check_author=True):
@ -33,78 +31,80 @@ def get_post(id, check_author=True):
:raise 404: if a post with the given id doesn't exist :raise 404: if a post with the given id doesn't exist
:raise 403: if the current user isn't the author :raise 403: if the current user isn't the author
""" """
post = get_db().execute( post = (
'SELECT p.id, title, body, created, author_id, username' get_db()
' FROM post p JOIN user u ON p.author_id = u.id' .execute(
' WHERE p.id = ?', "SELECT p.id, title, body, created, author_id, username"
(id,) " FROM post p JOIN user u ON p.author_id = u.id"
).fetchone() " WHERE p.id = ?",
(id,),
)
.fetchone()
)
if post is None: if post is None:
abort(404, "Post id {0} doesn't exist.".format(id)) abort(404, "Post id {0} doesn't exist.".format(id))
if check_author and post['author_id'] != g.user['id']: if check_author and post["author_id"] != g.user["id"]:
abort(403) abort(403)
return post return post
@bp.route('/create', methods=('GET', 'POST')) @bp.route("/create", methods=("GET", "POST"))
@login_required @login_required
def create(): def create():
"""Create a new post for the current user.""" """Create a new post for the current user."""
if request.method == 'POST': if request.method == "POST":
title = request.form['title'] title = request.form["title"]
body = request.form['body'] body = request.form["body"]
error = None error = None
if not title: if not title:
error = 'Title is required.' error = "Title is required."
if error is not None: if error is not None:
flash(error) flash(error)
else: else:
db = get_db() db = get_db()
db.execute( db.execute(
'INSERT INTO post (title, body, author_id)' "INSERT INTO post (title, body, author_id)" " VALUES (?, ?, ?)",
' VALUES (?, ?, ?)', (title, body, g.user["id"]),
(title, body, g.user['id'])
) )
db.commit() db.commit()
return redirect(url_for('blog.index')) return redirect(url_for("blog.index"))
return render_template('blog/create.html') return render_template("blog/create.html")
@bp.route('/<int:id>/update', methods=('GET', 'POST')) @bp.route("/<int:id>/update", methods=("GET", "POST"))
@login_required @login_required
def update(id): def update(id):
"""Update a post if the current user is the author.""" """Update a post if the current user is the author."""
post = get_post(id) post = get_post(id)
if request.method == 'POST': if request.method == "POST":
title = request.form['title'] title = request.form["title"]
body = request.form['body'] body = request.form["body"]
error = None error = None
if not title: if not title:
error = 'Title is required.' error = "Title is required."
if error is not None: if error is not None:
flash(error) flash(error)
else: else:
db = get_db() db = get_db()
db.execute( db.execute(
'UPDATE post SET title = ?, body = ? WHERE id = ?', "UPDATE post SET title = ?, body = ? WHERE id = ?", (title, body, id)
(title, body, id)
) )
db.commit() db.commit()
return redirect(url_for('blog.index')) return redirect(url_for("blog.index"))
return render_template('blog/update.html', post=post) return render_template("blog/update.html", post=post)
@bp.route('/<int:id>/delete', methods=('POST',)) @bp.route("/<int:id>/delete", methods=("POST",))
@login_required @login_required
def delete(id): def delete(id):
"""Delete a post. """Delete a post.
@ -114,6 +114,6 @@ def delete(id):
""" """
get_post(id) get_post(id)
db = get_db() db = get_db()
db.execute('DELETE FROM post WHERE id = ?', (id,)) db.execute("DELETE FROM post WHERE id = ?", (id,))
db.commit() db.commit()
return redirect(url_for('blog.index')) return redirect(url_for("blog.index"))

View file

@ -10,10 +10,9 @@ def get_db():
is unique for each request and will be reused if this is called is unique for each request and will be reused if this is called
again. again.
""" """
if 'db' not in g: if "db" not in g:
g.db = sqlite3.connect( g.db = sqlite3.connect(
current_app.config['DATABASE'], current_app.config["DATABASE"], detect_types=sqlite3.PARSE_DECLTYPES
detect_types=sqlite3.PARSE_DECLTYPES
) )
g.db.row_factory = sqlite3.Row g.db.row_factory = sqlite3.Row
@ -24,7 +23,7 @@ def close_db(e=None):
"""If this request connected to the database, close the """If this request connected to the database, close the
connection. connection.
""" """
db = g.pop('db', None) db = g.pop("db", None)
if db is not None: if db is not None:
db.close() db.close()
@ -34,16 +33,16 @@ def init_db():
"""Clear existing data and create new tables.""" """Clear existing data and create new tables."""
db = get_db() db = get_db()
with current_app.open_resource('schema.sql') as f: with current_app.open_resource("schema.sql") as f:
db.executescript(f.read().decode('utf8')) db.executescript(f.read().decode("utf8"))
@click.command('init-db') @click.command("init-db")
@with_appcontext @with_appcontext
def init_db_command(): def init_db_command():
"""Clear existing data and create new tables.""" """Clear existing data and create new tables."""
init_db() init_db()
click.echo('Initialized the database.') click.echo("Initialized the database.")
def init_app(app): def init_app(app):

View file

@ -2,28 +2,21 @@ import io
from setuptools import find_packages, setup from setuptools import find_packages, setup
with io.open('README.rst', 'rt', encoding='utf8') as f: with io.open("README.rst", "rt", encoding="utf8") as f:
readme = f.read() readme = f.read()
setup( setup(
name='flaskr', name="flaskr",
version='1.0.0', version="1.0.0",
url='http://flask.pocoo.org/docs/tutorial/', url="http://flask.pocoo.org/docs/tutorial/",
license='BSD', license="BSD",
maintainer='Pallets team', maintainer="Pallets team",
maintainer_email='contact@palletsprojects.com', maintainer_email="contact@palletsprojects.com",
description='The basic blog app built in the Flask tutorial.', description="The basic blog app built in the Flask tutorial.",
long_description=readme, long_description=readme,
packages=find_packages(), packages=find_packages(),
include_package_data=True, include_package_data=True,
zip_safe=False, zip_safe=False,
install_requires=[ install_requires=["flask"],
'flask', extras_require={"test": ["pytest", "coverage"]},
],
extras_require={
'test': [
'pytest',
'coverage',
],
},
) )

View file

@ -6,8 +6,8 @@ from flaskr import create_app
from flaskr.db import get_db, init_db from flaskr.db import get_db, init_db
# read in SQL for populating test data # read in SQL for populating test data
with open(os.path.join(os.path.dirname(__file__), 'data.sql'), 'rb') as f: with open(os.path.join(os.path.dirname(__file__), "data.sql"), "rb") as f:
_data_sql = f.read().decode('utf8') _data_sql = f.read().decode("utf8")
@pytest.fixture @pytest.fixture
@ -16,10 +16,7 @@ def app():
# create a temporary file to isolate the database for each test # create a temporary file to isolate the database for each test
db_fd, db_path = tempfile.mkstemp() db_fd, db_path = tempfile.mkstemp()
# create the app with common test config # create the app with common test config
app = create_app({ app = create_app({"TESTING": True, "DATABASE": db_path})
'TESTING': True,
'DATABASE': db_path,
})
# create the database and load test data # create the database and load test data
with app.app_context(): with app.app_context():
@ -49,14 +46,13 @@ class AuthActions(object):
def __init__(self, client): def __init__(self, client):
self._client = client self._client = client
def login(self, username='test', password='test'): def login(self, username="test", password="test"):
return self._client.post( return self._client.post(
'/auth/login', "/auth/login", data={"username": username, "password": password}
data={'username': username, 'password': password}
) )
def logout(self): def logout(self):
return self._client.get('/auth/logout') return self._client.get("/auth/logout")
@pytest.fixture @pytest.fixture

View file

@ -5,54 +5,55 @@ from flaskr.db import get_db
def test_register(client, app): def test_register(client, app):
# test that viewing the page renders without template errors # test that viewing the page renders without template errors
assert client.get('/auth/register').status_code == 200 assert client.get("/auth/register").status_code == 200
# test that successful registration redirects to the login page # test that successful registration redirects to the login page
response = client.post( response = client.post("/auth/register", data={"username": "a", "password": "a"})
'/auth/register', data={'username': 'a', 'password': 'a'} assert "http://localhost/auth/login" == response.headers["Location"]
)
assert 'http://localhost/auth/login' == response.headers['Location']
# test that the user was inserted into the database # test that the user was inserted into the database
with app.app_context(): with app.app_context():
assert get_db().execute( assert (
"select * from user where username = 'a'", get_db().execute("select * from user where username = 'a'").fetchone()
).fetchone() is not None is not None
)
@pytest.mark.parametrize(('username', 'password', 'message'), ( @pytest.mark.parametrize(
('', '', b'Username is required.'), ("username", "password", "message"),
('a', '', b'Password is required.'), (
('test', 'test', b'already registered'), ("", "", b"Username is required."),
)) ("a", "", b"Password is required."),
("test", "test", b"already registered"),
),
)
def test_register_validate_input(client, username, password, message): def test_register_validate_input(client, username, password, message):
response = client.post( response = client.post(
'/auth/register', "/auth/register", data={"username": username, "password": password}
data={'username': username, 'password': password}
) )
assert message in response.data assert message in response.data
def test_login(client, auth): def test_login(client, auth):
# test that viewing the page renders without template errors # test that viewing the page renders without template errors
assert client.get('/auth/login').status_code == 200 assert client.get("/auth/login").status_code == 200
# test that successful login redirects to the index page # test that successful login redirects to the index page
response = auth.login() response = auth.login()
assert response.headers['Location'] == 'http://localhost/' assert response.headers["Location"] == "http://localhost/"
# login request set the user_id in the session # login request set the user_id in the session
# check that the user is loaded from the session # check that the user is loaded from the session
with client: with client:
client.get('/') client.get("/")
assert session['user_id'] == 1 assert session["user_id"] == 1
assert g.user['username'] == 'test' assert g.user["username"] == "test"
@pytest.mark.parametrize(('username', 'password', 'message'), ( @pytest.mark.parametrize(
('a', 'test', b'Incorrect username.'), ("username", "password", "message"),
('test', 'a', b'Incorrect password.'), (("a", "test", b"Incorrect username."), ("test", "a", b"Incorrect password.")),
)) )
def test_login_validate_input(auth, username, password, message): def test_login_validate_input(auth, username, password, message):
response = auth.login(username, password) response = auth.login(username, password)
assert message in response.data assert message in response.data
@ -63,4 +64,4 @@ def test_logout(client, auth):
with client: with client:
auth.logout() auth.logout()
assert 'user_id' not in session assert "user_id" not in session

View file

@ -3,47 +3,40 @@ from flaskr.db import get_db
def test_index(client, auth): def test_index(client, auth):
response = client.get('/') response = client.get("/")
assert b"Log In" in response.data assert b"Log In" in response.data
assert b"Register" in response.data assert b"Register" in response.data
auth.login() auth.login()
response = client.get('/') response = client.get("/")
assert b'test title' in response.data assert b"test title" in response.data
assert b'by test on 2018-01-01' in response.data assert b"by test on 2018-01-01" in response.data
assert b'test\nbody' in response.data assert b"test\nbody" in response.data
assert b'href="/1/update"' in response.data assert b'href="/1/update"' in response.data
@pytest.mark.parametrize('path', ( @pytest.mark.parametrize("path", ("/create", "/1/update", "/1/delete"))
'/create',
'/1/update',
'/1/delete',
))
def test_login_required(client, path): def test_login_required(client, path):
response = client.post(path) response = client.post(path)
assert response.headers['Location'] == 'http://localhost/auth/login' assert response.headers["Location"] == "http://localhost/auth/login"
def test_author_required(app, client, auth): def test_author_required(app, client, auth):
# change the post author to another user # change the post author to another user
with app.app_context(): with app.app_context():
db = get_db() db = get_db()
db.execute('UPDATE post SET author_id = 2 WHERE id = 1') db.execute("UPDATE post SET author_id = 2 WHERE id = 1")
db.commit() db.commit()
auth.login() auth.login()
# current user can't modify other user's post # current user can't modify other user's post
assert client.post('/1/update').status_code == 403 assert client.post("/1/update").status_code == 403
assert client.post('/1/delete').status_code == 403 assert client.post("/1/delete").status_code == 403
# current user doesn't see edit link # current user doesn't see edit link
assert b'href="/1/update"' not in client.get('/').data assert b'href="/1/update"' not in client.get("/").data
@pytest.mark.parametrize('path', ( @pytest.mark.parametrize("path", ("/2/update", "/2/delete"))
'/2/update',
'/2/delete',
))
def test_exists_required(client, auth, path): def test_exists_required(client, auth, path):
auth.login() auth.login()
assert client.post(path).status_code == 404 assert client.post(path).status_code == 404
@ -51,42 +44,39 @@ def test_exists_required(client, auth, path):
def test_create(client, auth, app): def test_create(client, auth, app):
auth.login() auth.login()
assert client.get('/create').status_code == 200 assert client.get("/create").status_code == 200
client.post('/create', data={'title': 'created', 'body': ''}) client.post("/create", data={"title": "created", "body": ""})
with app.app_context(): with app.app_context():
db = get_db() db = get_db()
count = db.execute('SELECT COUNT(id) FROM post').fetchone()[0] count = db.execute("SELECT COUNT(id) FROM post").fetchone()[0]
assert count == 2 assert count == 2
def test_update(client, auth, app): def test_update(client, auth, app):
auth.login() auth.login()
assert client.get('/1/update').status_code == 200 assert client.get("/1/update").status_code == 200
client.post('/1/update', data={'title': 'updated', 'body': ''}) client.post("/1/update", data={"title": "updated", "body": ""})
with app.app_context(): with app.app_context():
db = get_db() db = get_db()
post = db.execute('SELECT * FROM post WHERE id = 1').fetchone() post = db.execute("SELECT * FROM post WHERE id = 1").fetchone()
assert post['title'] == 'updated' assert post["title"] == "updated"
@pytest.mark.parametrize('path', ( @pytest.mark.parametrize("path", ("/create", "/1/update"))
'/create',
'/1/update',
))
def test_create_update_validate(client, auth, path): def test_create_update_validate(client, auth, path):
auth.login() auth.login()
response = client.post(path, data={'title': '', 'body': ''}) response = client.post(path, data={"title": "", "body": ""})
assert b'Title is required.' in response.data assert b"Title is required." in response.data
def test_delete(client, auth, app): def test_delete(client, auth, app):
auth.login() auth.login()
response = client.post('/1/delete') response = client.post("/1/delete")
assert response.headers['Location'] == 'http://localhost/' assert response.headers["Location"] == "http://localhost/"
with app.app_context(): with app.app_context():
db = get_db() db = get_db()
post = db.execute('SELECT * FROM post WHERE id = 1').fetchone() post = db.execute("SELECT * FROM post WHERE id = 1").fetchone()
assert post is None assert post is None

View file

@ -10,9 +10,9 @@ def test_get_close_db(app):
assert db is get_db() assert db is get_db()
with pytest.raises(sqlite3.ProgrammingError) as e: with pytest.raises(sqlite3.ProgrammingError) as e:
db.execute('SELECT 1') db.execute("SELECT 1")
assert 'closed' in str(e) assert "closed" in str(e)
def test_init_db_command(runner, monkeypatch): def test_init_db_command(runner, monkeypatch):
@ -22,7 +22,7 @@ def test_init_db_command(runner, monkeypatch):
def fake_init_db(): def fake_init_db():
Recorder.called = True Recorder.called = True
monkeypatch.setattr('flaskr.db.init_db', fake_init_db) monkeypatch.setattr("flaskr.db.init_db", fake_init_db)
result = runner.invoke(args=['init-db']) result = runner.invoke(args=["init-db"])
assert 'Initialized' in result.output assert "Initialized" in result.output
assert Recorder.called assert Recorder.called

View file

@ -4,9 +4,9 @@ from flaskr import create_app
def test_config(): def test_config():
"""Test create_app without passing test config.""" """Test create_app without passing test config."""
assert not create_app().testing assert not create_app().testing
assert create_app({'TESTING': True}).testing assert create_app({"TESTING": True}).testing
def test_hello(client): def test_hello(client):
response = client.get('/hello') response = client.get("/hello")
assert response.data == b'Hello, World!' assert response.data == b"Hello, World!"

View file

@ -10,7 +10,7 @@
:license: BSD, see LICENSE for more details. :license: BSD, see LICENSE for more details.
""" """
__version__ = '1.1.dev' __version__ = "1.1.dev"
# utilities we import from Werkzeug and Jinja2 that are unused # utilities we import from Werkzeug and Jinja2 that are unused
# in the module but are exported as public interface. # in the module but are exported as public interface.
@ -20,21 +20,48 @@ from jinja2 import Markup, escape
from .app import Flask, Request, Response from .app import Flask, Request, Response
from .config import Config from .config import Config
from .helpers import url_for, flash, send_file, send_from_directory, \ from .helpers import (
get_flashed_messages, get_template_attribute, make_response, safe_join, \ url_for,
stream_with_context flash,
from .globals import current_app, g, request, session, _request_ctx_stack, \ send_file,
_app_ctx_stack send_from_directory,
from .ctx import has_request_context, has_app_context, \ get_flashed_messages,
after_this_request, copy_current_request_context get_template_attribute,
make_response,
safe_join,
stream_with_context,
)
from .globals import (
current_app,
g,
request,
session,
_request_ctx_stack,
_app_ctx_stack,
)
from .ctx import (
has_request_context,
has_app_context,
after_this_request,
copy_current_request_context,
)
from .blueprints import Blueprint from .blueprints import Blueprint
from .templating import render_template, render_template_string from .templating import render_template, render_template_string
# the signals # the signals
from .signals import signals_available, template_rendered, request_started, \ from .signals import (
request_finished, got_request_exception, request_tearing_down, \ signals_available,
appcontext_tearing_down, appcontext_pushed, \ template_rendered,
appcontext_popped, message_flashed, before_render_template request_started,
request_finished,
got_request_exception,
request_tearing_down,
appcontext_tearing_down,
appcontext_pushed,
appcontext_popped,
message_flashed,
before_render_template,
)
# We're not exposing the actual json module but a convenient wrapper around # We're not exposing the actual json module but a convenient wrapper around
# it. # it.

View file

@ -9,6 +9,7 @@
:license: BSD, see LICENSE for more details. :license: BSD, see LICENSE for more details.
""" """
if __name__ == '__main__': if __name__ == "__main__":
from .cli import main from .cli import main
main(as_module=True) main(as_module=True)

View file

@ -50,11 +50,11 @@ else:
from cStringIO import StringIO from cStringIO import StringIO
import collections as collections_abc import collections as collections_abc
exec('def reraise(tp, value, tb=None):\n raise tp, value, tb') exec("def reraise(tp, value, tb=None):\n raise tp, value, tb")
def implements_to_string(cls): def implements_to_string(cls):
cls.__unicode__ = cls.__str__ cls.__unicode__ = cls.__str__
cls.__str__ = lambda x: x.__unicode__().encode('utf-8') cls.__str__ = lambda x: x.__unicode__().encode("utf-8")
return cls return cls
@ -66,7 +66,8 @@ def with_metaclass(meta, *bases):
class metaclass(type): class metaclass(type):
def __new__(cls, name, this_bases, d): def __new__(cls, name, this_bases, d):
return meta(name, bases, d) return meta(name, bases, d)
return type.__new__(metaclass, 'temporary_class', (), {})
return type.__new__(metaclass, "temporary_class", (), {})
# Certain versions of pypy have a bug where clearing the exception stack # Certain versions of pypy have a bug where clearing the exception stack
@ -81,14 +82,17 @@ def with_metaclass(meta, *bases):
# #
# Ubuntu 14.04 has PyPy 2.2.1, which does exhibit this bug. # Ubuntu 14.04 has PyPy 2.2.1, which does exhibit this bug.
BROKEN_PYPY_CTXMGR_EXIT = False BROKEN_PYPY_CTXMGR_EXIT = False
if hasattr(sys, 'pypy_version_info'): if hasattr(sys, "pypy_version_info"):
class _Mgr(object): class _Mgr(object):
def __enter__(self): def __enter__(self):
return self return self
def __exit__(self, *args): def __exit__(self, *args):
if hasattr(sys, 'exc_clear'): if hasattr(sys, "exc_clear"):
# Python 3 (PyPy3) doesn't have exc_clear # Python 3 (PyPy3) doesn't have exc_clear
sys.exc_clear() sys.exc_clear()
try: try:
try: try:
with _Mgr(): with _Mgr():
@ -107,4 +111,4 @@ except ImportError:
# Backwards compatibility as proposed in PEP 0519: # Backwards compatibility as proposed in PEP 0519:
# https://www.python.org/dev/peps/pep-0519/#backwards-compatibility # https://www.python.org/dev/peps/pep-0519/#backwards-compatibility
def fspath(path): def fspath(path):
return path.__fspath__() if hasattr(path, '__fspath__') else path return path.__fspath__() if hasattr(path, "__fspath__") else path

View file

@ -18,10 +18,15 @@ from itertools import chain
from threading import Lock from threading import Lock
from werkzeug.datastructures import Headers, ImmutableDict from werkzeug.datastructures import Headers, ImmutableDict
from werkzeug.exceptions import BadRequest, BadRequestKeyError, HTTPException, \ from werkzeug.exceptions import (
InternalServerError, MethodNotAllowed, default_exceptions BadRequest,
from werkzeug.routing import BuildError, Map, RequestRedirect, \ BadRequestKeyError,
RoutingException, Rule HTTPException,
InternalServerError,
MethodNotAllowed,
default_exceptions,
)
from werkzeug.routing import BuildError, Map, RequestRedirect, RoutingException, Rule
from . import cli, json from . import cli, json
from ._compat import integer_types, reraise, string_types, text_type from ._compat import integer_types, reraise, string_types, text_type
@ -30,15 +35,29 @@ from .ctx import AppContext, RequestContext, _AppCtxGlobals
from .globals import _request_ctx_stack, g, request, session from .globals import _request_ctx_stack, g, request, session
from .helpers import ( from .helpers import (
_PackageBoundObject, _PackageBoundObject,
_endpoint_from_view_func, find_package, get_env, get_debug_flag, _endpoint_from_view_func,
get_flashed_messages, locked_cached_property, url_for, get_load_dotenv find_package,
get_env,
get_debug_flag,
get_flashed_messages,
locked_cached_property,
url_for,
get_load_dotenv,
) )
from .logging import create_logger from .logging import create_logger
from .sessions import SecureCookieSessionInterface from .sessions import SecureCookieSessionInterface
from .signals import appcontext_tearing_down, got_request_exception, \ from .signals import (
request_finished, request_started, request_tearing_down appcontext_tearing_down,
from .templating import DispatchingJinjaLoader, Environment, \ got_request_exception,
_default_template_ctx_processor request_finished,
request_started,
request_tearing_down,
)
from .templating import (
DispatchingJinjaLoader,
Environment,
_default_template_ctx_processor,
)
from .wrappers import Request, Response from .wrappers import Request, Response
# a singleton sentinel value for parameter defaults # a singleton sentinel value for parameter defaults
@ -55,16 +74,20 @@ def setupmethod(f):
"""Wraps a method so that it performs a check in debug mode if the """Wraps a method so that it performs a check in debug mode if the
first request was already handled. first request was already handled.
""" """
def wrapper_func(self, *args, **kwargs): def wrapper_func(self, *args, **kwargs):
if self.debug and self._got_first_request: if self.debug and self._got_first_request:
raise AssertionError('A setup function was called after the ' raise AssertionError(
'first request was handled. This usually indicates a bug ' "A setup function was called after the "
'in the application where a module was not imported ' "first request was handled. This usually indicates a bug "
'and decorators or other functionality was called too late.\n' "in the application where a module was not imported "
'To fix this make sure to import all your view modules, ' "and decorators or other functionality was called too late.\n"
'database models and everything related at a central place ' "To fix this make sure to import all your view modules, "
'before the application starts serving requests.') "database models and everything related at a central place "
"before the application starts serving requests."
)
return f(self, *args, **kwargs) return f(self, *args, **kwargs)
return update_wrapper(wrapper_func, f) return update_wrapper(wrapper_func, f)
@ -217,7 +240,7 @@ class Flask(_PackageBoundObject):
#: #:
#: This attribute can also be configured from the config with the #: This attribute can also be configured from the config with the
#: ``TESTING`` configuration key. Defaults to ``False``. #: ``TESTING`` configuration key. Defaults to ``False``.
testing = ConfigAttribute('TESTING') testing = ConfigAttribute("TESTING")
#: If a secret key is set, cryptographic components can use this to #: If a secret key is set, cryptographic components can use this to
#: sign cookies and other things. Set this to a complex random value #: sign cookies and other things. Set this to a complex random value
@ -225,13 +248,13 @@ class Flask(_PackageBoundObject):
#: #:
#: This attribute can also be configured from the config with the #: This attribute can also be configured from the config with the
#: :data:`SECRET_KEY` configuration key. Defaults to ``None``. #: :data:`SECRET_KEY` configuration key. Defaults to ``None``.
secret_key = ConfigAttribute('SECRET_KEY') secret_key = ConfigAttribute("SECRET_KEY")
#: The secure cookie uses this for the name of the session cookie. #: The secure cookie uses this for the name of the session cookie.
#: #:
#: This attribute can also be configured from the config with the #: This attribute can also be configured from the config with the
#: ``SESSION_COOKIE_NAME`` configuration key. Defaults to ``'session'`` #: ``SESSION_COOKIE_NAME`` configuration key. Defaults to ``'session'``
session_cookie_name = ConfigAttribute('SESSION_COOKIE_NAME') session_cookie_name = ConfigAttribute("SESSION_COOKIE_NAME")
#: A :class:`~datetime.timedelta` which is used to set the expiration #: A :class:`~datetime.timedelta` which is used to set the expiration
#: date of a permanent session. The default is 31 days which makes a #: date of a permanent session. The default is 31 days which makes a
@ -240,8 +263,9 @@ class Flask(_PackageBoundObject):
#: This attribute can also be configured from the config with the #: This attribute can also be configured from the config with the
#: ``PERMANENT_SESSION_LIFETIME`` configuration key. Defaults to #: ``PERMANENT_SESSION_LIFETIME`` configuration key. Defaults to
#: ``timedelta(days=31)`` #: ``timedelta(days=31)``
permanent_session_lifetime = ConfigAttribute('PERMANENT_SESSION_LIFETIME', permanent_session_lifetime = ConfigAttribute(
get_converter=_make_timedelta) "PERMANENT_SESSION_LIFETIME", get_converter=_make_timedelta
)
#: A :class:`~datetime.timedelta` which is used as default cache_timeout #: A :class:`~datetime.timedelta` which is used as default cache_timeout
#: for the :func:`send_file` functions. The default is 12 hours. #: for the :func:`send_file` functions. The default is 12 hours.
@ -250,8 +274,9 @@ class Flask(_PackageBoundObject):
#: ``SEND_FILE_MAX_AGE_DEFAULT`` configuration key. This configuration #: ``SEND_FILE_MAX_AGE_DEFAULT`` configuration key. This configuration
#: variable can also be set with an integer value used as seconds. #: variable can also be set with an integer value used as seconds.
#: Defaults to ``timedelta(hours=12)`` #: Defaults to ``timedelta(hours=12)``
send_file_max_age_default = ConfigAttribute('SEND_FILE_MAX_AGE_DEFAULT', send_file_max_age_default = ConfigAttribute(
get_converter=_make_timedelta) "SEND_FILE_MAX_AGE_DEFAULT", get_converter=_make_timedelta
)
#: Enable this if you want to use the X-Sendfile feature. Keep in #: Enable this if you want to use the X-Sendfile feature. Keep in
#: mind that the server has to support this. This only affects files #: mind that the server has to support this. This only affects files
@ -261,7 +286,7 @@ class Flask(_PackageBoundObject):
#: #:
#: This attribute can also be configured from the config with the #: This attribute can also be configured from the config with the
#: ``USE_X_SENDFILE`` configuration key. Defaults to ``False``. #: ``USE_X_SENDFILE`` configuration key. Defaults to ``False``.
use_x_sendfile = ConfigAttribute('USE_X_SENDFILE') use_x_sendfile = ConfigAttribute("USE_X_SENDFILE")
#: The JSON encoder class to use. Defaults to :class:`~flask.json.JSONEncoder`. #: The JSON encoder class to use. Defaults to :class:`~flask.json.JSONEncoder`.
#: #:
@ -275,41 +300,43 @@ class Flask(_PackageBoundObject):
#: Options that are passed directly to the Jinja2 environment. #: Options that are passed directly to the Jinja2 environment.
jinja_options = ImmutableDict( jinja_options = ImmutableDict(
extensions=['jinja2.ext.autoescape', 'jinja2.ext.with_'] extensions=["jinja2.ext.autoescape", "jinja2.ext.with_"]
) )
#: Default configuration parameters. #: Default configuration parameters.
default_config = ImmutableDict({ default_config = ImmutableDict(
'ENV': None, {
'DEBUG': None, "ENV": None,
'TESTING': False, "DEBUG": None,
'PROPAGATE_EXCEPTIONS': None, "TESTING": False,
'PRESERVE_CONTEXT_ON_EXCEPTION': None, "PROPAGATE_EXCEPTIONS": None,
'SECRET_KEY': None, "PRESERVE_CONTEXT_ON_EXCEPTION": None,
'PERMANENT_SESSION_LIFETIME': timedelta(days=31), "SECRET_KEY": None,
'USE_X_SENDFILE': False, "PERMANENT_SESSION_LIFETIME": timedelta(days=31),
'SERVER_NAME': None, "USE_X_SENDFILE": False,
'APPLICATION_ROOT': '/', "SERVER_NAME": None,
'SESSION_COOKIE_NAME': 'session', "APPLICATION_ROOT": "/",
'SESSION_COOKIE_DOMAIN': None, "SESSION_COOKIE_NAME": "session",
'SESSION_COOKIE_PATH': None, "SESSION_COOKIE_DOMAIN": None,
'SESSION_COOKIE_HTTPONLY': True, "SESSION_COOKIE_PATH": None,
'SESSION_COOKIE_SECURE': False, "SESSION_COOKIE_HTTPONLY": True,
'SESSION_COOKIE_SAMESITE': None, "SESSION_COOKIE_SECURE": False,
'SESSION_REFRESH_EACH_REQUEST': True, "SESSION_COOKIE_SAMESITE": None,
'MAX_CONTENT_LENGTH': None, "SESSION_REFRESH_EACH_REQUEST": True,
'SEND_FILE_MAX_AGE_DEFAULT': timedelta(hours=12), "MAX_CONTENT_LENGTH": None,
'TRAP_BAD_REQUEST_ERRORS': None, "SEND_FILE_MAX_AGE_DEFAULT": timedelta(hours=12),
'TRAP_HTTP_EXCEPTIONS': False, "TRAP_BAD_REQUEST_ERRORS": None,
'EXPLAIN_TEMPLATE_LOADING': False, "TRAP_HTTP_EXCEPTIONS": False,
'PREFERRED_URL_SCHEME': 'http', "EXPLAIN_TEMPLATE_LOADING": False,
'JSON_AS_ASCII': True, "PREFERRED_URL_SCHEME": "http",
'JSON_SORT_KEYS': True, "JSON_AS_ASCII": True,
'JSONIFY_PRETTYPRINT_REGULAR': False, "JSON_SORT_KEYS": True,
'JSONIFY_MIMETYPE': 'application/json', "JSONIFY_PRETTYPRINT_REGULAR": False,
'TEMPLATES_AUTO_RELOAD': None, "JSONIFY_MIMETYPE": "application/json",
'MAX_COOKIE_SIZE': 4093, "TEMPLATES_AUTO_RELOAD": None,
}) "MAX_COOKIE_SIZE": 4093,
}
)
#: The rule object to use for URL rules created. This is used by #: The rule object to use for URL rules created. This is used by
#: :meth:`add_url_rule`. Defaults to :class:`werkzeug.routing.Rule`. #: :meth:`add_url_rule`. Defaults to :class:`werkzeug.routing.Rule`.
@ -355,20 +382,17 @@ class Flask(_PackageBoundObject):
self, self,
import_name, import_name,
static_url_path=None, static_url_path=None,
static_folder='static', static_folder="static",
static_host=None, static_host=None,
host_matching=False, host_matching=False,
subdomain_matching=False, subdomain_matching=False,
template_folder='templates', template_folder="templates",
instance_path=None, instance_path=None,
instance_relative_config=False, instance_relative_config=False,
root_path=None root_path=None,
): ):
_PackageBoundObject.__init__( _PackageBoundObject.__init__(
self, self, import_name, template_folder=template_folder, root_path=root_path
import_name,
template_folder=template_folder,
root_path=root_path
) )
if static_url_path is not None: if static_url_path is not None:
@ -381,8 +405,8 @@ class Flask(_PackageBoundObject):
instance_path = self.auto_find_instance_path() instance_path = self.auto_find_instance_path()
elif not os.path.isabs(instance_path): elif not os.path.isabs(instance_path):
raise ValueError( raise ValueError(
'If an instance path is provided it must be absolute.' "If an instance path is provided it must be absolute."
' A relative path was given instead.' " A relative path was given instead."
) )
#: Holds the path to the instance folder. #: Holds the path to the instance folder.
@ -490,9 +514,7 @@ class Flask(_PackageBoundObject):
#: requests. Each returns a dictionary that the template context is #: requests. Each returns a dictionary that the template context is
#: updated with. To register a function here, use the #: updated with. To register a function here, use the
#: :meth:`context_processor` decorator. #: :meth:`context_processor` decorator.
self.template_context_processors = { self.template_context_processors = {None: [_default_template_ctx_processor]}
None: [_default_template_ctx_processor]
}
#: A list of shell context processor functions that should be run #: A list of shell context processor functions that should be run
#: when a shell context is created. #: when a shell context is created.
@ -555,12 +577,14 @@ class Flask(_PackageBoundObject):
# For one, it might be created while the server is running (e.g. during # For one, it might be created while the server is running (e.g. during
# development). Also, Google App Engine stores static files somewhere # development). Also, Google App Engine stores static files somewhere
if self.has_static_folder: if self.has_static_folder:
assert bool(static_host) == host_matching, 'Invalid static_host/host_matching combination' assert (
bool(static_host) == host_matching
), "Invalid static_host/host_matching combination"
self.add_url_rule( self.add_url_rule(
self.static_url_path + '/<path:filename>', self.static_url_path + "/<path:filename>",
endpoint='static', endpoint="static",
host=static_host, host=static_host,
view_func=self.send_static_file view_func=self.send_static_file,
) )
#: The click command line context for this application. Commands #: The click command line context for this application. Commands
@ -581,10 +605,10 @@ class Flask(_PackageBoundObject):
.. versionadded:: 0.8 .. versionadded:: 0.8
""" """
if self.import_name == '__main__': if self.import_name == "__main__":
fn = getattr(sys.modules['__main__'], '__file__', None) fn = getattr(sys.modules["__main__"], "__file__", None)
if fn is None: if fn is None:
return '__main__' return "__main__"
return os.path.splitext(os.path.basename(fn))[0] return os.path.splitext(os.path.basename(fn))[0]
return self.import_name return self.import_name
@ -595,7 +619,7 @@ class Flask(_PackageBoundObject):
.. versionadded:: 0.7 .. versionadded:: 0.7
""" """
rv = self.config['PROPAGATE_EXCEPTIONS'] rv = self.config["PROPAGATE_EXCEPTIONS"]
if rv is not None: if rv is not None:
return rv return rv
return self.testing or self.debug return self.testing or self.debug
@ -608,7 +632,7 @@ class Flask(_PackageBoundObject):
.. versionadded:: 0.7 .. versionadded:: 0.7
""" """
rv = self.config['PRESERVE_CONTEXT_ON_EXCEPTION'] rv = self.config["PRESERVE_CONTEXT_ON_EXCEPTION"]
if rv is not None: if rv is not None:
return rv return rv
return self.debug return self.debug
@ -663,8 +687,8 @@ class Flask(_PackageBoundObject):
if instance_relative: if instance_relative:
root_path = self.instance_path root_path = self.instance_path
defaults = dict(self.default_config) defaults = dict(self.default_config)
defaults['ENV'] = get_env() defaults["ENV"] = get_env()
defaults['DEBUG'] = get_debug_flag() defaults["DEBUG"] = get_debug_flag()
return self.config_class(root_path, defaults) return self.config_class(root_path, defaults)
def auto_find_instance_path(self): def auto_find_instance_path(self):
@ -677,10 +701,10 @@ class Flask(_PackageBoundObject):
""" """
prefix, package_path = find_package(self.import_name) prefix, package_path = find_package(self.import_name)
if prefix is None: if prefix is None:
return os.path.join(package_path, 'instance') return os.path.join(package_path, "instance")
return os.path.join(prefix, 'var', self.name + '-instance') return os.path.join(prefix, "var", self.name + "-instance")
def open_instance_resource(self, resource, mode='rb'): def open_instance_resource(self, resource, mode="rb"):
"""Opens a resource from the application's instance folder """Opens a resource from the application's instance folder
(:attr:`instance_path`). Otherwise works like (:attr:`instance_path`). Otherwise works like
:meth:`open_resource`. Instance resources can also be opened for :meth:`open_resource`. Instance resources can also be opened for
@ -703,11 +727,11 @@ class Flask(_PackageBoundObject):
This property was added but the underlying config and behavior This property was added but the underlying config and behavior
already existed. already existed.
""" """
rv = self.config['TEMPLATES_AUTO_RELOAD'] rv = self.config["TEMPLATES_AUTO_RELOAD"]
return rv if rv is not None else self.debug return rv if rv is not None else self.debug
def _set_templates_auto_reload(self, value): def _set_templates_auto_reload(self, value):
self.config['TEMPLATES_AUTO_RELOAD'] = value self.config["TEMPLATES_AUTO_RELOAD"] = value
templates_auto_reload = property( templates_auto_reload = property(
_get_templates_auto_reload, _set_templates_auto_reload _get_templates_auto_reload, _set_templates_auto_reload
@ -727,11 +751,11 @@ class Flask(_PackageBoundObject):
""" """
options = dict(self.jinja_options) options = dict(self.jinja_options)
if 'autoescape' not in options: if "autoescape" not in options:
options['autoescape'] = self.select_jinja_autoescape options["autoescape"] = self.select_jinja_autoescape
if 'auto_reload' not in options: if "auto_reload" not in options:
options['auto_reload'] = self.templates_auto_reload options["auto_reload"] = self.templates_auto_reload
rv = self.jinja_environment(self, **options) rv = self.jinja_environment(self, **options)
rv.globals.update( rv.globals.update(
@ -743,9 +767,9 @@ class Flask(_PackageBoundObject):
# templates we also want the proxies in there. # templates we also want the proxies in there.
request=request, request=request,
session=session, session=session,
g=g g=g,
) )
rv.filters['tojson'] = json.tojson_filter rv.filters["tojson"] = json.tojson_filter
return rv return rv
def create_global_jinja_loader(self): def create_global_jinja_loader(self):
@ -769,7 +793,7 @@ class Flask(_PackageBoundObject):
""" """
if filename is None: if filename is None:
return True return True
return filename.endswith(('.html', '.htm', '.xml', '.xhtml')) return filename.endswith((".html", ".htm", ".xml", ".xhtml"))
def update_template_context(self, context): def update_template_context(self, context):
"""Update the template context with some commonly used variables. """Update the template context with some commonly used variables.
@ -803,7 +827,7 @@ class Flask(_PackageBoundObject):
.. versionadded:: 0.11 .. versionadded:: 0.11
""" """
rv = {'app': self, 'g': g} rv = {"app": self, "g": g}
for processor in self.shell_context_processors: for processor in self.shell_context_processors:
rv.update(processor()) rv.update(processor())
return rv return rv
@ -817,13 +841,13 @@ class Flask(_PackageBoundObject):
#: **Do not enable development when deploying in production.** #: **Do not enable development when deploying in production.**
#: #:
#: Default: ``'production'`` #: Default: ``'production'``
env = ConfigAttribute('ENV') env = ConfigAttribute("ENV")
def _get_debug(self): def _get_debug(self):
return self.config['DEBUG'] return self.config["DEBUG"]
def _set_debug(self, value): def _set_debug(self, value):
self.config['DEBUG'] = value self.config["DEBUG"] = value
self.jinja_env.auto_reload = self.templates_auto_reload self.jinja_env.auto_reload = self.templates_auto_reload
#: Whether debug mode is enabled. When using ``flask run`` to start #: Whether debug mode is enabled. When using ``flask run`` to start
@ -841,8 +865,7 @@ class Flask(_PackageBoundObject):
debug = property(_get_debug, _set_debug) debug = property(_get_debug, _set_debug)
del _get_debug, _set_debug del _get_debug, _set_debug
def run(self, host=None, port=None, debug=None, def run(self, host=None, port=None, debug=None, load_dotenv=True, **options):
load_dotenv=True, **options):
"""Runs the application on a local development server. """Runs the application on a local development server.
Do not use ``run()`` in a production setting. It is not intended to Do not use ``run()`` in a production setting. It is not intended to
@ -902,8 +925,9 @@ class Flask(_PackageBoundObject):
""" """
# Change this into a no-op if the server is invoked from the # Change this into a no-op if the server is invoked from the
# command line. Have a look at cli.py for more information. # command line. Have a look at cli.py for more information.
if os.environ.get('FLASK_RUN_FROM_CLI') == 'true': if os.environ.get("FLASK_RUN_FROM_CLI") == "true":
from .debughelpers import explain_ignored_app_run from .debughelpers import explain_ignored_app_run
explain_ignored_app_run() explain_ignored_app_run()
return return
@ -911,30 +935,30 @@ class Flask(_PackageBoundObject):
cli.load_dotenv() cli.load_dotenv()
# if set, let env vars override previous values # if set, let env vars override previous values
if 'FLASK_ENV' in os.environ: if "FLASK_ENV" in os.environ:
self.env = get_env() self.env = get_env()
self.debug = get_debug_flag() self.debug = get_debug_flag()
elif 'FLASK_DEBUG' in os.environ: elif "FLASK_DEBUG" in os.environ:
self.debug = get_debug_flag() self.debug = get_debug_flag()
# debug passed to method overrides all other sources # debug passed to method overrides all other sources
if debug is not None: if debug is not None:
self.debug = bool(debug) self.debug = bool(debug)
_host = '127.0.0.1' _host = "127.0.0.1"
_port = 5000 _port = 5000
server_name = self.config.get('SERVER_NAME') server_name = self.config.get("SERVER_NAME")
sn_host, sn_port = None, None sn_host, sn_port = None, None
if server_name: if server_name:
sn_host, _, sn_port = server_name.partition(':') sn_host, _, sn_port = server_name.partition(":")
host = host or sn_host or _host host = host or sn_host or _host
port = int(port or sn_port or _port) port = int(port or sn_port or _port)
options.setdefault('use_reloader', self.debug) options.setdefault("use_reloader", self.debug)
options.setdefault('use_debugger', self.debug) options.setdefault("use_debugger", self.debug)
options.setdefault('threaded', True) options.setdefault("threaded", True)
cli.show_server_banner(self.env, self.debug, self.name, False) cli.show_server_banner(self.env, self.debug, self.name, False)
@ -1034,10 +1058,12 @@ class Flask(_PackageBoundObject):
:param request: an instance of :attr:`request_class`. :param request: an instance of :attr:`request_class`.
""" """
warnings.warn(DeprecationWarning( warnings.warn(
'"open_session" is deprecated and will be removed in 1.1. Use' DeprecationWarning(
' "session_interface.open_session" instead.' '"open_session" is deprecated and will be removed in 1.1. Use'
)) ' "session_interface.open_session" instead.'
)
)
return self.session_interface.open_session(self, request) return self.session_interface.open_session(self, request)
def save_session(self, session, response): def save_session(self, session, response):
@ -1055,10 +1081,12 @@ class Flask(_PackageBoundObject):
:param response: an instance of :attr:`response_class` :param response: an instance of :attr:`response_class`
""" """
warnings.warn(DeprecationWarning( warnings.warn(
'"save_session" is deprecated and will be removed in 1.1. Use' DeprecationWarning(
' "session_interface.save_session" instead.' '"save_session" is deprecated and will be removed in 1.1. Use'
)) ' "session_interface.save_session" instead.'
)
)
return self.session_interface.save_session(self, session, response) return self.session_interface.save_session(self, session, response)
def make_null_session(self): def make_null_session(self):
@ -1072,10 +1100,12 @@ class Flask(_PackageBoundObject):
.. versionadded:: 0.7 .. versionadded:: 0.7
""" """
warnings.warn(DeprecationWarning( warnings.warn(
'"make_null_session" is deprecated and will be removed in 1.1. Use' DeprecationWarning(
' "session_interface.make_null_session" instead.' '"make_null_session" is deprecated and will be removed in 1.1. Use'
)) ' "session_interface.make_null_session" instead.'
)
)
return self.session_interface.make_null_session(self) return self.session_interface.make_null_session(self)
@setupmethod @setupmethod
@ -1102,11 +1132,10 @@ class Flask(_PackageBoundObject):
if blueprint.name in self.blueprints: if blueprint.name in self.blueprints:
assert self.blueprints[blueprint.name] is blueprint, ( assert self.blueprints[blueprint.name] is blueprint, (
'A name collision occurred between blueprints %r and %r. Both' "A name collision occurred between blueprints %r and %r. Both"
' share the same name "%s". Blueprints that are created on the' ' share the same name "%s". Blueprints that are created on the'
' fly need unique names.' % ( " fly need unique names."
blueprint, self.blueprints[blueprint.name], blueprint.name % (blueprint, self.blueprints[blueprint.name], blueprint.name)
)
) )
else: else:
self.blueprints[blueprint.name] = blueprint self.blueprints[blueprint.name] = blueprint
@ -1123,8 +1152,14 @@ class Flask(_PackageBoundObject):
return iter(self._blueprint_order) return iter(self._blueprint_order)
@setupmethod @setupmethod
def add_url_rule(self, rule, endpoint=None, view_func=None, def add_url_rule(
provide_automatic_options=None, **options): self,
rule,
endpoint=None,
view_func=None,
provide_automatic_options=None,
**options
):
"""Connects a URL rule. Works exactly like the :meth:`route` """Connects a URL rule. Works exactly like the :meth:`route`
decorator. If a view_func is provided it will be registered with the decorator. If a view_func is provided it will be registered with the
endpoint. endpoint.
@ -1179,32 +1214,35 @@ class Flask(_PackageBoundObject):
""" """
if endpoint is None: if endpoint is None:
endpoint = _endpoint_from_view_func(view_func) endpoint = _endpoint_from_view_func(view_func)
options['endpoint'] = endpoint options["endpoint"] = endpoint
methods = options.pop('methods', None) methods = options.pop("methods", None)
# if the methods are not given and the view_func object knows its # if the methods are not given and the view_func object knows its
# methods we can use that instead. If neither exists, we go with # methods we can use that instead. If neither exists, we go with
# a tuple of only ``GET`` as default. # a tuple of only ``GET`` as default.
if methods is None: if methods is None:
methods = getattr(view_func, 'methods', None) or ('GET',) methods = getattr(view_func, "methods", None) or ("GET",)
if isinstance(methods, string_types): if isinstance(methods, string_types):
raise TypeError('Allowed methods have to be iterables of strings, ' raise TypeError(
'for example: @app.route(..., methods=["POST"])') "Allowed methods have to be iterables of strings, "
'for example: @app.route(..., methods=["POST"])'
)
methods = set(item.upper() for item in methods) methods = set(item.upper() for item in methods)
# Methods that should always be added # Methods that should always be added
required_methods = set(getattr(view_func, 'required_methods', ())) required_methods = set(getattr(view_func, "required_methods", ()))
# starting with Flask 0.8 the view_func object can disable and # starting with Flask 0.8 the view_func object can disable and
# force-enable the automatic options handling. # force-enable the automatic options handling.
if provide_automatic_options is None: if provide_automatic_options is None:
provide_automatic_options = getattr(view_func, provide_automatic_options = getattr(
'provide_automatic_options', None) view_func, "provide_automatic_options", None
)
if provide_automatic_options is None: if provide_automatic_options is None:
if 'OPTIONS' not in methods: if "OPTIONS" not in methods:
provide_automatic_options = True provide_automatic_options = True
required_methods.add('OPTIONS') required_methods.add("OPTIONS")
else: else:
provide_automatic_options = False provide_automatic_options = False
@ -1218,8 +1256,10 @@ class Flask(_PackageBoundObject):
if view_func is not None: if view_func is not None:
old_func = self.view_functions.get(endpoint) old_func = self.view_functions.get(endpoint)
if old_func is not None and old_func != view_func: if old_func is not None and old_func != view_func:
raise AssertionError('View function mapping is overwriting an ' raise AssertionError(
'existing endpoint function: %s' % endpoint) "View function mapping is overwriting an "
"existing endpoint function: %s" % endpoint
)
self.view_functions[endpoint] = view_func self.view_functions[endpoint] = view_func
def route(self, rule, **options): def route(self, rule, **options):
@ -1246,10 +1286,12 @@ class Flask(_PackageBoundObject):
Starting with Flask 0.6, ``OPTIONS`` is implicitly Starting with Flask 0.6, ``OPTIONS`` is implicitly
added and handled by the standard request handling. added and handled by the standard request handling.
""" """
def decorator(f): def decorator(f):
endpoint = options.pop('endpoint', None) endpoint = options.pop("endpoint", None)
self.add_url_rule(rule, endpoint, f, **options) self.add_url_rule(rule, endpoint, f, **options)
return f return f
return decorator return decorator
@setupmethod @setupmethod
@ -1263,9 +1305,11 @@ class Flask(_PackageBoundObject):
:param endpoint: the name of the endpoint :param endpoint: the name of the endpoint
""" """
def decorator(f): def decorator(f):
self.view_functions[endpoint] = f self.view_functions[endpoint] = f
return f return f
return decorator return decorator
@staticmethod @staticmethod
@ -1313,9 +1357,11 @@ class Flask(_PackageBoundObject):
:param code_or_exception: the code as integer for the handler, or :param code_or_exception: the code as integer for the handler, or
an arbitrary exception an arbitrary exception
""" """
def decorator(f): def decorator(f):
self._register_error_handler(None, code_or_exception, f) self._register_error_handler(None, code_or_exception, f)
return f return f
return decorator return decorator
@setupmethod @setupmethod
@ -1337,9 +1383,9 @@ class Flask(_PackageBoundObject):
""" """
if isinstance(code_or_exception, HTTPException): # old broken behavior if isinstance(code_or_exception, HTTPException): # old broken behavior
raise ValueError( raise ValueError(
'Tried to register a handler for an exception instance {0!r}.' "Tried to register a handler for an exception instance {0!r}."
' Handlers can only be registered for exception classes or' " Handlers can only be registered for exception classes or"
' HTTP error codes.'.format(code_or_exception) " HTTP error codes.".format(code_or_exception)
) )
try: try:
@ -1366,9 +1412,11 @@ class Flask(_PackageBoundObject):
:param name: the optional name of the filter, otherwise the :param name: the optional name of the filter, otherwise the
function name will be used. function name will be used.
""" """
def decorator(f): def decorator(f):
self.add_template_filter(f, name=name) self.add_template_filter(f, name=name)
return f return f
return decorator return decorator
@setupmethod @setupmethod
@ -1401,9 +1449,11 @@ class Flask(_PackageBoundObject):
:param name: the optional name of the test, otherwise the :param name: the optional name of the test, otherwise the
function name will be used. function name will be used.
""" """
def decorator(f): def decorator(f):
self.add_template_test(f, name=name) self.add_template_test(f, name=name)
return f return f
return decorator return decorator
@setupmethod @setupmethod
@ -1433,9 +1483,11 @@ class Flask(_PackageBoundObject):
:param name: the optional name of the global function, otherwise the :param name: the optional name of the global function, otherwise the
function name will be used. function name will be used.
""" """
def decorator(f): def decorator(f):
self.add_template_global(f, name=name) self.add_template_global(f, name=name)
return f return f
return decorator return decorator
@setupmethod @setupmethod
@ -1613,8 +1665,10 @@ class Flask(_PackageBoundObject):
exc_class, code = self._get_exc_class_and_code(type(e)) exc_class, code = self._get_exc_class_and_code(type(e))
for name, c in ( for name, c in (
(request.blueprint, code), (None, code), (request.blueprint, code),
(request.blueprint, None), (None, None) (None, code),
(request.blueprint, None),
(None, None),
): ):
handler_map = self.error_handler_spec.setdefault(name, {}).get(c) handler_map = self.error_handler_spec.setdefault(name, {}).get(c)
@ -1677,14 +1731,15 @@ class Flask(_PackageBoundObject):
.. versionadded:: 0.8 .. versionadded:: 0.8
""" """
if self.config['TRAP_HTTP_EXCEPTIONS']: if self.config["TRAP_HTTP_EXCEPTIONS"]:
return True return True
trap_bad_request = self.config['TRAP_BAD_REQUEST_ERRORS'] trap_bad_request = self.config["TRAP_BAD_REQUEST_ERRORS"]
# if unset, trap key errors in debug mode # if unset, trap key errors in debug mode
if ( if (
trap_bad_request is None and self.debug trap_bad_request is None
and self.debug
and isinstance(e, BadRequestKeyError) and isinstance(e, BadRequestKeyError)
): ):
return True return True
@ -1773,10 +1828,9 @@ class Flask(_PackageBoundObject):
.. versionadded:: 0.8 .. versionadded:: 0.8
""" """
self.logger.error('Exception on %s [%s]' % ( self.logger.error(
request.path, "Exception on %s [%s]" % (request.path, request.method), exc_info=exc_info
request.method )
), exc_info=exc_info)
def raise_routing_exception(self, request): def raise_routing_exception(self, request):
"""Exceptions that are recording during routing are reraised with """Exceptions that are recording during routing are reraised with
@ -1786,12 +1840,15 @@ class Flask(_PackageBoundObject):
:internal: :internal:
""" """
if not self.debug \ if (
or not isinstance(request.routing_exception, RequestRedirect) \ not self.debug
or request.method in ('GET', 'HEAD', 'OPTIONS'): or not isinstance(request.routing_exception, RequestRedirect)
or request.method in ("GET", "HEAD", "OPTIONS")
):
raise request.routing_exception raise request.routing_exception
from .debughelpers import FormDataRoutingRedirect from .debughelpers import FormDataRoutingRedirect
raise FormDataRoutingRedirect(request) raise FormDataRoutingRedirect(request)
def dispatch_request(self): def dispatch_request(self):
@ -1810,8 +1867,10 @@ class Flask(_PackageBoundObject):
rule = req.url_rule rule = req.url_rule
# if we provide automatic options for this URL and the # if we provide automatic options for this URL and the
# request came with the OPTIONS method, reply automatically # request came with the OPTIONS method, reply automatically
if getattr(rule, 'provide_automatic_options', False) \ if (
and req.method == 'OPTIONS': getattr(rule, "provide_automatic_options", False)
and req.method == "OPTIONS"
):
return self.make_default_options_response() return self.make_default_options_response()
# otherwise dispatch to the handler for that endpoint # otherwise dispatch to the handler for that endpoint
return self.view_functions[rule.endpoint](**req.view_args) return self.view_functions[rule.endpoint](**req.view_args)
@ -1853,8 +1912,9 @@ class Flask(_PackageBoundObject):
except Exception: except Exception:
if not from_error_handler: if not from_error_handler:
raise raise
self.logger.exception('Request finalizing failed with an ' self.logger.exception(
'error while handling an error') "Request finalizing failed with an " "error while handling an error"
)
return response return response
def try_trigger_before_first_request_functions(self): def try_trigger_before_first_request_functions(self):
@ -1881,13 +1941,13 @@ class Flask(_PackageBoundObject):
.. versionadded:: 0.7 .. versionadded:: 0.7
""" """
adapter = _request_ctx_stack.top.url_adapter adapter = _request_ctx_stack.top.url_adapter
if hasattr(adapter, 'allowed_methods'): if hasattr(adapter, "allowed_methods"):
methods = adapter.allowed_methods() methods = adapter.allowed_methods()
else: else:
# fallback for Werkzeug < 0.7 # fallback for Werkzeug < 0.7
methods = [] methods = []
try: try:
adapter.match(method='--') adapter.match(method="--")
except MethodNotAllowed as e: except MethodNotAllowed as e:
methods = e.valid_methods methods = e.valid_methods
except HTTPException as e: except HTTPException as e:
@ -1964,17 +2024,17 @@ class Flask(_PackageBoundObject):
# other sized tuples are not allowed # other sized tuples are not allowed
else: else:
raise TypeError( raise TypeError(
'The view function did not return a valid response tuple.' "The view function did not return a valid response tuple."
' The tuple must have the form (body, status, headers),' " The tuple must have the form (body, status, headers),"
' (body, status), or (body, headers).' " (body, status), or (body, headers)."
) )
# the body must not be None # the body must not be None
if rv is None: if rv is None:
raise TypeError( raise TypeError(
'The view function did not return a valid response. The' "The view function did not return a valid response. The"
' function either returned None or ended without a return' " function either returned None or ended without a return"
' statement.' " statement."
) )
# make sure the body is an instance of the response class # make sure the body is an instance of the response class
@ -1992,10 +2052,10 @@ class Flask(_PackageBoundObject):
rv = self.response_class.force_type(rv, request.environ) rv = self.response_class.force_type(rv, request.environ)
except TypeError as e: except TypeError as e:
new_error = TypeError( new_error = TypeError(
'{e}\nThe view function did not return a valid' "{e}\nThe view function did not return a valid"
' response. The return type must be a string, tuple,' " response. The return type must be a string, tuple,"
' Response instance, or WSGI callable, but it was a' " Response instance, or WSGI callable, but it was a"
' {rv.__class__.__name__}.'.format(e=e, rv=rv) " {rv.__class__.__name__}.".format(e=e, rv=rv)
) )
reraise(TypeError, new_error, sys.exc_info()[2]) reraise(TypeError, new_error, sys.exc_info()[2])
@ -2031,19 +2091,24 @@ class Flask(_PackageBoundObject):
# If subdomain matching is disabled (the default), use the # If subdomain matching is disabled (the default), use the
# default subdomain in all cases. This should be the default # default subdomain in all cases. This should be the default
# in Werkzeug but it currently does not have that feature. # in Werkzeug but it currently does not have that feature.
subdomain = ((self.url_map.default_subdomain or None) subdomain = (
if not self.subdomain_matching else None) (self.url_map.default_subdomain or None)
if not self.subdomain_matching
else None
)
return self.url_map.bind_to_environ( return self.url_map.bind_to_environ(
request.environ, request.environ,
server_name=self.config['SERVER_NAME'], server_name=self.config["SERVER_NAME"],
subdomain=subdomain) subdomain=subdomain,
)
# We need at the very least the server name to be set for this # We need at the very least the server name to be set for this
# to work. # to work.
if self.config['SERVER_NAME'] is not None: if self.config["SERVER_NAME"] is not None:
return self.url_map.bind( return self.url_map.bind(
self.config['SERVER_NAME'], self.config["SERVER_NAME"],
script_name=self.config['APPLICATION_ROOT'], script_name=self.config["APPLICATION_ROOT"],
url_scheme=self.config['PREFERRED_URL_SCHEME']) url_scheme=self.config["PREFERRED_URL_SCHEME"],
)
def inject_url_defaults(self, endpoint, values): def inject_url_defaults(self, endpoint, values):
"""Injects the URL defaults for the given endpoint directly into """Injects the URL defaults for the given endpoint directly into
@ -2053,8 +2118,8 @@ class Flask(_PackageBoundObject):
.. versionadded:: 0.7 .. versionadded:: 0.7
""" """
funcs = self.url_default_functions.get(None, ()) funcs = self.url_default_functions.get(None, ())
if '.' in endpoint: if "." in endpoint:
bp = endpoint.rsplit('.', 1)[0] bp = endpoint.rsplit(".", 1)[0]
funcs = chain(funcs, self.url_default_functions.get(bp, ())) funcs = chain(funcs, self.url_default_functions.get(bp, ()))
for func in funcs: for func in funcs:
func(endpoint, values) func(endpoint, values)
@ -2327,7 +2392,4 @@ class Flask(_PackageBoundObject):
return self.wsgi_app(environ, start_response) return self.wsgi_app(environ, start_response)
def __repr__(self): def __repr__(self):
return '<%s %r>' % ( return "<%s %r>" % (self.__class__.__name__, self.name)
self.__class__.__name__,
self.name,
)

View file

@ -39,7 +39,7 @@ class BlueprintSetupState(object):
#: out if the blueprint was registered in the past already. #: out if the blueprint was registered in the past already.
self.first_registration = first_registration self.first_registration = first_registration
subdomain = self.options.get('subdomain') subdomain = self.options.get("subdomain")
if subdomain is None: if subdomain is None:
subdomain = self.blueprint.subdomain subdomain = self.blueprint.subdomain
@ -47,7 +47,7 @@ class BlueprintSetupState(object):
#: otherwise. #: otherwise.
self.subdomain = subdomain self.subdomain = subdomain
url_prefix = self.options.get('url_prefix') url_prefix = self.options.get("url_prefix")
if url_prefix is None: if url_prefix is None:
url_prefix = self.blueprint.url_prefix url_prefix = self.blueprint.url_prefix
#: The prefix that should be used for all URLs defined on the #: The prefix that should be used for all URLs defined on the
@ -57,7 +57,7 @@ class BlueprintSetupState(object):
#: A dictionary with URL defaults that is added to each and every #: A dictionary with URL defaults that is added to each and every
#: URL that was defined with the blueprint. #: URL that was defined with the blueprint.
self.url_defaults = dict(self.blueprint.url_values_defaults) self.url_defaults = dict(self.blueprint.url_values_defaults)
self.url_defaults.update(self.options.get('url_defaults', ())) self.url_defaults.update(self.options.get("url_defaults", ()))
def add_url_rule(self, rule, endpoint=None, view_func=None, **options): def add_url_rule(self, rule, endpoint=None, view_func=None, **options):
"""A helper method to register a rule (and optionally a view function) """A helper method to register a rule (and optionally a view function)
@ -66,18 +66,22 @@ class BlueprintSetupState(object):
""" """
if self.url_prefix is not None: if self.url_prefix is not None:
if rule: if rule:
rule = '/'.join(( rule = "/".join((self.url_prefix.rstrip("/"), rule.lstrip("/")))
self.url_prefix.rstrip('/'), rule.lstrip('/')))
else: else:
rule = self.url_prefix rule = self.url_prefix
options.setdefault('subdomain', self.subdomain) options.setdefault("subdomain", self.subdomain)
if endpoint is None: if endpoint is None:
endpoint = _endpoint_from_view_func(view_func) endpoint = _endpoint_from_view_func(view_func)
defaults = self.url_defaults defaults = self.url_defaults
if 'defaults' in options: if "defaults" in options:
defaults = dict(defaults, **options.pop('defaults')) defaults = dict(defaults, **options.pop("defaults"))
self.app.add_url_rule(rule, '%s.%s' % (self.blueprint.name, endpoint), self.app.add_url_rule(
view_func, defaults=defaults, **options) rule,
"%s.%s" % (self.blueprint.name, endpoint),
view_func,
defaults=defaults,
**options
)
class Blueprint(_PackageBoundObject): class Blueprint(_PackageBoundObject):
@ -115,12 +119,21 @@ class Blueprint(_PackageBoundObject):
#: resources contained in the package. #: resources contained in the package.
root_path = None root_path = None
def __init__(self, name, import_name, static_folder=None, def __init__(
static_url_path=None, template_folder=None, self,
url_prefix=None, subdomain=None, url_defaults=None, name,
root_path=None): import_name,
_PackageBoundObject.__init__(self, import_name, template_folder, static_folder=None,
root_path=root_path) static_url_path=None,
template_folder=None,
url_prefix=None,
subdomain=None,
url_defaults=None,
root_path=None,
):
_PackageBoundObject.__init__(
self, import_name, template_folder, root_path=root_path
)
self.name = name self.name = name
self.url_prefix = url_prefix self.url_prefix = url_prefix
self.subdomain = subdomain self.subdomain = subdomain
@ -139,9 +152,14 @@ class Blueprint(_PackageBoundObject):
""" """
if self._got_registered_once and self.warn_on_modifications: if self._got_registered_once and self.warn_on_modifications:
from warnings import warn from warnings import warn
warn(Warning('The blueprint was already registered once '
'but is getting modified now. These changes ' warn(
'will not show up.')) Warning(
"The blueprint was already registered once "
"but is getting modified now. These changes "
"will not show up."
)
)
self.deferred_functions.append(func) self.deferred_functions.append(func)
def record_once(self, func): def record_once(self, func):
@ -150,9 +168,11 @@ class Blueprint(_PackageBoundObject):
blueprint is registered a second time on the application, the blueprint is registered a second time on the application, the
function passed is not called. function passed is not called.
""" """
def wrapper(state): def wrapper(state):
if state.first_registration: if state.first_registration:
func(state) func(state)
return self.record(update_wrapper(wrapper, func)) return self.record(update_wrapper(wrapper, func))
def make_setup_state(self, app, options, first_registration=False): def make_setup_state(self, app, options, first_registration=False):
@ -179,8 +199,9 @@ class Blueprint(_PackageBoundObject):
if self.has_static_folder: if self.has_static_folder:
state.add_url_rule( state.add_url_rule(
self.static_url_path + '/<path:filename>', self.static_url_path + "/<path:filename>",
view_func=self.send_static_file, endpoint='static' view_func=self.send_static_file,
endpoint="static",
) )
for deferred in self.deferred_functions: for deferred in self.deferred_functions:
@ -190,10 +211,12 @@ class Blueprint(_PackageBoundObject):
"""Like :meth:`Flask.route` but for a blueprint. The endpoint for the """Like :meth:`Flask.route` but for a blueprint. The endpoint for the
:func:`url_for` function is prefixed with the name of the blueprint. :func:`url_for` function is prefixed with the name of the blueprint.
""" """
def decorator(f): def decorator(f):
endpoint = options.pop("endpoint", f.__name__) endpoint = options.pop("endpoint", f.__name__)
self.add_url_rule(rule, endpoint, f, **options) self.add_url_rule(rule, endpoint, f, **options)
return f return f
return decorator return decorator
def add_url_rule(self, rule, endpoint=None, view_func=None, **options): def add_url_rule(self, rule, endpoint=None, view_func=None, **options):
@ -201,11 +224,12 @@ class Blueprint(_PackageBoundObject):
the :func:`url_for` function is prefixed with the name of the blueprint. the :func:`url_for` function is prefixed with the name of the blueprint.
""" """
if endpoint: if endpoint:
assert '.' not in endpoint, "Blueprint endpoints should not contain dots" assert "." not in endpoint, "Blueprint endpoints should not contain dots"
if view_func and hasattr(view_func, '__name__'): if view_func and hasattr(view_func, "__name__"):
assert '.' not in view_func.__name__, "Blueprint view function name should not contain dots" assert (
self.record(lambda s: "." not in view_func.__name__
s.add_url_rule(rule, endpoint, view_func, **options)) ), "Blueprint view function name should not contain dots"
self.record(lambda s: s.add_url_rule(rule, endpoint, view_func, **options))
def endpoint(self, endpoint): def endpoint(self, endpoint):
"""Like :meth:`Flask.endpoint` but for a blueprint. This does not """Like :meth:`Flask.endpoint` but for a blueprint. This does not
@ -214,11 +238,14 @@ class Blueprint(_PackageBoundObject):
with a `.` it will be registered to the current blueprint, otherwise with a `.` it will be registered to the current blueprint, otherwise
it's an application independent endpoint. it's an application independent endpoint.
""" """
def decorator(f): def decorator(f):
def register_endpoint(state): def register_endpoint(state):
state.app.view_functions[endpoint] = f state.app.view_functions[endpoint] = f
self.record_once(register_endpoint) self.record_once(register_endpoint)
return f return f
return decorator return decorator
def app_template_filter(self, name=None): def app_template_filter(self, name=None):
@ -228,9 +255,11 @@ class Blueprint(_PackageBoundObject):
:param name: the optional name of the filter, otherwise the :param name: the optional name of the filter, otherwise the
function name will be used. function name will be used.
""" """
def decorator(f): def decorator(f):
self.add_app_template_filter(f, name=name) self.add_app_template_filter(f, name=name)
return f return f
return decorator return decorator
def add_app_template_filter(self, f, name=None): def add_app_template_filter(self, f, name=None):
@ -241,8 +270,10 @@ class Blueprint(_PackageBoundObject):
:param name: the optional name of the filter, otherwise the :param name: the optional name of the filter, otherwise the
function name will be used. function name will be used.
""" """
def register_template(state): def register_template(state):
state.app.jinja_env.filters[name or f.__name__] = f state.app.jinja_env.filters[name or f.__name__] = f
self.record_once(register_template) self.record_once(register_template)
def app_template_test(self, name=None): def app_template_test(self, name=None):
@ -254,9 +285,11 @@ class Blueprint(_PackageBoundObject):
:param name: the optional name of the test, otherwise the :param name: the optional name of the test, otherwise the
function name will be used. function name will be used.
""" """
def decorator(f): def decorator(f):
self.add_app_template_test(f, name=name) self.add_app_template_test(f, name=name)
return f return f
return decorator return decorator
def add_app_template_test(self, f, name=None): def add_app_template_test(self, f, name=None):
@ -269,8 +302,10 @@ class Blueprint(_PackageBoundObject):
:param name: the optional name of the test, otherwise the :param name: the optional name of the test, otherwise the
function name will be used. function name will be used.
""" """
def register_template(state): def register_template(state):
state.app.jinja_env.tests[name or f.__name__] = f state.app.jinja_env.tests[name or f.__name__] = f
self.record_once(register_template) self.record_once(register_template)
def app_template_global(self, name=None): def app_template_global(self, name=None):
@ -282,9 +317,11 @@ class Blueprint(_PackageBoundObject):
:param name: the optional name of the global, otherwise the :param name: the optional name of the global, otherwise the
function name will be used. function name will be used.
""" """
def decorator(f): def decorator(f):
self.add_app_template_global(f, name=name) self.add_app_template_global(f, name=name)
return f return f
return decorator return decorator
def add_app_template_global(self, f, name=None): def add_app_template_global(self, f, name=None):
@ -297,8 +334,10 @@ class Blueprint(_PackageBoundObject):
:param name: the optional name of the global, otherwise the :param name: the optional name of the global, otherwise the
function name will be used. function name will be used.
""" """
def register_template(state): def register_template(state):
state.app.jinja_env.globals[name or f.__name__] = f state.app.jinja_env.globals[name or f.__name__] = f
self.record_once(register_template) self.record_once(register_template)
def before_request(self, f): def before_request(self, f):
@ -306,16 +345,18 @@ class Blueprint(_PackageBoundObject):
is only executed before each request that is handled by a function of is only executed before each request that is handled by a function of
that blueprint. that blueprint.
""" """
self.record_once(lambda s: s.app.before_request_funcs self.record_once(
.setdefault(self.name, []).append(f)) lambda s: s.app.before_request_funcs.setdefault(self.name, []).append(f)
)
return f return f
def before_app_request(self, f): def before_app_request(self, f):
"""Like :meth:`Flask.before_request`. Such a function is executed """Like :meth:`Flask.before_request`. Such a function is executed
before each request, even if outside of a blueprint. before each request, even if outside of a blueprint.
""" """
self.record_once(lambda s: s.app.before_request_funcs self.record_once(
.setdefault(None, []).append(f)) lambda s: s.app.before_request_funcs.setdefault(None, []).append(f)
)
return f return f
def before_app_first_request(self, f): def before_app_first_request(self, f):
@ -330,16 +371,18 @@ class Blueprint(_PackageBoundObject):
is only executed after each request that is handled by a function of is only executed after each request that is handled by a function of
that blueprint. that blueprint.
""" """
self.record_once(lambda s: s.app.after_request_funcs self.record_once(
.setdefault(self.name, []).append(f)) lambda s: s.app.after_request_funcs.setdefault(self.name, []).append(f)
)
return f return f
def after_app_request(self, f): def after_app_request(self, f):
"""Like :meth:`Flask.after_request` but for a blueprint. Such a function """Like :meth:`Flask.after_request` but for a blueprint. Such a function
is executed after each request, even if outside of the blueprint. is executed after each request, even if outside of the blueprint.
""" """
self.record_once(lambda s: s.app.after_request_funcs self.record_once(
.setdefault(None, []).append(f)) lambda s: s.app.after_request_funcs.setdefault(None, []).append(f)
)
return f return f
def teardown_request(self, f): def teardown_request(self, f):
@ -349,8 +392,9 @@ class Blueprint(_PackageBoundObject):
when the request context is popped, even when no actual request was when the request context is popped, even when no actual request was
performed. performed.
""" """
self.record_once(lambda s: s.app.teardown_request_funcs self.record_once(
.setdefault(self.name, []).append(f)) lambda s: s.app.teardown_request_funcs.setdefault(self.name, []).append(f)
)
return f return f
def teardown_app_request(self, f): def teardown_app_request(self, f):
@ -358,33 +402,40 @@ class Blueprint(_PackageBoundObject):
function is executed when tearing down each request, even if outside of function is executed when tearing down each request, even if outside of
the blueprint. the blueprint.
""" """
self.record_once(lambda s: s.app.teardown_request_funcs self.record_once(
.setdefault(None, []).append(f)) lambda s: s.app.teardown_request_funcs.setdefault(None, []).append(f)
)
return f return f
def context_processor(self, f): def context_processor(self, f):
"""Like :meth:`Flask.context_processor` but for a blueprint. This """Like :meth:`Flask.context_processor` but for a blueprint. This
function is only executed for requests handled by a blueprint. function is only executed for requests handled by a blueprint.
""" """
self.record_once(lambda s: s.app.template_context_processors self.record_once(
.setdefault(self.name, []).append(f)) lambda s: s.app.template_context_processors.setdefault(
self.name, []
).append(f)
)
return f return f
def app_context_processor(self, f): def app_context_processor(self, f):
"""Like :meth:`Flask.context_processor` but for a blueprint. Such a """Like :meth:`Flask.context_processor` but for a blueprint. Such a
function is executed each request, even if outside of the blueprint. function is executed each request, even if outside of the blueprint.
""" """
self.record_once(lambda s: s.app.template_context_processors self.record_once(
.setdefault(None, []).append(f)) lambda s: s.app.template_context_processors.setdefault(None, []).append(f)
)
return f return f
def app_errorhandler(self, code): def app_errorhandler(self, code):
"""Like :meth:`Flask.errorhandler` but for a blueprint. This """Like :meth:`Flask.errorhandler` but for a blueprint. This
handler is used for all requests, even if outside of the blueprint. handler is used for all requests, even if outside of the blueprint.
""" """
def decorator(f): def decorator(f):
self.record_once(lambda s: s.app.errorhandler(code)(f)) self.record_once(lambda s: s.app.errorhandler(code)(f))
return f return f
return decorator return decorator
def url_value_preprocessor(self, f): def url_value_preprocessor(self, f):
@ -392,8 +443,9 @@ class Blueprint(_PackageBoundObject):
blueprint. It's called before the view functions are called and blueprint. It's called before the view functions are called and
can modify the url values provided. can modify the url values provided.
""" """
self.record_once(lambda s: s.app.url_value_preprocessors self.record_once(
.setdefault(self.name, []).append(f)) lambda s: s.app.url_value_preprocessors.setdefault(self.name, []).append(f)
)
return f return f
def url_defaults(self, f): def url_defaults(self, f):
@ -401,22 +453,25 @@ class Blueprint(_PackageBoundObject):
with the endpoint and values and should update the values passed with the endpoint and values and should update the values passed
in place. in place.
""" """
self.record_once(lambda s: s.app.url_default_functions self.record_once(
.setdefault(self.name, []).append(f)) lambda s: s.app.url_default_functions.setdefault(self.name, []).append(f)
)
return f return f
def app_url_value_preprocessor(self, f): def app_url_value_preprocessor(self, f):
"""Same as :meth:`url_value_preprocessor` but application wide. """Same as :meth:`url_value_preprocessor` but application wide.
""" """
self.record_once(lambda s: s.app.url_value_preprocessors self.record_once(
.setdefault(None, []).append(f)) lambda s: s.app.url_value_preprocessors.setdefault(None, []).append(f)
)
return f return f
def app_url_defaults(self, f): def app_url_defaults(self, f):
"""Same as :meth:`url_defaults` but application wide. """Same as :meth:`url_defaults` but application wide.
""" """
self.record_once(lambda s: s.app.url_default_functions self.record_once(
.setdefault(None, []).append(f)) lambda s: s.app.url_default_functions.setdefault(None, []).append(f)
)
return f return f
def errorhandler(self, code_or_exception): def errorhandler(self, code_or_exception):
@ -430,10 +485,13 @@ class Blueprint(_PackageBoundObject):
Otherwise works as the :meth:`~flask.Flask.errorhandler` decorator Otherwise works as the :meth:`~flask.Flask.errorhandler` decorator
of the :class:`~flask.Flask` object. of the :class:`~flask.Flask` object.
""" """
def decorator(f): def decorator(f):
self.record_once(lambda s: s.app._register_error_handler( self.record_once(
self.name, code_or_exception, f)) lambda s: s.app._register_error_handler(self.name, code_or_exception, f)
)
return f return f
return decorator return decorator
def register_error_handler(self, code_or_exception, f): def register_error_handler(self, code_or_exception, f):
@ -444,5 +502,6 @@ class Blueprint(_PackageBoundObject):
.. versionadded:: 0.11 .. versionadded:: 0.11
""" """
self.record_once(lambda s: s.app._register_error_handler( self.record_once(
self.name, code_or_exception, f)) lambda s: s.app._register_error_handler(self.name, code_or_exception, f)
)

View file

@ -48,16 +48,14 @@ def find_best_app(script_info, module):
from . import Flask from . import Flask
# Search for the most common names first. # Search for the most common names first.
for attr_name in ('app', 'application'): for attr_name in ("app", "application"):
app = getattr(module, attr_name, None) app = getattr(module, attr_name, None)
if isinstance(app, Flask): if isinstance(app, Flask):
return app return app
# Otherwise find the only object that is a Flask instance. # Otherwise find the only object that is a Flask instance.
matches = [ matches = [v for v in itervalues(module.__dict__) if isinstance(v, Flask)]
v for v in itervalues(module.__dict__) if isinstance(v, Flask)
]
if len(matches) == 1: if len(matches) == 1:
return matches[0] return matches[0]
@ -65,11 +63,11 @@ def find_best_app(script_info, module):
raise NoAppException( raise NoAppException(
'Detected multiple Flask applications in module "{module}". Use ' 'Detected multiple Flask applications in module "{module}". Use '
'"FLASK_APP={module}:name" to specify the correct ' '"FLASK_APP={module}:name" to specify the correct '
'one.'.format(module=module.__name__) "one.".format(module=module.__name__)
) )
# Search for app factory functions. # Search for app factory functions.
for attr_name in ('create_app', 'make_app'): for attr_name in ("create_app", "make_app"):
app_factory = getattr(module, attr_name, None) app_factory = getattr(module, attr_name, None)
if inspect.isfunction(app_factory): if inspect.isfunction(app_factory):
@ -83,18 +81,14 @@ def find_best_app(script_info, module):
raise raise
raise NoAppException( raise NoAppException(
'Detected factory "{factory}" in module "{module}", but ' 'Detected factory "{factory}" in module "{module}", but '
'could not call it without arguments. Use ' "could not call it without arguments. Use "
'"FLASK_APP=\'{module}:{factory}(args)\'" to specify ' "\"FLASK_APP='{module}:{factory}(args)'\" to specify "
'arguments.'.format( "arguments.".format(factory=attr_name, module=module.__name__)
factory=attr_name, module=module.__name__
)
) )
raise NoAppException( raise NoAppException(
'Failed to find Flask application or factory in module "{module}". ' 'Failed to find Flask application or factory in module "{module}". '
'Use "FLASK_APP={module}:name to specify one.'.format( 'Use "FLASK_APP={module}:name to specify one.'.format(module=module.__name__)
module=module.__name__
)
) )
@ -107,7 +101,7 @@ def call_factory(script_info, app_factory, arguments=()):
arg_names = args_spec.args arg_names = args_spec.args
arg_defaults = args_spec.defaults arg_defaults = args_spec.defaults
if 'script_info' in arg_names: if "script_info" in arg_names:
return app_factory(*arguments, script_info=script_info) return app_factory(*arguments, script_info=script_info)
elif arguments: elif arguments:
return app_factory(*arguments) return app_factory(*arguments)
@ -148,12 +142,13 @@ def find_app_by_string(script_info, module, app_name):
arguments. arguments.
""" """
from flask import Flask from flask import Flask
match = re.match(r'^ *([^ ()]+) *(?:\((.*?) *,? *\))? *$', app_name)
match = re.match(r"^ *([^ ()]+) *(?:\((.*?) *,? *\))? *$", app_name)
if not match: if not match:
raise NoAppException( raise NoAppException(
'"{name}" is not a valid variable name or function ' '"{name}" is not a valid variable name or function '
'expression.'.format(name=app_name) "expression.".format(name=app_name)
) )
name, args = match.groups() name, args = match.groups()
@ -166,10 +161,10 @@ def find_app_by_string(script_info, module, app_name):
if inspect.isfunction(attr): if inspect.isfunction(attr):
if args: if args:
try: try:
args = ast.literal_eval('({args},)'.format(args=args)) args = ast.literal_eval("({args},)".format(args=args))
except (ValueError, SyntaxError)as e: except (ValueError, SyntaxError) as e:
raise NoAppException( raise NoAppException(
'Could not parse the arguments in ' "Could not parse the arguments in "
'"{app_name}".'.format(e=e, app_name=app_name) '"{app_name}".'.format(e=e, app_name=app_name)
) )
else: else:
@ -183,7 +178,7 @@ def find_app_by_string(script_info, module, app_name):
raise NoAppException( raise NoAppException(
'{e}\nThe factory "{app_name}" in module "{module}" could not ' '{e}\nThe factory "{app_name}" in module "{module}" could not '
'be called with the specified arguments.'.format( "be called with the specified arguments.".format(
e=e, app_name=app_name, module=module.__name__ e=e, app_name=app_name, module=module.__name__
) )
) )
@ -194,10 +189,8 @@ def find_app_by_string(script_info, module, app_name):
return app return app
raise NoAppException( raise NoAppException(
'A valid Flask application was not obtained from ' "A valid Flask application was not obtained from "
'"{module}:{app_name}".'.format( '"{module}:{app_name}".'.format(module=module.__name__, app_name=app_name)
module=module.__name__, app_name=app_name
)
) )
@ -208,10 +201,10 @@ def prepare_import(path):
path = os.path.realpath(path) path = os.path.realpath(path)
fname, ext = os.path.splitext(path) fname, ext = os.path.splitext(path)
if ext == '.py': if ext == ".py":
path = fname path = fname
if os.path.basename(path) == '__init__': if os.path.basename(path) == "__init__":
path = os.path.dirname(path) path = os.path.dirname(path)
module_name = [] module_name = []
@ -221,13 +214,13 @@ def prepare_import(path):
path, name = os.path.split(path) path, name = os.path.split(path)
module_name.append(name) module_name.append(name)
if not os.path.exists(os.path.join(path, '__init__.py')): if not os.path.exists(os.path.join(path, "__init__.py")):
break break
if sys.path[0] != path: if sys.path[0] != path:
sys.path.insert(0, path) sys.path.insert(0, path)
return '.'.join(module_name[::-1]) return ".".join(module_name[::-1])
def locate_app(script_info, module_name, app_name, raise_if_not_found=True): def locate_app(script_info, module_name, app_name, raise_if_not_found=True):
@ -241,12 +234,10 @@ def locate_app(script_info, module_name, app_name, raise_if_not_found=True):
if sys.exc_info()[-1].tb_next: if sys.exc_info()[-1].tb_next:
raise NoAppException( raise NoAppException(
'While importing "{name}", an ImportError was raised:' 'While importing "{name}", an ImportError was raised:'
'\n\n{tb}'.format(name=module_name, tb=traceback.format_exc()) "\n\n{tb}".format(name=module_name, tb=traceback.format_exc())
) )
elif raise_if_not_found: elif raise_if_not_found:
raise NoAppException( raise NoAppException('Could not import "{name}".'.format(name=module_name))
'Could not import "{name}".'.format(name=module_name)
)
else: else:
return return
@ -262,26 +253,27 @@ def get_version(ctx, param, value):
if not value or ctx.resilient_parsing: if not value or ctx.resilient_parsing:
return return
import werkzeug import werkzeug
message = (
'Python %(python)s\n' message = "Python %(python)s\n" "Flask %(flask)s\n" "Werkzeug %(werkzeug)s"
'Flask %(flask)s\n' click.echo(
'Werkzeug %(werkzeug)s' message
% {
"python": platform.python_version(),
"flask": __version__,
"werkzeug": werkzeug.__version__,
},
color=ctx.color,
) )
click.echo(message % {
'python': platform.python_version(),
'flask': __version__,
'werkzeug': werkzeug.__version__,
}, color=ctx.color)
ctx.exit() ctx.exit()
version_option = click.Option( version_option = click.Option(
['--version'], ["--version"],
help='Show the flask version', help="Show the flask version",
expose_value=False, expose_value=False,
callback=get_version, callback=get_version,
is_flag=True, is_flag=True,
is_eager=True is_eager=True,
) )
@ -310,6 +302,7 @@ class DispatchingApp(object):
self._load_unlocked() self._load_unlocked()
except Exception: except Exception:
self._bg_loading_exc_info = sys.exc_info() self._bg_loading_exc_info = sys.exc_info()
t = Thread(target=_load_app, args=()) t = Thread(target=_load_app, args=())
t.start() t.start()
@ -348,10 +341,9 @@ class ScriptInfo(object):
onwards as click object. onwards as click object.
""" """
def __init__(self, app_import_path=None, create_app=None, def __init__(self, app_import_path=None, create_app=None, set_debug_flag=True):
set_debug_flag=True):
#: Optionally the import path for the Flask application. #: Optionally the import path for the Flask application.
self.app_import_path = app_import_path or os.environ.get('FLASK_APP') self.app_import_path = app_import_path or os.environ.get("FLASK_APP")
#: Optionally a function that is passed the script info to create #: Optionally a function that is passed the script info to create
#: the instance of the application. #: the instance of the application.
self.create_app = create_app self.create_app = create_app
@ -377,21 +369,22 @@ class ScriptInfo(object):
app = call_factory(self, self.create_app) app = call_factory(self, self.create_app)
else: else:
if self.app_import_path: if self.app_import_path:
path, name = (re.split(r':(?![\\/])', self.app_import_path, 1) + [None])[:2] path, name = (
re.split(r":(?![\\/])", self.app_import_path, 1) + [None]
)[:2]
import_name = prepare_import(path) import_name = prepare_import(path)
app = locate_app(self, import_name, name) app = locate_app(self, import_name, name)
else: else:
for path in ('wsgi.py', 'app.py'): for path in ("wsgi.py", "app.py"):
import_name = prepare_import(path) import_name = prepare_import(path)
app = locate_app(self, import_name, None, app = locate_app(self, import_name, None, raise_if_not_found=False)
raise_if_not_found=False)
if app: if app:
break break
if not app: if not app:
raise NoAppException( raise NoAppException(
'Could not locate a Flask application. You did not provide ' "Could not locate a Flask application. You did not provide "
'the "FLASK_APP" environment variable, and a "wsgi.py" or ' 'the "FLASK_APP" environment variable, and a "wsgi.py" or '
'"app.py" module was not found in the current directory.' '"app.py" module was not found in the current directory.'
) )
@ -414,10 +407,12 @@ def with_appcontext(f):
to the ``app.cli`` object then they are wrapped with this function to the ``app.cli`` object then they are wrapped with this function
by default unless it's disabled. by default unless it's disabled.
""" """
@click.pass_context @click.pass_context
def decorator(__ctx, *args, **kwargs): def decorator(__ctx, *args, **kwargs):
with __ctx.ensure_object(ScriptInfo).load_app().app_context(): with __ctx.ensure_object(ScriptInfo).load_app().app_context():
return __ctx.invoke(f, *args, **kwargs) return __ctx.invoke(f, *args, **kwargs)
return update_wrapper(decorator, f) return update_wrapper(decorator, f)
@ -434,11 +429,13 @@ class AppGroup(click.Group):
:class:`click.Group` but it wraps callbacks in :func:`with_appcontext` :class:`click.Group` but it wraps callbacks in :func:`with_appcontext`
unless it's disabled by passing ``with_appcontext=False``. unless it's disabled by passing ``with_appcontext=False``.
""" """
wrap_for_ctx = kwargs.pop('with_appcontext', True) wrap_for_ctx = kwargs.pop("with_appcontext", True)
def decorator(f): def decorator(f):
if wrap_for_ctx: if wrap_for_ctx:
f = with_appcontext(f) f = with_appcontext(f)
return click.Group.command(self, *args, **kwargs)(f) return click.Group.command(self, *args, **kwargs)(f)
return decorator return decorator
def group(self, *args, **kwargs): def group(self, *args, **kwargs):
@ -446,7 +443,7 @@ class AppGroup(click.Group):
:class:`click.Group` but it defaults the group class to :class:`click.Group` but it defaults the group class to
:class:`AppGroup`. :class:`AppGroup`.
""" """
kwargs.setdefault('cls', AppGroup) kwargs.setdefault("cls", AppGroup)
return click.Group.group(self, *args, **kwargs) return click.Group.group(self, *args, **kwargs)
@ -475,10 +472,16 @@ class FlaskGroup(AppGroup):
from :file:`.env` and :file:`.flaskenv` files. from :file:`.env` and :file:`.flaskenv` files.
""" """
def __init__(self, add_default_commands=True, create_app=None, def __init__(
add_version_option=True, load_dotenv=True, self,
set_debug_flag=True, **extra): add_default_commands=True,
params = list(extra.pop('params', None) or ()) create_app=None,
add_version_option=True,
load_dotenv=True,
set_debug_flag=True,
**extra
):
params = list(extra.pop("params", None) or ())
if add_version_option: if add_version_option:
params.append(version_option) params.append(version_option)
@ -504,7 +507,7 @@ class FlaskGroup(AppGroup):
self._loaded_plugin_commands = True self._loaded_plugin_commands = True
return return
for ep in pkg_resources.iter_entry_points('flask.commands'): for ep in pkg_resources.iter_entry_points("flask.commands"):
self.add_command(ep.load(), ep.name) self.add_command(ep.load(), ep.name)
self._loaded_plugin_commands = True self._loaded_plugin_commands = True
@ -554,19 +557,20 @@ class FlaskGroup(AppGroup):
# command line interface. This is detected by Flask.run to make the # command line interface. This is detected by Flask.run to make the
# call into a no-op. This is necessary to avoid ugly errors when the # call into a no-op. This is necessary to avoid ugly errors when the
# script that is loaded here also attempts to start a server. # script that is loaded here also attempts to start a server.
os.environ['FLASK_RUN_FROM_CLI'] = 'true' os.environ["FLASK_RUN_FROM_CLI"] = "true"
if get_load_dotenv(self.load_dotenv): if get_load_dotenv(self.load_dotenv):
load_dotenv() load_dotenv()
obj = kwargs.get('obj') obj = kwargs.get("obj")
if obj is None: if obj is None:
obj = ScriptInfo(create_app=self.create_app, obj = ScriptInfo(
set_debug_flag=self.set_debug_flag) create_app=self.create_app, set_debug_flag=self.set_debug_flag
)
kwargs['obj'] = obj kwargs["obj"] = obj
kwargs.setdefault('auto_envvar_prefix', 'FLASK') kwargs.setdefault("auto_envvar_prefix", "FLASK")
return super(FlaskGroup, self).main(*args, **kwargs) return super(FlaskGroup, self).main(*args, **kwargs)
@ -574,7 +578,7 @@ def _path_is_ancestor(path, other):
"""Take ``other`` and remove the length of ``path`` from it. Then join it """Take ``other`` and remove the length of ``path`` from it. Then join it
to ``path``. If it is the original value, ``path`` is an ancestor of to ``path``. If it is the original value, ``path`` is an ancestor of
``other``.""" ``other``."""
return os.path.join(path, other[len(path):].lstrip(os.sep)) == other return os.path.join(path, other[len(path) :].lstrip(os.sep)) == other
def load_dotenv(path=None): def load_dotenv(path=None):
@ -597,11 +601,12 @@ def load_dotenv(path=None):
.. versionadded:: 1.0 .. versionadded:: 1.0
""" """
if dotenv is None: if dotenv is None:
if path or os.path.isfile('.env') or os.path.isfile('.flaskenv'): if path or os.path.isfile(".env") or os.path.isfile(".flaskenv"):
click.secho( click.secho(
' * Tip: There are .env or .flaskenv files present.' " * Tip: There are .env or .flaskenv files present."
' Do "pip install python-dotenv" to use them.', ' Do "pip install python-dotenv" to use them.',
fg='yellow') fg="yellow",
)
return return
if path is not None: if path is not None:
@ -609,7 +614,7 @@ def load_dotenv(path=None):
new_dir = None new_dir = None
for name in ('.env', '.flaskenv'): for name in (".env", ".flaskenv"):
path = dotenv.find_dotenv(name, usecwd=True) path = dotenv.find_dotenv(name, usecwd=True)
if not path: if not path:
@ -630,27 +635,29 @@ def show_server_banner(env, debug, app_import_path, eager_loading):
"""Show extra startup messages the first time the server is run, """Show extra startup messages the first time the server is run,
ignoring the reloader. ignoring the reloader.
""" """
if os.environ.get('WERKZEUG_RUN_MAIN') == 'true': if os.environ.get("WERKZEUG_RUN_MAIN") == "true":
return return
if app_import_path is not None: if app_import_path is not None:
message = ' * Serving Flask app "{0}"'.format(app_import_path) message = ' * Serving Flask app "{0}"'.format(app_import_path)
if not eager_loading: if not eager_loading:
message += ' (lazy loading)' message += " (lazy loading)"
click.echo(message) click.echo(message)
click.echo(' * Environment: {0}'.format(env)) click.echo(" * Environment: {0}".format(env))
if env == 'production': if env == "production":
click.secho( click.secho(
' WARNING: Do not use the development server in a production' " WARNING: Do not use the development server in a production"
' environment.', fg='red') " environment.",
click.secho(' Use a production WSGI server instead.', dim=True) fg="red",
)
click.secho(" Use a production WSGI server instead.", dim=True)
if debug is not None: if debug is not None:
click.echo(' * Debug mode: {0}'.format('on' if debug else 'off')) click.echo(" * Debug mode: {0}".format("on" if debug else "off"))
class CertParamType(click.ParamType): class CertParamType(click.ParamType):
@ -659,11 +666,10 @@ class CertParamType(click.ParamType):
:class:`~ssl.SSLContext` object. :class:`~ssl.SSLContext` object.
""" """
name = 'path' name = "path"
def __init__(self): def __init__(self):
self.path_type = click.Path( self.path_type = click.Path(exists=True, dir_okay=False, resolve_path=True)
exists=True, dir_okay=False, resolve_path=True)
def convert(self, value, param, ctx): def convert(self, value, param, ctx):
try: try:
@ -671,13 +677,13 @@ class CertParamType(click.ParamType):
except click.BadParameter: except click.BadParameter:
value = click.STRING(value, param, ctx).lower() value = click.STRING(value, param, ctx).lower()
if value == 'adhoc': if value == "adhoc":
try: try:
import OpenSSL import OpenSSL
except ImportError: except ImportError:
raise click.BadParameter( raise click.BadParameter(
'Using ad-hoc certificates requires pyOpenSSL.', "Using ad-hoc certificates requires pyOpenSSL.", ctx, param
ctx, param) )
return value return value
@ -697,8 +703,8 @@ def _validate_key(ctx, param, value):
"""The ``--key`` option must be specified when ``--cert`` is a file. """The ``--key`` option must be specified when ``--cert`` is a file.
Modifies the ``cert`` param to be a ``(cert, key)`` pair if needed. Modifies the ``cert`` param to be a ``(cert, key)`` pair if needed.
""" """
cert = ctx.params.get('cert') cert = ctx.params.get("cert")
is_adhoc = cert == 'adhoc' is_adhoc = cert == "adhoc"
if sys.version_info < (2, 7, 9): if sys.version_info < (2, 7, 9):
is_context = cert and not isinstance(cert, (text_type, bytes)) is_context = cert and not isinstance(cert, (text_type, bytes))
@ -708,55 +714,64 @@ def _validate_key(ctx, param, value):
if value is not None: if value is not None:
if is_adhoc: if is_adhoc:
raise click.BadParameter( raise click.BadParameter(
'When "--cert" is "adhoc", "--key" is not used.', 'When "--cert" is "adhoc", "--key" is not used.', ctx, param
ctx, param) )
if is_context: if is_context:
raise click.BadParameter( raise click.BadParameter(
'When "--cert" is an SSLContext object, "--key is not used.', 'When "--cert" is an SSLContext object, "--key is not used.', ctx, param
ctx, param) )
if not cert: if not cert:
raise click.BadParameter( raise click.BadParameter('"--cert" must also be specified.', ctx, param)
'"--cert" must also be specified.',
ctx, param)
ctx.params['cert'] = cert, value ctx.params["cert"] = cert, value
else: else:
if cert and not (is_adhoc or is_context): if cert and not (is_adhoc or is_context):
raise click.BadParameter( raise click.BadParameter('Required when using "--cert".', ctx, param)
'Required when using "--cert".',
ctx, param)
return value return value
@click.command('run', short_help='Run a development server.') @click.command("run", short_help="Run a development server.")
@click.option('--host', '-h', default='127.0.0.1', @click.option("--host", "-h", default="127.0.0.1", help="The interface to bind to.")
help='The interface to bind to.') @click.option("--port", "-p", default=5000, help="The port to bind to.")
@click.option('--port', '-p', default=5000, @click.option(
help='The port to bind to.') "--cert", type=CertParamType(), help="Specify a certificate file to use HTTPS."
@click.option('--cert', type=CertParamType(), )
help='Specify a certificate file to use HTTPS.') @click.option(
@click.option('--key', "--key",
type=click.Path(exists=True, dir_okay=False, resolve_path=True), type=click.Path(exists=True, dir_okay=False, resolve_path=True),
callback=_validate_key, expose_value=False, callback=_validate_key,
help='The key file to use when specifying a certificate.') expose_value=False,
@click.option('--reload/--no-reload', default=None, help="The key file to use when specifying a certificate.",
help='Enable or disable the reloader. By default the reloader ' )
'is active if debug is enabled.') @click.option(
@click.option('--debugger/--no-debugger', default=None, "--reload/--no-reload",
help='Enable or disable the debugger. By default the debugger ' default=None,
'is active if debug is enabled.') help="Enable or disable the reloader. By default the reloader "
@click.option('--eager-loading/--lazy-loader', default=None, "is active if debug is enabled.",
help='Enable or disable eager loading. By default eager ' )
'loading is enabled if the reloader is disabled.') @click.option(
@click.option('--with-threads/--without-threads', default=True, "--debugger/--no-debugger",
help='Enable or disable multithreading.') default=None,
help="Enable or disable the debugger. By default the debugger "
"is active if debug is enabled.",
)
@click.option(
"--eager-loading/--lazy-loader",
default=None,
help="Enable or disable eager loading. By default eager "
"loading is enabled if the reloader is disabled.",
)
@click.option(
"--with-threads/--without-threads",
default=True,
help="Enable or disable multithreading.",
)
@pass_script_info @pass_script_info
def run_command(info, host, port, reload, debugger, eager_loading, def run_command(info, host, port, reload, debugger, eager_loading, with_threads, cert):
with_threads, cert):
"""Run a local development server. """Run a local development server.
This server is for development purposes only. It does not provide This server is for development purposes only. It does not provide
@ -780,11 +795,19 @@ def run_command(info, host, port, reload, debugger, eager_loading,
app = DispatchingApp(info.load_app, use_eager_loading=eager_loading) app = DispatchingApp(info.load_app, use_eager_loading=eager_loading)
from werkzeug.serving import run_simple from werkzeug.serving import run_simple
run_simple(host, port, app, use_reloader=reload, use_debugger=debugger,
threaded=with_threads, ssl_context=cert) run_simple(
host,
port,
app,
use_reloader=reload,
use_debugger=debugger,
threaded=with_threads,
ssl_context=cert,
)
@click.command('shell', short_help='Run a shell in the app context.') @click.command("shell", short_help="Run a shell in the app context.")
@with_appcontext @with_appcontext
def shell_command(): def shell_command():
"""Run an interactive Python shell in the context of a given """Run an interactive Python shell in the context of a given
@ -796,8 +819,9 @@ def shell_command():
""" """
import code import code
from flask.globals import _app_ctx_stack from flask.globals import _app_ctx_stack
app = _app_ctx_stack.top.app app = _app_ctx_stack.top.app
banner = 'Python %s on %s\nApp: %s [%s]\nInstance: %s' % ( banner = "Python %s on %s\nApp: %s [%s]\nInstance: %s" % (
sys.version, sys.version,
sys.platform, sys.platform,
app.import_name, app.import_name,
@ -808,68 +832,64 @@ def shell_command():
# Support the regular Python interpreter startup script if someone # Support the regular Python interpreter startup script if someone
# is using it. # is using it.
startup = os.environ.get('PYTHONSTARTUP') startup = os.environ.get("PYTHONSTARTUP")
if startup and os.path.isfile(startup): if startup and os.path.isfile(startup):
with open(startup, 'r') as f: with open(startup, "r") as f:
eval(compile(f.read(), startup, 'exec'), ctx) eval(compile(f.read(), startup, "exec"), ctx)
ctx.update(app.make_shell_context()) ctx.update(app.make_shell_context())
code.interact(banner=banner, local=ctx) code.interact(banner=banner, local=ctx)
@click.command('routes', short_help='Show the routes for the app.') @click.command("routes", short_help="Show the routes for the app.")
@click.option( @click.option(
'--sort', '-s', "--sort",
type=click.Choice(('endpoint', 'methods', 'rule', 'match')), "-s",
default='endpoint', type=click.Choice(("endpoint", "methods", "rule", "match")),
default="endpoint",
help=( help=(
'Method to sort routes by. "match" is the order that Flask will match ' 'Method to sort routes by. "match" is the order that Flask will match '
'routes when dispatching a request.' "routes when dispatching a request."
) ),
)
@click.option(
'--all-methods',
is_flag=True,
help="Show HEAD and OPTIONS methods."
) )
@click.option("--all-methods", is_flag=True, help="Show HEAD and OPTIONS methods.")
@with_appcontext @with_appcontext
def routes_command(sort, all_methods): def routes_command(sort, all_methods):
"""Show all registered routes with endpoints and methods.""" """Show all registered routes with endpoints and methods."""
rules = list(current_app.url_map.iter_rules()) rules = list(current_app.url_map.iter_rules())
if not rules: if not rules:
click.echo('No routes were registered.') click.echo("No routes were registered.")
return return
ignored_methods = set(() if all_methods else ('HEAD', 'OPTIONS')) ignored_methods = set(() if all_methods else ("HEAD", "OPTIONS"))
if sort in ('endpoint', 'rule'): if sort in ("endpoint", "rule"):
rules = sorted(rules, key=attrgetter(sort)) rules = sorted(rules, key=attrgetter(sort))
elif sort == 'methods': elif sort == "methods":
rules = sorted(rules, key=lambda rule: sorted(rule.methods)) rules = sorted(rules, key=lambda rule: sorted(rule.methods))
rule_methods = [ rule_methods = [", ".join(sorted(rule.methods - ignored_methods)) for rule in rules]
', '.join(sorted(rule.methods - ignored_methods)) for rule in rules
]
headers = ('Endpoint', 'Methods', 'Rule') headers = ("Endpoint", "Methods", "Rule")
widths = ( widths = (
max(len(rule.endpoint) for rule in rules), max(len(rule.endpoint) for rule in rules),
max(len(methods) for methods in rule_methods), max(len(methods) for methods in rule_methods),
max(len(rule.rule) for rule in rules), max(len(rule.rule) for rule in rules),
) )
widths = [max(len(h), w) for h, w in zip(headers, widths)] widths = [max(len(h), w) for h, w in zip(headers, widths)]
row = '{{0:<{0}}} {{1:<{1}}} {{2:<{2}}}'.format(*widths) row = "{{0:<{0}}} {{1:<{1}}} {{2:<{2}}}".format(*widths)
click.echo(row.format(*headers).strip()) click.echo(row.format(*headers).strip())
click.echo(row.format(*('-' * width for width in widths))) click.echo(row.format(*("-" * width for width in widths)))
for rule, methods in zip(rules, rule_methods): for rule, methods in zip(rules, rule_methods):
click.echo(row.format(rule.endpoint, methods, rule.rule).rstrip()) click.echo(row.format(rule.endpoint, methods, rule.rule).rstrip())
cli = FlaskGroup(help="""\ cli = FlaskGroup(
help="""\
A general utility script for Flask applications. A general utility script for Flask applications.
Provides commands from Flask, extensions, and the application. Loads the Provides commands from Flask, extensions, and the application. Loads the
@ -882,30 +902,31 @@ debug mode.
{prefix}{cmd} FLASK_ENV=development {prefix}{cmd} FLASK_ENV=development
{prefix}flask run {prefix}flask run
""".format( """.format(
cmd='export' if os.name == 'posix' else 'set', cmd="export" if os.name == "posix" else "set",
prefix='$ ' if os.name == 'posix' else '> ' prefix="$ " if os.name == "posix" else "> ",
)) )
)
def main(as_module=False): def main(as_module=False):
args = sys.argv[1:] args = sys.argv[1:]
if as_module: if as_module:
this_module = 'flask' this_module = "flask"
if sys.version_info < (2, 7): if sys.version_info < (2, 7):
this_module += '.cli' this_module += ".cli"
name = 'python -m ' + this_module name = "python -m " + this_module
# Python rewrites "python -m flask" to the path to the file in argv. # Python rewrites "python -m flask" to the path to the file in argv.
# Restore the original command so that the reloader works. # Restore the original command so that the reloader works.
sys.argv = ['-m', this_module] + args sys.argv = ["-m", this_module] + args
else: else:
name = None name = None
cli.main(args=args, prog_name=name) cli.main(args=args, prog_name=name)
if __name__ == '__main__': if __name__ == "__main__":
main(as_module=True) main(as_module=True)

View file

@ -101,11 +101,12 @@ class Config(dict):
if not rv: if not rv:
if silent: if silent:
return False return False
raise RuntimeError('The environment variable %r is not set ' raise RuntimeError(
'and as such configuration could not be ' "The environment variable %r is not set "
'loaded. Set this variable and make it ' "and as such configuration could not be "
'point to a configuration file' % "loaded. Set this variable and make it "
variable_name) "point to a configuration file" % variable_name
)
return self.from_pyfile(rv, silent=silent) return self.from_pyfile(rv, silent=silent)
def from_pyfile(self, filename, silent=False): def from_pyfile(self, filename, silent=False):
@ -123,17 +124,15 @@ class Config(dict):
`silent` parameter. `silent` parameter.
""" """
filename = os.path.join(self.root_path, filename) filename = os.path.join(self.root_path, filename)
d = types.ModuleType('config') d = types.ModuleType("config")
d.__file__ = filename d.__file__ = filename
try: try:
with open(filename, mode='rb') as config_file: with open(filename, mode="rb") as config_file:
exec(compile(config_file.read(), filename, 'exec'), d.__dict__) exec(compile(config_file.read(), filename, "exec"), d.__dict__)
except IOError as e: except IOError as e:
if silent and e.errno in ( if silent and e.errno in (errno.ENOENT, errno.EISDIR, errno.ENOTDIR):
errno.ENOENT, errno.EISDIR, errno.ENOTDIR
):
return False return False
e.strerror = 'Unable to load configuration file (%s)' % e.strerror e.strerror = "Unable to load configuration file (%s)" % e.strerror
raise raise
self.from_object(d) self.from_object(d)
return True return True
@ -197,7 +196,7 @@ class Config(dict):
except IOError as e: except IOError as e:
if silent and e.errno in (errno.ENOENT, errno.EISDIR): if silent and e.errno in (errno.ENOENT, errno.EISDIR):
return False return False
e.strerror = 'Unable to load configuration file (%s)' % e.strerror e.strerror = "Unable to load configuration file (%s)" % e.strerror
raise raise
return self.from_mapping(obj) return self.from_mapping(obj)
@ -209,13 +208,13 @@ class Config(dict):
""" """
mappings = [] mappings = []
if len(mapping) == 1: if len(mapping) == 1:
if hasattr(mapping[0], 'items'): if hasattr(mapping[0], "items"):
mappings.append(mapping[0].items()) mappings.append(mapping[0].items())
else: else:
mappings.append(mapping[0]) mappings.append(mapping[0])
elif len(mapping) > 1: elif len(mapping) > 1:
raise TypeError( raise TypeError(
'expected at most 1 positional argument, got %d' % len(mapping) "expected at most 1 positional argument, got %d" % len(mapping)
) )
mappings.append(kwargs.items()) mappings.append(kwargs.items())
for mapping in mappings: for mapping in mappings:
@ -257,7 +256,7 @@ class Config(dict):
if not k.startswith(namespace): if not k.startswith(namespace):
continue continue
if trim_namespace: if trim_namespace:
key = k[len(namespace):] key = k[len(namespace) :]
else: else:
key = k key = k
if lowercase: if lowercase:
@ -266,4 +265,4 @@ class Config(dict):
return rv return rv
def __repr__(self): def __repr__(self):
return '<%s %s>' % (self.__class__.__name__, dict.__repr__(self)) return "<%s %s>" % (self.__class__.__name__, dict.__repr__(self))

View file

@ -89,7 +89,7 @@ class _AppCtxGlobals(object):
def __repr__(self): def __repr__(self):
top = _app_ctx_stack.top top = _app_ctx_stack.top
if top is not None: if top is not None:
return '<flask.g of %r>' % top.app.name return "<flask.g of %r>" % top.app.name
return object.__repr__(self) return object.__repr__(self)
@ -144,13 +144,17 @@ def copy_current_request_context(f):
""" """
top = _request_ctx_stack.top top = _request_ctx_stack.top
if top is None: if top is None:
raise RuntimeError('This decorator can only be used at local scopes ' raise RuntimeError(
'when a request context is on the stack. For instance within ' "This decorator can only be used at local scopes "
'view functions.') "when a request context is on the stack. For instance within "
"view functions."
)
reqctx = top.copy() reqctx = top.copy()
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
with reqctx: with reqctx:
return f(*args, **kwargs) return f(*args, **kwargs)
return update_wrapper(wrapper, f) return update_wrapper(wrapper, f)
@ -217,7 +221,7 @@ class AppContext(object):
def push(self): def push(self):
"""Binds the app context to the current context.""" """Binds the app context to the current context."""
self._refcnt += 1 self._refcnt += 1
if hasattr(sys, 'exc_clear'): if hasattr(sys, "exc_clear"):
sys.exc_clear() sys.exc_clear()
_app_ctx_stack.push(self) _app_ctx_stack.push(self)
appcontext_pushed.send(self.app) appcontext_pushed.send(self.app)
@ -232,8 +236,7 @@ class AppContext(object):
self.app.do_teardown_appcontext(exc) self.app.do_teardown_appcontext(exc)
finally: finally:
rv = _app_ctx_stack.pop() rv = _app_ctx_stack.pop()
assert rv is self, 'Popped wrong app context. (%r instead of %r)' \ assert rv is self, "Popped wrong app context. (%r instead of %r)" % (rv, self)
% (rv, self)
appcontext_popped.send(self.app) appcontext_popped.send(self.app)
def __enter__(self): def __enter__(self):
@ -314,8 +317,10 @@ class RequestContext(object):
def _get_g(self): def _get_g(self):
return _app_ctx_stack.top.g return _app_ctx_stack.top.g
def _set_g(self, value): def _set_g(self, value):
_app_ctx_stack.top.g = value _app_ctx_stack.top.g = value
g = property(_get_g, _set_g) g = property(_get_g, _set_g)
del _get_g, _set_g del _get_g, _set_g
@ -332,10 +337,11 @@ class RequestContext(object):
The current session object is used instead of reloading the original The current session object is used instead of reloading the original
data. This prevents `flask.session` pointing to an out-of-date object. data. This prevents `flask.session` pointing to an out-of-date object.
""" """
return self.__class__(self.app, return self.__class__(
self.app,
environ=self.request.environ, environ=self.request.environ,
request=self.request, request=self.request,
session=self.session session=self.session,
) )
def match_request(self): def match_request(self):
@ -343,8 +349,7 @@ class RequestContext(object):
of the request. of the request.
""" """
try: try:
url_rule, self.request.view_args = \ url_rule, self.request.view_args = self.url_adapter.match(return_rule=True)
self.url_adapter.match(return_rule=True)
self.request.url_rule = url_rule self.request.url_rule = url_rule
except HTTPException as e: except HTTPException as e:
self.request.routing_exception = e self.request.routing_exception = e
@ -373,7 +378,7 @@ class RequestContext(object):
else: else:
self._implicit_app_ctx_stack.append(None) self._implicit_app_ctx_stack.append(None)
if hasattr(sys, 'exc_clear'): if hasattr(sys, "exc_clear"):
sys.exc_clear() sys.exc_clear()
_request_ctx_stack.push(self) _request_ctx_stack.push(self)
@ -384,9 +389,7 @@ class RequestContext(object):
# pushed, otherwise stream_with_context loses the session. # pushed, otherwise stream_with_context loses the session.
if self.session is None: if self.session is None:
session_interface = self.app.session_interface session_interface = self.app.session_interface
self.session = session_interface.open_session( self.session = session_interface.open_session(self.app, self.request)
self.app, self.request
)
if self.session is None: if self.session is None:
self.session = session_interface.make_null_session(self.app) self.session = session_interface.make_null_session(self.app)
@ -414,10 +417,10 @@ class RequestContext(object):
# we do that now. This will only go into effect on Python 2.x, # we do that now. This will only go into effect on Python 2.x,
# on 3.x it disappears automatically at the end of the exception # on 3.x it disappears automatically at the end of the exception
# stack. # stack.
if hasattr(sys, 'exc_clear'): if hasattr(sys, "exc_clear"):
sys.exc_clear() sys.exc_clear()
request_close = getattr(self.request, 'close', None) request_close = getattr(self.request, "close", None)
if request_close is not None: if request_close is not None:
request_close() request_close()
clear_request = True clear_request = True
@ -427,18 +430,20 @@ class RequestContext(object):
# get rid of circular dependencies at the end of the request # get rid of circular dependencies at the end of the request
# so that we don't require the GC to be active. # so that we don't require the GC to be active.
if clear_request: if clear_request:
rv.request.environ['werkzeug.request'] = None rv.request.environ["werkzeug.request"] = None
# Get rid of the app as well if necessary. # Get rid of the app as well if necessary.
if app_ctx is not None: if app_ctx is not None:
app_ctx.pop(exc) app_ctx.pop(exc)
assert rv is self, 'Popped wrong request context. ' \ assert (
'(%r instead of %r)' % (rv, self) rv is self
), "Popped wrong request context. " "(%r instead of %r)" % (rv, self)
def auto_pop(self, exc): def auto_pop(self, exc):
if self.request.environ.get('flask._preserve_context') or \ if self.request.environ.get("flask._preserve_context") or (
(exc is not None and self.app.preserve_context_on_exception): exc is not None and self.app.preserve_context_on_exception
):
self.preserved = True self.preserved = True
self._preserved_exc = exc self._preserved_exc = exc
else: else:
@ -460,7 +465,7 @@ class RequestContext(object):
reraise(exc_type, exc_value, tb) reraise(exc_type, exc_value, tb)
def __repr__(self): def __repr__(self):
return '<%s \'%s\' [%s] of %s>' % ( return "<%s '%s' [%s] of %s>" % (
self.__class__.__name__, self.__class__.__name__,
self.request.url, self.request.url,
self.request.method, self.request.method,

View file

@ -32,17 +32,20 @@ class DebugFilesKeyError(KeyError, AssertionError):
def __init__(self, request, key): def __init__(self, request, key):
form_matches = request.form.getlist(key) form_matches = request.form.getlist(key)
buf = ['You tried to access the file "%s" in the request.files ' buf = [
'dictionary but it does not exist. The mimetype for the request ' 'You tried to access the file "%s" in the request.files '
'is "%s" instead of "multipart/form-data" which means that no ' "dictionary but it does not exist. The mimetype for the request "
'file contents were transmitted. To fix this error you should ' 'is "%s" instead of "multipart/form-data" which means that no '
'provide enctype="multipart/form-data" in your form.' % "file contents were transmitted. To fix this error you should "
(key, request.mimetype)] 'provide enctype="multipart/form-data" in your form.'
% (key, request.mimetype)
]
if form_matches: if form_matches:
buf.append('\n\nThe browser instead transmitted some file names. ' buf.append(
'This was submitted: %s' % ', '.join('"%s"' % x "\n\nThe browser instead transmitted some file names. "
for x in form_matches)) "This was submitted: %s" % ", ".join('"%s"' % x for x in form_matches)
self.msg = ''.join(buf) )
self.msg = "".join(buf)
def __str__(self): def __str__(self):
return self.msg return self.msg
@ -56,23 +59,28 @@ class FormDataRoutingRedirect(AssertionError):
def __init__(self, request): def __init__(self, request):
exc = request.routing_exception exc = request.routing_exception
buf = ['A request was sent to this URL (%s) but a redirect was ' buf = [
'issued automatically by the routing system to "%s".' "A request was sent to this URL (%s) but a redirect was "
% (request.url, exc.new_url)] 'issued automatically by the routing system to "%s".'
% (request.url, exc.new_url)
]
# In case just a slash was appended we can be extra helpful # In case just a slash was appended we can be extra helpful
if request.base_url + '/' == exc.new_url.split('?')[0]: if request.base_url + "/" == exc.new_url.split("?")[0]:
buf.append(' The URL was defined with a trailing slash so ' buf.append(
'Flask will automatically redirect to the URL ' " The URL was defined with a trailing slash so "
'with the trailing slash if it was accessed ' "Flask will automatically redirect to the URL "
'without one.') "with the trailing slash if it was accessed "
"without one."
)
buf.append(' Make sure to directly send your %s-request to this URL ' buf.append(
'since we can\'t make browsers or HTTP clients redirect ' " Make sure to directly send your %s-request to this URL "
'with form data reliably or without user interaction.' % "since we can't make browsers or HTTP clients redirect "
request.method) "with form data reliably or without user interaction." % request.method
buf.append('\n\nNote: this exception is only raised in debug mode') )
AssertionError.__init__(self, ''.join(buf).encode('utf-8')) buf.append("\n\nNote: this exception is only raised in debug mode")
AssertionError.__init__(self, "".join(buf).encode("utf-8"))
def attach_enctype_error_multidict(request): def attach_enctype_error_multidict(request):
@ -81,6 +89,7 @@ def attach_enctype_error_multidict(request):
object is accessed. object is accessed.
""" """
oldcls = request.files.__class__ oldcls = request.files.__class__
class newcls(oldcls): class newcls(oldcls):
def __getitem__(self, key): def __getitem__(self, key):
try: try:
@ -89,26 +98,27 @@ def attach_enctype_error_multidict(request):
if key not in request.form: if key not in request.form:
raise raise
raise DebugFilesKeyError(request, key) raise DebugFilesKeyError(request, key)
newcls.__name__ = oldcls.__name__ newcls.__name__ = oldcls.__name__
newcls.__module__ = oldcls.__module__ newcls.__module__ = oldcls.__module__
request.files.__class__ = newcls request.files.__class__ = newcls
def _dump_loader_info(loader): def _dump_loader_info(loader):
yield 'class: %s.%s' % (type(loader).__module__, type(loader).__name__) yield "class: %s.%s" % (type(loader).__module__, type(loader).__name__)
for key, value in sorted(loader.__dict__.items()): for key, value in sorted(loader.__dict__.items()):
if key.startswith('_'): if key.startswith("_"):
continue continue
if isinstance(value, (tuple, list)): if isinstance(value, (tuple, list)):
if not all(isinstance(x, (str, text_type)) for x in value): if not all(isinstance(x, (str, text_type)) for x in value):
continue continue
yield '%s:' % key yield "%s:" % key
for item in value: for item in value:
yield ' - %s' % item yield " - %s" % item
continue continue
elif not isinstance(value, (str, text_type, int, float, bool)): elif not isinstance(value, (str, text_type, int, float, bool)):
continue continue
yield '%s: %r' % (key, value) yield "%s: %r" % (key, value)
def explain_template_loading_attempts(app, template, attempts): def explain_template_loading_attempts(app, template, attempts):
@ -124,45 +134,50 @@ def explain_template_loading_attempts(app, template, attempts):
if isinstance(srcobj, Flask): if isinstance(srcobj, Flask):
src_info = 'application "%s"' % srcobj.import_name src_info = 'application "%s"' % srcobj.import_name
elif isinstance(srcobj, Blueprint): elif isinstance(srcobj, Blueprint):
src_info = 'blueprint "%s" (%s)' % (srcobj.name, src_info = 'blueprint "%s" (%s)' % (srcobj.name, srcobj.import_name)
srcobj.import_name)
else: else:
src_info = repr(srcobj) src_info = repr(srcobj)
info.append('% 5d: trying loader of %s' % ( info.append("% 5d: trying loader of %s" % (idx + 1, src_info))
idx + 1, src_info))
for line in _dump_loader_info(loader): for line in _dump_loader_info(loader):
info.append(' %s' % line) info.append(" %s" % line)
if triple is None: if triple is None:
detail = 'no match' detail = "no match"
else: else:
detail = 'found (%r)' % (triple[1] or '<string>') detail = "found (%r)" % (triple[1] or "<string>")
total_found += 1 total_found += 1
info.append(' -> %s' % detail) info.append(" -> %s" % detail)
seems_fishy = False seems_fishy = False
if total_found == 0: if total_found == 0:
info.append('Error: the template could not be found.') info.append("Error: the template could not be found.")
seems_fishy = True seems_fishy = True
elif total_found > 1: elif total_found > 1:
info.append('Warning: multiple loaders returned a match for the template.') info.append("Warning: multiple loaders returned a match for the template.")
seems_fishy = True seems_fishy = True
if blueprint is not None and seems_fishy: if blueprint is not None and seems_fishy:
info.append(' The template was looked up from an endpoint that ' info.append(
'belongs to the blueprint "%s".' % blueprint) " The template was looked up from an endpoint that "
info.append(' Maybe you did not place a template in the right folder?') 'belongs to the blueprint "%s".' % blueprint
info.append(' See http://flask.pocoo.org/docs/blueprints/#templates') )
info.append(" Maybe you did not place a template in the right folder?")
info.append(" See http://flask.pocoo.org/docs/blueprints/#templates")
app.logger.info('\n'.join(info)) app.logger.info("\n".join(info))
def explain_ignored_app_run(): def explain_ignored_app_run():
if os.environ.get('WERKZEUG_RUN_MAIN') != 'true': if os.environ.get("WERKZEUG_RUN_MAIN") != "true":
warn(Warning('Silently ignoring app.run() because the ' warn(
'application is run from the flask command line ' Warning(
'executable. Consider putting app.run() behind an ' "Silently ignoring app.run() because the "
'if __name__ == "__main__" guard to silence this ' "application is run from the flask command line "
'warning.'), stacklevel=3) "executable. Consider putting app.run() behind an "
'if __name__ == "__main__" guard to silence this '
"warning."
),
stacklevel=3,
)

View file

@ -14,21 +14,21 @@ from functools import partial
from werkzeug.local import LocalStack, LocalProxy from werkzeug.local import LocalStack, LocalProxy
_request_ctx_err_msg = '''\ _request_ctx_err_msg = """\
Working outside of request context. Working outside of request context.
This typically means that you attempted to use functionality that needed This typically means that you attempted to use functionality that needed
an active HTTP request. Consult the documentation on testing for an active HTTP request. Consult the documentation on testing for
information about how to avoid this problem.\ information about how to avoid this problem.\
''' """
_app_ctx_err_msg = '''\ _app_ctx_err_msg = """\
Working outside of application context. Working outside of application context.
This typically means that you attempted to use functionality that needed This typically means that you attempted to use functionality that needed
to interface with the current application object in some way. To solve to interface with the current application object in some way. To solve
this, set up an application context with app.app_context(). See the this, set up an application context with app.app_context(). See the
documentation for more information.\ documentation for more information.\
''' """
def _lookup_req_object(name): def _lookup_req_object(name):
@ -56,6 +56,6 @@ def _find_app():
_request_ctx_stack = LocalStack() _request_ctx_stack = LocalStack()
_app_ctx_stack = LocalStack() _app_ctx_stack = LocalStack()
current_app = LocalProxy(_find_app) current_app = LocalProxy(_find_app)
request = LocalProxy(partial(_lookup_req_object, 'request')) request = LocalProxy(partial(_lookup_req_object, "request"))
session = LocalProxy(partial(_lookup_req_object, 'session')) session = LocalProxy(partial(_lookup_req_object, "session"))
g = LocalProxy(partial(_lookup_app_object, 'g')) g = LocalProxy(partial(_lookup_app_object, "g"))

View file

@ -24,15 +24,13 @@ from functools import update_wrapper
from werkzeug.urls import url_quote from werkzeug.urls import url_quote
from werkzeug.datastructures import Headers, Range from werkzeug.datastructures import Headers, Range
from werkzeug.exceptions import BadRequest, NotFound, \ from werkzeug.exceptions import BadRequest, NotFound, RequestedRangeNotSatisfiable
RequestedRangeNotSatisfiable
from werkzeug.wsgi import wrap_file from werkzeug.wsgi import wrap_file
from jinja2 import FileSystemLoader from jinja2 import FileSystemLoader
from .signals import message_flashed from .signals import message_flashed
from .globals import session, _request_ctx_stack, _app_ctx_stack, \ from .globals import session, _request_ctx_stack, _app_ctx_stack, current_app, request
current_app, request
from ._compat import string_types, text_type, PY2, fspath from ._compat import string_types, text_type, PY2, fspath
# sentinel # sentinel
@ -42,8 +40,9 @@ _missing = object()
# what separators does this operating system provide that are not a slash? # what separators does this operating system provide that are not a slash?
# this is used by the send_from_directory function to ensure that nobody is # this is used by the send_from_directory function to ensure that nobody is
# able to access files from outside the filesystem. # able to access files from outside the filesystem.
_os_alt_seps = list(sep for sep in [os.path.sep, os.path.altsep] _os_alt_seps = list(
if sep not in (None, '/')) sep for sep in [os.path.sep, os.path.altsep] if sep not in (None, "/")
)
def get_env(): def get_env():
@ -51,7 +50,7 @@ def get_env():
:envvar:`FLASK_ENV` environment variable. The default is :envvar:`FLASK_ENV` environment variable. The default is
``'production'``. ``'production'``.
""" """
return os.environ.get('FLASK_ENV') or 'production' return os.environ.get("FLASK_ENV") or "production"
def get_debug_flag(): def get_debug_flag():
@ -60,12 +59,12 @@ def get_debug_flag():
``True`` if :func:`.get_env` returns ``'development'``, or ``False`` ``True`` if :func:`.get_env` returns ``'development'``, or ``False``
otherwise. otherwise.
""" """
val = os.environ.get('FLASK_DEBUG') val = os.environ.get("FLASK_DEBUG")
if not val: if not val:
return get_env() == 'development' return get_env() == "development"
return val.lower() not in ('0', 'false', 'no') return val.lower() not in ("0", "false", "no")
def get_load_dotenv(default=True): def get_load_dotenv(default=True):
@ -75,20 +74,19 @@ def get_load_dotenv(default=True):
:param default: What to return if the env var isn't set. :param default: What to return if the env var isn't set.
""" """
val = os.environ.get('FLASK_SKIP_DOTENV') val = os.environ.get("FLASK_SKIP_DOTENV")
if not val: if not val:
return default return default
return val.lower() in ('0', 'false', 'no') return val.lower() in ("0", "false", "no")
def _endpoint_from_view_func(view_func): def _endpoint_from_view_func(view_func):
"""Internal helper that returns the default endpoint for a given """Internal helper that returns the default endpoint for a given
function. This always is the function name. function. This always is the function name.
""" """
assert view_func is not None, 'expected view func if endpoint ' \ assert view_func is not None, "expected view func if endpoint " "is not provided."
'is not provided.'
return view_func.__name__ return view_func.__name__
@ -129,16 +127,20 @@ def stream_with_context(generator_or_function):
try: try:
gen = iter(generator_or_function) gen = iter(generator_or_function)
except TypeError: except TypeError:
def decorator(*args, **kwargs): def decorator(*args, **kwargs):
gen = generator_or_function(*args, **kwargs) gen = generator_or_function(*args, **kwargs)
return stream_with_context(gen) return stream_with_context(gen)
return update_wrapper(decorator, generator_or_function) return update_wrapper(decorator, generator_or_function)
def generator(): def generator():
ctx = _request_ctx_stack.top ctx = _request_ctx_stack.top
if ctx is None: if ctx is None:
raise RuntimeError('Attempted to stream with context but ' raise RuntimeError(
'there was no context in the first place to keep around.') "Attempted to stream with context but "
"there was no context in the first place to keep around."
)
with ctx: with ctx:
# Dummy sentinel. Has to be inside the context block or we're # Dummy sentinel. Has to be inside the context block or we're
# not actually keeping the context around. # not actually keeping the context around.
@ -152,7 +154,7 @@ def stream_with_context(generator_or_function):
for item in gen: for item in gen:
yield item yield item
finally: finally:
if hasattr(gen, 'close'): if hasattr(gen, "close"):
gen.close() gen.close()
# The trick is to start the generator. Then the code execution runs until # The trick is to start the generator. Then the code execution runs until
@ -291,9 +293,9 @@ def url_for(endpoint, **values):
if appctx is None: if appctx is None:
raise RuntimeError( raise RuntimeError(
'Attempted to generate a URL without the application context being' "Attempted to generate a URL without the application context being"
' pushed. This has to be executed when application context is' " pushed. This has to be executed when application context is"
' available.' " available."
) )
# If request specific information is available we have some extra # If request specific information is available we have some extra
@ -302,13 +304,13 @@ def url_for(endpoint, **values):
url_adapter = reqctx.url_adapter url_adapter = reqctx.url_adapter
blueprint_name = request.blueprint blueprint_name = request.blueprint
if endpoint[:1] == '.': if endpoint[:1] == ".":
if blueprint_name is not None: if blueprint_name is not None:
endpoint = blueprint_name + endpoint endpoint = blueprint_name + endpoint
else: else:
endpoint = endpoint[1:] endpoint = endpoint[1:]
external = values.pop('_external', False) external = values.pop("_external", False)
# Otherwise go with the url adapter from the appctx and make # Otherwise go with the url adapter from the appctx and make
# the URLs external by default. # the URLs external by default.
@ -317,16 +319,16 @@ def url_for(endpoint, **values):
if url_adapter is None: if url_adapter is None:
raise RuntimeError( raise RuntimeError(
'Application was not able to create a URL adapter for request' "Application was not able to create a URL adapter for request"
' independent URL generation. You might be able to fix this by' " independent URL generation. You might be able to fix this by"
' setting the SERVER_NAME config variable.' " setting the SERVER_NAME config variable."
) )
external = values.pop('_external', True) external = values.pop("_external", True)
anchor = values.pop('_anchor', None) anchor = values.pop("_anchor", None)
method = values.pop('_method', None) method = values.pop("_method", None)
scheme = values.pop('_scheme', None) scheme = values.pop("_scheme", None)
appctx.app.inject_url_defaults(endpoint, values) appctx.app.inject_url_defaults(endpoint, values)
# This is not the best way to deal with this but currently the # This is not the best way to deal with this but currently the
@ -335,28 +337,29 @@ def url_for(endpoint, **values):
old_scheme = None old_scheme = None
if scheme is not None: if scheme is not None:
if not external: if not external:
raise ValueError('When specifying _scheme, _external must be True') raise ValueError("When specifying _scheme, _external must be True")
old_scheme = url_adapter.url_scheme old_scheme = url_adapter.url_scheme
url_adapter.url_scheme = scheme url_adapter.url_scheme = scheme
try: try:
try: try:
rv = url_adapter.build(endpoint, values, method=method, rv = url_adapter.build(
force_external=external) endpoint, values, method=method, force_external=external
)
finally: finally:
if old_scheme is not None: if old_scheme is not None:
url_adapter.url_scheme = old_scheme url_adapter.url_scheme = old_scheme
except BuildError as error: except BuildError as error:
# We need to inject the values again so that the app callback can # We need to inject the values again so that the app callback can
# deal with that sort of stuff. # deal with that sort of stuff.
values['_external'] = external values["_external"] = external
values['_anchor'] = anchor values["_anchor"] = anchor
values['_method'] = method values["_method"] = method
values['_scheme'] = scheme values["_scheme"] = scheme
return appctx.app.handle_url_build_error(error, endpoint, values) return appctx.app.handle_url_build_error(error, endpoint, values)
if anchor is not None: if anchor is not None:
rv += '#' + url_quote(anchor) rv += "#" + url_quote(anchor)
return rv return rv
@ -379,11 +382,10 @@ def get_template_attribute(template_name, attribute):
:param template_name: the name of the template :param template_name: the name of the template
:param attribute: the name of the variable of macro to access :param attribute: the name of the variable of macro to access
""" """
return getattr(current_app.jinja_env.get_template(template_name).module, return getattr(current_app.jinja_env.get_template(template_name).module, attribute)
attribute)
def flash(message, category='message'): def flash(message, category="message"):
"""Flashes a message to the next request. In order to remove the """Flashes a message to the next request. In order to remove the
flashed message from the session and to display it to the user, flashed message from the session and to display it to the user,
the template has to call :func:`get_flashed_messages`. the template has to call :func:`get_flashed_messages`.
@ -405,11 +407,12 @@ def flash(message, category='message'):
# This assumed that changes made to mutable structures in the session are # This assumed that changes made to mutable structures in the session are
# always in sync with the session object, which is not true for session # always in sync with the session object, which is not true for session
# implementations that use external storage for keeping their keys/values. # implementations that use external storage for keeping their keys/values.
flashes = session.get('_flashes', []) flashes = session.get("_flashes", [])
flashes.append((category, message)) flashes.append((category, message))
session['_flashes'] = flashes session["_flashes"] = flashes
message_flashed.send(current_app._get_current_object(), message_flashed.send(
message=message, category=category) current_app._get_current_object(), message=message, category=category
)
def get_flashed_messages(with_categories=False, category_filter=[]): def get_flashed_messages(with_categories=False, category_filter=[]):
@ -442,8 +445,9 @@ def get_flashed_messages(with_categories=False, category_filter=[]):
""" """
flashes = _request_ctx_stack.top.flashes flashes = _request_ctx_stack.top.flashes
if flashes is None: if flashes is None:
_request_ctx_stack.top.flashes = flashes = session.pop('_flashes') \ _request_ctx_stack.top.flashes = flashes = (
if '_flashes' in session else [] session.pop("_flashes") if "_flashes" in session else []
)
if category_filter: if category_filter:
flashes = list(filter(lambda f: f[0] in category_filter, flashes)) flashes = list(filter(lambda f: f[0] in category_filter, flashes))
if not with_categories: if not with_categories:
@ -451,9 +455,16 @@ def get_flashed_messages(with_categories=False, category_filter=[]):
return flashes return flashes
def send_file(filename_or_fp, mimetype=None, as_attachment=False, def send_file(
attachment_filename=None, add_etags=True, filename_or_fp,
cache_timeout=None, conditional=False, last_modified=None): mimetype=None,
as_attachment=False,
attachment_filename=None,
add_etags=True,
cache_timeout=None,
conditional=False,
last_modified=None,
):
"""Sends the contents of a file to the client. This will use the """Sends the contents of a file to the client. This will use the
most efficient method available and configured. By default it will most efficient method available and configured. By default it will
try to use the WSGI server's file_wrapper support. Alternatively try to use the WSGI server's file_wrapper support. Alternatively
@ -545,7 +556,7 @@ def send_file(filename_or_fp, mimetype=None, as_attachment=False,
mtime = None mtime = None
fsize = None fsize = None
if hasattr(filename_or_fp, '__fspath__'): if hasattr(filename_or_fp, "__fspath__"):
filename_or_fp = fspath(filename_or_fp) filename_or_fp = fspath(filename_or_fp)
if isinstance(filename_or_fp, string_types): if isinstance(filename_or_fp, string_types):
@ -561,62 +572,67 @@ def send_file(filename_or_fp, mimetype=None, as_attachment=False,
if mimetype is None: if mimetype is None:
if attachment_filename is not None: if attachment_filename is not None:
mimetype = mimetypes.guess_type(attachment_filename)[0] \ mimetype = (
or 'application/octet-stream' mimetypes.guess_type(attachment_filename)[0]
or "application/octet-stream"
)
if mimetype is None: if mimetype is None:
raise ValueError( raise ValueError(
'Unable to infer MIME-type because no filename is available. ' "Unable to infer MIME-type because no filename is available. "
'Please set either `attachment_filename`, pass a filepath to ' "Please set either `attachment_filename`, pass a filepath to "
'`filename_or_fp` or set your own MIME-type via `mimetype`.' "`filename_or_fp` or set your own MIME-type via `mimetype`."
) )
headers = Headers() headers = Headers()
if as_attachment: if as_attachment:
if attachment_filename is None: if attachment_filename is None:
raise TypeError('filename unavailable, required for ' raise TypeError(
'sending as attachment') "filename unavailable, required for " "sending as attachment"
)
if not isinstance(attachment_filename, text_type): if not isinstance(attachment_filename, text_type):
attachment_filename = attachment_filename.decode('utf-8') attachment_filename = attachment_filename.decode("utf-8")
try: try:
attachment_filename = attachment_filename.encode('ascii') attachment_filename = attachment_filename.encode("ascii")
except UnicodeEncodeError: except UnicodeEncodeError:
filenames = { filenames = {
'filename': unicodedata.normalize( "filename": unicodedata.normalize("NFKD", attachment_filename).encode(
'NFKD', attachment_filename).encode('ascii', 'ignore'), "ascii", "ignore"
'filename*': "UTF-8''%s" % url_quote(attachment_filename), ),
"filename*": "UTF-8''%s" % url_quote(attachment_filename),
} }
else: else:
filenames = {'filename': attachment_filename} filenames = {"filename": attachment_filename}
headers.add('Content-Disposition', 'attachment', **filenames) headers.add("Content-Disposition", "attachment", **filenames)
if current_app.use_x_sendfile and filename: if current_app.use_x_sendfile and filename:
if file is not None: if file is not None:
file.close() file.close()
headers['X-Sendfile'] = filename headers["X-Sendfile"] = filename
fsize = os.path.getsize(filename) fsize = os.path.getsize(filename)
headers['Content-Length'] = fsize headers["Content-Length"] = fsize
data = None data = None
else: else:
if file is None: if file is None:
file = open(filename, 'rb') file = open(filename, "rb")
mtime = os.path.getmtime(filename) mtime = os.path.getmtime(filename)
fsize = os.path.getsize(filename) fsize = os.path.getsize(filename)
headers['Content-Length'] = fsize headers["Content-Length"] = fsize
elif isinstance(file, io.BytesIO): elif isinstance(file, io.BytesIO):
try: try:
fsize = file.getbuffer().nbytes fsize = file.getbuffer().nbytes
except AttributeError: except AttributeError:
# Python 2 doesn't have getbuffer # Python 2 doesn't have getbuffer
fsize = len(file.getvalue()) fsize = len(file.getvalue())
headers['Content-Length'] = fsize headers["Content-Length"] = fsize
data = wrap_file(request.environ, file) data = wrap_file(request.environ, file)
rv = current_app.response_class(data, mimetype=mimetype, headers=headers, rv = current_app.response_class(
direct_passthrough=True) data, mimetype=mimetype, headers=headers, direct_passthrough=True
)
if last_modified is not None: if last_modified is not None:
rv.last_modified = last_modified rv.last_modified = last_modified
@ -634,22 +650,29 @@ def send_file(filename_or_fp, mimetype=None, as_attachment=False,
from warnings import warn from warnings import warn
try: try:
rv.set_etag('%s-%s-%s' % ( rv.set_etag(
os.path.getmtime(filename), "%s-%s-%s"
os.path.getsize(filename), % (
adler32( os.path.getmtime(filename),
filename.encode('utf-8') if isinstance(filename, text_type) os.path.getsize(filename),
else filename adler32(
) & 0xffffffff filename.encode("utf-8")
)) if isinstance(filename, text_type)
else filename
)
& 0xFFFFFFFF,
)
)
except OSError: except OSError:
warn('Access %s failed, maybe it does not exist, so ignore etags in ' warn(
'headers' % filename, stacklevel=2) "Access %s failed, maybe it does not exist, so ignore etags in "
"headers" % filename,
stacklevel=2,
)
if conditional: if conditional:
try: try:
rv = rv.make_conditional(request, accept_ranges=True, rv = rv.make_conditional(request, accept_ranges=True, complete_length=fsize)
complete_length=fsize)
except RequestedRangeNotSatisfiable: except RequestedRangeNotSatisfiable:
if file is not None: if file is not None:
file.close() file.close()
@ -657,7 +680,7 @@ def send_file(filename_or_fp, mimetype=None, as_attachment=False,
# make sure we don't send x-sendfile for servers that # make sure we don't send x-sendfile for servers that
# ignore the 304 status code for x-sendfile. # ignore the 304 status code for x-sendfile.
if rv.status_code == 304: if rv.status_code == 304:
rv.headers.pop('x-sendfile', None) rv.headers.pop("x-sendfile", None)
return rv return rv
@ -682,14 +705,14 @@ def safe_join(directory, *pathnames):
parts = [directory] parts = [directory]
for filename in pathnames: for filename in pathnames:
if filename != '': if filename != "":
filename = posixpath.normpath(filename) filename = posixpath.normpath(filename)
if ( if (
any(sep in filename for sep in _os_alt_seps) any(sep in filename for sep in _os_alt_seps)
or os.path.isabs(filename) or os.path.isabs(filename)
or filename == '..' or filename == ".."
or filename.startswith('../') or filename.startswith("../")
): ):
raise NotFound() raise NotFound()
@ -735,7 +758,7 @@ def send_from_directory(directory, filename, **options):
raise NotFound() raise NotFound()
except (TypeError, ValueError): except (TypeError, ValueError):
raise BadRequest() raise BadRequest()
options.setdefault('conditional', True) options.setdefault("conditional", True)
return send_file(filename, **options) return send_file(filename, **options)
@ -747,7 +770,7 @@ def get_root_path(import_name):
""" """
# Module already imported and has a file attribute. Use that first. # Module already imported and has a file attribute. Use that first.
mod = sys.modules.get(import_name) mod = sys.modules.get(import_name)
if mod is not None and hasattr(mod, '__file__'): if mod is not None and hasattr(mod, "__file__"):
return os.path.dirname(os.path.abspath(mod.__file__)) return os.path.dirname(os.path.abspath(mod.__file__))
# Next attempt: check the loader. # Next attempt: check the loader.
@ -756,30 +779,32 @@ def get_root_path(import_name):
# Loader does not exist or we're referring to an unloaded main module # Loader does not exist or we're referring to an unloaded main module
# or a main module without path (interactive sessions), go with the # or a main module without path (interactive sessions), go with the
# current working directory. # current working directory.
if loader is None or import_name == '__main__': if loader is None or import_name == "__main__":
return os.getcwd() return os.getcwd()
# For .egg, zipimporter does not have get_filename until Python 2.7. # For .egg, zipimporter does not have get_filename until Python 2.7.
# Some other loaders might exhibit the same behavior. # Some other loaders might exhibit the same behavior.
if hasattr(loader, 'get_filename'): if hasattr(loader, "get_filename"):
filepath = loader.get_filename(import_name) filepath = loader.get_filename(import_name)
else: else:
# Fall back to imports. # Fall back to imports.
__import__(import_name) __import__(import_name)
mod = sys.modules[import_name] mod = sys.modules[import_name]
filepath = getattr(mod, '__file__', None) filepath = getattr(mod, "__file__", None)
# If we don't have a filepath it might be because we are a # If we don't have a filepath it might be because we are a
# namespace package. In this case we pick the root path from the # namespace package. In this case we pick the root path from the
# first module that is contained in our package. # first module that is contained in our package.
if filepath is None: if filepath is None:
raise RuntimeError('No root path can be found for the provided ' raise RuntimeError(
'module "%s". This can happen because the ' "No root path can be found for the provided "
'module came from an import hook that does ' 'module "%s". This can happen because the '
'not provide file name information or because ' "module came from an import hook that does "
'it\'s a namespace package. In this case ' "not provide file name information or because "
'the root path needs to be explicitly ' "it's a namespace package. In this case "
'provided.' % import_name) "the root path needs to be explicitly "
"provided." % import_name
)
# filepath is import_name.py for a module, or __init__.py for a package. # filepath is import_name.py for a module, or __init__.py for a package.
return os.path.dirname(os.path.abspath(filepath)) return os.path.dirname(os.path.abspath(filepath))
@ -791,21 +816,26 @@ def _matching_loader_thinks_module_is_package(loader, mod_name):
""" """
# If the loader can tell us if something is a package, we can # If the loader can tell us if something is a package, we can
# directly ask the loader. # directly ask the loader.
if hasattr(loader, 'is_package'): if hasattr(loader, "is_package"):
return loader.is_package(mod_name) return loader.is_package(mod_name)
# importlib's namespace loaders do not have this functionality but # importlib's namespace loaders do not have this functionality but
# all the modules it loads are packages, so we can take advantage of # all the modules it loads are packages, so we can take advantage of
# this information. # this information.
elif (loader.__class__.__module__ == '_frozen_importlib' and elif (
loader.__class__.__name__ == 'NamespaceLoader'): loader.__class__.__module__ == "_frozen_importlib"
and loader.__class__.__name__ == "NamespaceLoader"
):
return True return True
# Otherwise we need to fail with an error that explains what went # Otherwise we need to fail with an error that explains what went
# wrong. # wrong.
raise AttributeError( raise AttributeError(
('%s.is_package() method is missing but is required by Flask of ' (
'PEP 302 import hooks. If you do not use import hooks and ' "%s.is_package() method is missing but is required by Flask of "
'you encounter this error please file a bug against Flask.') % "PEP 302 import hooks. If you do not use import hooks and "
loader.__class__.__name__) "you encounter this error please file a bug against Flask."
)
% loader.__class__.__name__
)
def find_package(import_name): def find_package(import_name):
@ -816,16 +846,16 @@ def find_package(import_name):
import the module. The prefix is the path below which a UNIX like import the module. The prefix is the path below which a UNIX like
folder structure exists (lib, share etc.). folder structure exists (lib, share etc.).
""" """
root_mod_name = import_name.split('.')[0] root_mod_name = import_name.split(".")[0]
loader = pkgutil.get_loader(root_mod_name) loader = pkgutil.get_loader(root_mod_name)
if loader is None or import_name == '__main__': if loader is None or import_name == "__main__":
# import name is not found, or interactive/main module # import name is not found, or interactive/main module
package_path = os.getcwd() package_path = os.getcwd()
else: else:
# For .egg, zipimporter does not have get_filename until Python 2.7. # For .egg, zipimporter does not have get_filename until Python 2.7.
if hasattr(loader, 'get_filename'): if hasattr(loader, "get_filename"):
filename = loader.get_filename(root_mod_name) filename = loader.get_filename(root_mod_name)
elif hasattr(loader, 'archive'): elif hasattr(loader, "archive"):
# zipimporter's loader.archive points to the .egg or .zip # zipimporter's loader.archive points to the .egg or .zip
# archive filename is dropped in call to dirname below. # archive filename is dropped in call to dirname below.
filename = loader.archive filename = loader.archive
@ -841,21 +871,20 @@ def find_package(import_name):
# In case the root module is a package we need to chop of the # In case the root module is a package we need to chop of the
# rightmost part. This needs to go through a helper function # rightmost part. This needs to go through a helper function
# because of python 3.3 namespace packages. # because of python 3.3 namespace packages.
if _matching_loader_thinks_module_is_package( if _matching_loader_thinks_module_is_package(loader, root_mod_name):
loader, root_mod_name):
package_path = os.path.dirname(package_path) package_path = os.path.dirname(package_path)
site_parent, site_folder = os.path.split(package_path) site_parent, site_folder = os.path.split(package_path)
py_prefix = os.path.abspath(sys.prefix) py_prefix = os.path.abspath(sys.prefix)
if package_path.startswith(py_prefix): if package_path.startswith(py_prefix):
return py_prefix, package_path return py_prefix, package_path
elif site_folder.lower() == 'site-packages': elif site_folder.lower() == "site-packages":
parent, folder = os.path.split(site_parent) parent, folder = os.path.split(site_parent)
# Windows like installations # Windows like installations
if folder.lower() == 'lib': if folder.lower() == "lib":
base_dir = parent base_dir = parent
# UNIX like installations # UNIX like installations
elif os.path.basename(parent).lower() == 'lib': elif os.path.basename(parent).lower() == "lib":
base_dir = os.path.dirname(parent) base_dir = os.path.dirname(parent)
else: else:
base_dir = site_parent base_dir = site_parent
@ -921,8 +950,9 @@ class _PackageBoundObject(object):
self._static_folder = value self._static_folder = value
static_folder = property( static_folder = property(
_get_static_folder, _set_static_folder, _get_static_folder,
doc='The absolute path to the configured static folder.' _set_static_folder,
doc="The absolute path to the configured static folder.",
) )
del _get_static_folder, _set_static_folder del _get_static_folder, _set_static_folder
@ -931,14 +961,15 @@ class _PackageBoundObject(object):
return self._static_url_path return self._static_url_path
if self.static_folder is not None: if self.static_folder is not None:
return '/' + os.path.basename(self.static_folder) return "/" + os.path.basename(self.static_folder)
def _set_static_url_path(self, value): def _set_static_url_path(self, value):
self._static_url_path = value self._static_url_path = value
static_url_path = property( static_url_path = property(
_get_static_url_path, _set_static_url_path, _get_static_url_path,
doc='The URL prefix that the static route will be registered for.' _set_static_url_path,
doc="The URL prefix that the static route will be registered for.",
) )
del _get_static_url_path, _set_static_url_path del _get_static_url_path, _set_static_url_path
@ -958,8 +989,7 @@ class _PackageBoundObject(object):
.. versionadded:: 0.5 .. versionadded:: 0.5
""" """
if self.template_folder is not None: if self.template_folder is not None:
return FileSystemLoader(os.path.join(self.root_path, return FileSystemLoader(os.path.join(self.root_path, self.template_folder))
self.template_folder))
def get_send_file_max_age(self, filename): def get_send_file_max_age(self, filename):
"""Provides default cache_timeout for the :func:`send_file` functions. """Provides default cache_timeout for the :func:`send_file` functions.
@ -994,14 +1024,15 @@ class _PackageBoundObject(object):
.. versionadded:: 0.5 .. versionadded:: 0.5
""" """
if not self.has_static_folder: if not self.has_static_folder:
raise RuntimeError('No static folder for this object') raise RuntimeError("No static folder for this object")
# Ensure get_send_file_max_age is called in all cases. # Ensure get_send_file_max_age is called in all cases.
# Here, we ensure get_send_file_max_age is called for Blueprints. # Here, we ensure get_send_file_max_age is called for Blueprints.
cache_timeout = self.get_send_file_max_age(filename) cache_timeout = self.get_send_file_max_age(filename)
return send_from_directory(self.static_folder, filename, return send_from_directory(
cache_timeout=cache_timeout) self.static_folder, filename, cache_timeout=cache_timeout
)
def open_resource(self, resource, mode='rb'): def open_resource(self, resource, mode="rb"):
"""Opens a resource from the application's resource folder. To see """Opens a resource from the application's resource folder. To see
how this works, consider the following folder structure:: how this works, consider the following folder structure::
@ -1024,8 +1055,8 @@ class _PackageBoundObject(object):
subfolders use forward slashes as separator. subfolders use forward slashes as separator.
:param mode: resource file opening mode, default is 'rb'. :param mode: resource file opening mode, default is 'rb'.
""" """
if mode not in ('r', 'rb'): if mode not in ("r", "rb"):
raise ValueError('Resources can only be opened for reading') raise ValueError("Resources can only be opened for reading")
return open(os.path.join(self.root_path, resource), mode) return open(os.path.join(self.root_path, resource), mode)
@ -1052,7 +1083,7 @@ def is_ip(value):
:return: True if string is an IP address :return: True if string is an IP address
:rtype: bool :rtype: bool
""" """
if PY2 and os.name == 'nt': if PY2 and os.name == "nt":
try: try:
socket.inet_aton(value) socket.inet_aton(value)
return True return True

View file

@ -23,12 +23,20 @@ from itsdangerous import json as _json
# Figure out if simplejson escapes slashes. This behavior was changed # Figure out if simplejson escapes slashes. This behavior was changed
# from one version to another without reason. # from one version to another without reason.
_slash_escape = '\\/' not in _json.dumps('/') _slash_escape = "\\/" not in _json.dumps("/")
__all__ = ['dump', 'dumps', 'load', 'loads', 'htmlsafe_dump', __all__ = [
'htmlsafe_dumps', 'JSONDecoder', 'JSONEncoder', "dump",
'jsonify'] "dumps",
"load",
"loads",
"htmlsafe_dump",
"htmlsafe_dumps",
"JSONDecoder",
"JSONEncoder",
"jsonify",
]
def _wrap_reader_for_text(fp, encoding): def _wrap_reader_for_text(fp, encoding):
@ -39,7 +47,7 @@ def _wrap_reader_for_text(fp, encoding):
def _wrap_writer_for_text(fp, encoding): def _wrap_writer_for_text(fp, encoding):
try: try:
fp.write('') fp.write("")
except TypeError: except TypeError:
fp = io.TextIOWrapper(fp, encoding) fp = io.TextIOWrapper(fp, encoding)
return fp return fp
@ -76,7 +84,7 @@ class JSONEncoder(_json.JSONEncoder):
return http_date(o.timetuple()) return http_date(o.timetuple())
if isinstance(o, uuid.UUID): if isinstance(o, uuid.UUID):
return str(o) return str(o)
if hasattr(o, '__html__'): if hasattr(o, "__html__"):
return text_type(o.__html__()) return text_type(o.__html__())
return _json.JSONEncoder.default(self, o) return _json.JSONEncoder.default(self, o)
@ -94,18 +102,17 @@ def _dump_arg_defaults(kwargs):
if current_app: if current_app:
bp = current_app.blueprints.get(request.blueprint) if request else None bp = current_app.blueprints.get(request.blueprint) if request else None
kwargs.setdefault( kwargs.setdefault(
'cls', "cls",
bp.json_encoder if bp and bp.json_encoder bp.json_encoder if bp and bp.json_encoder else current_app.json_encoder,
else current_app.json_encoder
) )
if not current_app.config['JSON_AS_ASCII']: if not current_app.config["JSON_AS_ASCII"]:
kwargs.setdefault('ensure_ascii', False) kwargs.setdefault("ensure_ascii", False)
kwargs.setdefault('sort_keys', current_app.config['JSON_SORT_KEYS']) kwargs.setdefault("sort_keys", current_app.config["JSON_SORT_KEYS"])
else: else:
kwargs.setdefault('sort_keys', True) kwargs.setdefault("sort_keys", True)
kwargs.setdefault('cls', JSONEncoder) kwargs.setdefault("cls", JSONEncoder)
def _load_arg_defaults(kwargs): def _load_arg_defaults(kwargs):
@ -113,12 +120,11 @@ def _load_arg_defaults(kwargs):
if current_app: if current_app:
bp = current_app.blueprints.get(request.blueprint) if request else None bp = current_app.blueprints.get(request.blueprint) if request else None
kwargs.setdefault( kwargs.setdefault(
'cls', "cls",
bp.json_decoder if bp and bp.json_decoder bp.json_decoder if bp and bp.json_decoder else current_app.json_decoder,
else current_app.json_decoder
) )
else: else:
kwargs.setdefault('cls', JSONDecoder) kwargs.setdefault("cls", JSONDecoder)
def detect_encoding(data): def detect_encoding(data):
@ -134,34 +140,34 @@ def detect_encoding(data):
head = data[:4] head = data[:4]
if head[:3] == codecs.BOM_UTF8: if head[:3] == codecs.BOM_UTF8:
return 'utf-8-sig' return "utf-8-sig"
if b'\x00' not in head: if b"\x00" not in head:
return 'utf-8' return "utf-8"
if head in (codecs.BOM_UTF32_BE, codecs.BOM_UTF32_LE): if head in (codecs.BOM_UTF32_BE, codecs.BOM_UTF32_LE):
return 'utf-32' return "utf-32"
if head[:2] in (codecs.BOM_UTF16_BE, codecs.BOM_UTF16_LE): if head[:2] in (codecs.BOM_UTF16_BE, codecs.BOM_UTF16_LE):
return 'utf-16' return "utf-16"
if len(head) == 4: if len(head) == 4:
if head[:3] == b'\x00\x00\x00': if head[:3] == b"\x00\x00\x00":
return 'utf-32-be' return "utf-32-be"
if head[::2] == b'\x00\x00': if head[::2] == b"\x00\x00":
return 'utf-16-be' return "utf-16-be"
if head[1:] == b'\x00\x00\x00': if head[1:] == b"\x00\x00\x00":
return 'utf-32-le' return "utf-32-le"
if head[1::2] == b'\x00\x00': if head[1::2] == b"\x00\x00":
return 'utf-16-le' return "utf-16-le"
if len(head) == 2: if len(head) == 2:
return 'utf-16-be' if head.startswith(b'\x00') else 'utf-16-le' return "utf-16-be" if head.startswith(b"\x00") else "utf-16-le"
return 'utf-8' return "utf-8"
def dumps(obj, **kwargs): def dumps(obj, **kwargs):
@ -175,7 +181,7 @@ def dumps(obj, **kwargs):
and can be overridden by the simplejson ``ensure_ascii`` parameter. and can be overridden by the simplejson ``ensure_ascii`` parameter.
""" """
_dump_arg_defaults(kwargs) _dump_arg_defaults(kwargs)
encoding = kwargs.pop('encoding', None) encoding = kwargs.pop("encoding", None)
rv = _json.dumps(obj, **kwargs) rv = _json.dumps(obj, **kwargs)
if encoding is not None and isinstance(rv, text_type): if encoding is not None and isinstance(rv, text_type):
rv = rv.encode(encoding) rv = rv.encode(encoding)
@ -185,7 +191,7 @@ def dumps(obj, **kwargs):
def dump(obj, fp, **kwargs): def dump(obj, fp, **kwargs):
"""Like :func:`dumps` but writes into a file object.""" """Like :func:`dumps` but writes into a file object."""
_dump_arg_defaults(kwargs) _dump_arg_defaults(kwargs)
encoding = kwargs.pop('encoding', None) encoding = kwargs.pop("encoding", None)
if encoding is not None: if encoding is not None:
fp = _wrap_writer_for_text(fp, encoding) fp = _wrap_writer_for_text(fp, encoding)
_json.dump(obj, fp, **kwargs) _json.dump(obj, fp, **kwargs)
@ -198,7 +204,7 @@ def loads(s, **kwargs):
""" """
_load_arg_defaults(kwargs) _load_arg_defaults(kwargs)
if isinstance(s, bytes): if isinstance(s, bytes):
encoding = kwargs.pop('encoding', None) encoding = kwargs.pop("encoding", None)
if encoding is None: if encoding is None:
encoding = detect_encoding(s) encoding = detect_encoding(s)
s = s.decode(encoding) s = s.decode(encoding)
@ -210,7 +216,7 @@ def load(fp, **kwargs):
""" """
_load_arg_defaults(kwargs) _load_arg_defaults(kwargs)
if not PY2: if not PY2:
fp = _wrap_reader_for_text(fp, kwargs.pop('encoding', None) or 'utf-8') fp = _wrap_reader_for_text(fp, kwargs.pop("encoding", None) or "utf-8")
return _json.load(fp, **kwargs) return _json.load(fp, **kwargs)
@ -239,13 +245,15 @@ def htmlsafe_dumps(obj, **kwargs):
quoted. Always single quote attributes if you use the ``|tojson`` quoted. Always single quote attributes if you use the ``|tojson``
filter. Alternatively use ``|tojson|forceescape``. filter. Alternatively use ``|tojson|forceescape``.
""" """
rv = dumps(obj, **kwargs) \ rv = (
.replace(u'<', u'\\u003c') \ dumps(obj, **kwargs)
.replace(u'>', u'\\u003e') \ .replace(u"<", u"\\u003c")
.replace(u'&', u'\\u0026') \ .replace(u">", u"\\u003e")
.replace(u"'", u'\\u0027') .replace(u"&", u"\\u0026")
.replace(u"'", u"\\u0027")
)
if not _slash_escape: if not _slash_escape:
rv = rv.replace('\\/', '/') rv = rv.replace("\\/", "/")
return rv return rv
@ -304,22 +312,22 @@ def jsonify(*args, **kwargs):
""" """
indent = None indent = None
separators = (',', ':') separators = (",", ":")
if current_app.config['JSONIFY_PRETTYPRINT_REGULAR'] or current_app.debug: if current_app.config["JSONIFY_PRETTYPRINT_REGULAR"] or current_app.debug:
indent = 2 indent = 2
separators = (', ', ': ') separators = (", ", ": ")
if args and kwargs: if args and kwargs:
raise TypeError('jsonify() behavior undefined when passed both args and kwargs') raise TypeError("jsonify() behavior undefined when passed both args and kwargs")
elif len(args) == 1: # single args are passed directly to dumps() elif len(args) == 1: # single args are passed directly to dumps()
data = args[0] data = args[0]
else: else:
data = args or kwargs data = args or kwargs
return current_app.response_class( return current_app.response_class(
dumps(data, indent=indent, separators=separators) + '\n', dumps(data, indent=indent, separators=separators) + "\n",
mimetype=current_app.config['JSONIFY_MIMETYPE'] mimetype=current_app.config["JSONIFY_MIMETYPE"],
) )

View file

@ -56,7 +56,7 @@ from flask.json import dumps, loads
class JSONTag(object): class JSONTag(object):
"""Base class for defining type tags for :class:`TaggedJSONSerializer`.""" """Base class for defining type tags for :class:`TaggedJSONSerializer`."""
__slots__ = ('serializer',) __slots__ = ("serializer",)
#: The tag to mark the serialized object with. If ``None``, this tag is #: The tag to mark the serialized object with. If ``None``, this tag is
#: only used as an intermediate step during tagging. #: only used as an intermediate step during tagging.
@ -94,7 +94,7 @@ class TagDict(JSONTag):
""" """
__slots__ = () __slots__ = ()
key = ' di' key = " di"
def check(self, value): def check(self, value):
return ( return (
@ -105,7 +105,7 @@ class TagDict(JSONTag):
def to_json(self, value): def to_json(self, value):
key = next(iter(value)) key = next(iter(value))
return {key + '__': self.serializer.tag(value[key])} return {key + "__": self.serializer.tag(value[key])}
def to_python(self, value): def to_python(self, value):
key = next(iter(value)) key = next(iter(value))
@ -128,7 +128,7 @@ class PassDict(JSONTag):
class TagTuple(JSONTag): class TagTuple(JSONTag):
__slots__ = () __slots__ = ()
key = ' t' key = " t"
def check(self, value): def check(self, value):
return isinstance(value, tuple) return isinstance(value, tuple)
@ -154,13 +154,13 @@ class PassList(JSONTag):
class TagBytes(JSONTag): class TagBytes(JSONTag):
__slots__ = () __slots__ = ()
key = ' b' key = " b"
def check(self, value): def check(self, value):
return isinstance(value, bytes) return isinstance(value, bytes)
def to_json(self, value): def to_json(self, value):
return b64encode(value).decode('ascii') return b64encode(value).decode("ascii")
def to_python(self, value): def to_python(self, value):
return b64decode(value) return b64decode(value)
@ -172,10 +172,10 @@ class TagMarkup(JSONTag):
deserializes to an instance of :class:`~flask.Markup`.""" deserializes to an instance of :class:`~flask.Markup`."""
__slots__ = () __slots__ = ()
key = ' m' key = " m"
def check(self, value): def check(self, value):
return callable(getattr(value, '__html__', None)) return callable(getattr(value, "__html__", None))
def to_json(self, value): def to_json(self, value):
return text_type(value.__html__()) return text_type(value.__html__())
@ -186,7 +186,7 @@ class TagMarkup(JSONTag):
class TagUUID(JSONTag): class TagUUID(JSONTag):
__slots__ = () __slots__ = ()
key = ' u' key = " u"
def check(self, value): def check(self, value):
return isinstance(value, UUID) return isinstance(value, UUID)
@ -200,7 +200,7 @@ class TagUUID(JSONTag):
class TagDateTime(JSONTag): class TagDateTime(JSONTag):
__slots__ = () __slots__ = ()
key = ' d' key = " d"
def check(self, value): def check(self, value):
return isinstance(value, datetime) return isinstance(value, datetime)
@ -227,12 +227,18 @@ class TaggedJSONSerializer(object):
* :class:`~datetime.datetime` * :class:`~datetime.datetime`
""" """
__slots__ = ('tags', 'order') __slots__ = ("tags", "order")
#: Tag classes to bind when creating the serializer. Other tags can be #: Tag classes to bind when creating the serializer. Other tags can be
#: added later using :meth:`~register`. #: added later using :meth:`~register`.
default_tags = [ default_tags = [
TagDict, PassDict, TagTuple, PassList, TagBytes, TagMarkup, TagUUID, TagDict,
PassDict,
TagTuple,
PassList,
TagBytes,
TagMarkup,
TagUUID,
TagDateTime, TagDateTime,
] ]
@ -293,7 +299,7 @@ class TaggedJSONSerializer(object):
def dumps(self, value): def dumps(self, value):
"""Tag the value and dump it to a compact JSON string.""" """Tag the value and dump it to a compact JSON string."""
return dumps(self.tag(value), separators=(',', ':')) return dumps(self.tag(value), separators=(",", ":"))
def loads(self, value): def loads(self, value):
"""Load data from a JSON string and deserialized any tagged objects.""" """Load data from a JSON string and deserialized any tagged objects."""

View file

@ -27,7 +27,7 @@ def wsgi_errors_stream():
can't import this directly, you can refer to it as can't import this directly, you can refer to it as
``ext://flask.logging.wsgi_errors_stream``. ``ext://flask.logging.wsgi_errors_stream``.
""" """
return request.environ['wsgi.errors'] if request else sys.stderr return request.environ["wsgi.errors"] if request else sys.stderr
def has_level_handler(logger): def has_level_handler(logger):
@ -52,9 +52,9 @@ def has_level_handler(logger):
#: Log messages to :func:`~flask.logging.wsgi_errors_stream` with the format #: Log messages to :func:`~flask.logging.wsgi_errors_stream` with the format
#: ``[%(asctime)s] %(levelname)s in %(module)s: %(message)s``. #: ``[%(asctime)s] %(levelname)s in %(module)s: %(message)s``.
default_handler = logging.StreamHandler(wsgi_errors_stream) default_handler = logging.StreamHandler(wsgi_errors_stream)
default_handler.setFormatter(logging.Formatter( default_handler.setFormatter(
'[%(asctime)s] %(levelname)s in %(module)s: %(message)s' logging.Formatter("[%(asctime)s] %(levelname)s in %(module)s: %(message)s")
)) )
def create_logger(app): def create_logger(app):
@ -67,7 +67,7 @@ def create_logger(app):
:class:`~logging.StreamHandler` for :class:`~logging.StreamHandler` for
:func:`~flask.logging.wsgi_errors_stream` with a basic format. :func:`~flask.logging.wsgi_errors_stream` with a basic format.
""" """
logger = logging.getLogger('flask.app') logger = logging.getLogger("flask.app")
if app.debug and logger.level == logging.NOTSET: if app.debug and logger.level == logging.NOTSET:
logger.setLevel(logging.DEBUG) logger.setLevel(logging.DEBUG)

View file

@ -27,11 +27,11 @@ class SessionMixin(collections_abc.MutableMapping):
@property @property
def permanent(self): def permanent(self):
"""This reflects the ``'_permanent'`` key in the dict.""" """This reflects the ``'_permanent'`` key in the dict."""
return self.get('_permanent', False) return self.get("_permanent", False)
@permanent.setter @permanent.setter
def permanent(self, value): def permanent(self, value):
self['_permanent'] = bool(value) self["_permanent"] = bool(value)
#: Some implementations can detect whether a session is newly #: Some implementations can detect whether a session is newly
#: created, but that is not guaranteed. Use with caution. The mixin #: created, but that is not guaranteed. Use with caution. The mixin
@ -98,11 +98,13 @@ class NullSession(SecureCookieSession):
""" """
def _fail(self, *args, **kwargs): def _fail(self, *args, **kwargs):
raise RuntimeError('The session is unavailable because no secret ' raise RuntimeError(
'key was set. Set the secret_key on the ' "The session is unavailable because no secret "
'application to something unique and secret.') "key was set. Set the secret_key on the "
__setitem__ = __delitem__ = clear = pop = popitem = \ "application to something unique and secret."
update = setdefault = _fail )
__setitem__ = __delitem__ = clear = pop = popitem = update = setdefault = _fail
del _fail del _fail
@ -180,52 +182,52 @@ class SessionInterface(object):
updated to avoid re-running the logic. updated to avoid re-running the logic.
""" """
rv = app.config['SESSION_COOKIE_DOMAIN'] rv = app.config["SESSION_COOKIE_DOMAIN"]
# set explicitly, or cached from SERVER_NAME detection # set explicitly, or cached from SERVER_NAME detection
# if False, return None # if False, return None
if rv is not None: if rv is not None:
return rv if rv else None return rv if rv else None
rv = app.config['SERVER_NAME'] rv = app.config["SERVER_NAME"]
# server name not set, cache False to return none next time # server name not set, cache False to return none next time
if not rv: if not rv:
app.config['SESSION_COOKIE_DOMAIN'] = False app.config["SESSION_COOKIE_DOMAIN"] = False
return None return None
# chop off the port which is usually not supported by browsers # chop off the port which is usually not supported by browsers
# remove any leading '.' since we'll add that later # remove any leading '.' since we'll add that later
rv = rv.rsplit(':', 1)[0].lstrip('.') rv = rv.rsplit(":", 1)[0].lstrip(".")
if '.' not in rv: if "." not in rv:
# Chrome doesn't allow names without a '.' # Chrome doesn't allow names without a '.'
# this should only come up with localhost # this should only come up with localhost
# hack around this by not setting the name, and show a warning # hack around this by not setting the name, and show a warning
warnings.warn( warnings.warn(
'"{rv}" is not a valid cookie domain, it must contain a ".".' '"{rv}" is not a valid cookie domain, it must contain a ".".'
' Add an entry to your hosts file, for example' " Add an entry to your hosts file, for example"
' "{rv}.localdomain", and use that instead.'.format(rv=rv) ' "{rv}.localdomain", and use that instead.'.format(rv=rv)
) )
app.config['SESSION_COOKIE_DOMAIN'] = False app.config["SESSION_COOKIE_DOMAIN"] = False
return None return None
ip = is_ip(rv) ip = is_ip(rv)
if ip: if ip:
warnings.warn( warnings.warn(
'The session cookie domain is an IP address. This may not work' "The session cookie domain is an IP address. This may not work"
' as intended in some browsers. Add an entry to your hosts' " as intended in some browsers. Add an entry to your hosts"
' file, for example "localhost.localdomain", and use that' ' file, for example "localhost.localdomain", and use that'
' instead.' " instead."
) )
# if this is not an ip and app is mounted at the root, allow subdomain # if this is not an ip and app is mounted at the root, allow subdomain
# matching by adding a '.' prefix # matching by adding a '.' prefix
if self.get_cookie_path(app) == '/' and not ip: if self.get_cookie_path(app) == "/" and not ip:
rv = '.' + rv rv = "." + rv
app.config['SESSION_COOKIE_DOMAIN'] = rv app.config["SESSION_COOKIE_DOMAIN"] = rv
return rv return rv
def get_cookie_path(self, app): def get_cookie_path(self, app):
@ -234,28 +236,27 @@ class SessionInterface(object):
config var if it's set, and falls back to ``APPLICATION_ROOT`` or config var if it's set, and falls back to ``APPLICATION_ROOT`` or
uses ``/`` if it's ``None``. uses ``/`` if it's ``None``.
""" """
return app.config['SESSION_COOKIE_PATH'] \ return app.config["SESSION_COOKIE_PATH"] or app.config["APPLICATION_ROOT"]
or app.config['APPLICATION_ROOT']
def get_cookie_httponly(self, app): def get_cookie_httponly(self, app):
"""Returns True if the session cookie should be httponly. This """Returns True if the session cookie should be httponly. This
currently just returns the value of the ``SESSION_COOKIE_HTTPONLY`` currently just returns the value of the ``SESSION_COOKIE_HTTPONLY``
config var. config var.
""" """
return app.config['SESSION_COOKIE_HTTPONLY'] return app.config["SESSION_COOKIE_HTTPONLY"]
def get_cookie_secure(self, app): def get_cookie_secure(self, app):
"""Returns True if the cookie should be secure. This currently """Returns True if the cookie should be secure. This currently
just returns the value of the ``SESSION_COOKIE_SECURE`` setting. just returns the value of the ``SESSION_COOKIE_SECURE`` setting.
""" """
return app.config['SESSION_COOKIE_SECURE'] return app.config["SESSION_COOKIE_SECURE"]
def get_cookie_samesite(self, app): def get_cookie_samesite(self, app):
"""Return ``'Strict'`` or ``'Lax'`` if the cookie should use the """Return ``'Strict'`` or ``'Lax'`` if the cookie should use the
``SameSite`` attribute. This currently just returns the value of ``SameSite`` attribute. This currently just returns the value of
the :data:`SESSION_COOKIE_SAMESITE` setting. the :data:`SESSION_COOKIE_SAMESITE` setting.
""" """
return app.config['SESSION_COOKIE_SAMESITE'] return app.config["SESSION_COOKIE_SAMESITE"]
def get_expiration_time(self, app, session): def get_expiration_time(self, app, session):
"""A helper method that returns an expiration date for the session """A helper method that returns an expiration date for the session
@ -279,7 +280,7 @@ class SessionInterface(object):
""" """
return session.modified or ( return session.modified or (
session.permanent and app.config['SESSION_REFRESH_EACH_REQUEST'] session.permanent and app.config["SESSION_REFRESH_EACH_REQUEST"]
) )
def open_session(self, app, request): def open_session(self, app, request):
@ -306,14 +307,15 @@ class SecureCookieSessionInterface(SessionInterface):
"""The default session interface that stores sessions in signed cookies """The default session interface that stores sessions in signed cookies
through the :mod:`itsdangerous` module. through the :mod:`itsdangerous` module.
""" """
#: the salt that should be applied on top of the secret key for the #: the salt that should be applied on top of the secret key for the
#: signing of cookie based sessions. #: signing of cookie based sessions.
salt = 'cookie-session' salt = "cookie-session"
#: the hash function to use for the signature. The default is sha1 #: the hash function to use for the signature. The default is sha1
digest_method = staticmethod(hashlib.sha1) digest_method = staticmethod(hashlib.sha1)
#: the name of the itsdangerous supported key derivation. The default #: the name of the itsdangerous supported key derivation. The default
#: is hmac. #: is hmac.
key_derivation = 'hmac' key_derivation = "hmac"
#: A python serializer for the payload. The default is a compact #: A python serializer for the payload. The default is a compact
#: JSON derived serializer with support for some extra Python types #: JSON derived serializer with support for some extra Python types
#: such as datetime objects or tuples. #: such as datetime objects or tuples.
@ -324,12 +326,14 @@ class SecureCookieSessionInterface(SessionInterface):
if not app.secret_key: if not app.secret_key:
return None return None
signer_kwargs = dict( signer_kwargs = dict(
key_derivation=self.key_derivation, key_derivation=self.key_derivation, digest_method=self.digest_method
digest_method=self.digest_method )
return URLSafeTimedSerializer(
app.secret_key,
salt=self.salt,
serializer=self.serializer,
signer_kwargs=signer_kwargs,
) )
return URLSafeTimedSerializer(app.secret_key, salt=self.salt,
serializer=self.serializer,
signer_kwargs=signer_kwargs)
def open_session(self, app, request): def open_session(self, app, request):
s = self.get_signing_serializer(app) s = self.get_signing_serializer(app)
@ -354,16 +358,14 @@ class SecureCookieSessionInterface(SessionInterface):
if not session: if not session:
if session.modified: if session.modified:
response.delete_cookie( response.delete_cookie(
app.session_cookie_name, app.session_cookie_name, domain=domain, path=path
domain=domain,
path=path
) )
return return
# Add a "Vary: Cookie" header if the session was accessed at all. # Add a "Vary: Cookie" header if the session was accessed at all.
if session.accessed: if session.accessed:
response.vary.add('Cookie') response.vary.add("Cookie")
if not self.should_set_cookie(app, session): if not self.should_set_cookie(app, session):
return return
@ -381,5 +383,5 @@ class SecureCookieSessionInterface(SessionInterface):
domain=domain, domain=domain,
path=path, path=path,
secure=secure, secure=secure,
samesite=samesite samesite=samesite,
) )

View file

@ -13,8 +13,10 @@
signals_available = False signals_available = False
try: try:
from blinker import Namespace from blinker import Namespace
signals_available = True signals_available = True
except ImportError: except ImportError:
class Namespace(object): class Namespace(object):
def signal(self, name, doc=None): def signal(self, name, doc=None):
return _FakeSignal(name, doc) return _FakeSignal(name, doc)
@ -29,15 +31,23 @@ except ImportError:
def __init__(self, name, doc=None): def __init__(self, name, doc=None):
self.name = name self.name = name
self.__doc__ = doc self.__doc__ = doc
def _fail(self, *args, **kwargs): def _fail(self, *args, **kwargs):
raise RuntimeError('signalling support is unavailable ' raise RuntimeError(
'because the blinker library is ' "signalling support is unavailable "
'not installed.') "because the blinker library is "
"not installed."
)
send = lambda *a, **kw: None send = lambda *a, **kw: None
connect = disconnect = has_receivers_for = receivers_for = \ connect = (
temporarily_connected_to = connected_to = _fail disconnect
) = (
has_receivers_for
) = receivers_for = temporarily_connected_to = connected_to = _fail
del _fail del _fail
# The namespace for code signals. If you are not Flask code, do # The namespace for code signals. If you are not Flask code, do
# not put signals in here. Create your own namespace instead. # not put signals in here. Create your own namespace instead.
_signals = Namespace() _signals = Namespace()
@ -45,13 +55,13 @@ _signals = Namespace()
# Core signals. For usage examples grep the source code or consult # Core signals. For usage examples grep the source code or consult
# the API documentation in docs/api.rst as well as docs/signals.rst # the API documentation in docs/api.rst as well as docs/signals.rst
template_rendered = _signals.signal('template-rendered') template_rendered = _signals.signal("template-rendered")
before_render_template = _signals.signal('before-render-template') before_render_template = _signals.signal("before-render-template")
request_started = _signals.signal('request-started') request_started = _signals.signal("request-started")
request_finished = _signals.signal('request-finished') request_finished = _signals.signal("request-finished")
request_tearing_down = _signals.signal('request-tearing-down') request_tearing_down = _signals.signal("request-tearing-down")
got_request_exception = _signals.signal('got-request-exception') got_request_exception = _signals.signal("got-request-exception")
appcontext_tearing_down = _signals.signal('appcontext-tearing-down') appcontext_tearing_down = _signals.signal("appcontext-tearing-down")
appcontext_pushed = _signals.signal('appcontext-pushed') appcontext_pushed = _signals.signal("appcontext-pushed")
appcontext_popped = _signals.signal('appcontext-popped') appcontext_popped = _signals.signal("appcontext-popped")
message_flashed = _signals.signal('message-flashed') message_flashed = _signals.signal("message-flashed")

View file

@ -9,8 +9,7 @@
:license: BSD, see LICENSE for more details. :license: BSD, see LICENSE for more details.
""" """
from jinja2 import BaseLoader, Environment as BaseEnvironment, \ from jinja2 import BaseLoader, Environment as BaseEnvironment, TemplateNotFound
TemplateNotFound
from .globals import _request_ctx_stack, _app_ctx_stack from .globals import _request_ctx_stack, _app_ctx_stack
from .signals import template_rendered, before_render_template from .signals import template_rendered, before_render_template
@ -24,10 +23,10 @@ def _default_template_ctx_processor():
appctx = _app_ctx_stack.top appctx = _app_ctx_stack.top
rv = {} rv = {}
if appctx is not None: if appctx is not None:
rv['g'] = appctx.g rv["g"] = appctx.g
if reqctx is not None: if reqctx is not None:
rv['request'] = reqctx.request rv["request"] = reqctx.request
rv['session'] = reqctx.session rv["session"] = reqctx.session
return rv return rv
@ -38,8 +37,8 @@ class Environment(BaseEnvironment):
""" """
def __init__(self, app, **options): def __init__(self, app, **options):
if 'loader' not in options: if "loader" not in options:
options['loader'] = app.create_global_jinja_loader() options["loader"] = app.create_global_jinja_loader()
BaseEnvironment.__init__(self, **options) BaseEnvironment.__init__(self, **options)
self.app = app self.app = app
@ -53,7 +52,7 @@ class DispatchingJinjaLoader(BaseLoader):
self.app = app self.app = app
def get_source(self, environment, template): def get_source(self, environment, template):
if self.app.config['EXPLAIN_TEMPLATE_LOADING']: if self.app.config["EXPLAIN_TEMPLATE_LOADING"]:
return self._get_source_explained(environment, template) return self._get_source_explained(environment, template)
return self._get_source_fast(environment, template) return self._get_source_fast(environment, template)
@ -71,6 +70,7 @@ class DispatchingJinjaLoader(BaseLoader):
attempts.append((loader, srcobj, rv)) attempts.append((loader, srcobj, rv))
from .debughelpers import explain_template_loading_attempts from .debughelpers import explain_template_loading_attempts
explain_template_loading_attempts(self.app, template, attempts) explain_template_loading_attempts(self.app, template, attempts)
if trv is not None: if trv is not None:
@ -131,8 +131,11 @@ def render_template(template_name_or_list, **context):
""" """
ctx = _app_ctx_stack.top ctx = _app_ctx_stack.top
ctx.app.update_template_context(context) ctx.app.update_template_context(context)
return _render(ctx.app.jinja_env.get_or_select_template(template_name_or_list), return _render(
context, ctx.app) ctx.app.jinja_env.get_or_select_template(template_name_or_list),
context,
ctx.app,
)
def render_template_string(source, **context): def render_template_string(source, **context):
@ -146,5 +149,4 @@ def render_template_string(source, **context):
""" """
ctx = _app_ctx_stack.top ctx = _app_ctx_stack.top
ctx.app.update_template_context(context) ctx.app.update_template_context(context)
return _render(ctx.app.jinja_env.from_string(source), return _render(ctx.app.jinja_env.from_string(source), context, ctx.app)
context, ctx.app)

View file

@ -22,8 +22,7 @@ from werkzeug.urls import url_parse
def make_test_environ_builder( def make_test_environ_builder(
app, path='/', base_url=None, subdomain=None, url_scheme=None, app, path="/", base_url=None, subdomain=None, url_scheme=None, *args, **kwargs
*args, **kwargs
): ):
"""Create a :class:`~werkzeug.test.EnvironBuilder`, taking some """Create a :class:`~werkzeug.test.EnvironBuilder`, taking some
defaults from the application. defaults from the application.
@ -46,44 +45,41 @@ def make_test_environ_builder(
:class:`~werkzeug.test.EnvironBuilder`. :class:`~werkzeug.test.EnvironBuilder`.
""" """
assert ( assert not (base_url or subdomain or url_scheme) or (base_url is not None) != bool(
not (base_url or subdomain or url_scheme) subdomain or url_scheme
or (base_url is not None) != bool(subdomain or url_scheme)
), 'Cannot pass "subdomain" or "url_scheme" with "base_url".' ), 'Cannot pass "subdomain" or "url_scheme" with "base_url".'
if base_url is None: if base_url is None:
http_host = app.config.get('SERVER_NAME') or 'localhost' http_host = app.config.get("SERVER_NAME") or "localhost"
app_root = app.config['APPLICATION_ROOT'] app_root = app.config["APPLICATION_ROOT"]
if subdomain: if subdomain:
http_host = '{0}.{1}'.format(subdomain, http_host) http_host = "{0}.{1}".format(subdomain, http_host)
if url_scheme is None: if url_scheme is None:
url_scheme = app.config['PREFERRED_URL_SCHEME'] url_scheme = app.config["PREFERRED_URL_SCHEME"]
url = url_parse(path) url = url_parse(path)
base_url = '{scheme}://{netloc}/{path}'.format( base_url = "{scheme}://{netloc}/{path}".format(
scheme=url.scheme or url_scheme, scheme=url.scheme or url_scheme,
netloc=url.netloc or http_host, netloc=url.netloc or http_host,
path=app_root.lstrip('/') path=app_root.lstrip("/"),
) )
path = url.path path = url.path
if url.query: if url.query:
sep = b'?' if isinstance(url.query, bytes) else '?' sep = b"?" if isinstance(url.query, bytes) else "?"
path += sep + url.query path += sep + url.query
if 'json' in kwargs: if "json" in kwargs:
assert 'data' not in kwargs, ( assert "data" not in kwargs, "Client cannot provide both 'json' and 'data'."
"Client cannot provide both 'json' and 'data'."
)
# push a context so flask.json can use app's json attributes # push a context so flask.json can use app's json attributes
with app.app_context(): with app.app_context():
kwargs['data'] = json_dumps(kwargs.pop('json')) kwargs["data"] = json_dumps(kwargs.pop("json"))
if 'content_type' not in kwargs: if "content_type" not in kwargs:
kwargs['content_type'] = 'application/json' kwargs["content_type"] = "application/json"
return EnvironBuilder(path, base_url, *args, **kwargs) return EnvironBuilder(path, base_url, *args, **kwargs)
@ -109,7 +105,7 @@ class FlaskClient(Client):
super(FlaskClient, self).__init__(*args, **kwargs) super(FlaskClient, self).__init__(*args, **kwargs)
self.environ_base = { self.environ_base = {
"REMOTE_ADDR": "127.0.0.1", "REMOTE_ADDR": "127.0.0.1",
"HTTP_USER_AGENT": "werkzeug/" + werkzeug.__version__ "HTTP_USER_AGENT": "werkzeug/" + werkzeug.__version__,
} }
@contextmanager @contextmanager
@ -131,18 +127,20 @@ class FlaskClient(Client):
passed through. passed through.
""" """
if self.cookie_jar is None: if self.cookie_jar is None:
raise RuntimeError('Session transactions only make sense ' raise RuntimeError(
'with cookies enabled.') "Session transactions only make sense " "with cookies enabled."
)
app = self.application app = self.application
environ_overrides = kwargs.setdefault('environ_overrides', {}) environ_overrides = kwargs.setdefault("environ_overrides", {})
self.cookie_jar.inject_wsgi(environ_overrides) self.cookie_jar.inject_wsgi(environ_overrides)
outer_reqctx = _request_ctx_stack.top outer_reqctx = _request_ctx_stack.top
with app.test_request_context(*args, **kwargs) as c: with app.test_request_context(*args, **kwargs) as c:
session_interface = app.session_interface session_interface = app.session_interface
sess = session_interface.open_session(app, c.request) sess = session_interface.open_session(app, c.request)
if sess is None: if sess is None:
raise RuntimeError('Session backend did not open a session. ' raise RuntimeError(
'Check the configuration') "Session backend did not open a session. " "Check the configuration"
)
# Since we have to open a new request context for the session # Since we have to open a new request context for the session
# handling we want to make sure that we hide out own context # handling we want to make sure that we hide out own context
@ -164,12 +162,13 @@ class FlaskClient(Client):
self.cookie_jar.extract_wsgi(c.request.environ, headers) self.cookie_jar.extract_wsgi(c.request.environ, headers)
def open(self, *args, **kwargs): def open(self, *args, **kwargs):
as_tuple = kwargs.pop('as_tuple', False) as_tuple = kwargs.pop("as_tuple", False)
buffered = kwargs.pop('buffered', False) buffered = kwargs.pop("buffered", False)
follow_redirects = kwargs.pop('follow_redirects', False) follow_redirects = kwargs.pop("follow_redirects", False)
if ( if (
not kwargs and len(args) == 1 not kwargs
and len(args) == 1
and isinstance(args[0], (EnvironBuilder, dict)) and isinstance(args[0], (EnvironBuilder, dict))
): ):
environ = self.environ_base.copy() environ = self.environ_base.copy()
@ -179,14 +178,13 @@ class FlaskClient(Client):
else: else:
environ.update(args[0]) environ.update(args[0])
environ['flask._preserve_context'] = self.preserve_context environ["flask._preserve_context"] = self.preserve_context
else: else:
kwargs.setdefault('environ_overrides', {}) \ kwargs.setdefault("environ_overrides", {})[
['flask._preserve_context'] = self.preserve_context "flask._preserve_context"
kwargs.setdefault('environ_base', self.environ_base) ] = self.preserve_context
builder = make_test_environ_builder( kwargs.setdefault("environ_base", self.environ_base)
self.application, *args, **kwargs builder = make_test_environ_builder(self.application, *args, **kwargs)
)
try: try:
environ = builder.get_environ() environ = builder.get_environ()
@ -194,15 +192,16 @@ class FlaskClient(Client):
builder.close() builder.close()
return Client.open( return Client.open(
self, environ, self,
environ,
as_tuple=as_tuple, as_tuple=as_tuple,
buffered=buffered, buffered=buffered,
follow_redirects=follow_redirects follow_redirects=follow_redirects,
) )
def __enter__(self): def __enter__(self):
if self.preserve_context: if self.preserve_context:
raise RuntimeError('Cannot nest client invocations') raise RuntimeError("Cannot nest client invocations")
self.preserve_context = True self.preserve_context = True
return self return self
@ -222,6 +221,7 @@ class FlaskCliRunner(CliRunner):
CLI commands. Typically created using CLI commands. Typically created using
:meth:`~flask.Flask.test_cli_runner`. See :ref:`testing-cli`. :meth:`~flask.Flask.test_cli_runner`. See :ref:`testing-cli`.
""" """
def __init__(self, app, **kwargs): def __init__(self, app, **kwargs):
self.app = app self.app = app
super(FlaskCliRunner, self).__init__(**kwargs) super(FlaskCliRunner, self).__init__(**kwargs)
@ -244,7 +244,7 @@ class FlaskCliRunner(CliRunner):
if cli is None: if cli is None:
cli = self.app.cli cli = self.app.cli
if 'obj' not in kwargs: if "obj" not in kwargs:
kwargs['obj'] = ScriptInfo(create_app=lambda: self.app) kwargs["obj"] = ScriptInfo(create_app=lambda: self.app)
return super(FlaskCliRunner, self).invoke(cli, args, **kwargs) return super(FlaskCliRunner, self).invoke(cli, args, **kwargs)

View file

@ -13,8 +13,9 @@ from .globals import request
from ._compat import with_metaclass from ._compat import with_metaclass
http_method_funcs = frozenset(['get', 'post', 'head', 'options', http_method_funcs = frozenset(
'delete', 'put', 'trace', 'patch']) ["get", "post", "head", "options", "delete", "put", "trace", "patch"]
)
class View(object): class View(object):
@ -83,6 +84,7 @@ class View(object):
The arguments passed to :meth:`as_view` are forwarded to the The arguments passed to :meth:`as_view` are forwarded to the
constructor of the class. constructor of the class.
""" """
def view(*args, **kwargs): def view(*args, **kwargs):
self = view.view_class(*class_args, **class_kwargs) self = view.view_class(*class_args, **class_kwargs)
return self.dispatch_request(*args, **kwargs) return self.dispatch_request(*args, **kwargs)
@ -115,7 +117,7 @@ class MethodViewType(type):
def __init__(cls, name, bases, d): def __init__(cls, name, bases, d):
super(MethodViewType, cls).__init__(name, bases, d) super(MethodViewType, cls).__init__(name, bases, d)
if 'methods' not in d: if "methods" not in d:
methods = set() methods = set()
for key in http_method_funcs: for key in http_method_funcs:
@ -151,8 +153,8 @@ class MethodView(with_metaclass(MethodViewType, View)):
# If the request method is HEAD and we don't have a handler for it # If the request method is HEAD and we don't have a handler for it
# retry with GET. # retry with GET.
if meth is None and request.method == 'HEAD': if meth is None and request.method == "HEAD":
meth = getattr(self, 'get', None) meth = getattr(self, "get", None)
assert meth is not None, 'Unimplemented method %r' % request.method assert meth is not None, "Unimplemented method %r" % request.method
return meth(*args, **kwargs) return meth(*args, **kwargs)

View file

@ -34,8 +34,9 @@ class JSONMixin(object):
""" """
mt = self.mimetype mt = self.mimetype
return ( return (
mt == 'application/json' mt == "application/json"
or (mt.startswith('application/')) and mt.endswith('+json') or (mt.startswith("application/"))
and mt.endswith("+json")
) )
@property @property
@ -103,7 +104,7 @@ class JSONMixin(object):
.. versionadded:: 0.8 .. versionadded:: 0.8
""" """
if current_app is not None and current_app.debug: if current_app is not None and current_app.debug:
raise BadRequest('Failed to decode JSON object: {0}'.format(e)) raise BadRequest("Failed to decode JSON object: {0}".format(e))
raise BadRequest() raise BadRequest()
@ -146,7 +147,7 @@ class Request(RequestBase, JSONMixin):
def max_content_length(self): def max_content_length(self):
"""Read-only view of the ``MAX_CONTENT_LENGTH`` config key.""" """Read-only view of the ``MAX_CONTENT_LENGTH`` config key."""
if current_app: if current_app:
return current_app.config['MAX_CONTENT_LENGTH'] return current_app.config["MAX_CONTENT_LENGTH"]
@property @property
def endpoint(self): def endpoint(self):
@ -161,8 +162,8 @@ class Request(RequestBase, JSONMixin):
@property @property
def blueprint(self): def blueprint(self):
"""The name of the current blueprint""" """The name of the current blueprint"""
if self.url_rule and '.' in self.url_rule.endpoint: if self.url_rule and "." in self.url_rule.endpoint:
return self.url_rule.endpoint.rsplit('.', 1)[0] return self.url_rule.endpoint.rsplit(".", 1)[0]
def _load_form_data(self): def _load_form_data(self):
RequestBase._load_form_data(self) RequestBase._load_form_data(self)
@ -172,10 +173,11 @@ class Request(RequestBase, JSONMixin):
if ( if (
current_app current_app
and current_app.debug and current_app.debug
and self.mimetype != 'multipart/form-data' and self.mimetype != "multipart/form-data"
and not self.files and not self.files
): ):
from .debughelpers import attach_enctype_error_multidict from .debughelpers import attach_enctype_error_multidict
attach_enctype_error_multidict(self) attach_enctype_error_multidict(self)
@ -197,7 +199,7 @@ class Response(ResponseBase, JSONMixin):
Added :attr:`max_cookie_size`. Added :attr:`max_cookie_size`.
""" """
default_mimetype = 'text/html' default_mimetype = "text/html"
def _get_data_for_json(self, cache): def _get_data_for_json(self, cache):
return self.get_data() return self.get_data()
@ -210,7 +212,7 @@ class Response(ResponseBase, JSONMixin):
Werkzeug's docs. Werkzeug's docs.
""" """
if current_app: if current_app:
return current_app.config['MAX_COOKIE_SIZE'] return current_app.config["MAX_COOKIE_SIZE"]
# return Werkzeug's default when not in an app context # return Werkzeug's default when not in an app context
return super(Response, self).max_cookie_size return super(Response, self).max_cookie_size

View file

@ -7,21 +7,21 @@ import sys
from datetime import date, datetime from datetime import date, datetime
from subprocess import PIPE, Popen from subprocess import PIPE, Popen
_date_strip_re = re.compile(r'(?<=\d)(st|nd|rd|th)') _date_strip_re = re.compile(r"(?<=\d)(st|nd|rd|th)")
def parse_changelog(): def parse_changelog():
with open('CHANGES.rst') as f: with open("CHANGES.rst") as f:
lineiter = iter(f) lineiter = iter(f)
for line in lineiter: for line in lineiter:
match = re.search('^Version\s+(.*)', line.strip()) match = re.search("^Version\s+(.*)", line.strip())
if match is None: if match is None:
continue continue
version = match.group(1).strip() version = match.group(1).strip()
if next(lineiter).count('-') != len(match.group(0)): if next(lineiter).count("-") != len(match.group(0)):
continue continue
while 1: while 1:
@ -31,9 +31,9 @@ def parse_changelog():
break break
match = re.search( match = re.search(
r'released on (\w+\s+\d+\w+\s+\d+)(?:, codename (.*))?', r"released on (\w+\s+\d+\w+\s+\d+)(?:, codename (.*))?",
change_info, change_info,
flags=re.IGNORECASE flags=re.IGNORECASE,
) )
if match is None: if match is None:
@ -45,17 +45,17 @@ def parse_changelog():
def bump_version(version): def bump_version(version):
try: try:
parts = [int(i) for i in version.split('.')] parts = [int(i) for i in version.split(".")]
except ValueError: except ValueError:
fail('Current version is not numeric') fail("Current version is not numeric")
parts[-1] += 1 parts[-1] += 1
return '.'.join(map(str, parts)) return ".".join(map(str, parts))
def parse_date(string): def parse_date(string):
string = _date_strip_re.sub('', string) string = _date_strip_re.sub("", string)
return datetime.strptime(string, '%B %d %Y') return datetime.strptime(string, "%B %d %Y")
def set_filename_version(filename, version_number, pattern): def set_filename_version(filename, version_number, pattern):
@ -69,29 +69,30 @@ def set_filename_version(filename, version_number, pattern):
with open(filename) as f: with open(filename) as f:
contents = re.sub( contents = re.sub(
r"^(\s*%s\s*=\s*')(.+?)(')" % pattern, r"^(\s*%s\s*=\s*')(.+?)(')" % pattern,
inject_version, f.read(), inject_version,
flags=re.DOTALL | re.MULTILINE f.read(),
flags=re.DOTALL | re.MULTILINE,
) )
if not changed: if not changed:
fail('Could not find %s in %s', pattern, filename) fail("Could not find %s in %s", pattern, filename)
with open(filename, 'w') as f: with open(filename, "w") as f:
f.write(contents) f.write(contents)
def set_init_version(version): def set_init_version(version):
info('Setting __init__.py version to %s', version) info("Setting __init__.py version to %s", version)
set_filename_version('flask/__init__.py', version, '__version__') set_filename_version("flask/__init__.py", version, "__version__")
def build(): def build():
cmd = [sys.executable, 'setup.py', 'sdist', 'bdist_wheel'] cmd = [sys.executable, "setup.py", "sdist", "bdist_wheel"]
Popen(cmd).wait() Popen(cmd).wait()
def fail(message, *args): def fail(message, *args):
print('Error:', message % args, file=sys.stderr) print("Error:", message % args, file=sys.stderr)
sys.exit(1) sys.exit(1)
@ -100,39 +101,39 @@ def info(message, *args):
def get_git_tags(): def get_git_tags():
return set( return set(Popen(["git", "tag"], stdout=PIPE).communicate()[0].splitlines())
Popen(['git', 'tag'], stdout=PIPE).communicate()[0].splitlines()
)
def git_is_clean(): def git_is_clean():
return Popen(['git', 'diff', '--quiet']).wait() == 0 return Popen(["git", "diff", "--quiet"]).wait() == 0
def make_git_commit(message, *args): def make_git_commit(message, *args):
message = message % args message = message % args
Popen(['git', 'commit', '-am', message]).wait() Popen(["git", "commit", "-am", message]).wait()
def make_git_tag(tag): def make_git_tag(tag):
info('Tagging "%s"', tag) info('Tagging "%s"', tag)
Popen(['git', 'tag', tag]).wait() Popen(["git", "tag", tag]).wait()
def main(): def main():
os.chdir(os.path.join(os.path.dirname(__file__), '..')) os.chdir(os.path.join(os.path.dirname(__file__), ".."))
rv = parse_changelog() rv = parse_changelog()
if rv is None: if rv is None:
fail('Could not parse changelog') fail("Could not parse changelog")
version, release_date, codename = rv version, release_date, codename = rv
dev_version = bump_version(version) + '.dev' dev_version = bump_version(version) + ".dev"
info( info(
'Releasing %s (codename %s, release date %s)', "Releasing %s (codename %s, release date %s)",
version, codename, release_date.strftime('%d/%m/%Y') version,
codename,
release_date.strftime("%d/%m/%Y"),
) )
tags = get_git_tags() tags = get_git_tags()
@ -140,25 +141,22 @@ def main():
fail('Version "%s" is already tagged', version) fail('Version "%s" is already tagged', version)
if release_date.date() != date.today(): if release_date.date() != date.today():
fail( fail("Release date is not today (%s != %s)", release_date.date(), date.today())
'Release date is not today (%s != %s)',
release_date.date(), date.today()
)
if not git_is_clean(): if not git_is_clean():
fail('You have uncommitted changes in git') fail("You have uncommitted changes in git")
try: try:
import wheel # noqa: F401 import wheel # noqa: F401
except ImportError: except ImportError:
fail('You need to install the wheel package.') fail("You need to install the wheel package.")
set_init_version(version) set_init_version(version)
make_git_commit('Bump version number to %s', version) make_git_commit("Bump version number to %s", version)
make_git_tag(version) make_git_tag(version)
build() build()
set_init_version(dev_version) set_init_version(dev_version)
if __name__ == '__main__': if __name__ == "__main__":
main() main()

112
setup.py
View file

@ -6,78 +6,72 @@ from collections import OrderedDict
from setuptools import setup from setuptools import setup
with io.open('README.rst', 'rt', encoding='utf8') as f: with io.open("README.rst", "rt", encoding="utf8") as f:
readme = f.read() readme = f.read()
with io.open('flask/__init__.py', 'rt', encoding='utf8') as f: with io.open("flask/__init__.py", "rt", encoding="utf8") as f:
version = re.search(r'__version__ = \'(.*?)\'', f.read()).group(1) version = re.search(r"__version__ = \"(.*?)\"", f.read()).group(1)
setup( setup(
name='Flask', name="Flask",
version=version, version=version,
url='https://www.palletsprojects.com/p/flask/', url="https://www.palletsprojects.com/p/flask/",
project_urls=OrderedDict(( project_urls=OrderedDict(
('Documentation', 'http://flask.pocoo.org/docs/'), (
('Code', 'https://github.com/pallets/flask'), ("Documentation", "http://flask.pocoo.org/docs/"),
('Issue tracker', 'https://github.com/pallets/flask/issues'), ("Code", "https://github.com/pallets/flask"),
)), ("Issue tracker", "https://github.com/pallets/flask/issues"),
license='BSD', )
author='Armin Ronacher', ),
author_email='armin.ronacher@active-4.com', license="BSD",
maintainer='Pallets team', author="Armin Ronacher",
maintainer_email='contact@palletsprojects.com', author_email="armin.ronacher@active-4.com",
description='A simple framework for building complex web applications.', maintainer="Pallets team",
maintainer_email="contact@palletsprojects.com",
description="A simple framework for building complex web applications.",
long_description=readme, long_description=readme,
packages=['flask', 'flask.json'], packages=["flask", "flask.json"],
include_package_data=True, include_package_data=True,
zip_safe=False, zip_safe=False,
platforms='any', platforms="any",
python_requires='>=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*', python_requires=">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*",
install_requires=[ install_requires=[
'Werkzeug>=0.14', "Werkzeug>=0.14",
'Jinja2>=2.10.1', "Jinja2>=2.10.1",
'itsdangerous>=0.24', "itsdangerous>=0.24",
'click>=5.1', "click>=5.1",
], ],
extras_require={ extras_require={
'dotenv': ['python-dotenv'], "dotenv": ["python-dotenv"],
'dev': [ "dev": [
'pytest>=3', "pytest>=3",
'coverage', "coverage",
'tox', "tox",
'sphinx', "sphinx",
'pallets-sphinx-themes', "pallets-sphinx-themes",
'sphinxcontrib-log-cabinet', "sphinxcontrib-log-cabinet",
], ],
'docs': [ "docs": ["sphinx", "pallets-sphinx-themes", "sphinxcontrib-log-cabinet"],
'sphinx',
'pallets-sphinx-themes',
'sphinxcontrib-log-cabinet',
]
}, },
classifiers=[ classifiers=[
'Development Status :: 5 - Production/Stable', "Development Status :: 5 - Production/Stable",
'Environment :: Web Environment', "Environment :: Web Environment",
'Framework :: Flask', "Framework :: Flask",
'Intended Audience :: Developers', "Intended Audience :: Developers",
'License :: OSI Approved :: BSD License', "License :: OSI Approved :: BSD License",
'Operating System :: OS Independent', "Operating System :: OS Independent",
'Programming Language :: Python', "Programming Language :: Python",
'Programming Language :: Python :: 2', "Programming Language :: Python :: 2",
'Programming Language :: Python :: 2.7', "Programming Language :: Python :: 2.7",
'Programming Language :: Python :: 3', "Programming Language :: Python :: 3",
'Programming Language :: Python :: 3.4', "Programming Language :: Python :: 3.4",
'Programming Language :: Python :: 3.5', "Programming Language :: Python :: 3.5",
'Programming Language :: Python :: 3.6', "Programming Language :: Python :: 3.6",
'Programming Language :: Python :: 3.7', "Programming Language :: Python :: 3.7",
'Topic :: Internet :: WWW/HTTP :: Dynamic Content', "Topic :: Internet :: WWW/HTTP :: Dynamic Content",
'Topic :: Internet :: WWW/HTTP :: WSGI :: Application', "Topic :: Internet :: WWW/HTTP :: WSGI :: Application",
'Topic :: Software Development :: Libraries :: Application Frameworks', "Topic :: Software Development :: Libraries :: Application Frameworks",
'Topic :: Software Development :: Libraries :: Python Modules', "Topic :: Software Development :: Libraries :: Python Modules",
], ],
entry_points={ entry_points={"console_scripts": ["flask = flask.cli:main"]},
'console_scripts': [
'flask = flask.cli:main',
],
},
) )

View file

@ -20,7 +20,7 @@ import flask
from flask import Flask as _Flask from flask import Flask as _Flask
@pytest.fixture(scope='session', autouse=True) @pytest.fixture(scope="session", autouse=True)
def _standard_os_environ(): def _standard_os_environ():
"""Set up ``os.environ`` at the start of the test session to have """Set up ``os.environ`` at the start of the test session to have
standard values. Returns a list of operations that is used by standard values. Returns a list of operations that is used by
@ -28,11 +28,11 @@ def _standard_os_environ():
""" """
mp = monkeypatch.MonkeyPatch() mp = monkeypatch.MonkeyPatch()
out = ( out = (
(os.environ, 'FLASK_APP', monkeypatch.notset), (os.environ, "FLASK_APP", monkeypatch.notset),
(os.environ, 'FLASK_ENV', monkeypatch.notset), (os.environ, "FLASK_ENV", monkeypatch.notset),
(os.environ, 'FLASK_DEBUG', monkeypatch.notset), (os.environ, "FLASK_DEBUG", monkeypatch.notset),
(os.environ, 'FLASK_RUN_FROM_CLI', monkeypatch.notset), (os.environ, "FLASK_RUN_FROM_CLI", monkeypatch.notset),
(os.environ, 'WERKZEUG_RUN_MAIN', monkeypatch.notset), (os.environ, "WERKZEUG_RUN_MAIN", monkeypatch.notset),
) )
for _, key, value in out: for _, key, value in out:
@ -55,12 +55,12 @@ def _reset_os_environ(monkeypatch, _standard_os_environ):
class Flask(_Flask): class Flask(_Flask):
testing = True testing = True
secret_key = 'test key' secret_key = "test key"
@pytest.fixture @pytest.fixture
def app(): def app():
app = Flask('flask_test', root_path=os.path.dirname(__file__)) app = Flask("flask_test", root_path=os.path.dirname(__file__))
return app return app
@ -84,8 +84,7 @@ def client(app):
@pytest.fixture @pytest.fixture
def test_apps(monkeypatch): def test_apps(monkeypatch):
monkeypatch.syspath_prepend( monkeypatch.syspath_prepend(
os.path.abspath(os.path.join( os.path.abspath(os.path.join(os.path.dirname(__file__), "test_apps"))
os.path.dirname(__file__), 'test_apps'))
) )
@ -120,8 +119,8 @@ def limit_loader(request, monkeypatch):
self.loader = loader self.loader = loader
def __getattr__(self, name): def __getattr__(self, name):
if name in ('archive', 'get_filename'): if name in ("archive", "get_filename"):
msg = 'Mocking a loader which does not have `%s.`' % name msg = "Mocking a loader which does not have `%s.`" % name
raise AttributeError(msg) raise AttributeError(msg)
return getattr(self.loader, name) return getattr(self.loader, name)
@ -130,30 +129,31 @@ def limit_loader(request, monkeypatch):
def get_loader(*args, **kwargs): def get_loader(*args, **kwargs):
return LimitedLoader(old_get_loader(*args, **kwargs)) return LimitedLoader(old_get_loader(*args, **kwargs))
monkeypatch.setattr(pkgutil, 'get_loader', get_loader) monkeypatch.setattr(pkgutil, "get_loader", get_loader)
@pytest.fixture @pytest.fixture
def modules_tmpdir(tmpdir, monkeypatch): def modules_tmpdir(tmpdir, monkeypatch):
"""A tmpdir added to sys.path.""" """A tmpdir added to sys.path."""
rv = tmpdir.mkdir('modules_tmpdir') rv = tmpdir.mkdir("modules_tmpdir")
monkeypatch.syspath_prepend(str(rv)) monkeypatch.syspath_prepend(str(rv))
return rv return rv
@pytest.fixture @pytest.fixture
def modules_tmpdir_prefix(modules_tmpdir, monkeypatch): def modules_tmpdir_prefix(modules_tmpdir, monkeypatch):
monkeypatch.setattr(sys, 'prefix', str(modules_tmpdir)) monkeypatch.setattr(sys, "prefix", str(modules_tmpdir))
return modules_tmpdir return modules_tmpdir
@pytest.fixture @pytest.fixture
def site_packages(modules_tmpdir, monkeypatch): def site_packages(modules_tmpdir, monkeypatch):
"""Create a fake site-packages.""" """Create a fake site-packages."""
rv = modules_tmpdir \ rv = (
.mkdir('lib') \ modules_tmpdir.mkdir("lib")
.mkdir('python{x[0]}.{x[1]}'.format(x=sys.version_info)) \ .mkdir("python{x[0]}.{x[1]}".format(x=sys.version_info))
.mkdir('site-packages') .mkdir("site-packages")
)
monkeypatch.syspath_prepend(str(rv)) monkeypatch.syspath_prepend(str(rv))
return rv return rv
@ -167,23 +167,29 @@ def install_egg(modules_tmpdir, monkeypatch):
if not isinstance(name, str): if not isinstance(name, str):
raise ValueError(name) raise ValueError(name)
base.join(name).ensure_dir() base.join(name).ensure_dir()
base.join(name).join('__init__.py').ensure() base.join(name).join("__init__.py").ensure()
egg_setup = base.join('setup.py') egg_setup = base.join("setup.py")
egg_setup.write(textwrap.dedent(""" egg_setup.write(
textwrap.dedent(
"""
from setuptools import setup from setuptools import setup
setup(name='{0}', setup(name='{0}',
version='1.0', version='1.0',
packages=['site_egg'], packages=['site_egg'],
zip_safe=True) zip_safe=True)
""".format(name))) """.format(
name
)
)
)
import subprocess import subprocess
subprocess.check_call( subprocess.check_call(
[sys.executable, 'setup.py', 'bdist_egg'], [sys.executable, "setup.py", "bdist_egg"], cwd=str(modules_tmpdir)
cwd=str(modules_tmpdir)
) )
egg_path, = modules_tmpdir.join('dist/').listdir() egg_path, = modules_tmpdir.join("dist/").listdir()
monkeypatch.syspath_prepend(str(egg_path)) monkeypatch.syspath_prepend(str(egg_path))
return egg_path return egg_path
@ -202,4 +208,4 @@ def purge_module(request):
def catch_deprecation_warnings(recwarn): def catch_deprecation_warnings(recwarn):
yield yield
gc.collect() gc.collect()
assert not recwarn.list, '\n'.join(str(w.message) for w in recwarn.list) assert not recwarn.list, "\n".join(str(w.message) for w in recwarn.list)

View file

@ -15,27 +15,27 @@ import flask
def test_basic_url_generation(app): def test_basic_url_generation(app):
app.config['SERVER_NAME'] = 'localhost' app.config["SERVER_NAME"] = "localhost"
app.config['PREFERRED_URL_SCHEME'] = 'https' app.config["PREFERRED_URL_SCHEME"] = "https"
@app.route('/') @app.route("/")
def index(): def index():
pass pass
with app.app_context(): with app.app_context():
rv = flask.url_for('index') rv = flask.url_for("index")
assert rv == 'https://localhost/' assert rv == "https://localhost/"
def test_url_generation_requires_server_name(app): def test_url_generation_requires_server_name(app):
with app.app_context(): with app.app_context():
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
flask.url_for('index') flask.url_for("index")
def test_url_generation_without_context_fails(): def test_url_generation_without_context_fails():
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
flask.url_for('index') flask.url_for("index")
def test_request_context_means_app_context(app): def test_request_context_means_app_context(app):
@ -71,7 +71,7 @@ def test_app_tearing_down_with_previous_exception(app):
cleanup_stuff.append(exception) cleanup_stuff.append(exception)
try: try:
raise Exception('dummy') raise Exception("dummy")
except Exception: except Exception:
pass pass
@ -90,7 +90,7 @@ def test_app_tearing_down_with_handled_exception_by_except_block(app):
with app.app_context(): with app.app_context():
try: try:
raise Exception('dummy') raise Exception("dummy")
except Exception: except Exception:
pass pass
@ -98,79 +98,79 @@ def test_app_tearing_down_with_handled_exception_by_except_block(app):
def test_app_tearing_down_with_handled_exception_by_app_handler(app, client): def test_app_tearing_down_with_handled_exception_by_app_handler(app, client):
app.config['PROPAGATE_EXCEPTIONS'] = True app.config["PROPAGATE_EXCEPTIONS"] = True
cleanup_stuff = [] cleanup_stuff = []
@app.teardown_appcontext @app.teardown_appcontext
def cleanup(exception): def cleanup(exception):
cleanup_stuff.append(exception) cleanup_stuff.append(exception)
@app.route('/') @app.route("/")
def index(): def index():
raise Exception('dummy') raise Exception("dummy")
@app.errorhandler(Exception) @app.errorhandler(Exception)
def handler(f): def handler(f):
return flask.jsonify(str(f)) return flask.jsonify(str(f))
with app.app_context(): with app.app_context():
client.get('/') client.get("/")
assert cleanup_stuff == [None] assert cleanup_stuff == [None]
def test_app_tearing_down_with_unhandled_exception(app, client): def test_app_tearing_down_with_unhandled_exception(app, client):
app.config['PROPAGATE_EXCEPTIONS'] = True app.config["PROPAGATE_EXCEPTIONS"] = True
cleanup_stuff = [] cleanup_stuff = []
@app.teardown_appcontext @app.teardown_appcontext
def cleanup(exception): def cleanup(exception):
cleanup_stuff.append(exception) cleanup_stuff.append(exception)
@app.route('/') @app.route("/")
def index(): def index():
raise Exception('dummy') raise Exception("dummy")
with pytest.raises(Exception): with pytest.raises(Exception):
with app.app_context(): with app.app_context():
client.get('/') client.get("/")
assert len(cleanup_stuff) == 1 assert len(cleanup_stuff) == 1
assert isinstance(cleanup_stuff[0], Exception) assert isinstance(cleanup_stuff[0], Exception)
assert str(cleanup_stuff[0]) == 'dummy' assert str(cleanup_stuff[0]) == "dummy"
def test_app_ctx_globals_methods(app, app_ctx): def test_app_ctx_globals_methods(app, app_ctx):
# get # get
assert flask.g.get('foo') is None assert flask.g.get("foo") is None
assert flask.g.get('foo', 'bar') == 'bar' assert flask.g.get("foo", "bar") == "bar"
# __contains__ # __contains__
assert 'foo' not in flask.g assert "foo" not in flask.g
flask.g.foo = 'bar' flask.g.foo = "bar"
assert 'foo' in flask.g assert "foo" in flask.g
# setdefault # setdefault
flask.g.setdefault('bar', 'the cake is a lie') flask.g.setdefault("bar", "the cake is a lie")
flask.g.setdefault('bar', 'hello world') flask.g.setdefault("bar", "hello world")
assert flask.g.bar == 'the cake is a lie' assert flask.g.bar == "the cake is a lie"
# pop # pop
assert flask.g.pop('bar') == 'the cake is a lie' assert flask.g.pop("bar") == "the cake is a lie"
with pytest.raises(KeyError): with pytest.raises(KeyError):
flask.g.pop('bar') flask.g.pop("bar")
assert flask.g.pop('bar', 'more cake') == 'more cake' assert flask.g.pop("bar", "more cake") == "more cake"
# __iter__ # __iter__
assert list(flask.g) == ['foo'] assert list(flask.g) == ["foo"]
#__repr__ # __repr__
assert repr(flask.g) == "<flask.g of 'flask_test'>" assert repr(flask.g) == "<flask.g of 'flask_test'>"
def test_custom_app_ctx_globals_class(app): def test_custom_app_ctx_globals_class(app):
class CustomRequestGlobals(object): class CustomRequestGlobals(object):
def __init__(self): def __init__(self):
self.spam = 'eggs' self.spam = "eggs"
app.app_ctx_globals_class = CustomRequestGlobals app.app_ctx_globals_class = CustomRequestGlobals
with app.app_context(): with app.app_context():
assert flask.render_template_string('{{ g.spam }}') == 'eggs' assert flask.render_template_string("{{ g.spam }}") == "eggs"
def test_context_refcounts(app, client): def test_context_refcounts(app, client):
@ -178,25 +178,25 @@ def test_context_refcounts(app, client):
@app.teardown_request @app.teardown_request
def teardown_req(error=None): def teardown_req(error=None):
called.append('request') called.append("request")
@app.teardown_appcontext @app.teardown_appcontext
def teardown_app(error=None): def teardown_app(error=None):
called.append('app') called.append("app")
@app.route('/') @app.route("/")
def index(): def index():
with flask._app_ctx_stack.top: with flask._app_ctx_stack.top:
with flask._request_ctx_stack.top: with flask._request_ctx_stack.top:
pass pass
env = flask._request_ctx_stack.top.request.environ env = flask._request_ctx_stack.top.request.environ
assert env['werkzeug.request'] is not None assert env["werkzeug.request"] is not None
return u'' return u""
res = client.get('/') res = client.get("/")
assert res.status_code == 200 assert res.status_code == 200
assert res.data == b'' assert res.data == b""
assert called == ['request', 'app'] assert called == ["request", "app"]
def test_clean_pop(app): def test_clean_pop(app):
@ -209,7 +209,7 @@ def test_clean_pop(app):
@app.teardown_appcontext @app.teardown_appcontext
def teardown_app(error=None): def teardown_app(error=None):
called.append('TEARDOWN') called.append("TEARDOWN")
try: try:
with app.test_request_context(): with app.test_request_context():
@ -217,5 +217,5 @@ def test_clean_pop(app):
except ZeroDivisionError: except ZeroDivisionError:
pass pass
assert called == ['flask_test', 'TEARDOWN'] assert called == ["flask_test", "TEARDOWN"]
assert not flask.current_app assert not flask.current_app

View file

@ -1,8 +1,9 @@
from flask import Flask from flask import Flask
app = Flask(__name__) app = Flask(__name__)
app.config['DEBUG'] = True app.config["DEBUG"] = True
from blueprintapp.apps.admin import admin from blueprintapp.apps.admin import admin
from blueprintapp.apps.frontend import frontend from blueprintapp.apps.frontend import frontend
app.register_blueprint(admin) app.register_blueprint(admin)
app.register_blueprint(frontend) app.register_blueprint(frontend)

View file

@ -1,15 +1,19 @@
from flask import Blueprint, render_template from flask import Blueprint, render_template
admin = Blueprint('admin', __name__, url_prefix='/admin', admin = Blueprint(
template_folder='templates', "admin",
static_folder='static') __name__,
url_prefix="/admin",
template_folder="templates",
static_folder="static",
)
@admin.route('/') @admin.route("/")
def index(): def index():
return render_template('admin/index.html') return render_template("admin/index.html")
@admin.route('/index2') @admin.route("/index2")
def index2(): def index2():
return render_template('./admin/index.html') return render_template("./admin/index.html")

View file

@ -1,13 +1,13 @@
from flask import Blueprint, render_template from flask import Blueprint, render_template
frontend = Blueprint('frontend', __name__, template_folder='templates') frontend = Blueprint("frontend", __name__, template_folder="templates")
@frontend.route('/') @frontend.route("/")
def index(): def index():
return render_template('frontend/index.html') return render_template("frontend/index.html")
@frontend.route('/missing') @frontend.route("/missing")
def missing_template(): def missing_template():
return render_template('missing_template.html') return render_template("missing_template.html")

View file

@ -2,4 +2,4 @@ from __future__ import absolute_import, print_function
from flask import Flask from flask import Flask
testapp = Flask('testapp') testapp = Flask("testapp")

View file

@ -4,15 +4,15 @@ from flask import Flask
def create_app(): def create_app():
return Flask('app') return Flask("app")
def create_app2(foo, bar): def create_app2(foo, bar):
return Flask('_'.join(['app2', foo, bar])) return Flask("_".join(["app2", foo, bar]))
def create_app3(foo, script_info): def create_app3(foo, script_info):
return Flask('_'.join(['app3', foo, script_info.data['test']])) return Flask("_".join(["app3", foo, script_info.data["test"]]))
def no_app(): def no_app():

View file

@ -4,4 +4,4 @@ from flask import Flask
raise ImportError() raise ImportError()
testapp = Flask('testapp') testapp = Flask("testapp")

View file

@ -2,5 +2,5 @@ from __future__ import absolute_import, print_function
from flask import Flask from flask import Flask
app1 = Flask('app1') app1 = Flask("app1")
app2 = Flask('app2') app2 = Flask("app2")

View file

@ -1,6 +1,8 @@
from flask import Flask from flask import Flask
app = Flask(__name__) app = Flask(__name__)
@app.route("/") @app.route("/")
def hello(): def hello():
return "Hello World!" return "Hello World!"

View file

@ -1,4 +1,4 @@
from flask import Module from flask import Module
mod = Module(__name__, 'foo', subdomain='foo') mod = Module(__name__, "foo", subdomain="foo")

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -25,15 +25,22 @@ from click.testing import CliRunner
from flask import Flask, current_app from flask import Flask, current_app
from flask.cli import ( from flask.cli import (
AppGroup, FlaskGroup, NoAppException, ScriptInfo, dotenv, find_best_app, AppGroup,
get_version, load_dotenv, locate_app, prepare_import, run_command, FlaskGroup,
with_appcontext NoAppException,
ScriptInfo,
dotenv,
find_best_app,
get_version,
load_dotenv,
locate_app,
prepare_import,
run_command,
with_appcontext,
) )
cwd = os.getcwd() cwd = os.getcwd()
test_path = os.path.abspath(os.path.join( test_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "test_apps"))
os.path.dirname(__file__), 'test_apps'
))
@pytest.fixture @pytest.fixture
@ -44,6 +51,7 @@ def runner():
def test_cli_name(test_apps): def test_cli_name(test_apps):
"""Make sure the CLI object's name is the app's name and not the app itself""" """Make sure the CLI object's name is the app's name and not the app itself"""
from cliapp.app import testapp from cliapp.app import testapp
assert testapp.cli.name == testapp.name assert testapp.cli.name == testapp.name
@ -52,67 +60,67 @@ def test_find_best_app(test_apps):
script_info = ScriptInfo() script_info = ScriptInfo()
class Module: class Module:
app = Flask('appname') app = Flask("appname")
assert find_best_app(script_info, Module) == Module.app assert find_best_app(script_info, Module) == Module.app
class Module: class Module:
application = Flask('appname') application = Flask("appname")
assert find_best_app(script_info, Module) == Module.application assert find_best_app(script_info, Module) == Module.application
class Module: class Module:
myapp = Flask('appname') myapp = Flask("appname")
assert find_best_app(script_info, Module) == Module.myapp assert find_best_app(script_info, Module) == Module.myapp
class Module: class Module:
@staticmethod @staticmethod
def create_app(): def create_app():
return Flask('appname') return Flask("appname")
assert isinstance(find_best_app(script_info, Module), Flask) assert isinstance(find_best_app(script_info, Module), Flask)
assert find_best_app(script_info, Module).name == 'appname' assert find_best_app(script_info, Module).name == "appname"
class Module: class Module:
@staticmethod @staticmethod
def create_app(foo): def create_app(foo):
return Flask('appname') return Flask("appname")
assert isinstance(find_best_app(script_info, Module), Flask) assert isinstance(find_best_app(script_info, Module), Flask)
assert find_best_app(script_info, Module).name == 'appname' assert find_best_app(script_info, Module).name == "appname"
class Module: class Module:
@staticmethod @staticmethod
def create_app(foo=None, script_info=None): def create_app(foo=None, script_info=None):
return Flask('appname') return Flask("appname")
assert isinstance(find_best_app(script_info, Module), Flask) assert isinstance(find_best_app(script_info, Module), Flask)
assert find_best_app(script_info, Module).name == 'appname' assert find_best_app(script_info, Module).name == "appname"
class Module: class Module:
@staticmethod @staticmethod
def make_app(): def make_app():
return Flask('appname') return Flask("appname")
assert isinstance(find_best_app(script_info, Module), Flask) assert isinstance(find_best_app(script_info, Module), Flask)
assert find_best_app(script_info, Module).name == 'appname' assert find_best_app(script_info, Module).name == "appname"
class Module: class Module:
myapp = Flask('appname1') myapp = Flask("appname1")
@staticmethod @staticmethod
def create_app(): def create_app():
return Flask('appname2') return Flask("appname2")
assert find_best_app(script_info, Module) == Module.myapp assert find_best_app(script_info, Module) == Module.myapp
class Module: class Module:
myapp = Flask('appname1') myapp = Flask("appname1")
@staticmethod @staticmethod
def create_app(): def create_app():
return Flask('appname2') return Flask("appname2")
assert find_best_app(script_info, Module) == Module.myapp assert find_best_app(script_info, Module) == Module.myapp
@ -122,50 +130,56 @@ def test_find_best_app(test_apps):
pytest.raises(NoAppException, find_best_app, script_info, Module) pytest.raises(NoAppException, find_best_app, script_info, Module)
class Module: class Module:
myapp1 = Flask('appname1') myapp1 = Flask("appname1")
myapp2 = Flask('appname2') myapp2 = Flask("appname2")
pytest.raises(NoAppException, find_best_app, script_info, Module) pytest.raises(NoAppException, find_best_app, script_info, Module)
class Module: class Module:
@staticmethod @staticmethod
def create_app(foo, bar): def create_app(foo, bar):
return Flask('appname2') return Flask("appname2")
pytest.raises(NoAppException, find_best_app, script_info, Module) pytest.raises(NoAppException, find_best_app, script_info, Module)
class Module: class Module:
@staticmethod @staticmethod
def create_app(): def create_app():
raise TypeError('bad bad factory!') raise TypeError("bad bad factory!")
pytest.raises(TypeError, find_best_app, script_info, Module) pytest.raises(TypeError, find_best_app, script_info, Module)
@pytest.mark.parametrize('value,path,result', ( @pytest.mark.parametrize(
('test', cwd, 'test'), "value,path,result",
('test.py', cwd, 'test'),
('a/test', os.path.join(cwd, 'a'), 'test'),
('test/__init__.py', cwd, 'test'),
('test/__init__', cwd, 'test'),
# nested package
( (
os.path.join(test_path, 'cliapp', 'inner1', '__init__'), ("test", cwd, "test"),
test_path, 'cliapp.inner1' ("test.py", cwd, "test"),
("a/test", os.path.join(cwd, "a"), "test"),
("test/__init__.py", cwd, "test"),
("test/__init__", cwd, "test"),
# nested package
(
os.path.join(test_path, "cliapp", "inner1", "__init__"),
test_path,
"cliapp.inner1",
),
(
os.path.join(test_path, "cliapp", "inner1", "inner2"),
test_path,
"cliapp.inner1.inner2",
),
# dotted name
("test.a.b", cwd, "test.a.b"),
(os.path.join(test_path, "cliapp.app"), test_path, "cliapp.app"),
# not a Python file, will be caught during import
(
os.path.join(test_path, "cliapp", "message.txt"),
test_path,
"cliapp.message.txt",
),
), ),
( )
os.path.join(test_path, 'cliapp', 'inner1', 'inner2'),
test_path, 'cliapp.inner1.inner2'
),
# dotted name
('test.a.b', cwd, 'test.a.b'),
(os.path.join(test_path, 'cliapp.app'), test_path, 'cliapp.app'),
# not a Python file, will be caught during import
(
os.path.join(test_path, 'cliapp', 'message.txt'),
test_path, 'cliapp.message.txt'
),
))
def test_prepare_import(request, value, path, result): def test_prepare_import(request, value, path, result):
"""Expect the correct path to be set and the correct import and app names """Expect the correct path to be set and the correct import and app names
to be returned. to be returned.
@ -185,42 +199,48 @@ def test_prepare_import(request, value, path, result):
assert sys.path[0] == path assert sys.path[0] == path
@pytest.mark.parametrize('iname,aname,result', ( @pytest.mark.parametrize(
('cliapp.app', None, 'testapp'), "iname,aname,result",
('cliapp.app', 'testapp', 'testapp'), (
('cliapp.factory', None, 'app'), ("cliapp.app", None, "testapp"),
('cliapp.factory', 'create_app', 'app'), ("cliapp.app", "testapp", "testapp"),
('cliapp.factory', 'create_app()', 'app'), ("cliapp.factory", None, "app"),
# no script_info ("cliapp.factory", "create_app", "app"),
('cliapp.factory', 'create_app2("foo", "bar")', 'app2_foo_bar'), ("cliapp.factory", "create_app()", "app"),
# trailing comma space # no script_info
('cliapp.factory', 'create_app2("foo", "bar", )', 'app2_foo_bar'), ("cliapp.factory", 'create_app2("foo", "bar")', "app2_foo_bar"),
# takes script_info # trailing comma space
('cliapp.factory', 'create_app3("foo")', 'app3_foo_spam'), ("cliapp.factory", 'create_app2("foo", "bar", )', "app2_foo_bar"),
# strip whitespace # takes script_info
('cliapp.factory', ' create_app () ', 'app'), ("cliapp.factory", 'create_app3("foo")', "app3_foo_spam"),
)) # strip whitespace
("cliapp.factory", " create_app () ", "app"),
),
)
def test_locate_app(test_apps, iname, aname, result): def test_locate_app(test_apps, iname, aname, result):
info = ScriptInfo() info = ScriptInfo()
info.data['test'] = 'spam' info.data["test"] = "spam"
assert locate_app(info, iname, aname).name == result assert locate_app(info, iname, aname).name == result
@pytest.mark.parametrize('iname,aname', ( @pytest.mark.parametrize(
('notanapp.py', None), "iname,aname",
('cliapp/app', None), (
('cliapp.app', 'notanapp'), ("notanapp.py", None),
# not enough arguments ("cliapp/app", None),
('cliapp.factory', 'create_app2("foo")'), ("cliapp.app", "notanapp"),
# invalid identifier # not enough arguments
('cliapp.factory', 'create_app('), ("cliapp.factory", 'create_app2("foo")'),
# no app returned # invalid identifier
('cliapp.factory', 'no_app'), ("cliapp.factory", "create_app("),
# nested import error # no app returned
('cliapp.importerrorapp', None), ("cliapp.factory", "no_app"),
# not a Python file # nested import error
('cliapp.message.txt', None), ("cliapp.importerrorapp", None),
)) # not a Python file
("cliapp.message.txt", None),
),
)
def test_locate_app_raises(test_apps, iname, aname): def test_locate_app_raises(test_apps, iname, aname):
info = ScriptInfo() info = ScriptInfo()
@ -230,14 +250,12 @@ def test_locate_app_raises(test_apps, iname, aname):
def test_locate_app_suppress_raise(): def test_locate_app_suppress_raise():
info = ScriptInfo() info = ScriptInfo()
app = locate_app(info, 'notanapp.py', None, raise_if_not_found=False) app = locate_app(info, "notanapp.py", None, raise_if_not_found=False)
assert app is None assert app is None
# only direct import error is suppressed # only direct import error is suppressed
with pytest.raises(NoAppException): with pytest.raises(NoAppException):
locate_app( locate_app(info, "cliapp.importerrorapp", None, raise_if_not_found=False)
info, 'cliapp.importerrorapp', None, raise_if_not_found=False
)
def test_get_version(test_apps, capsys): def test_get_version(test_apps, capsys):
@ -249,7 +267,8 @@ def test_get_version(test_apps, capsys):
resilient_parsing = False resilient_parsing = False
color = None color = None
def exit(self): return def exit(self):
return
ctx = MockCtx() ctx = MockCtx()
get_version(ctx, None, "test") get_version(ctx, None, "test")
@ -267,15 +286,16 @@ def test_scriptinfo(test_apps, monkeypatch):
assert obj.load_app() is app assert obj.load_app() is app
# import app with module's absolute path # import app with module's absolute path
cli_app_path = os.path.abspath(os.path.join( cli_app_path = os.path.abspath(
os.path.dirname(__file__), 'test_apps', 'cliapp', 'app.py')) os.path.join(os.path.dirname(__file__), "test_apps", "cliapp", "app.py")
)
obj = ScriptInfo(app_import_path=cli_app_path) obj = ScriptInfo(app_import_path=cli_app_path)
app = obj.load_app() app = obj.load_app()
assert app.name == 'testapp' assert app.name == "testapp"
assert obj.load_app() is app assert obj.load_app() is app
obj = ScriptInfo(app_import_path=cli_app_path + ':testapp') obj = ScriptInfo(app_import_path=cli_app_path + ":testapp")
app = obj.load_app() app = obj.load_app()
assert app.name == 'testapp' assert app.name == "testapp"
assert obj.load_app() is app assert obj.load_app() is app
def create_app(info): def create_app(info):
@ -290,20 +310,22 @@ def test_scriptinfo(test_apps, monkeypatch):
pytest.raises(NoAppException, obj.load_app) pytest.raises(NoAppException, obj.load_app)
# import app from wsgi.py in current directory # import app from wsgi.py in current directory
monkeypatch.chdir(os.path.abspath(os.path.join( monkeypatch.chdir(
os.path.dirname(__file__), 'test_apps', 'helloworld' os.path.abspath(
))) os.path.join(os.path.dirname(__file__), "test_apps", "helloworld")
)
)
obj = ScriptInfo() obj = ScriptInfo()
app = obj.load_app() app = obj.load_app()
assert app.name == 'hello' assert app.name == "hello"
# import app from app.py in current directory # import app from app.py in current directory
monkeypatch.chdir(os.path.abspath(os.path.join( monkeypatch.chdir(
os.path.dirname(__file__), 'test_apps', 'cliapp' os.path.abspath(os.path.join(os.path.dirname(__file__), "test_apps", "cliapp"))
))) )
obj = ScriptInfo() obj = ScriptInfo()
app = obj.load_app() app = obj.load_app()
assert app.name == 'testapp' assert app.name == "testapp"
def test_with_appcontext(runner): def test_with_appcontext(runner):
@ -318,7 +340,7 @@ def test_with_appcontext(runner):
result = runner.invoke(testcmd, obj=obj) result = runner.invoke(testcmd, obj=obj)
assert result.exit_code == 0 assert result.exit_code == 0
assert result.output == 'testapp\n' assert result.output == "testapp\n"
def test_appgroup(runner): def test_appgroup(runner):
@ -342,13 +364,13 @@ def test_appgroup(runner):
obj = ScriptInfo(create_app=lambda info: Flask("testappgroup")) obj = ScriptInfo(create_app=lambda info: Flask("testappgroup"))
result = runner.invoke(cli, ['test'], obj=obj) result = runner.invoke(cli, ["test"], obj=obj)
assert result.exit_code == 0 assert result.exit_code == 0
assert result.output == 'testappgroup\n' assert result.output == "testappgroup\n"
result = runner.invoke(cli, ['subgroup', 'test2'], obj=obj) result = runner.invoke(cli, ["subgroup", "test2"], obj=obj)
assert result.exit_code == 0 assert result.exit_code == 0
assert result.output == 'testappgroup\n' assert result.output == "testappgroup\n"
def test_flaskgroup(runner): def test_flaskgroup(runner):
@ -365,12 +387,12 @@ def test_flaskgroup(runner):
def test(): def test():
click.echo(current_app.name) click.echo(current_app.name)
result = runner.invoke(cli, ['test']) result = runner.invoke(cli, ["test"])
assert result.exit_code == 0 assert result.exit_code == 0
assert result.output == 'flaskgroup\n' assert result.output == "flaskgroup\n"
@pytest.mark.parametrize('set_debug_flag', (True, False)) @pytest.mark.parametrize("set_debug_flag", (True, False))
def test_flaskgroup_debug(runner, set_debug_flag): def test_flaskgroup_debug(runner, set_debug_flag):
"""Test FlaskGroup debug flag behavior.""" """Test FlaskGroup debug flag behavior."""
@ -387,9 +409,9 @@ def test_flaskgroup_debug(runner, set_debug_flag):
def test(): def test():
click.echo(str(current_app.debug)) click.echo(str(current_app.debug))
result = runner.invoke(cli, ['test']) result = runner.invoke(cli, ["test"])
assert result.exit_code == 0 assert result.exit_code == 0
assert result.output == '%s\n' % str(not set_debug_flag) assert result.output == "%s\n" % str(not set_debug_flag)
def test_print_exceptions(runner): def test_print_exceptions(runner):
@ -403,10 +425,10 @@ def test_print_exceptions(runner):
def cli(**params): def cli(**params):
pass pass
result = runner.invoke(cli, ['--help']) result = runner.invoke(cli, ["--help"])
assert result.exit_code == 0 assert result.exit_code == 0
assert 'Exception: oh no' in result.output assert "Exception: oh no" in result.output
assert 'Traceback' in result.output assert "Traceback" in result.output
class TestRoutes: class TestRoutes:
@ -416,11 +438,11 @@ class TestRoutes:
app = Flask(__name__) app = Flask(__name__)
app.testing = True app.testing = True
@app.route('/get_post/<int:x>/<int:y>', methods=['GET', 'POST']) @app.route("/get_post/<int:x>/<int:y>", methods=["GET", "POST"])
def yyy_get_post(x, y): def yyy_get_post(x, y):
pass pass
@app.route('/zzz_post', methods=['POST']) @app.route("/zzz_post", methods=["POST"])
def aaa_post(): def aaa_post():
pass pass
@ -444,138 +466,132 @@ class TestRoutes:
# skip the header and match the start of each row # skip the header and match the start of each row
for expect, line in zip(order, output.splitlines()[2:]): for expect, line in zip(order, output.splitlines()[2:]):
# do this instead of startswith for nicer pytest output # do this instead of startswith for nicer pytest output
assert line[:len(expect)] == expect assert line[: len(expect)] == expect
def test_simple(self, invoke): def test_simple(self, invoke):
result = invoke(['routes']) result = invoke(["routes"])
assert result.exit_code == 0 assert result.exit_code == 0
self.expect_order( self.expect_order(["aaa_post", "static", "yyy_get_post"], result.output)
['aaa_post', 'static', 'yyy_get_post'],
result.output
)
def test_sort(self, invoke): def test_sort(self, invoke):
default_output = invoke(['routes']).output default_output = invoke(["routes"]).output
endpoint_output = invoke(['routes', '-s', 'endpoint']).output endpoint_output = invoke(["routes", "-s", "endpoint"]).output
assert default_output == endpoint_output assert default_output == endpoint_output
self.expect_order( self.expect_order(
['static', 'yyy_get_post', 'aaa_post'], ["static", "yyy_get_post", "aaa_post"],
invoke(['routes', '-s', 'methods']).output invoke(["routes", "-s", "methods"]).output,
) )
self.expect_order( self.expect_order(
['yyy_get_post', 'static', 'aaa_post'], ["yyy_get_post", "static", "aaa_post"],
invoke(['routes', '-s', 'rule']).output invoke(["routes", "-s", "rule"]).output,
) )
self.expect_order( self.expect_order(
['aaa_post', 'yyy_get_post', 'static'], ["aaa_post", "yyy_get_post", "static"],
invoke(['routes', '-s', 'match']).output invoke(["routes", "-s", "match"]).output,
) )
def test_all_methods(self, invoke): def test_all_methods(self, invoke):
output = invoke(['routes']).output output = invoke(["routes"]).output
assert 'GET, HEAD, OPTIONS, POST' not in output assert "GET, HEAD, OPTIONS, POST" not in output
output = invoke(['routes', '--all-methods']).output output = invoke(["routes", "--all-methods"]).output
assert 'GET, HEAD, OPTIONS, POST' in output assert "GET, HEAD, OPTIONS, POST" in output
def test_no_routes(self, invoke_no_routes): def test_no_routes(self, invoke_no_routes):
result = invoke_no_routes(['routes']) result = invoke_no_routes(["routes"])
assert result.exit_code == 0 assert result.exit_code == 0
assert 'No routes were registered.' in result.output assert "No routes were registered." in result.output
need_dotenv = pytest.mark.skipif( need_dotenv = pytest.mark.skipif(dotenv is None, reason="dotenv is not installed")
dotenv is None, reason='dotenv is not installed'
)
@need_dotenv @need_dotenv
def test_load_dotenv(monkeypatch): def test_load_dotenv(monkeypatch):
# can't use monkeypatch.delitem since the keys don't exist yet # can't use monkeypatch.delitem since the keys don't exist yet
for item in ('FOO', 'BAR', 'SPAM'): for item in ("FOO", "BAR", "SPAM"):
monkeypatch._setitem.append((os.environ, item, notset)) monkeypatch._setitem.append((os.environ, item, notset))
monkeypatch.setenv('EGGS', '3') monkeypatch.setenv("EGGS", "3")
monkeypatch.chdir(os.path.join(test_path, 'cliapp', 'inner1')) monkeypatch.chdir(os.path.join(test_path, "cliapp", "inner1"))
load_dotenv() load_dotenv()
assert os.getcwd() == test_path assert os.getcwd() == test_path
# .flaskenv doesn't overwrite .env # .flaskenv doesn't overwrite .env
assert os.environ['FOO'] == 'env' assert os.environ["FOO"] == "env"
# set only in .flaskenv # set only in .flaskenv
assert os.environ['BAR'] == 'bar' assert os.environ["BAR"] == "bar"
# set only in .env # set only in .env
assert os.environ['SPAM'] == '1' assert os.environ["SPAM"] == "1"
# set manually, files don't overwrite # set manually, files don't overwrite
assert os.environ['EGGS'] == '3' assert os.environ["EGGS"] == "3"
@need_dotenv @need_dotenv
def test_dotenv_path(monkeypatch): def test_dotenv_path(monkeypatch):
for item in ('FOO', 'BAR', 'EGGS'): for item in ("FOO", "BAR", "EGGS"):
monkeypatch._setitem.append((os.environ, item, notset)) monkeypatch._setitem.append((os.environ, item, notset))
cwd = os.getcwd() cwd = os.getcwd()
load_dotenv(os.path.join(test_path, '.flaskenv')) load_dotenv(os.path.join(test_path, ".flaskenv"))
assert os.getcwd() == cwd assert os.getcwd() == cwd
assert 'FOO' in os.environ assert "FOO" in os.environ
def test_dotenv_optional(monkeypatch): def test_dotenv_optional(monkeypatch):
monkeypatch.setattr('flask.cli.dotenv', None) monkeypatch.setattr("flask.cli.dotenv", None)
monkeypatch.chdir(test_path) monkeypatch.chdir(test_path)
load_dotenv() load_dotenv()
assert 'FOO' not in os.environ assert "FOO" not in os.environ
@need_dotenv @need_dotenv
def test_disable_dotenv_from_env(monkeypatch, runner): def test_disable_dotenv_from_env(monkeypatch, runner):
monkeypatch.chdir(test_path) monkeypatch.chdir(test_path)
monkeypatch.setitem(os.environ, 'FLASK_SKIP_DOTENV', '1') monkeypatch.setitem(os.environ, "FLASK_SKIP_DOTENV", "1")
runner.invoke(FlaskGroup()) runner.invoke(FlaskGroup())
assert 'FOO' not in os.environ assert "FOO" not in os.environ
def test_run_cert_path(): def test_run_cert_path():
# no key # no key
with pytest.raises(click.BadParameter): with pytest.raises(click.BadParameter):
run_command.make_context('run', ['--cert', __file__]) run_command.make_context("run", ["--cert", __file__])
# no cert # no cert
with pytest.raises(click.BadParameter): with pytest.raises(click.BadParameter):
run_command.make_context('run', ['--key', __file__]) run_command.make_context("run", ["--key", __file__])
ctx = run_command.make_context( ctx = run_command.make_context("run", ["--cert", __file__, "--key", __file__])
'run', ['--cert', __file__, '--key', __file__]) assert ctx.params["cert"] == (__file__, __file__)
assert ctx.params['cert'] == (__file__, __file__)
def test_run_cert_adhoc(monkeypatch): def test_run_cert_adhoc(monkeypatch):
monkeypatch.setitem(sys.modules, 'OpenSSL', None) monkeypatch.setitem(sys.modules, "OpenSSL", None)
# pyOpenSSL not installed # pyOpenSSL not installed
with pytest.raises(click.BadParameter): with pytest.raises(click.BadParameter):
run_command.make_context('run', ['--cert', 'adhoc']) run_command.make_context("run", ["--cert", "adhoc"])
# pyOpenSSL installed # pyOpenSSL installed
monkeypatch.setitem(sys.modules, 'OpenSSL', types.ModuleType('OpenSSL')) monkeypatch.setitem(sys.modules, "OpenSSL", types.ModuleType("OpenSSL"))
ctx = run_command.make_context('run', ['--cert', 'adhoc']) ctx = run_command.make_context("run", ["--cert", "adhoc"])
assert ctx.params['cert'] == 'adhoc' assert ctx.params["cert"] == "adhoc"
# no key with adhoc # no key with adhoc
with pytest.raises(click.BadParameter): with pytest.raises(click.BadParameter):
run_command.make_context('run', ['--cert', 'adhoc', '--key', __file__]) run_command.make_context("run", ["--cert", "adhoc", "--key", __file__])
def test_run_cert_import(monkeypatch): def test_run_cert_import(monkeypatch):
monkeypatch.setitem(sys.modules, 'not_here', None) monkeypatch.setitem(sys.modules, "not_here", None)
# ImportError # ImportError
with pytest.raises(click.BadParameter): with pytest.raises(click.BadParameter):
run_command.make_context('run', ['--cert', 'not_here']) run_command.make_context("run", ["--cert", "not_here"])
# not an SSLContext # not an SSLContext
if sys.version_info >= (2, 7, 9): if sys.version_info >= (2, 7, 9):
with pytest.raises(click.BadParameter): with pytest.raises(click.BadParameter):
run_command.make_context('run', ['--cert', 'flask']) run_command.make_context("run", ["--cert", "flask"])
# SSLContext # SSLContext
if sys.version_info < (2, 7, 9): if sys.version_info < (2, 7, 9):
@ -583,11 +599,10 @@ def test_run_cert_import(monkeypatch):
else: else:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
monkeypatch.setitem(sys.modules, 'ssl_context', ssl_context) monkeypatch.setitem(sys.modules, "ssl_context", ssl_context)
ctx = run_command.make_context('run', ['--cert', 'ssl_context']) ctx = run_command.make_context("run", ["--cert", "ssl_context"])
assert ctx.params['cert'] is ssl_context assert ctx.params["cert"] is ssl_context
# no --key with SSLContext # no --key with SSLContext
with pytest.raises(click.BadParameter): with pytest.raises(click.BadParameter):
run_command.make_context( run_command.make_context("run", ["--cert", "ssl_context", "--key", __file__])
'run', ['--cert', 'ssl_context', '--key', __file__])

View file

@ -17,19 +17,19 @@ import pytest
# config keys used for the TestConfig # config keys used for the TestConfig
TEST_KEY = 'foo' TEST_KEY = "foo"
SECRET_KEY = 'config' SECRET_KEY = "config"
def common_object_test(app): def common_object_test(app):
assert app.secret_key == 'config' assert app.secret_key == "config"
assert app.config['TEST_KEY'] == 'foo' assert app.config["TEST_KEY"] == "foo"
assert 'TestConfig' not in app.config assert "TestConfig" not in app.config
def test_config_from_file(): def test_config_from_file():
app = flask.Flask(__name__) app = flask.Flask(__name__)
app.config.from_pyfile(__file__.rsplit('.', 1)[0] + '.py') app.config.from_pyfile(__file__.rsplit(".", 1)[0] + ".py")
common_object_test(app) common_object_test(app)
@ -42,45 +42,34 @@ def test_config_from_object():
def test_config_from_json(): def test_config_from_json():
app = flask.Flask(__name__) app = flask.Flask(__name__)
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
app.config.from_json(os.path.join(current_dir, 'static', 'config.json')) app.config.from_json(os.path.join(current_dir, "static", "config.json"))
common_object_test(app) common_object_test(app)
def test_config_from_mapping(): def test_config_from_mapping():
app = flask.Flask(__name__) app = flask.Flask(__name__)
app.config.from_mapping({ app.config.from_mapping({"SECRET_KEY": "config", "TEST_KEY": "foo"})
'SECRET_KEY': 'config',
'TEST_KEY': 'foo'
})
common_object_test(app) common_object_test(app)
app = flask.Flask(__name__) app = flask.Flask(__name__)
app.config.from_mapping([ app.config.from_mapping([("SECRET_KEY", "config"), ("TEST_KEY", "foo")])
('SECRET_KEY', 'config'),
('TEST_KEY', 'foo')
])
common_object_test(app) common_object_test(app)
app = flask.Flask(__name__) app = flask.Flask(__name__)
app.config.from_mapping( app.config.from_mapping(SECRET_KEY="config", TEST_KEY="foo")
SECRET_KEY='config',
TEST_KEY='foo'
)
common_object_test(app) common_object_test(app)
app = flask.Flask(__name__) app = flask.Flask(__name__)
with pytest.raises(TypeError): with pytest.raises(TypeError):
app.config.from_mapping( app.config.from_mapping({}, {})
{}, {}
)
def test_config_from_class(): def test_config_from_class():
class Base(object): class Base(object):
TEST_KEY = 'foo' TEST_KEY = "foo"
class Test(Base): class Test(Base):
SECRET_KEY = 'config' SECRET_KEY = "config"
app = flask.Flask(__name__) app = flask.Flask(__name__)
app.config.from_object(Test) app.config.from_object(Test)
@ -93,12 +82,12 @@ def test_config_from_envvar():
os.environ = {} os.environ = {}
app = flask.Flask(__name__) app = flask.Flask(__name__)
with pytest.raises(RuntimeError) as e: with pytest.raises(RuntimeError) as e:
app.config.from_envvar('FOO_SETTINGS') app.config.from_envvar("FOO_SETTINGS")
assert "'FOO_SETTINGS' is not set" in str(e.value) assert "'FOO_SETTINGS' is not set" in str(e.value)
assert not app.config.from_envvar('FOO_SETTINGS', silent=True) assert not app.config.from_envvar("FOO_SETTINGS", silent=True)
os.environ = {'FOO_SETTINGS': __file__.rsplit('.', 1)[0] + '.py'} os.environ = {"FOO_SETTINGS": __file__.rsplit(".", 1)[0] + ".py"}
assert app.config.from_envvar('FOO_SETTINGS') assert app.config.from_envvar("FOO_SETTINGS")
common_object_test(app) common_object_test(app)
finally: finally:
os.environ = env os.environ = env
@ -107,15 +96,17 @@ def test_config_from_envvar():
def test_config_from_envvar_missing(): def test_config_from_envvar_missing():
env = os.environ env = os.environ
try: try:
os.environ = {'FOO_SETTINGS': 'missing.cfg'} os.environ = {"FOO_SETTINGS": "missing.cfg"}
with pytest.raises(IOError) as e: with pytest.raises(IOError) as e:
app = flask.Flask(__name__) app = flask.Flask(__name__)
app.config.from_envvar('FOO_SETTINGS') app.config.from_envvar("FOO_SETTINGS")
msg = str(e.value) msg = str(e.value)
assert msg.startswith('[Errno 2] Unable to load configuration ' assert msg.startswith(
'file (No such file or directory):') "[Errno 2] Unable to load configuration "
"file (No such file or directory):"
)
assert msg.endswith("missing.cfg'") assert msg.endswith("missing.cfg'")
assert not app.config.from_envvar('FOO_SETTINGS', silent=True) assert not app.config.from_envvar("FOO_SETTINGS", silent=True)
finally: finally:
os.environ = env os.environ = env
@ -123,23 +114,25 @@ def test_config_from_envvar_missing():
def test_config_missing(): def test_config_missing():
app = flask.Flask(__name__) app = flask.Flask(__name__)
with pytest.raises(IOError) as e: with pytest.raises(IOError) as e:
app.config.from_pyfile('missing.cfg') app.config.from_pyfile("missing.cfg")
msg = str(e.value) msg = str(e.value)
assert msg.startswith('[Errno 2] Unable to load configuration ' assert msg.startswith(
'file (No such file or directory):') "[Errno 2] Unable to load configuration " "file (No such file or directory):"
)
assert msg.endswith("missing.cfg'") assert msg.endswith("missing.cfg'")
assert not app.config.from_pyfile('missing.cfg', silent=True) assert not app.config.from_pyfile("missing.cfg", silent=True)
def test_config_missing_json(): def test_config_missing_json():
app = flask.Flask(__name__) app = flask.Flask(__name__)
with pytest.raises(IOError) as e: with pytest.raises(IOError) as e:
app.config.from_json('missing.json') app.config.from_json("missing.json")
msg = str(e.value) msg = str(e.value)
assert msg.startswith('[Errno 2] Unable to load configuration ' assert msg.startswith(
'file (No such file or directory):') "[Errno 2] Unable to load configuration " "file (No such file or directory):"
)
assert msg.endswith("missing.json'") assert msg.endswith("missing.json'")
assert not app.config.from_json('missing.json', silent=True) assert not app.config.from_json("missing.json", silent=True)
def test_custom_config_class(): def test_custom_config_class():
@ -148,6 +141,7 @@ def test_custom_config_class():
class Flask(flask.Flask): class Flask(flask.Flask):
config_class = Config config_class = Config
app = Flask(__name__) app = Flask(__name__)
assert isinstance(app.config, Config) assert isinstance(app.config, Config)
app.config.from_object(__name__) app.config.from_object(__name__)
@ -156,52 +150,60 @@ def test_custom_config_class():
def test_session_lifetime(): def test_session_lifetime():
app = flask.Flask(__name__) app = flask.Flask(__name__)
app.config['PERMANENT_SESSION_LIFETIME'] = 42 app.config["PERMANENT_SESSION_LIFETIME"] = 42
assert app.permanent_session_lifetime.seconds == 42 assert app.permanent_session_lifetime.seconds == 42
def test_send_file_max_age(): def test_send_file_max_age():
app = flask.Flask(__name__) app = flask.Flask(__name__)
app.config['SEND_FILE_MAX_AGE_DEFAULT'] = 3600 app.config["SEND_FILE_MAX_AGE_DEFAULT"] = 3600
assert app.send_file_max_age_default.seconds == 3600 assert app.send_file_max_age_default.seconds == 3600
app.config['SEND_FILE_MAX_AGE_DEFAULT'] = timedelta(hours=2) app.config["SEND_FILE_MAX_AGE_DEFAULT"] = timedelta(hours=2)
assert app.send_file_max_age_default.seconds == 7200 assert app.send_file_max_age_default.seconds == 7200
def test_get_namespace(): def test_get_namespace():
app = flask.Flask(__name__) app = flask.Flask(__name__)
app.config['FOO_OPTION_1'] = 'foo option 1' app.config["FOO_OPTION_1"] = "foo option 1"
app.config['FOO_OPTION_2'] = 'foo option 2' app.config["FOO_OPTION_2"] = "foo option 2"
app.config['BAR_STUFF_1'] = 'bar stuff 1' app.config["BAR_STUFF_1"] = "bar stuff 1"
app.config['BAR_STUFF_2'] = 'bar stuff 2' app.config["BAR_STUFF_2"] = "bar stuff 2"
foo_options = app.config.get_namespace('FOO_') foo_options = app.config.get_namespace("FOO_")
assert 2 == len(foo_options) assert 2 == len(foo_options)
assert 'foo option 1' == foo_options['option_1'] assert "foo option 1" == foo_options["option_1"]
assert 'foo option 2' == foo_options['option_2'] assert "foo option 2" == foo_options["option_2"]
bar_options = app.config.get_namespace('BAR_', lowercase=False) bar_options = app.config.get_namespace("BAR_", lowercase=False)
assert 2 == len(bar_options) assert 2 == len(bar_options)
assert 'bar stuff 1' == bar_options['STUFF_1'] assert "bar stuff 1" == bar_options["STUFF_1"]
assert 'bar stuff 2' == bar_options['STUFF_2'] assert "bar stuff 2" == bar_options["STUFF_2"]
foo_options = app.config.get_namespace('FOO_', trim_namespace=False) foo_options = app.config.get_namespace("FOO_", trim_namespace=False)
assert 2 == len(foo_options) assert 2 == len(foo_options)
assert 'foo option 1' == foo_options['foo_option_1'] assert "foo option 1" == foo_options["foo_option_1"]
assert 'foo option 2' == foo_options['foo_option_2'] assert "foo option 2" == foo_options["foo_option_2"]
bar_options = app.config.get_namespace('BAR_', lowercase=False, trim_namespace=False) bar_options = app.config.get_namespace(
"BAR_", lowercase=False, trim_namespace=False
)
assert 2 == len(bar_options) assert 2 == len(bar_options)
assert 'bar stuff 1' == bar_options['BAR_STUFF_1'] assert "bar stuff 1" == bar_options["BAR_STUFF_1"]
assert 'bar stuff 2' == bar_options['BAR_STUFF_2'] assert "bar stuff 2" == bar_options["BAR_STUFF_2"]
@pytest.mark.parametrize('encoding', ['utf-8', 'iso-8859-15', 'latin-1']) @pytest.mark.parametrize("encoding", ["utf-8", "iso-8859-15", "latin-1"])
def test_from_pyfile_weird_encoding(tmpdir, encoding): def test_from_pyfile_weird_encoding(tmpdir, encoding):
f = tmpdir.join('my_config.py') f = tmpdir.join("my_config.py")
f.write_binary(textwrap.dedent(u''' f.write_binary(
textwrap.dedent(
u"""
# -*- coding: {0} -*- # -*- coding: {0} -*-
TEST_VALUE = "föö" TEST_VALUE = "föö"
'''.format(encoding)).encode(encoding)) """.format(
encoding
)
).encode(encoding)
)
app = flask.Flask(__name__) app = flask.Flask(__name__)
app.config.from_pyfile(str(f)) app.config.from_pyfile(str(f))
value = app.config['TEST_VALUE'] value = app.config["TEST_VALUE"]
if PY2: if PY2:
value = value.decode(encoding) value = value.decode(encoding)
assert value == u'föö' assert value == u"föö"

File diff suppressed because it is too large Load diff

View file

@ -17,112 +17,120 @@ from flask._compat import PY2
def test_explicit_instance_paths(modules_tmpdir): def test_explicit_instance_paths(modules_tmpdir):
with pytest.raises(ValueError) as excinfo: with pytest.raises(ValueError) as excinfo:
flask.Flask(__name__, instance_path='instance') flask.Flask(__name__, instance_path="instance")
assert 'must be absolute' in str(excinfo.value) assert "must be absolute" in str(excinfo.value)
app = flask.Flask(__name__, instance_path=str(modules_tmpdir)) app = flask.Flask(__name__, instance_path=str(modules_tmpdir))
assert app.instance_path == str(modules_tmpdir) assert app.instance_path == str(modules_tmpdir)
def test_main_module_paths(modules_tmpdir, purge_module): def test_main_module_paths(modules_tmpdir, purge_module):
app = modules_tmpdir.join('main_app.py') app = modules_tmpdir.join("main_app.py")
app.write('import flask\n\napp = flask.Flask("__main__")') app.write('import flask\n\napp = flask.Flask("__main__")')
purge_module('main_app') purge_module("main_app")
from main_app import app from main_app import app
here = os.path.abspath(os.getcwd()) here = os.path.abspath(os.getcwd())
assert app.instance_path == os.path.join(here, 'instance') assert app.instance_path == os.path.join(here, "instance")
def test_uninstalled_module_paths(modules_tmpdir, purge_module): def test_uninstalled_module_paths(modules_tmpdir, purge_module):
app = modules_tmpdir.join('config_module_app.py').write( app = modules_tmpdir.join("config_module_app.py").write(
'import os\n' "import os\n"
'import flask\n' "import flask\n"
'here = os.path.abspath(os.path.dirname(__file__))\n' "here = os.path.abspath(os.path.dirname(__file__))\n"
'app = flask.Flask(__name__)\n' "app = flask.Flask(__name__)\n"
) )
purge_module('config_module_app') purge_module("config_module_app")
from config_module_app import app from config_module_app import app
assert app.instance_path == str(modules_tmpdir.join('instance'))
assert app.instance_path == str(modules_tmpdir.join("instance"))
def test_uninstalled_package_paths(modules_tmpdir, purge_module): def test_uninstalled_package_paths(modules_tmpdir, purge_module):
app = modules_tmpdir.mkdir('config_package_app') app = modules_tmpdir.mkdir("config_package_app")
init = app.join('__init__.py') init = app.join("__init__.py")
init.write( init.write(
'import os\n' "import os\n"
'import flask\n' "import flask\n"
'here = os.path.abspath(os.path.dirname(__file__))\n' "here = os.path.abspath(os.path.dirname(__file__))\n"
'app = flask.Flask(__name__)\n' "app = flask.Flask(__name__)\n"
) )
purge_module('config_package_app') purge_module("config_package_app")
from config_package_app import app from config_package_app import app
assert app.instance_path == str(modules_tmpdir.join('instance'))
assert app.instance_path == str(modules_tmpdir.join("instance"))
def test_installed_module_paths(modules_tmpdir, modules_tmpdir_prefix, def test_installed_module_paths(
purge_module, site_packages, limit_loader): modules_tmpdir, modules_tmpdir_prefix, purge_module, site_packages, limit_loader
site_packages.join('site_app.py').write( ):
'import flask\n' site_packages.join("site_app.py").write(
'app = flask.Flask(__name__)\n' "import flask\n" "app = flask.Flask(__name__)\n"
) )
purge_module('site_app') purge_module("site_app")
from site_app import app from site_app import app
assert app.instance_path == \
modules_tmpdir.join('var').join('site_app-instance') assert app.instance_path == modules_tmpdir.join("var").join("site_app-instance")
def test_installed_package_paths(limit_loader, modules_tmpdir, def test_installed_package_paths(
modules_tmpdir_prefix, purge_module, limit_loader, modules_tmpdir, modules_tmpdir_prefix, purge_module, monkeypatch
monkeypatch): ):
installed_path = modules_tmpdir.mkdir('path') installed_path = modules_tmpdir.mkdir("path")
monkeypatch.syspath_prepend(installed_path) monkeypatch.syspath_prepend(installed_path)
app = installed_path.mkdir('installed_package') app = installed_path.mkdir("installed_package")
init = app.join('__init__.py') init = app.join("__init__.py")
init.write('import flask\napp = flask.Flask(__name__)') init.write("import flask\napp = flask.Flask(__name__)")
purge_module('installed_package') purge_module("installed_package")
from installed_package import app from installed_package import app
assert app.instance_path == \
modules_tmpdir.join('var').join('installed_package-instance') assert app.instance_path == modules_tmpdir.join("var").join(
"installed_package-instance"
)
def test_prefix_package_paths(limit_loader, modules_tmpdir, def test_prefix_package_paths(
modules_tmpdir_prefix, purge_module, limit_loader, modules_tmpdir, modules_tmpdir_prefix, purge_module, site_packages
site_packages): ):
app = site_packages.mkdir('site_package') app = site_packages.mkdir("site_package")
init = app.join('__init__.py') init = app.join("__init__.py")
init.write('import flask\napp = flask.Flask(__name__)') init.write("import flask\napp = flask.Flask(__name__)")
purge_module('site_package') purge_module("site_package")
import site_package import site_package
assert site_package.app.instance_path == \
modules_tmpdir.join('var').join('site_package-instance')
assert site_package.app.instance_path == modules_tmpdir.join("var").join(
def test_egg_installed_paths(install_egg, modules_tmpdir, "site_package-instance"
modules_tmpdir_prefix):
modules_tmpdir.mkdir('site_egg').join('__init__.py').write(
'import flask\n\napp = flask.Flask(__name__)'
) )
install_egg('site_egg')
def test_egg_installed_paths(install_egg, modules_tmpdir, modules_tmpdir_prefix):
modules_tmpdir.mkdir("site_egg").join("__init__.py").write(
"import flask\n\napp = flask.Flask(__name__)"
)
install_egg("site_egg")
try: try:
import site_egg import site_egg
assert site_egg.app.instance_path == \
str(modules_tmpdir.join('var/').join('site_egg-instance')) assert site_egg.app.instance_path == str(
modules_tmpdir.join("var/").join("site_egg-instance")
)
finally: finally:
if 'site_egg' in sys.modules: if "site_egg" in sys.modules:
del sys.modules['site_egg'] del sys.modules["site_egg"]
@pytest.mark.skipif(not PY2, reason='This only works under Python 2.') @pytest.mark.skipif(not PY2, reason="This only works under Python 2.")
def test_meta_path_loader_without_is_package(request, modules_tmpdir): def test_meta_path_loader_without_is_package(request, modules_tmpdir):
app = modules_tmpdir.join('unimportable.py') app = modules_tmpdir.join("unimportable.py")
app.write('import flask\napp = flask.Flask(__name__)') app.write("import flask\napp = flask.Flask(__name__)")
class Loader(object): class Loader(object):
def find_module(self, name, path=None): def find_module(self, name, path=None):

View file

@ -16,18 +16,21 @@ from flask import Markup
from flask.json.tag import TaggedJSONSerializer, JSONTag from flask.json.tag import TaggedJSONSerializer, JSONTag
@pytest.mark.parametrize("data", ( @pytest.mark.parametrize(
{' t': (1, 2, 3)}, "data",
{' t__': b'a'}, (
{' di': ' di'}, {" t": (1, 2, 3)},
{'x': (1, 2, 3), 'y': 4}, {" t__": b"a"},
(1, 2, 3), {" di": " di"},
[(1, 2, 3)], {"x": (1, 2, 3), "y": 4},
b'\xff', (1, 2, 3),
Markup('<html>'), [(1, 2, 3)],
uuid4(), b"\xff",
datetime.utcnow().replace(microsecond=0), Markup("<html>"),
)) uuid4(),
datetime.utcnow().replace(microsecond=0),
),
)
def test_dump_load_unchanged(data): def test_dump_load_unchanged(data):
s = TaggedJSONSerializer() s = TaggedJSONSerializer()
assert s.loads(s.dumps(data)) == data assert s.loads(s.dumps(data)) == data
@ -35,12 +38,12 @@ def test_dump_load_unchanged(data):
def test_duplicate_tag(): def test_duplicate_tag():
class TagDict(JSONTag): class TagDict(JSONTag):
key = ' d' key = " d"
s = TaggedJSONSerializer() s = TaggedJSONSerializer()
pytest.raises(KeyError, s.register, TagDict) pytest.raises(KeyError, s.register, TagDict)
s.register(TagDict, force=True, index=0) s.register(TagDict, force=True, index=0)
assert isinstance(s.tags[' d'], TagDict) assert isinstance(s.tags[" d"], TagDict)
assert isinstance(s.order[0], TagDict) assert isinstance(s.order[0], TagDict)
@ -51,7 +54,7 @@ def test_custom_tag():
class TagFoo(JSONTag): class TagFoo(JSONTag):
__slots__ = () __slots__ = ()
key = ' f' key = " f"
def check(self, value): def check(self, value):
return isinstance(value, Foo) return isinstance(value, Foo)
@ -64,7 +67,7 @@ def test_custom_tag():
s = TaggedJSONSerializer() s = TaggedJSONSerializer()
s.register(TagFoo) s.register(TagFoo)
assert s.loads(s.dumps(Foo('bar'))).data == 'bar' assert s.loads(s.dumps(Foo("bar"))).data == "bar"
def test_tag_interface(): def test_tag_interface():
@ -76,10 +79,10 @@ def test_tag_interface():
def test_tag_order(): def test_tag_order():
class Tag1(JSONTag): class Tag1(JSONTag):
key = ' 1' key = " 1"
class Tag2(JSONTag): class Tag2(JSONTag):
key = ' 2' key = " 2"
s = TaggedJSONSerializer() s = TaggedJSONSerializer()

View file

@ -13,8 +13,7 @@ import sys
import pytest import pytest
from flask._compat import StringIO from flask._compat import StringIO
from flask.logging import default_handler, has_level_handler, \ from flask.logging import default_handler, has_level_handler, wsgi_errors_stream
wsgi_errors_stream
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
@ -23,12 +22,11 @@ def reset_logging(pytestconfig):
logging.root.handlers = [] logging.root.handlers = []
root_level = logging.root.level root_level = logging.root.level
logger = logging.getLogger('flask.app') logger = logging.getLogger("flask.app")
logger.handlers = [] logger.handlers = []
logger.setLevel(logging.NOTSET) logger.setLevel(logging.NOTSET)
logging_plugin = pytestconfig.pluginmanager.unregister( logging_plugin = pytestconfig.pluginmanager.unregister(name="logging-plugin")
name='logging-plugin')
yield yield
@ -39,11 +37,11 @@ def reset_logging(pytestconfig):
logger.setLevel(logging.NOTSET) logger.setLevel(logging.NOTSET)
if logging_plugin: if logging_plugin:
pytestconfig.pluginmanager.register(logging_plugin, 'logging-plugin') pytestconfig.pluginmanager.register(logging_plugin, "logging-plugin")
def test_logger(app): def test_logger(app):
assert app.logger.name == 'flask.app' assert app.logger.name == "flask.app"
assert app.logger.level == logging.NOTSET assert app.logger.level == logging.NOTSET
assert app.logger.handlers == [default_handler] assert app.logger.handlers == [default_handler]
@ -61,14 +59,14 @@ def test_existing_handler(app):
def test_wsgi_errors_stream(app, client): def test_wsgi_errors_stream(app, client):
@app.route('/') @app.route("/")
def index(): def index():
app.logger.error('test') app.logger.error("test")
return '' return ""
stream = StringIO() stream = StringIO()
client.get('/', errors_stream=stream) client.get("/", errors_stream=stream)
assert 'ERROR in test_logging: test' in stream.getvalue() assert "ERROR in test_logging: test" in stream.getvalue()
assert wsgi_errors_stream._get_current_object() is sys.stderr assert wsgi_errors_stream._get_current_object() is sys.stderr
@ -77,7 +75,7 @@ def test_wsgi_errors_stream(app, client):
def test_has_level_handler(): def test_has_level_handler():
logger = logging.getLogger('flask.app') logger = logging.getLogger("flask.app")
assert not has_level_handler(logger) assert not has_level_handler(logger)
handler = logging.StreamHandler() handler = logging.StreamHandler()
@ -93,15 +91,15 @@ def test_has_level_handler():
def test_log_view_exception(app, client): def test_log_view_exception(app, client):
@app.route('/') @app.route("/")
def index(): def index():
raise Exception('test') raise Exception("test")
app.testing = False app.testing = False
stream = StringIO() stream = StringIO()
rv = client.get('/', errors_stream=stream) rv = client.get("/", errors_stream=stream)
assert rv.status_code == 500 assert rv.status_code == 500
assert rv.data assert rv.data
err = stream.getvalue() err = stream.getvalue()
assert 'Exception on / [GET]' in err assert "Exception on / [GET]" in err
assert 'Exception: test' in err assert "Exception: test" in err

View file

@ -22,7 +22,6 @@ _gc_lock = threading.Lock()
class assert_no_leak(object): class assert_no_leak(object):
def __enter__(self): def __enter__(self):
gc.disable() gc.disable()
_gc_lock.acquire() _gc_lock.acquire()
@ -32,7 +31,7 @@ class assert_no_leak(object):
# This is necessary since Python only starts tracking # This is necessary since Python only starts tracking
# dicts if they contain mutable objects. It's a horrible, # dicts if they contain mutable objects. It's a horrible,
# horrible hack but makes this kinda testable. # horrible hack but makes this kinda testable.
loc.__storage__['FOOO'] = [1, 2, 3] loc.__storage__["FOOO"] = [1, 2, 3]
gc.collect() gc.collect()
self.old_objects = len(gc.get_objects()) self.old_objects = len(gc.get_objects())
@ -41,7 +40,7 @@ class assert_no_leak(object):
gc.collect() gc.collect()
new_objects = len(gc.get_objects()) new_objects = len(gc.get_objects())
if new_objects > self.old_objects: if new_objects > self.old_objects:
pytest.fail('Example code leaked') pytest.fail("Example code leaked")
_gc_lock.release() _gc_lock.release()
gc.enable() gc.enable()
@ -49,22 +48,21 @@ class assert_no_leak(object):
def test_memory_consumption(): def test_memory_consumption():
app = flask.Flask(__name__) app = flask.Flask(__name__)
@app.route('/') @app.route("/")
def index(): def index():
return flask.render_template('simple_template.html', whiskey=42) return flask.render_template("simple_template.html", whiskey=42)
def fire(): def fire():
with app.test_client() as c: with app.test_client() as c:
rv = c.get('/') rv = c.get("/")
assert rv.status_code == 200 assert rv.status_code == 200
assert rv.data == b'<h1>42</h1>' assert rv.data == b"<h1>42</h1>"
# Trigger caches # Trigger caches
fire() fire()
# This test only works on CPython 2.7. # This test only works on CPython 2.7.
if sys.version_info >= (2, 7) and \ if sys.version_info >= (2, 7) and not hasattr(sys, "pypy_translation_info"):
not hasattr(sys, 'pypy_translation_info'):
with assert_no_leak(): with assert_no_leak():
for x in range(10): for x in range(10):
fire() fire()
@ -72,8 +70,9 @@ def test_memory_consumption():
def test_safe_join_toplevel_pardir(): def test_safe_join_toplevel_pardir():
from flask.helpers import safe_join from flask.helpers import safe_join
with pytest.raises(NotFound): with pytest.raises(NotFound):
safe_join('/foo', '..') safe_join("/foo", "..")
def test_aborting(app): def test_aborting(app):
@ -84,16 +83,16 @@ def test_aborting(app):
def handle_foo(e): def handle_foo(e):
return str(e.whatever) return str(e.whatever)
@app.route('/') @app.route("/")
def index(): def index():
raise flask.abort(flask.redirect(flask.url_for('test'))) raise flask.abort(flask.redirect(flask.url_for("test")))
@app.route('/test') @app.route("/test")
def test(): def test():
raise Foo() raise Foo()
with app.test_client() as c: with app.test_client() as c:
rv = c.get('/') rv = c.get("/")
assert rv.headers['Location'] == 'http://localhost/test' assert rv.headers["Location"] == "http://localhost/test"
rv = c.get('/test') rv = c.get("/test")
assert rv.data == b'42' assert rv.data == b"42"

View file

@ -42,7 +42,7 @@ def test_teardown_with_previous_exception(app):
buffer.append(exception) buffer.append(exception)
try: try:
raise Exception('dummy') raise Exception("dummy")
except Exception: except Exception:
pass pass
@ -61,35 +61,39 @@ def test_teardown_with_handled_exception(app):
with app.test_request_context(): with app.test_request_context():
assert buffer == [] assert buffer == []
try: try:
raise Exception('dummy') raise Exception("dummy")
except Exception: except Exception:
pass pass
assert buffer == [None] assert buffer == [None]
def test_proper_test_request_context(app): def test_proper_test_request_context(app):
app.config.update( app.config.update(SERVER_NAME="localhost.localdomain:5000")
SERVER_NAME='localhost.localdomain:5000'
)
@app.route('/') @app.route("/")
def index(): def index():
return None return None
@app.route('/', subdomain='foo') @app.route("/", subdomain="foo")
def sub(): def sub():
return None return None
with app.test_request_context('/'): with app.test_request_context("/"):
assert flask.url_for('index', _external=True) == \ assert (
'http://localhost.localdomain:5000/' flask.url_for("index", _external=True)
== "http://localhost.localdomain:5000/"
)
with app.test_request_context('/'): with app.test_request_context("/"):
assert flask.url_for('sub', _external=True) == \ assert (
'http://foo.localhost.localdomain:5000/' flask.url_for("sub", _external=True)
== "http://foo.localhost.localdomain:5000/"
)
try: try:
with app.test_request_context('/', environ_overrides={'HTTP_HOST': 'localhost'}): with app.test_request_context(
"/", environ_overrides={"HTTP_HOST": "localhost"}
):
pass pass
except ValueError as e: except ValueError as e:
assert str(e) == ( assert str(e) == (
@ -98,28 +102,30 @@ def test_proper_test_request_context(app):
"server name from the WSGI environment ('localhost')" "server name from the WSGI environment ('localhost')"
) )
app.config.update(SERVER_NAME='localhost') app.config.update(SERVER_NAME="localhost")
with app.test_request_context('/', environ_overrides={'SERVER_NAME': 'localhost'}): with app.test_request_context("/", environ_overrides={"SERVER_NAME": "localhost"}):
pass pass
app.config.update(SERVER_NAME='localhost:80') app.config.update(SERVER_NAME="localhost:80")
with app.test_request_context('/', environ_overrides={'SERVER_NAME': 'localhost:80'}): with app.test_request_context(
"/", environ_overrides={"SERVER_NAME": "localhost:80"}
):
pass pass
def test_context_binding(app): def test_context_binding(app):
@app.route('/') @app.route("/")
def index(): def index():
return 'Hello %s!' % flask.request.args['name'] return "Hello %s!" % flask.request.args["name"]
@app.route('/meh') @app.route("/meh")
def meh(): def meh():
return flask.request.url return flask.request.url
with app.test_request_context('/?name=World'): with app.test_request_context("/?name=World"):
assert index() == 'Hello World!' assert index() == "Hello World!"
with app.test_request_context('/meh'): with app.test_request_context("/meh"):
assert meh() == 'http://localhost/meh' assert meh() == "http://localhost/meh"
assert flask._request_ctx_stack.top is None assert flask._request_ctx_stack.top is None
@ -136,27 +142,26 @@ def test_context_test(app):
def test_manual_context_binding(app): def test_manual_context_binding(app):
@app.route('/') @app.route("/")
def index(): def index():
return 'Hello %s!' % flask.request.args['name'] return "Hello %s!" % flask.request.args["name"]
ctx = app.test_request_context('/?name=World') ctx = app.test_request_context("/?name=World")
ctx.push() ctx.push()
assert index() == 'Hello World!' assert index() == "Hello World!"
ctx.pop() ctx.pop()
with pytest.raises(RuntimeError): with pytest.raises(RuntimeError):
index() index()
@pytest.mark.skipif(greenlet is None, reason='greenlet not installed') @pytest.mark.skipif(greenlet is None, reason="greenlet not installed")
class TestGreenletContextCopying(object): class TestGreenletContextCopying(object):
def test_greenlet_context_copying(self, app, client): def test_greenlet_context_copying(self, app, client):
greenlets = [] greenlets = []
@app.route('/') @app.route("/")
def index(): def index():
flask.session['fizz'] = 'buzz' flask.session["fizz"] = "buzz"
reqctx = flask._request_ctx_stack.top.copy() reqctx = flask._request_ctx_stack.top.copy()
def g(): def g():
@ -165,17 +170,17 @@ class TestGreenletContextCopying(object):
with reqctx: with reqctx:
assert flask.request assert flask.request
assert flask.current_app == app assert flask.current_app == app
assert flask.request.path == '/' assert flask.request.path == "/"
assert flask.request.args['foo'] == 'bar' assert flask.request.args["foo"] == "bar"
assert flask.session.get('fizz') == 'buzz' assert flask.session.get("fizz") == "buzz"
assert not flask.request assert not flask.request
return 42 return 42
greenlets.append(greenlet(g)) greenlets.append(greenlet(g))
return 'Hello World!' return "Hello World!"
rv = client.get('/?foo=bar') rv = client.get("/?foo=bar")
assert rv.data == b'Hello World!' assert rv.data == b"Hello World!"
result = greenlets[0].run() result = greenlets[0].run()
assert result == 42 assert result == 42
@ -183,25 +188,25 @@ class TestGreenletContextCopying(object):
def test_greenlet_context_copying_api(self, app, client): def test_greenlet_context_copying_api(self, app, client):
greenlets = [] greenlets = []
@app.route('/') @app.route("/")
def index(): def index():
flask.session['fizz'] = 'buzz' flask.session["fizz"] = "buzz"
reqctx = flask._request_ctx_stack.top.copy() reqctx = flask._request_ctx_stack.top.copy()
@flask.copy_current_request_context @flask.copy_current_request_context
def g(): def g():
assert flask.request assert flask.request
assert flask.current_app == app assert flask.current_app == app
assert flask.request.path == '/' assert flask.request.path == "/"
assert flask.request.args['foo'] == 'bar' assert flask.request.args["foo"] == "bar"
assert flask.session.get('fizz') == 'buzz' assert flask.session.get("fizz") == "buzz"
return 42 return 42
greenlets.append(greenlet(g)) greenlets.append(greenlet(g))
return 'Hello World!' return "Hello World!"
rv = client.get('/?foo=bar') rv = client.get("/?foo=bar")
assert rv.data == b'Hello World!' assert rv.data == b"Hello World!"
result = greenlets[0].run() result = greenlets[0].run()
assert result == 42 assert result == 42
@ -220,12 +225,12 @@ def test_session_error_pops_context():
app = CustomFlask(__name__) app = CustomFlask(__name__)
@app.route('/') @app.route("/")
def index(): def index():
# shouldn't get here # shouldn't get here
assert False assert False
response = app.test_client().get('/') response = app.test_client().get("/")
assert response.status_code == 500 assert response.status_code == 500
assert not flask.request assert not flask.request
assert not flask.current_app assert not flask.current_app
@ -239,11 +244,12 @@ def test_bad_environ_raises_bad_request():
# However it works when actually passed to the server. # However it works when actually passed to the server.
from flask.testing import make_test_environ_builder from flask.testing import make_test_environ_builder
builder = make_test_environ_builder(app) builder = make_test_environ_builder(app)
environ = builder.get_environ() environ = builder.get_environ()
# use a non-printable character in the Host - this is key to this test # use a non-printable character in the Host - this is key to this test
environ['HTTP_HOST'] = u'\x8a' environ["HTTP_HOST"] = u"\x8a"
with app.request_context(environ): with app.request_context(environ):
response = app.full_dispatch_request() response = app.full_dispatch_request()
@ -253,20 +259,21 @@ def test_bad_environ_raises_bad_request():
def test_environ_for_valid_idna_completes(): def test_environ_for_valid_idna_completes():
app = flask.Flask(__name__) app = flask.Flask(__name__)
@app.route('/') @app.route("/")
def index(): def index():
return 'Hello World!' return "Hello World!"
# We cannot use app.test_client() for the Unicode-rich Host header, # We cannot use app.test_client() for the Unicode-rich Host header,
# because werkzeug enforces latin1 on Python 2. # because werkzeug enforces latin1 on Python 2.
# However it works when actually passed to the server. # However it works when actually passed to the server.
from flask.testing import make_test_environ_builder from flask.testing import make_test_environ_builder
builder = make_test_environ_builder(app) builder = make_test_environ_builder(app)
environ = builder.get_environ() environ = builder.get_environ()
# these characters are all IDNA-compatible # these characters are all IDNA-compatible
environ['HTTP_HOST'] = u'ąśźäüжŠßя.com' environ["HTTP_HOST"] = u"ąśźäüжŠßя.com"
with app.request_context(environ): with app.request_context(environ):
response = app.full_dispatch_request() response = app.full_dispatch_request()
@ -277,9 +284,9 @@ def test_environ_for_valid_idna_completes():
def test_normal_environ_completes(): def test_normal_environ_completes():
app = flask.Flask(__name__) app = flask.Flask(__name__)
@app.route('/') @app.route("/")
def index(): def index():
return 'Hello World!' return "Hello World!"
response = app.test_client().get('/', headers={'host': 'xn--on-0ia.com'}) response = app.test_client().get("/", headers={"host": "xn--on-0ia.com"})
assert response.status_code == 200 assert response.status_code == 200

View file

@ -19,15 +19,14 @@ except ImportError:
import flask import flask
pytestmark = pytest.mark.skipif( pytestmark = pytest.mark.skipif(
blinker is None, blinker is None, reason="Signals require the blinker library."
reason='Signals require the blinker library.'
) )
def test_template_rendered(app, client): def test_template_rendered(app, client):
@app.route('/') @app.route("/")
def index(): def index():
return flask.render_template('simple_template.html', whiskey=42) return flask.render_template("simple_template.html", whiskey=42)
recorded = [] recorded = []
@ -36,11 +35,11 @@ def test_template_rendered(app, client):
flask.template_rendered.connect(record, app) flask.template_rendered.connect(record, app)
try: try:
client.get('/') client.get("/")
assert len(recorded) == 1 assert len(recorded) == 1
template, context = recorded[0] template, context = recorded[0]
assert template.name == 'simple_template.html' assert template.name == "simple_template.html"
assert context['whiskey'] == 42 assert context["whiskey"] == 42
finally: finally:
flask.template_rendered.disconnect(record, app) flask.template_rendered.disconnect(record, app)
@ -48,24 +47,24 @@ def test_template_rendered(app, client):
def test_before_render_template(): def test_before_render_template():
app = flask.Flask(__name__) app = flask.Flask(__name__)
@app.route('/') @app.route("/")
def index(): def index():
return flask.render_template('simple_template.html', whiskey=42) return flask.render_template("simple_template.html", whiskey=42)
recorded = [] recorded = []
def record(sender, template, context): def record(sender, template, context):
context['whiskey'] = 43 context["whiskey"] = 43
recorded.append((template, context)) recorded.append((template, context))
flask.before_render_template.connect(record, app) flask.before_render_template.connect(record, app)
try: try:
rv = app.test_client().get('/') rv = app.test_client().get("/")
assert len(recorded) == 1 assert len(recorded) == 1
template, context = recorded[0] template, context = recorded[0]
assert template.name == 'simple_template.html' assert template.name == "simple_template.html"
assert context['whiskey'] == 43 assert context["whiskey"] == 43
assert rv.data == b'<h1>43</h1>' assert rv.data == b"<h1>43</h1>"
finally: finally:
flask.before_render_template.disconnect(record, app) flask.before_render_template.disconnect(record, app)
@ -75,36 +74,41 @@ def test_request_signals():
calls = [] calls = []
def before_request_signal(sender): def before_request_signal(sender):
calls.append('before-signal') calls.append("before-signal")
def after_request_signal(sender, response): def after_request_signal(sender, response):
assert response.data == b'stuff' assert response.data == b"stuff"
calls.append('after-signal') calls.append("after-signal")
@app.before_request @app.before_request
def before_request_handler(): def before_request_handler():
calls.append('before-handler') calls.append("before-handler")
@app.after_request @app.after_request
def after_request_handler(response): def after_request_handler(response):
calls.append('after-handler') calls.append("after-handler")
response.data = 'stuff' response.data = "stuff"
return response return response
@app.route('/') @app.route("/")
def index(): def index():
calls.append('handler') calls.append("handler")
return 'ignored anyway' return "ignored anyway"
flask.request_started.connect(before_request_signal, app) flask.request_started.connect(before_request_signal, app)
flask.request_finished.connect(after_request_signal, app) flask.request_finished.connect(after_request_signal, app)
try: try:
rv = app.test_client().get('/') rv = app.test_client().get("/")
assert rv.data == b'stuff' assert rv.data == b"stuff"
assert calls == ['before-signal', 'before-handler', 'handler', assert calls == [
'after-handler', 'after-signal'] "before-signal",
"before-handler",
"handler",
"after-handler",
"after-signal",
]
finally: finally:
flask.request_started.disconnect(before_request_signal, app) flask.request_started.disconnect(before_request_signal, app)
flask.request_finished.disconnect(after_request_signal, app) flask.request_finished.disconnect(after_request_signal, app)
@ -114,7 +118,7 @@ def test_request_exception_signal():
app = flask.Flask(__name__) app = flask.Flask(__name__)
recorded = [] recorded = []
@app.route('/') @app.route("/")
def index(): def index():
1 // 0 1 // 0
@ -123,7 +127,7 @@ def test_request_exception_signal():
flask.got_request_exception.connect(record, app) flask.got_request_exception.connect(record, app)
try: try:
assert app.test_client().get('/').status_code == 500 assert app.test_client().get("/").status_code == 500
assert len(recorded) == 1 assert len(recorded) == 1
assert isinstance(recorded[0], ZeroDivisionError) assert isinstance(recorded[0], ZeroDivisionError)
finally: finally:
@ -135,33 +139,33 @@ def test_appcontext_signals():
recorded = [] recorded = []
def record_push(sender, **kwargs): def record_push(sender, **kwargs):
recorded.append('push') recorded.append("push")
def record_pop(sender, **kwargs): def record_pop(sender, **kwargs):
recorded.append('pop') recorded.append("pop")
@app.route('/') @app.route("/")
def index(): def index():
return 'Hello' return "Hello"
flask.appcontext_pushed.connect(record_push, app) flask.appcontext_pushed.connect(record_push, app)
flask.appcontext_popped.connect(record_pop, app) flask.appcontext_popped.connect(record_pop, app)
try: try:
with app.test_client() as c: with app.test_client() as c:
rv = c.get('/') rv = c.get("/")
assert rv.data == b'Hello' assert rv.data == b"Hello"
assert recorded == ['push'] assert recorded == ["push"]
assert recorded == ['push', 'pop'] assert recorded == ["push", "pop"]
finally: finally:
flask.appcontext_pushed.disconnect(record_push, app) flask.appcontext_pushed.disconnect(record_push, app)
flask.appcontext_popped.disconnect(record_pop, app) flask.appcontext_popped.disconnect(record_pop, app)
def test_flash_signal(app): def test_flash_signal(app):
@app.route('/') @app.route("/")
def index(): def index():
flask.flash('This is a flash message', category='notice') flask.flash("This is a flash message", category="notice")
return flask.redirect('/other') return flask.redirect("/other")
recorded = [] recorded = []
@ -172,11 +176,11 @@ def test_flash_signal(app):
try: try:
client = app.test_client() client = app.test_client()
with client.session_transaction(): with client.session_transaction():
client.get('/') client.get("/")
assert len(recorded) == 1 assert len(recorded) == 1
message, category = recorded[0] message, category = recorded[0]
assert message == 'This is a flash message' assert message == "This is a flash message"
assert category == 'notice' assert category == "notice"
finally: finally:
flask.message_flashed.disconnect(record, app) flask.message_flashed.disconnect(record, app)
@ -186,18 +190,18 @@ def test_appcontext_tearing_down_signal():
recorded = [] recorded = []
def record_teardown(sender, **kwargs): def record_teardown(sender, **kwargs):
recorded.append(('tear_down', kwargs)) recorded.append(("tear_down", kwargs))
@app.route('/') @app.route("/")
def index(): def index():
1 // 0 1 // 0
flask.appcontext_tearing_down.connect(record_teardown, app) flask.appcontext_tearing_down.connect(record_teardown, app)
try: try:
with app.test_client() as c: with app.test_client() as c:
rv = c.get('/') rv = c.get("/")
assert rv.status_code == 500 assert rv.status_code == 500
assert recorded == [] assert recorded == []
assert recorded == [('tear_down', {'exc': None})] assert recorded == [("tear_down", {"exc": None})]
finally: finally:
flask.appcontext_tearing_down.disconnect(record_teardown, app) flask.appcontext_tearing_down.disconnect(record_teardown, app)

View file

@ -23,11 +23,11 @@ def test_suppressed_exception_logging():
out = StringIO() out = StringIO()
app = SuppressedFlask(__name__) app = SuppressedFlask(__name__)
@app.route('/') @app.route("/")
def index(): def index():
raise Exception('test') raise Exception("test")
rv = app.test_client().get('/', errors_stream=out) rv = app.test_client().get("/", errors_stream=out)
assert rv.status_code == 500 assert rv.status_code == 500
assert b'Internal Server Error' in rv.data assert b"Internal Server Error" in rv.data
assert not out.getvalue() assert not out.getvalue()

View file

@ -20,102 +20,104 @@ import werkzeug.serving
def test_context_processing(app, client): def test_context_processing(app, client):
@app.context_processor @app.context_processor
def context_processor(): def context_processor():
return {'injected_value': 42} return {"injected_value": 42}
@app.route('/') @app.route("/")
def index(): def index():
return flask.render_template('context_template.html', value=23) return flask.render_template("context_template.html", value=23)
rv = client.get('/') rv = client.get("/")
assert rv.data == b'<p>23|42' assert rv.data == b"<p>23|42"
def test_original_win(app, client): def test_original_win(app, client):
@app.route('/') @app.route("/")
def index(): def index():
return flask.render_template_string('{{ config }}', config=42) return flask.render_template_string("{{ config }}", config=42)
rv = client.get('/') rv = client.get("/")
assert rv.data == b'42' assert rv.data == b"42"
def test_request_less_rendering(app, app_ctx): def test_request_less_rendering(app, app_ctx):
app.config['WORLD_NAME'] = 'Special World' app.config["WORLD_NAME"] = "Special World"
@app.context_processor @app.context_processor
def context_processor(): def context_processor():
return dict(foo=42) return dict(foo=42)
rv = flask.render_template_string('Hello {{ config.WORLD_NAME }} ' rv = flask.render_template_string("Hello {{ config.WORLD_NAME }} " "{{ foo }}")
'{{ foo }}') assert rv == "Hello Special World 42"
assert rv == 'Hello Special World 42'
def test_standard_context(app, client): def test_standard_context(app, client):
@app.route('/') @app.route("/")
def index(): def index():
flask.g.foo = 23 flask.g.foo = 23
flask.session['test'] = 'aha' flask.session["test"] = "aha"
return flask.render_template_string(''' return flask.render_template_string(
"""
{{ request.args.foo }} {{ request.args.foo }}
{{ g.foo }} {{ g.foo }}
{{ config.DEBUG }} {{ config.DEBUG }}
{{ session.test }} {{ session.test }}
''') """
)
rv = client.get('/?foo=42') rv = client.get("/?foo=42")
assert rv.data.split() == [b'42', b'23', b'False', b'aha'] assert rv.data.split() == [b"42", b"23", b"False", b"aha"]
def test_escaping(app, client): def test_escaping(app, client):
text = '<p>Hello World!' text = "<p>Hello World!"
@app.route('/') @app.route("/")
def index(): def index():
return flask.render_template('escaping_template.html', text=text, return flask.render_template(
html=flask.Markup(text)) "escaping_template.html", text=text, html=flask.Markup(text)
)
lines = client.get('/').data.splitlines() lines = client.get("/").data.splitlines()
assert lines == [ assert lines == [
b'&lt;p&gt;Hello World!', b"&lt;p&gt;Hello World!",
b'<p>Hello World!', b"<p>Hello World!",
b'<p>Hello World!', b"<p>Hello World!",
b'<p>Hello World!', b"<p>Hello World!",
b'&lt;p&gt;Hello World!', b"&lt;p&gt;Hello World!",
b'<p>Hello World!' b"<p>Hello World!",
] ]
def test_no_escaping(app, client): def test_no_escaping(app, client):
text = '<p>Hello World!' text = "<p>Hello World!"
@app.route('/') @app.route("/")
def index(): def index():
return flask.render_template('non_escaping_template.txt', text=text, return flask.render_template(
html=flask.Markup(text)) "non_escaping_template.txt", text=text, html=flask.Markup(text)
)
lines = client.get('/').data.splitlines() lines = client.get("/").data.splitlines()
assert lines == [ assert lines == [
b'<p>Hello World!', b"<p>Hello World!",
b'<p>Hello World!', b"<p>Hello World!",
b'<p>Hello World!', b"<p>Hello World!",
b'<p>Hello World!', b"<p>Hello World!",
b'&lt;p&gt;Hello World!', b"&lt;p&gt;Hello World!",
b'<p>Hello World!', b"<p>Hello World!",
b'<p>Hello World!', b"<p>Hello World!",
b'<p>Hello World!' b"<p>Hello World!",
] ]
def test_escaping_without_template_filename(app, client, req_ctx): def test_escaping_without_template_filename(app, client, req_ctx):
assert flask.render_template_string( assert flask.render_template_string("{{ foo }}", foo="<test>") == "&lt;test&gt;"
'{{ foo }}', foo='<test>') == '&lt;test&gt;' assert flask.render_template("mail.txt", foo="<test>") == "<test> Mail"
assert flask.render_template('mail.txt', foo='<test>') == '<test> Mail'
def test_macros(app, req_ctx): def test_macros(app, req_ctx):
macro = flask.get_template_attribute('_macro.html', 'hello') macro = flask.get_template_attribute("_macro.html", "hello")
assert macro('World') == 'Hello World!' assert macro("World") == "Hello World!"
def test_template_filter(app): def test_template_filter(app):
@ -123,9 +125,9 @@ def test_template_filter(app):
def my_reverse(s): def my_reverse(s):
return s[::-1] return s[::-1]
assert 'my_reverse' in app.jinja_env.filters.keys() assert "my_reverse" in app.jinja_env.filters.keys()
assert app.jinja_env.filters['my_reverse'] == my_reverse assert app.jinja_env.filters["my_reverse"] == my_reverse
assert app.jinja_env.filters['my_reverse']('abcd') == 'dcba' assert app.jinja_env.filters["my_reverse"]("abcd") == "dcba"
def test_add_template_filter(app): def test_add_template_filter(app):
@ -133,29 +135,29 @@ def test_add_template_filter(app):
return s[::-1] return s[::-1]
app.add_template_filter(my_reverse) app.add_template_filter(my_reverse)
assert 'my_reverse' in app.jinja_env.filters.keys() assert "my_reverse" in app.jinja_env.filters.keys()
assert app.jinja_env.filters['my_reverse'] == my_reverse assert app.jinja_env.filters["my_reverse"] == my_reverse
assert app.jinja_env.filters['my_reverse']('abcd') == 'dcba' assert app.jinja_env.filters["my_reverse"]("abcd") == "dcba"
def test_template_filter_with_name(app): def test_template_filter_with_name(app):
@app.template_filter('strrev') @app.template_filter("strrev")
def my_reverse(s): def my_reverse(s):
return s[::-1] return s[::-1]
assert 'strrev' in app.jinja_env.filters.keys() assert "strrev" in app.jinja_env.filters.keys()
assert app.jinja_env.filters['strrev'] == my_reverse assert app.jinja_env.filters["strrev"] == my_reverse
assert app.jinja_env.filters['strrev']('abcd') == 'dcba' assert app.jinja_env.filters["strrev"]("abcd") == "dcba"
def test_add_template_filter_with_name(app): def test_add_template_filter_with_name(app):
def my_reverse(s): def my_reverse(s):
return s[::-1] return s[::-1]
app.add_template_filter(my_reverse, 'strrev') app.add_template_filter(my_reverse, "strrev")
assert 'strrev' in app.jinja_env.filters.keys() assert "strrev" in app.jinja_env.filters.keys()
assert app.jinja_env.filters['strrev'] == my_reverse assert app.jinja_env.filters["strrev"] == my_reverse
assert app.jinja_env.filters['strrev']('abcd') == 'dcba' assert app.jinja_env.filters["strrev"]("abcd") == "dcba"
def test_template_filter_with_template(app, client): def test_template_filter_with_template(app, client):
@ -163,12 +165,12 @@ def test_template_filter_with_template(app, client):
def super_reverse(s): def super_reverse(s):
return s[::-1] return s[::-1]
@app.route('/') @app.route("/")
def index(): def index():
return flask.render_template('template_filter.html', value='abcd') return flask.render_template("template_filter.html", value="abcd")
rv = client.get('/') rv = client.get("/")
assert rv.data == b'dcba' assert rv.data == b"dcba"
def test_add_template_filter_with_template(app, client): def test_add_template_filter_with_template(app, client):
@ -177,39 +179,39 @@ def test_add_template_filter_with_template(app, client):
app.add_template_filter(super_reverse) app.add_template_filter(super_reverse)
@app.route('/') @app.route("/")
def index(): def index():
return flask.render_template('template_filter.html', value='abcd') return flask.render_template("template_filter.html", value="abcd")
rv = client.get('/') rv = client.get("/")
assert rv.data == b'dcba' assert rv.data == b"dcba"
def test_template_filter_with_name_and_template(app, client): def test_template_filter_with_name_and_template(app, client):
@app.template_filter('super_reverse') @app.template_filter("super_reverse")
def my_reverse(s): def my_reverse(s):
return s[::-1] return s[::-1]
@app.route('/') @app.route("/")
def index(): def index():
return flask.render_template('template_filter.html', value='abcd') return flask.render_template("template_filter.html", value="abcd")
rv = client.get('/') rv = client.get("/")
assert rv.data == b'dcba' assert rv.data == b"dcba"
def test_add_template_filter_with_name_and_template(app, client): def test_add_template_filter_with_name_and_template(app, client):
def my_reverse(s): def my_reverse(s):
return s[::-1] return s[::-1]
app.add_template_filter(my_reverse, 'super_reverse') app.add_template_filter(my_reverse, "super_reverse")
@app.route('/') @app.route("/")
def index(): def index():
return flask.render_template('template_filter.html', value='abcd') return flask.render_template("template_filter.html", value="abcd")
rv = client.get('/') rv = client.get("/")
assert rv.data == b'dcba' assert rv.data == b"dcba"
def test_template_test(app): def test_template_test(app):
@ -217,9 +219,9 @@ def test_template_test(app):
def boolean(value): def boolean(value):
return isinstance(value, bool) return isinstance(value, bool)
assert 'boolean' in app.jinja_env.tests.keys() assert "boolean" in app.jinja_env.tests.keys()
assert app.jinja_env.tests['boolean'] == boolean assert app.jinja_env.tests["boolean"] == boolean
assert app.jinja_env.tests['boolean'](False) assert app.jinja_env.tests["boolean"](False)
def test_add_template_test(app): def test_add_template_test(app):
@ -227,29 +229,29 @@ def test_add_template_test(app):
return isinstance(value, bool) return isinstance(value, bool)
app.add_template_test(boolean) app.add_template_test(boolean)
assert 'boolean' in app.jinja_env.tests.keys() assert "boolean" in app.jinja_env.tests.keys()
assert app.jinja_env.tests['boolean'] == boolean assert app.jinja_env.tests["boolean"] == boolean
assert app.jinja_env.tests['boolean'](False) assert app.jinja_env.tests["boolean"](False)
def test_template_test_with_name(app): def test_template_test_with_name(app):
@app.template_test('boolean') @app.template_test("boolean")
def is_boolean(value): def is_boolean(value):
return isinstance(value, bool) return isinstance(value, bool)
assert 'boolean' in app.jinja_env.tests.keys() assert "boolean" in app.jinja_env.tests.keys()
assert app.jinja_env.tests['boolean'] == is_boolean assert app.jinja_env.tests["boolean"] == is_boolean
assert app.jinja_env.tests['boolean'](False) assert app.jinja_env.tests["boolean"](False)
def test_add_template_test_with_name(app): def test_add_template_test_with_name(app):
def is_boolean(value): def is_boolean(value):
return isinstance(value, bool) return isinstance(value, bool)
app.add_template_test(is_boolean, 'boolean') app.add_template_test(is_boolean, "boolean")
assert 'boolean' in app.jinja_env.tests.keys() assert "boolean" in app.jinja_env.tests.keys()
assert app.jinja_env.tests['boolean'] == is_boolean assert app.jinja_env.tests["boolean"] == is_boolean
assert app.jinja_env.tests['boolean'](False) assert app.jinja_env.tests["boolean"](False)
def test_template_test_with_template(app, client): def test_template_test_with_template(app, client):
@ -257,12 +259,12 @@ def test_template_test_with_template(app, client):
def boolean(value): def boolean(value):
return isinstance(value, bool) return isinstance(value, bool)
@app.route('/') @app.route("/")
def index(): def index():
return flask.render_template('template_test.html', value=False) return flask.render_template("template_test.html", value=False)
rv = client.get('/') rv = client.get("/")
assert b'Success!' in rv.data assert b"Success!" in rv.data
def test_add_template_test_with_template(app, client): def test_add_template_test_with_template(app, client):
@ -271,39 +273,39 @@ def test_add_template_test_with_template(app, client):
app.add_template_test(boolean) app.add_template_test(boolean)
@app.route('/') @app.route("/")
def index(): def index():
return flask.render_template('template_test.html', value=False) return flask.render_template("template_test.html", value=False)
rv = client.get('/') rv = client.get("/")
assert b'Success!' in rv.data assert b"Success!" in rv.data
def test_template_test_with_name_and_template(app, client): def test_template_test_with_name_and_template(app, client):
@app.template_test('boolean') @app.template_test("boolean")
def is_boolean(value): def is_boolean(value):
return isinstance(value, bool) return isinstance(value, bool)
@app.route('/') @app.route("/")
def index(): def index():
return flask.render_template('template_test.html', value=False) return flask.render_template("template_test.html", value=False)
rv = client.get('/') rv = client.get("/")
assert b'Success!' in rv.data assert b"Success!" in rv.data
def test_add_template_test_with_name_and_template(app, client): def test_add_template_test_with_name_and_template(app, client):
def is_boolean(value): def is_boolean(value):
return isinstance(value, bool) return isinstance(value, bool)
app.add_template_test(is_boolean, 'boolean') app.add_template_test(is_boolean, "boolean")
@app.route('/') @app.route("/")
def index(): def index():
return flask.render_template('template_test.html', value=False) return flask.render_template("template_test.html", value=False)
rv = client.get('/') rv = client.get("/")
assert b'Success!' in rv.data assert b"Success!" in rv.data
def test_add_template_global(app, app_ctx): def test_add_template_global(app, app_ctx):
@ -311,84 +313,89 @@ def test_add_template_global(app, app_ctx):
def get_stuff(): def get_stuff():
return 42 return 42
assert 'get_stuff' in app.jinja_env.globals.keys() assert "get_stuff" in app.jinja_env.globals.keys()
assert app.jinja_env.globals['get_stuff'] == get_stuff assert app.jinja_env.globals["get_stuff"] == get_stuff
assert app.jinja_env.globals['get_stuff'](), 42 assert app.jinja_env.globals["get_stuff"](), 42
rv = flask.render_template_string('{{ get_stuff() }}') rv = flask.render_template_string("{{ get_stuff() }}")
assert rv == '42' assert rv == "42"
def test_custom_template_loader(client): def test_custom_template_loader(client):
class MyFlask(flask.Flask): class MyFlask(flask.Flask):
def create_global_jinja_loader(self): def create_global_jinja_loader(self):
from jinja2 import DictLoader from jinja2 import DictLoader
return DictLoader({'index.html': 'Hello Custom World!'})
return DictLoader({"index.html": "Hello Custom World!"})
app = MyFlask(__name__) app = MyFlask(__name__)
@app.route('/') @app.route("/")
def index(): def index():
return flask.render_template('index.html') return flask.render_template("index.html")
c = app.test_client() c = app.test_client()
rv = c.get('/') rv = c.get("/")
assert rv.data == b'Hello Custom World!' assert rv.data == b"Hello Custom World!"
def test_iterable_loader(app, client): def test_iterable_loader(app, client):
@app.context_processor @app.context_processor
def context_processor(): def context_processor():
return {'whiskey': 'Jameson'} return {"whiskey": "Jameson"}
@app.route('/') @app.route("/")
def index(): def index():
return flask.render_template( return flask.render_template(
['no_template.xml', # should skip this one [
'simple_template.html', # should render this "no_template.xml", # should skip this one
'context_template.html'], "simple_template.html", # should render this
value=23) "context_template.html",
],
value=23,
)
rv = client.get('/') rv = client.get("/")
assert rv.data == b'<h1>Jameson</h1>' assert rv.data == b"<h1>Jameson</h1>"
def test_templates_auto_reload(app): def test_templates_auto_reload(app):
# debug is False, config option is None # debug is False, config option is None
assert app.debug is False assert app.debug is False
assert app.config['TEMPLATES_AUTO_RELOAD'] is None assert app.config["TEMPLATES_AUTO_RELOAD"] is None
assert app.jinja_env.auto_reload is False assert app.jinja_env.auto_reload is False
# debug is False, config option is False # debug is False, config option is False
app = flask.Flask(__name__) app = flask.Flask(__name__)
app.config['TEMPLATES_AUTO_RELOAD'] = False app.config["TEMPLATES_AUTO_RELOAD"] = False
assert app.debug is False assert app.debug is False
assert app.jinja_env.auto_reload is False assert app.jinja_env.auto_reload is False
# debug is False, config option is True # debug is False, config option is True
app = flask.Flask(__name__) app = flask.Flask(__name__)
app.config['TEMPLATES_AUTO_RELOAD'] = True app.config["TEMPLATES_AUTO_RELOAD"] = True
assert app.debug is False assert app.debug is False
assert app.jinja_env.auto_reload is True assert app.jinja_env.auto_reload is True
# debug is True, config option is None # debug is True, config option is None
app = flask.Flask(__name__) app = flask.Flask(__name__)
app.config['DEBUG'] = True app.config["DEBUG"] = True
assert app.config['TEMPLATES_AUTO_RELOAD'] is None assert app.config["TEMPLATES_AUTO_RELOAD"] is None
assert app.jinja_env.auto_reload is True assert app.jinja_env.auto_reload is True
# debug is True, config option is False # debug is True, config option is False
app = flask.Flask(__name__) app = flask.Flask(__name__)
app.config['DEBUG'] = True app.config["DEBUG"] = True
app.config['TEMPLATES_AUTO_RELOAD'] = False app.config["TEMPLATES_AUTO_RELOAD"] = False
assert app.jinja_env.auto_reload is False assert app.jinja_env.auto_reload is False
# debug is True, config option is True # debug is True, config option is True
app = flask.Flask(__name__) app = flask.Flask(__name__)
app.config['DEBUG'] = True app.config["DEBUG"] = True
app.config['TEMPLATES_AUTO_RELOAD'] = True app.config["TEMPLATES_AUTO_RELOAD"] = True
assert app.jinja_env.auto_reload is True assert app.jinja_env.auto_reload is True
def test_templates_auto_reload_debug_run(app, monkeypatch): def test_templates_auto_reload_debug_run(app, monkeypatch):
def run_simple_mock(*args, **kwargs): def run_simple_mock(*args, **kwargs):
pass pass
monkeypatch.setattr(werkzeug.serving, 'run_simple', run_simple_mock) monkeypatch.setattr(werkzeug.serving, "run_simple", run_simple_mock)
app.run() app.run()
assert app.templates_auto_reload == False assert app.templates_auto_reload == False
@ -409,25 +416,26 @@ def test_template_loader_debugging(test_apps, monkeypatch):
called.append(True) called.append(True)
text = str(record.msg) text = str(record.msg)
assert '1: trying loader of application "blueprintapp"' in text assert '1: trying loader of application "blueprintapp"' in text
assert ('2: trying loader of blueprint "admin" ' assert (
'(blueprintapp.apps.admin)') in text '2: trying loader of blueprint "admin" ' "(blueprintapp.apps.admin)"
assert ('trying loader of blueprint "frontend" ' ) in text
'(blueprintapp.apps.frontend)') in text assert (
assert 'Error: the template could not be found' in text 'trying loader of blueprint "frontend" ' "(blueprintapp.apps.frontend)"
assert ('looked up from an endpoint that belongs to ' ) in text
'the blueprint "frontend"') in text assert "Error: the template could not be found" in text
assert 'See http://flask.pocoo.org/docs/blueprints/#templates' in text assert (
"looked up from an endpoint that belongs to " 'the blueprint "frontend"'
) in text
assert "See http://flask.pocoo.org/docs/blueprints/#templates" in text
with app.test_client() as c: with app.test_client() as c:
monkeypatch.setitem(app.config, 'EXPLAIN_TEMPLATE_LOADING', True) monkeypatch.setitem(app.config, "EXPLAIN_TEMPLATE_LOADING", True)
monkeypatch.setattr( monkeypatch.setattr(logging.getLogger("flask"), "handlers", [_TestHandler()])
logging.getLogger('flask'), 'handlers', [_TestHandler()]
)
with pytest.raises(TemplateNotFound) as excinfo: with pytest.raises(TemplateNotFound) as excinfo:
c.get('/missing') c.get("/missing")
assert 'missing_template.html' in str(excinfo.value) assert "missing_template.html" in str(excinfo.value)
assert len(called) == 1 assert len(called) == 1

View file

@ -21,166 +21,166 @@ from flask.testing import make_test_environ_builder, FlaskCliRunner
def test_environ_defaults_from_config(app, client): def test_environ_defaults_from_config(app, client):
app.config['SERVER_NAME'] = 'example.com:1234' app.config["SERVER_NAME"] = "example.com:1234"
app.config['APPLICATION_ROOT'] = '/foo' app.config["APPLICATION_ROOT"] = "/foo"
@app.route('/') @app.route("/")
def index(): def index():
return flask.request.url return flask.request.url
ctx = app.test_request_context() ctx = app.test_request_context()
assert ctx.request.url == 'http://example.com:1234/foo/' assert ctx.request.url == "http://example.com:1234/foo/"
rv = client.get('/') rv = client.get("/")
assert rv.data == b'http://example.com:1234/foo/' assert rv.data == b"http://example.com:1234/foo/"
def test_environ_defaults(app, client, app_ctx, req_ctx): def test_environ_defaults(app, client, app_ctx, req_ctx):
@app.route('/') @app.route("/")
def index(): def index():
return flask.request.url return flask.request.url
ctx = app.test_request_context() ctx = app.test_request_context()
assert ctx.request.url == 'http://localhost/' assert ctx.request.url == "http://localhost/"
with client: with client:
rv = client.get('/') rv = client.get("/")
assert rv.data == b'http://localhost/' assert rv.data == b"http://localhost/"
def test_environ_base_default(app, client, app_ctx): def test_environ_base_default(app, client, app_ctx):
@app.route('/') @app.route("/")
def index(): def index():
flask.g.user_agent = flask.request.headers["User-Agent"] flask.g.user_agent = flask.request.headers["User-Agent"]
return flask.request.remote_addr return flask.request.remote_addr
rv = client.get('/') rv = client.get("/")
assert rv.data == b'127.0.0.1' assert rv.data == b"127.0.0.1"
assert flask.g.user_agent == 'werkzeug/' + werkzeug.__version__ assert flask.g.user_agent == "werkzeug/" + werkzeug.__version__
def test_environ_base_modified(app, client, app_ctx): def test_environ_base_modified(app, client, app_ctx):
@app.route('/') @app.route("/")
def index(): def index():
flask.g.user_agent = flask.request.headers["User-Agent"] flask.g.user_agent = flask.request.headers["User-Agent"]
return flask.request.remote_addr return flask.request.remote_addr
client.environ_base['REMOTE_ADDR'] = '0.0.0.0' client.environ_base["REMOTE_ADDR"] = "0.0.0.0"
client.environ_base['HTTP_USER_AGENT'] = 'Foo' client.environ_base["HTTP_USER_AGENT"] = "Foo"
rv = client.get('/') rv = client.get("/")
assert rv.data == b'0.0.0.0' assert rv.data == b"0.0.0.0"
assert flask.g.user_agent == 'Foo' assert flask.g.user_agent == "Foo"
client.environ_base['REMOTE_ADDR'] = '0.0.0.1' client.environ_base["REMOTE_ADDR"] = "0.0.0.1"
client.environ_base['HTTP_USER_AGENT'] = 'Bar' client.environ_base["HTTP_USER_AGENT"] = "Bar"
rv = client.get('/') rv = client.get("/")
assert rv.data == b'0.0.0.1' assert rv.data == b"0.0.0.1"
assert flask.g.user_agent == 'Bar' assert flask.g.user_agent == "Bar"
def test_client_open_environ(app, client, request): def test_client_open_environ(app, client, request):
@app.route('/index') @app.route("/index")
def index(): def index():
return flask.request.remote_addr return flask.request.remote_addr
builder = make_test_environ_builder(app, path='/index', method='GET') builder = make_test_environ_builder(app, path="/index", method="GET")
request.addfinalizer(builder.close) request.addfinalizer(builder.close)
rv = client.open(builder) rv = client.open(builder)
assert rv.data == b'127.0.0.1' assert rv.data == b"127.0.0.1"
environ = builder.get_environ() environ = builder.get_environ()
client.environ_base['REMOTE_ADDR'] = '127.0.0.2' client.environ_base["REMOTE_ADDR"] = "127.0.0.2"
rv = client.open(environ) rv = client.open(environ)
assert rv.data == b'127.0.0.2' assert rv.data == b"127.0.0.2"
def test_specify_url_scheme(app, client): def test_specify_url_scheme(app, client):
@app.route('/') @app.route("/")
def index(): def index():
return flask.request.url return flask.request.url
ctx = app.test_request_context(url_scheme='https') ctx = app.test_request_context(url_scheme="https")
assert ctx.request.url == 'https://localhost/' assert ctx.request.url == "https://localhost/"
rv = client.get('/', url_scheme='https') rv = client.get("/", url_scheme="https")
assert rv.data == b'https://localhost/' assert rv.data == b"https://localhost/"
def test_path_is_url(app): def test_path_is_url(app):
eb = make_test_environ_builder(app, 'https://example.com/') eb = make_test_environ_builder(app, "https://example.com/")
assert eb.url_scheme == 'https' assert eb.url_scheme == "https"
assert eb.host == 'example.com' assert eb.host == "example.com"
assert eb.script_root == '' assert eb.script_root == ""
assert eb.path == '/' assert eb.path == "/"
def test_blueprint_with_subdomain(): def test_blueprint_with_subdomain():
app = flask.Flask(__name__, subdomain_matching=True) app = flask.Flask(__name__, subdomain_matching=True)
app.config['SERVER_NAME'] = 'example.com:1234' app.config["SERVER_NAME"] = "example.com:1234"
app.config['APPLICATION_ROOT'] = '/foo' app.config["APPLICATION_ROOT"] = "/foo"
client = app.test_client() client = app.test_client()
bp = flask.Blueprint('company', __name__, subdomain='xxx') bp = flask.Blueprint("company", __name__, subdomain="xxx")
@bp.route('/') @bp.route("/")
def index(): def index():
return flask.request.url return flask.request.url
app.register_blueprint(bp) app.register_blueprint(bp)
ctx = app.test_request_context('/', subdomain='xxx') ctx = app.test_request_context("/", subdomain="xxx")
assert ctx.request.url == 'http://xxx.example.com:1234/foo/' assert ctx.request.url == "http://xxx.example.com:1234/foo/"
assert ctx.request.blueprint == bp.name assert ctx.request.blueprint == bp.name
rv = client.get('/', subdomain='xxx') rv = client.get("/", subdomain="xxx")
assert rv.data == b'http://xxx.example.com:1234/foo/' assert rv.data == b"http://xxx.example.com:1234/foo/"
def test_redirect_keep_session(app, client, app_ctx): def test_redirect_keep_session(app, client, app_ctx):
@app.route('/', methods=['GET', 'POST']) @app.route("/", methods=["GET", "POST"])
def index(): def index():
if flask.request.method == 'POST': if flask.request.method == "POST":
return flask.redirect('/getsession') return flask.redirect("/getsession")
flask.session['data'] = 'foo' flask.session["data"] = "foo"
return 'index' return "index"
@app.route('/getsession') @app.route("/getsession")
def get_session(): def get_session():
return flask.session.get('data', '<missing>') return flask.session.get("data", "<missing>")
with client: with client:
rv = client.get('/getsession') rv = client.get("/getsession")
assert rv.data == b'<missing>' assert rv.data == b"<missing>"
rv = client.get('/') rv = client.get("/")
assert rv.data == b'index' assert rv.data == b"index"
assert flask.session.get('data') == 'foo' assert flask.session.get("data") == "foo"
rv = client.post('/', data={}, follow_redirects=True) rv = client.post("/", data={}, follow_redirects=True)
assert rv.data == b'foo' assert rv.data == b"foo"
# This support requires a new Werkzeug version # This support requires a new Werkzeug version
if not hasattr(client, 'redirect_client'): if not hasattr(client, "redirect_client"):
assert flask.session.get('data') == 'foo' assert flask.session.get("data") == "foo"
rv = client.get('/getsession') rv = client.get("/getsession")
assert rv.data == b'foo' assert rv.data == b"foo"
def test_session_transactions(app, client): def test_session_transactions(app, client):
@app.route('/') @app.route("/")
def index(): def index():
return text_type(flask.session['foo']) return text_type(flask.session["foo"])
with client: with client:
with client.session_transaction() as sess: with client.session_transaction() as sess:
assert len(sess) == 0 assert len(sess) == 0
sess['foo'] = [42] sess["foo"] = [42]
assert len(sess) == 1 assert len(sess) == 1
rv = client.get('/') rv = client.get("/")
assert rv.data == b'[42]' assert rv.data == b"[42]"
with client.session_transaction() as sess: with client.session_transaction() as sess:
assert len(sess) == 1 assert len(sess) == 1
assert sess['foo'] == [42] assert sess["foo"] == [42]
def test_session_transactions_no_null_sessions(): def test_session_transactions_no_null_sessions():
@ -191,11 +191,11 @@ def test_session_transactions_no_null_sessions():
with pytest.raises(RuntimeError) as e: with pytest.raises(RuntimeError) as e:
with c.session_transaction() as sess: with c.session_transaction() as sess:
pass pass
assert 'Session backend did not open a session' in str(e.value) assert "Session backend did not open a session" in str(e.value)
def test_session_transactions_keep_context(app, client, req_ctx): def test_session_transactions_keep_context(app, client, req_ctx):
rv = client.get('/') rv = client.get("/")
req = flask.request._get_current_object() req = flask.request._get_current_object()
assert req is not None assert req is not None
with client.session_transaction(): with client.session_transaction():
@ -207,30 +207,30 @@ def test_session_transaction_needs_cookies(app):
with pytest.raises(RuntimeError) as e: with pytest.raises(RuntimeError) as e:
with c.session_transaction() as s: with c.session_transaction() as s:
pass pass
assert 'cookies' in str(e.value) assert "cookies" in str(e.value)
def test_test_client_context_binding(app, client): def test_test_client_context_binding(app, client):
app.testing = False app.testing = False
@app.route('/') @app.route("/")
def index(): def index():
flask.g.value = 42 flask.g.value = 42
return 'Hello World!' return "Hello World!"
@app.route('/other') @app.route("/other")
def other(): def other():
1 // 0 1 // 0
with client: with client:
resp = client.get('/') resp = client.get("/")
assert flask.g.value == 42 assert flask.g.value == 42
assert resp.data == b'Hello World!' assert resp.data == b"Hello World!"
assert resp.status_code == 200 assert resp.status_code == 200
resp = client.get('/other') resp = client.get("/other")
assert not hasattr(flask.g, 'value') assert not hasattr(flask.g, "value")
assert b'Internal Server Error' in resp.data assert b"Internal Server Error" in resp.data
assert resp.status_code == 500 assert resp.status_code == 500
flask.g.value = 23 flask.g.value = 23
@ -239,17 +239,17 @@ def test_test_client_context_binding(app, client):
except (AttributeError, RuntimeError): except (AttributeError, RuntimeError):
pass pass
else: else:
raise AssertionError('some kind of exception expected') raise AssertionError("some kind of exception expected")
def test_reuse_client(client): def test_reuse_client(client):
c = client c = client
with c: with c:
assert client.get('/').status_code == 404 assert client.get("/").status_code == 404
with c: with c:
assert client.get('/').status_code == 404 assert client.get("/").status_code == 404
def test_test_client_calls_teardown_handlers(app, client): def test_test_client_calls_teardown_handlers(app, client):
@ -261,40 +261,40 @@ def test_test_client_calls_teardown_handlers(app, client):
with client: with client:
assert called == [] assert called == []
client.get('/') client.get("/")
assert called == [] assert called == []
assert called == [None] assert called == [None]
del called[:] del called[:]
with client: with client:
assert called == [] assert called == []
client.get('/') client.get("/")
assert called == [] assert called == []
client.get('/') client.get("/")
assert called == [None] assert called == [None]
assert called == [None, None] assert called == [None, None]
def test_full_url_request(app, client): def test_full_url_request(app, client):
@app.route('/action', methods=['POST']) @app.route("/action", methods=["POST"])
def action(): def action():
return 'x' return "x"
with client: with client:
rv = client.post('http://domain.com/action?vodka=42', data={'gin': 43}) rv = client.post("http://domain.com/action?vodka=42", data={"gin": 43})
assert rv.status_code == 200 assert rv.status_code == 200
assert 'gin' in flask.request.form assert "gin" in flask.request.form
assert 'vodka' in flask.request.args assert "vodka" in flask.request.args
def test_json_request_and_response(app, client): def test_json_request_and_response(app, client):
@app.route('/echo', methods=['POST']) @app.route("/echo", methods=["POST"])
def echo(): def echo():
return jsonify(flask.request.get_json()) return jsonify(flask.request.get_json())
with client: with client:
json_data = {'drink': {'gin': 1, 'tonic': True}, 'price': 10} json_data = {"drink": {"gin": 1, "tonic": True}, "price": 10}
rv = client.post('/echo', json=json_data) rv = client.post("/echo", json=json_data)
# Request should be in JSON # Request should be in JSON
assert flask.request.is_json assert flask.request.is_json
@ -308,38 +308,38 @@ def test_json_request_and_response(app, client):
def test_subdomain(): def test_subdomain():
app = flask.Flask(__name__, subdomain_matching=True) app = flask.Flask(__name__, subdomain_matching=True)
app.config['SERVER_NAME'] = 'example.com' app.config["SERVER_NAME"] = "example.com"
client = app.test_client() client = app.test_client()
@app.route('/', subdomain='<company_id>') @app.route("/", subdomain="<company_id>")
def view(company_id): def view(company_id):
return company_id return company_id
with app.test_request_context(): with app.test_request_context():
url = flask.url_for('view', company_id='xxx') url = flask.url_for("view", company_id="xxx")
with client: with client:
response = client.get(url) response = client.get(url)
assert 200 == response.status_code assert 200 == response.status_code
assert b'xxx' == response.data assert b"xxx" == response.data
def test_nosubdomain(app, client): def test_nosubdomain(app, client):
app.config['SERVER_NAME'] = 'example.com' app.config["SERVER_NAME"] = "example.com"
@app.route('/<company_id>') @app.route("/<company_id>")
def view(company_id): def view(company_id):
return company_id return company_id
with app.test_request_context(): with app.test_request_context():
url = flask.url_for('view', company_id='xxx') url = flask.url_for("view", company_id="xxx")
with client: with client:
response = client.get(url) response = client.get(url)
assert 200 == response.status_code assert 200 == response.status_code
assert b'xxx' == response.data assert b"xxx" == response.data
def test_cli_runner_class(app): def test_cli_runner_class(app):
@ -355,17 +355,17 @@ def test_cli_runner_class(app):
def test_cli_invoke(app): def test_cli_invoke(app):
@app.cli.command('hello') @app.cli.command("hello")
def hello_command(): def hello_command():
click.echo('Hello, World!') click.echo("Hello, World!")
runner = app.test_cli_runner() runner = app.test_cli_runner()
# invoke with command name # invoke with command name
result = runner.invoke(args=['hello']) result = runner.invoke(args=["hello"])
assert 'Hello' in result.output assert "Hello" in result.output
# invoke with command object # invoke with command object
result = runner.invoke(hello_command) result = runner.invoke(hello_command)
assert 'Hello' in result.output assert "Hello" in result.output
def test_cli_custom_obj(app): def test_cli_custom_obj(app):
@ -376,9 +376,9 @@ def test_cli_custom_obj(app):
NS.called = True NS.called = True
return app return app
@app.cli.command('hello') @app.cli.command("hello")
def hello_command(): def hello_command():
click.echo('Hello, World!') click.echo("Hello, World!")
script_info = ScriptInfo(create_app=create_app) script_info = ScriptInfo(create_app=create_app)
runner = app.test_cli_runner() runner = app.test_cli_runner()

View file

@ -7,40 +7,34 @@ tests.test_user_error_handler
:license: BSD, see LICENSE for more details. :license: BSD, see LICENSE for more details.
""" """
from werkzeug.exceptions import ( from werkzeug.exceptions import Forbidden, InternalServerError, HTTPException, NotFound
Forbidden,
InternalServerError,
HTTPException,
NotFound
)
import flask import flask
def test_error_handler_no_match(app, client): def test_error_handler_no_match(app, client):
class CustomException(Exception): class CustomException(Exception):
pass pass
@app.errorhandler(CustomException) @app.errorhandler(CustomException)
def custom_exception_handler(e): def custom_exception_handler(e):
assert isinstance(e, CustomException) assert isinstance(e, CustomException)
return 'custom' return "custom"
@app.errorhandler(500) @app.errorhandler(500)
def handle_500(e): def handle_500(e):
return type(e).__name__ return type(e).__name__
@app.route('/custom') @app.route("/custom")
def custom_test(): def custom_test():
raise CustomException() raise CustomException()
@app.route('/keyerror') @app.route("/keyerror")
def key_error(): def key_error():
raise KeyError() raise KeyError()
app.testing = False app.testing = False
assert client.get('/custom').data == b'custom' assert client.get("/custom").data == b"custom"
assert client.get('/keyerror').data == b'KeyError' assert client.get("/keyerror").data == b"KeyError"
def test_error_handler_subclass(app): def test_error_handler_subclass(app):
@ -56,30 +50,30 @@ def test_error_handler_subclass(app):
@app.errorhandler(ParentException) @app.errorhandler(ParentException)
def parent_exception_handler(e): def parent_exception_handler(e):
assert isinstance(e, ParentException) assert isinstance(e, ParentException)
return 'parent' return "parent"
@app.errorhandler(ChildExceptionRegistered) @app.errorhandler(ChildExceptionRegistered)
def child_exception_handler(e): def child_exception_handler(e):
assert isinstance(e, ChildExceptionRegistered) assert isinstance(e, ChildExceptionRegistered)
return 'child-registered' return "child-registered"
@app.route('/parent') @app.route("/parent")
def parent_test(): def parent_test():
raise ParentException() raise ParentException()
@app.route('/child-unregistered') @app.route("/child-unregistered")
def unregistered_test(): def unregistered_test():
raise ChildExceptionUnregistered() raise ChildExceptionUnregistered()
@app.route('/child-registered') @app.route("/child-registered")
def registered_test(): def registered_test():
raise ChildExceptionRegistered() raise ChildExceptionRegistered()
c = app.test_client() c = app.test_client()
assert c.get('/parent').data == b'parent' assert c.get("/parent").data == b"parent"
assert c.get('/child-unregistered').data == b'parent' assert c.get("/child-unregistered").data == b"parent"
assert c.get('/child-registered').data == b'child-registered' assert c.get("/child-registered").data == b"child-registered"
def test_error_handler_http_subclass(app): def test_error_handler_http_subclass(app):
@ -92,78 +86,78 @@ def test_error_handler_http_subclass(app):
@app.errorhandler(403) @app.errorhandler(403)
def code_exception_handler(e): def code_exception_handler(e):
assert isinstance(e, Forbidden) assert isinstance(e, Forbidden)
return 'forbidden' return "forbidden"
@app.errorhandler(ForbiddenSubclassRegistered) @app.errorhandler(ForbiddenSubclassRegistered)
def subclass_exception_handler(e): def subclass_exception_handler(e):
assert isinstance(e, ForbiddenSubclassRegistered) assert isinstance(e, ForbiddenSubclassRegistered)
return 'forbidden-registered' return "forbidden-registered"
@app.route('/forbidden') @app.route("/forbidden")
def forbidden_test(): def forbidden_test():
raise Forbidden() raise Forbidden()
@app.route('/forbidden-registered') @app.route("/forbidden-registered")
def registered_test(): def registered_test():
raise ForbiddenSubclassRegistered() raise ForbiddenSubclassRegistered()
@app.route('/forbidden-unregistered') @app.route("/forbidden-unregistered")
def unregistered_test(): def unregistered_test():
raise ForbiddenSubclassUnregistered() raise ForbiddenSubclassUnregistered()
c = app.test_client() c = app.test_client()
assert c.get('/forbidden').data == b'forbidden' assert c.get("/forbidden").data == b"forbidden"
assert c.get('/forbidden-unregistered').data == b'forbidden' assert c.get("/forbidden-unregistered").data == b"forbidden"
assert c.get('/forbidden-registered').data == b'forbidden-registered' assert c.get("/forbidden-registered").data == b"forbidden-registered"
def test_error_handler_blueprint(app): def test_error_handler_blueprint(app):
bp = flask.Blueprint('bp', __name__) bp = flask.Blueprint("bp", __name__)
@bp.errorhandler(500) @bp.errorhandler(500)
def bp_exception_handler(e): def bp_exception_handler(e):
return 'bp-error' return "bp-error"
@bp.route('/error') @bp.route("/error")
def bp_test(): def bp_test():
raise InternalServerError() raise InternalServerError()
@app.errorhandler(500) @app.errorhandler(500)
def app_exception_handler(e): def app_exception_handler(e):
return 'app-error' return "app-error"
@app.route('/error') @app.route("/error")
def app_test(): def app_test():
raise InternalServerError() raise InternalServerError()
app.register_blueprint(bp, url_prefix='/bp') app.register_blueprint(bp, url_prefix="/bp")
c = app.test_client() c = app.test_client()
assert c.get('/error').data == b'app-error' assert c.get("/error").data == b"app-error"
assert c.get('/bp/error').data == b'bp-error' assert c.get("/bp/error").data == b"bp-error"
def test_default_error_handler(): def test_default_error_handler():
bp = flask.Blueprint('bp', __name__) bp = flask.Blueprint("bp", __name__)
@bp.errorhandler(HTTPException) @bp.errorhandler(HTTPException)
def bp_exception_handler(e): def bp_exception_handler(e):
assert isinstance(e, HTTPException) assert isinstance(e, HTTPException)
assert isinstance(e, NotFound) assert isinstance(e, NotFound)
return 'bp-default' return "bp-default"
@bp.errorhandler(Forbidden) @bp.errorhandler(Forbidden)
def bp_exception_handler(e): def bp_exception_handler(e):
assert isinstance(e, Forbidden) assert isinstance(e, Forbidden)
return 'bp-forbidden' return "bp-forbidden"
@bp.route('/undefined') @bp.route("/undefined")
def bp_registered_test(): def bp_registered_test():
raise NotFound() raise NotFound()
@bp.route('/forbidden') @bp.route("/forbidden")
def bp_forbidden_test(): def bp_forbidden_test():
raise Forbidden() raise Forbidden()
@ -173,14 +167,14 @@ def test_default_error_handler():
def catchall_errorhandler(e): def catchall_errorhandler(e):
assert isinstance(e, HTTPException) assert isinstance(e, HTTPException)
assert isinstance(e, NotFound) assert isinstance(e, NotFound)
return 'default' return "default"
@app.errorhandler(Forbidden) @app.errorhandler(Forbidden)
def catchall_errorhandler(e): def catchall_errorhandler(e):
assert isinstance(e, Forbidden) assert isinstance(e, Forbidden)
return 'forbidden' return "forbidden"
@app.route('/forbidden') @app.route("/forbidden")
def forbidden(): def forbidden():
raise Forbidden() raise Forbidden()
@ -188,12 +182,12 @@ def test_default_error_handler():
def slash(): def slash():
return "slash" return "slash"
app.register_blueprint(bp, url_prefix='/bp') app.register_blueprint(bp, url_prefix="/bp")
c = app.test_client() c = app.test_client()
assert c.get('/bp/undefined').data == b'bp-default' assert c.get("/bp/undefined").data == b"bp-default"
assert c.get('/bp/forbidden').data == b'bp-forbidden' assert c.get("/bp/forbidden").data == b"bp-forbidden"
assert c.get('/undefined').data == b'default' assert c.get("/undefined").data == b"default"
assert c.get('/forbidden').data == b'forbidden' assert c.get("/forbidden").data == b"forbidden"
# Don't handle RequestRedirect raised when adding slash. # Don't handle RequestRedirect raised when adding slash.
assert c.get("/slash", follow_redirects=True).data == b"slash" assert c.get("/slash", follow_redirects=True).data == b"slash"

View file

@ -20,33 +20,33 @@ from werkzeug.http import parse_set_header
def common_test(app): def common_test(app):
c = app.test_client() c = app.test_client()
assert c.get('/').data == b'GET' assert c.get("/").data == b"GET"
assert c.post('/').data == b'POST' assert c.post("/").data == b"POST"
assert c.put('/').status_code == 405 assert c.put("/").status_code == 405
meths = parse_set_header(c.open('/', method='OPTIONS').headers['Allow']) meths = parse_set_header(c.open("/", method="OPTIONS").headers["Allow"])
assert sorted(meths) == ['GET', 'HEAD', 'OPTIONS', 'POST'] assert sorted(meths) == ["GET", "HEAD", "OPTIONS", "POST"]
def test_basic_view(app): def test_basic_view(app):
class Index(flask.views.View): class Index(flask.views.View):
methods = ['GET', 'POST'] methods = ["GET", "POST"]
def dispatch_request(self): def dispatch_request(self):
return flask.request.method return flask.request.method
app.add_url_rule('/', view_func=Index.as_view('index')) app.add_url_rule("/", view_func=Index.as_view("index"))
common_test(app) common_test(app)
def test_method_based_view(app): def test_method_based_view(app):
class Index(flask.views.MethodView): class Index(flask.views.MethodView):
def get(self): def get(self):
return 'GET' return "GET"
def post(self): def post(self):
return 'POST' return "POST"
app.add_url_rule('/', view_func=Index.as_view('index')) app.add_url_rule("/", view_func=Index.as_view("index"))
common_test(app) common_test(app)
@ -61,40 +61,40 @@ def test_view_patching(app):
class Other(Index): class Other(Index):
def get(self): def get(self):
return 'GET' return "GET"
def post(self): def post(self):
return 'POST' return "POST"
view = Index.as_view('index') view = Index.as_view("index")
view.view_class = Other view.view_class = Other
app.add_url_rule('/', view_func=view) app.add_url_rule("/", view_func=view)
common_test(app) common_test(app)
def test_view_inheritance(app, client): def test_view_inheritance(app, client):
class Index(flask.views.MethodView): class Index(flask.views.MethodView):
def get(self): def get(self):
return 'GET' return "GET"
def post(self): def post(self):
return 'POST' return "POST"
class BetterIndex(Index): class BetterIndex(Index):
def delete(self): def delete(self):
return 'DELETE' return "DELETE"
app.add_url_rule('/', view_func=BetterIndex.as_view('index')) app.add_url_rule("/", view_func=BetterIndex.as_view("index"))
meths = parse_set_header(client.open('/', method='OPTIONS').headers['Allow']) meths = parse_set_header(client.open("/", method="OPTIONS").headers["Allow"])
assert sorted(meths) == ['DELETE', 'GET', 'HEAD', 'OPTIONS', 'POST'] assert sorted(meths) == ["DELETE", "GET", "HEAD", "OPTIONS", "POST"]
def test_view_decorators(app, client): def test_view_decorators(app, client):
def add_x_parachute(f): def add_x_parachute(f):
def new_function(*args, **kwargs): def new_function(*args, **kwargs):
resp = flask.make_response(f(*args, **kwargs)) resp = flask.make_response(f(*args, **kwargs))
resp.headers['X-Parachute'] = 'awesome' resp.headers["X-Parachute"] = "awesome"
return resp return resp
return new_function return new_function
@ -103,12 +103,12 @@ def test_view_decorators(app, client):
decorators = [add_x_parachute] decorators = [add_x_parachute]
def dispatch_request(self): def dispatch_request(self):
return 'Awesome' return "Awesome"
app.add_url_rule('/', view_func=Index.as_view('index')) app.add_url_rule("/", view_func=Index.as_view("index"))
rv = client.get('/') rv = client.get("/")
assert rv.headers['X-Parachute'] == 'awesome' assert rv.headers["X-Parachute"] == "awesome"
assert rv.data == b'Awesome' assert rv.data == b"Awesome"
def test_view_provide_automatic_options_attr(): def test_view_provide_automatic_options_attr():
@ -118,84 +118,82 @@ def test_view_provide_automatic_options_attr():
provide_automatic_options = False provide_automatic_options = False
def dispatch_request(self): def dispatch_request(self):
return 'Hello World!' return "Hello World!"
app.add_url_rule('/', view_func=Index1.as_view('index')) app.add_url_rule("/", view_func=Index1.as_view("index"))
c = app.test_client() c = app.test_client()
rv = c.open('/', method='OPTIONS') rv = c.open("/", method="OPTIONS")
assert rv.status_code == 405 assert rv.status_code == 405
app = flask.Flask(__name__) app = flask.Flask(__name__)
class Index2(flask.views.View): class Index2(flask.views.View):
methods = ['OPTIONS'] methods = ["OPTIONS"]
provide_automatic_options = True provide_automatic_options = True
def dispatch_request(self): def dispatch_request(self):
return 'Hello World!' return "Hello World!"
app.add_url_rule('/', view_func=Index2.as_view('index')) app.add_url_rule("/", view_func=Index2.as_view("index"))
c = app.test_client() c = app.test_client()
rv = c.open('/', method='OPTIONS') rv = c.open("/", method="OPTIONS")
assert sorted(rv.allow) == ['OPTIONS'] assert sorted(rv.allow) == ["OPTIONS"]
app = flask.Flask(__name__) app = flask.Flask(__name__)
class Index3(flask.views.View): class Index3(flask.views.View):
def dispatch_request(self): def dispatch_request(self):
return 'Hello World!' return "Hello World!"
app.add_url_rule('/', view_func=Index3.as_view('index')) app.add_url_rule("/", view_func=Index3.as_view("index"))
c = app.test_client() c = app.test_client()
rv = c.open('/', method='OPTIONS') rv = c.open("/", method="OPTIONS")
assert 'OPTIONS' in rv.allow assert "OPTIONS" in rv.allow
def test_implicit_head(app, client): def test_implicit_head(app, client):
class Index(flask.views.MethodView): class Index(flask.views.MethodView):
def get(self): def get(self):
return flask.Response('Blub', headers={ return flask.Response("Blub", headers={"X-Method": flask.request.method})
'X-Method': flask.request.method
})
app.add_url_rule('/', view_func=Index.as_view('index')) app.add_url_rule("/", view_func=Index.as_view("index"))
rv = client.get('/') rv = client.get("/")
assert rv.data == b'Blub' assert rv.data == b"Blub"
assert rv.headers['X-Method'] == 'GET' assert rv.headers["X-Method"] == "GET"
rv = client.head('/') rv = client.head("/")
assert rv.data == b'' assert rv.data == b""
assert rv.headers['X-Method'] == 'HEAD' assert rv.headers["X-Method"] == "HEAD"
def test_explicit_head(app, client): def test_explicit_head(app, client):
class Index(flask.views.MethodView): class Index(flask.views.MethodView):
def get(self): def get(self):
return 'GET' return "GET"
def head(self): def head(self):
return flask.Response('', headers={'X-Method': 'HEAD'}) return flask.Response("", headers={"X-Method": "HEAD"})
app.add_url_rule('/', view_func=Index.as_view('index')) app.add_url_rule("/", view_func=Index.as_view("index"))
rv = client.get('/') rv = client.get("/")
assert rv.data == b'GET' assert rv.data == b"GET"
rv = client.head('/') rv = client.head("/")
assert rv.data == b'' assert rv.data == b""
assert rv.headers['X-Method'] == 'HEAD' assert rv.headers["X-Method"] == "HEAD"
def test_endpoint_override(app): def test_endpoint_override(app):
app.debug = True app.debug = True
class Index(flask.views.View): class Index(flask.views.View):
methods = ['GET', 'POST'] methods = ["GET", "POST"]
def dispatch_request(self): def dispatch_request(self):
return flask.request.method return flask.request.method
app.add_url_rule('/', view_func=Index.as_view('index')) app.add_url_rule("/", view_func=Index.as_view("index"))
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
app.add_url_rule('/', view_func=Index.as_view('index')) app.add_url_rule("/", view_func=Index.as_view("index"))
# But these tests should still pass. We just log a warning. # But these tests should still pass. We just log a warning.
common_test(app) common_test(app)
@ -204,36 +202,36 @@ def test_endpoint_override(app):
def test_multiple_inheritance(app, client): def test_multiple_inheritance(app, client):
class GetView(flask.views.MethodView): class GetView(flask.views.MethodView):
def get(self): def get(self):
return 'GET' return "GET"
class DeleteView(flask.views.MethodView): class DeleteView(flask.views.MethodView):
def delete(self): def delete(self):
return 'DELETE' return "DELETE"
class GetDeleteView(GetView, DeleteView): class GetDeleteView(GetView, DeleteView):
pass pass
app.add_url_rule('/', view_func=GetDeleteView.as_view('index')) app.add_url_rule("/", view_func=GetDeleteView.as_view("index"))
assert client.get('/').data == b'GET' assert client.get("/").data == b"GET"
assert client.delete('/').data == b'DELETE' assert client.delete("/").data == b"DELETE"
assert sorted(GetDeleteView.methods) == ['DELETE', 'GET'] assert sorted(GetDeleteView.methods) == ["DELETE", "GET"]
def test_remove_method_from_parent(app, client): def test_remove_method_from_parent(app, client):
class GetView(flask.views.MethodView): class GetView(flask.views.MethodView):
def get(self): def get(self):
return 'GET' return "GET"
class OtherView(flask.views.MethodView): class OtherView(flask.views.MethodView):
def post(self): def post(self):
return 'POST' return "POST"
class View(GetView, OtherView): class View(GetView, OtherView):
methods = ['GET'] methods = ["GET"]
app.add_url_rule('/', view_func=View.as_view('index')) app.add_url_rule("/", view_func=View.as_view("index"))
assert client.get('/').data == b'GET' assert client.get("/").data == b"GET"
assert client.post('/').status_code == 405 assert client.post("/").status_code == 405
assert sorted(View.methods) == ['GET'] assert sorted(View.methods) == ["GET"]

View file

@ -37,6 +37,11 @@ commands =
# pytest-cov doesn't seem to play nice with -p # pytest-cov doesn't seem to play nice with -p
coverage run -p -m pytest tests examples coverage run -p -m pytest tests examples
[testenv:stylecheck]
deps = pre-commit
skip_install = true
commands = pre-commit run --all-files --show-diff-on-failure
[testenv:docs-html] [testenv:docs-html]
deps = deps =
-r docs/requirements.txt -r docs/requirements.txt