diff --git a/rattail/db/util.py b/rattail/db/util.py
index bd1ca43adb4997cbec4ab46930033fdaa6feb394..055cdc0d298c776723be4b88c78617fe93dc8b64 100644
--- a/rattail/db/util.py
+++ b/rattail/db/util.py
@@ -39,6 +39,23 @@ from rattail.db.config import engine_from_config, get_engines, get_default_engin
log = logging.getLogger(__name__)
+class QuerySequence(object):
+ """
+ Simple wrapper for a SQLAlchemy (or Django, or other?) query, to make it
+ sort of behave like a normal sequence, as much as needed to e.g. make an
+ importer happy.
+ """
+
+ def __init__(self, query):
+ self.query = query
+
+ def __len__(self):
+ return self.query.count()
+
+ def __iter__(self):
+ return iter(self.query)
+
+
def maxlen(attr):
"""
Return the maximum length for the given attribute.
diff --git a/rattail/importing/__init__.py b/rattail/importing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..59a333cc8cf5d243aaa4a69a2ed640e29bfbcab7
--- /dev/null
+++ b/rattail/importing/__init__.py
@@ -0,0 +1,30 @@
+# -*- coding: utf-8 -*-
+################################################################################
+#
+# Rattail -- Retail Software Framework
+# Copyright © 2010-2016 Lance Edgar
+#
+# This file is part of Rattail.
+#
+# Rattail is free software: you can redistribute it and/or modify it under the
+# terms of the GNU Affero General Public License as published by the Free
+# Software Foundation, either version 3 of the License, or (at your option)
+# any later version.
+#
+# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for
+# more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with Rattail. If not, see .
+#
+################################################################################
+"""
+Data Importing Framework
+"""
+
+from __future__ import unicode_literals, absolute_import
+
+from .importers import Importer, FromQuery
+from .sqlalchemy import FromSQLAlchemy, ToSQLAlchemy
diff --git a/rattail/importing/importers.py b/rattail/importing/importers.py
new file mode 100644
index 0000000000000000000000000000000000000000..35bcea6e9d19e45e24ad4d331a71b6fffa9206e4
--- /dev/null
+++ b/rattail/importing/importers.py
@@ -0,0 +1,442 @@
+# -*- coding: utf-8 -*-
+################################################################################
+#
+# Rattail -- Retail Software Framework
+# Copyright © 2010-2016 Lance Edgar
+#
+# This file is part of Rattail.
+#
+# Rattail is free software: you can redistribute it and/or modify it under the
+# terms of the GNU Affero General Public License as published by the Free
+# Software Foundation, either version 3 of the License, or (at your option)
+# any later version.
+#
+# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for
+# more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with Rattail. If not, see .
+#
+################################################################################
+"""
+Data Importers
+"""
+
+from __future__ import unicode_literals, absolute_import
+
+import datetime
+import logging
+
+from rattail.db.util import QuerySequence
+from rattail.time import make_utc
+
+
+log = logging.getLogger(__name__)
+
+
+class Importer(object):
+ """
+ Base class for all data importers.
+ """
+ # Set this to the data model class which is targeted on the local side.
+ model_class = None
+
+ key = None
+
+ # The full list of field names supported by the importer, i.e. for the data
+ # model to which the importer pertains. By definition this list will be
+ # restricted to what the local side can acommodate, but may be further
+ # restricted by what the host side has to offer.
+ supported_fields = []
+
+ # The list of field names which may be considered "simple" and therefore
+ # treated as such, i.e. with basic getattr/setattr calls. Note that this
+ # only applies to the local side, it has no effect on the host side.
+ simple_fields = []
+
+ allow_create = True
+ allow_update = True
+ allow_delete = True
+ dry_run = False
+
+ max_create = None
+ max_update = None
+ max_delete = None
+ max_total = None
+ progress = None
+
+ caches_local_data = False
+ cached_local_data = None
+
+ host_system_title = None
+ local_system_title = None
+
+ def __init__(self, config=None, fields=None, key=None, **kwargs):
+ self.config = config
+ self.fields = fields or self.supported_fields
+ if key is not None:
+ self.key = key
+ if isinstance(self.key, basestring):
+ self.key = (self.key,)
+ if self.key:
+ for field in self.key:
+ if field not in self.fields:
+ raise ValueError("Key field '{}' must be included in effective fields "
+ "for {}".format(key, self.__class__.__name__))
+ self._setup(**kwargs)
+
+ def _setup(self, **kwargs):
+ self.now = kwargs.pop('now', make_utc(datetime.datetime.utcnow(), tzinfo=True))
+ self.create = kwargs.pop('create', self.allow_create) and self.allow_create
+ self.update = kwargs.pop('update', self.allow_update) and self.allow_update
+ self.delete = kwargs.pop('delete', self.allow_delete) and self.allow_delete
+ for key, value in kwargs.iteritems():
+ setattr(self, key, value)
+
+ @property
+ def model_name(self):
+ if self.model_class:
+ return self.model_class.__name__
+
+ def setup(self):
+ """
+ Perform any setup necessary, e.g. cache lookups for existing data.
+ """
+
+ def teardown(self):
+ """
+ Perform any cleanup after import, if necessary.
+ """
+
+ def import_data(self, host_data=None, **kwargs):
+ """
+ Import some data! This is the core body of logic for that, regardless
+ of where data is coming from or where it's headed. Note that this
+ method handles deletions as well as adds/updates.
+ """
+ self._setup(**kwargs)
+ self.setup()
+ created = updated = deleted = []
+
+ # Get complete set of normalized host data.
+ if host_data is None:
+ host_data = self.normalize_host_data()
+
+ # Prune duplicate keys from host/source data. This is for the sake of
+ # sanity since duplicates typically lead to a ping-pong effect, where a
+ # "clean" (change-less) import is impossible.
+ unique = {}
+ for data in host_data:
+ key = self.get_key(data)
+ if key in unique:
+ log.warning("duplicate records detected from {} for key: {}".format(
+ self.host_system_title, key))
+ unique[key] = data
+ host_data = []
+ for key in sorted(unique):
+ host_data.append(unique[key])
+
+ # Cache local data if appropriate.
+ if self.caches_local_data:
+ self.cached_local_data = self.cache_local_data(host_data)
+
+ # Create and/or update data.
+ if self.create or self.update:
+ created, updated = self._import_create_update(host_data)
+
+ # Delete data.
+ if self.delete:
+ changes = len(created) + len(updated)
+ if self.max_total and changes >= self.max_total:
+ log.warning("max of {} total changes already reached; skipping deletions".format(self.max_total))
+ else:
+ deleted = self._import_delete(host_data, set(unique), changes=changes)
+
+ self.teardown()
+ return created, updated, deleted
+
+ def _import_create_update(self, data):
+ """
+ Import the given data; create and/or update records as needed.
+ """
+ created, updated = [], []
+ count = len(data)
+ if not count:
+ return created, updated
+
+ prog = None
+ if self.progress:
+ prog = self.progress("Importing {} data".format(self.model_name), count)
+
+ for i, host_data in enumerate(data, 1):
+
+ # Fetch local object, using key from host data.
+ key = self.get_key(host_data)
+ local_object = self.get_local_object(key)
+
+ # If we have a local object, but its data differs from host, update it.
+ if local_object and self.update:
+ local_data = self.normalize_local_object(local_object)
+ diffs = self.data_diffs(local_data, host_data)
+ if diffs:
+ log.debug("fields '{}' differed for local data: {}, host data: {}".format(
+ ','.join(diffs), local_data, host_data))
+ local_object = self.update_object(local_object, host_data, local_data)
+ updated.append((local_object, local_data, host_data))
+ if self.max_update and len(updated) >= self.max_update:
+ log.warning("max of {} *updated* records has been reached; stopping now".format(self.max_update))
+ break
+ if self.max_total and (len(created) + len(updated)) >= self.max_total:
+ log.warning("max of {} *total changes* has been reached; stopping now".format(self.max_total))
+ break
+
+ # If we did not yet have a local object, create it using host data.
+ elif not local_object and self.create:
+ local_object = self.create_object(key, host_data)
+ log.debug("created new {} {}: {}".format(self.model_name, key, local_object))
+ created.append((local_object, host_data))
+ if self.caches_local_data and self.cached_local_data is not None:
+ self.cached_local_data[key] = {'object': local_object, 'data': self.normalize_local_object(local_object)}
+ if self.max_create and len(created) >= self.max_create:
+ log.warning("max of {} *created* records has been reached; stopping now".format(self.max_create))
+ break
+ if self.max_total and (len(created) + len(updated)) >= self.max_total:
+ log.warning("max of {} *total changes* has been reached; stopping now".format(self.max_total))
+ break
+
+ if prog:
+ prog.update(i)
+ if prog:
+ prog.destroy()
+
+ return created, updated
+
+ def _import_delete(self, host_data, host_keys, changes=0):
+ """
+ Import deletions for the given data set.
+ """
+ deleted = []
+ deleting = self.get_deletion_keys() - host_keys
+ count = len(deleting)
+ log.debug("found {} instances to delete".format(count))
+ if count:
+
+ prog = None
+ if self.progress:
+ prog = self.progress("Deleting {} data".format(self.model_name), count)
+
+ for i, key in enumerate(sorted(deleting), 1):
+
+ cached = self.cached_local_data.pop(key)
+ obj = cached['object']
+ if self.delete_object(obj):
+ deleted.append((obj, cached['data']))
+
+ if self.max_delete and len(deleted) >= self.max_delete:
+ log.warning("max of {} *deleted* records has been reached; stopping now".format(self.max_delete))
+ break
+ if self.max_total and (changes + len(deleted)) >= self.max_total:
+ log.warning("max of {} *total changes* has been reached; stopping now".format(self.max_total))
+ break
+
+ if prog:
+ prog.update(i)
+ if prog:
+ prog.destroy()
+
+ return deleted
+
+ def get_key(self, data):
+ """
+ Return the key value for the given data dict.
+ """
+ return tuple(data[k] for k in self.key)
+
+ def get_host_objects(self):
+ """
+ Return the "raw" (as-is, not normalized) host objects which are to be
+ imported. This may return any sequence-like object, which has a
+ ``len()`` value and responds to iteration etc. The objects contained
+ within it may be of any type, no assumptions are made there. (That is
+ the job of the :meth:`normalize_host_data()` method.)
+ """
+ return []
+
+ def normalize_host_data(self, host_objects=None):
+ """
+ Return a normalized version of the full set of host data. Note that
+ this calls :meth:`get_host_objects()` to obtain the initial raw
+ objects, and then normalizes each object. The normalization process
+ may filter out some records from the set, in which case the return
+ value will be smaller than the original data set.
+ """
+ if host_objects is None:
+ host_objects = self.get_host_objects()
+ normalized = []
+ count = len(host_objects)
+ if count == 0:
+ return normalized
+ prog = None
+ if self.progress:
+ prog = self.progress("Reading {} data from {}".format(self.model_name, self.host_system_title), count)
+ for i, obj in enumerate(host_objects, 1):
+ data = self.normalize_host_object(obj)
+ if data:
+ normalized.append(data)
+ if prog:
+ prog.update(i)
+ if prog:
+ prog.destroy()
+ return normalized
+
+ def normalize_host_object(self, obj):
+ """
+ Normalize a raw host object into a data dict, or return ``None`` if the
+ object should be ignored for the importer's purposes.
+ """
+ return obj
+
+ def cache_local_data(self, host_data=None):
+ """
+ Cache all raw objects and normalized data from the local system.
+ """
+ raise NotImplementedError
+
+ def get_cache_key(self, obj, normal):
+ """
+ Get the primary cache key for a given object and normalized data.
+
+ Note that this method's signature is designed for use with the
+ :func:`rattail.db.cache.cache_model()` function, and as such the
+ ``normal`` parameter is understood to be a dict with a ``'data'`` key,
+ value for which is the normalized data dict for the raw object.
+ """
+ return tuple(normal['data'].get(k) for k in self.key)
+
+ def normalize_cache_object(self, obj):
+ """
+ Normalizer for cached local data. This returns a simple dict with
+ ``'object'`` and ``'data'`` keys; values are the raw object and its
+ normalized data dict, respectively.
+ """
+ return {'object': obj, 'data': self.normalize_local_object(obj)}
+
+ def normalize_local_object(self, obj):
+ """
+ Normalize a local (raw) object into a data dict.
+ """
+ data = {}
+ for field in self.simple_fields:
+ if field in self.fields:
+ data[field] = getattr(obj, field)
+ return data
+
+ def get_local_object(self, key):
+ """
+ Must return the local object corresponding to the given key, or
+ ``None``. Default behavior here will be to check the cache if one is
+ in effect, otherwise return the value from
+ :meth:`get_single_local_object()`.
+ """
+ if self.caches_local_data and self.cached_local_data is not None:
+ data = self.cached_local_data.get(key)
+ return data['object'] if data else None
+ return self.get_single_local_object(key)
+
+ def get_single_local_object(self, key):
+ """
+ Must return the local object corresponding to the given key, or None.
+ This method should not consult the cache; that is handled within the
+ :meth:`get_local_object()` method.
+ """
+ raise NotImplementedError
+
+ def data_diffs(self, local_data, host_data):
+ """
+ Find all (relevant) fields which differ between the host and local data
+ values for a given record.
+ """
+ diffs = []
+ for field in self.fields:
+ if local_data[field] != host_data[field]:
+ diffs.append(field)
+ return diffs
+
+ def make_object(self):
+ """
+ Make a new/empty local object from scratch.
+ """
+ return self.model_class()
+
+ def new_object(self, key):
+ """
+ Return a new local object to correspond to the given key. Note that
+ this method should only populate the object's key, and leave the rest
+ of the fields to :meth:`update_object()`.
+ """
+ obj = self.make_object()
+ for i, k in enumerate(self.key):
+ if hasattr(obj, k):
+ setattr(obj, k, key[i])
+ return obj
+
+ def create_object(self, key, host_data):
+ """
+ Create and return a new local object for the given key, fully populated
+ from the given host data. This may return ``None`` if no object is
+ created.
+ """
+ obj = self.new_object(key)
+ if obj:
+ return self.update_object(obj, host_data)
+
+ def update_object(self, obj, host_data, local_data=None):
+ """
+ Update the local data object with the given host data, and return the
+ object.
+ """
+ for field in self.simple_fields:
+ if field in self.fields:
+ if not local_data or local_data[field] != host_data[field]:
+ setattr(obj, field, host_data[field])
+ return obj
+
+ def get_deletion_keys(self):
+ """
+ Return a set of keys from the *local* data set, which are eligible for
+ deletion. By default this will be all keys from the local cached data
+ set, or an empty set if local data isn't cached.
+ """
+ if self.caches_local_data and self.cached_local_data is not None:
+ return set(self.cached_local_data)
+ return set()
+
+ def delete_object(self, obj):
+ """
+ Delete the given object from the local system (or not), and return a
+ boolean indicating whether deletion was successful. What exactly this
+ entails may vary; default implementation does nothing at all.
+ """
+ return True
+
+
+class FromQuery(Importer):
+ """
+ Generic base class for importers whose raw external data source is a
+ SQLAlchemy (or Django, or possibly other?) query.
+ """
+
+ def query(self):
+ """
+ Subclasses must override this, and return the primary query which will
+ define the data set.
+ """
+ raise NotImplementedError
+
+ def get_host_objects(self, progress=None):
+ """
+ Returns (raw) query results as a sequence.
+ """
+ return QuerySequence(self.query())
diff --git a/rattail/importing/sqlalchemy.py b/rattail/importing/sqlalchemy.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc67d98d56a83ac96088bfc22343aaed6d6417ae
--- /dev/null
+++ b/rattail/importing/sqlalchemy.py
@@ -0,0 +1,170 @@
+# -*- coding: utf-8 -*-
+################################################################################
+#
+# Rattail -- Retail Software Framework
+# Copyright © 2010-2016 Lance Edgar
+#
+# This file is part of Rattail.
+#
+# Rattail is free software: you can redistribute it and/or modify it under the
+# terms of the GNU Affero General Public License as published by the Free
+# Software Foundation, either version 3 of the License, or (at your option)
+# any later version.
+#
+# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
+# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
+# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for
+# more details.
+#
+# You should have received a copy of the GNU Affero General Public License
+# along with Rattail. If not, see .
+#
+################################################################################
+"""
+Data Importers for SQLAlchemy
+"""
+
+from __future__ import unicode_literals, absolute_import
+
+from sqlalchemy import orm
+from sqlalchemy.orm.exc import NoResultFound
+
+from rattail.db import cache
+from rattail.importing import Importer, FromQuery
+
+
+class FromSQLAlchemy(FromQuery):
+ """
+ Base class for importers whose external data source is a SQLAlchemy query.
+ """
+ host_session = None
+
+ # For default behavior, set this to a model class to be used in generating
+ # the host data query.
+ host_model_class = None
+
+ def query(self):
+ """
+ Must return the primary query which will define the data set. Default
+ behavior is to leverage :attr:`host_session` and generate a query for
+ the class defined by :attr:`host_model_class`.
+ """
+ return self.host_session.query(self.host_model_class)
+
+
+class ToSQLAlchemy(Importer):
+ """
+ Base class for all data importers which support the common use case of
+ targeting a SQLAlchemy ORM on the local side. This is the base class for
+ all primary Rattail importers.
+ """
+ caches_local_data = True
+
+ def __init__(self, model_class=None, **kwargs):
+ if model_class:
+ self.model_class = model_class
+ super(ToSQLAlchemy, self).__init__(**kwargs)
+
+ @property
+ def model_mapper(self):
+ """
+ Reference to the effective SQLAlchemy mapper for the local model class.
+ """
+ return orm.class_mapper(self.model_class)
+
+ @property
+ def model_table(self):
+ """
+ Reference to the effective SQLAlchemy table for the local model class.
+ """
+ tables = self.model_mapper.tables
+ assert len(tables) == 1
+ return tables[0]
+
+ @property
+ def simple_fields(self):
+ """
+ Returns the list of column names on the underlying local model mapper.
+ """
+ return list(self.model_mapper.columns.keys())
+
+ @property
+ def supported_fields(self):
+ """
+ All/only simple fields are supported by default.
+ """
+ return self.simple_fields
+
+ def get_single_local_object(self, key):
+ """
+ Try to fetch the object from the local database, using SA ORM.
+ """
+ query = self.session.query(self.model_class)
+ for i, k in enumerate(self.key):
+ query = query.filter(getattr(self.model_class, k) == key[i])
+ try:
+ return query.one()
+ except NoResultFound:
+ pass
+
+ def create_object(self, key, host_data):
+ """
+ Create and return a new local object for the given key, fully populated
+ from the given host data. This may return ``None`` if no object is
+ created.
+
+ Note that this also adds the new object to the local database session,
+ and flushes the session.
+ """
+ obj = super(ToSQLAlchemy, self).create_object(key, host_data)
+ if obj:
+ self.session.add(obj)
+ self.session.flush()
+ return obj
+
+ def update_object(self, obj, host_data, local_data=None):
+ """
+ Update the local data object with the given host data, and return the
+ object.
+ """
+ obj = super(ToSQLAlchemy, self).update_object(obj, host_data, local_data)
+ if obj:
+ self.session.flush()
+ return obj
+
+ def delete_object(self, obj):
+ """
+ Delete the given object from the local system (or not), and return a
+ boolean indicating whether deletion was successful. Default
+ implementation will truly delete and expunge the local object via SA
+ ORM, and flush the local session.
+ """
+ self.session.delete(obj)
+ self.session.flush()
+ self.session.expunge(obj)
+ return True
+
+ def cache_model(self, model, **kwargs):
+ """
+ Convenience method which invokes :func:`rattail.db.cache.cache_model()`
+ with the given model and keyword arguments. It will provide the
+ ``session`` and ``progress`` parameters by default, setting them to the
+ importer's attributes of the same names.
+ """
+ session = kwargs.pop('session', self.session)
+ kwargs.setdefault('progress', self.progress)
+ return cache.cache_model(session, model, **kwargs)
+
+ def cache_local_data(self, host_data=None):
+ """
+ Cache all local objects and data using SA ORM.
+ """
+ return self.cache_model(self.model_class, key=self.get_cache_key,
+ # omit_duplicates=True,
+ query_options=self.cache_query_options(),
+ normalizer=self.normalize_cache_object)
+
+ def cache_query_options(self):
+ """
+ Return a list of options to apply to the cache query, if needed.
+ """
diff --git a/rattail/tests/__init__.py b/rattail/tests/__init__.py
index 63c476c592c7021cf36948613466389bd55ed13c..56c2d453b1b99e2a76a7fb2906567db01bacf02f 100644
--- a/rattail/tests/__init__.py
+++ b/rattail/tests/__init__.py
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
-from __future__ import unicode_literals
+from __future__ import unicode_literals, absolute_import
import os
import warnings
@@ -9,6 +9,7 @@ from unittest import TestCase
from sqlalchemy import create_engine
from sqlalchemy.exc import SAWarning
+from rattail.config import make_config
from rattail.db import model
from rattail.db import Session
@@ -22,6 +23,69 @@ warnings.filterwarnings(
SAWarning, r'^sqlalchemy\..*$')
+class NullProgress(object):
+ """
+ Dummy progress bar which does nothing, but allows for better test coverage
+ when used with code under test.
+ """
+
+ def __init__(self, message, count):
+ pass
+
+ def update(self, value):
+ pass
+
+ def destroy(self):
+ pass
+
+
+class RattailMixin(object):
+ """
+ Generic mixin for ``TestCase`` classes which need common Rattail setup
+ functionality.
+ """
+ engine_url = os.environ.get('RATTAIL_TEST_ENGINE_URL', 'sqlite://')
+
+ def setUp(self):
+ self.setup_rattail()
+
+ def tearDown(self):
+ self.teardown_rattail()
+
+ def setup_rattail(self):
+ config = self.make_rattail_config()
+ self.rattail_config = config
+
+ engine = create_engine(self.engine_url)
+ config.rattail_engines['default'] = engine
+ config.rattail_engine = engine
+
+ model = self.get_rattail_model()
+ model.Base.metadata.create_all(bind=engine)
+
+ Session.configure(bind=engine, rattail_config=config)
+ self.session = Session()
+
+ def teardown_rattail(self):
+ self.session.close()
+ Session.configure(bind=None, rattail_config=None)
+ model = self.get_rattail_model()
+ model.Base.metadata.drop_all(bind=self.rattail_config.rattail_engine)
+
+ def make_rattail_config(self, **kwargs):
+ kwargs.setdefault('files', [])
+ return make_config(**kwargs)
+
+ def get_rattail_model(self):
+ return model
+
+
+class RattailTestCase(RattailMixin, TestCase):
+ """
+ Generic base class for Rattail tests.
+ """
+
+
class DataTestCase(TestCase):
engine_url = os.environ.get('RATTAIL_TEST_ENGINE_URL', 'sqlite://')
diff --git a/rattail/tests/importing/__init__.py b/rattail/tests/importing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..388d0dfaea9d2520190010b7f109768925551547
--- /dev/null
+++ b/rattail/tests/importing/__init__.py
@@ -0,0 +1,100 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import unicode_literals, absolute_import
+
+import copy
+from contextlib import contextmanager
+
+from mock import patch
+
+from rattail.tests import NullProgress
+
+
+class ImporterTester(object):
+ """
+ Mixin for importer test suites.
+ """
+ importer_class = None
+ sample_data = {}
+
+ def setUp(self):
+ self.setup_importer()
+
+ def setup_importer(self):
+ self.importer = self.make_importer()
+
+ def make_importer(self, **kwargs):
+ kwargs.setdefault('progress', NullProgress)
+ return self.importer_class(**kwargs)
+
+ def copy_data(self):
+ return copy.deepcopy(self.sample_data)
+
+ @contextmanager
+ def host_data(self, data):
+ self._host_data = data
+ host_data = [self.importer.normalize_host_object(obj) for obj in data.itervalues()]
+ with patch.object(self.importer, 'normalize_host_data') as normalize:
+ normalize.return_value = host_data
+ yield
+
+ @contextmanager
+ def local_data(self, data):
+ self._local_data = data
+ local_data = {}
+ for key, obj in data.iteritems():
+ normal = self.importer.normalize_local_object(obj)
+ local_data[self.importer.get_key(normal)] = {'object': obj, 'data': normal}
+ with patch.object(self.importer, 'cache_local_data') as cache:
+ cache.return_value = local_data
+ yield
+
+ def import_data(self, **kwargs):
+ self.result = self.importer.import_data(**kwargs)
+
+ def assert_import_created(self, *keys):
+ created, updated, deleted = self.result
+ self.assertEqual(len(created), len(keys))
+ for key in keys:
+ key = self.importer.get_key(self._host_data[key])
+ found = False
+ for local_object, host_data in created:
+ if self.importer.get_key(host_data) == key:
+ found = True
+ break
+ if not found:
+ raise self.failureException("Key {} not created when importing with {}".format(key, self.importer))
+
+ def assert_import_updated(self, *keys):
+ created, updated, deleted = self.result
+ self.assertEqual(len(updated), len(keys))
+ for key in keys:
+ key = self.importer.get_key(self._host_data[key])
+ found = False
+ for local_object, local_data, host_data in updated:
+ if self.importer.get_key(local_data) == key:
+ found = True
+ break
+ if not found:
+ raise self.failureException("Key {} not updated when importing with {}".format(key, self.importer))
+
+ def assert_import_deleted(self, *keys):
+ created, updated, deleted = self.result
+ self.assertEqual(len(deleted), len(keys))
+ for key in keys:
+ key = self.importer.get_key(self._local_data[key])
+ found = False
+ for local_object, local_data in deleted:
+ if self.importer.get_key(local_data) == key:
+ found = True
+ break
+ if not found:
+ raise self.failureException("Key {} not deleted when importing with {}".format(key, self.importer))
+
+ def test_empty_host(self):
+ with self.host_data({}):
+ with self.local_data(self.sample_data):
+ self.import_data(delete=False)
+ self.assert_import_created()
+ self.assert_import_updated()
+ self.assert_import_deleted()
diff --git a/rattail/tests/importing/test_importers.py b/rattail/tests/importing/test_importers.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a4638b685d817a865e798b50b2b2ad727c414fd
--- /dev/null
+++ b/rattail/tests/importing/test_importers.py
@@ -0,0 +1,285 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import unicode_literals, absolute_import
+
+from unittest import TestCase
+
+from mock import Mock, patch
+
+from rattail.db import model
+from rattail.db.util import QuerySequence
+from rattail.importing import importers
+from rattail.tests import NullProgress, RattailTestCase
+from rattail.tests.importing import ImporterTester
+
+
+class TestImporter(TestCase):
+
+ def test_init(self):
+ importer = importers.Importer()
+ self.assertIsNone(importer.model_class)
+ self.assertIsNone(importer.model_name)
+ self.assertIsNone(importer.key)
+ self.assertEqual(importer.fields, [])
+
+ # key must be included among the fields
+ self.assertRaises(ValueError, importers.Importer, key='upc', fields=[])
+ importer = importers.Importer(key='upc', fields=['upc'])
+ self.assertEqual(importer.key, ('upc',))
+ self.assertEqual(importer.fields, ['upc'])
+
+ # extra bits are passed as-is
+ importer = importers.Importer()
+ self.assertFalse(hasattr(importer, 'extra_bit'))
+ extra_bit = object()
+ importer = importers.Importer(extra_bit=extra_bit)
+ self.assertIs(importer.extra_bit, extra_bit)
+
+ def test_get_host_objects(self):
+ importer = importers.Importer()
+ objects = importer.get_host_objects()
+ self.assertEqual(objects, [])
+
+ def test_cache_local_data(self):
+ importer = importers.Importer()
+ self.assertRaises(NotImplementedError, importer.cache_local_data)
+
+ def test_get_local_object(self):
+ importer = importers.Importer()
+ self.assertFalse(importer.caches_local_data)
+ self.assertRaises(NotImplementedError, importer.get_local_object, None)
+
+ someobj = object()
+ with patch.object(importer, 'get_single_local_object', Mock(return_value=someobj)):
+ obj = importer.get_local_object('somekey')
+ self.assertIs(obj, someobj)
+
+ importer.caches_local_data = True
+ importer.cached_local_data = {'somekey': {'object': someobj, 'data': {}}}
+ obj = importer.get_local_object('somekey')
+ self.assertIs(obj, someobj)
+
+ def test_get_single_local_object(self):
+ importer = importers.Importer()
+ self.assertRaises(NotImplementedError, importer.get_single_local_object, None)
+
+ def test_get_cache_key(self):
+ importer = importers.Importer(key='upc', fields=['upc'])
+ obj = {'upc': '00074305001321'}
+ normal = {'data': obj}
+ key = importer.get_cache_key(obj, normal)
+ self.assertEqual(key, ('00074305001321',))
+
+ def test_normalize_cache_object(self):
+ importer = importers.Importer()
+ obj = {'upc': '00074305001321'}
+ with patch.object(importer, 'normalize_local_object', new=lambda obj: obj):
+ cached = importer.normalize_cache_object(obj)
+ self.assertEqual(cached, {'object': obj, 'data': obj})
+
+ def test_normalize_local_object(self):
+ importer = importers.Importer(key='upc', fields=['upc', 'description'])
+ importer.simple_fields = importer.fields
+ obj = Mock(upc='00074305001321', description="Apple Cider Vinegar")
+ data = importer.normalize_local_object(obj)
+ self.assertEqual(data, {'upc': '00074305001321', 'description': "Apple Cider Vinegar"})
+
+ def test_update_object(self):
+ importer = importers.Importer(key='upc', fields=['upc', 'description'])
+ importer.simple_fields = importer.fields
+ obj = Mock(upc='00074305001321', description="Apple Cider Vinegar")
+
+ newobj = importer.update_object(obj, {'upc': '00074305001321', 'description': "Apple Cider Vinegar"})
+ self.assertIs(newobj, obj)
+ self.assertEqual(obj.description, "Apple Cider Vinegar")
+
+ newobj = importer.update_object(obj, {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"})
+ self.assertIs(newobj, obj)
+ self.assertEqual(obj.description, "Apple Cider Vinegar 32oz")
+
+ def test_normalize_host_data(self):
+ importer = importers.Importer(key='upc', fields=['upc', 'description'],
+ progress=NullProgress)
+
+ data = [
+ {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"},
+ {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"},
+ ]
+
+ host_data = importer.normalize_host_data(host_objects=[])
+ self.assertEqual(host_data, [])
+
+ host_data = importer.normalize_host_data(host_objects=data)
+ self.assertEqual(host_data, data)
+
+ with patch.object(importer, 'get_host_objects', new=Mock(return_value=data)):
+ host_data = importer.normalize_host_data()
+ self.assertEqual(host_data, data)
+
+ def test_get_deletion_keys(self):
+ importer = importers.Importer()
+ self.assertFalse(importer.caches_local_data)
+ keys = importer.get_deletion_keys()
+ self.assertEqual(keys, set())
+
+ importer.caches_local_data = True
+ self.assertIsNone(importer.cached_local_data)
+ keys = importer.get_deletion_keys()
+ self.assertEqual(keys, set())
+
+ importer.cached_local_data = {'delete-me': object()}
+ keys = importer.get_deletion_keys()
+ self.assertEqual(keys, set(['delete-me']))
+
+
+class TestFromQuery(RattailTestCase):
+
+ def test_query(self):
+ importer = importers.FromQuery()
+ self.assertRaises(NotImplementedError, importer.query)
+
+ def test_get_host_objects(self):
+ query = self.session.query(model.Product)
+ importer = importers.FromQuery()
+ with patch.object(importer, 'query', Mock(return_value=query)):
+ objects = importer.get_host_objects()
+ self.assertIsInstance(objects, QuerySequence)
+
+
+######################################################################
+# fake importer class, tested mostly for basic coverage
+######################################################################
+
+class Product(object):
+ upc = None
+ description = None
+
+
+class MockImporter(importers.Importer):
+ model_class = Product
+ key = 'upc'
+ simple_fields = ['upc', 'description']
+ supported_fields = simple_fields
+ caches_local_data = True
+
+ def normalize_local_object(self, obj):
+ return obj
+
+ def update_object(self, obj, host_data, local_data=None):
+ return host_data
+
+
+class TestMockImporter(ImporterTester, TestCase):
+ importer_class = MockImporter
+
+ sample_data = {
+ '16oz': {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"},
+ '32oz': {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"},
+ '1gal': {'upc': '00074305011283', 'description': "Apple Cider Vinegar 1gal"},
+ }
+
+ def test_create(self):
+ local = self.copy_data()
+ del local['32oz']
+ with self.host_data(self.sample_data):
+ with self.local_data(local):
+ self.import_data()
+ self.assert_import_created('32oz')
+ self.assert_import_updated()
+ self.assert_import_deleted()
+
+ def test_update(self):
+ local = self.copy_data()
+ local['16oz']['description'] = "wrong description"
+ with self.host_data(self.sample_data):
+ with self.local_data(local):
+ self.import_data()
+ self.assert_import_created()
+ self.assert_import_updated('16oz')
+ self.assert_import_deleted()
+
+ def test_delete(self):
+ local = self.copy_data()
+ local['bogus'] = {'upc': '00000000000000', 'description': "Delete Me"}
+ with self.host_data(self.sample_data):
+ with self.local_data(local):
+ self.import_data()
+ self.assert_import_created()
+ self.assert_import_updated()
+ self.assert_import_deleted('bogus')
+
+ def test_duplicate(self):
+ host = self.copy_data()
+ host['32oz-dupe'] = host['32oz']
+ with self.host_data(host):
+ with self.local_data(self.sample_data):
+ self.import_data()
+ self.assert_import_created()
+ self.assert_import_updated()
+ self.assert_import_deleted()
+
+ def test_max_create(self):
+ local = self.copy_data()
+ del local['16oz']
+ del local['1gal']
+ with self.host_data(self.sample_data):
+ with self.local_data(local):
+ self.import_data(max_create=1)
+ self.assert_import_created('16oz')
+ self.assert_import_updated()
+ self.assert_import_deleted()
+
+ def test_max_total_create(self):
+ local = self.copy_data()
+ del local['16oz']
+ del local['1gal']
+ with self.host_data(self.sample_data):
+ with self.local_data(local):
+ self.import_data(max_total=1)
+ self.assert_import_created('16oz')
+ self.assert_import_updated()
+ self.assert_import_deleted()
+
+ def test_max_update(self):
+ local = self.copy_data()
+ local['16oz']['description'] = "wrong"
+ local['1gal']['description'] = "wrong"
+ with self.host_data(self.sample_data):
+ with self.local_data(local):
+ self.import_data(max_update=1)
+ self.assert_import_created()
+ self.assert_import_updated('16oz')
+ self.assert_import_deleted()
+
+ def test_max_total_update(self):
+ local = self.copy_data()
+ local['16oz']['description'] = "wrong"
+ local['1gal']['description'] = "wrong"
+ with self.host_data(self.sample_data):
+ with self.local_data(local):
+ self.import_data(max_total=1)
+ self.assert_import_created()
+ self.assert_import_updated('16oz')
+ self.assert_import_deleted()
+
+ def test_max_delete(self):
+ local = self.copy_data()
+ local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"}
+ local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"}
+ with self.host_data(self.sample_data):
+ with self.local_data(local):
+ self.import_data(max_delete=1)
+ self.assert_import_created()
+ self.assert_import_updated()
+ self.assert_import_deleted('bogus1')
+
+ def test_max_total_delete(self):
+ local = self.copy_data()
+ local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"}
+ local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"}
+ with self.host_data(self.sample_data):
+ with self.local_data(local):
+ self.import_data(max_total=1)
+ self.assert_import_created()
+ self.assert_import_updated()
+ self.assert_import_deleted('bogus1')
diff --git a/rattail/tests/importing/test_sqlalchemy.py b/rattail/tests/importing/test_sqlalchemy.py
new file mode 100644
index 0000000000000000000000000000000000000000..435343f1b0f2d39823c0f40b03becf480ce6210f
--- /dev/null
+++ b/rattail/tests/importing/test_sqlalchemy.py
@@ -0,0 +1,140 @@
+# -*- coding: utf-8 -*-
+
+from __future__ import unicode_literals, absolute_import
+
+from unittest import TestCase
+
+import sqlalchemy as sa
+from sqlalchemy import orm
+from sqlalchemy.orm.exc import MultipleResultsFound
+
+from rattail.importing import sqlalchemy as saimport
+
+
+class Widget(object):
+ pass
+
+metadata = sa.MetaData()
+
+widget_table = sa.Table('widgets', metadata,
+ sa.Column('id', sa.Integer(), primary_key=True),
+ sa.Column('description', sa.String(length=50)))
+
+widget_mapper = orm.mapper(Widget, widget_table)
+
+WIDGETS = [
+ {'id': 1, 'description': "Main Widget"},
+ {'id': 2, 'description': "Other Widget"},
+ {'id': 3, 'description': "Other Widget"},
+]
+
+
+class TestFromSQLAlchemy(TestCase):
+
+ def test_query(self):
+ Session = orm.sessionmaker(bind=sa.create_engine('sqlite://'))
+ session = Session()
+ importer = saimport.FromSQLAlchemy(host_session=session,
+ host_model_class=Widget)
+ self.assertEqual(unicode(importer.query()),
+ "SELECT widgets.id AS widgets_id, widgets.description AS widgets_description \n"
+ "FROM widgets")
+
+
+class TestToSQLAlchemy(TestCase):
+
+ def setUp(self):
+ engine = sa.create_engine('sqlite://')
+ metadata.create_all(bind=engine)
+ Session = orm.sessionmaker(bind=engine)
+ self.session = Session()
+ for data in WIDGETS:
+ widget = Widget()
+ for key, value in data.iteritems():
+ setattr(widget, key, value)
+ self.session.add(widget)
+ self.session.commit()
+
+ def tearDown(self):
+ self.session.close()
+ del self.session
+
+ def make_importer(self, **kwargs):
+ kwargs.setdefault('model_class', Widget)
+ return saimport.ToSQLAlchemy(**kwargs)
+
+ def test_model_mapper(self):
+ importer = self.make_importer()
+ self.assertIs(importer.model_mapper, widget_mapper)
+
+ def test_model_table(self):
+ importer = self.make_importer()
+ self.assertIs(importer.model_table, widget_table)
+
+ def test_simple_fields(self):
+ importer = self.make_importer()
+ self.assertEqual(importer.simple_fields, ['id', 'description'])
+
+ def test_get_single_local_object(self):
+
+ # simple key
+ importer = self.make_importer(key='id', session=self.session)
+ widget = importer.get_single_local_object((1,))
+ self.assertEqual(widget.id, 1)
+ self.assertEqual(widget.description, "Main Widget")
+
+ # compound key
+ importer = self.make_importer(key=('id', 'description'), session=self.session)
+ widget = importer.get_single_local_object((1, "Main Widget"))
+ self.assertEqual(widget.id, 1)
+ self.assertEqual(widget.description, "Main Widget")
+
+ # widget not found
+ importer = self.make_importer(key='id', session=self.session)
+ self.assertIsNone(importer.get_single_local_object((42,)))
+
+ # multiple widgets found
+ importer = self.make_importer(key='description', session=self.session)
+ self.assertRaises(MultipleResultsFound, importer.get_single_local_object, ("Other Widget",))
+
+ def test_create_object(self):
+ importer = self.make_importer(key='id', session=self.session)
+ widget = importer.create_object((42,), {'id': 42, 'description': "Latest Widget"})
+ self.assertFalse(self.session.new or self.session.dirty or self.session.deleted) # i.e. has been flushed
+ self.assertIn(widget, self.session) # therefore widget has been flushed and would be committed
+ self.assertEqual(widget.id, 42)
+ self.assertEqual(widget.description, "Latest Widget")
+
+ def test_delete_object(self):
+ widget = self.session.query(Widget).get(1)
+ self.assertIn(widget, self.session)
+ importer = self.make_importer(session=self.session)
+ self.assertTrue(importer.delete_object(widget))
+ self.assertNotIn(widget, self.session)
+ self.assertIsNone(self.session.query(Widget).get(1))
+
+ def test_cache_model(self):
+ importer = self.make_importer(key='id', session=self.session)
+ cache = importer.cache_model(Widget, key='id')
+ self.assertEqual(len(cache), 3)
+ for i in range(1, 4):
+ self.assertIn(i, cache)
+ self.assertIsInstance(cache[i], Widget)
+ self.assertEqual(cache[i].id, i)
+ self.assertEqual(cache[i].description, WIDGETS[i-1]['description'])
+
+ def test_cache_local_data(self):
+ importer = self.make_importer(key='id', session=self.session)
+ cache = importer.cache_local_data()
+ self.assertEqual(len(cache), 3)
+ for i in range(1, 4):
+ self.assertIn((i,), cache)
+ cached = cache[(i,)]
+ self.assertIsInstance(cached, dict)
+ self.assertIsInstance(cached['object'], Widget)
+ self.assertEqual(cached['object'].id, i)
+ self.assertEqual(cached['object'].description, WIDGETS[i-1]['description'])
+ self.assertIsInstance(cached['data'], dict)
+ self.assertEqual(cached['data']['id'], i)
+ self.assertEqual(cached['data']['description'], WIDGETS[i-1]['description'])
+