Changeset - 22e10accf5f0
[Not reviewed]
0 10 0
Lance Edgar - 8 years ago 2016-05-17 15:02:08
ledgar@sacfoodcoop.com
Add "batch" support to importer framework

Doesn't really split things up per se, just controls how often a "flush
changes" should occur.
10 files changed with 294 insertions and 147 deletions:
0 comments (0 inline, 0 general)
rattail/commands/importing.py
Show inline comments
 
@@ -23,72 +23,76 @@
 
"""
 
Importing Commands
 
"""
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
import logging
 

	
 
from rattail.commands.core import Subcommand, date_argument
 
from rattail.util import load_object
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
class ImportSubcommand(Subcommand):
 
    """
 
    Base class for subcommands which use the (new) data importing system.
 
    """
 
    handler_spec = None
 

	
 
    # TODO: move this into Subcommand or something..
 
    parent_name = None
 
    def __init__(self, *args, **kwargs):
 
        if 'handler_spec' in kwargs:
 
            self.handler_spec = kwargs.pop('handler_spec')
 
        super(ImportSubcommand, self).__init__(*args, **kwargs)
 
        if self.parent:
 
            self.parent_name = self.parent.name
 

	
 
    def get_handler_factory(self):
 
        """
 
        Subclasses must override this, and return a callable that creates an
 
        import handler instance which the command should use.
 
        """
 
        if self.handler_spec:
 
            return load_object(self.handler_spec)
 
        raise NotImplementedError
 

	
 
    def get_handler(self, **kwargs):
 
        """
 
        Returns a handler instance to be used by the command.
 
        """
 
        factory = self.get_handler_factory()
 
        kwargs.setdefault('config', getattr(self, 'config', None))
 
        kwargs.setdefault('command', self)
 
        kwargs.setdefault('progress', self.progress)
 
        if 'args' in kwargs:
 
            args = kwargs['args']
 
            kwargs.setdefault('dry_run', args.dry_run)
 
            if hasattr(args, 'batch_size'):
 
                kwargs.setdefault('batch_size', args.batch_size)
 
            # kwargs.setdefault('max_create', args.max_create)
 
            # kwargs.setdefault('max_update', args.max_update)
 
            # kwargs.setdefault('max_delete', args.max_delete)
 
            # kwargs.setdefault('max_total', args.max_total)
 
        kwargs = self.get_handler_kwargs(**kwargs)
 
        return factory(**kwargs)
 

	
 
    def get_handler_kwargs(self, **kwargs):
 
        """
 
        Return a dict of kwargs to be passed to the handler factory.
 
        """
 
        return kwargs
 

	
 
    def add_parser_args(self, parser):
 

	
 
        # model names (aka importer keys)
 
        doc = ("Which data models to import.  If you specify any, then only "
 
               "data for those models will be imported.  If you do not specify "
 
               "any, then all *default* models will be imported.")
 
        try:
 
            handler = self.get_handler()
 
        except NotImplementedError:
 
            pass
 
        else:
 
