Changeset - b69785f5be6b
[Not reviewed]
0 2 6
Lance Edgar - 8 years ago 2016-05-11 17:23:36
ledgar@sacfoodcoop.com
Add initial importers for new/final importing framework

Er, hopefully final. This one gets test coverage at least, hopefully
that makes it official.
8 files changed with 1249 insertions and 1 deletions:
0 comments (0 inline, 0 general)
rattail/db/util.py
Show inline comments
 
@@ -39,6 +39,23 @@ from rattail.db.config import engine_from_config, get_engines, get_default_engin
 
log = logging.getLogger(__name__)
 

	
 

	
 
class QuerySequence(object):
 
    """
 
    Simple wrapper for a SQLAlchemy (or Django, or other?) query, to make it
 
    sort of behave like a normal sequence, as much as needed to e.g. make an
 
    importer happy.
 
    """
 

	
 
    def __init__(self, query):
 
        self.query = query
 

	
 
    def __len__(self):
 
        return self.query.count()
 

	
 
    def __iter__(self):
 
        return iter(self.query)
 

	
 

	
 
def maxlen(attr):
 
    """
 
    Return the maximum length for the given attribute.
rattail/importing/__init__.py
Show inline comments
 
new file 100644
 
# -*- coding: utf-8 -*-
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2016 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
#  Rattail is free software: you can redistribute it and/or modify it under the
 
#  terms of the GNU Affero General Public License as published by the Free
 
#  Software Foundation, either version 3 of the License, or (at your option)
 
#  any later version.
 
#
 
#  Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
 
#  WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 
#  FOR A PARTICULAR PURPOSE.  See the GNU Affero General Public License for
 
#  more details.
 
#
 
#  You should have received a copy of the GNU Affero General Public License
 
#  along with Rattail.  If not, see <http://www.gnu.org/licenses/>.
 
#
 
################################################################################
 
"""
 
Data Importing Framework
 
"""
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from .importers import Importer, FromQuery
 
from .sqlalchemy import FromSQLAlchemy, ToSQLAlchemy
rattail/importing/importers.py
Show inline comments
 
new file 100644
 
# -*- coding: utf-8 -*-
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2016 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
#  Rattail is free software: you can redistribute it and/or modify it under the
 
#  terms of the GNU Affero General Public License as published by the Free
 
#  Software Foundation, either version 3 of the License, or (at your option)
 
#  any later version.
 
#
 
#  Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
 
#  WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 
#  FOR A PARTICULAR PURPOSE.  See the GNU Affero General Public License for
 
#  more details.
 
#
 
#  You should have received a copy of the GNU Affero General Public License
 
#  along with Rattail.  If not, see <http://www.gnu.org/licenses/>.
 
#
 
################################################################################
 
"""
 
Data Importers
 
