diff --git a/rattail/commands/importing.py b/rattail/commands/importing.py index 2b1ccdebc2c863c2c6f57e4640177640a1ffe026..ba1a05867de71da2a20f611398d72b2b9724ea1c 100644 --- a/rattail/commands/importing.py +++ b/rattail/commands/importing.py @@ -44,6 +44,8 @@ class ImportSubcommand(Subcommand): # TODO: move this into Subcommand or something.. parent_name = None def __init__(self, *args, **kwargs): + if 'handler_spec' in kwargs: + self.handler_spec = kwargs.pop('handler_spec') super(ImportSubcommand, self).__init__(*args, **kwargs) if self.parent: self.parent_name = self.parent.name @@ -68,6 +70,8 @@ class ImportSubcommand(Subcommand): if 'args' in kwargs: args = kwargs['args'] kwargs.setdefault('dry_run', args.dry_run) + if hasattr(args, 'batch_size'): + kwargs.setdefault('batch_size', args.batch_size) # kwargs.setdefault('max_create', args.max_create) # kwargs.setdefault('max_update', args.max_update) # kwargs.setdefault('max_delete', args.max_delete) @@ -139,6 +143,13 @@ class ImportSubcommand(Subcommand): "a given import task should stop. Note that this applies on a per-model " "basis and not overall.") + # batch size + parser.add_argument('--batch', type=int, dest='batch_size', metavar='SIZE', default=200, + help="Split work to be done into batches, with the specified number of " + "records in each batch. Or, set this to 0 (zero) to disable batching. " + "Implementation for this may vary somewhat between importers; default " + "batch size is 200 records.") + # treat changes as warnings? parser.add_argument('--warnings', '-W', action='store_true', help="Set this flag if you expect a \"clean\" import, and wish for any " diff --git a/rattail/importing/handlers.py b/rattail/importing/handlers.py index 268874482c2159fc4d64a0fc5805698cc1534334..f797cb61d6d9763c2cd3cdd1fe9b58032a72a508 100644 --- a/rattail/importing/handlers.py +++ b/rattail/importing/handlers.py @@ -91,6 +91,8 @@ class ImportHandler(object): kwargs.setdefault('handler', self) kwargs.setdefault('config', self.config) kwargs.setdefault('host_system_title', self.host_title) + if hasattr(self, 'batch_size'): + kwargs.setdefault('batch_size', self.batch_size) kwargs = self.get_importer_kwargs(key, **kwargs) return self.importers[key](**kwargs) @@ -339,9 +341,3 @@ class ToSQLAlchemyHandler(ImportHandler): self.session.commit() self.session.close() self.session = None - - -class BulkToPostgreSQLHandler(BulkImportHandler): - """ - Handler for bulk imports which target PostgreSQL on the local side. - """ diff --git a/rattail/importing/importers.py b/rattail/importing/importers.py index 1a53a0aad38214eb9c7de671b3fd37b9fc06aa96..60fa27d302688167420795f1f03088d80d2b3b42 100644 --- a/rattail/importing/importers.py +++ b/rattail/importing/importers.py @@ -42,6 +42,7 @@ class Importer(object): """ # Set this to the data model class which is targeted on the local side. model_class = None + model_name = None key = None @@ -65,6 +66,7 @@ class Importer(object): max_update = None max_delete = None max_total = None + batch_size = 200 progress = None empty_local_data = False @@ -86,6 +88,10 @@ class Importer(object): if field not in self.fields: raise ValueError("Key field '{}' must be included in effective fields " "for {}".format(field, self.__class__.__name__)) + self.model_class = kwargs.pop('model_class', self.model_class) + self.model_name = kwargs.pop('model_name', self.model_name) + if not self.model_name and self.model_class: + self.model_name = self.model_class.__name__ self._setup(**kwargs) def _setup(self, **kwargs): @@ -95,11 +101,6 @@ class Importer(object): for key, value in kwargs.iteritems(): setattr(self, key, value) - @property - def model_name(self): - if self.model_class: - return self.model_class.__name__ - def setup(self): """ Perform any setup necessary, e.g. cache lookups for existing data. @@ -208,25 +209,29 @@ class Importer(object): log.warning("max of {} *total changes* has been reached; stopping now".format(self.max_total)) break - self.flush_changes(i) - # # TODO: this needs to be customizable etc. somehow maybe.. - # if i % 100 == 0 and hasattr(self, 'session'): - # self.session.flush() + # flush changes every so often + if not self.batch_size or (len(created) + len(updated)) % self.batch_size == 0: + self.flush_create_update() if prog: prog.update(i) if prog: prog.destroy() + self.flush_create_update_final() return created, updated - # TODO: this surely goes elsewhere - flush_every_x = 100 + def flush_create_update(self): + """ + Perform any steps necessary to "flush" the create/update changes which + have occurred thus far in the import. + """ - def flush_changes(self, x): - if self.flush_every_x and x % self.flush_every_x == 0: - if hasattr(self, 'session'): - self.session.flush() + def flush_create_update_final(self): + """ + Perform any final steps to "flush" the created/updated data here. + """ + self.flush_create_update() def _import_delete(self, host_data, host_keys, changes=0): """ @@ -460,7 +465,7 @@ class FromQuery(Importer): class BulkImporter(Importer): """ - Base class for bulk data importers which target PostgreSQL on the local side. + Base class for bulk data importers. """ def import_data(self, host_data=None, now=None, **kwargs): @@ -493,17 +498,14 @@ class BulkImporter(Importer): created = i break + # flush changes every so often + if not self.batch_size or i % self.batch_size == 0: + self.flush_create_update() + if prog: prog.update(i) if prog: prog.destroy() - self.flush_create() + self.flush_create_update_final() return created - - def flush_create(self): - """ - Perform any final steps to "flush" the created data here. Note that - the importer's handler is still responsible for actually committing - changes to the local system, if applicable. - """ diff --git a/rattail/importing/postgresql.py b/rattail/importing/postgresql.py index c484d50d18c1e278b5dbc32b35a34ea0ed360ab4..ee43af0dc5c5fca3f00cd4b71973bd21c9f19f1b 100644 --- a/rattail/importing/postgresql.py +++ b/rattail/importing/postgresql.py @@ -83,7 +83,10 @@ class BulkToPostgreSQL(BulkImporter, ToSQLAlchemy): return unicode(value) - def flush_create(self): + def flush_create_update(self): + pass + + def flush_create_update_final(self): log.info("copying {} data from buffer to PostgreSQL".format(self.model_name)) self.data_buffer.close() self.data_buffer = open(self.data_path, 'rb') diff --git a/rattail/importing/sqlalchemy.py b/rattail/importing/sqlalchemy.py index a16ad2ffb53fbea5b8b394afa748c4b6b98cc7d0..9289af18ff80cd99dab11d18b1e448d7786ed375 100644 --- a/rattail/importing/sqlalchemy.py +++ b/rattail/importing/sqlalchemy.py @@ -170,3 +170,10 @@ class ToSQLAlchemy(Importer): """ Return a list of options to apply to the cache query, if needed. """ + + def flush_create_update(self): + """ + Flush the database session, to send SQL to the server for all changes + made thus far. + """ + self.session.flush() diff --git a/rattail/tests/commands/test_importing.py b/rattail/tests/commands/test_importing.py index f664eb1e06ac9e389cfc8cbe4a422de5e4163baf..dddf0cdaeb8d48577ccfcaf747063e839348d332 100644 --- a/rattail/tests/commands/test_importing.py +++ b/rattail/tests/commands/test_importing.py @@ -9,7 +9,7 @@ from mock import Mock, patch from rattail.core import Object from rattail.commands import importing -from rattail.importing import ImportHandler +from rattail.importing import Importer, ImportHandler from rattail.importing.rattail import FromRattailToRattail from rattail.config import RattailConfig from rattail.tests.importing import ImporterTester @@ -24,7 +24,7 @@ class MockImport(importing.ImportSubcommand): return MockImportHandler -class TestImportSubcommandBasics(TestCase): +class TestImportSubcommand(TestCase): # TODO: lame, here only for coverage def test_parent_name(self): @@ -39,6 +39,22 @@ class TestImportSubcommandBasics(TestCase): factory = command.get_handler_factory() self.assertIs(factory, FromRattailToRattail) + def test_handler_spec_attr(self): + # default is None + command = importing.ImportSubcommand() + self.assertIsNone(command.handler_spec) + + # can't get a handler without a spec + self.assertRaises(NotImplementedError, command.get_handler) + + # but may be specified with init kwarg + command = importing.ImportSubcommand(handler_spec='rattail.importing:ImportHandler') + self.assertEqual(command.handler_spec, 'rattail.importing:ImportHandler') + + # now we can get a handler + handler = command.get_handler() + self.assertIsInstance(handler, ImportHandler) + def test_get_handler(self): # no config @@ -65,20 +81,47 @@ class TestImportSubcommandBasics(TestCase): self.assertTrue(handler.dry_run) def test_add_parser_args(self): - # TODO: this doesn't really test anything, but does give some coverage.. - # no handler + # adding the args throws no error..(basic coverage) command = importing.ImportSubcommand() parser = argparse.ArgumentParser() command.add_parser_args(parser) - # with handler - command = MockImport() + # confirm default values + args = parser.parse_args([]) + self.assertIsNone(args.start_date) + self.assertIsNone(args.end_date) + self.assertTrue(args.create) + self.assertIsNone(args.max_create) + self.assertTrue(args.update) + self.assertIsNone(args.max_update) + self.assertFalse(args.delete) + self.assertIsNone(args.max_delete) + self.assertIsNone(args.max_total) + self.assertEqual(args.batch_size, 200) + self.assertFalse(args.warnings) + self.assertFalse(args.dry_run) + + def test_batch_size_kwarg(self): + command = importing.ImportSubcommand(handler_spec='rattail.importing:ImportHandler') parser = argparse.ArgumentParser() command.add_parser_args(parser) + with patch.object(ImportHandler, 'get_importers', Mock(return_value={'Foo': Importer})): + + # importer default is 200 + args = parser.parse_args([]) + handler = command.get_handler(args=args) + importer = handler.get_importer('Foo') + self.assertEqual(importer.batch_size, 200) + # but may be overridden with command line arg + args = parser.parse_args(['--batch', '42']) + handler = command.get_handler(args=args) + importer = handler.get_importer('Foo') + self.assertEqual(importer.batch_size, 42) -class TestImportSubcommandRun(ImporterTester, TestCase): + +class TestImportSubcommandRun(TestCase, ImporterTester): sample_data = { '16oz': {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"}, @@ -91,7 +134,12 @@ class TestImportSubcommandRun(ImporterTester, TestCase): self.handler = MockImportHandler() self.importer = MockImporter() - def import_data(self, **kwargs): + def import_data(self, host_data=None, local_data=None, **kwargs): + if host_data is None: + host_data = self.sample_data + if local_data is None: + local_data = self.sample_data + models = kwargs.pop('models', []) kwargs.setdefault('dry_run', False) @@ -104,6 +152,7 @@ class TestImportSubcommandRun(ImporterTester, TestCase): 'delete': None, 'max_delete': None, 'max_total': None, + 'batchcount': None, 'progress': None, } kw.update(kwargs) @@ -115,7 +164,9 @@ class TestImportSubcommandRun(ImporterTester, TestCase): self.importer._setup(**kwargs) with patch.object(self.command, 'get_handler', Mock(return_value=self.handler)): with patch.object(self.handler, 'get_importer', Mock(return_value=self.importer)): - self.command.run(args) + with self.host_data(host_data): + with self.local_data(local_data): + self.command.run(args) if self.handler._result: self.result = self.handler._result['Product'] @@ -125,9 +176,7 @@ class TestImportSubcommandRun(ImporterTester, TestCase): def test_create(self): local = self.copy_data() del local['32oz'] - with self.host_data(self.sample_data): - with self.local_data(local): - self.import_data() + self.import_data(local_data=local) self.assert_import_created('32oz') self.assert_import_updated() self.assert_import_deleted() @@ -135,9 +184,7 @@ class TestImportSubcommandRun(ImporterTester, TestCase): def test_update(self): local = self.copy_data() local['16oz']['description'] = "wrong description" - with self.host_data(self.sample_data): - with self.local_data(local): - self.import_data() + self.import_data(local_data=local) self.assert_import_created() self.assert_import_updated('16oz') self.assert_import_deleted() @@ -145,9 +192,7 @@ class TestImportSubcommandRun(ImporterTester, TestCase): def test_delete(self): local = self.copy_data() local['bogus'] = {'upc': '00000000000000', 'description': "Delete Me"} - with self.host_data(self.sample_data): - with self.local_data(local): - self.import_data(delete=True) + self.import_data(local_data=local, delete=True) self.assert_import_created() self.assert_import_updated() self.assert_import_deleted('bogus') @@ -155,9 +200,7 @@ class TestImportSubcommandRun(ImporterTester, TestCase): def test_duplicate(self): host = self.copy_data() host['32oz-dupe'] = host['32oz'] - with self.host_data(host): - with self.local_data(self.sample_data): - self.import_data() + self.import_data(host_data=host) self.assert_import_created() self.assert_import_updated() self.assert_import_deleted() @@ -166,9 +209,7 @@ class TestImportSubcommandRun(ImporterTester, TestCase): local = self.copy_data() del local['16oz'] del local['1gal'] - with self.host_data(self.sample_data): - with self.local_data(local): - self.import_data(max_create=1) + self.import_data(local_data=local, max_create=1) self.assert_import_created('16oz') self.assert_import_updated() self.assert_import_deleted() @@ -177,9 +218,7 @@ class TestImportSubcommandRun(ImporterTester, TestCase): local = self.copy_data() del local['16oz'] del local['1gal'] - with self.host_data(self.sample_data): - with self.local_data(local): - self.import_data(max_total=1) + self.import_data(local_data=local, max_total=1) self.assert_import_created('16oz') self.assert_import_updated() self.assert_import_deleted() @@ -188,9 +227,7 @@ class TestImportSubcommandRun(ImporterTester, TestCase): local = self.copy_data() local['16oz']['description'] = "wrong" local['1gal']['description'] = "wrong" - with self.host_data(self.sample_data): - with self.local_data(local): - self.import_data(max_update=1) + self.import_data(local_data=local, max_update=1) self.assert_import_created() self.assert_import_updated('16oz') self.assert_import_deleted() @@ -199,9 +236,7 @@ class TestImportSubcommandRun(ImporterTester, TestCase): local = self.copy_data() local['16oz']['description'] = "wrong" local['1gal']['description'] = "wrong" - with self.host_data(self.sample_data): - with self.local_data(local): - self.import_data(max_total=1) + self.import_data(local_data=local, max_total=1) self.assert_import_created() self.assert_import_updated('16oz') self.assert_import_deleted() @@ -210,9 +245,7 @@ class TestImportSubcommandRun(ImporterTester, TestCase): local = self.copy_data() local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"} local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"} - with self.host_data(self.sample_data): - with self.local_data(local): - self.import_data(delete=True, max_delete=1) + self.import_data(local_data=local, delete=True, max_delete=1) self.assert_import_created() self.assert_import_updated() self.assert_import_deleted('bogus1') @@ -221,9 +254,7 @@ class TestImportSubcommandRun(ImporterTester, TestCase): local = self.copy_data() local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"} local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"} - with self.host_data(self.sample_data): - with self.local_data(local): - self.import_data(delete=True, max_total=1) + self.import_data(local_data=local, delete=True, max_total=1) self.assert_import_created() self.assert_import_updated() self.assert_import_deleted('bogus1') @@ -233,9 +264,7 @@ class TestImportSubcommandRun(ImporterTester, TestCase): del local['32oz'] local['16oz']['description'] = "wrong description" local['bogus'] = {'upc': '00000000000000', 'description': "Delete Me"} - with self.host_data(self.sample_data): - with self.local_data(local): - self.import_data(delete=True, dry_run=True) + self.import_data(local_data=local, delete=True, dry_run=True) # TODO: maybe need a way to confirm no changes actually made due to dry # run; currently results still reflect "proposed" changes. this rather # bogus test is here just for coverage sake diff --git a/rattail/tests/importing/test_handlers.py b/rattail/tests/importing/test_handlers.py index 0797cebd7bd1fa9f4f1de333748143ad8a306d72..758c573d2b6905fc25aa4e0aac943eb86f8d437a 100644 --- a/rattail/tests/importing/test_handlers.py +++ b/rattail/tests/importing/test_handlers.py @@ -74,6 +74,16 @@ class ImportHandlerBattery(ImporterTester): self.assertIs(importer.config, config) self.assertIs(importer.handler, handler) + # # batch size + # with patch.object(self.handler_class, 'get_importers', get_importers): + # handler = self.handler_class() + # importer = handler.get_importer('foo') + # self.assertEqual(importer.batch_size, 200) + # with patch.object(self.handler_class, 'get_importers', get_importers): + # handler = self.handler_class(batch_size=42) + # importer = handler.get_importer('foo') + # self.assertEqual(importer.batch_size, 42) + # dry run with patch.object(self.handler_class, 'get_importers', get_importers): handler = self.handler_class() @@ -276,6 +286,32 @@ class BulkImportHandlerBattery(ImportHandlerBattery): class TestImportHandler(unittest.TestCase, ImportHandlerBattery): handler_class = handlers.ImportHandler + def test_batch_size_kwarg(self): + with patch.object(handlers.ImportHandler, 'get_importers', Mock(return_value={'Foo': Importer})): + + # handler has no batch size by default + handler = handlers.ImportHandler() + self.assertFalse(hasattr(handler, 'batch_size')) + + # but can override with kwarg + handler = handlers.ImportHandler(batch_size=42) + self.assertEqual(handler.batch_size, 42) + + # importer default is 200 + handler = handlers.ImportHandler() + importer = handler.get_importer('Foo') + self.assertEqual(importer.batch_size, 200) + + # but can override with handler init kwarg + handler = handlers.ImportHandler(batch_size=37) + importer = handler.get_importer('Foo') + self.assertEqual(importer.batch_size, 37) + + # can also override with get_importer kwarg + handler = handlers.ImportHandler() + importer = handler.get_importer('Foo', batch_size=98) + self.assertEqual(importer.batch_size, 98) + @patch('rattail.importing.handlers.send_email') def test_process_changes_sends_email(self, send_email): handler = handlers.ImportHandler() @@ -591,66 +627,3 @@ class TestToSQLAlchemyHandler(unittest.TestCase): session.rollback.assert_called_once_with() self.assertFalse(session.commit.called) # self.assertIsNone(handler.session) - - -###################################################################### -# fake bulk import handler, tested mostly for basic coverage -###################################################################### - -class MockBulkImportHandler(handlers.BulkToPostgreSQLHandler): - - def get_importers(self): - return {'Department': MockBulkImporter} - - def make_session(self): - return Session() - - -class TestBulkImportHandlerOld(RattailTestCase, ImporterTester): - - importer_class = MockBulkImporter - - sample_data = { - 'grocery': {'number': 1, 'name': "Grocery", 'uuid': 'decd909a194011e688093ca9f40bc550'}, - 'bulk': {'number': 2, 'name': "Bulk", 'uuid': 'e633d54c194011e687e33ca9f40bc550'}, - 'hba': {'number': 3, 'name': "HBA", 'uuid': 'e2bad79e194011e6a4783ca9f40bc550'}, - } - - def setUp(self): - self.setup_rattail() - self.tempio = TempIO() - self.config.set('rattail', 'workdir', self.tempio.realpath()) - self.handler = MockBulkImportHandler(config=self.config) - - def tearDown(self): - self.teardown_rattail() - self.tempio = None - - def import_data(self, host_data=None, **kwargs): - if host_data is None: - host_data = list(self.copy_data().itervalues()) - with patch.object(self.importer_class, 'normalize_host_data', Mock(return_value=host_data)): - with patch.object(self.handler, 'make_session', Mock(return_value=self.session)): - return self.handler.import_data('Department', **kwargs) - - def test_invalid_importer_key_is_ignored(self): - handler = MockBulkImportHandler() - self.assertNotIn('InvalidKey', handler.importers) - self.assertEqual(handler.import_data('InvalidKey'), {}) - - def assert_import_created(self, *keys): - pass - - def assert_import_updated(self, *keys): - pass - - def assert_import_deleted(self, *keys): - pass - - def test_normal_run(self): - if self.postgresql(): - self.import_data() - - def test_dry_run(self): - if self.postgresql(): - self.import_data(dry_run=True) diff --git a/rattail/tests/importing/test_importers.py b/rattail/tests/importing/test_importers.py index d814a763de9bf08003e8398a26568e4c81c71daa..a6247f3efbebfe757cbd90c0b9f623a49ac5e304 100644 --- a/rattail/tests/importing/test_importers.py +++ b/rattail/tests/importing/test_importers.py @@ -59,9 +59,10 @@ class BulkImporterBattery(ImporterBattery): class TestImporter(TestCase): def test_init(self): + + # defaults importer = importers.Importer() self.assertIsNone(importer.model_class) - self.assertIsNone(importer.model_name) self.assertIsNone(importer.key) self.assertEqual(importer.fields, []) self.assertIsNone(importer.host_system_title) @@ -192,6 +193,73 @@ class TestImporter(TestCase): keys = importer.get_deletion_keys() self.assertEqual(keys, set(['delete-me'])) + def test_model_name_attr(self): + # default is None + importer = importers.Importer() + self.assertIsNone(importer.model_name) + + # but may be overridden via init kwarg + importer = importers.Importer(model_name='Foo') + self.assertEqual(importer.model_name, 'Foo') + + # or may inherit its value from 'model_class' + class Foo: + pass + importer = importers.Importer(model_class=Foo) + self.assertEqual(importer.model_name, 'Foo') + + def test_batch_size_attr(self): + # default is 200 + importer = importers.Importer() + self.assertEqual(importer.batch_size, 200) + + # but may be overridden via init kwarg + importer = importers.Importer(batch_size=0) + self.assertEqual(importer.batch_size, 0) + importer = importers.Importer(batch_size=42) + self.assertEqual(importer.batch_size, 42) + + # batch size determines how often flush occurs + data = [{'id': i} for i in range(1, 101)] + importer = importers.Importer(model_name='Foo', key='id', fields=['id'], empty_local_data=True) + with patch.object(importer, 'create_object'): # just mock that out + with patch.object(importer, 'flush_create_update') as flush: + + # 4 batches @ 33/per + importer.import_data(host_data=data, batch_size=33) + self.assertEqual(flush.call_count, 4) + flush.reset_mock() + + # 3 batches @ 34/per + importer.import_data(host_data=data, batch_size=34) + self.assertEqual(flush.call_count, 3) + flush.reset_mock() + + # 2 batches @ 50/per + importer.import_data(host_data=data, batch_size=100) + self.assertEqual(flush.call_count, 2) + flush.reset_mock() + + # one extra/final flush happens, whenever the total number of + # changes happens to match the batch size... + + # 1 batch @ 100/per, plus final flush + importer.import_data(host_data=data, batch_size=100) + self.assertEqual(flush.call_count, 2) + flush.reset_mock() + + # 1 batch @ 200/per + importer.import_data(host_data=data, batch_size=200) + self.assertEqual(flush.call_count, 1) + flush.reset_mock() + + # one extra/final flush also happens when batching is disabled + + # 100 "batches" @ 0/per, plus final flush + importer.import_data(host_data=data, batch_size=0) + self.assertEqual(flush.call_count, 101) + flush.reset_mock() + class TestFromQuery(RattailTestCase): @@ -210,6 +278,57 @@ class TestFromQuery(RattailTestCase): class TestBulkImporter(TestCase, BulkImporterBattery): importer_class = importers.BulkImporter + def test_batch_size_attr(self): + # default is 200 + importer = importers.BulkImporter() + self.assertEqual(importer.batch_size, 200) + + # but may be overridden via init kwarg + importer = importers.BulkImporter(batch_size=0) + self.assertEqual(importer.batch_size, 0) + importer = importers.BulkImporter(batch_size=42) + self.assertEqual(importer.batch_size, 42) + + # batch size determines how often flush occurs + data = [{'id': i} for i in range(1, 101)] + importer = importers.BulkImporter(model_name='Foo', key='id', fields=['id'], empty_local_data=True) + with patch.object(importer, 'create_object'): # just mock that out + with patch.object(importer, 'flush_create_update') as flush: + + # 4 batches @ 33/per + importer.import_data(host_data=data, batch_size=33) + self.assertEqual(flush.call_count, 4) + flush.reset_mock() + + # 3 batches @ 34/per + importer.import_data(host_data=data, batch_size=34) + self.assertEqual(flush.call_count, 3) + flush.reset_mock() + + # 2 batches @ 50/per + importer.import_data(host_data=data, batch_size=100) + self.assertEqual(flush.call_count, 2) + flush.reset_mock() + + # one extra/final flush happens, whenever the total number of + # changes happens to match the batch size... + + # 1 batch @ 100/per, plus final flush + importer.import_data(host_data=data, batch_size=100) + self.assertEqual(flush.call_count, 2) + flush.reset_mock() + + # 1 batch @ 200/per + importer.import_data(host_data=data, batch_size=200) + self.assertEqual(flush.call_count, 1) + flush.reset_mock() + + # one extra/final flush also happens when batching is disabled + + # 100 "batches" @ 0/per, plus final flush + importer.import_data(host_data=data, batch_size=0) + self.assertEqual(flush.call_count, 101) + flush.reset_mock() ###################################################################### diff --git a/rattail/tests/importing/test_postgresql.py b/rattail/tests/importing/test_postgresql.py index 846127dc28147a0c1cb4a05f0d9a523f400cbe57..bf91ba64d4f935254bd8528aec094fff84bceafe 100644 --- a/rattail/tests/importing/test_postgresql.py +++ b/rattail/tests/importing/test_postgresql.py @@ -39,17 +39,16 @@ class TestBulkToPostgreSQL(unittest.TestCase): kwargs.setdefault('fields', ['id']) # hack return pgimport.BulkToPostgreSQL(**kwargs) - def test_data_path(self): - importer = self.make_importer(config=None) - self.assertIsNone(importer.config) - self.assertRaises(AttributeError, getattr, importer, 'data_path') - importer.config = RattailConfig() - self.assertRaises(ConfigurationError, getattr, importer, 'data_path') - importer.config = self.config + def test_data_path_property(self): self.config.set('rattail', 'workdir', '/tmp') - self.assertEqual(importer.data_path, '/tmp/import_bulk_postgresql_None.csv') # no model yet - importer.model_class = Widget - self.assertEqual(importer.data_path, '/tmp/import_bulk_postgresql_Widget.csv') + importer = pgimport.BulkToPostgreSQL(config=self.config, fields=['id']) + + # path leverages model name, so default is None + self.assertEqual(importer.data_path, '/tmp/import_bulk_postgresql_None.csv') + + # but it will reflect model name + importer.model_name = 'Foo' + self.assertEqual(importer.data_path, '/tmp/import_bulk_postgresql_Foo.csv') def test_setup(self): importer = self.make_importer() diff --git a/rattail/tests/importing/test_sqlalchemy.py b/rattail/tests/importing/test_sqlalchemy.py index 0f693b4b548805baf5982e44df59f70a9268e924..c5674c97edb3b02ef45860dae40ab86453b0350e 100644 --- a/rattail/tests/importing/test_sqlalchemy.py +++ b/rattail/tests/importing/test_sqlalchemy.py @@ -4,6 +4,8 @@ from __future__ import unicode_literals, absolute_import from unittest import TestCase +from mock import patch + import sqlalchemy as sa from sqlalchemy import orm from sqlalchemy.orm.exc import MultipleResultsFound @@ -145,3 +147,9 @@ class TestToSQLAlchemy(TestCase): widget.id = 1 widget, original = importer.update_object(widget, {'id': 1}), widget self.assertIs(widget, original) + + def test_flush_create_update(self): + importer = saimport.ToSQLAlchemy(fields=['id'], session=self.session) + with patch.object(self.session, 'flush') as flush: + importer.flush_create_update() + flush.assert_called_once_with()