import __main__
import logging
import pymongo
from .errors import ExtraFieldError, ValidationError
ok_types = (pymongo.MongoClient, pymongo.MongoReplicaSetClient)
try:
import motor
ok_types += (motor.MotorClient,)
except ImportError as e:
pass
logging.basicConfig(format='[%(asctime)s] %(levelname)s [%(module)s.%(funcName)s():%(lineno)d] %(message)s')
logger = logging.getLogger(__file__)
def valid_client(client):
"""Returns ``True`` if input is a pymongo or motor client or any client added with allow_client()."""
return isinstance(client, ok_types)
def allow_client(client_type):
"""Allows another type to act as client type. Intended for using with mock clients."""
global ok_types
ok_types += (client_type,)
def valid_field(obj, field):
"""Returns ``True`` if given object (BaseDocument subclass or an instance thereof) has given field defined."""
return object.__getattribute__(obj, 'nanomongo').has_field(field)
def check_keys(dct):
"""Recursively check against '.' and '$' at start position in dictionary keys."""
if not isinstance(dct, dict):
raise TypeError('dict-like argument expected')
dot_err_str = 'MongoDB does not allow . in field names. "%s"'
dollar_err_str = 'MongoDB does not allow fields starting with $. "%s"'
for k, v in dct.items():
if '.' in k:
raise ValidationError(dot_err_str % k)
elif k.startswith('$'):
raise ValidationError(dollar_err_str % k)
elif isinstance(v, dict):
check_keys(v)
[docs]def check_spec(cls, spec):
"""
Check the query spec for given class and log warnings. Not extensive, helpful to catch mistyped queries.
* Dotted keys (eg. ``{'foo.bar': 1}``) in spec are checked for top-level (ie. ``foo``) field existence
* Dotted keys are also checked for their top-level field type (must be ``dict`` or ``list``)
* Normal keys (eg. ``{'foo': 1}``) in spec are checked for top-level (ie. ``foo``) field existence
* Normal keys with non-dict queries (ie. not something like ``{'foo': {'$gte': 0, '$lte': 1}}``) are also
checked for their data type
"""
for field, query in spec.items():
f = field.split('.')[0]
if not cls.nanomongo.has_field(f): # field existence
logging.warning('%s has no field "%s" defined, spec %s can not match', cls, f, spec)
continue
dtype = cls.nanomongo.fields[f].data_type
query_type = type(query)
if '.' not in field and query_type != dict and query_type != dtype:
# simple query type mismatch
logging.warning('%s field "%s" has type %s, spec %s can not match', cls, f, dtype, spec)
elif '.' in field and dtype not in (dict, list):
# top-level field not a dict or list
logging.warning('%s field "%s" is not of type %s, spec %s can not match', cls, f, (dict, list), spec)
[docs]class RecordingDict(dict):
"""
A dict subclass modifying ``dict.__setitem__()`` and ``dict.__delitem__()`` methods to record changes
internally in its ``__nanodiff__`` attribute.
"""
def __init__(self, *args, **kwargs):
super(RecordingDict, self).__init__(*args, **kwargs)
self.__nanodiff__ = {
'$set': {}, '$unset': {}, '$addToSet': {},
}
def __setitem__(self, key, value):
"""Override the dict method so we can track changes."""
try:
no_change = self[key] == value # same value
except KeyError:
no_change = False # never set
if no_change:
return
value = RecordingDict(value) if isinstance(value, dict) else value
super(RecordingDict, self).__setitem__(key, value)
self.__nanodiff__['$set'][key] = value
self.clear_other_modifiers('$set', key)
def __delitem__(self, key):
"""Override the dict method so we can track changes."""
super(RecordingDict, self).__delitem__(key)
self.__nanodiff__['$unset'][key] = 1
self.clear_other_modifiers('$unset', key)
[docs] def clear_other_modifiers(self, current_mod, field_name):
"""
Given ``current_mod``, removes other ``field_name`` modifiers, eg. when called with ``$set``,
removes ``$unset`` and ``$addToSet`` etc. on ``field_name``.
"""
for mod, updates in self.__nanodiff__.items():
if mod == current_mod:
continue
if field_name in updates:
del self.__nanodiff__[mod][field_name]
[docs] def reset_diff(self):
"""
Reset ``__nanodiff__`` recursively. To be used after saving diffs.
This does NOT do a rollback. Reload from db for that.
"""
nanodiff_base = {'$set': {}, '$unset': {}, '$addToSet': {}}
self.__nanodiff__ = nanodiff_base
for field_name, field_value in self.items():
if isinstance(field_value, RecordingDict):
field_value.reset_diff()
[docs] def get_sub_diff(self):
"""
Find fields of :class:`~RecordingDict` type, iterate over their diff and build dotted
keys to be merged into top level diff.
"""
diff = {'$set': {}, '$unset': {}, '$addToSet': {}}
for field_name, field_value in self.items():
if isinstance(field_value, RecordingDict):
sets = field_value.__nanodiff__['$set']
unsets = field_value.__nanodiff__['$unset']
addtosets = field_value.__nanodiff__['$addToSet']
for k, v in sets.items():
dotkey = '%s.%s' % (field_name, k)
diff['$set'][dotkey] = v
for k, v in unsets.items():
dotkey = '%s.%s' % (field_name, k)
diff['$unset'][dotkey] = v
for k, v in addtosets.items():
dotkey = '%s.%s' % (field_name, k)
diff['$addToSet'][dotkey] = v
return diff
[docs] def check_can_update(self, modifier, field_name):
"""Check if given `modifier` `field_name` combination can be
added. MongoDB does not allow field duplication with update
modifiers. This is to be used with methods :meth:`~.document.BaseDocument.add_to_set()` ...
"""
for mod, updates in self.__nanodiff__.items():
if mod == modifier:
continue
if field_name in updates:
err_str = 'Field name duplication not allowed with modifiers '
err_str += ('new: {%s} old: {%s: {%s: %s}}' %
(modifier, mod, field_name, updates[field_name]))
raise ValidationError(err_str)
class DotNotationMixin(object):
"""Mixin to make dot notation available on dictionaries"""
# TODO: When dot_notation is active but key not a field, FAIL?
def __setattr__(self, key, value):
"""object attribute setting eg. ``self.foo = 42``"""
if not valid_field(self, key):
super(DotNotationMixin, self).__setattr__(key, value)
else:
self.__setitem__(key, value)
def __getattr__(self, key):
"""object attribute lookup eg. ``print(self.foo)``"""
if not valid_field(self, key):
return super(DotNotationMixin, self).__getattribute__(key)
try:
return self.__getitem__(key)
except KeyError:
pass
raise AttributeError("'%s' object has no attribute '%s'" %
(self.__class__.__name__, key))
def __getattribute__(self, key):
# first to check interpreter
if hasattr(__main__, '__file__') or not valid_field(self, key):
return super(DotNotationMixin, self).__getattribute__(key)
try:
return self.__getitem__(key)
except KeyError:
return
return super(DotNotationMixin, self).__getattribute__(key)
def __delattr__(self, key):
"""object attribute delete eg. ``del self.foo``"""
if not valid_field(self, key):
super(DotNotationMixin, self).__delattr__(key)
return
try:
self.__delitem__(key)
except KeyError:
raise AttributeError("'%s' object has no attribute '%s'" %
(self.__class__.__name__, key))
[docs]class NanomongoSONManipulator(pymongo.son_manipulator.SONManipulator):
"""A pymongo SON Manipulator used on data that comes from the database
to transform data to the document class we want because `as_class`
argument to pymongo find methods is called in a way that screws us.
- Recursively applied, we don't want that
- `__init__` is not properly used but rather __setitem__, fails us
JIRA: PYTHON-175 PYTHON-215
"""
def __init__(self, as_class, transforms=None):
self.as_class = as_class
if transforms:
assert isinstance(transforms, dict), 'transforms must be a dict'
self.transforms = transforms
def will_copy(self):
return True
def transform_outgoing(self, son, collection):
if hasattr(self, 'transforms'):
for field, transformer in self.transforms.items():
son[field] = transformer(son[field])
try:
return self.as_class(son)
except ExtraFieldError:
return son