@@ -118,48 +122,55 @@ class ImportSubcommand(Subcommand):
 
                            help="Allow existing records to be updated during the import.")
 
        parser.add_argument('--no-update', action='store_false', dest='update',
 
                            help="Do not allow existing records to be updated during the import.")
 
        parser.add_argument('--max-update', type=int, metavar='COUNT',
 
                            help="Maximum number of records which may be updated, after which a "
 
                            "given import task should stop.  Note that this applies on a per-model "
 
                            "basis and not overall.")
 

	
 
        # allow delete?
 
        parser.add_argument('--delete', action='store_true', default=False,
 
                            help="Allow records to be deleted during the import.")
 
        parser.add_argument('--no-delete', action='store_false', dest='delete',
 
                            help="Do not allow records to be deleted during the import.")
 
        parser.add_argument('--max-delete', type=int, metavar='COUNT',
 
                            help="Maximum number of records which may be deleted, after which a "
 
                            "given import task should stop.  Note that this applies on a per-model "
 
                            "basis and not overall.")
 

	
 
        # max total changes, per model
 
        parser.add_argument('--max-total', type=int, metavar='COUNT',
 
                            help="Maximum number of *any* record changes which may occur, after which "
 
                            "a given import task should stop.  Note that this applies on a per-model "
 
                            "basis and not overall.")
 

	
 
        # batch size
 
        parser.add_argument('--batch', type=int, dest='batch_size', metavar='SIZE', default=200,
 
                            help="Split work to be done into batches, with the specified number of "
 
                            "records in each batch.  Or, set this to 0 (zero) to disable batching. "
 
                            "Implementation for this may vary somewhat between importers; default "
 
                            "batch size is 200 records.")
 

	
 
        # treat changes as warnings?
 
        parser.add_argument('--warnings', '-W', action='store_true',
 
                            help="Set this flag if you expect a \"clean\" import, and wish for any "
 
                            "changes which do occur to be processed further and/or specially.  The "
 
                            "behavior of this flag is ultimately up to the import handler, but the "
 
                            "default is to send an email notification.")
 

	
 
        # dry run?
 
        parser.add_argument('--dry-run', action='store_true',
 
                            help="Go through the full motions and allow logging etc. to "
 
                            "occur, but rollback (abort) the transaction at the end.")
 

	
 
    def run(self, args):
 
        log.info("begin `{} {}` for data models: {}".format(
 
                self.parent_name, self.name, ', '.join(args.models or ["(ALL)"])))
 

	
 
        handler = self.get_handler(args=args)
 
        models = args.models or handler.get_default_keys()
 
        log.debug("using handler: {}".format(handler))
 
        log.debug("importing models: {}".format(models))
 
        log.debug("args are: {}".format(args))
 

	
 
        kwargs = {
 
            'dry_run': args.dry_run,
rattail/importing/handlers.py
Show inline comments
 
@@ -70,48 +70,50 @@ class ImportHandler(object):
 
        return {}
 

	
 
    def get_importer_keys(self):
 
        """
 
        Returns the list of keys corresponding to the available importers.
 
        """
 
        return list(self.importers.iterkeys())
 

	
 
    def get_default_keys(self):
 
        """
 
        Returns the list of keys corresponding to the "default" importers.
 
        Override this if you wish certain importers to be excluded by default,
 
        e.g. when first testing them out etc.
 
        """
 
        return self.get_importer_keys()
 

	
 
    def get_importer(self, key, **kwargs):
 
        """
 
        Returns an importer instance corresponding to the given key.
 
        """
 
        if key in self.importers:
 
            kwargs.setdefault('handler', self)
 
            kwargs.setdefault('config', self.config)
 
            kwargs.setdefault('host_system_title', self.host_title)
 
            if hasattr(self, 'batch_size'):
 
                kwargs.setdefault('batch_size', self.batch_size)
 
            kwargs = self.get_importer_kwargs(key, **kwargs)
 
            return self.importers[key](**kwargs)
 

	
 
    def get_importer_kwargs(self, key, **kwargs):
 
        """
 
        Return a dict of kwargs to be used when construcing an importer with
 
        the given key.
 
        """
 
        return kwargs
 

	
 
    def import_data(self, *keys, **kwargs):
 
        """
 
        Import all data for the given importer/model keys.
 
        """
 
        self.import_began = make_utc(datetime.datetime.utcnow(), tzinfo=True)
 
        if 'dry_run' in kwargs:
 
            self.dry_run = kwargs['dry_run']
 
        self.progress = kwargs.pop('progress', getattr(self, 'progress', None))
 
        self.warnings = kwargs.pop('warnings', False)
 
        kwargs.update({'dry_run': self.dry_run,
 
                       'progress': self.progress})
 

	
 
        self.setup()
 
        self.begin_transaction()
 
@@ -318,30 +320,24 @@ class ToSQLAlchemyHandler(ImportHandler):
 

	
 
    def make_session(self):
 
        """
 
        Subclasses must override this to define the local database connection.
 
        """
 
        raise NotImplementedError
 

	
 
    def get_importer_kwargs(self, key, **kwargs):
 
        kwargs = super(ToSQLAlchemyHandler, self).get_importer_kwargs(key, **kwargs)
 
        kwargs.setdefault('session', self.session)
 
        return kwargs
 

	
 
    def begin_local_transaction(self):
 
        self.session = self.make_session()
 

	
 
    def rollback_local_transaction(self):
 
        self.session.rollback()
 
        self.session.close()
 
        self.session = None
 

	
 
    def commit_local_transaction(self):
 
        self.session.commit()
 
        self.session.close()
 
        self.session = None
 

	
 

	
 
class BulkToPostgreSQLHandler(BulkImportHandler):
 
    """
 
    Handler for bulk imports which target PostgreSQL on the local side.
 
    """
rattail/importing/importers.py
Show inline comments
 
@@ -21,106 +21,107 @@
 
#
 
################################################################################
 
"""
 
Data Importers
 
"""
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
import datetime
 
import logging
 

	
 
from rattail.db.util import QuerySequence
 
from rattail.time import make_utc
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
class Importer(object):
 
    """
 
    Base class for all data importers.
 
    """
 
    # Set this to the data model class which is targeted on the local side.
 
    model_class = None
 
    model_name = None
 

	
 
    key = None
 

	
 
    # The full list of field names supported by the importer, i.e. for the data
 
    # model to which the importer pertains.  By definition this list will be
 
    # restricted to what the local side can acommodate, but may be further
 
    # restricted by what the host side has to offer.
 
    supported_fields = []
 

	
 
    # The list of field names which may be considered "simple" and therefore
 
    # treated as such, i.e. with basic getattr/setattr calls.  Note that this
 
    # only applies to the local side, it has no effect on the host side.
 
    simple_fields = []
 

	
 
    allow_create = True
 
    allow_update = True
 
    allow_delete = True
 
    dry_run = False
 

	
 
    max_create = None
 
    max_update = None
 
    max_delete = None
 
    max_total = None
 
    batch_size = 200
 
    progress = None
 

	
 
    empty_local_data = False
 
    caches_local_data = False
 
    cached_local_data = None
 

	
 
    host_system_title = None
 
    local_system_title = None
 

	
 
    def __init__(self, config=None, fields=None, key=None, **kwargs):
 
        self.config = config
 
        self.fields = fields or self.supported_fields
 
        if key is not None:
 
            self.key = key
 
        if isinstance(self.key, basestring):
 
            self.key = (self.key,)
 
        if self.key:
 
            for field in self.key:
 
                if field not in self.fields:
 
                    raise ValueError("Key field '{}' must be included in effective fields "
 
                                     "for {}".format(field, self.__class__.__name__))
 
        self.model_class = kwargs.pop('model_class', self.model_class)
 
        self.model_name = kwargs.pop('model_name', self.model_name)
 
        if not self.model_name and self.model_class:
 
            self.model_name = self.model_class.__name__
 
        self._setup(**kwargs)
 

	
 
    def _setup(self, **kwargs):
 
        self.create = kwargs.pop('create', self.allow_create) and self.allow_create
 
        self.update = kwargs.pop('update', self.allow_update) and self.allow_update
 
        self.delete = kwargs.pop('delete', False) and self.allow_delete
 
        for key, value in kwargs.iteritems():
 
            setattr(self, key, value)
 

	
 
    @property
 
    def model_name(self):
 
        if self.model_class:
 
            return self.model_class.__name__
 

	
 
    def setup(self):
 
        """
 
        Perform any setup necessary, e.g. cache lookups for existing data.
 
        """
 

	
 
    def teardown(self):
 
        """
 
        Perform any cleanup after import, if necessary.
 
        """
 

	
 
    def import_data(self, host_data=None, now=None, **kwargs):
 
        """
 
        Import some data!  This is the core body of logic for that, regardless
 
        of where data is coming from or where it's headed.  Note that this
 
        method handles deletions as well as adds/updates.
 
        """
 
        self.now = now or make_utc(datetime.datetime.utcnow(), tzinfo=True)
 
        if kwargs:
 
            self._setup(**kwargs)
 
        self.setup()
 
        created = updated = deleted = []
 

	
 
        # Get complete set of normalized host data.
 
        if host_data is None:
 
@@ -187,67 +188,71 @@ class Importer(object):
 
                        ','.join(diffs), local_data, host_data))
 
                    local_object = self.update_object(local_object, host_data, local_data)
 
                    updated.append((local_object, local_data, host_data))
 
                    if self.max_update and len(updated) >= self.max_update:
 
                        log.warning("max of {} *updated* records has been reached; stopping now".format(self.max_update))
 
                        break
 
                    if self.max_total and (len(created) + len(updated)) >= self.max_total:
 
                        log.warning("max of {} *total changes* has been reached; stopping now".format(self.max_total))
 
                        break
 

	
 
            # If we did not yet have a local object, create it using host data.
 
            elif not local_object and self.create:
 
                local_object = self.create_object(key, host_data)
 
                log.debug("created new {} {}: {}".format(self.model_name, key, local_object))
 
                created.append((local_object, host_data))
 
                if self.caches_local_data and self.cached_local_data is not None:
 
                    self.cached_local_data[key] = {'object': local_object, 'data': self.normalize_local_object(local_object)}
 
                if self.max_create and len(created) >= self.max_create:
 
                    log.warning("max of {} *created* records has been reached; stopping now".format(self.max_create))
 
                    break
 
                if self.max_total and (len(created) + len(updated)) >= self.max_total:
 
                    log.warning("max of {} *total changes* has been reached; stopping now".format(self.max_total))
 
                    break
 

	
 
            self.flush_changes(i)
 
            # # TODO: this needs to be customizable etc. somehow maybe..
 
            # if i % 100 == 0 and hasattr(self, 'session'):
 
            #     self.session.flush()
 
            # flush changes every so often
 
            if not self.batch_size or (len(created) + len(updated)) % self.batch_size == 0:
 
                self.flush_create_update()
 

	
 
            if prog:
 
                prog.update(i)
 
        if prog:
 
            prog.destroy()
 

	
 
        self.flush_create_update_final()
 
        return created, updated
 

	
 
    # TODO: this surely goes elsewhere
 
    flush_every_x = 100
 
    def flush_create_update(self):
 
        """
 
        Perform any steps necessary to "flush" the create/update changes which
 
        have occurred thus far in the import.
 
        """
 

	
 
    def flush_changes(self, x):
 
        if self.flush_every_x and x % self.flush_every_x == 0:
 
            if hasattr(self, 'session'):
 
                self.session.flush()
 
    def flush_create_update_final(self):
 
        """
 
        Perform any final steps to "flush" the created/updated data here.
 
        """
 
        self.flush_create_update()
 

	
 
    def _import_delete(self, host_data, host_keys, changes=0):
 
        """
 
        Import deletions for the given data set.
 
        """
 
        deleted = []
 
        deleting = self.get_deletion_keys() - host_keys
 
        count = len(deleting)
 
        log.debug("found {} instances to delete".format(count))
 
        if count:
 

	
 
            prog = None
 
            if self.progress:
 
                prog = self.progress("Deleting {} data".format(self.model_name), count)
 

	
 
            for i, key in enumerate(sorted(deleting), 1):
 

	
 
                cached = self.cached_local_data.pop(key)
 
                obj = cached['object']
 
                if self.delete_object(obj):
 
                    deleted.append((obj, cached['data']))
 

	
 
                    if self.max_delete and len(deleted) >= self.max_delete:
 
                        log.warning("max of {} *deleted* records has been reached; stopping now".format(self.max_delete))
 
