Source code for nanomongo.document

import importlib
import weakref

import six

from bson import ObjectId, DBRef

from .errors import ConfigurationError, DBRefNotSetError, ExtraFieldError, UnsupportedOperation, ValidationError
from .field import Field
from .util import (
    RecordingDict, DotNotationMixin, valid_client, NanomongoSONManipulator,
    check_spec,
)


def ref_getter_maker(field_name, document_class=None):
    """create dereference methods for given ``field_name`` to be bound
    to Document instances
    """
    def ref_getter(self):
        if field_name not in self or not self[field_name]:
            raise DBRefNotSetError('"%s" field is not set' % field_name)
        dbref = self[field_name]
        if document_class is not None:
            splat = document_class.split('.')
            class_name = splat.pop()
            module = '.'.join(splat) if splat else self.__class__.__module__
            cls = getattr(importlib.import_module(module), class_name)
        else:
            database = dbref.database if dbref.database else self.nanomongo.database
            collection = dbref.collection
            filter_f = (lambda cls: cls.nanomongo.registered and
                        cls.nanomongo.database.name == database and
                        cls.nanomongo.collection == collection)
            classes = list(filter(filter_f, BaseDocument.__subclasses__()))
            if 1 != len(classes):
                err_str = ('can not guess document class for "%s", found: "%s". '
                           'Please provide document_class kwarg to "%s" Field')
                raise UnsupportedOperation(err_str % (dbref, classes, field_name))
            cls = classes.pop()
        # we don't use dereference since BaseDocument.find_one handles type casting nicely
        return cls.find_one(dbref.id)
    return ref_getter


class BasesTuple(tuple):
    pass


class Nanomongo(object):
    """Contains information about the Document it's attached to like
    its fields (which contain validators), db and collection etc and
    provides methods to ease checks
    """
    def __init__(self, fields=None):
        super(Nanomongo, self).__init__()
        if not isinstance(fields, dict):
            raise TypeError('fields kwarg expected of type dict')
        self.fields = fields
        self.classref = None
        self.registered = False
        self.client, self.database, self.collection = None, None, None
        self.transforms = {}  # save auto_update fields so we don't keep looping
        for field_name, field in self.fields.items():
            if hasattr(field, 'auto_update'):
                self.transforms[field_name] = field.auto_update

    @classmethod
    def from_dicts(cls, *args):
        """Create from dict, filtering relevant items"""
        if not args:
            raise TypeError('from_dicts takes at least 1 positional argument')
        fields = {}
        for dct in args:
            if not isinstance(dct, dict):
                raise TypeError('expected input of dictionaries')
            for field_name, field_value in dct.items():
                if isinstance(field_value, Field):
                    fields[field_name] = field_value
        return cls(fields=fields)

    def has_field(self, key):
        """Check existence of field"""
        return key in self.fields

    def list_fields(self):
        """Return a list of strings denoting fields"""
        return sorted(self.fields.keys())

    def validate(self, field_name, value):
        """Validate field input"""
        return self.fields[field_name].validator(value, field_name=field_name)

    def set_client(self, client):
        """Set client, a Client from pymongo or motor expected"""
        if not valid_client(client):
            raise TypeError('pymongo or motor Client expected')
        self.client = client

    def set_db(self, db_string):
        """Set database, string expected"""
        if not db_string or not isinstance(db_string, six.string_types):
            raise TypeError('Exected database string')
        if not self.client:
            raise ConfigurationError('Mongo client not set')
        self.database = self.client[db_string]

    def set_collection(self, col_string):
        """Set collection, string expected"""
        if not col_string or not isinstance(col_string, six.string_types):
            raise TypeError('Expected collection string')
        self.collection = col_string

    def check_config(self):
        """Check if client, database and collection attributes are set"""
        if not self.client:
            raise ConfigurationError('Mongo client not set')
        elif not self.database:
            raise ConfigurationError('database not set')
        elif not self.collection:
            raise ConfigurationError('collection not set')

    def add_son_manipulator(self):
        """add our son manipulator to transform documents coming from mongodb
        to the class we defined, see
        :class:`~nanomongo.util.NanomongoSONManipulator`
        """
        if six.PY2:  # unicode -> str implicit transform for binary_type (str) Fields
            def str_transformer(unicode_str):
                return unicode_str.encode('utf-8')

            str_fields = ((fname, field) for fname, field in self.fields.items()
                          if six.binary_type == field.data_type)
            transforms = dict((fname, str_transformer) for fname, field in str_fields)
        else:
            transforms = None

        manipulator = NanomongoSONManipulator(self.classref(), transforms=transforms)
        self.database.add_son_manipulator(manipulator)

    def register(self, client=None, db_string=None, collection=None):
        """register the class. this is called from defined documents'
        :meth:`~BaseDocument.register()` method. Note that this also
        runs :meth:`~pymongo.collection.Collection.create_indexes()`
        """
        self.set_client(client) if client else None
        self.set_db(db_string) if db_string else None
        self.set_collection(collection) if collection else None
        self.check_config()
        self.add_son_manipulator()
        # indexes
        doc_class = self.classref()
        indexes = doc_class.__indexes__ if hasattr(doc_class, '__indexes__') else []
        if indexes:
            self.get_collection().create_indexes(indexes)
        # mark as registered
        self.registered = True

    def get_collection(self):
        """Returns collection"""
        self.check_config()
        return self.database[self.collection]


