Changeset - 7a2ef3551801
[Not reviewed]
0 7 0
Lance Edgar - 8 years ago 2016-05-16 21:09:19
ledgar@sacfoodcoop.com
Add `BulkImporter` and `BulkImportHandler` base classes
7 files changed with 240 insertions and 125 deletions:
0 comments (0 inline, 0 general)
rattail/importing/__init__.py
Show inline comments
 
@@ -26,9 +26,9 @@ Data Importing Framework
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from .importers import Importer, FromQuery
 
from .importers import Importer, FromQuery, BulkImporter
 
from .sqlalchemy import FromSQLAlchemy, ToSQLAlchemy
 
from .postgresql import BulkToPostgreSQL
 
from .handlers import ImportHandler, FromSQLAlchemyHandler, ToSQLAlchemyHandler, BulkToPostgreSQLHandler
 
from .handlers import ImportHandler, BulkImportHandler, FromSQLAlchemyHandler, ToSQLAlchemyHandler
 
from .rattail import FromRattailHandler, ToRattailHandler
 
from . import model
rattail/importing/handlers.py
Show inline comments
 
@@ -229,6 +229,55 @@ class ImportHandler(object):
 
        log.info("warning email was sent for {} -> {} import".format(self.host_title, self.local_title))
 

	
 

	
 
class BulkImportHandler(ImportHandler):
 
    """
 
    Base class for bulk import handlers.
 
    """
 

	
 
    def import_data(self, *keys, **kwargs):
 
        """
 
        Import all data for the given importer/model keys.
 
        """
 
        # TODO: still need to refactor much of this so can share with parent class
 
        self.import_began = make_utc(datetime.datetime.utcnow(), tzinfo=True)
 
        if 'dry_run' in kwargs:
 
            self.dry_run = kwargs['dry_run']
 
        self.progress = kwargs.pop('progress', getattr(self, 'progress', None))
 
        self.warnings = kwargs.pop('warnings', False)
 
        kwargs.update({'dry_run': self.dry_run,
 
                       'progress': self.progress})
 
        self.setup()
 
        self.begin_transaction()
 
        changes = OrderedDict()
 

	
 
        try:
 
            for key in keys:
 
                importer = self.get_importer(key, **kwargs)
 
                if not importer:
 
                    log.warning("skipping unknown importer: {}".format(key))
 
                    continue
 

	
 
                created = importer.import_data()
 
                log.info("{} -> {}: added {}, updated 0, deleted 0 {} records".format(
 
                    self.host_title, self.local_title, created, key))
 
                if created:
 
                    changes[key] = created
 
        except:
 
            if self.commit_host_partial and not self.dry_run:
 
                log.warning("{host} -> {local}: committing partial transaction on host {host} (despite error)".format(
 
                    host=self.host_title, local=self.local_title))
 
                self.commit_host_transaction()
 
            raise
 
        else:
 
            if self.dry_run:
 
                self.rollback_transaction()
 
            else:
 
                self.commit_transaction()
 

	
 
        self.teardown()
 
        return changes
 

	
 

	
 
