diff --git a/CHANGES b/CHANGES index dc39a95d..a3504ad5 100644 --- a/CHANGES +++ b/CHANGES @@ -65,6 +65,8 @@ Major release, unreleased - ``TRAP_BAD_REQUEST_ERRORS`` is enabled by default in debug mode. ``BadRequestKeyError`` has a message with the bad key in debug mode instead of the generic bad request message. (`#2348`_) +- Allow registering new tags with ``TaggedJSONSerializer`` to support + storing other types in the session cookie. (`#2352`_) .. _#1489: https://github.com/pallets/flask/pull/1489 .. _#1621: https://github.com/pallets/flask/pull/1621 @@ -84,6 +86,7 @@ Major release, unreleased .. _#2319: https://github.com/pallets/flask/pull/2319 .. _#2326: https://github.com/pallets/flask/pull/2326 .. _#2348: https://github.com/pallets/flask/pull/2348 +.. _#2352: https://github.com/pallets/flask/pull/2352 Version 0.12.2 -------------- diff --git a/flask/json.py b/flask/json/__init__.py similarity index 97% rename from flask/json.py rename to flask/json/__init__.py index a029e73a..93e6fdc4 100644 --- a/flask/json.py +++ b/flask/json/__init__.py @@ -1,18 +1,9 @@ # -*- coding: utf-8 -*- -""" - flask.json - ~~~~~~~~~~ - - Implementation helpers for the JSON support in Flask. - - :copyright: (c) 2015 by Armin Ronacher. - :license: BSD, see LICENSE for more details. -""" import io import uuid from datetime import date -from .globals import current_app, request -from ._compat import text_type, PY2 +from flask.globals import current_app, request +from flask._compat import text_type, PY2 from werkzeug.http import http_date from jinja2 import Markup diff --git a/flask/json/tag.py b/flask/json/tag.py new file mode 100644 index 00000000..40594282 --- /dev/null +++ b/flask/json/tag.py @@ -0,0 +1,188 @@ +from base64 import b64decode, b64encode +from datetime import datetime +from uuid import UUID + +from jinja2 import Markup +from werkzeug.http import http_date, parse_date + +from flask._compat import iteritems, text_type +from flask.json import dumps, loads + + +class JSONTag(object): + __slots__ = () + key = None + + def check(self, serializer, value): + raise NotImplementedError + + def to_json(self, serializer, value): + raise NotImplementedError + + def to_python(self, serializer, value): + raise NotImplementedError + + def tag(self, serializer, value): + return {self.key: self.to_json(serializer, value)} + + +class TagDict(JSONTag): + __slots__ = () + key = ' di' + + def check(self, serializer, value): + return isinstance(value, dict) + + def to_json(self, serializer, value, key=None): + if key is not None: + return {key + '__': serializer._tag(value[key])} + + return dict((k, serializer._tag(v)) for k, v in iteritems(value)) + + def to_python(self, serializer, value): + key = next(iter(value)) + return {key[:-2]: value[key]} + + def tag(self, serializer, value): + if len(value) == 1: + key = next(iter(value)) + + if key in serializer._tags: + return {self.key: self.to_json(serializer, value, key=key)} + + return self.to_json(serializer, value) + + +class TagTuple(JSONTag): + __slots__ = () + key = ' t' + + def check(self, serializer, value): + return isinstance(value, tuple) + + def to_json(self, serializer, value): + return [serializer._tag(item) for item in value] + + def to_python(self, serializer, value): + return tuple(value) + + +class PassList(JSONTag): + __slots__ = () + + def check(self, serializer, value): + return isinstance(value, list) + + def to_json(self, serializer, value): + return [serializer._tag(item) for item in value] + + tag = to_json + + +class TagBytes(JSONTag): + __slots__ = () + key = ' b' + + def check(self, serializer, value): + return isinstance(value, bytes) + + def to_json(self, serializer, value): + return b64encode(value).decode('ascii') + + def to_python(self, serializer, value): + return b64decode(value) + + +class TagMarkup(JSONTag): + __slots__ = () + key = ' m' + + def check(self, serializer, value): + return callable(getattr(value, '__html__', None)) + + def to_json(self, serializer, value): + return text_type(value.__html__()) + + def to_python(self, serializer, value): + return Markup(value) + + +class TagUUID(JSONTag): + __slots__ = () + key = ' u' + + def check(self, serializer, value): + return isinstance(value, UUID) + + def to_json(self, serializer, value): + return value.hex + + def to_python(self, serializer, value): + return UUID(value) + + +class TagDateTime(JSONTag): + __slots__ = () + key = ' d' + + def check(self, serializer, value): + return isinstance(value, datetime) + + def to_json(self, serializer, value): + return http_date(value) + + def to_python(self, serializer, value): + return parse_date(value) + + +class TaggedJSONSerializer(object): + __slots__ = ('_tags', '_order') + _default_tags = [ + TagDict(), TagTuple(), PassList(), TagBytes(), TagMarkup(), TagUUID(), + TagDateTime(), + ] + + def __init__(self): + self._tags = {} + self._order = [] + + for tag in self._default_tags: + self.register(tag) + + def register(self, tag, force=False, index=-1): + key = tag.key + + if key is not None: + if not force and key in self._tags: + raise KeyError("Tag '{0}' is already registered.".format(key)) + + self._tags[key] = tag + + if index == -1: + self._order.append(tag) + else: + self._order.insert(index, tag) + + def _tag(self, value): + for tag in self._order: + if tag.check(self, value): + return tag.tag(self, value) + + return value + + def _untag(self, value): + if len(value) != 1: + return value + + key = next(iter(value)) + + if key not in self._tags: + return value + + return self._tags[key].to_python(self, value[key]) + + def dumps(self, value): + return dumps(self._tag(value), separators=(',', ':')) + + def loads(self, value): + return loads(value, object_hook=self._untag) diff --git a/flask/sessions.py b/flask/sessions.py index a334e703..82b588bc 100644 --- a/flask/sessions.py +++ b/flask/sessions.py @@ -9,18 +9,14 @@ :license: BSD, see LICENSE for more details. """ import hashlib -import uuid import warnings -from base64 import b64decode, b64encode from datetime import datetime from itsdangerous import BadSignature, URLSafeTimedSerializer from werkzeug.datastructures import CallbackDict -from werkzeug.http import http_date, parse_date -from . import Markup, json -from ._compat import iteritems, text_type -from .helpers import is_ip, total_seconds +from flask.helpers import is_ip, total_seconds +from flask.json.tag import TaggedJSONSerializer class SessionMixin(object): @@ -57,126 +53,6 @@ class SessionMixin(object): #: from being served the same cache. accessed = True -class TaggedJSONSerializer(object): - """A customized JSON serializer that supports a few extra types that - we take for granted when serializing (tuples, markup objects, datetime). - """ - - def __init__(self): - self.conversions = [ - { - 'check': lambda value: self._is_dict_with_used_key(value), - 'tag': lambda value: self._tag_dict_used_with_key(value), - 'untag': lambda value: self._untag_dict_used_with_key(value), - 'key': ' di', - }, - { - 'check': lambda value: isinstance(value, tuple), - 'tag': lambda value: [self._tag(x) for x in value], - 'untag': lambda value: tuple(value), - 'key': ' t', - }, - { - 'check': lambda value: isinstance(value, uuid.UUID), - 'tag': lambda value: value.hex, - 'untag': lambda value: uuid.UUID(value), - 'key': ' u', - }, - { - 'check': lambda value: isinstance(value, bytes), - 'tag': lambda value: b64encode(value).decode('ascii'), - 'untag': lambda value: b64decode(value), - 'key': ' b', - }, - { - 'check': lambda value: callable(getattr(value, '__html__', - None)), - 'tag': lambda value: text_type(value.__html__()), - 'untag': lambda value: Markup(value), - 'key': ' m', - }, - { - 'check': lambda value: isinstance(value, list), - 'tag': lambda value: [self._tag(x) for x in value], - }, - { - 'check': lambda value: isinstance(value, datetime), - 'tag': lambda value: http_date(value), - 'untag': lambda value: parse_date(value), - 'key': ' d', - }, - { - 'check': lambda value: isinstance(value, dict), - 'tag': lambda value: dict((k, self._tag(v)) for k, v in - iteritems(value)), - }, - { - 'check': lambda value: isinstance(value, str), - 'tag': lambda value: self._tag_string(value), - } - ] - - @property - def keys(self): - return [c['key'] for c in self.conversions if c.get('key')] - - def _get_conversion_untag(self, key): - return next( - (c['untag'] for c in self.conversions if c.get('key') == key), - lambda v: None - ) - - def _is_dict_with_used_key(self, v): - return isinstance(v, dict) and len(v) == 1 and list(v)[0] in self.keys - - def _was_dict_with_used_key(self, k): - return k.endswith('__') and k[:-2] in self.keys - - def _tag_string(self, value): - try: - return text_type(value) - except UnicodeError: - from flask.debughelpers import UnexpectedUnicodeError - raise UnexpectedUnicodeError(u'A byte string with ' - u'non-ASCII data was passed to the session system ' - u'which can only store unicode strings. Consider ' - u'base64 encoding your string (String was %r)' % value) - - def _tag_dict_used_with_key(self, value): - k, v = next(iteritems(value)) - return {'%s__' % k: v} - - def _tag(self, value): - for tag_ops in self.conversions: - if tag_ops['check'](value): - tag = tag_ops.get('key') - if tag: - return {tag: tag_ops['tag'](value)} - return tag_ops['tag'](value) - return value - - def _untag_dict_used_with_key(self, the_value): - k, v = next(iteritems(the_value)) - if self._was_dict_with_used_key(k): - return {k[:-2]: self._untag(v)} - - def _untag(self, obj): - if len(obj) != 1: - return obj - the_key, the_value = next(iteritems(obj)) - untag = self._get_conversion_untag(the_key) - new_value = untag(the_value) - return new_value if new_value else obj - - def dumps(self, value): - return json.dumps(self._tag(value), separators=(',', ':')) - - def loads(self, value): - return json.loads(value, object_hook=self._untag) - - -session_json_serializer = TaggedJSONSerializer() - class SecureCookieSession(CallbackDict, SessionMixin): """Base class for sessions based on signed cookies.""" @@ -284,10 +160,10 @@ class SessionInterface(object): def get_cookie_domain(self, app): """Returns the domain that should be set for the session cookie. - + Uses ``SESSION_COOKIE_DOMAIN`` if it is configured, otherwise falls back to detecting the domain based on ``SERVER_NAME``. - + Once detected (or if not set at all), ``SESSION_COOKIE_DOMAIN`` is updated to avoid re-running the logic. """ @@ -377,7 +253,7 @@ class SessionInterface(object): has been modified, the cookie is set. If the session is permanent and the ``SESSION_REFRESH_EACH_REQUEST`` config is true, the cookie is always set. - + This check is usually skipped if the session was deleted. .. versionadded:: 0.11 @@ -404,6 +280,9 @@ class SessionInterface(object): raise NotImplementedError() +session_json_serializer = TaggedJSONSerializer() + + class SecureCookieSessionInterface(SessionInterface): """The default session interface that stores sessions in signed cookies through the :mod:`itsdangerous` module.