Changeset - 22e10accf5f0
[Not reviewed]
0 10 0
Lance Edgar - 8 years ago 2016-05-17 15:02:08
ledgar@sacfoodcoop.com
Add "batch" support to importer framework

Doesn't really split things up per se, just controls how often a "flush
changes" should occur.
10 files changed with 294 insertions and 147 deletions:
0 comments (0 inline, 0 general)
rattail/commands/importing.py
Show inline comments
 
@@ -41,12 +41,14 @@ class ImportSubcommand(Subcommand):
 
    """
 
    handler_spec = None
 

	
 
    # TODO: move this into Subcommand or something..
 
    parent_name = None
 
    def __init__(self, *args, **kwargs):
 
        if 'handler_spec' in kwargs:
 
            self.handler_spec = kwargs.pop('handler_spec')
 
        super(ImportSubcommand, self).__init__(*args, **kwargs)
 
        if self.parent:
 
            self.parent_name = self.parent.name
 

	
 
    def get_handler_factory(self):
 
        """
 
@@ -65,12 +67,14 @@ class ImportSubcommand(Subcommand):
 
        kwargs.setdefault('config', getattr(self, 'config', None))
 
        kwargs.setdefault('command', self)
 
        kwargs.setdefault('progress', self.progress)
 
        if 'args' in kwargs:
 
            args = kwargs['args']
 
            kwargs.setdefault('dry_run', args.dry_run)
 
            if hasattr(args, 'batch_size'):
 
                kwargs.setdefault('batch_size', args.batch_size)
 
            # kwargs.setdefault('max_create', args.max_create)
 
            # kwargs.setdefault('max_update', args.max_update)
 
            # kwargs.setdefault('max_delete', args.max_delete)
 
            # kwargs.setdefault('max_total', args.max_total)
 
        kwargs = self.get_handler_kwargs(**kwargs)
 
        return factory(**kwargs)
 
@@ -136,12 +140,19 @@ class ImportSubcommand(Subcommand):
 
        # max total changes, per model
 
        parser.add_argument('--max-total', type=int, metavar='COUNT',
 
                            help="Maximum number of *any* record changes which may occur, after which "
 
                            "a given import task should stop.  Note that this applies on a per-model "
 
                            "basis and not overall.")
 

	
 
        # batch size
 
        parser.add_argument('--batch', type=int, dest='batch_size', metavar='SIZE', default=200,
 
                            help="Split work to be done into batches, with the specified number of "
 
                            "records in each batch.  Or, set this to 0 (zero) to disable batching. "
 
                            "Implementation for this may vary somewhat between importers; default "
 
                            "batch size is 200 records.")
 

	
 
        # treat changes as warnings?
 
        parser.add_argument('--warnings', '-W', action='store_true',
 
                            help="Set this flag if you expect a \"clean\" import, and wish for any "
 
                            "changes which do occur to be processed further and/or specially.  The "
 
                            "behavior of this flag is ultimately up to the import handler, but the "
 
                            "default is to send an email notification.")
rattail/importing/handlers.py
Show inline comments
 
@@ -88,12 +88,14 @@ class ImportHandler(object):
 
        Returns an importer instance corresponding to the given key.
 
        """
 
        if key in self.importers:
 
            kwargs.setdefault('handler', self)
 
            kwargs.setdefault('config', self.config)
 
            kwargs.setdefault('host_system_title', self.host_title)
 
            if hasattr(self, 'batch_size'):
 
                kwargs.setdefault('batch_size', self.batch_size)
 
            kwargs = self.get_importer_kwargs(key, **kwargs)
 
            return self.importers[key](**kwargs)
 

	
 
    def get_importer_kwargs(self, key, **kwargs):
 
        """
 
        Return a dict of kwargs to be used when construcing an importer with
 
@@ -336,12 +338,6 @@ class ToSQLAlchemyHandler(ImportHandler):
 
        self.session = None
 

	
 
    def commit_local_transaction(self):
 
        self.session.commit()
 
        self.session.close()
 
        self.session = None
 

	
 

	
 
class BulkToPostgreSQLHandler(BulkImportHandler):
 
    """
 
    Handler for bulk imports which target PostgreSQL on the local side.
 
    """
rattail/importing/importers.py
Show inline comments
 
@@ -39,12 +39,13 @@ log = logging.getLogger(__name__)
 