class FromSQLAlchemyHandler(ImportHandler):
 
    """
 
    Handler for imports for which the host data source is represented by a
 
@@ -292,42 +341,7 @@ class ToSQLAlchemyHandler(ImportHandler):
 
        self.session = None
 

	
 

	
 
class BulkToPostgreSQLHandler(ToSQLAlchemyHandler):
 
class BulkToPostgreSQLHandler(BulkImportHandler):
 
    """
 
    Handler for bulk imports which target PostgreSQL on the local side.
 
    """
 

	
 
    def import_data(self, *keys, **kwargs):
 
        """
 
        Import all data for the given importer/model keys.
 
        """
 
        # TODO: still need to refactor much of this so can share with parent class
 
        self.import_began = make_utc(datetime.datetime.utcnow(), tzinfo=True)
 
        if 'dry_run' in kwargs:
 
            self.dry_run = kwargs['dry_run']
 
        self.progress = kwargs.pop('progress', getattr(self, 'progress', None))
 
        self.warnings = kwargs.pop('warnings', False)
 
        kwargs.update({'dry_run': self.dry_run,
 
                       'progress': self.progress})
 
        self.setup()
 
        self.begin_transaction()
 
        changes = OrderedDict()
 

	
 
        for key in keys:
 
            importer = self.get_importer(key, **kwargs)
 
            if not importer:
 
                log.warning("skipping unknown importer: {}".format(key))
 
                continue
 

	
 
            created = importer.import_data()
 
            log.info("{} -> {}: added {}, updated 0, deleted 0 {} records".format(
 
                self.host_title, self.local_title, created, key))
 
            if created:
 
                changes[key] = created
 

	
 
        if self.dry_run:
 
            self.rollback_transaction()
 
        else:
 
            self.commit_transaction()
 
        self.teardown()
 
        return changes
rattail/importing/importers.py
Show inline comments
 
@@ -456,3 +456,54 @@ class FromQuery(Importer):
 
        Returns (raw) query results as a sequence.
 
        """
 
        return QuerySequence(self.query())
 

	
 

	
 
class BulkImporter(Importer):
 
    """
 
    Base class for bulk data importers which target PostgreSQL on the local side.
 
    """
 

	
 
    def import_data(self, host_data=None, now=None, **kwargs):
 
        self.now = now or make_utc(datetime.datetime.utcnow(), tzinfo=True)
 
        if kwargs:
 
            self._setup(**kwargs)
 
        self.setup()
 
        if host_data is None:
 
            host_data = self.normalize_host_data()
 
        created = self._import_create(host_data)
 
        self.teardown()
 
        return created
 

	
 
    def _import_create(self, data):
 
        count = len(data)
 
        if not count:
 
            return 0
 
        created = count
 

	
 
        prog = None
 
        if self.progress:
 
            prog = self.progress("Importing {} data".format(self.model_name), count)
 

	
 
        for i, host_data in enumerate(data, 1):
 

	
 
            key = self.get_key(host_data)
 
            self.create_object(key, host_data)
 
            if self.max_create and i >= self.max_create:
 
                log.warning("max of {} *created* records has been reached; stopping now".format(self.max_create))
 
                created = i
 
                break
 

	
 
            if prog:
 
                prog.update(i)
 
        if prog:
 
            prog.destroy()
 

	
 
        self.flush_create()
 
        return created
 

	
 
    def flush_create(self):
 
        """
 
        Perform any final steps to "flush" the created data here.  Note that
 
        the importer's handler is still responsible for actually committing
 
        changes to the local system, if applicable.
 
        """
rattail/importing/postgresql.py
Show inline comments
 
@@ -30,14 +30,14 @@ import os
 
import datetime
 
import logging
 

	
 
from rattail.importing.sqlalchemy import ToSQLAlchemy
 
from rattail.importing import BulkImporter, ToSQLAlchemy
 
from rattail.time import make_utc
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
class BulkToPostgreSQL(ToSQLAlchemy):
 
class BulkToPostgreSQL(BulkImporter, ToSQLAlchemy):
 
    """
 
    Base class for bulk data importers which target PostgreSQL on the local side.
 
    """
 
@@ -55,44 +55,6 @@ class BulkToPostgreSQL(ToSQLAlchemy):
 
        os.remove(self.data_path)
 
        self.data_buffer = None
 

	
 
    def import_data(self, host_data=None, now=None, **kwargs):
 
        self.now = now or make_utc(datetime.datetime.utcnow(), tzinfo=True)
 
        if kwargs:
 
            self._setup(**kwargs)
 
        self.setup()
 
        if host_data is None:
 
            host_data = self.normalize_host_data()
 
        created = self._import_create(host_data)
 
        self.teardown()
 
        return created
 

	
 
    def _import_create(self, data):
 
        count = len(data)
 
        if not count:
 
            return 0
 
        created = count
 

	
 
        prog = None
 
        if self.progress:
 
            prog = self.progress("Importing {} data".format(self.model_name), count)
 

	
 
        for i, host_data in enumerate(data, 1):
 

	
 
            key = self.get_key(host_data)
 
            self.create_object(key, host_data)
 
            if self.max_create and i >= self.max_create:
 
                log.warning("max of {} *created* records has been reached; stopping now".format(self.max_create))
 
                created = i
 
                break
 

	
 
            if prog:
 
                prog.update(i)
 
        if prog:
 
            prog.destroy()
 

	
 
        self.commit_create()
 
        return created
 

	
 
    def create_object(self, key, data):
 
        data = self.prep_data_for_postgres(data)
 
        self.data_buffer.write('{}\n'.format('\t'.join([data[field] for field in self.fields])).encode('utf-8'))
 
@@ -121,7 +83,7 @@ class BulkToPostgreSQL(ToSQLAlchemy):
 

	
 
        return unicode(value)
 

	
 
    def commit_create(self):
 
    def flush_create(self):
 
        log.info("copying {} data from buffer to PostgreSQL".format(self.model_name))
 
        self.data_buffer.close()
 
        self.data_buffer = open(self.data_path, 'rb')
rattail/importing/rattail_bulk.py
Show inline comments
 
@@ -31,7 +31,7 @@ from rattail.util import OrderedDict
 
from rattail.importing.rattail import FromRattailToRattail, FromRattail
 

	
 

	
 
class BulkFromRattailToRattail(FromRattailToRattail, importing.BulkToPostgreSQLHandler):
 
class BulkFromRattailToRattail(FromRattailToRattail, importing.BulkImportHandler):
 
    """
 
    Handler for Rattail -> Rattail bulk data import.
 
    """
rattail/tests/importing/test_handlers.py
Show inline comments
 
@@ -19,12 +19,12 @@ from rattail.tests.importing.test_importers import MockImporter
 
from rattail.tests.importing.test_postgresql import MockBulkImporter
 

	
 

	
 
class TestImportHandler(unittest.TestCase):
 
class ImportHandlerBattery(ImporterTester):
 

	
 
    def test_init(self):
 

	
 
        # vanilla
 
        handler = handlers.ImportHandler()
 
        handler = self.handler_class()
 
        self.assertEqual(handler.importers, {})
 
        self.assertEqual(handler.get_importers(), {})
 
        self.assertEqual(handler.get_importer_keys(), [])
 
@@ -32,34 +32,34 @@ class TestImportHandler(unittest.TestCase):
 
        self.assertFalse(handler.commit_host_partial)
 

	
 
        # with config
 
        handler = handlers.ImportHandler()
 
        handler = self.handler_class()
 
        self.assertIsNone(handler.config)
 
        config = RattailConfig()
 
        handler = handlers.ImportHandler(config=config)
 
        handler = self.handler_class(config=config)
 
        self.assertIs(handler.config, config)
 

	
 
        # dry run
 
        handler = handlers.ImportHandler()
 
        handler = self.handler_class()
 
        self.assertFalse(handler.dry_run)
 
        handler = handlers.ImportHandler(dry_run=True)
 
        handler = self.handler_class(dry_run=True)
 
        self.assertTrue(handler.dry_run)
 

	
 
        # extra kwarg
 
        handler = handlers.ImportHandler()
 
        handler = self.handler_class()
 
        self.assertRaises(AttributeError, getattr, handler, 'foo')
 
        handler = handlers.ImportHandler(foo='bar')
 
        handler = self.handler_class(foo='bar')
 
        self.assertEqual(handler.foo, 'bar')
 

	
 
    def test_get_importer(self):
 
        get_importers = Mock(return_value={'foo': Importer})
 

	
 
        # no importers
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        self.assertIsNone(handler.get_importer('foo'))
 

	
 
        # no config
 
        with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
 
            handler = handlers.ImportHandler()
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class()
 
        importer = handler.get_importer('foo')
 
        self.assertIs(type(importer), Importer)
 
        self.assertIsNone(importer.config)
 
