Changeset - 7a2ef3551801
[Not reviewed]
0 7 0
Lance Edgar - 8 years ago 2016-05-16 21:09:19
ledgar@sacfoodcoop.com
Add `BulkImporter` and `BulkImportHandler` base classes
7 files changed with 240 insertions and 125 deletions:
0 comments (0 inline, 0 general)
rattail/importing/__init__.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2016 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
#  Rattail is free software: you can redistribute it and/or modify it under the
 
#  terms of the GNU Affero General Public License as published by the Free
 
#  Software Foundation, either version 3 of the License, or (at your option)
 
#  any later version.
 
#
 
#  Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
 
#  WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 
#  FOR A PARTICULAR PURPOSE.  See the GNU Affero General Public License for
 
#  more details.
 
#
 
#  You should have received a copy of the GNU Affero General Public License
 
#  along with Rattail.  If not, see <http://www.gnu.org/licenses/>.
 
#
 
################################################################################
 
"""
 
Data Importing Framework
 
"""
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from .importers import Importer, FromQuery
 
from .importers import Importer, FromQuery, BulkImporter
 
from .sqlalchemy import FromSQLAlchemy, ToSQLAlchemy
 
from .postgresql import BulkToPostgreSQL
 
from .handlers import ImportHandler, FromSQLAlchemyHandler, ToSQLAlchemyHandler, BulkToPostgreSQLHandler
 
from .handlers import ImportHandler, BulkImportHandler, FromSQLAlchemyHandler, ToSQLAlchemyHandler
 
from .rattail import FromRattailHandler, ToRattailHandler
 
from . import model
rattail/importing/handlers.py
Show inline comments
 
@@ -184,150 +184,164 @@ class ImportHandler(object):
 

	
 
    def commit_host_transaction(self):
 
        pass
 

	
 
    def commit_local_transaction(self):
 
        pass
 

	
 
    def process_changes(self, changes):
 
        """
 
        This method is called any time changes occur, regardless of whether the
 
        import is running in "warnings" mode.  Default implementation does
 
        nothing; override as needed.
 
        """
 
        # TODO: This whole thing needs a re-write...but for now, waiting until
 
        # the old importer has really gone away, so we can share its email
 
        # template instead of bothering with something more complicated.
 

	
 
        if not self.warnings:
 
            return
 

	
 
        now = make_utc(datetime.datetime.utcnow(), tzinfo=True)
 
        data = {
 
            'local_title': self.local_title,
 
            'host_title': self.host_title,
 
            'argv': sys.argv,
 
            'runtime': humanize.naturaldelta(now - self.import_began),
 
            'changes': changes,
 
            'dry_run': self.dry_run,
 
            'render_record': RecordRenderer(self.config),
 
            'max_display': 15,
 
        }
 

	
 
        command = getattr(self, 'command', None)
 
        if command:
 
            data['command'] = '{} {}'.format(command.parent.name, command.name)
 
        else:
 
            data['command'] = None
 

	
 
        if command:
 
            key = '{}_{}_updates'.format(command.parent.name, command.name)
 
            key = key.replace('-', '_')
 
        else:
 
            key = 'rattail_import_updates'
 

	
 
        send_email(self.config, key, fallback_key='rattail_import_updates', data=data)
 
        log.info("warning email was sent for {} -> {} import".format(self.host_title, self.local_title))
 

	
 

	
 
class BulkImportHandler(ImportHandler):
 
    """
 
    Base class for bulk import handlers.
 
    """
 

	
 
    def import_data(self, *keys, **kwargs):
 
        """
 
        Import all data for the given importer/model keys.
 
        """
 
        # TODO: still need to refactor much of this so can share with parent class
 
        self.import_began = make_utc(datetime.datetime.utcnow(), tzinfo=True)
 
        if 'dry_run' in kwargs:
 
            self.dry_run = kwargs['dry_run']
 
        self.progress = kwargs.pop('progress', getattr(self, 'progress', None))
 
        self.warnings = kwargs.pop('warnings', False)
 
        kwargs.update({'dry_run': self.dry_run,
 
                       'progress': self.progress})
 
        self.setup()
 
        self.begin_transaction()
 
        changes = OrderedDict()
 

	
 
        try:
 
            for key in keys:
 
                importer = self.get_importer(key, **kwargs)
 
                if not importer:
 
                    log.warning("skipping unknown importer: {}".format(key))
 
                    continue
 

	
 
                created = importer.import_data()
 
                log.info("{} -> {}: added {}, updated 0, deleted 0 {} records".format(
 
                    self.host_title, self.local_title, created, key))
 
                if created:
 
                    changes[key] = created
 
        except:
 
            if self.commit_host_partial and not self.dry_run:
 
                log.warning("{host} -> {local}: committing partial transaction on host {host} (despite error)".format(
 
                    host=self.host_title, local=self.local_title))
 
                self.commit_host_transaction()
 
            raise
 
        else:
 
            if self.dry_run:
 
                self.rollback_transaction()
 
            else:
 
                self.commit_transaction()
 

	
 
        self.teardown()
 
        return changes
 

	
 

	
 
class FromSQLAlchemyHandler(ImportHandler):
 
    """
 
    Handler for imports for which the host data source is represented by a
 
    SQLAlchemy engine and ORM.
 
    """
 
    host_session = None
 

	
 
    def make_host_session(self):
 
        """
 
        Subclasses must override this to define the host database connection.
 
        """
 
        raise NotImplementedError
 

	
 
    def get_importer_kwargs(self, key, **kwargs):
 
        kwargs = super(FromSQLAlchemyHandler, self).get_importer_kwargs(key, **kwargs)
 
        kwargs.setdefault('host_session', self.host_session)
 
        return kwargs
 

	
 
    def begin_host_transaction(self):
 
        self.host_session = self.make_host_session()
 

	
 
    def rollback_host_transaction(self):
 
        self.host_session.rollback()
 
        self.host_session.close()
 
        self.host_session = None
 

	
 
    def commit_host_transaction(self):
 
        self.host_session.commit()
 
        self.host_session.close()
 
        self.host_session = None
 

	
 

	
 