"""
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
import datetime
 
import logging
 

	
 
from rattail.db.util import QuerySequence
 
from rattail.time import make_utc
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
class Importer(object):
 
    """
 
    Base class for all data importers.
 
    """
 
    # Set this to the data model class which is targeted on the local side.
 
    model_class = None
 

	
 
    key = None
 

	
 
    # The full list of field names supported by the importer, i.e. for the data
 
    # model to which the importer pertains.  By definition this list will be
 
    # restricted to what the local side can acommodate, but may be further
 
    # restricted by what the host side has to offer.
 
    supported_fields = []
 

	
 
    # The list of field names which may be considered "simple" and therefore
 
    # treated as such, i.e. with basic getattr/setattr calls.  Note that this
 
    # only applies to the local side, it has no effect on the host side.
 
    simple_fields = []
 

	
 
    allow_create = True
 
    allow_update = True
 
    allow_delete = True
 
    dry_run = False
 

	
 
    max_create = None
 
    max_update = None
 
    max_delete = None
 
    max_total = None
 
    progress = None
 

	
 
    caches_local_data = False
 
    cached_local_data = None
 

	
 
    host_system_title = None
 
    local_system_title = None
 

	
 
    def __init__(self, config=None, fields=None, key=None, **kwargs):
 
        self.config = config
 
        self.fields = fields or self.supported_fields
 
        if key is not None:
 
            self.key = key
 
        if isinstance(self.key, basestring):
 
            self.key = (self.key,)
 
        if self.key:
 
            for field in self.key:
 
                if field not in self.fields:
 
                    raise ValueError("Key field '{}' must be included in effective fields "
 
                                     "for {}".format(key, self.__class__.__name__))
 
        self._setup(**kwargs)
 

	
 
    def _setup(self, **kwargs):
 
        self.now = kwargs.pop('now', make_utc(datetime.datetime.utcnow(), tzinfo=True))
 
        self.create = kwargs.pop('create', self.allow_create) and self.allow_create
 
        self.update = kwargs.pop('update', self.allow_update) and self.allow_update
 
        self.delete = kwargs.pop('delete', self.allow_delete) and self.allow_delete
 
        for key, value in kwargs.iteritems():
 
            setattr(self, key, value)
 

	
 
    @property
 
    def model_name(self):
 
        if self.model_class:
 
            return self.model_class.__name__
 

	
 
    def setup(self):
 
        """
 
        Perform any setup necessary, e.g. cache lookups for existing data.
 
        """
 

	
 
    def teardown(self):
 
        """
 
        Perform any cleanup after import, if necessary.
 
        """
 

	
 
    def import_data(self, host_data=None, **kwargs):
 
        """
 
        Import some data!  This is the core body of logic for that, regardless
 
        of where data is coming from or where it's headed.  Note that this
 
        method handles deletions as well as adds/updates.
 
        """
 
        self._setup(**kwargs)
 
        self.setup()
 
        created = updated = deleted = []
 

	
 
        # Get complete set of normalized host data.
 
        if host_data is None:
 
            host_data = self.normalize_host_data()
 

	
 
        # Prune duplicate keys from host/source data.  This is for the sake of
 
        # sanity since duplicates typically lead to a ping-pong effect, where a
 
        # "clean" (change-less) import is impossible.
 
        unique = {}
 
        for data in host_data:
 
            key = self.get_key(data)
 
            if key in unique:
 
                log.warning("duplicate records detected from {} for key: {}".format(
 
                    self.host_system_title, key))
 
            unique[key] = data
 
        host_data = []
 
        for key in sorted(unique):
 
            host_data.append(unique[key])
 

	
 
        # Cache local data if appropriate.
 
        if self.caches_local_data:
 
            self.cached_local_data = self.cache_local_data(host_data)
 

	
 
        # Create and/or update data.
 
        if self.create or self.update:
 
            created, updated = self._import_create_update(host_data)
 

	
 
        # Delete data.
 
        if self.delete:
 
            changes = len(created) + len(updated)
 
            if self.max_total and changes >= self.max_total:
 
                log.warning("max of {} total changes already reached; skipping deletions".format(self.max_total))
 
            else:
 
                deleted = self._import_delete(host_data, set(unique), changes=changes)
 

	
 
        self.teardown()
 
        return created, updated, deleted
 

	
 
    def _import_create_update(self, data):
 
        """
 
        Import the given data; create and/or update records as needed.
 
        """
 
        created, updated = [], []
 
        count = len(data)
 
        if not count:
 
            return created, updated
 

	
 
        prog = None
 
        if self.progress:
 
            prog = self.progress("Importing {} data".format(self.model_name), count)
 

	
 
        for i, host_data in enumerate(data, 1):
 

	
 
            # Fetch local object, using key from host data.
 
            key = self.get_key(host_data)
 
            local_object = self.get_local_object(key)
 

	
 
            # If we have a local object, but its data differs from host, update it.
 
            if local_object and self.update:
 
                local_data = self.normalize_local_object(local_object)
 
                diffs = self.data_diffs(local_data, host_data)
 
                if diffs:
 
                    log.debug("fields '{}' differed for local data: {}, host data: {}".format(
 
                        ','.join(diffs), local_data, host_data))
 
                    local_object = self.update_object(local_object, host_data, local_data)
 
                    updated.append((local_object, local_data, host_data))
 
                    if self.max_update and len(updated) >= self.max_update:
 
                        log.warning("max of {} *updated* records has been reached; stopping now".format(self.max_update))
 
                        break
 
                    if self.max_total and (len(created) + len(updated)) >= self.max_total:
 
                        log.warning("max of {} *total changes* has been reached; stopping now".format(self.max_total))
 
                        break
 

	
 
            # If we did not yet have a local object, create it using host data.
 
            elif not local_object and self.create:
 
                local_object = self.create_object(key, host_data)
 
                log.debug("created new {} {}: {}".format(self.model_name, key, local_object))
 
                created.append((local_object, host_data))
 
                if self.caches_local_data and self.cached_local_data is not None:
 
                    self.cached_local_data[key] = {'object': local_object, 'data': self.normalize_local_object(local_object)}
 
                if self.max_create and len(created) >= self.max_create:
 
                    log.warning("max of {} *created* records has been reached; stopping now".format(self.max_create))
 
                    break
 
                if self.max_total and (len(created) + len(updated)) >= self.max_total:
 
                    log.warning("max of {} *total changes* has been reached; stopping now".format(self.max_total))
 
                    break
 

	
 
            if prog:
 
                prog.update(i)
 
        if prog:
 
            prog.destroy()
 

	
 
        return created, updated
 

	
 
    def _import_delete(self, host_data, host_keys, changes=0):
 
        """
 
        Import deletions for the given data set.
 
        """
 
        deleted = []
 
        deleting = self.get_deletion_keys() - host_keys
 
        count = len(deleting)
 
        log.debug("found {} instances to delete".format(count))
 
        if count:
 

	
 
            prog = None
 
            if self.progress:
 
                prog = self.progress("Deleting {} data".format(self.model_name), count)
 

	
 
            for i, key in enumerate(sorted(deleting), 1):
 

	
 
                cached = self.cached_local_data.pop(key)
 
                obj = cached['object']
 
                if self.delete_object(obj):
 
                    deleted.append((obj, cached['data']))
 

	
 
                    if self.max_delete and len(deleted) >= self.max_delete:
 
                        log.warning("max of {} *deleted* records has been reached; stopping now".format(self.max_delete))
 
                        break
 
                    if self.max_total and (changes + len(deleted)) >= self.max_total:
 
                        log.warning("max of {} *total changes* has been reached; stopping now".format(self.max_total))
 
                        break
 

	
 
                if prog:
 
                    prog.update(i)
 
            if prog:
 
                prog.destroy()
 

	
 
        return deleted
 

	
 
    def get_key(self, data):
 
        """
 
        Return the key value for the given data dict.
 
        """
 
        return tuple(data[k] for k in self.key)
 

	
 
    def get_host_objects(self):
 
        """
 
        Return the "raw" (as-is, not normalized) host objects which are to be
 
        imported.  This may return any sequence-like object, which has a
 
        ``len()`` value and responds to iteration etc.  The objects contained
 
        within it may be of any type, no assumptions are made there.  (That is
 
        the job of the :meth:`normalize_host_data()` method.)
 
        """
 
        return []
 

	
 
    def normalize_host_data(self, host_objects=None):
 
        """
 
        Return a normalized version of the full set of host data.  Note that
 
        this calls :meth:`get_host_objects()` to obtain the initial raw
 
        objects, and then normalizes each object.  The normalization process
 
        may filter out some records from the set, in which case the return
 
        value will be smaller than the original data set.
 
        """
 
        if host_objects is None:
 
            host_objects = self.get_host_objects()
 
        normalized = []
 
        count = len(host_objects)
 
        if count == 0:
 
            return normalized
 
        prog = None
 
        if self.progress:
 
            prog = self.progress("Reading {} data from {}".format(self.model_name, self.host_system_title), count)
 
        for i, obj in enumerate(host_objects, 1):
 
            data = self.normalize_host_object(obj)
 
            if data:
 
                normalized.append(data)
 
            if prog:
 
                prog.update(i)
 
        if prog:
 
            prog.destroy()
 
        return normalized
 

	
 
    def normalize_host_object(self, obj):
 
        """
 
        Normalize a raw host object into a data dict, or return ``None`` if the
 
        object should be ignored for the importer's purposes.
 
        """
 
        return obj
 

	
 
    def cache_local_data(self, host_data=None):
 
        """
 
        Cache all raw objects and normalized data from the local system.
 
        """
 
        raise NotImplementedError
 

	
 
    def get_cache_key(self, obj, normal):
 
        """
 
        Get the primary cache key for a given object and normalized data.
 

	
 
        Note that this method's signature is designed for use with the
 
        :func:`rattail.db.cache.cache_model()` function, and as such the
 
        ``normal`` parameter is understood to be a dict with a ``'data'`` key,
 
        value for which is the normalized data dict for the raw object.
 
        """
 
        return tuple(normal['data'].get(k) for k in self.key)
 

	
 
    def normalize_cache_object(self, obj):
 
        """
 
        Normalizer for cached local data.  This returns a simple dict with
 
        ``'object'`` and ``'data'`` keys; values are the raw object and its
 
        normalized data dict, respectively.
 
        """
 
        return {'object': obj, 'data': self.normalize_local_object(obj)}
 

	
 
    def normalize_local_object(self, obj):
 
        """
 
        Normalize a local (raw) object into a data dict.
 
        """
 
        data = {}
 
        for field in self.simple_fields:
 
            if field in self.fields:
 
                data[field] = getattr(obj, field)
 
        return data
 

	
 
    def get_local_object(self, key):
 
        """
 
        Must return the local object corresponding to the given key, or
 
        ``None``.  Default behavior here will be to check the cache if one is
 
        in effect, otherwise return the value from
 
        :meth:`get_single_local_object()`.
 
        """
 
        if self.caches_local_data and self.cached_local_data is not None:
 
            data = self.cached_local_data.get(key)
 
            return data['object'] if data else None
 
        return self.get_single_local_object(key)
 

	
 
    def get_single_local_object(self, key):
 
        """
 
        Must return the local object corresponding to the given key, or None.
 
        This method should not consult the cache; that is handled within the
 
        :meth:`get_local_object()` method.
 
        """
 
        raise NotImplementedError
 

	
 
    def data_diffs(self, local_data, host_data):
 
        """
 
        Find all (relevant) fields which differ between the host and local data
 
        values for a given record.
 
        """
 
        diffs = []
 
        for field in self.fields:
 
            if local_data[field] != host_data[field]:
 
                diffs.append(field)
 
        return diffs
 

	
 
    def make_object(self):
 
        """
 
        Make a new/empty local object from scratch.
 
        """
 
        return self.model_class()
 

	
 
    def new_object(self, key):
 
        """
 
        Return a new local object to correspond to the given key.  Note that
 
        this method should only populate the object's key, and leave the rest
 
        of the fields to :meth:`update_object()`.
 
        """
 
        obj = self.make_object()
 
        for i, k in enumerate(self.key):
 
            if hasattr(obj, k):
 
                setattr(obj, k, key[i])
 
        return obj
 

	
 
    def create_object(self, key, host_data):
 
        """
 
        Create and return a new local object for the given key, fully populated
 
        from the given host data.  This may return ``None`` if no object is
 
        created.
 
        """
 
        obj = self.new_object(key)
 
        if obj:
 
            return self.update_object(obj, host_data)
 

	
 
    def update_object(self, obj, host_data, local_data=None):
 
        """
 
        Update the local data object with the given host data, and return the
 
        object.
 
        """
 
        for field in self.simple_fields:
 
            if field in self.fields:
 
                if not local_data or local_data[field] != host_data[field]:
 
                    setattr(obj, field, host_data[field])
 
        return obj
 

	
 
    def get_deletion_keys(self):
 
        """
 
        Return a set of keys from the *local* data set, which are eligible for
 
        deletion.  By default this will be all keys from the local cached data
 
        set, or an empty set if local data isn't cached.
 
        """
 
        if self.caches_local_data and self.cached_local_data is not None:
 
            return set(self.cached_local_data)
 
        return set()
 

	
 
    def delete_object(self, obj):
 
        """
 
        Delete the given object from the local system (or not), and return a
 
        boolean indicating whether deletion was successful.  What exactly this
 
        entails may vary; default implementation does nothing at all.
 
        """
 
        return True
 

	
 

	
 
class FromQuery(Importer):
 
    """
 
    Generic base class for importers whose raw external data source is a
 
    SQLAlchemy (or Django, or possibly other?) query.
 
    """
 

	
 
    def query(self):
 
        """
 
        Subclasses must override this, and return the primary query which will
 
        define the data set.
 
        """
 
        raise NotImplementedError
 

	
 
    def get_host_objects(self, progress=None):
 
        """
 
        Returns (raw) query results as a sequence.
 
        """
 
        return QuerySequence(self.query())
rattail/importing/sqlalchemy.py
Show inline comments
 
new file 100644
 
# -*- coding: utf-8 -*-
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2016 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
#  Rattail is free software: you can redistribute it and/or modify it under the
 
#  terms of the GNU Affero General Public License as published by the Free
 
#  Software Foundation, either version 3 of the License, or (at your option)
 
#  any later version.
 
#
 
#  Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
 
#  WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 
#  FOR A PARTICULAR PURPOSE.  See the GNU Affero General Public License for
 
#  more details.
 
#
 
#  You should have received a copy of the GNU Affero General Public License
 
#  along with Rattail.  If not, see <http://www.gnu.org/licenses/>.
 
#
 
################################################################################
 
"""
 