@@ -67,26 +67,26 @@ class TestImportHandler(unittest.TestCase):
 

	
 
        # with config
 
        config = RattailConfig()
 
        with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
 
            handler = handlers.ImportHandler(config=config)
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class(config=config)
 
        importer = handler.get_importer('foo')
 
        self.assertIs(type(importer), Importer)
 
        self.assertIs(importer.config, config)
 
        self.assertIs(importer.handler, handler)
 

	
 
        # dry run
 
        with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
 
            handler = handlers.ImportHandler()
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class()
 
        importer = handler.get_importer('foo')
 
        self.assertFalse(importer.dry_run)
 
        with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
 
            handler = handlers.ImportHandler(dry_run=True)
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class(dry_run=True)
 
        importer = handler.get_importer('foo')
 
        self.assertTrue(handler.dry_run)
 

	
 
        # host title
 
        with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
 
            handler = handlers.ImportHandler()
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class()
 
        importer = handler.get_importer('foo')
 
        self.assertIsNone(importer.host_system_title)
 
        handler.host_title = "Foo"
 
@@ -94,8 +94,8 @@ class TestImportHandler(unittest.TestCase):
 
        self.assertEqual(importer.host_system_title, "Foo")
 

	
 
        # extra kwarg
 
        with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
 
            handler = handlers.ImportHandler()
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class()
 
        importer = handler.get_importer('foo')
 
        self.assertRaises(AttributeError, getattr, importer, 'bar')
 
        importer = handler.get_importer('foo', bar='baz')
 
@@ -104,15 +104,15 @@ class TestImportHandler(unittest.TestCase):
 
    def test_get_importer_kwargs(self):
 

	
 
        # empty by default
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        self.assertEqual(handler.get_importer_kwargs('foo'), {})
 

	
 
        # extra kwargs are preserved
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        self.assertEqual(handler.get_importer_kwargs('foo', bar='baz'), {'bar': 'baz'})
 

	
 
    def test_begin_transaction(self):
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        with patch.object(handler, 'begin_host_transaction') as begin_host:
 
            with patch.object(handler, 'begin_local_transaction') as begin_local:
 
                handler.begin_transaction()
 
@@ -120,15 +120,15 @@ class TestImportHandler(unittest.TestCase):
 
                begin_local.assert_called_once_with()
 

	
 
    def test_begin_host_transaction(self):
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        handler.begin_host_transaction()
 

	
 
    def test_begin_local_transaction(self):
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        handler.begin_local_transaction()
 

	
 
    def test_commit_transaction(self):
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        with patch.object(handler, 'commit_host_transaction') as commit_host:
 
            with patch.object(handler, 'commit_local_transaction') as commit_local:
 
                handler.commit_transaction()
 
@@ -136,15 +136,15 @@ class TestImportHandler(unittest.TestCase):
 
                commit_local.assert_called_once_with()
 

	
 
    def test_commit_host_transaction(self):
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        handler.commit_host_transaction()
 

	
 
    def test_commit_local_transaction(self):
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        handler.commit_local_transaction()
 

	
 
    def test_rollback_transaction(self):
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        with patch.object(handler, 'rollback_host_transaction') as rollback_host:
 
            with patch.object(handler, 'rollback_local_transaction') as rollback_local:
 
                handler.rollback_transaction()
 