class ToSQLAlchemyHandler(ImportHandler):
 
    """
 
    Handler for imports which target a SQLAlchemy ORM on the local side.
 
    """
 
    session = None
 

	
 
    def make_session(self):
 
        """
 
        Subclasses must override this to define the local database connection.
 
        """
 
        raise NotImplementedError
 

	
 
    def get_importer_kwargs(self, key, **kwargs):
 
        kwargs = super(ToSQLAlchemyHandler, self).get_importer_kwargs(key, **kwargs)
 
        kwargs.setdefault('session', self.session)
 
        return kwargs
 

	
 
    def begin_local_transaction(self):
 
        self.session = self.make_session()
 

	
 
    def rollback_local_transaction(self):
 
        self.session.rollback()
 
        self.session.close()
 
        self.session = None
 

	
 
    def commit_local_transaction(self):
 
        self.session.commit()
 
        self.session.close()
 
        self.session = None
 

	
 

	
 
class BulkToPostgreSQLHandler(ToSQLAlchemyHandler):
 
class BulkToPostgreSQLHandler(BulkImportHandler):
 
    """
 
    Handler for bulk imports which target PostgreSQL on the local side.
 
    """
 

	
 
    def import_data(self, *keys, **kwargs):
 
        """
 
        Import all data for the given importer/model keys.
 
        """
 
        # TODO: still need to refactor much of this so can share with parent class
 
        self.import_began = make_utc(datetime.datetime.utcnow(), tzinfo=True)
 
        if 'dry_run' in kwargs:
 
            self.dry_run = kwargs['dry_run']
 
        self.progress = kwargs.pop('progress', getattr(self, 'progress', None))
 
        self.warnings = kwargs.pop('warnings', False)
 
        kwargs.update({'dry_run': self.dry_run,
 
                       'progress': self.progress})
 
        self.setup()
 
        self.begin_transaction()
 
        changes = OrderedDict()
 

	
 
        for key in keys:
 
            importer = self.get_importer(key, **kwargs)
 
            if not importer:
 
                log.warning("skipping unknown importer: {}".format(key))
 
                continue
 

	
 
            created = importer.import_data()
 
            log.info("{} -> {}: added {}, updated 0, deleted 0 {} records".format(
 
                self.host_title, self.local_title, created, key))
 
            if created:
 
                changes[key] = created
 

	
 
        if self.dry_run:
 
            self.rollback_transaction()
 
        else:
 
            self.commit_transaction()
 
        self.teardown()
 
        return changes
rattail/importing/importers.py
Show inline comments
 
@@ -411,48 +411,99 @@ class Importer(object):
 
    def update_object(self, obj, host_data, local_data=None):
 
        """
 
        Update the local data object with the given host data, and return the
 
        object.
 
        """
 
        for field in self.simple_fields:
 
            if field in self.fields:
 
                if not local_data or local_data[field] != host_data[field]:
 
                    setattr(obj, field, host_data[field])
 
        return obj
 

	
 
    def get_deletion_keys(self):
 
        """
 
        Return a set of keys from the *local* data set, which are eligible for
 
        deletion.  By default this will be all keys from the local cached data
 
        set, or an empty set if local data isn't cached.
 
        """
 
        if self.caches_local_data and self.cached_local_data is not None:
 
            return set(self.cached_local_data)
 
        return set()
 

	
 
    def delete_object(self, obj):
 
        """
 
        Delete the given object from the local system (or not), and return a
 
        boolean indicating whether deletion was successful.  What exactly this
 
        entails may vary; default implementation does nothing at all.
 
        """
 
        return True
 

	
 

	
 
class FromQuery(Importer):
 
    """
 
    Generic base class for importers whose raw external data source is a
 
    SQLAlchemy (or Django, or possibly other?) query.
 
    """
 

	
 
    def query(self):
 
        """
 
        Subclasses must override this, and return the primary query which will
 
        define the data set.
 
        """
 
        raise NotImplementedError
 

	
 
    def get_host_objects(self, progress=None):
 
        """
 
        Returns (raw) query results as a sequence.
 
        """
 
        return QuerySequence(self.query())
 

	
 

	
 
class BulkImporter(Importer):
 
    """
 
    Base class for bulk data importers which target PostgreSQL on the local side.
 
    """
 

	
 
    def import_data(self, host_data=None, now=None, **kwargs):
 
        self.now = now or make_utc(datetime.datetime.utcnow(), tzinfo=True)
 
        if kwargs:
 
            self._setup(**kwargs)
 
        self.setup()
 
        if host_data is None:
 
            host_data = self.normalize_host_data()
 
        created = self._import_create(host_data)
 
        self.teardown()
 
        return created
 

	
 
    def _import_create(self, data):
 
        count = len(data)
 
        if not count:
 
            return 0
 
        created = count
 

	
 
        prog = None
 
        if self.progress:
 
            prog = self.progress("Importing {} data".format(self.model_name), count)
 

	
 
        for i, host_data in enumerate(data, 1):
 

	
 
            key = self.get_key(host_data)
 
            self.create_object(key, host_data)
 
            if self.max_create and i >= self.max_create:
 
                log.warning("max of {} *created* records has been reached; stopping now".format(self.max_create))
 
                created = i
 
                break
 

	
 
            if prog:
 
                prog.update(i)
 
        if prog:
 
            prog.destroy()
 

	
 
        self.flush_create()
 
        return created
 

	
 
    def flush_create(self):
 
        """
 
        Perform any final steps to "flush" the created data here.  Note that
 
        the importer's handler is still responsible for actually committing
 
        changes to the local system, if applicable.
 
        """
