Last active
August 22, 2020 01:07
-
-
Save todofixthis/79a2f213989a3584211e49bfba582b40 to your computer and use it in GitHub Desktop.
MongoDB transparent escaping/unescaping of illegal keys
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
from __future__ import absolute_import, print_function, unicode_literals | |
from pprint import pprint | |
from bson import ObjectId, SON | |
from pymongo import MongoClient | |
from pymongo.collection import Collection | |
from key_escaper import DeterministicKeyEscaper | |
def main(): | |
client = MongoClient() | |
collection = client['test_db']['test_collection'] | |
# Example of a document with naughty keys. | |
# :see: https://docs.mongodb.com/manual/reference/limits/#Restrictions-on-Field-Names | |
document = { | |
'$foo': 'bar', | |
'$baz': 'luhrmann' | |
} | |
document_id = store(collection, document) | |
retrieved = retrieve(collection, document_id) | |
print('Asserting that retrieved document matches what we stored.') | |
retrieved.pop('_id') | |
assert retrieved.to_dict() == document | |
print('Match!') | |
def store(collection, document): | |
# type: (Collection, dict) -> ObjectId | |
"""Stores a document to the specified collection.""" | |
print('Original document:') | |
pprint(document) | |
print('') | |
# Before inserting the document, we must first run it through our | |
# DeterministicKeyEscaper. | |
manipulator = DeterministicKeyEscaper() | |
# Note that the method to invoke here is `transform_incoming`. | |
# From MongoDB's perspective, this document is coming in. | |
escaped = manipulator.transform_incoming(document, collection.name) | |
print('Escaped document:') | |
pprint(escaped.to_dict()) | |
print('') | |
# Now we can insert the document. | |
result = collection.insert_one(escaped) | |
return result.inserted_id | |
def retrieve(collection, document_id): | |
# type: (Collection, ObjectId) -> SON | |
"""Retrieves a document from the specified collection.""" | |
raw = collection.find_one({'_id': document_id}) | |
print('Stored document:') | |
pprint(raw) | |
print('') | |
# Run the document through our DeterministicKeyEscaper to restore the original | |
# keys. | |
manipulator = DeterministicKeyEscaper() | |
# Note that the method to invoke here is `transform_outgoing`. | |
# From MongoDB's perspective, this document is going out. | |
unescaped = manipulator.transform_outgoing(raw, collection.name) | |
print('Unescaped document:') | |
pprint(unescaped.to_dict()) | |
print('') | |
return unescaped | |
if __name__ == '__main__': | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
from __future__ import absolute_import, division, print_function, \ | |
unicode_literals | |
from abc import ABCMeta, abstractmethod as abstract_method | |
from codecs import decode | |
from hashlib import md5 | |
from sys import getdefaultencoding as get_default_encoding | |
from typing import Any, Dict, Iterable, List, Mapping, Text, Tuple, Union | |
from bson import InvalidDocument, SON | |
from pymongo.son_manipulator import SONManipulator | |
from six import binary_type, iteritems, string_types, text_type, with_metaclass | |
__all__ = [ | |
'BaseKeyEscaper', | |
'DeterministicKeyEscaper', | |
'NonDeterministicKeyEscaper', | |
] | |
class BaseKeyEscaper(with_metaclass(ABCMeta, SONManipulator)): | |
""" | |
Escapes illegal keys, ensuring that the original values can be | |
recovered later. | |
Note that the escaped keys will be virtually impossible to query | |
for, but that's infinitely preferable to MongoDB refusing to | |
store the document in the first place. | |
""" | |
magic_prefix = '__escaped__' | |
"""Used to identify escaped keys.""" | |
def __init__(self): | |
super(BaseKeyEscaper, self).__init__() | |
## | |
# These attributes are only used when escaping keys. | |
## | |
self.current_path = None # type: List[Text] | |
""" | |
Keeps track of where we are in the document so that we can | |
populate the ``__escapedKeys`` dict correctly. | |
""" | |
self.escaped_keys = None # type: Dict[Text, Union[Dict, Text]] | |
""" | |
Keeps track of any keys that we've escaped so that a | |
KeyEscaper can later unescape them. | |
""" | |
@abstract_method | |
def escape_key(self, key): | |
# type: (Text) -> Text | |
""" | |
Escapes a single key. | |
""" | |
raise NotImplementedError( | |
'Not implemented in {cls}.'.format(cls=type(self).__name__), | |
) | |
def will_copy(self): | |
# type: () -> bool | |
""" | |
Does this manipulator create a copy of the SON? | |
""" | |
# | |
# ``transform_incoming`` does create a copy, but | |
# ``transform_outgoing`` does not. Well, we have to pick one! | |
# | |
# We'll go with ``False`` because it is not safe to assume that | |
# this manipulator will create a copy of the SON. | |
# | |
return False | |
def transform_incoming(self, son, collection): | |
# type: (Union[Mapping, SON], Text) -> SON | |
""" | |
Transforms a document before it is stored to the database. | |
""" | |
self.current_path = [] # type: List[Text] | |
self.escaped_keys = {} # type: Dict[Text, Union[Dict, Text]] | |
transformed = self._escape(son) | |
transformed[self.magic_prefix] = self.escaped_keys | |
return transformed | |
def transform_outgoing(self, son, collection): | |
# type: (Union[Mapping, SON], Text) -> SON | |
""" | |
Transforms a document after it is retrieved from the database. | |
Note that this method will directly modify the document! | |
""" | |
escaped_keys = son.pop(self.magic_prefix, None) | |
if not isinstance(escaped_keys, Mapping): | |
# Document is corrupted or was not escaped when it was stored. | |
return son | |
return self._unescape(son, escaped_keys) if escaped_keys else son | |
def _escape(self, son): | |
# type: (Union[Mapping, SON]) -> SON | |
""" | |
Recursively crawls the document, transforming keys as it goes. | |
""" | |
copy = SON() | |
for (key, value) in iteritems(son): # type: Tuple[Text, Any] | |
# Python 2 compatibility: Binary strings are allowed, so long as | |
# they can be converted to unicode strings. | |
if isinstance(key, binary_type): | |
encoding = get_default_encoding() | |
if encoding == 'ascii': | |
encoding = 'utf-8' | |
try: | |
key = decode(key, encoding) | |
except UnicodeDecodeError: | |
pass | |
if not isinstance(key, text_type): | |
raise InvalidDocument( | |
'documents must have only string keys, ' | |
'key was {path}[{actual!r}]'.format( | |
actual = key, | |
path = '.'.join(self.current_path), | |
), | |
) | |
if ( | |
key.startswith('$') | |
or key.startswith(self.magic_prefix) | |
or ('.' in key) | |
or ('\x00' in key) | |
): | |
key = self._escape_key(key) | |
self.current_path.append(key) | |
copy[key] = self._escape_value(value) | |
self.current_path.pop() | |
return copy | |
def _escape_key(self, key): | |
# type: (Text) -> Text | |
""" | |
Transforms an illegal key into something that MongoDB will | |
approve of. | |
""" | |
new_key = self.escape_key(key) | |
# Insert the escaped key into the correct location in | |
# ``self.escaped_keys`` so that it can be unescaped later. | |
crawler = self.escaped_keys | |
for x in self.current_path: | |
crawler.setdefault(x, [None, {}]) | |
crawler = crawler[x][1] | |
crawler[new_key] = [key, {}] | |
return new_key | |
def _escape_value(self, value): | |
""" | |
Recursively escapes nested values inside mappings and iterables. | |
""" | |
# Escape nested mappings. | |
if isinstance(value, Mapping): | |
return self._escape(value) | |
# Scan iterables for nested mappings. | |
elif isinstance(value, Iterable) and not isinstance(value, string_types): | |
copy = [] | |
for i, item in enumerate(value): | |
self.current_path.append(text_type(i)) | |
copy.append(self._escape_value(item)) | |
self.current_path.pop() | |
return copy | |
# Any other value is safe to return unescaped. | |
else: | |
return value | |
def _unescape(self, son, escaped_keys): | |
""" | |
Recursively unescapes a value. | |
""" | |
if isinstance(son, Mapping): | |
copy = SON() | |
for key, value in iteritems(son): | |
if key in escaped_keys: | |
# - ``r_key`` is the unescaped key value. | |
# - ``r_children`` contains information needed to unescape | |
# nested objects (if any). | |
r_key, r_children = escaped_keys[key] | |
if r_key is None: | |
# The key did not need to be escaped; it's just a | |
# placeholder so that we can find a nested object that was | |
# escaped. | |
r_key = key | |
if r_children: | |
# Descend into the nested value and continue escaping. | |
copy[r_key] = self._unescape(son[key], r_children) | |
else: | |
# The nested value did not need to be escaped. | |
copy[r_key] = son[key] | |
else: | |
copy[key] = value | |
elif isinstance(son, Iterable) and not isinstance(son, string_types): | |
copy = [] | |
for i, value in enumerate(son): | |
key = text_type(i) | |
if key in escaped_keys: | |
# Lists don't have keys that need escaping; we're only | |
# interested in whether the value is a nested mapping. | |
_, r_children = escaped_keys[key] | |
if r_children: | |
# Descend into the nested value and continue escaping. | |
copy.append(self._unescape(value, r_children)) | |
else: | |
# The nested value did not need to be escaped. | |
copy.append(value) | |
else: | |
copy.append(value) | |
else: | |
copy = son | |
return copy | |
class NonDeterministicKeyEscaper(BaseKeyEscaper): | |
""" | |
A KeyEscaper that uses an internal counter to generate escaped keys. | |
This method is a bit faster and tends to yield smaller escaped keys | |
than DeterministicKeyEscaper, but the result is more difficult to | |
query. | |
""" | |
def __init__(self): | |
super(NonDeterministicKeyEscaper, self).__init__() | |
self.escaped_key_count = None # type: int | |
"""Used to ensure each escaped key is unique.""" | |
def transform_incoming(self, son, collection): | |
self.escaped_key_count = 0 # type: int | |
return \ | |
super(NonDeterministicKeyEscaper, self) \ | |
.transform_incoming(son, collection) | |
def escape_key(self, key): | |
escaped = self.magic_prefix + text_type(self.escaped_key_count) | |
self.escaped_key_count += 1 | |
return escaped | |
class DeterministicKeyEscaper(BaseKeyEscaper): | |
""" | |
A KeyEscaper that uses hashes to escape unsafe keys. | |
This method is a little slower and tends to yield larger escaped keys | |
than NonDeterministicKeyEscaper, but you can execute queries | |
against the escaped keys more easily. | |
""" | |
def escape_key(self, key): | |
# Note: In Python 3, hashlib requires a byte string. | |
return self.magic_prefix + md5(key.encode('utf-8')).hexdigest() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding=utf-8 | |
from __future__ import absolute_import, unicode_literals | |
from abc import ABCMeta, abstractproperty as abstract_property | |
from inspect import isabstract as is_abstract | |
from unittest import TestCase | |
from bson import SON | |
from pymongo import MongoClient | |
from six import with_metaclass | |
from key_escaper import ( | |
BaseKeyEscaper, | |
DeterministicKeyEscaper, | |
NonDeterministicKeyEscaper, | |
) | |
__all__ = [ | |
'DeterministicKeyEscaper', | |
'NonDeterministicKeyEscaper', | |
] | |
class TestCaseMeta(ABCMeta): | |
# noinspection PyShadowingBuiltins | |
def __init__(cls, name, bases=None, dict=None): | |
super(TestCaseMeta, cls).__init__(name, bases, dict) | |
# :see: https://nose.readthedocs.io/en/latest/finding_tests.html | |
cls.__test__ = not is_abstract(cls) | |
class BaseKeyEscaperTestCase(with_metaclass(TestCaseMeta, TestCase)): | |
""" | |
Defines base functionality and templates for KeyEscaper test cases. | |
""" | |
@abstract_property | |
def get_manipulator(self): | |
# type: () -> BaseKeyEscaper | |
raise NotImplementedError( | |
'Not implemented in {cls}.'.format(cls=type(self).__name__), | |
) | |
def setUp(self): | |
super(BaseKeyEscaperTestCase, self).setUp() | |
client = MongoClient() | |
self.collection = client['test_db']['test_collection'] | |
# Purge any existing documents from the collection. | |
self.collection.drop() | |
self.manipulator = self.get_manipulator() # type: BaseKeyEscaper | |
def assertKeysEscaped(self, document): | |
""" | |
Asserts that the KeyEscaper correctly escapes/unescapes keys in the | |
document. | |
""" | |
escaped =\ | |
self.manipulator.transform_incoming(document, self.collection.name) | |
document_id = self.collection.insert_one(escaped).inserted_id | |
# Load the stored document from the database, omitting the '_id' | |
# field, since we can't predict that value for the comparison. | |
stored = self.collection.find_one( | |
filter = {'_id': document_id}, | |
projection = {'_id': False}, | |
) | |
unescaped =\ | |
self.manipulator.transform_outgoing(stored, self.collection.name) | |
if isinstance(unescaped, SON): | |
unescaped = unescaped.to_dict() | |
self.assertEqual(unescaped, document) | |
def test_illegal_key_names_dollar(self): | |
""" | |
The stored document includes keys that starts with '$'. | |
This is a MongoDB no-no, according to | |
https://docs.mongodb.com/manual/reference/limits/#Restrictions-on-Field-Names | |
""" | |
self.assertKeysEscaped({ | |
'$topLevel': { | |
'severity': 'innocent enough', | |
'explanation': 'this is a common enough use case', | |
}, | |
'nested': { | |
'$tricky': 'can we handle nested values?', | |
'$deep': { | |
'$reference': 'we need to go deeper', | |
}, | |
}, | |
'$iñtërnâtiônàlizætiøn': 'non-ascii characters supported, too', | |
'perfectly$legal': | |
"keys may contain '$' so long as it's not the first character", | |
'string': '$values may start with "$", no problem', | |
'list': ['$list', '$items', '$are', '$also', '$exempt'], | |
}) | |
def test_illegal_key_names_dot(self): | |
""" | |
The stored document includes keys that include '.' characters. | |
This is a MongoDB no-no, according to | |
https://docs.mongodb.com/manual/reference/limits/#Restrictions-on-Field-Names | |
""" | |
self.assertKeysEscaped({ | |
'top.level': { | |
'severity': 'innocent enough', | |
'explanation': 'this is a common enough use case', | |
}, | |
'nested': { | |
'.tricky': 'can we handle nested values?', | |
'.deep': { | |
'reference.': 'we need to go deeper', | |
}, | |
}, | |
'.iñtërnâtiônàlizætiøn': 'non-ascii characters supported, too', | |
'string': 'values.may.contain "." no.problem', | |
'list': ['.list', 'items.', '.are.', '.also', 'exempt.'], | |
}) | |
def test_illegal_key_names_null(self): | |
""" | |
The stored document includes keys that include null bytes. | |
This is a MongoDB no-no, according to | |
https://docs.mongodb.com/manual/reference/limits/#Restrictions-on-Field-Names | |
""" | |
self.assertKeysEscaped({ | |
# These all should evaluate to the same code point, but | |
# just to make absolutely sure.... | |
'\U00000000top\x00level\u0000': { | |
'severity': 'suspect', | |
'explanation': | |
'not sure why you would ever really need to do this', | |
}, | |
'nested': { | |
'\x00tricky\u0000': 'can we handle nested values?', | |
'\x00deep': { | |
'\U00000000reference': 'we need to go deeper', | |
}, | |
}, | |
'\x00iñtërnâtiônàlizætiøn': 'non-ascii characters supported, too', | |
'string': 'values\x00may\x00contain\x00nulls\x00no\x00problem', | |
'list': [ | |
'\x00list', | |
'items\x00', | |
'\x00are\x00', | |
'also\x00', | |
'\x00exempt', | |
], | |
}) | |
def test_illegal_key_names_magic(self): | |
""" | |
The stored document includes key names that coincide with | |
escaped keys. | |
""" | |
self.assertKeysEscaped({ | |
# This is the attribute where the self.manipulator stores the | |
# escaped keys. | |
self.manipulator.magic_prefix: { | |
'severity': 'strange', | |
'explanation': 'i guess it could happen', | |
}, | |
# This is an example of an escaped key. | |
self.manipulator.magic_prefix + '1': | |
'somebody has to think of these things', | |
# This is nonsense, but props for creative thinking. | |
self.manipulator.magic_prefix + 'wonka': | |
'there is no life i know to compare with pure imagination', | |
'nested': { | |
self.manipulator.magic_prefix: 'can we handle nested values?', | |
self.manipulator.magic_prefix + '0': 'same story, different day', | |
self.manipulator.magic_prefix + 'deep': { | |
self.manipulator.magic_prefix: 'we need to go deeper', | |
}, | |
}, | |
self.manipulator.magic_prefix + 'iñtërnâtiônàlizætiøn': | |
'non-ascii characters supported, too', | |
# Values may use the magic prefix without consequence. | |
'string': self.manipulator.magic_prefix, | |
'list': [ | |
self.manipulator.magic_prefix, | |
self.manipulator.magic_prefix + '0', | |
self.manipulator.magic_prefix + 'foobar', | |
], | |
}) | |
def test_illegal_key_names_combo(self): | |
"""The stored document has all kinds of illegal keys.""" | |
self.assertKeysEscaped({ | |
self.manipulator.magic_prefix + '$very.very.\x00illegal\x00': { | |
'severity': 'major', | |
'explanation': 'did you even read the instructions?', | |
}, | |
'nested': { | |
'$dollars': 'starts with $', | |
'has.dot': 'contains a .', | |
'has\x00null': 'contains a null', | |
'$iñtërnâtiônàlizætiøn': 'contains non-ascii', | |
'$level.down': { | |
'..': 'low-budget ascii bear', | |
}, | |
self.manipulator.magic_prefix: 'overslept', | |
}, | |
}) | |
def test_safe_byte_strings(self): | |
""" | |
Byte strings are allowed, so long as they can be converted into | |
unicode strings. | |
""" | |
document_id = self._store_document({ | |
b'$ascii_escaped': 'escaped, safe; contains ascii only', | |
b'ascii_unescaped': 'unescaped, safe; contains ascii only', | |
'$iñtërnâtiônàlizætiøn_escaped'.encode(get_default_encoding()): | |
def test_safe_byte_strings(self): | |
""" | |
Byte strings are allowed, so long as they can be converted into | |
unicode strings. | |
""" | |
document_id = self._store_document({ | |
b'$ascii_escaped': 'escaped, safe; contains ascii only', | |
b'ascii_unescaped': 'unescaped, safe; contains ascii only', | |
'$iñtërnâtiônàlizætiøn_escaped'.encode('utf-8'): | |
'escaped, safe; non-ascii, but can be decoded w/ default encoding', | |
'iñtërnâtiônàlizætiøn_unescaped'.encode('utf-8'): | |
'unescaped, safe; non-ascii, but can be decoded w/ default encoding', | |
}) | |
retrieved = self._retrieve_document({'_id': document_id}) | |
self.assertDictEqual( | |
retrieved, | |
{ | |
# Note that keys are automatically converted to unicode strings | |
# before storage. | |
'$ascii_escaped': 'escaped, safe; contains ascii only', | |
'ascii_unescaped': 'unescaped, safe; contains ascii only', | |
'$iñtërnâtiônàlizætiøn_escaped': | |
'escaped, safe; non-ascii, but can be decoded w/ default encoding', | |
'iñtërnâtiônàlizætiøn_unescaped': | |
'unescaped, safe; non-ascii, but can be decoded w/ default encoding', | |
}, | |
) | |
def test_unsafe_byte_strings(self): | |
""" | |
Any byte string that can't be converted into a unicode string is | |
invalid. | |
""" | |
# Ensure that we pick the wrong encoding, regardless of system | |
# configuration. | |
wrong_encoding = \ | |
'latin-1' if get_default_encoding() == 'utf-16' else 'utf-16' | |
with self.assertRaises(InvalidDocument): | |
self.manipulator.transform_incoming( | |
{'$iñtërnâtiônàlizætiøn'.encode(wrong_encoding): 'wrong encoding!'}, | |
self.collection.name, | |
) | |
# An exception will be raised even if the key doesn't need to be | |
# escaped. | |
with self.assertRaises(InvalidDocument): | |
self.manipulator.transform_incoming( | |
{'iñtërnâtiônàlizætiøn'.encode(wrong_encoding): 'wrong encoding!'}, | |
self.collection.name, | |
) | |
class DeterministicKeyEscaperTestCase(BaseKeyEscaperTestCase): | |
def get_manipulator(self): | |
return DeterministicKeyEscaper() | |
def test_query_by_escaped_key(self): | |
""" | |
It is possible (with a little work) to find a document using | |
an escaped key. | |
""" | |
document = { | |
'data': { | |
'responseValues': { | |
'$firstName': 'Marcus', | |
'$lastName': 'Brody', | |
}, | |
}, | |
} | |
self._store_document(document) | |
# If we escape the entire search key, we won't find anything, | |
# because the entire thing will be escaped. | |
self.assertIsNone( | |
self.collection.find_one({ | |
self.manipulator.escape_key('data.responseValues.$lastName'): 'Brody', | |
}) | |
) | |
# Instead, we need to escape just the final part of the filter key. | |
self.assertDictEqual( | |
self._retrieve_document({ | |
'data.responseValues.' + self.manipulator.escape_key('$lastName'): | |
'Brody', | |
}), | |
document, | |
) | |
class NonDeterministicKeyEscaperTestCase(BaseKeyEscaperTestCase): | |
def get_manipulator(self): | |
return NonDeterministicKeyEscaper() | |
def test_query_by_escaped_key(self): | |
""" | |
As its name suggests, NonDeterministicKeyEscaper uses (effectively) | |
unpredictable replacement names for escaped keys. | |
""" | |
document = { | |
'data': { | |
'responseValues': { | |
'$firstName': 'Marcus', | |
'$lastName': 'Brody', | |
}, | |
}, | |
} | |
self._store_document(document) | |
# | |
# It is theoretically possible to guess the correct escaped key, | |
# but outside of contrived examples in unit tests, it's very | |
# unlikely that this approach will ever be practical. | |
# | |
# If you want to be able to query against escaped keys, you're | |
# better off using DeterministicKeyEscaper. | |
# | |
self.assertIsNone( | |
self.collection.find_one({ | |
'data.responseValues.' + self.manipulator.escape_key('$lastName'): | |
'Brody', | |
}) | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment