diff --git a/rattail/importing/__init__.py b/rattail/importing/__init__.py index 58cec3f799edc100ec79e756dfc11f005322e727..397a1cee845d818c99f932490c9b68ead4bdfa77 100644 --- a/rattail/importing/__init__.py +++ b/rattail/importing/__init__.py @@ -26,9 +26,9 @@ Data Importing Framework from __future__ import unicode_literals, absolute_import -from .importers import Importer, FromQuery +from .importers import Importer, FromQuery, BulkImporter from .sqlalchemy import FromSQLAlchemy, ToSQLAlchemy from .postgresql import BulkToPostgreSQL -from .handlers import ImportHandler, FromSQLAlchemyHandler, ToSQLAlchemyHandler, BulkToPostgreSQLHandler +from .handlers import ImportHandler, BulkImportHandler, FromSQLAlchemyHandler, ToSQLAlchemyHandler from .rattail import FromRattailHandler, ToRattailHandler from . import model diff --git a/rattail/importing/handlers.py b/rattail/importing/handlers.py index 27a01fc214b6eeff7302bbabdcd48ed334e932ee..268874482c2159fc4d64a0fc5805698cc1534334 100644 --- a/rattail/importing/handlers.py +++ b/rattail/importing/handlers.py @@ -229,6 +229,55 @@ class ImportHandler(object): log.info("warning email was sent for {} -> {} import".format(self.host_title, self.local_title)) +class BulkImportHandler(ImportHandler): + """ + Base class for bulk import handlers. + """ + + def import_data(self, *keys, **kwargs): + """ + Import all data for the given importer/model keys. + """ + # TODO: still need to refactor much of this so can share with parent class + self.import_began = make_utc(datetime.datetime.utcnow(), tzinfo=True) + if 'dry_run' in kwargs: + self.dry_run = kwargs['dry_run'] + self.progress = kwargs.pop('progress', getattr(self, 'progress', None)) + self.warnings = kwargs.pop('warnings', False) + kwargs.update({'dry_run': self.dry_run, + 'progress': self.progress}) + self.setup() + self.begin_transaction() + changes = OrderedDict() + + try: + for key in keys: + importer = self.get_importer(key, **kwargs) + if not importer: + log.warning("skipping unknown importer: {}".format(key)) + continue + + created = importer.import_data() + log.info("{} -> {}: added {}, updated 0, deleted 0 {} records".format( + self.host_title, self.local_title, created, key)) + if created: + changes[key] = created + except: + if self.commit_host_partial and not self.dry_run: + log.warning("{host} -> {local}: committing partial transaction on host {host} (despite error)".format( + host=self.host_title, local=self.local_title)) + self.commit_host_transaction() + raise + else: + if self.dry_run: + self.rollback_transaction() + else: + self.commit_transaction() + + self.teardown() + return changes + + class FromSQLAlchemyHandler(ImportHandler): """ Handler for imports for which the host data source is represented by a @@ -292,42 +341,7 @@ class ToSQLAlchemyHandler(ImportHandler): self.session = None -class BulkToPostgreSQLHandler(ToSQLAlchemyHandler): +class BulkToPostgreSQLHandler(BulkImportHandler): """ Handler for bulk imports which target PostgreSQL on the local side. """ - - def import_data(self, *keys, **kwargs): - """ - Import all data for the given importer/model keys. - """ - # TODO: still need to refactor much of this so can share with parent class - self.import_began = make_utc(datetime.datetime.utcnow(), tzinfo=True) - if 'dry_run' in kwargs: - self.dry_run = kwargs['dry_run'] - self.progress = kwargs.pop('progress', getattr(self, 'progress', None)) - self.warnings = kwargs.pop('warnings', False) - kwargs.update({'dry_run': self.dry_run, - 'progress': self.progress}) - self.setup() - self.begin_transaction() - changes = OrderedDict() - - for key in keys: - importer = self.get_importer(key, **kwargs) - if not importer: - log.warning("skipping unknown importer: {}".format(key)) - continue - - created = importer.import_data() - log.info("{} -> {}: added {}, updated 0, deleted 0 {} records".format( - self.host_title, self.local_title, created, key)) - if created: - changes[key] = created - - if self.dry_run: - self.rollback_transaction() - else: - self.commit_transaction() - self.teardown() - return changes diff --git a/rattail/importing/importers.py b/rattail/importing/importers.py index 612bae84fe4a7f5445e9fe5c8c0f9a75f1373559..1a53a0aad38214eb9c7de671b3fd37b9fc06aa96 100644 --- a/rattail/importing/importers.py +++ b/rattail/importing/importers.py @@ -456,3 +456,54 @@ class FromQuery(Importer): Returns (raw) query results as a sequence. """ return QuerySequence(self.query()) + + +class BulkImporter(Importer): + """ + Base class for bulk data importers which target PostgreSQL on the local side. + """ + + def import_data(self, host_data=None, now=None, **kwargs): + self.now = now or make_utc(datetime.datetime.utcnow(), tzinfo=True) + if kwargs: + self._setup(**kwargs) + self.setup() + if host_data is None: + host_data = self.normalize_host_data() + created = self._import_create(host_data) + self.teardown() + return created + + def _import_create(self, data): + count = len(data) + if not count: + return 0 + created = count + + prog = None + if self.progress: + prog = self.progress("Importing {} data".format(self.model_name), count) + + for i, host_data in enumerate(data, 1): + + key = self.get_key(host_data) + self.create_object(key, host_data) + if self.max_create and i >= self.max_create: + log.warning("max of {} *created* records has been reached; stopping now".format(self.max_create)) + created = i + break + + if prog: + prog.update(i) + if prog: + prog.destroy() + + self.flush_create() + return created + + def flush_create(self): + """ + Perform any final steps to "flush" the created data here. Note that + the importer's handler is still responsible for actually committing + changes to the local system, if applicable. + """ diff --git a/rattail/importing/postgresql.py b/rattail/importing/postgresql.py index 607fb94ce59bc717991c149ae7d4458c066909f7..c484d50d18c1e278b5dbc32b35a34ea0ed360ab4 100644 --- a/rattail/importing/postgresql.py +++ b/rattail/importing/postgresql.py @@ -30,14 +30,14 @@ import os import datetime import logging -from rattail.importing.sqlalchemy import ToSQLAlchemy +from rattail.importing import BulkImporter, ToSQLAlchemy from rattail.time import make_utc log = logging.getLogger(__name__) -class BulkToPostgreSQL(ToSQLAlchemy): +class BulkToPostgreSQL(BulkImporter, ToSQLAlchemy): """ Base class for bulk data importers which target PostgreSQL on the local side. """ @@ -55,44 +55,6 @@ class BulkToPostgreSQL(ToSQLAlchemy): os.remove(self.data_path) self.data_buffer = None - def import_data(self, host_data=None, now=None, **kwargs): - self.now = now or make_utc(datetime.datetime.utcnow(), tzinfo=True) - if kwargs: - self._setup(**kwargs) - self.setup() - if host_data is None: - host_data = self.normalize_host_data() - created = self._import_create(host_data) - self.teardown() - return created - - def _import_create(self, data): - count = len(data) - if not count: - return 0 - created = count - - prog = None - if self.progress: - prog = self.progress("Importing {} data".format(self.model_name), count) - - for i, host_data in enumerate(data, 1): - - key = self.get_key(host_data) - self.create_object(key, host_data) - if self.max_create and i >= self.max_create: - log.warning("max of {} *created* records has been reached; stopping now".format(self.max_create)) - created = i - break - - if prog: - prog.update(i) - if prog: - prog.destroy() - - self.commit_create() - return created - def create_object(self, key, data): data = self.prep_data_for_postgres(data) self.data_buffer.write('{}\n'.format('\t'.join([data[field] for field in self.fields])).encode('utf-8')) @@ -121,7 +83,7 @@ class BulkToPostgreSQL(ToSQLAlchemy): return unicode(value) - def commit_create(self): + def flush_create(self): log.info("copying {} data from buffer to PostgreSQL".format(self.model_name)) self.data_buffer.close() self.data_buffer = open(self.data_path, 'rb') diff --git a/rattail/importing/rattail_bulk.py b/rattail/importing/rattail_bulk.py index 2dcf93c32b2b82bfe3f8449cf3694df32e0aa4d5..4d60ce4ffc26872f92c82c43ab4c8219b14593b6 100644 --- a/rattail/importing/rattail_bulk.py +++ b/rattail/importing/rattail_bulk.py @@ -31,7 +31,7 @@ from rattail.util import OrderedDict from rattail.importing.rattail import FromRattailToRattail, FromRattail -class BulkFromRattailToRattail(FromRattailToRattail, importing.BulkToPostgreSQLHandler): +class BulkFromRattailToRattail(FromRattailToRattail, importing.BulkImportHandler): """ Handler for Rattail -> Rattail bulk data import. """ diff --git a/rattail/tests/importing/test_handlers.py b/rattail/tests/importing/test_handlers.py index edada94f1fb108d0e9b8e929cdf4bc47bcb8c9fb..0797cebd7bd1fa9f4f1de333748143ad8a306d72 100644 --- a/rattail/tests/importing/test_handlers.py +++ b/rattail/tests/importing/test_handlers.py @@ -19,12 +19,12 @@ from rattail.tests.importing.test_importers import MockImporter from rattail.tests.importing.test_postgresql import MockBulkImporter -class TestImportHandler(unittest.TestCase): +class ImportHandlerBattery(ImporterTester): def test_init(self): # vanilla - handler = handlers.ImportHandler() + handler = self.handler_class() self.assertEqual(handler.importers, {}) self.assertEqual(handler.get_importers(), {}) self.assertEqual(handler.get_importer_keys(), []) @@ -32,34 +32,34 @@ class TestImportHandler(unittest.TestCase): self.assertFalse(handler.commit_host_partial) # with config - handler = handlers.ImportHandler() + handler = self.handler_class() self.assertIsNone(handler.config) config = RattailConfig() - handler = handlers.ImportHandler(config=config) + handler = self.handler_class(config=config) self.assertIs(handler.config, config) # dry run - handler = handlers.ImportHandler() + handler = self.handler_class() self.assertFalse(handler.dry_run) - handler = handlers.ImportHandler(dry_run=True) + handler = self.handler_class(dry_run=True) self.assertTrue(handler.dry_run) # extra kwarg - handler = handlers.ImportHandler() + handler = self.handler_class() self.assertRaises(AttributeError, getattr, handler, 'foo') - handler = handlers.ImportHandler(foo='bar') + handler = self.handler_class(foo='bar') self.assertEqual(handler.foo, 'bar') def test_get_importer(self): get_importers = Mock(return_value={'foo': Importer}) # no importers - handler = handlers.ImportHandler() + handler = self.make_handler() self.assertIsNone(handler.get_importer('foo')) # no config - with patch.object(handlers.ImportHandler, 'get_importers', get_importers): - handler = handlers.ImportHandler() + with patch.object(self.handler_class, 'get_importers', get_importers): + handler = self.handler_class() importer = handler.get_importer('foo') self.assertIs(type(importer), Importer) self.assertIsNone(importer.config) @@ -67,26 +67,26 @@ class TestImportHandler(unittest.TestCase): # with config config = RattailConfig() - with patch.object(handlers.ImportHandler, 'get_importers', get_importers): - handler = handlers.ImportHandler(config=config) + with patch.object(self.handler_class, 'get_importers', get_importers): + handler = self.handler_class(config=config) importer = handler.get_importer('foo') self.assertIs(type(importer), Importer) self.assertIs(importer.config, config) self.assertIs(importer.handler, handler) # dry run - with patch.object(handlers.ImportHandler, 'get_importers', get_importers): - handler = handlers.ImportHandler() + with patch.object(self.handler_class, 'get_importers', get_importers): + handler = self.handler_class() importer = handler.get_importer('foo') self.assertFalse(importer.dry_run) - with patch.object(handlers.ImportHandler, 'get_importers', get_importers): - handler = handlers.ImportHandler(dry_run=True) + with patch.object(self.handler_class, 'get_importers', get_importers): + handler = self.handler_class(dry_run=True) importer = handler.get_importer('foo') self.assertTrue(handler.dry_run) # host title - with patch.object(handlers.ImportHandler, 'get_importers', get_importers): - handler = handlers.ImportHandler() + with patch.object(self.handler_class, 'get_importers', get_importers): + handler = self.handler_class() importer = handler.get_importer('foo') self.assertIsNone(importer.host_system_title) handler.host_title = "Foo" @@ -94,8 +94,8 @@ class TestImportHandler(unittest.TestCase): self.assertEqual(importer.host_system_title, "Foo") # extra kwarg - with patch.object(handlers.ImportHandler, 'get_importers', get_importers): - handler = handlers.ImportHandler() + with patch.object(self.handler_class, 'get_importers', get_importers): + handler = self.handler_class() importer = handler.get_importer('foo') self.assertRaises(AttributeError, getattr, importer, 'bar') importer = handler.get_importer('foo', bar='baz') @@ -104,15 +104,15 @@ class TestImportHandler(unittest.TestCase): def test_get_importer_kwargs(self): # empty by default - handler = handlers.ImportHandler() + handler = self.make_handler() self.assertEqual(handler.get_importer_kwargs('foo'), {}) # extra kwargs are preserved - handler = handlers.ImportHandler() + handler = self.make_handler() self.assertEqual(handler.get_importer_kwargs('foo', bar='baz'), {'bar': 'baz'}) def test_begin_transaction(self): - handler = handlers.ImportHandler() + handler = self.make_handler() with patch.object(handler, 'begin_host_transaction') as begin_host: with patch.object(handler, 'begin_local_transaction') as begin_local: handler.begin_transaction() @@ -120,15 +120,15 @@ class TestImportHandler(unittest.TestCase): begin_local.assert_called_once_with() def test_begin_host_transaction(self): - handler = handlers.ImportHandler() + handler = self.make_handler() handler.begin_host_transaction() def test_begin_local_transaction(self): - handler = handlers.ImportHandler() + handler = self.make_handler() handler.begin_local_transaction() def test_commit_transaction(self): - handler = handlers.ImportHandler() + handler = self.make_handler() with patch.object(handler, 'commit_host_transaction') as commit_host: with patch.object(handler, 'commit_local_transaction') as commit_local: handler.commit_transaction() @@ -136,15 +136,15 @@ class TestImportHandler(unittest.TestCase): commit_local.assert_called_once_with() def test_commit_host_transaction(self): - handler = handlers.ImportHandler() + handler = self.make_handler() handler.commit_host_transaction() def test_commit_local_transaction(self): - handler = handlers.ImportHandler() + handler = self.make_handler() handler.commit_local_transaction() def test_rollback_transaction(self): - handler = handlers.ImportHandler() + handler = self.make_handler() with patch.object(handler, 'rollback_host_transaction') as rollback_host: with patch.object(handler, 'rollback_local_transaction') as rollback_local: handler.rollback_transaction() @@ -152,24 +152,22 @@ class TestImportHandler(unittest.TestCase): rollback_local.assert_called_once_with() def test_rollback_host_transaction(self): - handler = handlers.ImportHandler() + handler = self.make_handler() handler.rollback_host_transaction() def test_rollback_local_transaction(self): - handler = handlers.ImportHandler() + handler = self.make_handler() handler.rollback_local_transaction() def test_import_data(self): - - # normal - handler = handlers.ImportHandler() + handler = self.make_handler() result = handler.import_data() self.assertEqual(result, {}) def test_import_data_dry_run(self): # as init kwarg - handler = handlers.ImportHandler(dry_run=True) + handler = self.make_handler(dry_run=True) with patch.object(handler, 'commit_transaction') as commit: with patch.object(handler, 'rollback_transaction') as rollback: handler.import_data() @@ -178,7 +176,7 @@ class TestImportHandler(unittest.TestCase): self.assertTrue(handler.dry_run) # as import kwarg - handler = handlers.ImportHandler() + handler = self.make_handler() with patch.object(handler, 'commit_transaction') as commit: with patch.object(handler, 'rollback_transaction') as rollback: handler.import_data(dry_run=True) @@ -187,11 +185,10 @@ class TestImportHandler(unittest.TestCase): self.assertTrue(handler.dry_run) def test_import_data_invalid_model(self): + handler = self.make_handler() importer = Mock() importer.import_data.return_value = [], [], [] FooImporter = Mock(return_value=importer) - - handler = handlers.ImportHandler() handler.importers = {'Foo': FooImporter} handler.import_data('Foo') @@ -206,10 +203,9 @@ class TestImportHandler(unittest.TestCase): self.assertFalse(importer.called) def test_import_data_with_changes(self): + handler = self.make_handler() importer = Mock() FooImporter = Mock(return_value=importer) - - handler = handlers.ImportHandler() handler.importers = {'Foo': FooImporter} importer.import_data.return_value = [], [], [] @@ -223,11 +219,10 @@ class TestImportHandler(unittest.TestCase): process.assert_called_once_with({'Foo': ([1], [2], [3])}) def test_import_data_commit_host_partial(self): + handler = self.make_handler() importer = Mock() importer.import_data.side_effect = ValueError FooImporter = Mock(return_value=importer) - - handler = handlers.ImportHandler() handler.importers = {'Foo': FooImporter} handler.commit_host_partial = False @@ -240,6 +235,47 @@ class TestImportHandler(unittest.TestCase): self.assertRaises(ValueError, handler.import_data, 'Foo') commit.assert_called_once_with() + +class BulkImportHandlerBattery(ImportHandlerBattery): + + def test_import_data_invalid_model(self): + handler = self.make_handler() + importer = Mock() + importer.import_data.return_value = 0 + FooImporter = Mock(return_value=importer) + handler.importers = {'Foo': FooImporter} + + handler.import_data('Foo') + self.assertEqual(FooImporter.call_count, 1) + importer.import_data.assert_called_once_with() + + FooImporter.reset_mock() + importer.reset_mock() + + handler.import_data('Missing') + self.assertFalse(FooImporter.called) + self.assertFalse(importer.called) + + def test_import_data_with_changes(self): + handler = self.make_handler() + importer = Mock() + FooImporter = Mock(return_value=importer) + handler.importers = {'Foo': FooImporter} + + importer.import_data.return_value = 0 + with patch.object(handler, 'process_changes') as process: + handler.import_data('Foo') + self.assertFalse(process.called) + + importer.import_data.return_value = 3 + with patch.object(handler, 'process_changes') as process: + handler.import_data('Foo') + self.assertFalse(process.called) + + +class TestImportHandler(unittest.TestCase, ImportHandlerBattery): + handler_class = handlers.ImportHandler + @patch('rattail.importing.handlers.send_email') def test_process_changes_sends_email(self, send_email): handler = handlers.ImportHandler() @@ -265,6 +301,10 @@ class TestImportHandler(unittest.TestCase): self.assertEqual(send_email.call_count, 1) +class TestBulkImportHandler(unittest.TestCase, BulkImportHandlerBattery): + handler_class = handlers.BulkImportHandler + + ###################################################################### # fake import handler, tested mostly for basic coverage ###################################################################### @@ -566,7 +606,7 @@ class MockBulkImportHandler(handlers.BulkToPostgreSQLHandler): return Session() -class TestBulkImportHandler(RattailTestCase, ImporterTester): +class TestBulkImportHandlerOld(RattailTestCase, ImporterTester): importer_class = MockBulkImporter diff --git a/rattail/tests/importing/test_importers.py b/rattail/tests/importing/test_importers.py index 934c4a918ed73b3a24dda4fc53eeb385cebe87ef..d814a763de9bf08003e8398a26568e4c81c71daa 100644 --- a/rattail/tests/importing/test_importers.py +++ b/rattail/tests/importing/test_importers.py @@ -4,7 +4,7 @@ from __future__ import unicode_literals, absolute_import from unittest import TestCase -from mock import Mock, patch +from mock import Mock, patch, call from rattail.db import model from rattail.db.util import QuerySequence @@ -13,6 +13,49 @@ from rattail.tests import NullProgress, RattailTestCase from rattail.tests.importing import ImporterTester +class ImporterBattery(ImporterTester): + """ + Battery of tests which can hopefully be ran for any non-bulk importer. + """ + + def test_import_data_empty(self): + importer = self.make_importer() + result = importer.import_data() + self.assertEqual(result, {}) + + def test_import_data_dry_run(self): + importer = self.make_importer() + self.assertFalse(importer.dry_run) + importer.import_data(dry_run=True) + self.assertTrue(importer.dry_run) + + def test_import_data_create(self): + importer = self.make_importer() + with patch.object(importer, 'get_key', lambda k: k): + with patch.object(importer, 'create_object') as create: + importer.import_data(host_data=[1, 2, 3]) + self.assertEqual(create.call_args_list, [ + call(1, 1), call(2, 2), call(3, 3)]) + + def test_import_data_max_create(self): + importer = self.make_importer() + with patch.object(importer, 'get_key', lambda k: k): + with patch.object(importer, 'create_object') as create: + importer.import_data(host_data=[1, 2, 3], max_create=1) + self.assertEqual(create.call_args_list, [call(1, 1)]) + + +class BulkImporterBattery(ImporterBattery): + """ + Battery of tests which can hopefully be ran for any bulk importer. + """ + + def test_import_data_empty(self): + importer = self.make_importer() + result = importer.import_data() + self.assertEqual(result, 0) + + class TestImporter(TestCase): def test_init(self): @@ -164,6 +207,11 @@ class TestFromQuery(RattailTestCase): self.assertIsInstance(objects, QuerySequence) +class TestBulkImporter(TestCase, BulkImporterBattery): + importer_class = importers.BulkImporter + + + ###################################################################### # fake importer class, tested mostly for basic coverage ######################################################################