rattail/importing/postgresql.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2016 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
#  Rattail is free software: you can redistribute it and/or modify it under the
 
#  terms of the GNU Affero General Public License as published by the Free
 
#  Software Foundation, either version 3 of the License, or (at your option)
 
#  any later version.
 
#
 
#  Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
 
#  WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 
#  FOR A PARTICULAR PURPOSE.  See the GNU Affero General Public License for
 
#  more details.
 
#
 
#  You should have received a copy of the GNU Affero General Public License
 
#  along with Rattail.  If not, see <http://www.gnu.org/licenses/>.
 
#
 
################################################################################
 
"""
 
PostgreSQL data importers
 
"""
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
import os
 
import datetime
 
import logging
 

	
 
from rattail.importing.sqlalchemy import ToSQLAlchemy
 
from rattail.importing import BulkImporter, ToSQLAlchemy
 
from rattail.time import make_utc
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
class BulkToPostgreSQL(ToSQLAlchemy):
 
class BulkToPostgreSQL(BulkImporter, ToSQLAlchemy):
 
    """
 
    Base class for bulk data importers which target PostgreSQL on the local side.
 
    """
 

	
 
    @property
 
    def data_path(self):
 
        return os.path.join(self.config.workdir(require=True),
 
                            'import_bulk_postgresql_{}.csv'.format(self.model_name))
 

	
 
    def setup(self):
 
        self.data_buffer = open(self.data_path, 'wb')
 

	
 
    def teardown(self):
 
        self.data_buffer.close()
 
        os.remove(self.data_path)
 
        self.data_buffer = None
 

	
 
    def import_data(self, host_data=None, now=None, **kwargs):
 
        self.now = now or make_utc(datetime.datetime.utcnow(), tzinfo=True)
 
        if kwargs:
 
            self._setup(**kwargs)
 
        self.setup()
 
        if host_data is None:
 
            host_data = self.normalize_host_data()
 
        created = self._import_create(host_data)
 
        self.teardown()
 
        return created
 

	
 
    def _import_create(self, data):
 
        count = len(data)
 
        if not count:
 
            return 0
 
        created = count
 

	
 
        prog = None
 
        if self.progress:
 
            prog = self.progress("Importing {} data".format(self.model_name), count)
 

	
 
        for i, host_data in enumerate(data, 1):
 

	
 
            key = self.get_key(host_data)
 
            self.create_object(key, host_data)
 
            if self.max_create and i >= self.max_create:
 
                log.warning("max of {} *created* records has been reached; stopping now".format(self.max_create))
 
                created = i
 
                break
 

	
 
            if prog:
 
                prog.update(i)
 
        if prog:
 
            prog.destroy()
 

	
 
        self.commit_create()
 
        return created
 

	
 
    def create_object(self, key, data):
 
        data = self.prep_data_for_postgres(data)
 
        self.data_buffer.write('{}\n'.format('\t'.join([data[field] for field in self.fields])).encode('utf-8'))
 

	
 
    def prep_data_for_postgres(self, data):
 
        data = dict(data)
 
        for key, value in data.iteritems():
 
            data[key] = self.prep_value_for_postgres(value)
 
        return data
 

	
 
    def prep_value_for_postgres(self, value):
 
        if value is None:
 
            return '\\N'
 
        if value is True:
 
            return 't'
 
        if value is False:
 
            return 'f'
 

	
 
        if isinstance(value, datetime.datetime):
 
            value = make_utc(value, tzinfo=False)
 
        elif isinstance(value, basestring):
 
            value = value.replace('\\', '\\\\')
 
            value = value.replace('\r', '\\r')
 
            value = value.replace('\n', '\\n')
 
            value = value.replace('\t', '\\t') # TODO: add test for this
 

	
 
        return unicode(value)
 

	
 
    def commit_create(self):
 
    def flush_create(self):
 
        log.info("copying {} data from buffer to PostgreSQL".format(self.model_name))
 
        self.data_buffer.close()
 
        self.data_buffer = open(self.data_path, 'rb')
 
        cursor = self.session.connection().connection.cursor()
 
        table_name = '"{}"'.format(self.model_table.name)
 
        cursor.copy_from(self.data_buffer, table_name, columns=self.fields)
 
        log.debug("PostgreSQL data copy completed")
rattail/importing/rattail_bulk.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2016 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
#  Rattail is free software: you can redistribute it and/or modify it under the
 
#  terms of the GNU Affero General Public License as published by the Free
 
#  Software Foundation, either version 3 of the License, or (at your option)
 
#  any later version.
 
#
 
#  Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
 
#  WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 
#  FOR A PARTICULAR PURPOSE.  See the GNU Affero General Public License for
 
#  more details.
 
#
 
#  You should have received a copy of the GNU Affero General Public License
 
#  along with Rattail.  If not, see <http://www.gnu.org/licenses/>.
 
#
 
################################################################################
 
"""
 
Rattail -> Rattail bulk data import
 
"""
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from rattail import importing
 
from rattail.util import OrderedDict
 
from rattail.importing.rattail import FromRattailToRattail, FromRattail
 

	
 

	
 