@@ -439,71 +444,68 @@ class Importer(object):
 

	
 

	
 
class FromQuery(Importer):
 
    """
 
    Generic base class for importers whose raw external data source is a
 
    SQLAlchemy (or Django, or possibly other?) query.
 
    """
 

	
 
    def query(self):
 
        """
 
        Subclasses must override this, and return the primary query which will
 
        define the data set.
 
        """
 
        raise NotImplementedError
 

	
 
    def get_host_objects(self, progress=None):
 
        """
 
        Returns (raw) query results as a sequence.
 
        """
 
        return QuerySequence(self.query())
 

	
 

	
 
class BulkImporter(Importer):
 
    """
 
    Base class for bulk data importers which target PostgreSQL on the local side.
 
    Base class for bulk data importers.
 
    """
 

	
 
    def import_data(self, host_data=None, now=None, **kwargs):
 
        self.now = now or make_utc(datetime.datetime.utcnow(), tzinfo=True)
 
        if kwargs:
 
            self._setup(**kwargs)
 
        self.setup()
 
        if host_data is None:
 
            host_data = self.normalize_host_data()
 
        created = self._import_create(host_data)
 
        self.teardown()
 
        return created
 

	
 
    def _import_create(self, data):
 
        count = len(data)
 
        if not count:
 
            return 0
 
        created = count
 

	
 
        prog = None
 
        if self.progress:
 
            prog = self.progress("Importing {} data".format(self.model_name), count)
 

	
 
        for i, host_data in enumerate(data, 1):
 

	
 
            key = self.get_key(host_data)
 
            self.create_object(key, host_data)
 
            if self.max_create and i >= self.max_create:
 
                log.warning("max of {} *created* records has been reached; stopping now".format(self.max_create))
 
                created = i
 
                break
 

	
 
            # flush changes every so often
 
            if not self.batch_size or i % self.batch_size == 0:
 
                self.flush_create_update()
 

	
 
            if prog:
 
                prog.update(i)
 
        if prog:
 
            prog.destroy()
 

	
 
        self.flush_create()
 
        self.flush_create_update_final()
 
        return created
 

	
 
    def flush_create(self):
 
        """
 
        Perform any final steps to "flush" the created data here.  Note that
 
        the importer's handler is still responsible for actually committing
 
        changes to the local system, if applicable.
 
        """
rattail/importing/postgresql.py
Show inline comments
 
@@ -62,32 +62,35 @@ class BulkToPostgreSQL(BulkImporter, ToSQLAlchemy):
 
    def prep_data_for_postgres(self, data):
 
        data = dict(data)
 
        for key, value in data.iteritems():
 
            data[key] = self.prep_value_for_postgres(value)
 
        return data
 

	
 
    def prep_value_for_postgres(self, value):
 
        if value is None:
 
            return '\\N'
 
        if value is True:
 
            return 't'
 
        if value is False:
 
            return 'f'
 

	
 
        if isinstance(value, datetime.datetime):
 
            value = make_utc(value, tzinfo=False)
 
        elif isinstance(value, basestring):
 
            value = value.replace('\\', '\\\\')
 
            value = value.replace('\r', '\\r')
 
            value = value.replace('\n', '\\n')
 
            value = value.replace('\t', '\\t') # TODO: add test for this
 

	
 
        return unicode(value)
 

	
 
    def flush_create(self):
 
    def flush_create_update(self):
 
        pass
 

	
 
    def flush_create_update_final(self):
 
        log.info("copying {} data from buffer to PostgreSQL".format(self.model_name))
 
        self.data_buffer.close()
 
        self.data_buffer = open(self.data_path, 'rb')
 
        cursor = self.session.connection().connection.cursor()
 
        table_name = '"{}"'.format(self.model_table.name)
 
        cursor.copy_from(self.data_buffer, table_name, columns=self.fields)
 
        log.debug("PostgreSQL data copy completed")
rattail/importing/sqlalchemy.py
Show inline comments
 
@@ -149,24 +149,31 @@ class ToSQLAlchemy(Importer):
 
    def cache_model(self, model, **kwargs):
 
        """
 
        Convenience method which invokes :func:`rattail.db.cache.cache_model()`
 
        with the given model and keyword arguments.  It will provide the
 
        ``session`` and ``progress`` parameters by default, setting them to the
 
        importer's attributes of the same names.
 
        """
 
        session = kwargs.pop('session', self.session)
 
        kwargs.setdefault('progress', self.progress)
 
        return cache.cache_model(session, model, **kwargs)
 

	
 
    def cache_local_data(self, host_data=None):
 
        """
 
        Cache all local objects and data using SA ORM.
 
        """
 
        return self.cache_model(self.model_class, key=self.get_cache_key,
 
                                # omit_duplicates=True,
 
                                query_options=self.cache_query_options(),
 
                                normalizer=self.normalize_cache_object)
 

	
 
    def cache_query_options(self):
 
        """
 
        Return a list of options to apply to the cache query, if needed.
 
        """
 

	
 
    def flush_create_update(self):
 
        """
 
        Flush the database session, to send SQL to the server for all changes
 
        made thus far.
 
        """
 
        self.session.flush()
rattail/tests/commands/test_importing.py
Show inline comments
 
# -*- coding: utf-8 -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
import argparse
 
from unittest import TestCase
 

	
 
from mock import Mock, patch
 

	
 
from rattail.core import Object
 
from rattail.commands import importing
 
from rattail.importing import ImportHandler
 
from rattail.importing import Importer, ImportHandler
 
from rattail.importing.rattail import FromRattailToRattail
 
from rattail.config import RattailConfig
 
from rattail.tests.importing import ImporterTester
 
from rattail.tests.importing.test_handlers import MockImportHandler
 
from rattail.tests.importing.test_importers import MockImporter
 
from rattail.tests.importing.test_rattail import DualRattailTestCase
 

	
 

	
 
class MockImport(importing.ImportSubcommand):
 

	
 
    def get_handler_factory(self):
 
        return MockImportHandler
 

	
 

	
 
class TestImportSubcommandBasics(TestCase):
 
class TestImportSubcommand(TestCase):
 

	
 
    # TODO: lame, here only for coverage
 
    def test_parent_name(self):
 
        parent = Object(name='milo')
 
        command = importing.ImportSubcommand(parent=parent)
 
        self.assertEqual(command.parent_name, 'milo')
 

	
 
    def test_get_handler_factory(self):
 
        command = importing.ImportSubcommand()
 
        self.assertRaises(NotImplementedError, command.get_handler_factory)
 
        command.handler_spec = 'rattail.importing.rattail:FromRattailToRattail'
 
        factory = command.get_handler_factory()
 
        self.assertIs(factory, FromRattailToRattail)
 

	
 
    def test_handler_spec_attr(self):
 
        # default is None
 
        command = importing.ImportSubcommand()
 
        self.assertIsNone(command.handler_spec)
 

	
 
        # can't get a handler without a spec
 
        self.assertRaises(NotImplementedError, command.get_handler)
 

	
 
        # but may be specified with init kwarg
 
        command = importing.ImportSubcommand(handler_spec='rattail.importing:ImportHandler')
 
        self.assertEqual(command.handler_spec, 'rattail.importing:ImportHandler')
 

	
 
        # now we can get a handler
 
        handler = command.get_handler()
 
        self.assertIsInstance(handler, ImportHandler)
 

	
 
    def test_get_handler(self):
 

	
 
        # no config
 
        command = MockImport()
 
        handler = command.get_handler()
 
        self.assertIs(type(handler), MockImportHandler)
 
        self.assertIsNone(handler.config)
 

	
 
        # with config
 
        config = RattailConfig()
 
        command = MockImport(config=config)
 
        handler = command.get_handler()
 
        self.assertIs(type(handler), MockImportHandler)
 
        self.assertIs(handler.config, config)
 

	
 
        # dry run
 
        command = MockImport()
 
        handler = command.get_handler()
 
        self.assertFalse(handler.dry_run)
 
        handler = command.get_handler(dry_run=True)
 
        self.assertTrue(handler.dry_run)
 
        args = argparse.Namespace(dry_run=True)
 
        handler = command.get_handler(args=args)
 
        self.assertTrue(handler.dry_run)
 

	
 
    def test_add_parser_args(self):
 
        # TODO: this doesn't really test anything, but does give some coverage..
 

	
 
        # no handler
 
        # adding the args throws no error..(basic coverage)
 
        command = importing.ImportSubcommand()
 
        parser = argparse.ArgumentParser()
 
        command.add_parser_args(parser)
 

	
 
        # with handler
 
        command = MockImport()
 
        # confirm default values
 
        args = parser.parse_args([])
 
        self.assertIsNone(args.start_date)
 
        self.assertIsNone(args.end_date)
 
        self.assertTrue(args.create)
 
        self.assertIsNone(args.max_create)
 
        self.assertTrue(args.update)
 
        self.assertIsNone(args.max_update)
 
        self.assertFalse(args.delete)
 
        self.assertIsNone(args.max_delete)
 
        self.assertIsNone(args.max_total)
 
        self.assertEqual(args.batch_size, 200)
 
        self.assertFalse(args.warnings)
 
        self.assertFalse(args.dry_run)
 

	
 
    def test_batch_size_kwarg(self):
 
        command = importing.ImportSubcommand(handler_spec='rattail.importing:ImportHandler')
 
        parser = argparse.ArgumentParser()
 
        command.add_parser_args(parser)
 
        with patch.object(ImportHandler, 'get_importers', Mock(return_value={'Foo': Importer})):
 

	
 
            # importer default is 200
 
            args = parser.parse_args([])
 
            handler = command.get_handler(args=args)
 
            importer = handler.get_importer('Foo')
 
            self.assertEqual(importer.batch_size, 200)
 

	
 
            # but may be overridden with command line arg
 
            args = parser.parse_args(['--batch', '42'])
 
            handler = command.get_handler(args=args)
 
            importer = handler.get_importer('Foo')
 
            self.assertEqual(importer.batch_size, 42)
 

	
 
class TestImportSubcommandRun(ImporterTester, TestCase):
 

	
 