@@ -152,24 +152,22 @@ class TestImportHandler(unittest.TestCase):
 
                rollback_local.assert_called_once_with()
 

	
 
    def test_rollback_host_transaction(self):
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        handler.rollback_host_transaction()
 

	
 
    def test_rollback_local_transaction(self):
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        handler.rollback_local_transaction()
 

	
 
    def test_import_data(self):
 

	
 
        # normal
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        result = handler.import_data()
 
        self.assertEqual(result, {})
 

	
 
    def test_import_data_dry_run(self):
 

	
 
        # as init kwarg
 
        handler = handlers.ImportHandler(dry_run=True)
 
        handler = self.make_handler(dry_run=True)
 
        with patch.object(handler, 'commit_transaction') as commit:
 
            with patch.object(handler, 'rollback_transaction') as rollback:
 
                handler.import_data()
 
@@ -178,7 +176,7 @@ class TestImportHandler(unittest.TestCase):
 
        self.assertTrue(handler.dry_run)
 

	
 
        # as import kwarg
 
        handler = handlers.ImportHandler()
 
        handler = self.make_handler()
 
        with patch.object(handler, 'commit_transaction') as commit:
 
            with patch.object(handler, 'rollback_transaction') as rollback:
 
                handler.import_data(dry_run=True)
 
@@ -187,11 +185,10 @@ class TestImportHandler(unittest.TestCase):
 
        self.assertTrue(handler.dry_run)
 

	
 
    def test_import_data_invalid_model(self):
 
        handler = self.make_handler()
 
        importer = Mock()
 
        importer.import_data.return_value = [], [], []
 
        FooImporter = Mock(return_value=importer)
 

	
 
        handler = handlers.ImportHandler()
 
        handler.importers = {'Foo': FooImporter}
 

	
 
        handler.import_data('Foo')
 
@@ -206,10 +203,9 @@ class TestImportHandler(unittest.TestCase):
 
        self.assertFalse(importer.called)
 

	
 
    def test_import_data_with_changes(self):
 
        handler = self.make_handler()
 
        importer = Mock()
 
        FooImporter = Mock(return_value=importer)
 

	
 
        handler = handlers.ImportHandler()
 
        handler.importers = {'Foo': FooImporter}
 

	
 
        importer.import_data.return_value = [], [], []
 
@@ -223,11 +219,10 @@ class TestImportHandler(unittest.TestCase):
 
            process.assert_called_once_with({'Foo': ([1], [2], [3])})
 

	
 
    def test_import_data_commit_host_partial(self):
 
        handler = self.make_handler()
 
        importer = Mock()
 
        importer.import_data.side_effect = ValueError
 
        FooImporter = Mock(return_value=importer)
 

	
 
        handler = handlers.ImportHandler()
 
        handler.importers = {'Foo': FooImporter}
 

	
 
        handler.commit_host_partial = False
 
@@ -240,6 +235,47 @@ class TestImportHandler(unittest.TestCase):
 
            self.assertRaises(ValueError, handler.import_data, 'Foo')
 
            commit.assert_called_once_with()
 

	
 

	
 
class BulkImportHandlerBattery(ImportHandlerBattery):
 

	
 
    def test_import_data_invalid_model(self):
 
        handler = self.make_handler()
 
        importer = Mock()
 
        importer.import_data.return_value = 0
 
        FooImporter = Mock(return_value=importer)
 
        handler.importers = {'Foo': FooImporter}
 

	
 
        handler.import_data('Foo')
 
        self.assertEqual(FooImporter.call_count, 1)
 
        importer.import_data.assert_called_once_with()
 

	
 
        FooImporter.reset_mock()
 
        importer.reset_mock()
 

	
 
        handler.import_data('Missing')
 
        self.assertFalse(FooImporter.called)
 
        self.assertFalse(importer.called)
 

	
 
    def test_import_data_with_changes(self):
 
        handler = self.make_handler()
 
        importer = Mock()
 
        FooImporter = Mock(return_value=importer)
 
        handler.importers = {'Foo': FooImporter}
 

	
 
        importer.import_data.return_value = 0
 
        with patch.object(handler, 'process_changes') as process:
 
            handler.import_data('Foo')
 
            self.assertFalse(process.called)
 

	
 
        importer.import_data.return_value = 3
 
        with patch.object(handler, 'process_changes') as process:
 
            handler.import_data('Foo')
 
            self.assertFalse(process.called)
 

	
 

	
 