class BulkFromRattailToRattail(FromRattailToRattail, importing.BulkToPostgreSQLHandler):
 
class BulkFromRattailToRattail(FromRattailToRattail, importing.BulkImportHandler):
 
    """
 
    Handler for Rattail -> Rattail bulk data import.
 
    """
 

	
 
    def get_importers(self):
 
        importers = OrderedDict()
 
        importers['Person'] = PersonImporter
 
        importers['PersonEmailAddress'] = PersonEmailAddressImporter
 
        importers['PersonPhoneNumber'] = PersonPhoneNumberImporter
 
        importers['PersonMailingAddress'] = PersonMailingAddressImporter
 
        importers['User'] = UserImporter
 
        importers['Message'] = MessageImporter
 
        importers['MessageRecipient'] = MessageRecipientImporter
 
        importers['Store'] = StoreImporter
 
        importers['StorePhoneNumber'] = StorePhoneNumberImporter
 
        importers['Employee'] = EmployeeImporter
 
        importers['EmployeeStore'] = EmployeeStoreImporter
 
        importers['EmployeeEmailAddress'] = EmployeeEmailAddressImporter
 
        importers['EmployeePhoneNumber'] = EmployeePhoneNumberImporter
 
        importers['ScheduledShift'] = ScheduledShiftImporter
 
        importers['WorkedShift'] = WorkedShiftImporter
 
        importers['Customer'] = CustomerImporter
 
        importers['CustomerGroup'] = CustomerGroupImporter
 
        importers['CustomerGroupAssignment'] = CustomerGroupAssignmentImporter
 
        importers['CustomerPerson'] = CustomerPersonImporter
 
        importers['CustomerEmailAddress'] = CustomerEmailAddressImporter
 
        importers['CustomerPhoneNumber'] = CustomerPhoneNumberImporter
 
        importers['Vendor'] = VendorImporter
 
        importers['VendorEmailAddress'] = VendorEmailAddressImporter
 
        importers['VendorPhoneNumber'] = VendorPhoneNumberImporter
 
        importers['VendorContact'] = VendorContactImporter
 
        importers['Department'] = DepartmentImporter
 
        importers['EmployeeDepartment'] = EmployeeDepartmentImporter
 
        importers['Subdepartment'] = SubdepartmentImporter
 
        importers['Category'] = CategoryImporter
 
        importers['Family'] = FamilyImporter
 
        importers['ReportCode'] = ReportCodeImporter
 
        importers['DepositLink'] = DepositLinkImporter
 
        importers['Tax'] = TaxImporter
 
        importers['Brand'] = BrandImporter
 
        importers['Product'] = ProductImporter
 
        importers['ProductCode'] = ProductCodeImporter
 
        importers['ProductCost'] = ProductCostImporter
 
        importers['ProductPrice'] = ProductPriceImporter
 
        return importers
 

	
 

	
 
class BulkFromRattail(FromRattail, importing.BulkToPostgreSQL):
rattail/tests/importing/test_handlers.py
Show inline comments
 
# -*- coding: utf-8 -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
import datetime
 
import unittest
 

	
 
import pytz
 
from sqlalchemy import orm
 
from mock import patch, Mock
 
from fixture import TempIO
 

	
 
from rattail.db import Session
 
from rattail.importing import handlers, Importer
 
from rattail.config import RattailConfig
 
from rattail.tests import RattailTestCase
 
from rattail.tests.importing import ImporterTester
 
from rattail.tests.importing.test_importers import MockImporter
 
from rattail.tests.importing.test_postgresql import MockBulkImporter
 

	
 

	
 
class TestImportHandler(unittest.TestCase):
 
