diff --git a/rattail/db/util.py b/rattail/db/util.py index bd1ca43adb4997cbec4ab46930033fdaa6feb394..055cdc0d298c776723be4b88c78617fe93dc8b64 100644 --- a/rattail/db/util.py +++ b/rattail/db/util.py @@ -39,6 +39,23 @@ from rattail.db.config import engine_from_config, get_engines, get_default_engin log = logging.getLogger(__name__) +class QuerySequence(object): + """ + Simple wrapper for a SQLAlchemy (or Django, or other?) query, to make it + sort of behave like a normal sequence, as much as needed to e.g. make an + importer happy. + """ + + def __init__(self, query): + self.query = query + + def __len__(self): + return self.query.count() + + def __iter__(self): + return iter(self.query) + + def maxlen(attr): """ Return the maximum length for the given attribute. diff --git a/rattail/importing/__init__.py b/rattail/importing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..59a333cc8cf5d243aaa4a69a2ed640e29bfbcab7 --- /dev/null +++ b/rattail/importing/__init__.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +################################################################################ +# +# Rattail -- Retail Software Framework +# Copyright © 2010-2016 Lance Edgar +# +# This file is part of Rattail. +# +# Rattail is free software: you can redistribute it and/or modify it under the +# terms of the GNU Affero General Public License as published by the Free +# Software Foundation, either version 3 of the License, or (at your option) +# any later version. +# +# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for +# more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with Rattail. If not, see . +# +################################################################################ +""" +Data Importing Framework +""" + +from __future__ import unicode_literals, absolute_import + +from .importers import Importer, FromQuery +from .sqlalchemy import FromSQLAlchemy, ToSQLAlchemy diff --git a/rattail/importing/importers.py b/rattail/importing/importers.py new file mode 100644 index 0000000000000000000000000000000000000000..35bcea6e9d19e45e24ad4d331a71b6fffa9206e4 --- /dev/null +++ b/rattail/importing/importers.py @@ -0,0 +1,442 @@ +# -*- coding: utf-8 -*- +################################################################################ +# +# Rattail -- Retail Software Framework +# Copyright © 2010-2016 Lance Edgar +# +# This file is part of Rattail. +# +# Rattail is free software: you can redistribute it and/or modify it under the +# terms of the GNU Affero General Public License as published by the Free +# Software Foundation, either version 3 of the License, or (at your option) +# any later version. +# +# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for +# more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with Rattail. If not, see . +# +################################################################################ +""" +Data Importers +""" + +from __future__ import unicode_literals, absolute_import + +import datetime +import logging + +from rattail.db.util import QuerySequence +from rattail.time import make_utc + + +log = logging.getLogger(__name__) + + +class Importer(object): + """ + Base class for all data importers. + """ + # Set this to the data model class which is targeted on the local side. + model_class = None + + key = None + + # The full list of field names supported by the importer, i.e. for the data + # model to which the importer pertains. By definition this list will be + # restricted to what the local side can acommodate, but may be further + # restricted by what the host side has to offer. + supported_fields = [] + + # The list of field names which may be considered "simple" and therefore + # treated as such, i.e. with basic getattr/setattr calls. Note that this + # only applies to the local side, it has no effect on the host side. + simple_fields = [] + + allow_create = True + allow_update = True + allow_delete = True + dry_run = False + + max_create = None + max_update = None + max_delete = None + max_total = None + progress = None + + caches_local_data = False + cached_local_data = None + + host_system_title = None + local_system_title = None + + def __init__(self, config=None, fields=None, key=None, **kwargs): + self.config = config + self.fields = fields or self.supported_fields + if key is not None: + self.key = key + if isinstance(self.key, basestring): + self.key = (self.key,) + if self.key: + for field in self.key: + if field not in self.fields: + raise ValueError("Key field '{}' must be included in effective fields " + "for {}".format(key, self.__class__.__name__)) + self._setup(**kwargs) + + def _setup(self, **kwargs): + self.now = kwargs.pop('now', make_utc(datetime.datetime.utcnow(), tzinfo=True)) + self.create = kwargs.pop('create', self.allow_create) and self.allow_create + self.update = kwargs.pop('update', self.allow_update) and self.allow_update + self.delete = kwargs.pop('delete', self.allow_delete) and self.allow_delete + for key, value in kwargs.iteritems(): + setattr(self, key, value) + + @property + def model_name(self): + if self.model_class: + return self.model_class.__name__ + + def setup(self): + """ + Perform any setup necessary, e.g. cache lookups for existing data. + """ + + def teardown(self): + """ + Perform any cleanup after import, if necessary. + """ + + def import_data(self, host_data=None, **kwargs): + """ + Import some data! This is the core body of logic for that, regardless + of where data is coming from or where it's headed. Note that this + method handles deletions as well as adds/updates. + """ + self._setup(**kwargs) + self.setup() + created = updated = deleted = [] + + # Get complete set of normalized host data. + if host_data is None: + host_data = self.normalize_host_data() + + # Prune duplicate keys from host/source data. This is for the sake of + # sanity since duplicates typically lead to a ping-pong effect, where a + # "clean" (change-less) import is impossible. + unique = {} + for data in host_data: + key = self.get_key(data) + if key in unique: + log.warning("duplicate records detected from {} for key: {}".format( + self.host_system_title, key)) + unique[key] = data + host_data = [] + for key in sorted(unique): + host_data.append(unique[key]) + + # Cache local data if appropriate. + if self.caches_local_data: + self.cached_local_data = self.cache_local_data(host_data) + + # Create and/or update data. + if self.create or self.update: + created, updated = self._import_create_update(host_data) + + # Delete data. + if self.delete: + changes = len(created) + len(updated) + if self.max_total and changes >= self.max_total: + log.warning("max of {} total changes already reached; skipping deletions".format(self.max_total)) + else: + deleted = self._import_delete(host_data, set(unique), changes=changes) + + self.teardown() + return created, updated, deleted + + def _import_create_update(self, data): + """ + Import the given data; create and/or update records as needed. + """ + created, updated = [], [] + count = len(data) + if not count: + return created, updated + + prog = None + if self.progress: + prog = self.progress("Importing {} data".format(self.model_name), count) + + for i, host_data in enumerate(data, 1): + + # Fetch local object, using key from host data. + key = self.get_key(host_data) + local_object = self.get_local_object(key) + + # If we have a local object, but its data differs from host, update it. + if local_object and self.update: + local_data = self.normalize_local_object(local_object) + diffs = self.data_diffs(local_data, host_data) + if diffs: + log.debug("fields '{}' differed for local data: {}, host data: {}".format( + ','.join(diffs), local_data, host_data)) + local_object = self.update_object(local_object, host_data, local_data) + updated.append((local_object, local_data, host_data)) + if self.max_update and len(updated) >= self.max_update: + log.warning("max of {} *updated* records has been reached; stopping now".format(self.max_update)) + break + if self.max_total and (len(created) + len(updated)) >= self.max_total: + log.warning("max of {} *total changes* has been reached; stopping now".format(self.max_total)) + break + + # If we did not yet have a local object, create it using host data. + elif not local_object and self.create: + local_object = self.create_object(key, host_data) + log.debug("created new {} {}: {}".format(self.model_name, key, local_object)) + created.append((local_object, host_data)) + if self.caches_local_data and self.cached_local_data is not None: + self.cached_local_data[key] = {'object': local_object, 'data': self.normalize_local_object(local_object)} + if self.max_create and len(created) >= self.max_create: + log.warning("max of {} *created* records has been reached; stopping now".format(self.max_create)) + break + if self.max_total and (len(created) + len(updated)) >= self.max_total: + log.warning("max of {} *total changes* has been reached; stopping now".format(self.max_total)) + break + + if prog: + prog.update(i) + if prog: + prog.destroy() + + return created, updated + + def _import_delete(self, host_data, host_keys, changes=0): + """ + Import deletions for the given data set. + """ + deleted = [] + deleting = self.get_deletion_keys() - host_keys + count = len(deleting) + log.debug("found {} instances to delete".format(count)) + if count: + + prog = None + if self.progress: + prog = self.progress("Deleting {} data".format(self.model_name), count) + + for i, key in enumerate(sorted(deleting), 1): + + cached = self.cached_local_data.pop(key) + obj = cached['object'] + if self.delete_object(obj): + deleted.append((obj, cached['data'])) + + if self.max_delete and len(deleted) >= self.max_delete: + log.warning("max of {} *deleted* records has been reached; stopping now".format(self.max_delete)) + break + if self.max_total and (changes + len(deleted)) >= self.max_total: + log.warning("max of {} *total changes* has been reached; stopping now".format(self.max_total)) + break + + if prog: + prog.update(i) + if prog: + prog.destroy() + + return deleted + + def get_key(self, data): + """ + Return the key value for the given data dict. + """ + return tuple(data[k] for k in self.key) + + def get_host_objects(self): + """ + Return the "raw" (as-is, not normalized) host objects which are to be + imported. This may return any sequence-like object, which has a + ``len()`` value and responds to iteration etc. The objects contained + within it may be of any type, no assumptions are made there. (That is + the job of the :meth:`normalize_host_data()` method.) + """ + return [] + + def normalize_host_data(self, host_objects=None): + """ + Return a normalized version of the full set of host data. Note that + this calls :meth:`get_host_objects()` to obtain the initial raw + objects, and then normalizes each object. The normalization process + may filter out some records from the set, in which case the return + value will be smaller than the original data set. + """ + if host_objects is None: + host_objects = self.get_host_objects() + normalized = [] + count = len(host_objects) + if count == 0: + return normalized + prog = None + if self.progress: + prog = self.progress("Reading {} data from {}".format(self.model_name, self.host_system_title), count) + for i, obj in enumerate(host_objects, 1): + data = self.normalize_host_object(obj) + if data: + normalized.append(data) + if prog: + prog.update(i) + if prog: + prog.destroy() + return normalized + + def normalize_host_object(self, obj): + """ + Normalize a raw host object into a data dict, or return ``None`` if the + object should be ignored for the importer's purposes. + """ + return obj + + def cache_local_data(self, host_data=None): + """ + Cache all raw objects and normalized data from the local system. + """ + raise NotImplementedError + + def get_cache_key(self, obj, normal): + """ + Get the primary cache key for a given object and normalized data. + + Note that this method's signature is designed for use with the + :func:`rattail.db.cache.cache_model()` function, and as such the + ``normal`` parameter is understood to be a dict with a ``'data'`` key, + value for which is the normalized data dict for the raw object. + """ + return tuple(normal['data'].get(k) for k in self.key) + + def normalize_cache_object(self, obj): + """ + Normalizer for cached local data. This returns a simple dict with + ``'object'`` and ``'data'`` keys; values are the raw object and its + normalized data dict, respectively. + """ + return {'object': obj, 'data': self.normalize_local_object(obj)} + + def normalize_local_object(self, obj): + """ + Normalize a local (raw) object into a data dict. + """ + data = {} + for field in self.simple_fields: + if field in self.fields: + data[field] = getattr(obj, field) + return data + + def get_local_object(self, key): + """ + Must return the local object corresponding to the given key, or + ``None``. Default behavior here will be to check the cache if one is + in effect, otherwise return the value from + :meth:`get_single_local_object()`. + """ + if self.caches_local_data and self.cached_local_data is not None: + data = self.cached_local_data.get(key) + return data['object'] if data else None + return self.get_single_local_object(key) + + def get_single_local_object(self, key): + """ + Must return the local object corresponding to the given key, or None. + This method should not consult the cache; that is handled within the + :meth:`get_local_object()` method. + """ + raise NotImplementedError + + def data_diffs(self, local_data, host_data): + """ + Find all (relevant) fields which differ between the host and local data + values for a given record. + """ + diffs = [] + for field in self.fields: + if local_data[field] != host_data[field]: + diffs.append(field) + return diffs + + def make_object(self): + """ + Make a new/empty local object from scratch. + """ + return self.model_class() + + def new_object(self, key): + """ + Return a new local object to correspond to the given key. Note that + this method should only populate the object's key, and leave the rest + of the fields to :meth:`update_object()`. + """ + obj = self.make_object() + for i, k in enumerate(self.key): + if hasattr(obj, k): + setattr(obj, k, key[i]) + return obj + + def create_object(self, key, host_data): + """ + Create and return a new local object for the given key, fully populated + from the given host data. This may return ``None`` if no object is + created. + """ + obj = self.new_object(key) + if obj: + return self.update_object(obj, host_data) + + def update_object(self, obj, host_data, local_data=None): + """ + Update the local data object with the given host data, and return the + object. + """ + for field in self.simple_fields: + if field in self.fields: + if not local_data or local_data[field] != host_data[field]: + setattr(obj, field, host_data[field]) + return obj + + def get_deletion_keys(self): + """ + Return a set of keys from the *local* data set, which are eligible for + deletion. By default this will be all keys from the local cached data + set, or an empty set if local data isn't cached. + """ + if self.caches_local_data and self.cached_local_data is not None: + return set(self.cached_local_data) + return set() + + def delete_object(self, obj): + """ + Delete the given object from the local system (or not), and return a + boolean indicating whether deletion was successful. What exactly this + entails may vary; default implementation does nothing at all. + """ + return True + + +class FromQuery(Importer): + """ + Generic base class for importers whose raw external data source is a + SQLAlchemy (or Django, or possibly other?) query. + """ + + def query(self): + """ + Subclasses must override this, and return the primary query which will + define the data set. + """ + raise NotImplementedError + + def get_host_objects(self, progress=None): + """ + Returns (raw) query results as a sequence. + """ + return QuerySequence(self.query()) diff --git a/rattail/importing/sqlalchemy.py b/rattail/importing/sqlalchemy.py new file mode 100644 index 0000000000000000000000000000000000000000..cc67d98d56a83ac96088bfc22343aaed6d6417ae --- /dev/null +++ b/rattail/importing/sqlalchemy.py @@ -0,0 +1,170 @@ +# -*- coding: utf-8 -*- +################################################################################ +# +# Rattail -- Retail Software Framework +# Copyright © 2010-2016 Lance Edgar +# +# This file is part of Rattail. +# +# Rattail is free software: you can redistribute it and/or modify it under the +# terms of the GNU Affero General Public License as published by the Free +# Software Foundation, either version 3 of the License, or (at your option) +# any later version. +# +# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for +# more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with Rattail. If not, see . +# +################################################################################ +""" +Data Importers for SQLAlchemy +""" + +from __future__ import unicode_literals, absolute_import + +from sqlalchemy import orm +from sqlalchemy.orm.exc import NoResultFound + +from rattail.db import cache +from rattail.importing import Importer, FromQuery + + +class FromSQLAlchemy(FromQuery): + """ + Base class for importers whose external data source is a SQLAlchemy query. + """ + host_session = None + + # For default behavior, set this to a model class to be used in generating + # the host data query. + host_model_class = None + + def query(self): + """ + Must return the primary query which will define the data set. Default + behavior is to leverage :attr:`host_session` and generate a query for + the class defined by :attr:`host_model_class`. + """ + return self.host_session.query(self.host_model_class) + + +class ToSQLAlchemy(Importer): + """ + Base class for all data importers which support the common use case of + targeting a SQLAlchemy ORM on the local side. This is the base class for + all primary Rattail importers. + """ + caches_local_data = True + + def __init__(self, model_class=None, **kwargs): + if model_class: + self.model_class = model_class + super(ToSQLAlchemy, self).__init__(**kwargs) + + @property + def model_mapper(self): + """ + Reference to the effective SQLAlchemy mapper for the local model class. + """ + return orm.class_mapper(self.model_class) + + @property + def model_table(self): + """ + Reference to the effective SQLAlchemy table for the local model class. + """ + tables = self.model_mapper.tables + assert len(tables) == 1 + return tables[0] + + @property + def simple_fields(self): + """ + Returns the list of column names on the underlying local model mapper. + """ + return list(self.model_mapper.columns.keys()) + + @property + def supported_fields(self): + """ + All/only simple fields are supported by default. + """ + return self.simple_fields + + def get_single_local_object(self, key): + """ + Try to fetch the object from the local database, using SA ORM. + """ + query = self.session.query(self.model_class) + for i, k in enumerate(self.key): + query = query.filter(getattr(self.model_class, k) == key[i]) + try: + return query.one() + except NoResultFound: + pass + + def create_object(self, key, host_data): + """ + Create and return a new local object for the given key, fully populated + from the given host data. This may return ``None`` if no object is + created. + + Note that this also adds the new object to the local database session, + and flushes the session. + """ + obj = super(ToSQLAlchemy, self).create_object(key, host_data) + if obj: + self.session.add(obj) + self.session.flush() + return obj + + def update_object(self, obj, host_data, local_data=None): + """ + Update the local data object with the given host data, and return the + object. + """ + obj = super(ToSQLAlchemy, self).update_object(obj, host_data, local_data) + if obj: + self.session.flush() + return obj + + def delete_object(self, obj): + """ + Delete the given object from the local system (or not), and return a + boolean indicating whether deletion was successful. Default + implementation will truly delete and expunge the local object via SA + ORM, and flush the local session. + """ + self.session.delete(obj) + self.session.flush() + self.session.expunge(obj) + return True + + def cache_model(self, model, **kwargs): + """ + Convenience method which invokes :func:`rattail.db.cache.cache_model()` + with the given model and keyword arguments. It will provide the + ``session`` and ``progress`` parameters by default, setting them to the + importer's attributes of the same names. + """ + session = kwargs.pop('session', self.session) + kwargs.setdefault('progress', self.progress) + return cache.cache_model(session, model, **kwargs) + + def cache_local_data(self, host_data=None): + """ + Cache all local objects and data using SA ORM. + """ + return self.cache_model(self.model_class, key=self.get_cache_key, + # omit_duplicates=True, + query_options=self.cache_query_options(), + normalizer=self.normalize_cache_object) + + def cache_query_options(self): + """ + Return a list of options to apply to the cache query, if needed. + """ diff --git a/rattail/tests/__init__.py b/rattail/tests/__init__.py index 63c476c592c7021cf36948613466389bd55ed13c..56c2d453b1b99e2a76a7fb2906567db01bacf02f 100644 --- a/rattail/tests/__init__.py +++ b/rattail/tests/__init__.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- -from __future__ import unicode_literals +from __future__ import unicode_literals, absolute_import import os import warnings @@ -9,6 +9,7 @@ from unittest import TestCase from sqlalchemy import create_engine from sqlalchemy.exc import SAWarning +from rattail.config import make_config from rattail.db import model from rattail.db import Session @@ -22,6 +23,69 @@ warnings.filterwarnings( SAWarning, r'^sqlalchemy\..*$') +class NullProgress(object): + """ + Dummy progress bar which does nothing, but allows for better test coverage + when used with code under test. + """ + + def __init__(self, message, count): + pass + + def update(self, value): + pass + + def destroy(self): + pass + + +class RattailMixin(object): + """ + Generic mixin for ``TestCase`` classes which need common Rattail setup + functionality. + """ + engine_url = os.environ.get('RATTAIL_TEST_ENGINE_URL', 'sqlite://') + + def setUp(self): + self.setup_rattail() + + def tearDown(self): + self.teardown_rattail() + + def setup_rattail(self): + config = self.make_rattail_config() + self.rattail_config = config + + engine = create_engine(self.engine_url) + config.rattail_engines['default'] = engine + config.rattail_engine = engine + + model = self.get_rattail_model() + model.Base.metadata.create_all(bind=engine) + + Session.configure(bind=engine, rattail_config=config) + self.session = Session() + + def teardown_rattail(self): + self.session.close() + Session.configure(bind=None, rattail_config=None) + model = self.get_rattail_model() + model.Base.metadata.drop_all(bind=self.rattail_config.rattail_engine) + + def make_rattail_config(self, **kwargs): + kwargs.setdefault('files', []) + return make_config(**kwargs) + + def get_rattail_model(self): + return model + + +class RattailTestCase(RattailMixin, TestCase): + """ + Generic base class for Rattail tests. + """ + + class DataTestCase(TestCase): engine_url = os.environ.get('RATTAIL_TEST_ENGINE_URL', 'sqlite://') diff --git a/rattail/tests/importing/__init__.py b/rattail/tests/importing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..388d0dfaea9d2520190010b7f109768925551547 --- /dev/null +++ b/rattail/tests/importing/__init__.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- + +from __future__ import unicode_literals, absolute_import + +import copy +from contextlib import contextmanager + +from mock import patch + +from rattail.tests import NullProgress + + +class ImporterTester(object): + """ + Mixin for importer test suites. + """ + importer_class = None + sample_data = {} + + def setUp(self): + self.setup_importer() + + def setup_importer(self): + self.importer = self.make_importer() + + def make_importer(self, **kwargs): + kwargs.setdefault('progress', NullProgress) + return self.importer_class(**kwargs) + + def copy_data(self): + return copy.deepcopy(self.sample_data) + + @contextmanager + def host_data(self, data): + self._host_data = data + host_data = [self.importer.normalize_host_object(obj) for obj in data.itervalues()] + with patch.object(self.importer, 'normalize_host_data') as normalize: + normalize.return_value = host_data + yield + + @contextmanager + def local_data(self, data): + self._local_data = data + local_data = {} + for key, obj in data.iteritems(): + normal = self.importer.normalize_local_object(obj) + local_data[self.importer.get_key(normal)] = {'object': obj, 'data': normal} + with patch.object(self.importer, 'cache_local_data') as cache: + cache.return_value = local_data + yield + + def import_data(self, **kwargs): + self.result = self.importer.import_data(**kwargs) + + def assert_import_created(self, *keys): + created, updated, deleted = self.result + self.assertEqual(len(created), len(keys)) + for key in keys: + key = self.importer.get_key(self._host_data[key]) + found = False + for local_object, host_data in created: + if self.importer.get_key(host_data) == key: + found = True + break + if not found: + raise self.failureException("Key {} not created when importing with {}".format(key, self.importer)) + + def assert_import_updated(self, *keys): + created, updated, deleted = self.result + self.assertEqual(len(updated), len(keys)) + for key in keys: + key = self.importer.get_key(self._host_data[key]) + found = False + for local_object, local_data, host_data in updated: + if self.importer.get_key(local_data) == key: + found = True + break + if not found: + raise self.failureException("Key {} not updated when importing with {}".format(key, self.importer)) + + def assert_import_deleted(self, *keys): + created, updated, deleted = self.result + self.assertEqual(len(deleted), len(keys)) + for key in keys: + key = self.importer.get_key(self._local_data[key]) + found = False + for local_object, local_data in deleted: + if self.importer.get_key(local_data) == key: + found = True + break + if not found: + raise self.failureException("Key {} not deleted when importing with {}".format(key, self.importer)) + + def test_empty_host(self): + with self.host_data({}): + with self.local_data(self.sample_data): + self.import_data(delete=False) + self.assert_import_created() + self.assert_import_updated() + self.assert_import_deleted() diff --git a/rattail/tests/importing/test_importers.py b/rattail/tests/importing/test_importers.py new file mode 100644 index 0000000000000000000000000000000000000000..7a4638b685d817a865e798b50b2b2ad727c414fd --- /dev/null +++ b/rattail/tests/importing/test_importers.py @@ -0,0 +1,285 @@ +# -*- coding: utf-8 -*- + +from __future__ import unicode_literals, absolute_import + +from unittest import TestCase + +from mock import Mock, patch + +from rattail.db import model +from rattail.db.util import QuerySequence +from rattail.importing import importers +from rattail.tests import NullProgress, RattailTestCase +from rattail.tests.importing import ImporterTester + + +class TestImporter(TestCase): + + def test_init(self): + importer = importers.Importer() + self.assertIsNone(importer.model_class) + self.assertIsNone(importer.model_name) + self.assertIsNone(importer.key) + self.assertEqual(importer.fields, []) + + # key must be included among the fields + self.assertRaises(ValueError, importers.Importer, key='upc', fields=[]) + importer = importers.Importer(key='upc', fields=['upc']) + self.assertEqual(importer.key, ('upc',)) + self.assertEqual(importer.fields, ['upc']) + + # extra bits are passed as-is + importer = importers.Importer() + self.assertFalse(hasattr(importer, 'extra_bit')) + extra_bit = object() + importer = importers.Importer(extra_bit=extra_bit) + self.assertIs(importer.extra_bit, extra_bit) + + def test_get_host_objects(self): + importer = importers.Importer() + objects = importer.get_host_objects() + self.assertEqual(objects, []) + + def test_cache_local_data(self): + importer = importers.Importer() + self.assertRaises(NotImplementedError, importer.cache_local_data) + + def test_get_local_object(self): + importer = importers.Importer() + self.assertFalse(importer.caches_local_data) + self.assertRaises(NotImplementedError, importer.get_local_object, None) + + someobj = object() + with patch.object(importer, 'get_single_local_object', Mock(return_value=someobj)): + obj = importer.get_local_object('somekey') + self.assertIs(obj, someobj) + + importer.caches_local_data = True + importer.cached_local_data = {'somekey': {'object': someobj, 'data': {}}} + obj = importer.get_local_object('somekey') + self.assertIs(obj, someobj) + + def test_get_single_local_object(self): + importer = importers.Importer() + self.assertRaises(NotImplementedError, importer.get_single_local_object, None) + + def test_get_cache_key(self): + importer = importers.Importer(key='upc', fields=['upc']) + obj = {'upc': '00074305001321'} + normal = {'data': obj} + key = importer.get_cache_key(obj, normal) + self.assertEqual(key, ('00074305001321',)) + + def test_normalize_cache_object(self): + importer = importers.Importer() + obj = {'upc': '00074305001321'} + with patch.object(importer, 'normalize_local_object', new=lambda obj: obj): + cached = importer.normalize_cache_object(obj) + self.assertEqual(cached, {'object': obj, 'data': obj}) + + def test_normalize_local_object(self): + importer = importers.Importer(key='upc', fields=['upc', 'description']) + importer.simple_fields = importer.fields + obj = Mock(upc='00074305001321', description="Apple Cider Vinegar") + data = importer.normalize_local_object(obj) + self.assertEqual(data, {'upc': '00074305001321', 'description': "Apple Cider Vinegar"}) + + def test_update_object(self): + importer = importers.Importer(key='upc', fields=['upc', 'description']) + importer.simple_fields = importer.fields + obj = Mock(upc='00074305001321', description="Apple Cider Vinegar") + + newobj = importer.update_object(obj, {'upc': '00074305001321', 'description': "Apple Cider Vinegar"}) + self.assertIs(newobj, obj) + self.assertEqual(obj.description, "Apple Cider Vinegar") + + newobj = importer.update_object(obj, {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"}) + self.assertIs(newobj, obj) + self.assertEqual(obj.description, "Apple Cider Vinegar 32oz") + + def test_normalize_host_data(self): + importer = importers.Importer(key='upc', fields=['upc', 'description'], + progress=NullProgress) + + data = [ + {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"}, + {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"}, + ] + + host_data = importer.normalize_host_data(host_objects=[]) + self.assertEqual(host_data, []) + + host_data = importer.normalize_host_data(host_objects=data) + self.assertEqual(host_data, data) + + with patch.object(importer, 'get_host_objects', new=Mock(return_value=data)): + host_data = importer.normalize_host_data() + self.assertEqual(host_data, data) + + def test_get_deletion_keys(self): + importer = importers.Importer() + self.assertFalse(importer.caches_local_data) + keys = importer.get_deletion_keys() + self.assertEqual(keys, set()) + + importer.caches_local_data = True + self.assertIsNone(importer.cached_local_data) + keys = importer.get_deletion_keys() + self.assertEqual(keys, set()) + + importer.cached_local_data = {'delete-me': object()} + keys = importer.get_deletion_keys() + self.assertEqual(keys, set(['delete-me'])) + + +class TestFromQuery(RattailTestCase): + + def test_query(self): + importer = importers.FromQuery() + self.assertRaises(NotImplementedError, importer.query) + + def test_get_host_objects(self): + query = self.session.query(model.Product) + importer = importers.FromQuery() + with patch.object(importer, 'query', Mock(return_value=query)): + objects = importer.get_host_objects() + self.assertIsInstance(objects, QuerySequence) + + +###################################################################### +# fake importer class, tested mostly for basic coverage +###################################################################### + +class Product(object): + upc = None + description = None + + +class MockImporter(importers.Importer): + model_class = Product + key = 'upc' + simple_fields = ['upc', 'description'] + supported_fields = simple_fields + caches_local_data = True + + def normalize_local_object(self, obj): + return obj + + def update_object(self, obj, host_data, local_data=None): + return host_data + + +class TestMockImporter(ImporterTester, TestCase): + importer_class = MockImporter + + sample_data = { + '16oz': {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"}, + '32oz': {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"}, + '1gal': {'upc': '00074305011283', 'description': "Apple Cider Vinegar 1gal"}, + } + + def test_create(self): + local = self.copy_data() + del local['32oz'] + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data() + self.assert_import_created('32oz') + self.assert_import_updated() + self.assert_import_deleted() + + def test_update(self): + local = self.copy_data() + local['16oz']['description'] = "wrong description" + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data() + self.assert_import_created() + self.assert_import_updated('16oz') + self.assert_import_deleted() + + def test_delete(self): + local = self.copy_data() + local['bogus'] = {'upc': '00000000000000', 'description': "Delete Me"} + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data() + self.assert_import_created() + self.assert_import_updated() + self.assert_import_deleted('bogus') + + def test_duplicate(self): + host = self.copy_data() + host['32oz-dupe'] = host['32oz'] + with self.host_data(host): + with self.local_data(self.sample_data): + self.import_data() + self.assert_import_created() + self.assert_import_updated() + self.assert_import_deleted() + + def test_max_create(self): + local = self.copy_data() + del local['16oz'] + del local['1gal'] + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data(max_create=1) + self.assert_import_created('16oz') + self.assert_import_updated() + self.assert_import_deleted() + + def test_max_total_create(self): + local = self.copy_data() + del local['16oz'] + del local['1gal'] + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data(max_total=1) + self.assert_import_created('16oz') + self.assert_import_updated() + self.assert_import_deleted() + + def test_max_update(self): + local = self.copy_data() + local['16oz']['description'] = "wrong" + local['1gal']['description'] = "wrong" + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data(max_update=1) + self.assert_import_created() + self.assert_import_updated('16oz') + self.assert_import_deleted() + + def test_max_total_update(self): + local = self.copy_data() + local['16oz']['description'] = "wrong" + local['1gal']['description'] = "wrong" + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data(max_total=1) + self.assert_import_created() + self.assert_import_updated('16oz') + self.assert_import_deleted() + + def test_max_delete(self): + local = self.copy_data() + local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"} + local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"} + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data(max_delete=1) + self.assert_import_created() + self.assert_import_updated() + self.assert_import_deleted('bogus1') + + def test_max_total_delete(self): + local = self.copy_data() + local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"} + local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"} + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data(max_total=1) + self.assert_import_created() + self.assert_import_updated() + self.assert_import_deleted('bogus1') diff --git a/rattail/tests/importing/test_sqlalchemy.py b/rattail/tests/importing/test_sqlalchemy.py new file mode 100644 index 0000000000000000000000000000000000000000..435343f1b0f2d39823c0f40b03becf480ce6210f --- /dev/null +++ b/rattail/tests/importing/test_sqlalchemy.py @@ -0,0 +1,140 @@ +# -*- coding: utf-8 -*- + +from __future__ import unicode_literals, absolute_import + +from unittest import TestCase + +import sqlalchemy as sa +from sqlalchemy import orm +from sqlalchemy.orm.exc import MultipleResultsFound + +from rattail.importing import sqlalchemy as saimport + + +class Widget(object): + pass + +metadata = sa.MetaData() + +widget_table = sa.Table('widgets', metadata, + sa.Column('id', sa.Integer(), primary_key=True), + sa.Column('description', sa.String(length=50))) + +widget_mapper = orm.mapper(Widget, widget_table) + +WIDGETS = [ + {'id': 1, 'description': "Main Widget"}, + {'id': 2, 'description': "Other Widget"}, + {'id': 3, 'description': "Other Widget"}, +] + + +class TestFromSQLAlchemy(TestCase): + + def test_query(self): + Session = orm.sessionmaker(bind=sa.create_engine('sqlite://')) + session = Session() + importer = saimport.FromSQLAlchemy(host_session=session, + host_model_class=Widget) + self.assertEqual(unicode(importer.query()), + "SELECT widgets.id AS widgets_id, widgets.description AS widgets_description \n" + "FROM widgets") + + +class TestToSQLAlchemy(TestCase): + + def setUp(self): + engine = sa.create_engine('sqlite://') + metadata.create_all(bind=engine) + Session = orm.sessionmaker(bind=engine) + self.session = Session() + for data in WIDGETS: + widget = Widget() + for key, value in data.iteritems(): + setattr(widget, key, value) + self.session.add(widget) + self.session.commit() + + def tearDown(self): + self.session.close() + del self.session + + def make_importer(self, **kwargs): + kwargs.setdefault('model_class', Widget) + return saimport.ToSQLAlchemy(**kwargs) + + def test_model_mapper(self): + importer = self.make_importer() + self.assertIs(importer.model_mapper, widget_mapper) + + def test_model_table(self): + importer = self.make_importer() + self.assertIs(importer.model_table, widget_table) + + def test_simple_fields(self): + importer = self.make_importer() + self.assertEqual(importer.simple_fields, ['id', 'description']) + + def test_get_single_local_object(self): + + # simple key + importer = self.make_importer(key='id', session=self.session) + widget = importer.get_single_local_object((1,)) + self.assertEqual(widget.id, 1) + self.assertEqual(widget.description, "Main Widget") + + # compound key + importer = self.make_importer(key=('id', 'description'), session=self.session) + widget = importer.get_single_local_object((1, "Main Widget")) + self.assertEqual(widget.id, 1) + self.assertEqual(widget.description, "Main Widget") + + # widget not found + importer = self.make_importer(key='id', session=self.session) + self.assertIsNone(importer.get_single_local_object((42,))) + + # multiple widgets found + importer = self.make_importer(key='description', session=self.session) + self.assertRaises(MultipleResultsFound, importer.get_single_local_object, ("Other Widget",)) + + def test_create_object(self): + importer = self.make_importer(key='id', session=self.session) + widget = importer.create_object((42,), {'id': 42, 'description': "Latest Widget"}) + self.assertFalse(self.session.new or self.session.dirty or self.session.deleted) # i.e. has been flushed + self.assertIn(widget, self.session) # therefore widget has been flushed and would be committed + self.assertEqual(widget.id, 42) + self.assertEqual(widget.description, "Latest Widget") + + def test_delete_object(self): + widget = self.session.query(Widget).get(1) + self.assertIn(widget, self.session) + importer = self.make_importer(session=self.session) + self.assertTrue(importer.delete_object(widget)) + self.assertNotIn(widget, self.session) + self.assertIsNone(self.session.query(Widget).get(1)) + + def test_cache_model(self): + importer = self.make_importer(key='id', session=self.session) + cache = importer.cache_model(Widget, key='id') + self.assertEqual(len(cache), 3) + for i in range(1, 4): + self.assertIn(i, cache) + self.assertIsInstance(cache[i], Widget) + self.assertEqual(cache[i].id, i) + self.assertEqual(cache[i].description, WIDGETS[i-1]['description']) + + def test_cache_local_data(self): + importer = self.make_importer(key='id', session=self.session) + cache = importer.cache_local_data() + self.assertEqual(len(cache), 3) + for i in range(1, 4): + self.assertIn((i,), cache) + cached = cache[(i,)] + self.assertIsInstance(cached, dict) + self.assertIsInstance(cached['object'], Widget) + self.assertEqual(cached['object'].id, i) + self.assertEqual(cached['object'].description, WIDGETS[i-1]['description']) + self.assertIsInstance(cached['data'], dict) + self.assertEqual(cached['data']['id'], i) + self.assertEqual(cached['data']['description'], WIDGETS[i-1]['description']) +