class TestImportSubcommandRun(TestCase, ImporterTester):
 

	
 
    sample_data = {
 
        '16oz': {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"},
 
        '32oz': {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"},
 
        '1gal': {'upc': '00074305011283', 'description': "Apple Cider Vinegar 1gal"},
 
    }
 

	
 
    def setUp(self):
 
        self.command = MockImport()
 
        self.handler = MockImportHandler()
 
        self.importer = MockImporter()
 

	
 
    def import_data(self, **kwargs):
 
    def import_data(self, host_data=None, local_data=None, **kwargs):
 
        if host_data is None:
 
            host_data = self.sample_data
 
        if local_data is None:
 
            local_data = self.sample_data
 

	
 
        models = kwargs.pop('models', [])
 
        kwargs.setdefault('dry_run', False)
 

	
 
        kw = {
 
            'warnings': False,
 
            'create': None,
 
            'max_create': None,
 
            'update': None,
 
            'max_update': None,
 
            'delete': None,
 
            'max_delete': None,
 
            'max_total': None,
 
            'batchcount': None,
 
            'progress': None,
 
        }
 
        kw.update(kwargs)
 
        args = argparse.Namespace(models=models, **kw)
 

	
 
        # must modify our importer in-place since we need the handler to return
 
        # that specific instance, below (because the host/local data context
 
        # managers reference that instance directly)
 
        self.importer._setup(**kwargs)
 
        with patch.object(self.command, 'get_handler', Mock(return_value=self.handler)):
 
            with patch.object(self.handler, 'get_importer', Mock(return_value=self.importer)):
 
                self.command.run(args)
 
                with self.host_data(host_data):
 
                    with self.local_data(local_data):
 
                        self.command.run(args)
 

	
 
        if self.handler._result:
 
            self.result = self.handler._result['Product']
 
        else:
 
            self.result = [], [], []
 

	
 
    def test_create(self):
 
        local = self.copy_data()
 
        del local['32oz']
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data()
 
        self.import_data(local_data=local)
 
        self.assert_import_created('32oz')
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_update(self):
 
        local = self.copy_data()
 
        local['16oz']['description'] = "wrong description"
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data()
 
        self.import_data(local_data=local)
 
        self.assert_import_created()
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted()
 

	
 
    def test_delete(self):
 
        local = self.copy_data()
 
        local['bogus'] = {'upc': '00000000000000', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(delete=True)
 
        self.import_data(local_data=local, delete=True)
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted('bogus')
 

	
 
    def test_duplicate(self):
 
        host = self.copy_data()
 
        host['32oz-dupe'] = host['32oz']
 
        with self.host_data(host):
 
            with self.local_data(self.sample_data):
 
                self.import_data()
 
        self.import_data(host_data=host)
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_max_create(self):
 
        local = self.copy_data()
 
        del local['16oz']
 
        del local['1gal']
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_create=1)
 
        self.import_data(local_data=local, max_create=1)
 
        self.assert_import_created('16oz')
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_max_total_create(self):
 
        local = self.copy_data()
 
        del local['16oz']
 
        del local['1gal']
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_total=1)
 
        self.import_data(local_data=local, max_total=1)
 
        self.assert_import_created('16oz')
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_max_update(self):
 
        local = self.copy_data()
 
        local['16oz']['description'] = "wrong"
 
        local['1gal']['description'] = "wrong"
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_update=1)
 
        self.import_data(local_data=local, max_update=1)
 
        self.assert_import_created()
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted()
 

	
 
    def test_max_total_update(self):
 
        local = self.copy_data()
 
        local['16oz']['description'] = "wrong"
 
        local['1gal']['description'] = "wrong"
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_total=1)
 
        self.import_data(local_data=local, max_total=1)
 
        self.assert_import_created()
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted()
 

	
 
    def test_max_delete(self):
 
        local = self.copy_data()
 
        local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"}
 
        local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(delete=True, max_delete=1)
 
        self.import_data(local_data=local, delete=True, max_delete=1)
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted('bogus1')
 

	
 
    def test_max_total_delete(self):
 
        local = self.copy_data()
 
        local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"}
 
        local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(delete=True, max_total=1)
 
        self.import_data(local_data=local, delete=True, max_total=1)
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted('bogus1')
 

	
 
    def test_dry_run(self):
 
        local = self.copy_data()
 
        del local['32oz']
 
        local['16oz']['description'] = "wrong description"
 
        local['bogus'] = {'upc': '00000000000000', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(delete=True, dry_run=True)
 
        self.import_data(local_data=local, delete=True, dry_run=True)
 
        # TODO: maybe need a way to confirm no changes actually made due to dry
 
        # run; currently results still reflect "proposed" changes.  this rather
 
        # bogus test is here just for coverage sake
 
        self.assert_import_created('32oz')
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted('bogus')
 

	
 

	
 
class TestImportRattail(DualRattailTestCase):
 

	
 
    def make_command(self, **kwargs):
 
        kwargs.setdefault('config', self.config)
 
        return importing.ImportRattail(**kwargs)
 

	
 
    def test_get_handler_factory(self):
 

	
 
        # default handler
 
        config = RattailConfig()
 
        command = self.make_command(config=config)
 
        Handler = command.get_handler_factory()
 
        self.assertIs(Handler, FromRattailToRattail)
 

	
 
    def test_get_handler_kwargs(self):
 
        command = self.make_command()
rattail/tests/importing/test_handlers.py
Show inline comments
 
@@ -53,48 +53,58 @@ class ImportHandlerBattery(ImporterTester):
 
    def test_get_importer(self):
 
        get_importers = Mock(return_value={'foo': Importer})
 

	
 
        # no importers
 
        handler = self.make_handler()
 
        self.assertIsNone(handler.get_importer('foo'))
 

	
 
        # no config
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class()
 
        importer = handler.get_importer('foo')
 
        self.assertIs(type(importer), Importer)
 
        self.assertIsNone(importer.config)
 
        self.assertIs(importer.handler, handler)
 

	
 
        # with config
 
        config = RattailConfig()
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class(config=config)
 
        importer = handler.get_importer('foo')
 
        self.assertIs(type(importer), Importer)
 
        self.assertIs(importer.config, config)
 
        self.assertIs(importer.handler, handler)
 

	
 
        # # batch size
 
        # with patch.object(self.handler_class, 'get_importers', get_importers):
 
        #     handler = self.handler_class()
 
        # importer = handler.get_importer('foo')
 
        # self.assertEqual(importer.batch_size, 200)
 
        # with patch.object(self.handler_class, 'get_importers', get_importers):
 
        #     handler = self.handler_class(batch_size=42)
 
        # importer = handler.get_importer('foo')
 
        # self.assertEqual(importer.batch_size, 42)
 

	
 
        # dry run
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class()
 
        importer = handler.get_importer('foo')
 
        self.assertFalse(importer.dry_run)
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class(dry_run=True)
 
        importer = handler.get_importer('foo')
 
        self.assertTrue(handler.dry_run)
 

	
 
        # host title
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class()
 
        importer = handler.get_importer('foo')
 
        self.assertIsNone(importer.host_system_title)
 
        handler.host_title = "Foo"
 
        importer = handler.get_importer('foo')
 
        self.assertEqual(importer.host_system_title, "Foo")
 

	
 
        # extra kwarg
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class()
 
        importer = handler.get_importer('foo')
 
        self.assertRaises(AttributeError, getattr, importer, 'bar')
 
@@ -255,48 +265,74 @@ class BulkImportHandlerBattery(ImportHandlerBattery):
 
        handler.import_data('Missing')
 
        self.assertFalse(FooImporter.called)
 
        self.assertFalse(importer.called)
 

	
 
    def test_import_data_with_changes(self):
 
        handler = self.make_handler()
 
        importer = Mock()
 
        FooImporter = Mock(return_value=importer)
 
        handler.importers = {'Foo': FooImporter}
 

	
 
        importer.import_data.return_value = 0
 
        with patch.object(handler, 'process_changes') as process:
 
            handler.import_data('Foo')
 
            self.assertFalse(process.called)
 

	
 
        importer.import_data.return_value = 3
 
        with patch.object(handler, 'process_changes') as process:
 
            handler.import_data('Foo')
 
            self.assertFalse(process.called)
 

	
 

	
 
class TestImportHandler(unittest.TestCase, ImportHandlerBattery):
 
    handler_class = handlers.ImportHandler
 

	
 
    def test_batch_size_kwarg(self):
 
        with patch.object(handlers.ImportHandler, 'get_importers', Mock(return_value={'Foo': Importer})):
 

	
 
            # handler has no batch size by default
 
            handler = handlers.ImportHandler()
 
            self.assertFalse(hasattr(handler, 'batch_size'))
 

	
 
            # but can override with kwarg
 
            handler = handlers.ImportHandler(batch_size=42)
 
            self.assertEqual(handler.batch_size, 42)
 

	
 
            # importer default is 200
 
            handler = handlers.ImportHandler()
 
            importer = handler.get_importer('Foo')
 
            self.assertEqual(importer.batch_size, 200)
 

	
 
            # but can override with handler init kwarg
 
            handler = handlers.ImportHandler(batch_size=37)
 
            importer = handler.get_importer('Foo')
 
            self.assertEqual(importer.batch_size, 37)
 

	
 
            # can also override with get_importer kwarg
 
            handler = handlers.ImportHandler()
 
            importer = handler.get_importer('Foo', batch_size=98)
 
            self.assertEqual(importer.batch_size, 98)
 

	
 
    @patch('rattail.importing.handlers.send_email')
 
    def test_process_changes_sends_email(self, send_email):
 
        handler = handlers.ImportHandler()
 
        handler.import_began = pytz.utc.localize(datetime.datetime.utcnow())
 
        changes = [], [], []
 

	
 
        # warnings disabled
 
        handler.warnings = False
 
        handler.process_changes(changes)
 
        self.assertFalse(send_email.called)
 

	
 
        # warnings enabled
 
        handler.warnings = True
 
        handler.process_changes(changes)
 
        self.assertEqual(send_email.call_count, 1)
 

	
 
        send_email.reset_mock()
 

	
 
        # warnings enabled, with command (just for coverage..)
 
        handler.warnings = True
 
        handler.command = Mock(name='import-testing', parent=Mock(name='rattail'))
 
        handler.process_changes(changes)
 
        self.assertEqual(send_email.call_count, 1)
 

	
 
@@ -570,87 +606,24 @@ class TestToSQLAlchemyHandler(unittest.TestCase):
 
        self.assertIsInstance(handler.session, orm.Session)
 
        handler.session.close()
 

	
 
    def test_commit_local_transaction(self):
 
        # TODO: test actual commit for data changes
 
        session = Session()
 
        handler = handlers.ToSQLAlchemyHandler(session=session)
 
        self.assertIs(handler.session, session)
 
        with patch.object(handler, 'session') as session:
 
            handler.commit_local_transaction()
 
            session.commit.assert_called_once_with()
 
            self.assertFalse(session.rollback.called)
 
        # self.assertIsNone(handler.session)
 

	
 
    def test_rollback_local_transaction(self):
 
        # TODO: test actual rollback for data changes
 
        session = Session()
 
        handler = handlers.ToSQLAlchemyHandler(session=session)
 
        self.assertIs(handler.session, session)
 
        with patch.object(handler, 'session') as session:
 
            handler.rollback_local_transaction()
 
            session.rollback.assert_called_once_with()
 
            self.assertFalse(session.commit.called)
 
        # self.assertIsNone(handler.session)
 

	
 

	
 
######################################################################
 
# fake bulk import handler, tested mostly for basic coverage
 
######################################################################
 

	
 
class MockBulkImportHandler(handlers.BulkToPostgreSQLHandler):
 

	
 
    def get_importers(self):
 
        return {'Department': MockBulkImporter}
 

	
 
    def make_session(self):
 
        return Session()
 

	
 

	
 
class TestBulkImportHandlerOld(RattailTestCase, ImporterTester):
 

	
 
    importer_class = MockBulkImporter
 

	
 
    sample_data = {
 
        'grocery': {'number': 1, 'name': "Grocery", 'uuid': 'decd909a194011e688093ca9f40bc550'},
 
        'bulk': {'number': 2, 'name': "Bulk", 'uuid': 'e633d54c194011e687e33ca9f40bc550'},
 
        'hba': {'number': 3, 'name': "HBA", 'uuid': 'e2bad79e194011e6a4783ca9f40bc550'},
 
    }
 

	
 
    def setUp(self):
 
        self.setup_rattail()
 
        self.tempio = TempIO()
 
        self.config.set('rattail', 'workdir', self.tempio.realpath())
 
        self.handler = MockBulkImportHandler(config=self.config)
 

	
 
    def tearDown(self):
 
        self.teardown_rattail()
 
        self.tempio = None
 

	
 
    def import_data(self, host_data=None, **kwargs):
 
        if host_data is None:
 
            host_data = list(self.copy_data().itervalues())
 
        with patch.object(self.importer_class, 'normalize_host_data', Mock(return_value=host_data)):
 
            with patch.object(self.handler, 'make_session', Mock(return_value=self.session)):
 
                return self.handler.import_data('Department', **kwargs)
 

	
 
    def test_invalid_importer_key_is_ignored(self):
 
        handler = MockBulkImportHandler()
 
        self.assertNotIn('InvalidKey', handler.importers)
 
        self.assertEqual(handler.import_data('InvalidKey'), {})
 

	
 
    def assert_import_created(self, *keys):
 
        pass
 

	
 
    def assert_import_updated(self, *keys):
 
        pass
 

	
 
    def assert_import_deleted(self, *keys):
 
        pass
 

	
 
    def test_normal_run(self):
 
        if self.postgresql():
 
            self.import_data()
 

	
 
    def test_dry_run(self):
 
        if self.postgresql():
 
            self.import_data(dry_run=True)
rattail/tests/importing/test_importers.py
Show inline comments
 
@@ -38,51 +38,52 @@ class ImporterBattery(ImporterTester):
 
                    call(1, 1), call(2, 2), call(3, 3)])
 

	
 
    def test_import_data_max_create(self):
 
        importer = self.make_importer()
 
        with patch.object(importer, 'get_key', lambda k: k):
 
            with patch.object(importer, 'create_object') as create:
 
                importer.import_data(host_data=[1, 2, 3], max_create=1)
 
                self.assertEqual(create.call_args_list, [call(1, 1)])
 

	
 

	
 