class DocumentMeta(type):
    """Document Metaclass. Generates allowed field set and their validators
    """

    def __new__(cls, name, bases, dct, **kwargs):
        """Check against illegal attributes (eg. ``nanomongo``); get bases
        so we can get their :class:`~nanomongo.field.Field` definitions
        """
        if 'nanomongo' in dct:
            raise TypeError('field name "nanomongo" is not allowed')
        if '__indexes__' in dct and not isinstance(dct['__indexes__'], list):
            raise TypeError('__indexes__: list of Index instances expected')
        use_dot_notation = kwargs.pop('dot_notation') if 'dot_notation' in kwargs else None
        if 'dot_notation' in dct:
            use_dot_notation = dct.pop('dot_notation')
        new_bases = cls._get_bases(bases)
        if use_dot_notation and DotNotationMixin not in new_bases:
            new_bases = (DotNotationMixin,) + new_bases
        return super(DocumentMeta, cls).__new__(cls, name, new_bases, dct)

    def __init__(cls, name, bases, dct, **kwargs):
        """Create the `~nanomongo.document.Nanomongo` for this class and delete
        :class:`~nanomongo.field.Field` attributes. Also sets client, db, collection
        info if provided and runs indexes"""
        super(DocumentMeta, cls).__init__(name, bases, dct)
        if hasattr(cls, 'nanomongo'):
            cls.nanomongo = Nanomongo.from_dicts(cls.nanomongo.fields, dct)
        else:
            cls.nanomongo = Nanomongo.from_dicts(dct)
        if not cls.nanomongo.has_field('_id'):
            cls.nanomongo.fields['_id'] = Field(ObjectId, required=False)
        for field_name, field_value in dct.items():
            if isinstance(field_value, Field):
                delattr(cls, field_name)
        # client, database, collection
        cls.nanomongo.classref = weakref.ref(cls)

        def _check_arg(arg):
            return arg in kwargs or hasattr(cls, arg)

        def _get_arg(arg):
            """get arg from kwargs or from class attribute and
            remove class attribute"""
            if arg in kwargs:
                return kwargs[arg]
            retval = getattr(cls, arg)
            delattr(cls, arg)
            return retval

        if _check_arg('client'):
            cls.nanomongo.set_client(_get_arg('client'))
        if _check_arg('db'):
            cls.nanomongo.set_db(_get_arg('db'))
        if _check_arg('collection'):
            cls.nanomongo.set_collection(_get_arg('collection'))
        else:
            cls.nanomongo.set_collection(name.lower())
        # register if nanomongo config is OK
        try:
            cls.nanomongo.check_config()
            cls.nanomongo.register()
        except ConfigurationError:
            pass

    @classmethod
    def _get_bases(cls, bases):
        # taken from MongoEngine
        if isinstance(bases, BasesTuple):
            return bases
        seen = []
        bases = cls.__get_bases(bases)
        unique_bases = (b for b in bases if not (b in seen or seen.append(b)))
        return BasesTuple(unique_bases)

    @classmethod
    def __get_bases(cls, bases):
        for base in bases:
            if base is object:
                continue
            yield base
            for child_base in cls.__get_bases(base.__bases__):
                yield child_base