class TestImportHandler(unittest.TestCase, ImportHandlerBattery):
 
    handler_class = handlers.ImportHandler
 

	
 
    @patch('rattail.importing.handlers.send_email')
 
    def test_process_changes_sends_email(self, send_email):
 
        handler = handlers.ImportHandler()
 
@@ -265,6 +301,10 @@ class TestImportHandler(unittest.TestCase):
 
        self.assertEqual(send_email.call_count, 1)
 

	
 

	
 
class TestBulkImportHandler(unittest.TestCase, BulkImportHandlerBattery):
 
    handler_class = handlers.BulkImportHandler
 

	
 

	
 
######################################################################
 
# fake import handler, tested mostly for basic coverage
 
######################################################################
 
@@ -566,7 +606,7 @@ class MockBulkImportHandler(handlers.BulkToPostgreSQLHandler):
 
        return Session()
 

	
 

	
 
class TestBulkImportHandler(RattailTestCase, ImporterTester):
 
class TestBulkImportHandlerOld(RattailTestCase, ImporterTester):
 

	
 
    importer_class = MockBulkImporter
 

	
rattail/tests/importing/test_importers.py
Show inline comments
 
@@ -4,7 +4,7 @@ from __future__ import unicode_literals, absolute_import
 

	
 
from unittest import TestCase
 

	
 
from mock import Mock, patch
 
from mock import Mock, patch, call
 

	
 
from rattail.db import model
 
from rattail.db.util import QuerySequence
 
@@ -13,6 +13,49 @@ from rattail.tests import NullProgress, RattailTestCase
 
from rattail.tests.importing import ImporterTester
 

	
 

	
 
class ImporterBattery(ImporterTester):
 
    """
 
    Battery of tests which can hopefully be ran for any non-bulk importer.
 
    """
 

	
 
    def test_import_data_empty(self):
 
        importer = self.make_importer()
 
        result = importer.import_data()
 
        self.assertEqual(result, {})
 

	
 
    def test_import_data_dry_run(self):
 
        importer = self.make_importer()
 
        self.assertFalse(importer.dry_run)
 
        importer.import_data(dry_run=True)
 
        self.assertTrue(importer.dry_run)
 

	
 
    def test_import_data_create(self):
 
        importer = self.make_importer()
 
        with patch.object(importer, 'get_key', lambda k: k):
 
            with patch.object(importer, 'create_object') as create:
 
                importer.import_data(host_data=[1, 2, 3])
 
                self.assertEqual(create.call_args_list, [
 
                    call(1, 1), call(2, 2), call(3, 3)])
 

	
 
    def test_import_data_max_create(self):
 
        importer = self.make_importer()
 
        with patch.object(importer, 'get_key', lambda k: k):
 
            with patch.object(importer, 'create_object') as create:
 
                importer.import_data(host_data=[1, 2, 3], max_create=1)
 
                self.assertEqual(create.call_args_list, [call(1, 1)])
 

	
 

	
 
class BulkImporterBattery(ImporterBattery):
 
    """
 
    Battery of tests which can hopefully be ran for any bulk importer.
 
    """
 

	
 
    def test_import_data_empty(self):
 
        importer = self.make_importer()
 
        result = importer.import_data()
 
        self.assertEqual(result, 0)
 

	
 

	
 
class TestImporter(TestCase):
 

	
 
    def test_init(self):
 
@@ -164,6 +207,11 @@ class TestFromQuery(RattailTestCase):
 
        self.assertIsInstance(objects, QuerySequence)
 

	
 

	
 
class TestBulkImporter(TestCase, BulkImporterBattery):
 
    importer_class = importers.BulkImporter
 
        
 

	
 

	
 
######################################################################
 
# fake importer class, tested mostly for basic coverage
 
######################################################################
0 comments (0 inline, 0 general)