class BulkImporterBattery(ImporterBattery):
 
    """
 
    Battery of tests which can hopefully be ran for any bulk importer.
 
    """
 

	
 
    def test_import_data_empty(self):
 
        importer = self.make_importer()
 
        result = importer.import_data()
 
        self.assertEqual(result, 0)
 

	
 

	
 
class TestImporter(TestCase):
 

	
 
    def test_init(self):
 

	
 
        # defaults
 
        importer = importers.Importer()
 
        self.assertIsNone(importer.model_class)
 
        self.assertIsNone(importer.model_name)
 
        self.assertIsNone(importer.key)
 
        self.assertEqual(importer.fields, [])
 
        self.assertIsNone(importer.host_system_title)
 

	
 
        # key must be included among the fields
 
        self.assertRaises(ValueError, importers.Importer, key='upc', fields=[])
 
        importer = importers.Importer(key='upc', fields=['upc'])
 
        self.assertEqual(importer.key, ('upc',))
 
        self.assertEqual(importer.fields, ['upc'])
 

	
 
        # extra bits are passed as-is
 
        importer = importers.Importer()
 
        self.assertFalse(hasattr(importer, 'extra_bit'))
 
        extra_bit = object()
 
        importer = importers.Importer(extra_bit=extra_bit)
 
        self.assertIs(importer.extra_bit, extra_bit)
 

	
 
    def test_delete_flag(self):
 
        # disabled by default
 
        importer = importers.Importer()
 
        self.assertTrue(importer.allow_delete)
 
        self.assertFalse(importer.delete)
 
        importer.import_data(host_data=[])
 
        self.assertFalse(importer.delete)
 