class ImportHandlerBattery(ImporterTester):
 

	
 
    def test_init(self):
 

	
 
        # vanilla
 
        handler = handlers.ImportHandler()
 
        handler = self.handler_class()
 
        self.assertEqual(handler.importers, {})
 
        self.assertEqual(handler.get_importers(), {})
 
        self.assertEqual(handler.get_importer_keys(), [])
 
        self.assertEqual(handler.get_default_keys(), [])
 
        self.assertFalse(handler.commit_host_partial)
 

	
 
        # with config
 
        handler = handlers.ImportHandler()
 
        handler = self.handler_class()
 
        self.assertIsNone(handler.config)
 
        config = RattailConfig()
 
        handler = handlers.ImportHandler(config=config)
 
        handler = self.handler_class(config=config)
 
        self.assertIs(handler.config, config)
 

	
 
        # dry run
 
        handler = handlers.ImportHandler()
 
        handler = self.handler_class()
 
        self.assertFalse(handler.dry_run)
 
        handler = handlers.ImportHandler(dry_run=True)
 
        handler = self.handler_class(dry_run=True)
 
        self.assertTrue(handler.dry_run)
 

	
 
        # extra kwarg
 
        handler = handlers.ImportHandler()
 
        handler = self.handler_class()
 
        self.assertRaises(AttributeError, getattr, handler, 'foo')
 
        handler = handlers.ImportHandler(foo='bar')
 
        handler = self.handler_class(foo='bar')
 
        self.assertEqual(handler.foo, 'bar')
 

	
 
    def test_get_importer(self):
 
        get_importers = Mock(return_value={'foo': Importer})
 

	
 
        # no importers
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        self.assertIsNone(handler.get_importer('foo'))
 

	
 
        # no config
 
        with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
 
            handler = handlers.ImportHandler()
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class()
 
        importer = handler.get_importer('foo')
 
        self.assertIs(type(importer), Importer)
 
        self.assertIsNone(importer.config)
 
        self.assertIs(importer.handler, handler)
 

	
 
        # with config
 
        config = RattailConfig()
 
        with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
 
            handler = handlers.ImportHandler(config=config)
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class(config=config)
 
        importer = handler.get_importer('foo')
 
        self.assertIs(type(importer), Importer)
 
        self.assertIs(importer.config, config)
 
        self.assertIs(importer.handler, handler)
 

	
 
        # dry run
 
        with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
 
            handler = handlers.ImportHandler()
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class()
 
        importer = handler.get_importer('foo')
 
        self.assertFalse(importer.dry_run)
 
        with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
 
            handler = handlers.ImportHandler(dry_run=True)
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class(dry_run=True)
 
        importer = handler.get_importer('foo')
 
        self.assertTrue(handler.dry_run)
 

	
 
        # host title
 
        with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
 
            handler = handlers.ImportHandler()
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class()
 
        importer = handler.get_importer('foo')
 
        self.assertIsNone(importer.host_system_title)
 
        handler.host_title = "Foo"
 
        importer = handler.get_importer('foo')
 
        self.assertEqual(importer.host_system_title, "Foo")
 

	
 
        # extra kwarg
 
        with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
 
            handler = handlers.ImportHandler()
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class()
 
        importer = handler.get_importer('foo')
 
        self.assertRaises(AttributeError, getattr, importer, 'bar')
 
        importer = handler.get_importer('foo', bar='baz')
 
        self.assertEqual(importer.bar, 'baz')
 

	
 
    def test_get_importer_kwargs(self):
 

	
 
        # empty by default
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        self.assertEqual(handler.get_importer_kwargs('foo'), {})
 

	
 
        # extra kwargs are preserved
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        self.assertEqual(handler.get_importer_kwargs('foo', bar='baz'), {'bar': 'baz'})
 

	
 
    def test_begin_transaction(self):
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        with patch.object(handler, 'begin_host_transaction') as begin_host:
 
            with patch.object(handler, 'begin_local_transaction') as begin_local:
 
                handler.begin_transaction()
 
                begin_host.assert_called_once_with()
 
                begin_local.assert_called_once_with()
 

	
 
    def test_begin_host_transaction(self):
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        handler.begin_host_transaction()
 

	
 
    def test_begin_local_transaction(self):
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        handler.begin_local_transaction()
 

	
 
    def test_commit_transaction(self):
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        with patch.object(handler, 'commit_host_transaction') as commit_host:
 
            with patch.object(handler, 'commit_local_transaction') as commit_local:
 
                handler.commit_transaction()
 
                commit_host.assert_called_once_with()
 
                commit_local.assert_called_once_with()
 

	
 
    def test_commit_host_transaction(self):
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        handler.commit_host_transaction()
 

	
 
    def test_commit_local_transaction(self):
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        handler.commit_local_transaction()
 

	
 
    def test_rollback_transaction(self):
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        with patch.object(handler, 'rollback_host_transaction') as rollback_host:
 
            with patch.object(handler, 'rollback_local_transaction') as rollback_local:
 
                handler.rollback_transaction()
 
                rollback_host.assert_called_once_with()
 
                rollback_local.assert_called_once_with()
 

	
 
    def test_rollback_host_transaction(self):
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        handler.rollback_host_transaction()
 

	
 
    def test_rollback_local_transaction(self):
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        handler.rollback_local_transaction()
 

	
 
    def test_import_data(self):
 

	
 
        # normal
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        result = handler.import_data()
 
        self.assertEqual(result, {})
 

	
 
    def test_import_data_dry_run(self):
 

	
 
        # as init kwarg
 
        handler = handlers.ImportHandler(dry_run=True)
 
        handler = self.make_handler(dry_run=True)
 
        with patch.object(handler, 'commit_transaction') as commit:
 
            with patch.object(handler, 'rollback_transaction') as rollback:
 
                handler.import_data()
 
                self.assertFalse(commit.called)
 
                rollback.assert_called_once_with()
 
        self.assertTrue(handler.dry_run)
 

	
 
        # as import kwarg
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        with patch.object(handler, 'commit_transaction') as commit:
 
            with patch.object(handler, 'rollback_transaction') as rollback:
 
                handler.import_data(dry_run=True)
 
                self.assertFalse(commit.called)
 
                rollback.assert_called_once_with()
 
        self.assertTrue(handler.dry_run)
 

	
 
    def test_import_data_invalid_model(self):
 
        handler = self.make_handler()
 
        importer = Mock()
 
        importer.import_data.return_value = [], [], []
 
        FooImporter = Mock(return_value=importer)
 

	
 
        handler = handlers.ImportHandler()
 
        handler.importers = {'Foo': FooImporter}
 

	
 
        handler.import_data('Foo')
 
        self.assertEqual(FooImporter.call_count, 1)
 
        importer.import_data.assert_called_once_with()
 

	
 
        FooImporter.reset_mock()
 
        importer.reset_mock()
 

	
 
        handler.import_data('Missing')
 
        self.assertFalse(FooImporter.called)
 
        self.assertFalse(importer.called)
 

	
 
    def test_import_data_with_changes(self):
 
        handler = self.make_handler()
 
        importer = Mock()
 
        FooImporter = Mock(return_value=importer)
 

	
 
        handler = handlers.ImportHandler()
 
        handler.importers = {'Foo': FooImporter}
 

	
 
        importer.import_data.return_value = [], [], []
 
        with patch.object(handler, 'process_changes') as process:
 
            handler.import_data('Foo')
 
            self.assertFalse(process.called)
 

	
 
        importer.import_data.return_value = [1], [2], [3]
 
        with patch.object(handler, 'process_changes') as process:
 
            handler.import_data('Foo')
 
            process.assert_called_once_with({'Foo': ([1], [2], [3])})
 

	
 
    def test_import_data_commit_host_partial(self):
 
        handler = self.make_handler()
 
        importer = Mock()
 
        importer.import_data.side_effect = ValueError
 
        FooImporter = Mock(return_value=importer)
 

	
 
        handler = handlers.ImportHandler()
 
        handler.importers = {'Foo': FooImporter}
 

	
 
        handler.commit_host_partial = False
 
        with patch.object(handler, 'commit_host_transaction') as commit:
 
            self.assertRaises(ValueError, handler.import_data, 'Foo')
 
            self.assertFalse(commit.called)
 

	
 
        handler.commit_host_partial = True
 
        with patch.object(handler, 'commit_host_transaction') as commit:
 
            self.assertRaises(ValueError, handler.import_data, 'Foo')
 
            commit.assert_called_once_with()
 

	
 

	
 