[docs]@six.add_metaclass(DocumentMeta) class BaseDocument(RecordingDict): """BaseDocument class. Subclasses should be used. See :meth:`~BaseDocument.__init__()` """ def __init__(self, *args, **kwargs): """Inits the document with given data and validates the fields (field validation bad idea during init?). If you define ``__init__`` method for your document class, make sure to call this :: class MyDoc(BaseDocument, dot_notation=True): foo = Field(str) bar = Field(int, required=False) def __init__(self, *args, **kwargs): super(MyDoc, self).__init__(*args, **kwargs) # do other stuff """ # if input dict, merge (not updating) into kwargs if args and not isinstance(args[0], dict): raise TypeError('dict or dict subclass argument expected') elif args: for field_name, field_value in args[0].items(): if field_name not in kwargs: kwargs[field_name] = field_value super(BaseDocument, self).__init__() for field_name, field in self.nanomongo.fields.items(): if hasattr(field, 'default_value'): val = field.default_value dict.__setitem__(self, field_name, val() if callable(val) else val) # attach get_<field_name>_field methods for DBRef fields if field.data_type in [DBRef] + DBRef.__subclasses__(): getter_name = 'get_%s_field' % field_name doc_class = field.document_class if hasattr(field, 'document_class') else None getter = ref_getter_maker(field_name, document_class=doc_class) setattr(self, getter_name, six.create_bound_method(getter, self)) for field_name in kwargs: if self.nanomongo.has_field(field_name): self.nanomongo.validate(field_name, kwargs[field_name]) dict.__setitem__(self, field_name, kwargs[field_name]) else: raise ExtraFieldError('Undefined field %s=%s in %s' % (field_name, kwargs[field_name], self.__class__)) for field_name, field_value in self.items(): # transform dict to RecordingDict so we can track diff in embedded docs if isinstance(field_value, dict): dict.__setitem__(self, field_name, RecordingDict(field_value))
[docs] @classmethod def register(cls, client=None, db=None, collection=None): """Register this document. Sets client, database, collection information, creates indexes and sets SON manipulator """ if cls.nanomongo.registered: err_str = '''%s is already registered. This is automatic if you have defined your document class with client, db, collection.''' % cls raise ConfigurationError(err_str) cls.nanomongo.register(client=client, db_string=db, collection=collection)
[docs] @classmethod def get_collection(cls): """Returns collection as set in :attr:`~cls.nanomongo`""" return cls.nanomongo.get_collection()
[docs] @classmethod def find(cls, *args, **kwargs): """``pymongo.Collection().find`` wrapper for this document""" if args and isinstance(args[0], dict): check_spec(cls, args[0]) return cls.get_collection().find(*args, **kwargs)
[docs] @classmethod def find_one(cls, *args, **kwargs): """``pymongo.Collection().find_one`` wrapper for this document""" if args and isinstance(args[0], dict): check_spec(cls, args[0]) return cls.get_collection().find_one(*args, **kwargs)
def __dir__(self): """Add defined Fields to dir""" return sorted(dir(super(BaseDocument, self)) + self.nanomongo.list_fields())
[docs] def validate(self): """ **Override** this to add extra document validation. It will be called during :meth:`~insert` and :meth:`~save` before the database operation. """ pass
[docs] def validate_all(self): """ Check correctness of the document before :meth:`~insert()`. Ensure that * no extra (undefined) fields are present * field values are of correct data type * required fields are present """ for field_name, field_value in self.items(): if not self.nanomongo.has_field(field_name): raise ValidationError('Extra undefined field "{}" with value "{}"'.format(field_name, field_value)) field = self.nanomongo.fields[field_name] field.validator(field_value, field_name=field_name) # run field validator for field_name, field in self.nanomongo.fields.items(): if field.required and field_name not in self: raise ValidationError('Required field "{}" is missing'.format(field_name))
[docs] def validate_diff(self): """ Check correctness of diffs (ie. ``$set`` and ``$unset``) before :meth:`~save()`. Ensure that * no extra (undefined) fields are present for either set or unset * field values are of correct data type * required fields are not unset """ sets = self.__nanodiff__['$set'] unsets = self.__nanodiff__['$unset'] for field_name, field_value in sets.items(): if not self.nanomongo.has_field(field_name): raise ValidationError('Extra undefined field "{}" with value "{}"'.format(field_name, field_value)) field = self.nanomongo.fields[field_name] field.validator(field_value, field_name=field_name) # run field validator for field_name, field_value in unsets.items(): if self.nanomongo.has_field(field_name): if self.nanomongo.fields[field_name].required: raise ValidationError('Can not unset required field "{}"'.format(field_name)) else: raise ValidationError('Can not unset undefined field "{}"'.format(field_name))
[docs] def run_auto_updates(self): """Runs auto_update functions in ``.nanomongo.transforms``.""" # TODO: This would override any preceding $set on the field for field_name, updater in self.nanomongo.transforms.items(): self[field_name] = updater()
[docs] def insert(self, **kwargs): """ Runs auto updates, validates the document, and inserts into database. Returns ``pymongo.results.InsertOneResult``. """ self.run_auto_updates() self.validate_all() self.validate() insert_one_result = self.get_collection().insert_one(self, **kwargs) self.reset_diff() for field_name, field_value in self.items(): if isinstance(field_value, dict): # cast dicts field_value = RecordingDict(field_value) return insert_one_result
[docs] def save(self, **kwargs): """ Runs auto updates, validates the document, and saves the changes into database. Returns ``pymongo.results.UpdateResult``. """ if '_id' not in self: raise ValidationError('insert first; save does partial updates') if '_id' in self.__nanodiff__['$set']: raise ValidationError('_id seems to be manually set, do insert') self.run_auto_updates() self.validate_diff() self.validate() assert 3 == len(self.__nanodiff__), '__nanodiff__: %s' % self.__nanodiff__ query = {'_id': self['_id']} diff = self.__nanodiff__ # get subdiff containing dotted keys, merge into diff subdiff = self.get_sub_diff() for operator, value in subdiff.items(): diff[operator].update(value) # remove empty update ops, MongoDB 2.6 returns error for them for operator in list(diff.keys()): if not diff[operator]: diff.pop(operator) if not diff: self.reset_diff() return update_result = self.get_collection().update_one(query, diff, **kwargs) self.reset_diff() return update_result
[docs] def add_to_set(self, field, value): """ Explicitly defined ``$addToSet`` functionality. This sets/updates the field value accordingly and records the change to be saved with :meth:`~save()`. :: # MongoDB style dot notation can be used to add to lists # in embedded documents doc = Doc(foo=[], bar={}) doc.add_to_set('foo', new_value) Contrary to how ``$set`` ing the same value has no effect under __setitem__ (see ``.util.RecordingDict.__setitem__()``) when the new value is equal to the current, this explicitly records the change so it will be sent to the database when :meth:`~save()` is called. """ # TODO: doc.add_to_set('bar.sub_field', new_value) doesn't actually work def top_level_add(self, field, value): """add the value to field. appending if the list exists and does not contain the value; create new list otherwise. raise :class:`.errors.ValidationError` if non-list value initiated """ self.check_can_update('$addToSet', field) if field in self and isinstance(self[field], list): if value not in self[field]: self[field].append(value) elif field not in self or self[field] is None: dict.__setitem__(self, field, [value]) # to avoid $set record else: err_str = 'Could not $addToSet on valid field, bad init? %s: %s' raise ValidationError(err_str % (field, self[field])) if field not in self.__nanodiff__['$addToSet']: self.__nanodiff__['$addToSet'][field] = {'$each': [value]} elif value not in self.__nanodiff__['$addToSet'][field]['$each']: self.__nanodiff__['$addToSet'][field]['$each'].append(value) if field.startswith('$') or '.$' in field: err_str = 'MongoDB does not allow fields starting with $. "%s"' raise ValidationError(err_str % field) # if top-level if '.' not in field: if ((self.nanomongo.has_field(field) and list == self.nanomongo.fields[field].data_type)): top_level_add(self, field, value) # add & record elif self.nanomongo.has_field(field): err_str = 'Cannot apply $addToSet modifier to non-array: %s=%s' err_str = err_str % (field, self.nanomongo.fields[field].data_type) raise ValidationError(err_str) else: raise ValidationError('Undefined field: "%s"' % field) # if deep-level else: try: top_key, deep_key = field.split('.') except ValueError: err_str = '''Only top level and one level deep keus supported for \ $addToSet: "%s"''' raise UnsupportedOperation(err_str, field) if not self.nanomongo.has_field(top_key): raise ValidationError('Undefined field: "%s"' % top_key) elif dict != self.nanomongo.fields[top_key].data_type: raise ValidationError('"%s" is not a dict' % top_key) # field name ok, ensure top level value is RecordingDict type if top_key not in self: # not set yet, do it dict.__setitem__(self, top_key, RecordingDict()) elif not isinstance(self[top_key], RecordingDict): # what did you do, use dict.__setitem__ ? :) err_str = '''Dotted key's target is not a RecordingDict: %s=%s \ If you've just set it as a new dict; FYI: you can't $set and $addToSet together''' raise ValidationError(err_str % (top_key, self[top_key])) # make sure we have no $set or $unset on top_key self.check_can_update('$addToSet', top_key) top_level_add(self[top_key], deep_key, value) # add & record
[docs] def get_dbref(self): """Return a ``bson.DBRef`` instance for this :class:`~BaseDocument` instance""" assert '_id' in self and self['_id'], 'Cannot get DBRef for document with no _id' collection = self.get_collection() return DBRef(collection.name, self['_id'], database=collection.database.name)