Data Importers for SQLAlchemy
 
"""
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from sqlalchemy import orm
 
from sqlalchemy.orm.exc import NoResultFound
 

	
 
from rattail.db import cache
 
from rattail.importing import Importer, FromQuery
 

	
 

	
 
class FromSQLAlchemy(FromQuery):
 
    """
 
    Base class for importers whose external data source is a SQLAlchemy query.
 
    """
 
    host_session = None
 

	
 
    # For default behavior, set this to a model class to be used in generating
 
    # the host data query.
 
    host_model_class = None
 

	
 
    def query(self):
 
        """
 
        Must return the primary query which will define the data set.  Default
 
        behavior is to leverage :attr:`host_session` and generate a query for
 
        the class defined by :attr:`host_model_class`.
 
        """
 
        return self.host_session.query(self.host_model_class)
 

	
 

	
 
class ToSQLAlchemy(Importer):
 
    """
 
    Base class for all data importers which support the common use case of
 
    targeting a SQLAlchemy ORM on the local side.  This is the base class for
 
    all primary Rattail importers.
 
    """
 
    caches_local_data = True
 

	
 
    def __init__(self, model_class=None, **kwargs):
 
        if model_class:
 
            self.model_class = model_class
 
        super(ToSQLAlchemy, self).__init__(**kwargs)
 

	
 
    @property
 
    def model_mapper(self):
 
        """
 
        Reference to the effective SQLAlchemy mapper for the local model class.
 
        """
 
        return orm.class_mapper(self.model_class)
 

	
 
    @property
 
    def model_table(self):
 
        """
 
        Reference to the effective SQLAlchemy table for the local model class.
 
        """
 
        tables = self.model_mapper.tables
 
        assert len(tables) == 1
 
        return tables[0]
 

	
 
    @property
 
    def simple_fields(self):
 
        """
 
        Returns the list of column names on the underlying local model mapper.
 
        """
 
        return list(self.model_mapper.columns.keys())
 

	
 
    @property
 
    def supported_fields(self):
 
        """
 
        All/only simple fields are supported by default.
 
        """
 
        return self.simple_fields
 

	
 
    def get_single_local_object(self, key):
 
        """
 
        Try to fetch the object from the local database, using SA ORM.
 
        """
 
        query = self.session.query(self.model_class)
 
        for i, k in enumerate(self.key):
 
            query = query.filter(getattr(self.model_class, k) == key[i])
 
        try:
 
            return query.one()
 
        except NoResultFound:
 
            pass
 

	
 
    def create_object(self, key, host_data):
 
        """
 
        Create and return a new local object for the given key, fully populated
 
        from the given host data.  This may return ``None`` if no object is
 
        created.
 

	
 
        Note that this also adds the new object to the local database session,
 
        and flushes the session.
 
        """
 
        obj = super(ToSQLAlchemy, self).create_object(key, host_data)
 
        if obj:
 
            self.session.add(obj)
 
            self.session.flush()
 
            return obj
 

	
 
    def update_object(self, obj, host_data, local_data=None):
 
        """
 
        Update the local data object with the given host data, and return the
 
        object.
 
        """
 
        obj = super(ToSQLAlchemy, self).update_object(obj, host_data, local_data)
 
        if obj:
 
            self.session.flush()
 
            return obj
 

	
 
    def delete_object(self, obj):
 
        """
 
        Delete the given object from the local system (or not), and return a
 
        boolean indicating whether deletion was successful.  Default
 
        implementation will truly delete and expunge the local object via SA
 
        ORM, and flush the local session.
 
        """
 
        self.session.delete(obj)
 
        self.session.flush()
 
        self.session.expunge(obj)
 
        return True
 

	
 
    def cache_model(self, model, **kwargs):
 
        """
 
        Convenience method which invokes :func:`rattail.db.cache.cache_model()`
 
        with the given model and keyword arguments.  It will provide the
 
        ``session`` and ``progress`` parameters by default, setting them to the
 
        importer's attributes of the same names.
 
        """
 
        session = kwargs.pop('session', self.session)
 
        kwargs.setdefault('progress', self.progress)
 
        return cache.cache_model(session, model, **kwargs)
 

	
 
    def cache_local_data(self, host_data=None):
 
        """
 
        Cache all local objects and data using SA ORM.
 
        """
 
        return self.cache_model(self.model_class, key=self.get_cache_key,
 
                                # omit_duplicates=True,
 
                                query_options=self.cache_query_options(),
 
                                normalizer=self.normalize_cache_object)
 

	
 
    def cache_query_options(self):
 
        """
 
        Return a list of options to apply to the cache query, if needed.
 
        """
rattail/tests/__init__.py
Show inline comments
 
# -*- coding: utf-8 -*-
 

	
 
from __future__ import unicode_literals
 
from __future__ import unicode_literals, absolute_import
 

	
 
import os
 
import warnings
 
@@ -9,6 +9,7 @@ from unittest import TestCase
 
from sqlalchemy import create_engine
 
from sqlalchemy.exc import SAWarning
 

	
 
from rattail.config import make_config
 
from rattail.db import model
 
from rattail.db import Session
 

	
 
@@ -22,6 +23,69 @@ warnings.filterwarnings(
 
    SAWarning, r'^sqlalchemy\..*$')
 

	
 

	
 
class NullProgress(object):
 
    """
 
    Dummy progress bar which does nothing, but allows for better test coverage
 
    when used with code under test.
 
    """
 

	
 
    def __init__(self, message, count):
 
        pass
 

	
 
    def update(self, value):
 
        pass
 

	
 
    def destroy(self):
 
        pass
 

	
 

	
 
class RattailMixin(object):
 
    """
 
    Generic mixin for ``TestCase`` classes which need common Rattail setup
 
    functionality.
 
    """
 
    engine_url = os.environ.get('RATTAIL_TEST_ENGINE_URL', 'sqlite://')
 

	
 
    def setUp(self):
 
        self.setup_rattail()
 

	
 
    def tearDown(self):
 
        self.teardown_rattail()
 

	
 
    def setup_rattail(self):
 
        config = self.make_rattail_config()
 
        self.rattail_config = config
 

	
 
        engine = create_engine(self.engine_url)
 
        config.rattail_engines['default'] = engine
 
        config.rattail_engine = engine
 

	
 
        model = self.get_rattail_model()
 
        model.Base.metadata.create_all(bind=engine)
 

	
 
        Session.configure(bind=engine, rattail_config=config)
 
        self.session = Session()
 

	
 
    def teardown_rattail(self):
 
        self.session.close()
 
        Session.configure(bind=None, rattail_config=None)
 
        model = self.get_rattail_model()
 
        model.Base.metadata.drop_all(bind=self.rattail_config.rattail_engine)
 

	
 
    def make_rattail_config(self, **kwargs):
 
        kwargs.setdefault('files', [])
 
        return make_config(**kwargs)
 

	
 
    def get_rattail_model(self):
 
        return model
 

	
 

	
 
class RattailTestCase(RattailMixin, TestCase):
 
    """
 
    Generic base class for Rattail tests.
 
    """
 

	
 

	
 
class DataTestCase(TestCase):
 

	
 
    engine_url = os.environ.get('RATTAIL_TEST_ENGINE_URL', 'sqlite://')
rattail/tests/importing/__init__.py
Show inline comments
 
new file 100644
 
# -*- coding: utf-8 -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
import copy
 
from contextlib import contextmanager
 

	
 
from mock import patch
 

	
 
from rattail.tests import NullProgress
 

	
 

	
 
class ImporterTester(object):
 
    """
 
    Mixin for importer test suites.
 
    """
 
    importer_class = None
 
    sample_data = {}
 

	
 
    def setUp(self):
 
        self.setup_importer()
 

	
 
    def setup_importer(self):
 
        self.importer = self.make_importer()
 

	
 
    def make_importer(self, **kwargs):
 
        kwargs.setdefault('progress', NullProgress)
 
        return self.importer_class(**kwargs)
 

	
 
    def copy_data(self):
 
        return copy.deepcopy(self.sample_data)
 

	
 
    @contextmanager
 
    def host_data(self, data):
 
        self._host_data = data
 
        host_data = [self.importer.normalize_host_object(obj) for obj in data.itervalues()]
 
        with patch.object(self.importer, 'normalize_host_data') as normalize:
 
            normalize.return_value = host_data
 
            yield
 

	
 
    @contextmanager
 
    def local_data(self, data):
 
        self._local_data = data
 
        local_data = {}
 
        for key, obj in data.iteritems():
 
            normal = self.importer.normalize_local_object(obj)
 
            local_data[self.importer.get_key(normal)] = {'object': obj, 'data': normal}
 
        with patch.object(self.importer, 'cache_local_data') as cache:
 
            cache.return_value = local_data
 
            yield
 

	
 
    def import_data(self, **kwargs):
 
        self.result = self.importer.import_data(**kwargs)
 

	
 
    def assert_import_created(self, *keys):
 
        created, updated, deleted = self.result
 
        self.assertEqual(len(created), len(keys))
 
        for key in keys:
 
            key = self.importer.get_key(self._host_data[key])
 
            found = False
 
            for local_object, host_data in created:
 
                if self.importer.get_key(host_data) == key:
 
                    found = True
 
                    break
 
            if not found:
 
                raise self.failureException("Key {} not created when importing with {}".format(key, self.importer))
 

	
 
    def assert_import_updated(self, *keys):
 
        created, updated, deleted = self.result
 
        self.assertEqual(len(updated), len(keys))
 
        for key in keys:
 
            key = self.importer.get_key(self._host_data[key])
 
            found = False
 
            for local_object, local_data, host_data in updated:
 
                if self.importer.get_key(local_data) == key:
 
                    found = True
 
                    break
 
            if not found:
 
                raise self.failureException("Key {} not updated when importing with {}".format(key, self.importer))
 

	
 
    def assert_import_deleted(self, *keys):
 
        created, updated, deleted = self.result
 
        self.assertEqual(len(deleted), len(keys))
 
        for key in keys:
 
            key = self.importer.get_key(self._local_data[key])
 
            found = False
 
            for local_object, local_data in deleted:
 
                if self.importer.get_key(local_data) == key:
 
                    found = True
 
                    break
 
            if not found:
 
                raise self.failureException("Key {} not deleted when importing with {}".format(key, self.importer))
 

	
 
    def test_empty_host(self):
 
        with self.host_data({}):
 
            with self.local_data(self.sample_data):
 
                self.import_data(delete=False)
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
rattail/tests/importing/test_importers.py
Show inline comments
 
new file 100644
 
# -*- coding: utf-8 -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from unittest import TestCase
 

	
 
from mock import Mock, patch
 

	
 
from rattail.db import model
 
from rattail.db.util import QuerySequence
 
from rattail.importing import importers
 
from rattail.tests import NullProgress, RattailTestCase
 
from rattail.tests.importing import ImporterTester
 

	
 

	
 
class TestImporter(TestCase):
 

	
 
    def test_init(self):
 
        importer = importers.Importer()
 
        self.assertIsNone(importer.model_class)
 
        self.assertIsNone(importer.model_name)
 
        self.assertIsNone(importer.key)
 
        self.assertEqual(importer.fields, [])
 

	
 
        # key must be included among the fields
 
        self.assertRaises(ValueError, importers.Importer, key='upc', fields=[])
 
        importer = importers.Importer(key='upc', fields=['upc'])
 
        self.assertEqual(importer.key, ('upc',))
 
        self.assertEqual(importer.fields, ['upc'])
 

	
 
        # extra bits are passed as-is
 
        importer = importers.Importer()
 
        self.assertFalse(hasattr(importer, 'extra_bit'))
 
        extra_bit = object()
 
        importer = importers.Importer(extra_bit=extra_bit)
 
        self.assertIs(importer.extra_bit, extra_bit)
 

	
 
    def test_get_host_objects(self):
 
        importer = importers.Importer()
 
        objects = importer.get_host_objects()
 
        self.assertEqual(objects, [])
 

	
 
    def test_cache_local_data(self):
 
        importer = importers.Importer()
 
        self.assertRaises(NotImplementedError, importer.cache_local_data)
 

	
 
    def test_get_local_object(self):
 
        importer = importers.Importer()
 
        self.assertFalse(importer.caches_local_data)
 
        self.assertRaises(NotImplementedError, importer.get_local_object, None)
 

	
 
        someobj = object()
 
        with patch.object(importer, 'get_single_local_object', Mock(return_value=someobj)):
 
            obj = importer.get_local_object('somekey')
 
        self.assertIs(obj, someobj)
 

	
 
        importer.caches_local_data = True
 
        importer.cached_local_data = {'somekey': {'object': someobj, 'data': {}}}
 
        obj = importer.get_local_object('somekey')
 
        self.assertIs(obj, someobj)
 

	
 
    def test_get_single_local_object(self):
 
        importer = importers.Importer()
 
        self.assertRaises(NotImplementedError, importer.get_single_local_object, None)
 

	
 
    def test_get_cache_key(self):
 
        importer = importers.Importer(key='upc', fields=['upc'])
 
        obj = {'upc': '00074305001321'}
 
        normal = {'data': obj}
 
        key = importer.get_cache_key(obj, normal)
 
        self.assertEqual(key, ('00074305001321',))
 

	
 
    def test_normalize_cache_object(self):
 
        importer = importers.Importer()
 
        obj = {'upc': '00074305001321'}
 
        with patch.object(importer, 'normalize_local_object', new=lambda obj: obj):
 
            cached = importer.normalize_cache_object(obj)
 
        self.assertEqual(cached, {'object': obj, 'data': obj})
 

	
 
    def test_normalize_local_object(self):
 
        importer = importers.Importer(key='upc', fields=['upc', 'description'])
 
        importer.simple_fields = importer.fields
 
        obj = Mock(upc='00074305001321', description="Apple Cider Vinegar")
 
        data = importer.normalize_local_object(obj)
 
        self.assertEqual(data, {'upc': '00074305001321', 'description': "Apple Cider Vinegar"})
 

	
 
    def test_update_object(self):
 
        importer = importers.Importer(key='upc', fields=['upc', 'description'])
 
        importer.simple_fields = importer.fields
 
        obj = Mock(upc='00074305001321', description="Apple Cider Vinegar")
 

	
 
        newobj = importer.update_object(obj, {'upc': '00074305001321', 'description': "Apple Cider Vinegar"})
 
        self.assertIs(newobj, obj)
 
        self.assertEqual(obj.description, "Apple Cider Vinegar")
 

	
 
        newobj = importer.update_object(obj, {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"})
 
        self.assertIs(newobj, obj)
 
        self.assertEqual(obj.description, "Apple Cider Vinegar 32oz")
 

	
 
    def test_normalize_host_data(self):
 
        importer = importers.Importer(key='upc', fields=['upc', 'description'],
 
                                          progress=NullProgress)
 

	
 
        data = [
 
            {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"},
 
            {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"},
 
        ]
 

	
 
        host_data = importer.normalize_host_data(host_objects=[])
 
        self.assertEqual(host_data, [])
 

	
 
        host_data = importer.normalize_host_data(host_objects=data)
 
        self.assertEqual(host_data, data)
 

	
 
        with patch.object(importer, 'get_host_objects', new=Mock(return_value=data)):
 
            host_data = importer.normalize_host_data()
 
        self.assertEqual(host_data, data)
 

	
 
    def test_get_deletion_keys(self):
 
        importer = importers.Importer()
 
        self.assertFalse(importer.caches_local_data)
 
        keys = importer.get_deletion_keys()
 
        self.assertEqual(keys, set())
 

	
 
        importer.caches_local_data = True
 
        self.assertIsNone(importer.cached_local_data)
 
        keys = importer.get_deletion_keys()
 
        self.assertEqual(keys, set())
 

	
 
        importer.cached_local_data = {'delete-me': object()}
 
        keys = importer.get_deletion_keys()
 
        self.assertEqual(keys, set(['delete-me']))
 

	
 

	
 
class TestFromQuery(RattailTestCase):
 

	
 
    def test_query(self):
 
        importer = importers.FromQuery()
 
        self.assertRaises(NotImplementedError, importer.query)
 

	
 
    def test_get_host_objects(self):
 
        query = self.session.query(model.Product)
 
        importer = importers.FromQuery()
 
        with patch.object(importer, 'query', Mock(return_value=query)):
 
            objects = importer.get_host_objects()
 
        self.assertIsInstance(objects, QuerySequence)
 

	
 

	
 
######################################################################
 
# fake importer class, tested mostly for basic coverage
 
######################################################################
 

	
 
class Product(object):
 
    upc = None
 
    description = None
 

	
 

	
 
class MockImporter(importers.Importer):
 
    model_class = Product
 
    key = 'upc'
 
    simple_fields = ['upc', 'description']
 
    supported_fields = simple_fields
 
    caches_local_data = True
 

	
 
    def normalize_local_object(self, obj):
 
        return obj
 

	
 
    def update_object(self, obj, host_data, local_data=None):
 
        return host_data
 

	
 

	
 
class TestMockImporter(ImporterTester, TestCase):
 
    importer_class = MockImporter
 

	
 
    sample_data = {
 
        '16oz': {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"},
 
        '32oz': {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"},
 
        '1gal': {'upc': '00074305011283', 'description': "Apple Cider Vinegar 1gal"},
 
    }
 

	
 
    def test_create(self):
 
        local = self.copy_data()
 
        del local['32oz']
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data()
 
        self.assert_import_created('32oz')
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_update(self):
 
        local = self.copy_data()
 
        local['16oz']['description'] = "wrong description"
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data()
 
        self.assert_import_created()
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted()
 

	
 
    def test_delete(self):
 
        local = self.copy_data()
 
        local['bogus'] = {'upc': '00000000000000', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data()
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted('bogus')
 

	
 
    def test_duplicate(self):
 
        host = self.copy_data()
 
        host['32oz-dupe'] = host['32oz']
 
        with self.host_data(host):
 
            with self.local_data(self.sample_data):
 
                self.import_data()
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_max_create(self):
 
        local = self.copy_data()
 
        del local['16oz']
 
        del local['1gal']
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_create=1)
 
        self.assert_import_created('16oz')
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_max_total_create(self):
 
        local = self.copy_data()
 
        del local['16oz']
 
        del local['1gal']
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_total=1)
 
        self.assert_import_created('16oz')
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_max_update(self):
 
        local = self.copy_data()
 
        local['16oz']['description'] = "wrong"
 
        local['1gal']['description'] = "wrong"
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_update=1)
 
        self.assert_import_created()
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted()
 

	
 
    def test_max_total_update(self):
 
        local = self.copy_data()
 
        local['16oz']['description'] = "wrong"
 
        local['1gal']['description'] = "wrong"
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_total=1)
 
        self.assert_import_created()
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted()
 

	
 
    def test_max_delete(self):
 
        local = self.copy_data()
 
        local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"}
 
        local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_delete=1)
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted('bogus1')
 

	
 
    def test_max_total_delete(self):
 
        local = self.copy_data()
 
        local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"}
 
        local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_total=1)
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted('bogus1')
rattail/tests/importing/test_sqlalchemy.py
Show inline comments
 
new file 100644
 
# -*- coding: utf-8 -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from unittest import TestCase
 

	
 
import sqlalchemy as sa
 
from sqlalchemy import orm
 
from sqlalchemy.orm.exc import MultipleResultsFound
 

	
 
from rattail.importing import sqlalchemy as saimport
 

	
 

	
 
class Widget(object):
 
    pass
 

	
 
metadata = sa.MetaData()
 

	
 
widget_table = sa.Table('widgets', metadata,
 
                        sa.Column('id', sa.Integer(), primary_key=True),
 
                        sa.Column('description', sa.String(length=50)))
 

	
 
widget_mapper = orm.mapper(Widget, widget_table)
 

	
 
WIDGETS = [
 
    {'id': 1, 'description': "Main Widget"},
 
    {'id': 2, 'description': "Other Widget"},
 
    {'id': 3, 'description': "Other Widget"},
 
]
 

	
 

	
 
class TestFromSQLAlchemy(TestCase):
 

	
 
    def test_query(self):
 
        Session = orm.sessionmaker(bind=sa.create_engine('sqlite://'))
 
        session = Session()
 
        importer = saimport.FromSQLAlchemy(host_session=session,
 
                                           host_model_class=Widget)
 
        self.assertEqual(unicode(importer.query()),
 
                         "SELECT widgets.id AS widgets_id, widgets.description AS widgets_description \n"
 
                         "FROM widgets")
 

	
 

	
 
class TestToSQLAlchemy(TestCase):
 

	
 
    def setUp(self):
 
        engine = sa.create_engine('sqlite://')
 
        metadata.create_all(bind=engine)
 
        Session = orm.sessionmaker(bind=engine)
 
        self.session = Session()
 
        for data in WIDGETS:
 
            widget = Widget()
 
            for key, value in data.iteritems():
 
                setattr(widget, key, value)
 
            self.session.add(widget)
 
        self.session.commit()
 

	
 
    def tearDown(self):
 
        self.session.close()
 
        del self.session
 

	
 
    def make_importer(self, **kwargs):
 
        kwargs.setdefault('model_class', Widget)
 
        return saimport.ToSQLAlchemy(**kwargs)
 

	
 
    def test_model_mapper(self):
 
        importer = self.make_importer()
 
        self.assertIs(importer.model_mapper, widget_mapper)
 

	
 
    def test_model_table(self):
 
        importer = self.make_importer()
 
        self.assertIs(importer.model_table, widget_table)
 

	
 
    def test_simple_fields(self):
 
        importer = self.make_importer()
 
        self.assertEqual(importer.simple_fields, ['id', 'description'])
 

	
 
    def test_get_single_local_object(self):
 

	
 
        # simple key
 
        importer = self.make_importer(key='id', session=self.session)
 
        widget = importer.get_single_local_object((1,))
 
        self.assertEqual(widget.id, 1)
 
        self.assertEqual(widget.description, "Main Widget")
 

	
 
        # compound key
 
        importer = self.make_importer(key=('id', 'description'), session=self.session)
 
        widget = importer.get_single_local_object((1, "Main Widget"))
 
        self.assertEqual(widget.id, 1)
 
        self.assertEqual(widget.description, "Main Widget")
 

	
 
        # widget not found
 
        importer = self.make_importer(key='id', session=self.session)
 
        self.assertIsNone(importer.get_single_local_object((42,)))
 

	
 
        # multiple widgets found
 
        importer = self.make_importer(key='description', session=self.session)
 
        self.assertRaises(MultipleResultsFound, importer.get_single_local_object, ("Other Widget",))
 

	
 
    def test_create_object(self):
 
        importer = self.make_importer(key='id', session=self.session)
 
        widget = importer.create_object((42,), {'id': 42, 'description': "Latest Widget"})
 
        self.assertFalse(self.session.new or self.session.dirty or self.session.deleted) # i.e. has been flushed
 
        self.assertIn(widget, self.session) # therefore widget has been flushed and would be committed
 
        self.assertEqual(widget.id, 42)
 
        self.assertEqual(widget.description, "Latest Widget")
 

	
 
    def test_delete_object(self):
 
        widget = self.session.query(Widget).get(1)
 
        self.assertIn(widget, self.session)
 
        importer = self.make_importer(session=self.session)
 
        self.assertTrue(importer.delete_object(widget))
 
        self.assertNotIn(widget, self.session)
 
        self.assertIsNone(self.session.query(Widget).get(1))
 

	
 
    def test_cache_model(self):
 
        importer = self.make_importer(key='id', session=self.session)
 
        cache = importer.cache_model(Widget, key='id')
 
        self.assertEqual(len(cache), 3)
 
        for i in range(1, 4):
 
            self.assertIn(i, cache)
 
            self.assertIsInstance(cache[i], Widget)
 
            self.assertEqual(cache[i].id, i)
 
            self.assertEqual(cache[i].description, WIDGETS[i-1]['description'])
 

	
 
    def test_cache_local_data(self):
 
        importer = self.make_importer(key='id', session=self.session)
 
        cache = importer.cache_local_data()
 
        self.assertEqual(len(cache), 3)
 
        for i in range(1, 4):
 
            self.assertIn((i,), cache)
 
            cached = cache[(i,)]
 
            self.assertIsInstance(cached, dict)
 
            self.assertIsInstance(cached['object'], Widget)
 
            self.assertEqual(cached['object'].id, i)
 
            self.assertEqual(cached['object'].description, WIDGETS[i-1]['description'])
 
            self.assertIsInstance(cached['data'], dict)
 
            self.assertEqual(cached['data']['id'], i)
 
            self.assertEqual(cached['data']['description'], WIDGETS[i-1]['description'])
 
        
0 comments (0 inline, 0 general)