class BulkImportHandlerBattery(ImportHandlerBattery):
 

	
 
    def test_import_data_invalid_model(self):
 
        handler = self.make_handler()
 
        importer = Mock()
 
        importer.import_data.return_value = 0
 
        FooImporter = Mock(return_value=importer)
 
        handler.importers = {'Foo': FooImporter}
 

	
 
        handler.import_data('Foo')
 
        self.assertEqual(FooImporter.call_count, 1)
 
        importer.import_data.assert_called_once_with()
 

	
 
        FooImporter.reset_mock()
 
        importer.reset_mock()
 

	
 
        handler.import_data('Missing')
 
        self.assertFalse(FooImporter.called)
 
        self.assertFalse(importer.called)
 

	
 
    def test_import_data_with_changes(self):
 
        handler = self.make_handler()
 
        importer = Mock()
 
        FooImporter = Mock(return_value=importer)
 
        handler.importers = {'Foo': FooImporter}
 

	
 
        importer.import_data.return_value = 0
 
        with patch.object(handler, 'process_changes') as process:
 
            handler.import_data('Foo')
 
            self.assertFalse(process.called)
 

	
 
        importer.import_data.return_value = 3
 
        with patch.object(handler, 'process_changes') as process:
 
            handler.import_data('Foo')
 
            self.assertFalse(process.called)
 

	
 

	
 
class TestImportHandler(unittest.TestCase, ImportHandlerBattery):
 
    handler_class = handlers.ImportHandler
 

	
 
    @patch('rattail.importing.handlers.send_email')
 
    def test_process_changes_sends_email(self, send_email):
 
        handler = handlers.ImportHandler()
 
        handler.import_began = pytz.utc.localize(datetime.datetime.utcnow())
 
        changes = [], [], []
 

	
 
        # warnings disabled
 
        handler.warnings = False
 
        handler.process_changes(changes)
 
        self.assertFalse(send_email.called)
 

	
 
        # warnings enabled
 
        handler.warnings = True
 
        handler.process_changes(changes)
 
        self.assertEqual(send_email.call_count, 1)
 

	
 
        send_email.reset_mock()
 

	
 
        # warnings enabled, with command (just for coverage..)
 
        handler.warnings = True
 
        handler.command = Mock(name='import-testing', parent=Mock(name='rattail'))
 
        handler.process_changes(changes)
 
        self.assertEqual(send_email.call_count, 1)
 

	
 

	
 
class TestBulkImportHandler(unittest.TestCase, BulkImportHandlerBattery):
 
    handler_class = handlers.BulkImportHandler
 

	
 

	
 
######################################################################
 
# fake import handler, tested mostly for basic coverage
 
######################################################################
 

	
 
class MockImportHandler(handlers.ImportHandler):
 

	
 
    def get_importers(self):
 
        return {'Product': MockImporter}
 

	
 
    def import_data(self, *keys, **kwargs):
 
        result = super(MockImportHandler, self).import_data(*keys, **kwargs)
 
        self._result = result
 
        return result
 

	
 

	
 
