diff --git a/rattail/commands/importing.py b/rattail/commands/importing.py index c2887f92c6bfae597804fbb4f123874c06355093..2b1ccdebc2c863c2c6f57e4640177640a1ffe026 100644 --- a/rattail/commands/importing.py +++ b/rattail/commands/importing.py @@ -39,6 +39,7 @@ class ImportSubcommand(Subcommand): """ Base class for subcommands which use the (new) data importing system. """ + handler_spec = None # TODO: move this into Subcommand or something.. parent_name = None @@ -52,6 +53,8 @@ class ImportSubcommand(Subcommand): Subclasses must override this, and return a callable that creates an import handler instance which the command should use. """ + if self.handler_spec: + return load_object(self.handler_spec) raise NotImplementedError def get_handler(self, **kwargs): @@ -169,6 +172,7 @@ class ImportSubcommand(Subcommand): 'max_delete': args.max_delete, 'max_total': args.max_total, 'progress': self.progress, + 'args': args, } handler.import_data(*models, **kwargs) diff --git a/rattail/importing/__init__.py b/rattail/importing/__init__.py index 231990e47a3266ee2550a8d226620b667ac19a6e..58cec3f799edc100ec79e756dfc11f005322e727 100644 --- a/rattail/importing/__init__.py +++ b/rattail/importing/__init__.py @@ -30,5 +30,5 @@ from .importers import Importer, FromQuery from .sqlalchemy import FromSQLAlchemy, ToSQLAlchemy from .postgresql import BulkToPostgreSQL from .handlers import ImportHandler, FromSQLAlchemyHandler, ToSQLAlchemyHandler, BulkToPostgreSQLHandler -from .rattail import ToRattailHandler +from .rattail import FromRattailHandler, ToRattailHandler from . import model diff --git a/rattail/importing/rattail.py b/rattail/importing/rattail.py index 65ec820e973d47827266ad765a785acb3437dcf0..e8e6c2af542b3957754c80daf5a3ce10da6d6a77 100644 --- a/rattail/importing/rattail.py +++ b/rattail/importing/rattail.py @@ -33,6 +33,16 @@ from rattail.importing.sqlalchemy import FromSQLAlchemy from rattail.util import OrderedDict +class FromRattailHandler(FromSQLAlchemyHandler): + """ + Base class for import handlers which target a Rattail database on the local side. + """ + host_title = "Rattail" + + def make_host_session(self): + return Session() + + class ToRattailHandler(ToSQLAlchemyHandler): """ Base class for import handlers which target a Rattail database on the local side. @@ -43,7 +53,7 @@ class ToRattailHandler(ToSQLAlchemyHandler): return Session() -class FromRattailToRattail(FromSQLAlchemyHandler, ToRattailHandler): +class FromRattailToRattail(FromRattailHandler, ToRattailHandler): """ Handler for Rattail -> Rattail data import. """ diff --git a/rattail/tests/commands/test_importing.py b/rattail/tests/commands/test_importing.py index 22939f2f49a531028d2cae4078a3856aafdacba8..f664eb1e06ac9e389cfc8cbe4a422de5e4163baf 100644 --- a/rattail/tests/commands/test_importing.py +++ b/rattail/tests/commands/test_importing.py @@ -35,6 +35,9 @@ class TestImportSubcommandBasics(TestCase): def test_get_handler_factory(self): command = importing.ImportSubcommand() self.assertRaises(NotImplementedError, command.get_handler_factory) + command.handler_spec = 'rattail.importing.rattail:FromRattailToRattail' + factory = command.get_handler_factory() + self.assertIs(factory, FromRattailToRattail) def test_get_handler(self): diff --git a/rattail/tests/importing/__init__.py b/rattail/tests/importing/__init__.py index 33c95c3fbae5c6f77edec3a0a6688c28685fdacd..d157956682052732bad6fc2403f0695f65b9ce4f 100644 --- a/rattail/tests/importing/__init__.py +++ b/rattail/tests/importing/__init__.py @@ -2,87 +2,4 @@ from __future__ import unicode_literals, absolute_import -import copy -from contextlib import contextmanager - -from mock import patch - -from rattail.tests import NullProgress - - -class ImporterTester(object): - """ - Mixin for importer test suites. - """ - importer_class = None - sample_data = {} - - def make_importer(self, **kwargs): - if 'config' not in kwargs and hasattr(self, 'config'): - kwargs['config'] = self.config - kwargs.setdefault('progress', NullProgress) - return self.importer_class(**kwargs) - - def copy_data(self): - return copy.deepcopy(self.sample_data) - - @contextmanager - def host_data(self, data): - self._host_data = data - host_data = [self.importer.normalize_host_object(obj) for obj in data.itervalues()] - with patch.object(self.importer, 'normalize_host_data') as normalize: - normalize.return_value = host_data - yield - - @contextmanager - def local_data(self, data): - self._local_data = data - local_data = {} - for key, obj in data.iteritems(): - normal = self.importer.normalize_local_object(obj) - local_data[self.importer.get_key(normal)] = {'object': obj, 'data': normal} - with patch.object(self.importer, 'cache_local_data') as cache: - cache.return_value = local_data - yield - - def import_data(self, **kwargs): - self.result = self.importer.import_data(**kwargs) - - def assert_import_created(self, *keys): - created, updated, deleted = self.result - self.assertEqual(len(created), len(keys)) - for key in keys: - key = self.importer.get_key(self._host_data[key]) - found = False - for local_object, host_data in created: - if self.importer.get_key(host_data) == key: - found = True - break - if not found: - raise self.failureException("Key {} not created when importing with {}".format(key, self.importer)) - - def assert_import_updated(self, *keys): - created, updated, deleted = self.result - self.assertEqual(len(updated), len(keys)) - for key in keys: - key = self.importer.get_key(self._host_data[key]) - found = False - for local_object, local_data, host_data in updated: - if self.importer.get_key(local_data) == key: - found = True - break - if not found: - raise self.failureException("Key {} not updated when importing with {}".format(key, self.importer)) - - def assert_import_deleted(self, *keys): - created, updated, deleted = self.result - self.assertEqual(len(deleted), len(keys)) - for key in keys: - key = self.importer.get_key(self._local_data[key]) - found = False - for local_object, local_data in deleted: - if self.importer.get_key(local_data) == key: - found = True - break - if not found: - raise self.failureException("Key {} not deleted when importing with {}".format(key, self.importer)) +from .lib import ImporterTester diff --git a/rattail/tests/importing/lib.py b/rattail/tests/importing/lib.py new file mode 100644 index 0000000000000000000000000000000000000000..628acd3936d78a811e8aac857ec9e3c913322c6b --- /dev/null +++ b/rattail/tests/importing/lib.py @@ -0,0 +1,100 @@ +# -*- coding: utf-8 -*- + +from __future__ import unicode_literals, absolute_import + +import copy +from contextlib import contextmanager + +from mock import patch + +from rattail.tests import NullProgress + + +class ImporterTester(object): + """ + Mixin for importer test suites. + """ + handler_class = None + importer_class = None + sample_data = {} + + def make_handler(self, **kwargs): + if 'config' not in kwargs and hasattr(self, 'config'): + kwargs['config'] = self.config + return self.handler_class(**kwargs) + + def make_importer(self, **kwargs): + if 'config' not in kwargs and hasattr(self, 'config'): + kwargs['config'] = self.config + kwargs.setdefault('progress', NullProgress) + return self.importer_class(**kwargs) + + def copy_data(self): + return copy.deepcopy(self.sample_data) + + @contextmanager + def host_data(self, data): + self._host_data = data + host_data = [self.importer.normalize_host_object(obj) for obj in data.itervalues()] + with patch.object(self.importer, 'normalize_host_data') as normalize: + normalize.return_value = host_data + yield + + @contextmanager + def local_data(self, data): + self._local_data = data + local_data = {} + for key, obj in data.iteritems(): + normal = self.importer.normalize_local_object(obj) + local_data[self.importer.get_key(normal)] = {'object': obj, 'data': normal} + with patch.object(self.importer, 'cache_local_data') as cache: + cache.return_value = local_data + yield + + def import_data(self, host_data=None, local_data=None, **kwargs): + if host_data is None: + host_data = self.sample_data + if local_data is None: + local_data = self.sample_data + with self.host_data(host_data): + with self.local_data(local_data): + self.result = self.importer.import_data(**kwargs) + + def assert_import_created(self, *keys): + created, updated, deleted = self.result + self.assertEqual(len(created), len(keys)) + for key in keys: + key = self.importer.get_key(self._host_data[key]) + found = False + for local_object, host_data in created: + if self.importer.get_key(host_data) == key: + found = True + break + if not found: + raise self.failureException("Key {} not created when importing with {}".format(key, self.importer)) + + def assert_import_updated(self, *keys): + created, updated, deleted = self.result + self.assertEqual(len(updated), len(keys)) + for key in keys: + key = self.importer.get_key(self._host_data[key]) + found = False + for local_object, local_data, host_data in updated: + if self.importer.get_key(local_data) == key: + found = True + break + if not found: + raise self.failureException("Key {} not updated when importing with {}".format(key, self.importer)) + + def assert_import_deleted(self, *keys): + created, updated, deleted = self.result + self.assertEqual(len(deleted), len(keys)) + for key in keys: + key = self.importer.get_key(self._local_data[key]) + found = False + for local_object, local_data in deleted: + if self.importer.get_key(local_data) == key: + found = True + break + if not found: + raise self.failureException("Key {} not deleted when importing with {}".format(key, self.importer)) diff --git a/rattail/tests/importing/test_importers.py b/rattail/tests/importing/test_importers.py index f7f03f0b8471dc7e51f7a614854f52346d2428e8..934c4a918ed73b3a24dda4fc53eeb385cebe87ef 100644 --- a/rattail/tests/importing/test_importers.py +++ b/rattail/tests/importing/test_importers.py @@ -204,17 +204,13 @@ class TestMockImporter(ImporterTester, TestCase): def test_create(self): local = self.copy_data() del local['32oz'] - with self.host_data(self.sample_data): - with self.local_data(local): - self.import_data() + self.import_data(local_data=local) self.assert_import_created('32oz') self.assert_import_updated() self.assert_import_deleted() def test_create_empty(self): - with self.host_data({}): - with self.local_data({}): - self.import_data() + self.import_data(host_data={}, local_data={}) self.assert_import_created() self.assert_import_updated() self.assert_import_deleted() @@ -222,9 +218,7 @@ class TestMockImporter(ImporterTester, TestCase): def test_update(self): local = self.copy_data() local['16oz']['description'] = "wrong description" - with self.host_data(self.sample_data): - with self.local_data(local): - self.import_data() + self.import_data(local_data=local) self.assert_import_created() self.assert_import_updated('16oz') self.assert_import_deleted() @@ -232,9 +226,7 @@ class TestMockImporter(ImporterTester, TestCase): def test_delete(self): local = self.copy_data() local['bogus'] = {'upc': '00000000000000', 'description': "Delete Me"} - with self.host_data(self.sample_data): - with self.local_data(local): - self.import_data(delete=True) + self.import_data(local_data=local, delete=True) self.assert_import_created() self.assert_import_updated() self.assert_import_deleted('bogus') @@ -242,9 +234,7 @@ class TestMockImporter(ImporterTester, TestCase): def test_duplicate(self): host = self.copy_data() host['32oz-dupe'] = host['32oz'] - with self.host_data(host): - with self.local_data(self.sample_data): - self.import_data() + self.import_data(host_data=host) self.assert_import_created() self.assert_import_updated() self.assert_import_deleted() @@ -253,9 +243,7 @@ class TestMockImporter(ImporterTester, TestCase): local = self.copy_data() del local['16oz'] del local['1gal'] - with self.host_data(self.sample_data): - with self.local_data(local): - self.import_data(max_create=1) + self.import_data(local_data=local, max_create=1) self.assert_import_created('16oz') self.assert_import_updated() self.assert_import_deleted() @@ -264,9 +252,7 @@ class TestMockImporter(ImporterTester, TestCase): local = self.copy_data() del local['16oz'] del local['1gal'] - with self.host_data(self.sample_data): - with self.local_data(local): - self.import_data(max_total=1) + self.import_data(local_data=local, max_total=1) self.assert_import_created('16oz') self.assert_import_updated() self.assert_import_deleted() @@ -275,9 +261,7 @@ class TestMockImporter(ImporterTester, TestCase): local = self.copy_data() local['16oz']['description'] = "wrong" local['1gal']['description'] = "wrong" - with self.host_data(self.sample_data): - with self.local_data(local): - self.import_data(max_update=1) + self.import_data(local_data=local, max_update=1) self.assert_import_created() self.assert_import_updated('16oz') self.assert_import_deleted() @@ -286,9 +270,7 @@ class TestMockImporter(ImporterTester, TestCase): local = self.copy_data() local['16oz']['description'] = "wrong" local['1gal']['description'] = "wrong" - with self.host_data(self.sample_data): - with self.local_data(local): - self.import_data(max_total=1) + self.import_data(local_data=local, max_total=1) self.assert_import_created() self.assert_import_updated('16oz') self.assert_import_deleted() @@ -297,9 +279,7 @@ class TestMockImporter(ImporterTester, TestCase): local = self.copy_data() local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"} local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"} - with self.host_data(self.sample_data): - with self.local_data(local): - self.import_data(delete=True, max_delete=1) + self.import_data(local_data=local, delete=True, max_delete=1) self.assert_import_created() self.assert_import_updated() self.assert_import_deleted('bogus1') @@ -308,9 +288,7 @@ class TestMockImporter(ImporterTester, TestCase): local = self.copy_data() local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"} local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"} - with self.host_data(self.sample_data): - with self.local_data(local): - self.import_data(delete=True, max_total=1) + self.import_data(local_data=local, delete=True, max_total=1) self.assert_import_created() self.assert_import_updated() self.assert_import_deleted('bogus1') @@ -322,9 +300,7 @@ class TestMockImporter(ImporterTester, TestCase): local['1gal']['description'] = "wrong" local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"} local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"} - with self.host_data(self.sample_data): - with self.local_data(local): - self.import_data(delete=True, max_total=3) + self.import_data(local_data=local, delete=True, max_total=3) self.assert_import_created('16oz') self.assert_import_updated('32oz', '1gal') self.assert_import_deleted() diff --git a/rattail/tests/importing/test_rattail.py b/rattail/tests/importing/test_rattail.py index a5d85a04d6668ddd24548caf21d91e46a88961c8..aaa5a5b8de490eff8dae744231b16be976df95ad 100644 --- a/rattail/tests/importing/test_rattail.py +++ b/rattail/tests/importing/test_rattail.py @@ -11,6 +11,7 @@ from fixture import TempIO from rattail.db import model, Session, SessionBase, auth from rattail.importing import rattail as rattail_importing from rattail.tests import RattailMixin, RattailTestCase +from rattail.tests.importing import ImporterTester class DualRattailMixin(RattailMixin): @@ -43,10 +44,18 @@ class DualRattailTestCase(DualRattailMixin, TestCase): pass -class TestFromRattailToRattail(DualRattailTestCase): +class TestFromRattailHandler(RattailTestCase, ImporterTester): + handler_class = rattail_importing.FromRattailHandler + + def test_make_host_session(self): + handler = self.make_handler() + session = handler.make_host_session() + self.assertIsInstance(session, SessionBase) + self.assertIs(session.bind, self.config.rattail_engine) + - def make_handler(self, **kwargs): - return rattail_importing.FromRattailToRattail(self.config, **kwargs) +class TestFromRattailToRattail(DualRattailTestCase, ImporterTester): + handler_class = rattail_importing.FromRattailToRattail def test_host_title(self): handler = self.make_handler()