@@ -171,66 +172,184 @@ class TestImporter(TestCase):
 
        self.assertEqual(host_data, [])
 

	
 
        host_data = importer.normalize_host_data(host_objects=data)
 
        self.assertEqual(host_data, data)
 

	
 
        with patch.object(importer, 'get_host_objects', new=Mock(return_value=data)):
 
            host_data = importer.normalize_host_data()
 
        self.assertEqual(host_data, data)
 

	
 
    def test_get_deletion_keys(self):
 
        importer = importers.Importer()
 
        self.assertFalse(importer.caches_local_data)
 
        keys = importer.get_deletion_keys()
 
        self.assertEqual(keys, set())
 

	
 
        importer.caches_local_data = True
 
        self.assertIsNone(importer.cached_local_data)
 
        keys = importer.get_deletion_keys()
 
        self.assertEqual(keys, set())
 

	
 
        importer.cached_local_data = {'delete-me': object()}
 
        keys = importer.get_deletion_keys()
 
        self.assertEqual(keys, set(['delete-me']))
 

	
 
    def test_model_name_attr(self):
 
        # default is None
 
        importer = importers.Importer()
 
        self.assertIsNone(importer.model_name)
 

	
 
        # but may be overridden via init kwarg
 
        importer = importers.Importer(model_name='Foo')
 
        self.assertEqual(importer.model_name, 'Foo')
 

	
 
        # or may inherit its value from 'model_class'
 
        class Foo:
 
            pass
 
        importer = importers.Importer(model_class=Foo)
 
        self.assertEqual(importer.model_name, 'Foo')
 

	
 
    def test_batch_size_attr(self):
 
        # default is 200
 
        importer = importers.Importer()
 
        self.assertEqual(importer.batch_size, 200)
 

	
 
        # but may be overridden via init kwarg
 
        importer = importers.Importer(batch_size=0)
 
        self.assertEqual(importer.batch_size, 0)
 
        importer = importers.Importer(batch_size=42)
 
        self.assertEqual(importer.batch_size, 42)
 

	
 
        # batch size determines how often flush occurs
 
        data = [{'id': i} for i in range(1, 101)]
 
        importer = importers.Importer(model_name='Foo', key='id', fields=['id'], empty_local_data=True)
 
        with patch.object(importer, 'create_object'): # just mock that out
 
            with patch.object(importer, 'flush_create_update') as flush:
 

	
 
                # 4 batches @ 33/per
 
                importer.import_data(host_data=data, batch_size=33)
 
                self.assertEqual(flush.call_count, 4)
 
                flush.reset_mock()
 

	
 
                # 3 batches @ 34/per
 
                importer.import_data(host_data=data, batch_size=34)
 
                self.assertEqual(flush.call_count, 3)
 
                flush.reset_mock()
 

	
 
                # 2 batches @ 50/per
 
                importer.import_data(host_data=data, batch_size=100)
 
                self.assertEqual(flush.call_count, 2)
 
                flush.reset_mock()
 

	
 
                # one extra/final flush happens, whenever the total number of
 
                # changes happens to match the batch size...
 

	
 
                # 1 batch @ 100/per, plus final flush
 
                importer.import_data(host_data=data, batch_size=100)
 
                self.assertEqual(flush.call_count, 2)
 
                flush.reset_mock()
 

	
 
                # 1 batch @ 200/per
 
                importer.import_data(host_data=data, batch_size=200)
 
                self.assertEqual(flush.call_count, 1)
 
                flush.reset_mock()
 

	
 
                # one extra/final flush also happens when batching is disabled
 

	
 
                # 100 "batches" @ 0/per, plus final flush
 
                importer.import_data(host_data=data, batch_size=0)
 
                self.assertEqual(flush.call_count, 101)
 
                flush.reset_mock()
 

	
 

	
 
class TestFromQuery(RattailTestCase):
 

	
 
    def test_query(self):
 
        importer = importers.FromQuery()
 
        self.assertRaises(NotImplementedError, importer.query)
 

	
 
    def test_get_host_objects(self):
 
        query = self.session.query(model.Product)
 
        importer = importers.FromQuery()
 
        with patch.object(importer, 'query', Mock(return_value=query)):
 
            objects = importer.get_host_objects()
 
        self.assertIsInstance(objects, QuerySequence)
 

	
 

	
 
class TestBulkImporter(TestCase, BulkImporterBattery):
 
    importer_class = importers.BulkImporter
 
        
 
    def test_batch_size_attr(self):
 
        # default is 200
 
        importer = importers.BulkImporter()
 
        self.assertEqual(importer.batch_size, 200)
 

	
 
        # but may be overridden via init kwarg
 
        importer = importers.BulkImporter(batch_size=0)
 
        self.assertEqual(importer.batch_size, 0)
 
        importer = importers.BulkImporter(batch_size=42)
 
        self.assertEqual(importer.batch_size, 42)
 

	
 
        # batch size determines how often flush occurs
 
        data = [{'id': i} for i in range(1, 101)]
 
        importer = importers.BulkImporter(model_name='Foo', key='id', fields=['id'], empty_local_data=True)
 
        with patch.object(importer, 'create_object'): # just mock that out
 
            with patch.object(importer, 'flush_create_update') as flush:
 

	
 
                # 4 batches @ 33/per
 
                importer.import_data(host_data=data, batch_size=33)
 
                self.assertEqual(flush.call_count, 4)
 
                flush.reset_mock()
 

	
 
                # 3 batches @ 34/per
 
                importer.import_data(host_data=data, batch_size=34)
 
                self.assertEqual(flush.call_count, 3)
 
                flush.reset_mock()
 

	
 
                # 2 batches @ 50/per
 
                importer.import_data(host_data=data, batch_size=100)
 
                self.assertEqual(flush.call_count, 2)
 
                flush.reset_mock()
 

	
 
                # one extra/final flush happens, whenever the total number of
 
                # changes happens to match the batch size...
 

	
 
                # 1 batch @ 100/per, plus final flush
 
                importer.import_data(host_data=data, batch_size=100)
 
                self.assertEqual(flush.call_count, 2)
 
                flush.reset_mock()
 

	
 
                # 1 batch @ 200/per
 
                importer.import_data(host_data=data, batch_size=200)
 
                self.assertEqual(flush.call_count, 1)
 
                flush.reset_mock()
 

	
 
                # one extra/final flush also happens when batching is disabled
 

	
 
                # 100 "batches" @ 0/per, plus final flush
 
                importer.import_data(host_data=data, batch_size=0)
 
                self.assertEqual(flush.call_count, 101)
 
                flush.reset_mock()
 

	
 

	
 
