diff --git a/CHANGES b/CHANGES index dc39a95d..a3504ad5 100644 --- a/CHANGES +++ b/CHANGES @@ -65,6 +65,8 @@ Major release, unreleased - ``TRAP_BAD_REQUEST_ERRORS`` is enabled by default in debug mode. ``BadRequestKeyError`` has a message with the bad key in debug mode instead of the generic bad request message. (`#2348`_) +- Allow registering new tags with ``TaggedJSONSerializer`` to support + storing other types in the session cookie. (`#2352`_) .. _#1489: https://github.com/pallets/flask/pull/1489 .. _#1621: https://github.com/pallets/flask/pull/1621 @@ -84,6 +86,7 @@ Major release, unreleased .. _#2319: https://github.com/pallets/flask/pull/2319 .. _#2326: https://github.com/pallets/flask/pull/2326 .. _#2348: https://github.com/pallets/flask/pull/2348 +.. _#2352: https://github.com/pallets/flask/pull/2352 Version 0.12.2 -------------- diff --git a/docs/api.rst b/docs/api.rst index b5009907..fe4f151f 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -171,18 +171,6 @@ implementation that Flask is using. .. autoclass:: SessionMixin :members: -.. autodata:: session_json_serializer - - This object provides dumping and loading methods similar to simplejson - but it also tags certain builtin Python objects that commonly appear in - sessions. Currently the following extended values are supported in - the JSON it dumps: - - - :class:`~markupsafe.Markup` objects - - :class:`~uuid.UUID` objects - - :class:`~datetime.datetime` objects - - :class:`tuple`\s - .. admonition:: Notice The ``PERMANENT_SESSION_LIFETIME`` config key can also be an integer @@ -354,6 +342,8 @@ you are using Flask 0.10 which implies that: .. autoclass:: JSONDecoder :members: +.. automodule:: flask.json.tag + Template Rendering ------------------ diff --git a/flask/json.py b/flask/json/__init__.py similarity index 97% rename from flask/json.py rename to flask/json/__init__.py index a029e73a..93e6fdc4 100644 --- a/flask/json.py +++ b/flask/json/__init__.py @@ -1,18 +1,9 @@ # -*- coding: utf-8 -*- -""" - flask.json - ~~~~~~~~~~ - - Implementation helpers for the JSON support in Flask. - - :copyright: (c) 2015 by Armin Ronacher. - :license: BSD, see LICENSE for more details. -""" import io import uuid from datetime import date -from .globals import current_app, request -from ._compat import text_type, PY2 +from flask.globals import current_app, request +from flask._compat import text_type, PY2 from werkzeug.http import http_date from jinja2 import Markup diff --git a/flask/json/tag.py b/flask/json/tag.py new file mode 100644 index 00000000..3c57884e --- /dev/null +++ b/flask/json/tag.py @@ -0,0 +1,297 @@ +""" +Tagged JSON +~~~~~~~~~~~ + +A compact representation for lossless serialization of non-standard JSON types. +:class:`~flask.sessions.SecureCookieSessionInterface` uses this to serialize +the session data, but it may be useful in other places. It can be extended to +support other types. + +.. autoclass:: TaggedJSONSerializer + :members: + +.. autoclass:: JSONTag + :members: + +Let's seen an example that adds support for :class:`~collections.OrderedDict`. +Dicts don't have an order in Python or JSON, so to handle this we will dump +the items as a list of ``[key, value]`` pairs. Subclass :class:`JSONTag` and +give it the new key ``' od'`` to identify the type. The session serializer +processes dicts first, so insert the new tag at the front of the order since +``OrderedDict`` must be processed before ``dict``. :: + + from flask.json.tag import JSONTag + + class TagOrderedDict(JSONTag): + __slots__ = ('serializer',) + key = ' od' + + def check(self, value): + return isinstance(value, OrderedDict) + + def to_json(self, value): + return [[k, self.serializer.tag(v)] for k, v in iteritems(value)] + + def to_python(self, value): + return OrderedDict(value) + + app.session_interface.serializer.register(TagOrderedDict, 0) + +""" + +from base64 import b64decode, b64encode +from datetime import datetime +from uuid import UUID + +from jinja2 import Markup +from werkzeug.http import http_date, parse_date + +from flask._compat import iteritems, text_type +from flask.json import dumps, loads + + +class JSONTag(object): + """Base class for defining type tags for :class:`TaggedJSONSerializer`.""" + + __slots__ = ('serializer',) + + #: The tag to mark the serialized object with. If ``None``, this tag is + #: only used as an intermediate step during tagging. + key = None + + def __init__(self, serializer): + """Create a tagger for the given serializer.""" + self.serializer = serializer + + def check(self, value): + """Check if the given value should be tagged by this tag.""" + raise NotImplementedError + + def to_json(self, value): + """Convert the Python object to an object that is a valid JSON type. + The tag will be added later.""" + raise NotImplementedError + + def to_python(self, value): + """Convert the JSON representation back to the correct type. The tag + will already be removed.""" + raise NotImplementedError + + def tag(self, value): + """Convert the value to a valid JSON type and add the tag structure + around it.""" + return {self.key: self.to_json(value)} + + +class TagDict(JSONTag): + """Tag for 1-item dicts whose only key matches a registered tag. + + Internally, the dict key is suffixed with `__`, and the suffix is removed + when deserializing. + """ + + __slots__ = () + key = ' di' + + def check(self, value): + return ( + isinstance(value, dict) + and len(value) == 1 + and next(iter(value)) in self.serializer.tags + ) + + def to_json(self, value): + key = next(iter(value)) + return {key + '__': self.serializer.tag(value[key])} + + def to_python(self, value): + key = next(iter(value)) + return {key[:-2]: value[key]} + + +class PassDict(JSONTag): + __slots__ = () + + def check(self, value): + return isinstance(value, dict) + + def to_json(self, value): + # JSON objects may only have string keys, so don't bother tagging the + # key here. + return dict((k, self.serializer.tag(v)) for k, v in iteritems(value)) + + tag = to_json + + +class TagTuple(JSONTag): + __slots__ = () + key = ' t' + + def check(self, value): + return isinstance(value, tuple) + + def to_json(self, value): + return [self.serializer.tag(item) for item in value] + + def to_python(self, value): + return tuple(value) + + +class PassList(JSONTag): + __slots__ = () + + def check(self, value): + return isinstance(value, list) + + def to_json(self, value): + return [self.serializer.tag(item) for item in value] + + tag = to_json + + +class TagBytes(JSONTag): + __slots__ = () + key = ' b' + + def check(self, value): + return isinstance(value, bytes) + + def to_json(self, value): + return b64encode(value).decode('ascii') + + def to_python(self, value): + return b64decode(value) + + +class TagMarkup(JSONTag): + """Serialize anything matching the :class:`~flask.Markup` API by + having a ``__html__`` method to the result of that method. Always + deserializes to an instance of :class:`~flask.Markup`.""" + + __slots__ = () + key = ' m' + + def check(self, value): + return callable(getattr(value, '__html__', None)) + + def to_json(self, value): + return text_type(value.__html__()) + + def to_python(self, value): + return Markup(value) + + +class TagUUID(JSONTag): + __slots__ = () + key = ' u' + + def check(self, value): + return isinstance(value, UUID) + + def to_json(self, value): + return value.hex + + def to_python(self, value): + return UUID(value) + + +class TagDateTime(JSONTag): + __slots__ = () + key = ' d' + + def check(self, value): + return isinstance(value, datetime) + + def to_json(self, value): + return http_date(value) + + def to_python(self, value): + return parse_date(value) + + +class TaggedJSONSerializer(object): + """Serializer that uses a tag system to compactly represent objects that + are not JSON types. Passed as the intermediate serializer to + :class:`itsdangerous.Serializer`. + + The following extra types are supported: + + * :class:`dict` + * :class:`tuple` + * :class:`bytes` + * :class:`~flask.Markup` + * :class:`~uuid.UUID` + * :class:`~datetime.datetime` + """ + + __slots__ = ('tags', 'order') + + #: Tag classes to bind when creating the serializer. Other tags can be + #: added later using :meth:`~register`. + default_tags = [ + TagDict, PassDict, TagTuple, PassList, TagBytes, TagMarkup, TagUUID, + TagDateTime, + ] + + def __init__(self): + self.tags = {} + self.order = [] + + for cls in self.default_tags: + self.register(cls) + + def register(self, tag_class, force=False, index=-1): + """Register a new tag with this serializer. + + :param tag_class: tag class to register. Will be instantiated with this + serializer instance. + :param force: overwrite an existing tag. If false (default), a + :exc:`KeyError` is raised. + :param index: index to insert the new tag in the tag order. Useful when + the new tag is a special case of an existing tag. If -1 (default), + the tag is appended to the end of the order. + + :raise KeyError: if the tag key is already registered and ``force`` is + not true. + """ + tag = tag_class(self) + key = tag.key + + if key is not None: + if not force and key in self.tags: + raise KeyError("Tag '{0}' is already registered.".format(key)) + + self.tags[key] = tag + + if index == -1: + self.order.append(tag) + else: + self.order.insert(index, tag) + + def tag(self, value): + """Convert a value to a tagged representation if necessary.""" + for tag in self.order: + if tag.check(value): + return tag.tag(value) + + return value + + def untag(self, value): + """Convert a tagged representation back to the original type.""" + if len(value) != 1: + return value + + key = next(iter(value)) + + if key not in self.tags: + return value + + return self.tags[key].to_python(value[key]) + + def dumps(self, value): + """Tag the value and dump it to a compact JSON string.""" + return dumps(self.tag(value), separators=(',', ':')) + + def loads(self, value): + """Load data from a JSON string and deserialized any tagged objects.""" + return loads(value, object_hook=self.untag) diff --git a/flask/sessions.py b/flask/sessions.py index 3a5e2ac3..82b588bc 100644 --- a/flask/sessions.py +++ b/flask/sessions.py @@ -8,20 +8,15 @@ :copyright: (c) 2015 by Armin Ronacher. :license: BSD, see LICENSE for more details. """ - import hashlib -import uuid import warnings -from base64 import b64decode, b64encode from datetime import datetime from itsdangerous import BadSignature, URLSafeTimedSerializer from werkzeug.datastructures import CallbackDict -from werkzeug.http import http_date, parse_date -from . import Markup, json -from ._compat import iteritems, text_type -from .helpers import is_ip, total_seconds +from flask.helpers import is_ip, total_seconds +from flask.json.tag import TaggedJSONSerializer class SessionMixin(object): @@ -58,66 +53,6 @@ class SessionMixin(object): #: from being served the same cache. accessed = True -def _tag(value): - if isinstance(value, tuple): - return {' t': [_tag(x) for x in value]} - elif isinstance(value, uuid.UUID): - return {' u': value.hex} - elif isinstance(value, bytes): - return {' b': b64encode(value).decode('ascii')} - elif callable(getattr(value, '__html__', None)): - return {' m': text_type(value.__html__())} - elif isinstance(value, list): - return [_tag(x) for x in value] - elif isinstance(value, datetime): - return {' d': http_date(value)} - elif isinstance(value, dict): - return dict((k, _tag(v)) for k, v in iteritems(value)) - elif isinstance(value, str): - try: - return text_type(value) - except UnicodeError: - from flask.debughelpers import UnexpectedUnicodeError - raise UnexpectedUnicodeError(u'A byte string with ' - u'non-ASCII data was passed to the session system ' - u'which can only store unicode strings. Consider ' - u'base64 encoding your string (String was %r)' % value) - return value - - -class TaggedJSONSerializer(object): - """A customized JSON serializer that supports a few extra types that - we take for granted when serializing (tuples, markup objects, datetime). - """ - - def dumps(self, value): - return json.dumps(_tag(value), separators=(',', ':')) - - LOADS_MAP = { - ' t': tuple, - ' u': uuid.UUID, - ' b': b64decode, - ' m': Markup, - ' d': parse_date, - } - - def loads(self, value): - def object_hook(obj): - if len(obj) != 1: - return obj - the_key, the_value = next(iteritems(obj)) - # Check the key for a corresponding function - return_function = self.LOADS_MAP.get(the_key) - if return_function: - # Pass the value to the function - return return_function(the_value) - # Didn't find a function for this object - return obj - return json.loads(value, object_hook=object_hook) - - -session_json_serializer = TaggedJSONSerializer() - class SecureCookieSession(CallbackDict, SessionMixin): """Base class for sessions based on signed cookies.""" @@ -225,10 +160,10 @@ class SessionInterface(object): def get_cookie_domain(self, app): """Returns the domain that should be set for the session cookie. - + Uses ``SESSION_COOKIE_DOMAIN`` if it is configured, otherwise falls back to detecting the domain based on ``SERVER_NAME``. - + Once detected (or if not set at all), ``SESSION_COOKIE_DOMAIN`` is updated to avoid re-running the logic. """ @@ -318,7 +253,7 @@ class SessionInterface(object): has been modified, the cookie is set. If the session is permanent and the ``SESSION_REFRESH_EACH_REQUEST`` config is true, the cookie is always set. - + This check is usually skipped if the session was deleted. .. versionadded:: 0.11 @@ -345,6 +280,9 @@ class SessionInterface(object): raise NotImplementedError() +session_json_serializer = TaggedJSONSerializer() + + class SecureCookieSessionInterface(SessionInterface): """The default session interface that stores sessions in signed cookies through the :mod:`itsdangerous` module. diff --git a/tests/test_basic.py b/tests/test_basic.py index a38428b2..108c1409 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -435,28 +435,31 @@ def test_session_special_types(app, client): now = datetime.utcnow().replace(microsecond=0) the_uuid = uuid.uuid4() - @app.after_request - def modify_session(response): - flask.session['m'] = flask.Markup('Hello!') - flask.session['u'] = the_uuid - flask.session['dt'] = now - flask.session['b'] = b'\xff' - flask.session['t'] = (1, 2, 3) - return response - @app.route('/') def dump_session_contents(): - return pickle.dumps(dict(flask.session)) + flask.session['t'] = (1, 2, 3) + flask.session['b'] = b'\xff' + flask.session['m'] = flask.Markup('') + flask.session['u'] = the_uuid + flask.session['d'] = now + flask.session['t_tag'] = {' t': 'not-a-tuple'} + flask.session['di_t_tag'] = {' t__': 'not-a-tuple'} + flask.session['di_tag'] = {' di': 'not-a-dict'} + return '', 204 - client.get('/') - rv = pickle.loads(client.get('/').data) - assert rv['m'] == flask.Markup('Hello!') - assert type(rv['m']) == flask.Markup - assert rv['dt'] == now - assert rv['u'] == the_uuid - assert rv['b'] == b'\xff' - assert type(rv['b']) == bytes - assert rv['t'] == (1, 2, 3) + with client: + client.get('/') + s = flask.session + assert s['t'] == (1, 2, 3) + assert type(s['b']) == bytes + assert s['b'] == b'\xff' + assert type(s['m']) == flask.Markup + assert s['m'] == flask.Markup('') + assert s['u'] == the_uuid + assert s['d'] == now + assert s['t_tag'] == {' t': 'not-a-tuple'} + assert s['di_t_tag'] == {' t__': 'not-a-tuple'} + assert s['di_tag'] == {' di': 'not-a-dict'} def test_session_cookie_setting(app): diff --git a/tests/test_json_tag.py b/tests/test_json_tag.py new file mode 100644 index 00000000..b8cb6550 --- /dev/null +++ b/tests/test_json_tag.py @@ -0,0 +1,65 @@ +from datetime import datetime +from uuid import uuid4 + +import pytest + +from flask import Markup +from flask.json.tag import TaggedJSONSerializer, JSONTag + + +@pytest.mark.parametrize("data", ( + {' t': (1, 2, 3)}, + {' t__': b'a'}, + {' di': ' di'}, + {'x': (1, 2, 3), 'y': 4}, + (1, 2, 3), + [(1, 2, 3)], + b'\xff', + Markup(''), + uuid4(), + datetime.utcnow().replace(microsecond=0), +)) +def test_dump_load_unchanged(data): + s = TaggedJSONSerializer() + assert s.loads(s.dumps(data)) == data + + +def test_duplicate_tag(): + class TagDict(JSONTag): + key = ' d' + + s = TaggedJSONSerializer() + pytest.raises(KeyError, s.register, TagDict) + s.register(TagDict, force=True, index=0) + assert isinstance(s.tags[' d'], TagDict) + assert isinstance(s.order[0], TagDict) + + +def test_custom_tag(): + class Foo(object): + def __init__(self, data): + self.data = data + + class TagFoo(JSONTag): + __slots__ = () + key = ' f' + + def check(self, value): + return isinstance(value, Foo) + + def to_json(self, value): + return self.serializer.tag(value.data) + + def to_python(self, value): + return Foo(value) + + s = TaggedJSONSerializer() + s.register(TagFoo) + assert s.loads(s.dumps(Foo('bar'))).data == 'bar' + + +def test_tag_interface(): + t = JSONTag(None) + pytest.raises(NotImplementedError, t.check, None) + pytest.raises(NotImplementedError, t.to_json, None) + pytest.raises(NotImplementedError, t.to_python, None)