class Importer(object):
 
    """
 
    Base class for all data importers.
 
    """
 
    # Set this to the data model class which is targeted on the local side.
 
    model_class = None
 
    model_name = None
 

	
 
    key = None
 

	
 
    # The full list of field names supported by the importer, i.e. for the data
 
    # model to which the importer pertains.  By definition this list will be
 
    # restricted to what the local side can acommodate, but may be further
 
@@ -62,12 +63,13 @@ class Importer(object):
 
    dry_run = False
 

	
 
    max_create = None
 
    max_update = None
 
    max_delete = None
 
    max_total = None
 
    batch_size = 200
 
    progress = None
 

	
 
    empty_local_data = False
 
    caches_local_data = False
 
    cached_local_data = None
 

	
 
@@ -83,26 +85,25 @@ class Importer(object):
 
            self.key = (self.key,)
 
        if self.key:
 
            for field in self.key:
 
                if field not in self.fields:
 
                    raise ValueError("Key field '{}' must be included in effective fields "
 
                                     "for {}".format(field, self.__class__.__name__))
 
        self.model_class = kwargs.pop('model_class', self.model_class)
 
        self.model_name = kwargs.pop('model_name', self.model_name)
 
        if not self.model_name and self.model_class:
 
            self.model_name = self.model_class.__name__
 
        self._setup(**kwargs)
 

	
 
    def _setup(self, **kwargs):
 
        self.create = kwargs.pop('create', self.allow_create) and self.allow_create
 
        self.update = kwargs.pop('update', self.allow_update) and self.allow_update
 
        self.delete = kwargs.pop('delete', False) and self.allow_delete
 
        for key, value in kwargs.iteritems():
 
            setattr(self, key, value)
 

	
 
    @property
 
    def model_name(self):
 
        if self.model_class:
 
            return self.model_class.__name__
 

	
 
    def setup(self):
 
        """
 
        Perform any setup necessary, e.g. cache lookups for existing data.
 
        """
 

	
 
    def teardown(self):
 
@@ -205,31 +206,35 @@ class Importer(object):
 
                    log.warning("max of {} *created* records has been reached; stopping now".format(self.max_create))
 
                    break
 
                if self.max_total and (len(created) + len(updated)) >= self.max_total:
 
                    log.warning("max of {} *total changes* has been reached; stopping now".format(self.max_total))
 
                    break
 

	
 
            self.flush_changes(i)
 
            # # TODO: this needs to be customizable etc. somehow maybe..
 
            # if i % 100 == 0 and hasattr(self, 'session'):
 
            #     self.session.flush()
 
            # flush changes every so often
 
            if not self.batch_size or (len(created) + len(updated)) % self.batch_size == 0:
 
                self.flush_create_update()
 

	
 
            if prog:
 
                prog.update(i)
 
        if prog:
 
            prog.destroy()
 

	
 
        self.flush_create_update_final()
 
        return created, updated
 

	
 
    # TODO: this surely goes elsewhere
 
    flush_every_x = 100
 
    def flush_create_update(self):
 
        """
 
        Perform any steps necessary to "flush" the create/update changes which
 
        have occurred thus far in the import.
 
        """
 

	
 
    def flush_changes(self, x):
 
        if self.flush_every_x and x % self.flush_every_x == 0:
 
            if hasattr(self, 'session'):
 
                self.session.flush()
 
    def flush_create_update_final(self):
 
        """
 
        Perform any final steps to "flush" the created/updated data here.
 
        """
 
        self.flush_create_update()
 

	
 
    def _import_delete(self, host_data, host_keys, changes=0):
 
        """
 
        Import deletions for the given data set.
 
        """
 
        deleted = []
 