######################################################################
 
# fake importer class, tested mostly for basic coverage
 
######################################################################
 

	
 
class Product(object):
 
    upc = None
 
    description = None
 

	
 

	
 
class MockImporter(importers.Importer):
 
    model_class = Product
 
    key = 'upc'
 
    simple_fields = ['upc', 'description']
 
    supported_fields = simple_fields
 
    caches_local_data = True
 
    flush_every_x = 1
 
    session = Mock()
 

	
 
    def normalize_local_object(self, obj):
 
        return obj
 

	
 
    def update_object(self, obj, host_data, local_data=None):
rattail/tests/importing/test_postgresql.py
Show inline comments
 
@@ -18,59 +18,58 @@ from rattail.tests.importing import ImporterTester
 
from rattail.tests.importing.test_rattail import DualRattailTestCase
 
from rattail.time import localtime
 

	
 

	
 
class Widget(object):
 
    pass
 

	
 

	
 
class TestBulkToPostgreSQL(unittest.TestCase):
 

	
 
    def setUp(self):
 
        self.tempio = TempIO()
 
        self.config = RattailConfig()
 
        self.config.set('rattail', 'workdir', self.tempio.realpath())
 
        self.config.set('rattail', 'timezone.default', 'America/Chicago')
 

	
 
    def tearDown(self):
 
        self.tempio = None
 

	
 
    def make_importer(self, **kwargs):
 
        kwargs.setdefault('config', self.config)
 
        kwargs.setdefault('fields', ['id']) # hack
 
        return pgimport.BulkToPostgreSQL(**kwargs)
 

	
 
    def test_data_path(self):
 
        importer = self.make_importer(config=None)
 
        self.assertIsNone(importer.config)
 
        self.assertRaises(AttributeError, getattr, importer, 'data_path')
 
        importer.config = RattailConfig()
 
        self.assertRaises(ConfigurationError, getattr, importer, 'data_path')
 
        importer.config = self.config
 
    def test_data_path_property(self):
 
        self.config.set('rattail', 'workdir', '/tmp')
 
        self.assertEqual(importer.data_path, '/tmp/import_bulk_postgresql_None.csv') # no model yet
 
        importer.model_class = Widget
 
        self.assertEqual(importer.data_path, '/tmp/import_bulk_postgresql_Widget.csv')
 
        importer = pgimport.BulkToPostgreSQL(config=self.config, fields=['id'])
 

	
 
        # path leverages model name, so default is None
 
        self.assertEqual(importer.data_path, '/tmp/import_bulk_postgresql_None.csv')
 

	
 
        # but it will reflect model name
 
        importer.model_name = 'Foo'
 
        self.assertEqual(importer.data_path, '/tmp/import_bulk_postgresql_Foo.csv')
 

	
 
    def test_setup(self):
 
        importer = self.make_importer()
 
        self.assertFalse(hasattr(importer, 'data_buffer'))
 
        importer.setup()
 
        self.assertIsNotNone(importer.data_buffer)
 
        importer.data_buffer.close()
 

	
 
    def test_teardown(self):
 
        importer = self.make_importer()
 
        importer.data_buffer = open(importer.data_path, 'wb')
 
        importer.teardown()
 
        self.assertIsNone(importer.data_buffer)
 

	
 
    def test_prep_value_for_postgres(self):
 
        importer = self.make_importer()
 

	
 
        # constants
 
        self.assertEqual(importer.prep_value_for_postgres(None), '\\N')
 
        self.assertEqual(importer.prep_value_for_postgres(True), 't')
 
        self.assertEqual(importer.prep_value_for_postgres(False), 'f')
 

	
 
        # datetime (local zone is Chicago/CDT; UTC-5)
 
        value = localtime(self.config, datetime.datetime(2016, 5, 13, 12))
rattail/tests/importing/test_sqlalchemy.py
Show inline comments
 
# -*- coding: utf-8 -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from unittest import TestCase
 

	
 
from mock import patch
 

	
 
import sqlalchemy as sa
 
from sqlalchemy import orm
 
from sqlalchemy.orm.exc import MultipleResultsFound
 

	
 
from rattail.importing import sqlalchemy as saimport
 

	
 

	
 
class Widget(object):
 
    pass
 

	
 
metadata = sa.MetaData()
 

	
 
widget_table = sa.Table('widgets', metadata,
 
                        sa.Column('id', sa.Integer(), primary_key=True),
 
                        sa.Column('description', sa.String(length=50)))
 

	
 
widget_mapper = orm.mapper(Widget, widget_table)
 

	
 
WIDGETS = [
 
    {'id': 1, 'description': "Main Widget"},
 
    {'id': 2, 'description': "Other Widget"},
 
    {'id': 3, 'description': "Other Widget"},
 
]
 

	
 
@@ -124,24 +126,30 @@ class TestToSQLAlchemy(TestCase):
 
            self.assertEqual(cache[i].description, WIDGETS[i-1]['description'])
 

	
 
    def test_cache_local_data(self):
 
        importer = self.make_importer(key='id', session=self.session)
 
        cache = importer.cache_local_data()
 
        self.assertEqual(len(cache), 3)
 
        for i in range(1, 4):
 
            self.assertIn((i,), cache)
 
            cached = cache[(i,)]
 
            self.assertIsInstance(cached, dict)
 
            self.assertIsInstance(cached['object'], Widget)
 
            self.assertEqual(cached['object'].id, i)
 
            self.assertEqual(cached['object'].description, WIDGETS[i-1]['description'])
 
            self.assertIsInstance(cached['data'], dict)
 
            self.assertEqual(cached['data']['id'], i)
 
            self.assertEqual(cached['data']['description'], WIDGETS[i-1]['description'])
 

	
 
    # TODO: lame
 
    def test_flush_session(self):
 
        importer = self.make_importer(fields=['id'], session=self.session, flush_session=True)
 
        widget = Widget()
 
        widget.id = 1
 
        widget, original = importer.update_object(widget, {'id': 1}), widget
 
        self.assertIs(widget, original)
 

	
 
    def test_flush_create_update(self):
 
        importer = saimport.ToSQLAlchemy(fields=['id'], session=self.session)
 
        with patch.object(self.session, 'flush') as flush:
 
            importer.flush_create_update()
 
            flush.assert_called_once_with()
0 comments (0 inline, 0 general)