class TestImportHandlerImportData(ImporterTester, unittest.TestCase):
 

	
 
    sample_data = {
 
        '16oz': {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"},
 
        '32oz': {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"},
 
        '1gal': {'upc': '00074305011283', 'description': "Apple Cider Vinegar 1gal"},
 
    }
 

	
 
    def setUp(self):
 
        self.config = RattailConfig()
 
        self.handler = MockImportHandler(config=self.config)
 
        self.importer = MockImporter(config=self.config)
 

	
 
    def import_data(self, **kwargs):
 
        # must modify our importer in-place since we need the handler to return
 
        # that specific instance, below (because the host/local data context
 
        # managers reference that instance directly)
 
        self.importer._setup(**kwargs)
 
        with patch.object(self.handler, 'get_importer', Mock(return_value=self.importer)):
 
            result = self.handler.import_data('Product', **kwargs)
 
        if result:
 
            self.result = result['Product']
 
        else:
 
            self.result = [], [], []
 

	
 
    def test_invalid_importer_key_is_ignored(self):
 
        handler = handlers.ImportHandler()
 
        self.assertNotIn('InvalidKey', handler.importers)
 
        self.assertEqual(handler.import_data('InvalidKey'), {})
 

	
 
    def test_create(self):
 
        local = self.copy_data()
 
        del local['32oz']
 
@@ -521,96 +561,96 @@ class TestToSQLAlchemyHandler(unittest.TestCase):
 
        handler = handlers.ToSQLAlchemyHandler(session=session)
 
        kwargs = handler.get_importer_kwargs(None)
 
        self.assertEqual(list(kwargs.iterkeys()), ['session'])
 
        self.assertIs(kwargs['session'], session)
 

	
 
    def test_begin_local_transaction(self):
 
        handler = MockToSQLAlchemyHandler()
 
        self.assertIsNone(handler.session)
 
        handler.begin_local_transaction()
 
        self.assertIsInstance(handler.session, orm.Session)
 
        handler.session.close()
 

	
 
    def test_commit_local_transaction(self):
 
        # TODO: test actual commit for data changes
 
        session = Session()
 
        handler = handlers.ToSQLAlchemyHandler(session=session)
 
        self.assertIs(handler.session, session)
 
        with patch.object(handler, 'session') as session:
 
            handler.commit_local_transaction()
 
            session.commit.assert_called_once_with()
 
            self.assertFalse(session.rollback.called)
 
        # self.assertIsNone(handler.session)
 

	
 
    def test_rollback_local_transaction(self):
 
        # TODO: test actual rollback for data changes
 
        session = Session()
 
        handler = handlers.ToSQLAlchemyHandler(session=session)
 
        self.assertIs(handler.session, session)
 
        with patch.object(handler, 'session') as session:
 
            handler.rollback_local_transaction()
 
            session.rollback.assert_called_once_with()
 
            self.assertFalse(session.commit.called)
 
        # self.assertIsNone(handler.session)
 

	
 

	
 
######################################################################
 
# fake bulk import handler, tested mostly for basic coverage
 
######################################################################
 

	
 
class MockBulkImportHandler(handlers.BulkToPostgreSQLHandler):
 

	
 
    def get_importers(self):
 
        return {'Department': MockBulkImporter}
 

	
 
    def make_session(self):
 
        return Session()
 

	
 

	
 
class TestBulkImportHandler(RattailTestCase, ImporterTester):
 
class TestBulkImportHandlerOld(RattailTestCase, ImporterTester):
 

	
 
    importer_class = MockBulkImporter
 

	
 
    sample_data = {
 
        'grocery': {'number': 1, 'name': "Grocery", 'uuid': 'decd909a194011e688093ca9f40bc550'},
 
        'bulk': {'number': 2, 'name': "Bulk", 'uuid': 'e633d54c194011e687e33ca9f40bc550'},
 
        'hba': {'number': 3, 'name': "HBA", 'uuid': 'e2bad79e194011e6a4783ca9f40bc550'},
 
    }
 

	
 
    def setUp(self):
 
        self.setup_rattail()
 
        self.tempio = TempIO()
 
        self.config.set('rattail', 'workdir', self.tempio.realpath())
 
        self.handler = MockBulkImportHandler(config=self.config)
 

	
 
    def tearDown(self):
 
        self.teardown_rattail()
 
        self.tempio = None
 

	
 
    def import_data(self, host_data=None, **kwargs):
 
        if host_data is None:
 
            host_data = list(self.copy_data().itervalues())
 
        with patch.object(self.importer_class, 'normalize_host_data', Mock(return_value=host_data)):
 
            with patch.object(self.handler, 'make_session', Mock(return_value=self.session)):
 
                return self.handler.import_data('Department', **kwargs)
 

	
 
    def test_invalid_importer_key_is_ignored(self):
 
        handler = MockBulkImportHandler()
 
        self.assertNotIn('InvalidKey', handler.importers)
 
        self.assertEqual(handler.import_data('InvalidKey'), {})
 

	
 
    def assert_import_created(self, *keys):
 
        pass
 

	
 
    def assert_import_updated(self, *keys):
 
        pass
 

	
 
    def assert_import_deleted(self, *keys):
 
        pass
 

	
 
    def test_normal_run(self):
 
        if self.postgresql():
 
            self.import_data()
 

	
 
    def test_dry_run(self):
 
        if self.postgresql():
 
            self.import_data(dry_run=True)
rattail/tests/importing/test_importers.py
Show inline comments
 
# -*- coding: utf-8 -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from unittest import TestCase
 

	
 
from mock import Mock, patch
 
from mock import Mock, patch, call
 

	
 
from rattail.db import model
 
from rattail.db.util import QuerySequence
 
from rattail.importing import importers
 
from rattail.tests import NullProgress, RattailTestCase
 
from rattail.tests.importing import ImporterTester
 

	
 

	
 
class ImporterBattery(ImporterTester):
 
    """
 
    Battery of tests which can hopefully be ran for any non-bulk importer.
 
    """
 

	
 
    def test_import_data_empty(self):
 
        importer = self.make_importer()
 
        result = importer.import_data()
 
        self.assertEqual(result, {})
 

	
 
    def test_import_data_dry_run(self):
 
        importer = self.make_importer()
 
        self.assertFalse(importer.dry_run)
 
        importer.import_data(dry_run=True)
 
        self.assertTrue(importer.dry_run)
 

	
 
    def test_import_data_create(self):
 
        importer = self.make_importer()
 
        with patch.object(importer, 'get_key', lambda k: k):
 
            with patch.object(importer, 'create_object') as create:
 
                importer.import_data(host_data=[1, 2, 3])
 
                self.assertEqual(create.call_args_list, [
 
                    call(1, 1), call(2, 2), call(3, 3)])
 

	
 
    def test_import_data_max_create(self):
 
        importer = self.make_importer()
 
        with patch.object(importer, 'get_key', lambda k: k):
 
            with patch.object(importer, 'create_object') as create:
 
                importer.import_data(host_data=[1, 2, 3], max_create=1)
 
                self.assertEqual(create.call_args_list, [call(1, 1)])
 

	
 

	
 
class BulkImporterBattery(ImporterBattery):
 
    """
 
    Battery of tests which can hopefully be ran for any bulk importer.
 
    """
 

	
 
    def test_import_data_empty(self):
 
        importer = self.make_importer()
 
        result = importer.import_data()
 
        self.assertEqual(result, 0)
 

	
 

	
 
class TestImporter(TestCase):
 

	
 
    def test_init(self):
 
        importer = importers.Importer()
 
        self.assertIsNone(importer.model_class)
 
        self.assertIsNone(importer.model_name)
 
        self.assertIsNone(importer.key)
 
        self.assertEqual(importer.fields, [])
 
        self.assertIsNone(importer.host_system_title)
 

	
 
        # key must be included among the fields
 
        self.assertRaises(ValueError, importers.Importer, key='upc', fields=[])
 
        importer = importers.Importer(key='upc', fields=['upc'])
 
        self.assertEqual(importer.key, ('upc',))
 
        self.assertEqual(importer.fields, ['upc'])
 

	
 
        # extra bits are passed as-is
 
        importer = importers.Importer()
 
        self.assertFalse(hasattr(importer, 'extra_bit'))
 
        extra_bit = object()
 
        importer = importers.Importer(extra_bit=extra_bit)
 
        self.assertIs(importer.extra_bit, extra_bit)
 

	
 
    def test_delete_flag(self):
 
        # disabled by default
 
        importer = importers.Importer()
 
        self.assertTrue(importer.allow_delete)
 
        self.assertFalse(importer.delete)
 
        importer.import_data(host_data=[])
 
        self.assertFalse(importer.delete)
 

	
 
        # but can be enabled
 
        importer = importers.Importer(delete=True)
 
        self.assertTrue(importer.allow_delete)
 
        self.assertTrue(importer.delete)
 
        importer = importers.Importer()
 
        self.assertFalse(importer.delete)
 
        importer.import_data(host_data=[], delete=True)
 
        self.assertTrue(importer.delete)
 

	
 
    def test_get_host_objects(self):
 
        importer = importers.Importer()
 
        objects = importer.get_host_objects()
 
        self.assertEqual(objects, [])
 

	
 
    def test_cache_local_data(self):
 
        importer = importers.Importer()
 
        self.assertRaises(NotImplementedError, importer.cache_local_data)
 
@@ -119,96 +162,101 @@ class TestImporter(TestCase):
 
        importer = importers.Importer(key='upc', fields=['upc', 'description'],
 
                                          progress=NullProgress)
 

	
 
        data = [
 
            {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"},
 
            {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"},
 
        ]
 

	
 
        host_data = importer.normalize_host_data(host_objects=[])
 
        self.assertEqual(host_data, [])
 

	
 
        host_data = importer.normalize_host_data(host_objects=data)
 
        self.assertEqual(host_data, data)
 

	
 
        with patch.object(importer, 'get_host_objects', new=Mock(return_value=data)):
 
            host_data = importer.normalize_host_data()
 
        self.assertEqual(host_data, data)
 

	
 
    def test_get_deletion_keys(self):
 
        importer = importers.Importer()
 
        self.assertFalse(importer.caches_local_data)
 
        keys = importer.get_deletion_keys()
 
        self.assertEqual(keys, set())
 

	
 
        importer.caches_local_data = True
 
        self.assertIsNone(importer.cached_local_data)
 
        keys = importer.get_deletion_keys()
 
        self.assertEqual(keys, set())
 

	
 
        importer.cached_local_data = {'delete-me': object()}
 
        keys = importer.get_deletion_keys()
 
        self.assertEqual(keys, set(['delete-me']))
 

	
 

	
 
class TestFromQuery(RattailTestCase):
 

	
 
    def test_query(self):
 
        importer = importers.FromQuery()
 
        self.assertRaises(NotImplementedError, importer.query)
 

	
 
    def test_get_host_objects(self):
 
        query = self.session.query(model.Product)
 
        importer = importers.FromQuery()
 
        with patch.object(importer, 'query', Mock(return_value=query)):
 
            objects = importer.get_host_objects()
 
        self.assertIsInstance(objects, QuerySequence)
 

	
 

	
 
class TestBulkImporter(TestCase, BulkImporterBattery):
 
    importer_class = importers.BulkImporter
 
        
 

	
 

	
 
######################################################################
 
# fake importer class, tested mostly for basic coverage
 
######################################################################
 

	
 
class Product(object):
 
    upc = None
 
    description = None
 

	
 

	
 
class MockImporter(importers.Importer):
 
    model_class = Product
 
    key = 'upc'
 
    simple_fields = ['upc', 'description']
 
    supported_fields = simple_fields
 
    caches_local_data = True
 
    flush_every_x = 1
 
    session = Mock()
 

	
 
    def normalize_local_object(self, obj):
 
        return obj
 

	
 
    def update_object(self, obj, host_data, local_data=None):
 
        return host_data
 

	
 

	
 
class TestMockImporter(ImporterTester, TestCase):
 
    importer_class = MockImporter
 

	
 
    sample_data = {
 
        '16oz': {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"},
 
        '32oz': {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"},
 
        '1gal': {'upc': '00074305011283', 'description': "Apple Cider Vinegar 1gal"},
 
    }
 

	
 
    def setUp(self):
 
        self.importer = self.make_importer()
 

	
 
    def test_create(self):
 
        local = self.copy_data()
 
        del local['32oz']
 
        self.import_data(local_data=local)
 
        self.assert_import_created('32oz')
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_create_empty(self):
 
        self.import_data(host_data={}, local_data={})
 
        self.assert_import_created()
0 comments (0 inline, 0 general)