@@ -457,13 +462,13 @@ class FromQuery(Importer):
 
        """
 
        return QuerySequence(self.query())
 

	
 

	
 
class BulkImporter(Importer):
 
    """
 
    Base class for bulk data importers which target PostgreSQL on the local side.
 
    Base class for bulk data importers.
 
    """
 

	
 
    def import_data(self, host_data=None, now=None, **kwargs):
 
        self.now = now or make_utc(datetime.datetime.utcnow(), tzinfo=True)
 
        if kwargs:
 
            self._setup(**kwargs)
 
@@ -490,20 +495,17 @@ class BulkImporter(Importer):
 
            self.create_object(key, host_data)
 
            if self.max_create and i >= self.max_create:
 
                log.warning("max of {} *created* records has been reached; stopping now".format(self.max_create))
 
                created = i
 
                break
 

	
 
            # flush changes every so often
 
            if not self.batch_size or i % self.batch_size == 0:
 
                self.flush_create_update()
 

	
 
            if prog:
 
                prog.update(i)
 
        if prog:
 
            prog.destroy()
 

	
 
        self.flush_create()
 
        self.flush_create_update_final()
 
        return created
 

	
 
    def flush_create(self):
 
        """
 
        Perform any final steps to "flush" the created data here.  Note that
 
        the importer's handler is still responsible for actually committing
 
        changes to the local system, if applicable.
 
        """
rattail/importing/postgresql.py
Show inline comments
 
@@ -80,13 +80,16 @@ class BulkToPostgreSQL(BulkImporter, ToSQLAlchemy):
 
            value = value.replace('\r', '\\r')
 
            value = value.replace('\n', '\\n')
 
            value = value.replace('\t', '\\t') # TODO: add test for this
 

	
 
        return unicode(value)
 

	
 
    def flush_create(self):
 
    def flush_create_update(self):
 
        pass
 

	
 
    def flush_create_update_final(self):
 
        log.info("copying {} data from buffer to PostgreSQL".format(self.model_name))
 
        self.data_buffer.close()
 
        self.data_buffer = open(self.data_path, 'rb')
 
        cursor = self.session.connection().connection.cursor()
 
        table_name = '"{}"'.format(self.model_table.name)
 
        cursor.copy_from(self.data_buffer, table_name, columns=self.fields)
rattail/importing/sqlalchemy.py
Show inline comments
 
@@ -167,6 +167,13 @@ class ToSQLAlchemy(Importer):
 
                                normalizer=self.normalize_cache_object)
 

	
 
    def cache_query_options(self):
 
        """
 
        Return a list of options to apply to the cache query, if needed.
 
        """
 

	
 
    def flush_create_update(self):
 
        """
 
        Flush the database session, to send SQL to the server for all changes
 
        made thus far.
 
        """
 
        self.session.flush()
rattail/tests/commands/test_importing.py
Show inline comments
 
@@ -6,13 +6,13 @@ import argparse
 
from unittest import TestCase
 

	
 
from mock import Mock, patch
 

	
 
from rattail.core import Object
 
from rattail.commands import importing
 
from rattail.importing import ImportHandler
 
from rattail.importing import Importer, ImportHandler
 
from rattail.importing.rattail import FromRattailToRattail
 
from rattail.config import RattailConfig
 
from rattail.tests.importing import ImporterTester
 
from rattail.tests.importing.test_handlers import MockImportHandler
 
from rattail.tests.importing.test_importers import MockImporter
 
from rattail.tests.importing.test_rattail import DualRattailTestCase
 
@@ -21,13 +21,13 @@ from rattail.tests.importing.test_rattail import DualRattailTestCase
 
class MockImport(importing.ImportSubcommand):
 

	
 
    def get_handler_factory(self):
 
        return MockImportHandler
 

	
 

	
 
class TestImportSubcommandBasics(TestCase):
 
class TestImportSubcommand(TestCase):
 

	
 
    # TODO: lame, here only for coverage
 
    def test_parent_name(self):
 
        parent = Object(name='milo')
 
        command = importing.ImportSubcommand(parent=parent)
 
        self.assertEqual(command.parent_name, 'milo')
 
@@ -36,12 +36,28 @@ class TestImportSubcommandBasics(TestCase):
 
        command = importing.ImportSubcommand()
 
        self.assertRaises(NotImplementedError, command.get_handler_factory)
 
        command.handler_spec = 'rattail.importing.rattail:FromRattailToRattail'
 
        factory = command.get_handler_factory()
 
        self.assertIs(factory, FromRattailToRattail)
 

	
 
    def test_handler_spec_attr(self):
 
        # default is None
 
        command = importing.ImportSubcommand()
 
        self.assertIsNone(command.handler_spec)
 

	
 
        # can't get a handler without a spec
 
        self.assertRaises(NotImplementedError, command.get_handler)
 

	
 
        # but may be specified with init kwarg
 
        command = importing.ImportSubcommand(handler_spec='rattail.importing:ImportHandler')
 
        self.assertEqual(command.handler_spec, 'rattail.importing:ImportHandler')
 

	
 
        # now we can get a handler
 
        handler = command.get_handler()
 
        self.assertIsInstance(handler, ImportHandler)
 

	
 
    def test_get_handler(self):
 

	
 
        # no config
 
        command = MockImport()
 
        handler = command.get_handler()
 
        self.assertIs(type(handler), MockImportHandler)
 
@@ -62,183 +78,196 @@ class TestImportSubcommandBasics(TestCase):
 
        self.assertTrue(handler.dry_run)
 
        args = argparse.Namespace(dry_run=True)
 
        handler = command.get_handler(args=args)
 
        self.assertTrue(handler.dry_run)
 

	
 
    def test_add_parser_args(self):
 
        # TODO: this doesn't really test anything, but does give some coverage..
 

	
 
        # no handler
 
        # adding the args throws no error..(basic coverage)
 
        command = importing.ImportSubcommand()
 
        parser = argparse.ArgumentParser()
 
        command.add_parser_args(parser)
 

	
 
        # with handler
 
        command = MockImport()
 
        # confirm default values
 
        args = parser.parse_args([])
 
        self.assertIsNone(args.start_date)
 
        self.assertIsNone(args.end_date)
 
        self.assertTrue(args.create)
 
        self.assertIsNone(args.max_create)
 
        self.assertTrue(args.update)
 
        self.assertIsNone(args.max_update)
 
        self.assertFalse(args.delete)
 
        self.assertIsNone(args.max_delete)
 
        self.assertIsNone(args.max_total)
 
        self.assertEqual(args.batch_size, 200)
 
        self.assertFalse(args.warnings)
 
        self.assertFalse(args.dry_run)
 

	
 
    def test_batch_size_kwarg(self):
 
        command = importing.ImportSubcommand(handler_spec='rattail.importing:ImportHandler')
 
        parser = argparse.ArgumentParser()
 
        command.add_parser_args(parser)
 
        with patch.object(ImportHandler, 'get_importers', Mock(return_value={'Foo': Importer})):
 

	
 
            # importer default is 200
 
            args = parser.parse_args([])
 
            handler = command.get_handler(args=args)
 
            importer = handler.get_importer('Foo')
 
            self.assertEqual(importer.batch_size, 200)
 

	
 
            # but may be overridden with command line arg
 
            args = parser.parse_args(['--batch', '42'])
 
            handler = command.get_handler(args=args)
 
            importer = handler.get_importer('Foo')
 
            self.assertEqual(importer.batch_size, 42)
 

	
 
class TestImportSubcommandRun(ImporterTester, TestCase):
 

	
 
class TestImportSubcommandRun(TestCase, ImporterTester):
 

	
 
    sample_data = {
 
        '16oz': {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"},
 
        '32oz': {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"},
 
        '1gal': {'upc': '00074305011283', 'description': "Apple Cider Vinegar 1gal"},
 
    }
 

	
 
    def setUp(self):
 
        self.command = MockImport()
 
        self.handler = MockImportHandler()
 
        self.importer = MockImporter()
 

	
 
    def import_data(self, **kwargs):
 
    def import_data(self, host_data=None, local_data=None, **kwargs):
 
        if host_data is None:
 
            host_data = self.sample_data
 
        if local_data is None:
 
            local_data = self.sample_data
 

	
 
        models = kwargs.pop('models', [])
 
        kwargs.setdefault('dry_run', False)
 

	
 
        kw = {
 
            'warnings': False,
 
            'create': None,
 
            'max_create': None,
 
            'update': None,
 
            'max_update': None,
 
            'delete': None,
 
            'max_delete': None,
 
            'max_total': None,
 
            'batchcount': None,
 
            'progress': None,
 
        }
 
        kw.update(kwargs)
 
        args = argparse.Namespace(models=models, **kw)
 

	
 
        # must modify our importer in-place since we need the handler to return
 
        # that specific instance, below (because the host/local data context
 
        # managers reference that instance directly)
 
        self.importer._setup(**kwargs)
 
        with patch.object(self.command, 'get_handler', Mock(return_value=self.handler)):
 
            with patch.object(self.handler, 'get_importer', Mock(return_value=self.importer)):
 
                self.command.run(args)
 
                with self.host_data(host_data):
 
                    with self.local_data(local_data):
 
                        self.command.run(args)
 

	
 
        if self.handler._result:
 
            self.result = self.handler._result['Product']
 
        else:
 
            self.result = [], [], []
 

	
 
    def test_create(self):
 
        local = self.copy_data()
 
        del local['32oz']
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data()
 
        self.import_data(local_data=local)
 
        self.assert_import_created('32oz')
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_update(self):
 
        local = self.copy_data()
 
        local['16oz']['description'] = "wrong description"
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data()
 
        self.import_data(local_data=local)
 
        self.assert_import_created()
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted()
 

	
 
    def test_delete(self):
 
        local = self.copy_data()
 
        local['bogus'] = {'upc': '00000000000000', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(delete=True)
 
        self.import_data(local_data=local, delete=True)
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted('bogus')
 

	
 
    def test_duplicate(self):
 
        host = self.copy_data()
 
        host['32oz-dupe'] = host['32oz']
 
        with self.host_data(host):
 
            with self.local_data(self.sample_data):
 
                self.import_data()
 
        self.import_data(host_data=host)
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_max_create(self):
 
        local = self.copy_data()
 
        del local['16oz']
 
        del local['1gal']
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_create=1)
 
        self.import_data(local_data=local, max_create=1)
 
        self.assert_import_created('16oz')
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_max_total_create(self):
 
        local = self.copy_data()
 
        del local['16oz']
 
        del local['1gal']
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_total=1)
 
        self.import_data(local_data=local, max_total=1)
 
        self.assert_import_created('16oz')
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_max_update(self):
 
        local = self.copy_data()
 
        local['16oz']['description'] = "wrong"
 
        local['1gal']['description'] = "wrong"
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_update=1)
 
        self.import_data(local_data=local, max_update=1)
 
        self.assert_import_created()
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted()
 

	
 
    def test_max_total_update(self):
 
        local = self.copy_data()
 
        local['16oz']['description'] = "wrong"
 
        local['1gal']['description'] = "wrong"
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_total=1)
 
        self.import_data(local_data=local, max_total=1)
 
        self.assert_import_created()
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted()
 

	
 
    def test_max_delete(self):
 
        local = self.copy_data()
 
        local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"}
 
        local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(delete=True, max_delete=1)
 
        self.import_data(local_data=local, delete=True, max_delete=1)
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted('bogus1')
 

	
 
    def test_max_total_delete(self):
 
        local = self.copy_data()
 
        local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"}
 
        local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(delete=True, max_total=1)
 
        self.import_data(local_data=local, delete=True, max_total=1)
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted('bogus1')
 

	
 
    def test_dry_run(self):
 
        local = self.copy_data()
 
        del local['32oz']
 
        local['16oz']['description'] = "wrong description"
 
        local['bogus'] = {'upc': '00000000000000', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(delete=True, dry_run=True)
 
        self.import_data(local_data=local, delete=True, dry_run=True)
 
        # TODO: maybe need a way to confirm no changes actually made due to dry
 
        # run; currently results still reflect "proposed" changes.  this rather
 
        # bogus test is here just for coverage sake
 
        self.assert_import_created('32oz')
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted('bogus')
rattail/tests/importing/test_handlers.py
Show inline comments
 
@@ -71,12 +71,22 @@ class ImportHandlerBattery(ImporterTester):
 
            handler = self.handler_class(config=config)
 
        importer = handler.get_importer('foo')
 
        self.assertIs(type(importer), Importer)
 
        self.assertIs(importer.config, config)
 
        self.assertIs(importer.handler, handler)
 

	
 
        # # batch size
 
        # with patch.object(self.handler_class, 'get_importers', get_importers):
 
        #     handler = self.handler_class()
 
        # importer = handler.get_importer('foo')
 
        # self.assertEqual(importer.batch_size, 200)
 
        # with patch.object(self.handler_class, 'get_importers', get_importers):
 
        #     handler = self.handler_class(batch_size=42)
 
        # importer = handler.get_importer('foo')
 
        # self.assertEqual(importer.batch_size, 42)
 

	
 
        # dry run
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class()
 
        importer = handler.get_importer('foo')
 
        self.assertFalse(importer.dry_run)
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
@@ -273,12 +283,38 @@ class BulkImportHandlerBattery(ImportHandlerBattery):
 
            self.assertFalse(process.called)
 

	
 

	
 
class TestImportHandler(unittest.TestCase, ImportHandlerBattery):
 
    handler_class = handlers.ImportHandler
 

	
 
    def test_batch_size_kwarg(self):
 
        with patch.object(handlers.ImportHandler, 'get_importers', Mock(return_value={'Foo': Importer})):
 

	
 
            # handler has no batch size by default
 
            handler = handlers.ImportHandler()
 
            self.assertFalse(hasattr(handler, 'batch_size'))
 

	
 
            # but can override with kwarg
 
            handler = handlers.ImportHandler(batch_size=42)
 
            self.assertEqual(handler.batch_size, 42)
 

	
 
            # importer default is 200
 
            handler = handlers.ImportHandler()
 
            importer = handler.get_importer('Foo')
 
            self.assertEqual(importer.batch_size, 200)
 

	
 
            # but can override with handler init kwarg
 
            handler = handlers.ImportHandler(batch_size=37)
 
            importer = handler.get_importer('Foo')
 
            self.assertEqual(importer.batch_size, 37)
 

	
 
            # can also override with get_importer kwarg
 
            handler = handlers.ImportHandler()
 
            importer = handler.get_importer('Foo', batch_size=98)
 
            self.assertEqual(importer.batch_size, 98)
 

	
 
    @patch('rattail.importing.handlers.send_email')
 
    def test_process_changes_sends_email(self, send_email):
 
        handler = handlers.ImportHandler()
 
        handler.import_began = pytz.utc.localize(datetime.datetime.utcnow())
 
        changes = [], [], []
 

	
 
@@ -588,69 +624,6 @@ class TestToSQLAlchemyHandler(unittest.TestCase):
 
        self.assertIs(handler.session, session)
 
        with patch.object(handler, 'session') as session:
 
            handler.rollback_local_transaction()
 
            session.rollback.assert_called_once_with()
 
            self.assertFalse(session.commit.called)
 
        # self.assertIsNone(handler.session)
 

	
 

	
 
######################################################################
 
# fake bulk import handler, tested mostly for basic coverage
 
######################################################################
 

	
 
class MockBulkImportHandler(handlers.BulkToPostgreSQLHandler):
 

	
 
    def get_importers(self):
 
        return {'Department': MockBulkImporter}
 

	
 
    def make_session(self):
 
        return Session()
 

	
 

	
 
class TestBulkImportHandlerOld(RattailTestCase, ImporterTester):
 

	
 
    importer_class = MockBulkImporter
 

	
 
    sample_data = {
 
        'grocery': {'number': 1, 'name': "Grocery", 'uuid': 'decd909a194011e688093ca9f40bc550'},
 
        'bulk': {'number': 2, 'name': "Bulk", 'uuid': 'e633d54c194011e687e33ca9f40bc550'},
 
        'hba': {'number': 3, 'name': "HBA", 'uuid': 'e2bad79e194011e6a4783ca9f40bc550'},
 
    }
 

	
 
    def setUp(self):
 
        self.setup_rattail()
 
        self.tempio = TempIO()
 
        self.config.set('rattail', 'workdir', self.tempio.realpath())
 
        self.handler = MockBulkImportHandler(config=self.config)
 

	
 
    def tearDown(self):
 
        self.teardown_rattail()
 
        self.tempio = None
 

	
 
    def import_data(self, host_data=None, **kwargs):
 
        if host_data is None:
 
            host_data = list(self.copy_data().itervalues())
 
        with patch.object(self.importer_class, 'normalize_host_data', Mock(return_value=host_data)):
 
            with patch.object(self.handler, 'make_session', Mock(return_value=self.session)):
 
                return self.handler.import_data('Department', **kwargs)
 

	
 
    def test_invalid_importer_key_is_ignored(self):
 
        handler = MockBulkImportHandler()
 
        self.assertNotIn('InvalidKey', handler.importers)
 
        self.assertEqual(handler.import_data('InvalidKey'), {})
 

	
 
    def assert_import_created(self, *keys):
 
        pass
 

	
 
    def assert_import_updated(self, *keys):
 
        pass
 

	
 
    def assert_import_deleted(self, *keys):
 
        pass
 

	
 
    def test_normal_run(self):
 
        if self.postgresql():
 
            self.import_data()
 

	
 
    def test_dry_run(self):
 
        if self.postgresql():
 
            self.import_data(dry_run=True)
rattail/tests/importing/test_importers.py
Show inline comments
 
@@ -56,15 +56,16 @@ class BulkImporterBattery(ImporterBattery):
 
        self.assertEqual(result, 0)
 

	
 

	
 
class TestImporter(TestCase):
 

	
 
    def test_init(self):
 

	
 
        # defaults
 
        importer = importers.Importer()
 
        self.assertIsNone(importer.model_class)
 
        self.assertIsNone(importer.model_name)
 
        self.assertIsNone(importer.key)
 
        self.assertEqual(importer.fields, [])
 
        self.assertIsNone(importer.host_system_title)
 

	
 
        # key must be included among the fields
 
        self.assertRaises(ValueError, importers.Importer, key='upc', fields=[])
 
@@ -189,12 +190,79 @@ class TestImporter(TestCase):
 
        self.assertEqual(keys, set())
 

	
 
        importer.cached_local_data = {'delete-me': object()}
 
        keys = importer.get_deletion_keys()
 
        self.assertEqual(keys, set(['delete-me']))
 

	
 
    def test_model_name_attr(self):
 
        # default is None
 
        importer = importers.Importer()
 
        self.assertIsNone(importer.model_name)
 

	
 
        # but may be overridden via init kwarg
 
        importer = importers.Importer(model_name='Foo')
 
        self.assertEqual(importer.model_name, 'Foo')
 

	
 
        # or may inherit its value from 'model_class'
 
        class Foo:
 
            pass
 
        importer = importers.Importer(model_class=Foo)
 
        self.assertEqual(importer.model_name, 'Foo')
 

	
 
    def test_batch_size_attr(self):
 
        # default is 200
 
        importer = importers.Importer()
 
        self.assertEqual(importer.batch_size, 200)
 

	
 
        # but may be overridden via init kwarg
 
        importer = importers.Importer(batch_size=0)
 
        self.assertEqual(importer.batch_size, 0)
 
        importer = importers.Importer(batch_size=42)
 
        self.assertEqual(importer.batch_size, 42)
 

	
 
        # batch size determines how often flush occurs
 
        data = [{'id': i} for i in range(1, 101)]
 
        importer = importers.Importer(model_name='Foo', key='id', fields=['id'], empty_local_data=True)
 
        with patch.object(importer, 'create_object'): # just mock that out
 
            with patch.object(importer, 'flush_create_update') as flush:
 

	
 
                # 4 batches @ 33/per
 
                importer.import_data(host_data=data, batch_size=33)
 
                self.assertEqual(flush.call_count, 4)
 
                flush.reset_mock()
 

	
 
                # 3 batches @ 34/per
 
                importer.import_data(host_data=data, batch_size=34)
 
                self.assertEqual(flush.call_count, 3)
 
                flush.reset_mock()
 

	
 
                # 2 batches @ 50/per
 
                importer.import_data(host_data=data, batch_size=100)
 
                self.assertEqual(flush.call_count, 2)
 
                flush.reset_mock()
 

	
 
                # one extra/final flush happens, whenever the total number of
 
                # changes happens to match the batch size...
 

	
 
                # 1 batch @ 100/per, plus final flush
 
                importer.import_data(host_data=data, batch_size=100)
 
                self.assertEqual(flush.call_count, 2)
 
                flush.reset_mock()
 

	
 
                # 1 batch @ 200/per
 
                importer.import_data(host_data=data, batch_size=200)
 
                self.assertEqual(flush.call_count, 1)
 
                flush.reset_mock()
 

	
 
                # one extra/final flush also happens when batching is disabled
 

	
 
                # 100 "batches" @ 0/per, plus final flush
 
                importer.import_data(host_data=data, batch_size=0)
 
                self.assertEqual(flush.call_count, 101)
 
                flush.reset_mock()
 

	
 

	
 
class TestFromQuery(RattailTestCase):
 

	
 
    def test_query(self):
 
        importer = importers.FromQuery()
 
        self.assertRaises(NotImplementedError, importer.query)
 
@@ -207,12 +275,63 @@ class TestFromQuery(RattailTestCase):
 
        self.assertIsInstance(objects, QuerySequence)
 

	
 

	
 
class TestBulkImporter(TestCase, BulkImporterBattery):
 
    importer_class = importers.BulkImporter
 
        
 
    def test_batch_size_attr(self):
 
        # default is 200
 
        importer = importers.BulkImporter()
 
        self.assertEqual(importer.batch_size, 200)
 

	
 
        # but may be overridden via init kwarg
 
        importer = importers.BulkImporter(batch_size=0)
 
        self.assertEqual(importer.batch_size, 0)
 
        importer = importers.BulkImporter(batch_size=42)
 
        self.assertEqual(importer.batch_size, 42)
 

	
 
        # batch size determines how often flush occurs
 
        data = [{'id': i} for i in range(1, 101)]
 
        importer = importers.BulkImporter(model_name='Foo', key='id', fields=['id'], empty_local_data=True)
 
        with patch.object(importer, 'create_object'): # just mock that out
 
            with patch.object(importer, 'flush_create_update') as flush:
 

	
 
                # 4 batches @ 33/per
 
                importer.import_data(host_data=data, batch_size=33)
 
                self.assertEqual(flush.call_count, 4)
 
                flush.reset_mock()
 

	
 
                # 3 batches @ 34/per
 
                importer.import_data(host_data=data, batch_size=34)
 
                self.assertEqual(flush.call_count, 3)
 
                flush.reset_mock()
 

	
 
                # 2 batches @ 50/per
 
                importer.import_data(host_data=data, batch_size=100)
 
                self.assertEqual(flush.call_count, 2)
 
                flush.reset_mock()
 

	
 
                # one extra/final flush happens, whenever the total number of
 
                # changes happens to match the batch size...
 

	
 
                # 1 batch @ 100/per, plus final flush
 
                importer.import_data(host_data=data, batch_size=100)
 
                self.assertEqual(flush.call_count, 2)
 
                flush.reset_mock()
 

	
 
                # 1 batch @ 200/per
 
                importer.import_data(host_data=data, batch_size=200)
 
                self.assertEqual(flush.call_count, 1)
 
                flush.reset_mock()
 

	
 
                # one extra/final flush also happens when batching is disabled
 

	
 
                # 100 "batches" @ 0/per, plus final flush
 
                importer.import_data(host_data=data, batch_size=0)
 
                self.assertEqual(flush.call_count, 101)
 
                flush.reset_mock()
 

	
 

	
 
######################################################################
 
# fake importer class, tested mostly for basic coverage
 
######################################################################
 

	
rattail/tests/importing/test_postgresql.py
Show inline comments
 
@@ -36,23 +36,22 @@ class TestBulkToPostgreSQL(unittest.TestCase):
 

	
 
    def make_importer(self, **kwargs):
 
        kwargs.setdefault('config', self.config)
 
        kwargs.setdefault('fields', ['id']) # hack
 
        return pgimport.BulkToPostgreSQL(**kwargs)
 

	
 
    def test_data_path(self):
 
        importer = self.make_importer(config=None)
 
        self.assertIsNone(importer.config)
 
        self.assertRaises(AttributeError, getattr, importer, 'data_path')
 
        importer.config = RattailConfig()
 
        self.assertRaises(ConfigurationError, getattr, importer, 'data_path')
 
        importer.config = self.config
 
    def test_data_path_property(self):
 
        self.config.set('rattail', 'workdir', '/tmp')
 
        self.assertEqual(importer.data_path, '/tmp/import_bulk_postgresql_None.csv') # no model yet
 
        importer.model_class = Widget
 
        self.assertEqual(importer.data_path, '/tmp/import_bulk_postgresql_Widget.csv')
 
        importer = pgimport.BulkToPostgreSQL(config=self.config, fields=['id'])
 

	
 
        # path leverages model name, so default is None
 
        self.assertEqual(importer.data_path, '/tmp/import_bulk_postgresql_None.csv')
 

	
 
        # but it will reflect model name
 
        importer.model_name = 'Foo'
 
        self.assertEqual(importer.data_path, '/tmp/import_bulk_postgresql_Foo.csv')
 

	
 
    def test_setup(self):
 
        importer = self.make_importer()
 
        self.assertFalse(hasattr(importer, 'data_buffer'))
 
        importer.setup()
 
        self.assertIsNotNone(importer.data_buffer)
rattail/tests/importing/test_sqlalchemy.py
Show inline comments
 
# -*- coding: utf-8 -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from unittest import TestCase
 

	
 
from mock import patch
 

	
 
import sqlalchemy as sa
 
from sqlalchemy import orm
 
from sqlalchemy.orm.exc import MultipleResultsFound
 

	
 
from rattail.importing import sqlalchemy as saimport
 

	
 
@@ -142,6 +144,12 @@ class TestToSQLAlchemy(TestCase):
 
    def test_flush_session(self):
 
        importer = self.make_importer(fields=['id'], session=self.session, flush_session=True)
 
        widget = Widget()
 
        widget.id = 1
 
        widget, original = importer.update_object(widget, {'id': 1}), widget
 
        self.assertIs(widget, original)
 

	
 
    def test_flush_create_update(self):
 
        importer = saimport.ToSQLAlchemy(fields=['id'], session=self.session)
 
        with patch.object(self.session, 'flush') as flush:
 
            importer.flush_create_update()
 
            flush.assert_called_once_with()
0 comments (0 inline, 0 general)