Changeset - 22e10accf5f0
[Not reviewed]
0 10 0
Lance Edgar - 8 years ago 2016-05-17 15:02:08
ledgar@sacfoodcoop.com
Add "batch" support to importer framework

Doesn't really split things up per se, just controls how often a "flush
changes" should occur.
10 files changed with 293 insertions and 146 deletions:
0 comments (0 inline, 0 general)
rattail/commands/importing.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2016 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
#  Rattail is free software: you can redistribute it and/or modify it under the
 
#  terms of the GNU Affero General Public License as published by the Free
 
#  Software Foundation, either version 3 of the License, or (at your option)
 
#  any later version.
 
#
 
#  Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
 
#  WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 
#  FOR A PARTICULAR PURPOSE.  See the GNU Affero General Public License for
 
#  more details.
 
#
 
#  You should have received a copy of the GNU Affero General Public License
 
#  along with Rattail.  If not, see <http://www.gnu.org/licenses/>.
 
#
 
################################################################################
 
"""
 
Importing Commands
 
"""
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
import logging
 

	
 
from rattail.commands.core import Subcommand, date_argument
 
from rattail.util import load_object
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
class ImportSubcommand(Subcommand):
 
    """
 
    Base class for subcommands which use the (new) data importing system.
 
    """
 
    handler_spec = None
 

	
 
    # TODO: move this into Subcommand or something..
 
    parent_name = None
 
    def __init__(self, *args, **kwargs):
 
        if 'handler_spec' in kwargs:
 
            self.handler_spec = kwargs.pop('handler_spec')
 
        super(ImportSubcommand, self).__init__(*args, **kwargs)
 
        if self.parent:
 
            self.parent_name = self.parent.name
 

	
 
    def get_handler_factory(self):
 
        """
 
        Subclasses must override this, and return a callable that creates an
 
        import handler instance which the command should use.
 
        """
 
        if self.handler_spec:
 
            return load_object(self.handler_spec)
 
        raise NotImplementedError
 

	
 
    def get_handler(self, **kwargs):
 
        """
 
        Returns a handler instance to be used by the command.
 
        """
 
        factory = self.get_handler_factory()
 
        kwargs.setdefault('config', getattr(self, 'config', None))
 
        kwargs.setdefault('command', self)
 
        kwargs.setdefault('progress', self.progress)
 
        if 'args' in kwargs:
 
            args = kwargs['args']
 
            kwargs.setdefault('dry_run', args.dry_run)
 
            if hasattr(args, 'batch_size'):
 
                kwargs.setdefault('batch_size', args.batch_size)
 
            # kwargs.setdefault('max_create', args.max_create)
 
            # kwargs.setdefault('max_update', args.max_update)
 
            # kwargs.setdefault('max_delete', args.max_delete)
 
            # kwargs.setdefault('max_total', args.max_total)
 
        kwargs = self.get_handler_kwargs(**kwargs)
 
        return factory(**kwargs)
 

	
 
    def get_handler_kwargs(self, **kwargs):
 
        """
 
        Return a dict of kwargs to be passed to the handler factory.
 
        """
 
        return kwargs
 

	
 
    def add_parser_args(self, parser):
 

	
 
        # model names (aka importer keys)
 
        doc = ("Which data models to import.  If you specify any, then only "
 
               "data for those models will be imported.  If you do not specify "
 
               "any, then all *default* models will be imported.")
 
        try:
 
            handler = self.get_handler()
 
        except NotImplementedError:
 
            pass
 
        else:
 
            doc += "  Supported models are: ({})".format(', '.join(handler.get_importer_keys()))
 
        parser.add_argument('models', nargs='*', metavar='MODEL', help=doc)
 

	
 
        # start/end date
 
        parser.add_argument('--start-date', type=date_argument,
 
                            help="Optional (inclusive) starting point for date range, by which host "
 
                            "data should be filtered.  Only used by certain importers.")
 
        parser.add_argument('--end-date', type=date_argument,
 
                            help="Optional (inclusive) ending point for date range, by which host "
 
                            "data should be filtered.  Only used by certain importers.")
 

	
 
        # allow create?
 
        parser.add_argument('--create', action='store_true', default=True,
 
                            help="Allow new records to be created during the import.")
 
        parser.add_argument('--no-create', action='store_false', dest='create',
 
                            help="Do not allow new records to be created during the import.")
 
        parser.add_argument('--max-create', type=int, metavar='COUNT',
 
                            help="Maximum number of records which may be created, after which a "
 
                            "given import task should stop.  Note that this applies on a per-model "
 
                            "basis and not overall.")
 

	
 
        # allow update?
 
        parser.add_argument('--update', action='store_true', default=True,
 
                            help="Allow existing records to be updated during the import.")
 
        parser.add_argument('--no-update', action='store_false', dest='update',
 
                            help="Do not allow existing records to be updated during the import.")
 
        parser.add_argument('--max-update', type=int, metavar='COUNT',
 
                            help="Maximum number of records which may be updated, after which a "
 
                            "given import task should stop.  Note that this applies on a per-model "
 
                            "basis and not overall.")
 

	
 
        # allow delete?
 
        parser.add_argument('--delete', action='store_true', default=False,
 
                            help="Allow records to be deleted during the import.")
 
        parser.add_argument('--no-delete', action='store_false', dest='delete',
 
                            help="Do not allow records to be deleted during the import.")
 
        parser.add_argument('--max-delete', type=int, metavar='COUNT',
 
                            help="Maximum number of records which may be deleted, after which a "
 
                            "given import task should stop.  Note that this applies on a per-model "
 
                            "basis and not overall.")
 

	
 
        # max total changes, per model
 
        parser.add_argument('--max-total', type=int, metavar='COUNT',
 
                            help="Maximum number of *any* record changes which may occur, after which "
 
                            "a given import task should stop.  Note that this applies on a per-model "
 
                            "basis and not overall.")
 

	
 
        # batch size
 
        parser.add_argument('--batch', type=int, dest='batch_size', metavar='SIZE', default=200,
 
                            help="Split work to be done into batches, with the specified number of "
 
                            "records in each batch.  Or, set this to 0 (zero) to disable batching. "
 
                            "Implementation for this may vary somewhat between importers; default "
 
                            "batch size is 200 records.")
 

	
 
        # treat changes as warnings?
 
        parser.add_argument('--warnings', '-W', action='store_true',
 
                            help="Set this flag if you expect a \"clean\" import, and wish for any "
 
                            "changes which do occur to be processed further and/or specially.  The "
 
                            "behavior of this flag is ultimately up to the import handler, but the "
 
                            "default is to send an email notification.")
 

	
 
        # dry run?
 
        parser.add_argument('--dry-run', action='store_true',
 
                            help="Go through the full motions and allow logging etc. to "
 
                            "occur, but rollback (abort) the transaction at the end.")
 

	
 
    def run(self, args):
 
        log.info("begin `{} {}` for data models: {}".format(
 
                self.parent_name, self.name, ', '.join(args.models or ["(ALL)"])))
 

	
 
        handler = self.get_handler(args=args)
 
        models = args.models or handler.get_default_keys()
 
        log.debug("using handler: {}".format(handler))
 
        log.debug("importing models: {}".format(models))
 
        log.debug("args are: {}".format(args))
 

	
 
        kwargs = {
 
            'dry_run': args.dry_run,
 
            'warnings': args.warnings,
 
            'create': args.create,
 
            'max_create': args.max_create,
 
            'update': args.update,
 
            'max_update': args.max_update,
 
            'delete': args.delete,
 
            'max_delete': args.max_delete,
 
            'max_total': args.max_total,
 
            'progress': self.progress,
 
            'args': args,
 
        }
 
        handler.import_data(*models, **kwargs)
 

	
 
        # TODO: should this logging happen elsewhere / be customizable?
 
        if args.dry_run:
 
            log.info("dry run, so transaction was rolled back")
 
        else:
 
            log.info("transaction was committed")
 

	
 

	
 
class ImportRattail(ImportSubcommand):
 
    """
 
    Import data from another Rattail database
 
    """
rattail/importing/handlers.py
Show inline comments
 
@@ -46,96 +46,98 @@ log = logging.getLogger(__name__)
 
class ImportHandler(object):
 
    """
 
    Base class for all import handlers.
 
    """
 
    host_title = None
 
    local_title = None
 
    progress = None
 
    dry_run = False
 
    commit_host_partial = False
 

	
 
    def __init__(self, config=None, **kwargs):
 
        self.config = config
 
        self.importers = self.get_importers()
 
        for key, value in kwargs.iteritems():
 
            setattr(self, key, value)
 

	
 
    def get_importers(self):
 
        """
 
        Returns a dict of all available importers, where the keys are model
 
        names and the values are importer factories.  All subclasses will want
 
        to override this.  Note that if you return an
 
        :class:`python:collections.OrderedDict` instance, you can affect the
 
        ordering of keys in the command line help system, etc.
 
        """
 
        return {}
 

	
 
    def get_importer_keys(self):
 
        """
 
        Returns the list of keys corresponding to the available importers.
 
        """
 
        return list(self.importers.iterkeys())
 

	
 
    def get_default_keys(self):
 
        """
 
        Returns the list of keys corresponding to the "default" importers.
 
        Override this if you wish certain importers to be excluded by default,
 
        e.g. when first testing them out etc.
 
        """
 
        return self.get_importer_keys()
 

	
 
    def get_importer(self, key, **kwargs):
 
        """
 
        Returns an importer instance corresponding to the given key.
 
        """
 
        if key in self.importers:
 
            kwargs.setdefault('handler', self)
 
            kwargs.setdefault('config', self.config)
 
            kwargs.setdefault('host_system_title', self.host_title)
 
            if hasattr(self, 'batch_size'):
 
                kwargs.setdefault('batch_size', self.batch_size)
 
            kwargs = self.get_importer_kwargs(key, **kwargs)
 
            return self.importers[key](**kwargs)
 

	
 
    def get_importer_kwargs(self, key, **kwargs):
 
        """
 
        Return a dict of kwargs to be used when construcing an importer with
 
        the given key.
 
        """
 
        return kwargs
 

	
 
    def import_data(self, *keys, **kwargs):
 
        """
 
        Import all data for the given importer/model keys.
 
        """
 
        self.import_began = make_utc(datetime.datetime.utcnow(), tzinfo=True)
 
        if 'dry_run' in kwargs:
 
            self.dry_run = kwargs['dry_run']
 
        self.progress = kwargs.pop('progress', getattr(self, 'progress', None))
 
        self.warnings = kwargs.pop('warnings', False)
 
        kwargs.update({'dry_run': self.dry_run,
 
                       'progress': self.progress})
 

	
 
        self.setup()
 
        self.begin_transaction()
 
        changes = OrderedDict()
 

	
 
        try:
 
            for key in keys:
 
                importer = self.get_importer(key, **kwargs)
 
                if importer:
 
                    created, updated, deleted = importer.import_data()
 
                    changed = bool(created or updated or deleted)
 
                    logger = log.warning if changed and self.warnings else log.info
 
                    logger("{} -> {}: added {}, updated {}, deleted {} {} records".format(
 
                        self.host_title, self.local_title, len(created), len(updated), len(deleted), key))
 
                    if changed:
 
                        changes[key] = created, updated, deleted
 
                else:
 
                    log.warning("skipping unknown importer: {}".format(key))
 
        except:
 
            if self.commit_host_partial and not self.dry_run:
 
                log.warning("{host} -> {local}: committing partial transaction on host {host} (despite error)".format(
 
                    host=self.host_title, local=self.local_title))
 
                self.commit_host_transaction()
 
            raise
 
        else:
 
            if changes:
 
                self.process_changes(changes)
 
@@ -294,54 +296,48 @@ class FromSQLAlchemyHandler(ImportHandler):
 
    def get_importer_kwargs(self, key, **kwargs):
 
        kwargs = super(FromSQLAlchemyHandler, self).get_importer_kwargs(key, **kwargs)
 
        kwargs.setdefault('host_session', self.host_session)
 
        return kwargs
 

	
 
    def begin_host_transaction(self):
 
        self.host_session = self.make_host_session()
 

	
 
    def rollback_host_transaction(self):
 
        self.host_session.rollback()
 
        self.host_session.close()
 
        self.host_session = None
 

	
 
    def commit_host_transaction(self):
 
        self.host_session.commit()
 
        self.host_session.close()
 
        self.host_session = None
 

	
 

	
 
class ToSQLAlchemyHandler(ImportHandler):
 
    """
 
    Handler for imports which target a SQLAlchemy ORM on the local side.
 
    """
 
    session = None
 

	
 
    def make_session(self):
 
        """
 
        Subclasses must override this to define the local database connection.
 
        """
 
        raise NotImplementedError
 

	
 
    def get_importer_kwargs(self, key, **kwargs):
 
        kwargs = super(ToSQLAlchemyHandler, self).get_importer_kwargs(key, **kwargs)
 
        kwargs.setdefault('session', self.session)
 
        return kwargs
 

	
 
    def begin_local_transaction(self):
 
        self.session = self.make_session()
 

	
 
    def rollback_local_transaction(self):
 
        self.session.rollback()
 
        self.session.close()
 
        self.session = None
 

	
 
    def commit_local_transaction(self):
 
        self.session.commit()
 
        self.session.close()
 
        self.session = None
 

	
 

	
 
class BulkToPostgreSQLHandler(BulkImportHandler):
 
    """
 
    Handler for bulk imports which target PostgreSQL on the local side.
 
    """
rattail/importing/importers.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2016 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
#  Rattail is free software: you can redistribute it and/or modify it under the
 
#  terms of the GNU Affero General Public License as published by the Free
 
#  Software Foundation, either version 3 of the License, or (at your option)
 
#  any later version.
 
#
 
#  Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
 
#  WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 
#  FOR A PARTICULAR PURPOSE.  See the GNU Affero General Public License for
 
#  more details.
 
#
 
#  You should have received a copy of the GNU Affero General Public License
 
#  along with Rattail.  If not, see <http://www.gnu.org/licenses/>.
 
#
 
################################################################################
 
"""
 
Data Importers
 
"""
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
import datetime
 
import logging
 

	
 
from rattail.db.util import QuerySequence
 
from rattail.time import make_utc
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
class Importer(object):
 
    """
 
    Base class for all data importers.
 
    """
 
    # Set this to the data model class which is targeted on the local side.
 
    model_class = None
 
    model_name = None
 

	
 
    key = None
 

	
 
    # The full list of field names supported by the importer, i.e. for the data
 
    # model to which the importer pertains.  By definition this list will be
 
    # restricted to what the local side can acommodate, but may be further
 
    # restricted by what the host side has to offer.
 
    supported_fields = []
 

	
 
    # The list of field names which may be considered "simple" and therefore
 
    # treated as such, i.e. with basic getattr/setattr calls.  Note that this
 
    # only applies to the local side, it has no effect on the host side.
 
    simple_fields = []
 

	
 
    allow_create = True
 
    allow_update = True
 
    allow_delete = True
 
    dry_run = False
 

	
 
    max_create = None
 
    max_update = None
 
    max_delete = None
 
    max_total = None
 
    batch_size = 200
 
    progress = None
 

	
 
    empty_local_data = False
 
    caches_local_data = False
 
    cached_local_data = None
 

	
 
    host_system_title = None
 
    local_system_title = None
 

	
 
    def __init__(self, config=None, fields=None, key=None, **kwargs):
 
        self.config = config
 
        self.fields = fields or self.supported_fields
 
        if key is not None:
 
            self.key = key
 
        if isinstance(self.key, basestring):
 
            self.key = (self.key,)
 
        if self.key:
 
            for field in self.key:
 
                if field not in self.fields:
 
                    raise ValueError("Key field '{}' must be included in effective fields "
 
                                     "for {}".format(field, self.__class__.__name__))
 
        self.model_class = kwargs.pop('model_class', self.model_class)
 
        self.model_name = kwargs.pop('model_name', self.model_name)
 
        if not self.model_name and self.model_class:
 
            self.model_name = self.model_class.__name__
 
        self._setup(**kwargs)
 

	
 
    def _setup(self, **kwargs):
 
        self.create = kwargs.pop('create', self.allow_create) and self.allow_create
 
        self.update = kwargs.pop('update', self.allow_update) and self.allow_update
 
        self.delete = kwargs.pop('delete', False) and self.allow_delete
 
        for key, value in kwargs.iteritems():
 
            setattr(self, key, value)
 

	
 
    @property
 
    def model_name(self):
 
        if self.model_class:
 
            return self.model_class.__name__
 

	
 
    def setup(self):
 
        """
 
        Perform any setup necessary, e.g. cache lookups for existing data.
 
        """
 

	
 
    def teardown(self):
 
        """
 
        Perform any cleanup after import, if necessary.
 
        """
 

	
 
    def import_data(self, host_data=None, now=None, **kwargs):
 
        """
 
        Import some data!  This is the core body of logic for that, regardless
 
        of where data is coming from or where it's headed.  Note that this
 
        method handles deletions as well as adds/updates.
 
        """
 
        self.now = now or make_utc(datetime.datetime.utcnow(), tzinfo=True)
 
        if kwargs:
 
            self._setup(**kwargs)
 
        self.setup()
 
        created = updated = deleted = []
 

	
 
        # Get complete set of normalized host data.
 
        if host_data is None:
 
            host_data = self.normalize_host_data()
 

	
 
        # Prune duplicate keys from host/source data.  This is for the sake of
 
        # sanity since duplicates typically lead to a ping-pong effect, where a
 
        # "clean" (change-less) import is impossible.
 
        unique = {}
 
        for data in host_data:
 
            key = self.get_key(data)
 
            if key in unique:
 
                log.warning("duplicate records detected from {} for key: {}".format(
 
                    self.host_system_title, key))
 
            unique[key] = data
 
        host_data = []
 
        for key in sorted(unique):
 
            host_data.append(unique[key])
 

	
 
        # Cache local data if appropriate.
 
        if self.caches_local_data:
 
            self.cached_local_data = self.cache_local_data(host_data)
 

	
 
        # Create and/or update data.
 
        if self.create or self.update:
 
            created, updated = self._import_create_update(host_data)
 

	
 
@@ -163,115 +164,119 @@ class Importer(object):
 
        """
 
        Import the given data; create and/or update records as needed.
 
        """
 
        created, updated = [], []
 
        count = len(data)
 
        if not count:
 
            return created, updated
 

	
 
        prog = None
 
        if self.progress:
 
            prog = self.progress("Importing {} data".format(self.model_name), count)
 

	
 
        for i, host_data in enumerate(data, 1):
 

	
 
            # Fetch local object, using key from host data.
 
            key = self.get_key(host_data)
 
            local_object = self.get_local_object(key)
 

	
 
            # If we have a local object, but its data differs from host, update it.
 
            if local_object and self.update:
 
                local_data = self.normalize_local_object(local_object)
 
                diffs = self.data_diffs(local_data, host_data)
 
                if diffs:
 
                    log.debug("fields '{}' differed for local data: {}, host data: {}".format(
 
                        ','.join(diffs), local_data, host_data))
 
                    local_object = self.update_object(local_object, host_data, local_data)
 
                    updated.append((local_object, local_data, host_data))
 
                    if self.max_update and len(updated) >= self.max_update:
 
                        log.warning("max of {} *updated* records has been reached; stopping now".format(self.max_update))
 
                        break
 
                    if self.max_total and (len(created) + len(updated)) >= self.max_total:
 
                        log.warning("max of {} *total changes* has been reached; stopping now".format(self.max_total))
 
                        break
 

	
 
            # If we did not yet have a local object, create it using host data.
 
            elif not local_object and self.create:
 
                local_object = self.create_object(key, host_data)
 
                log.debug("created new {} {}: {}".format(self.model_name, key, local_object))
 
                created.append((local_object, host_data))
 
                if self.caches_local_data and self.cached_local_data is not None:
 
                    self.cached_local_data[key] = {'object': local_object, 'data': self.normalize_local_object(local_object)}
 
                if self.max_create and len(created) >= self.max_create:
 
                    log.warning("max of {} *created* records has been reached; stopping now".format(self.max_create))
 
                    break
 
                if self.max_total and (len(created) + len(updated)) >= self.max_total:
 
                    log.warning("max of {} *total changes* has been reached; stopping now".format(self.max_total))
 
                    break
 

	
 
            self.flush_changes(i)
 
            # # TODO: this needs to be customizable etc. somehow maybe..
 
            # if i % 100 == 0 and hasattr(self, 'session'):
 
            #     self.session.flush()
 
            # flush changes every so often
 
            if not self.batch_size or (len(created) + len(updated)) % self.batch_size == 0:
 
                self.flush_create_update()
 

	
 
            if prog:
 
                prog.update(i)
 
        if prog:
 
            prog.destroy()
 

	
 
        self.flush_create_update_final()
 
        return created, updated
 

	
 
    # TODO: this surely goes elsewhere
 
    flush_every_x = 100
 
    def flush_create_update(self):
 
        """
 
        Perform any steps necessary to "flush" the create/update changes which
 
        have occurred thus far in the import.
 
        """
 

	
 
    def flush_changes(self, x):
 
        if self.flush_every_x and x % self.flush_every_x == 0:
 
            if hasattr(self, 'session'):
 
                self.session.flush()
 
    def flush_create_update_final(self):
 
        """
 
        Perform any final steps to "flush" the created/updated data here.
 
        """
 
        self.flush_create_update()
 

	
 
    def _import_delete(self, host_data, host_keys, changes=0):
 
        """
 
        Import deletions for the given data set.
 
        """
 
        deleted = []
 
        deleting = self.get_deletion_keys() - host_keys
 
        count = len(deleting)
 
        log.debug("found {} instances to delete".format(count))
 
        if count:
 

	
 
            prog = None
 
            if self.progress:
 
                prog = self.progress("Deleting {} data".format(self.model_name), count)
 

	
 
            for i, key in enumerate(sorted(deleting), 1):
 

	
 
                cached = self.cached_local_data.pop(key)
 
                obj = cached['object']
 
                if self.delete_object(obj):
 
                    deleted.append((obj, cached['data']))
 

	
 
                    if self.max_delete and len(deleted) >= self.max_delete:
 
                        log.warning("max of {} *deleted* records has been reached; stopping now".format(self.max_delete))
 
                        break
 
                    if self.max_total and (changes + len(deleted)) >= self.max_total:
 
                        log.warning("max of {} *total changes* has been reached; stopping now".format(self.max_total))
 
                        break
 

	
 
                if prog:
 
                    prog.update(i)
 
            if prog:
 
                prog.destroy()
 

	
 
        return deleted
 

	
 
    def get_key(self, data):
 
        """
 
        Return the key value for the given data dict.
 
        """
 
        return tuple(data[k] for k in self.key)
 

	
 
    def get_host_objects(self):
 
        """
 
        Return the "raw" (as-is, not normalized) host objects which are to be
 
        imported.  This may return any sequence-like object, which has a
 
        ``len()`` value and responds to iteration etc.  The objects contained
 
        within it may be of any type, no assumptions are made there.  (That is
 
@@ -415,95 +420,92 @@ class Importer(object):
 
        """
 
        for field in self.simple_fields:
 
            if field in self.fields:
 
                if not local_data or local_data[field] != host_data[field]:
 
                    setattr(obj, field, host_data[field])
 
        return obj
 

	
 
    def get_deletion_keys(self):
 
        """
 
        Return a set of keys from the *local* data set, which are eligible for
 
        deletion.  By default this will be all keys from the local cached data
 
        set, or an empty set if local data isn't cached.
 
        """
 
        if self.caches_local_data and self.cached_local_data is not None:
 
            return set(self.cached_local_data)
 
        return set()
 

	
 
    def delete_object(self, obj):
 
        """
 
        Delete the given object from the local system (or not), and return a
 
        boolean indicating whether deletion was successful.  What exactly this
 
        entails may vary; default implementation does nothing at all.
 
        """
 
        return True
 

	
 

	
 
class FromQuery(Importer):
 
    """
 
    Generic base class for importers whose raw external data source is a
 
    SQLAlchemy (or Django, or possibly other?) query.
 
    """
 

	
 
    def query(self):
 
        """
 
        Subclasses must override this, and return the primary query which will
 
        define the data set.
 
        """
 
        raise NotImplementedError
 

	
 
    def get_host_objects(self, progress=None):
 
        """
 
        Returns (raw) query results as a sequence.
 
        """
 
        return QuerySequence(self.query())
 

	
 

	
 
class BulkImporter(Importer):
 
    """
 
    Base class for bulk data importers which target PostgreSQL on the local side.
 
    Base class for bulk data importers.
 
    """
 

	
 
    def import_data(self, host_data=None, now=None, **kwargs):
 
        self.now = now or make_utc(datetime.datetime.utcnow(), tzinfo=True)
 
        if kwargs:
 
            self._setup(**kwargs)
 
        self.setup()
 
        if host_data is None:
 
            host_data = self.normalize_host_data()
 
        created = self._import_create(host_data)
 
        self.teardown()
 
        return created
 

	
 
    def _import_create(self, data):
 
        count = len(data)
 
        if not count:
 
            return 0
 
        created = count
 

	
 
        prog = None
 
        if self.progress:
 
            prog = self.progress("Importing {} data".format(self.model_name), count)
 

	
 
        for i, host_data in enumerate(data, 1):
 

	
 
            key = self.get_key(host_data)
 
            self.create_object(key, host_data)
 
            if self.max_create and i >= self.max_create:
 
                log.warning("max of {} *created* records has been reached; stopping now".format(self.max_create))
 
                created = i
 
                break
 

	
 
            # flush changes every so often
 
            if not self.batch_size or i % self.batch_size == 0:
 
                self.flush_create_update()
 

	
 
            if prog:
 
                prog.update(i)
 
        if prog:
 
            prog.destroy()
 

	
 
        self.flush_create()
 
        self.flush_create_update_final()
 
        return created
 

	
 
    def flush_create(self):
 
        """
 
        Perform any final steps to "flush" the created data here.  Note that
 
        the importer's handler is still responsible for actually committing
 
        changes to the local system, if applicable.
 
        """
rattail/importing/postgresql.py
Show inline comments
 
@@ -38,56 +38,59 @@ log = logging.getLogger(__name__)
 

	
 

	
 
class BulkToPostgreSQL(BulkImporter, ToSQLAlchemy):
 
    """
 
    Base class for bulk data importers which target PostgreSQL on the local side.
 
    """
 

	
 
    @property
 
    def data_path(self):
 
        return os.path.join(self.config.workdir(require=True),
 
                            'import_bulk_postgresql_{}.csv'.format(self.model_name))
 

	
 
    def setup(self):
 
        self.data_buffer = open(self.data_path, 'wb')
 

	
 
    def teardown(self):
 
        self.data_buffer.close()
 
        os.remove(self.data_path)
 
        self.data_buffer = None
 

	
 
    def create_object(self, key, data):
 
        data = self.prep_data_for_postgres(data)
 
        self.data_buffer.write('{}\n'.format('\t'.join([data[field] for field in self.fields])).encode('utf-8'))
 

	
 
    def prep_data_for_postgres(self, data):
 
        data = dict(data)
 
        for key, value in data.iteritems():
 
            data[key] = self.prep_value_for_postgres(value)
 
        return data
 

	
 
    def prep_value_for_postgres(self, value):
 
        if value is None:
 
            return '\\N'
 
        if value is True:
 
            return 't'
 
        if value is False:
 
            return 'f'
 

	
 
        if isinstance(value, datetime.datetime):
 
            value = make_utc(value, tzinfo=False)
 
        elif isinstance(value, basestring):
 
            value = value.replace('\\', '\\\\')
 
            value = value.replace('\r', '\\r')
 
            value = value.replace('\n', '\\n')
 
            value = value.replace('\t', '\\t') # TODO: add test for this
 

	
 
        return unicode(value)
 

	
 
    def flush_create(self):
 
    def flush_create_update(self):
 
        pass
 

	
 
    def flush_create_update_final(self):
 
        log.info("copying {} data from buffer to PostgreSQL".format(self.model_name))
 
        self.data_buffer.close()
 
        self.data_buffer = open(self.data_path, 'rb')
 
        cursor = self.session.connection().connection.cursor()
 
        table_name = '"{}"'.format(self.model_table.name)
 
        cursor.copy_from(self.data_buffer, table_name, columns=self.fields)
 
        log.debug("PostgreSQL data copy completed")
rattail/importing/sqlalchemy.py
Show inline comments
 
@@ -125,48 +125,55 @@ class ToSQLAlchemy(Importer):
 

	
 
    def update_object(self, obj, host_data, local_data=None):
 
        """
 
        Update the local data object with the given host data, and return the
 
        object.
 
        """
 
        obj = super(ToSQLAlchemy, self).update_object(obj, host_data, local_data)
 
        if obj:
 
            if self.flush_session:
 
                self.session.flush()
 
            return obj
 

	
 
    def delete_object(self, obj):
 
        """
 
        Delete the given object from the local system (or not), and return a
 
        boolean indicating whether deletion was successful.  Default
 
        implementation will truly delete and expunge the local object via SA
 
        ORM, and flush the local session.
 
        """
 
        self.session.delete(obj)
 
        self.session.flush()
 
        self.session.expunge(obj)
 
        return True
 

	
 
    def cache_model(self, model, **kwargs):
 
        """
 
        Convenience method which invokes :func:`rattail.db.cache.cache_model()`
 
        with the given model and keyword arguments.  It will provide the
 
        ``session`` and ``progress`` parameters by default, setting them to the
 
        importer's attributes of the same names.
 
        """
 
        session = kwargs.pop('session', self.session)
 
        kwargs.setdefault('progress', self.progress)
 
        return cache.cache_model(session, model, **kwargs)
 

	
 
    def cache_local_data(self, host_data=None):
 
        """
 
        Cache all local objects and data using SA ORM.
 
        """
 
        return self.cache_model(self.model_class, key=self.get_cache_key,
 
                                # omit_duplicates=True,
 
                                query_options=self.cache_query_options(),
 
                                normalizer=self.normalize_cache_object)
 

	
 
    def cache_query_options(self):
 
        """
 
        Return a list of options to apply to the cache query, if needed.
 
        """
 

	
 
    def flush_create_update(self):
 
        """
 
        Flush the database session, to send SQL to the server for all changes
 
        made thus far.
 
        """
 
        self.session.flush()
rattail/tests/commands/test_importing.py
Show inline comments
 
# -*- coding: utf-8 -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
import argparse
 
from unittest import TestCase
 

	
 
from mock import Mock, patch
 

	
 
from rattail.core import Object
 
from rattail.commands import importing
 
from rattail.importing import ImportHandler
 
from rattail.importing import Importer, ImportHandler
 
from rattail.importing.rattail import FromRattailToRattail
 
from rattail.config import RattailConfig
 
from rattail.tests.importing import ImporterTester
 
from rattail.tests.importing.test_handlers import MockImportHandler
 
from rattail.tests.importing.test_importers import MockImporter
 
from rattail.tests.importing.test_rattail import DualRattailTestCase
 

	
 

	
 
class MockImport(importing.ImportSubcommand):
 

	
 
    def get_handler_factory(self):
 
        return MockImportHandler
 

	
 

	
 
class TestImportSubcommandBasics(TestCase):
 
class TestImportSubcommand(TestCase):
 

	
 
    # TODO: lame, here only for coverage
 
    def test_parent_name(self):
 
        parent = Object(name='milo')
 
        command = importing.ImportSubcommand(parent=parent)
 
        self.assertEqual(command.parent_name, 'milo')
 

	
 
    def test_get_handler_factory(self):
 
        command = importing.ImportSubcommand()
 
        self.assertRaises(NotImplementedError, command.get_handler_factory)
 
        command.handler_spec = 'rattail.importing.rattail:FromRattailToRattail'
 
        factory = command.get_handler_factory()
 
        self.assertIs(factory, FromRattailToRattail)
 

	
 
    def test_handler_spec_attr(self):
 
        # default is None
 
        command = importing.ImportSubcommand()
 
        self.assertIsNone(command.handler_spec)
 

	
 
        # can't get a handler without a spec
 
        self.assertRaises(NotImplementedError, command.get_handler)
 

	
 
        # but may be specified with init kwarg
 
        command = importing.ImportSubcommand(handler_spec='rattail.importing:ImportHandler')
 
        self.assertEqual(command.handler_spec, 'rattail.importing:ImportHandler')
 

	
 
        # now we can get a handler
 
        handler = command.get_handler()
 
        self.assertIsInstance(handler, ImportHandler)
 

	
 
    def test_get_handler(self):
 

	
 
        # no config
 
        command = MockImport()
 
        handler = command.get_handler()
 
        self.assertIs(type(handler), MockImportHandler)
 
        self.assertIsNone(handler.config)
 

	
 
        # with config
 
        config = RattailConfig()
 
        command = MockImport(config=config)
 
        handler = command.get_handler()
 
        self.assertIs(type(handler), MockImportHandler)
 
        self.assertIs(handler.config, config)
 

	
 
        # dry run
 
        command = MockImport()
 
        handler = command.get_handler()
 
        self.assertFalse(handler.dry_run)
 
        handler = command.get_handler(dry_run=True)
 
        self.assertTrue(handler.dry_run)
 
        args = argparse.Namespace(dry_run=True)
 
        handler = command.get_handler(args=args)
 
        self.assertTrue(handler.dry_run)
 

	
 
    def test_add_parser_args(self):
 
        # TODO: this doesn't really test anything, but does give some coverage..
 

	
 
        # no handler
 
        # adding the args throws no error..(basic coverage)
 
        command = importing.ImportSubcommand()
 
        parser = argparse.ArgumentParser()
 
        command.add_parser_args(parser)
 

	
 
        # with handler
 
        command = MockImport()
 
        # confirm default values
 
        args = parser.parse_args([])
 
        self.assertIsNone(args.start_date)
 
        self.assertIsNone(args.end_date)
 
        self.assertTrue(args.create)
 
        self.assertIsNone(args.max_create)
 
        self.assertTrue(args.update)
 
        self.assertIsNone(args.max_update)
 
        self.assertFalse(args.delete)
 
        self.assertIsNone(args.max_delete)
 
        self.assertIsNone(args.max_total)
 
        self.assertEqual(args.batch_size, 200)
 
        self.assertFalse(args.warnings)
 
        self.assertFalse(args.dry_run)
 

	
 
    def test_batch_size_kwarg(self):
 
        command = importing.ImportSubcommand(handler_spec='rattail.importing:ImportHandler')
 
        parser = argparse.ArgumentParser()
 
        command.add_parser_args(parser)
 
        with patch.object(ImportHandler, 'get_importers', Mock(return_value={'Foo': Importer})):
 

	
 
            # importer default is 200
 
            args = parser.parse_args([])
 
            handler = command.get_handler(args=args)
 
            importer = handler.get_importer('Foo')
 
            self.assertEqual(importer.batch_size, 200)
 

	
 
            # but may be overridden with command line arg
 
            args = parser.parse_args(['--batch', '42'])
 
            handler = command.get_handler(args=args)
 
            importer = handler.get_importer('Foo')
 
            self.assertEqual(importer.batch_size, 42)
 

	
 

	
 
class TestImportSubcommandRun(ImporterTester, TestCase):
 
class TestImportSubcommandRun(TestCase, ImporterTester):
 

	
 
    sample_data = {
 
        '16oz': {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"},
 
        '32oz': {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"},
 
        '1gal': {'upc': '00074305011283', 'description': "Apple Cider Vinegar 1gal"},
 
    }
 

	
 
    def setUp(self):
 
        self.command = MockImport()
 
        self.handler = MockImportHandler()
 
        self.importer = MockImporter()
 

	
 
    def import_data(self, **kwargs):
 
    def import_data(self, host_data=None, local_data=None, **kwargs):
 
        if host_data is None:
 
            host_data = self.sample_data
 
        if local_data is None:
 
            local_data = self.sample_data
 

	
 
        models = kwargs.pop('models', [])
 
        kwargs.setdefault('dry_run', False)
 

	
 
        kw = {
 
            'warnings': False,
 
            'create': None,
 
            'max_create': None,
 
            'update': None,
 
            'max_update': None,
 
            'delete': None,
 
            'max_delete': None,
 
            'max_total': None,
 
            'batchcount': None,
 
            'progress': None,
 
        }
 
        kw.update(kwargs)
 
        args = argparse.Namespace(models=models, **kw)
 

	
 
        # must modify our importer in-place since we need the handler to return
 
        # that specific instance, below (because the host/local data context
 
        # managers reference that instance directly)
 
        self.importer._setup(**kwargs)
 
        with patch.object(self.command, 'get_handler', Mock(return_value=self.handler)):
 
            with patch.object(self.handler, 'get_importer', Mock(return_value=self.importer)):
 
                with self.host_data(host_data):
 
                    with self.local_data(local_data):
 
                        self.command.run(args)
 

	
 
        if self.handler._result:
 
            self.result = self.handler._result['Product']
 
        else:
 
            self.result = [], [], []
 

	
 
    def test_create(self):
 
        local = self.copy_data()
 
        del local['32oz']
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data()
 
        self.import_data(local_data=local)
 
        self.assert_import_created('32oz')
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_update(self):
 
        local = self.copy_data()
 
        local['16oz']['description'] = "wrong description"
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data()
 
        self.import_data(local_data=local)
 
        self.assert_import_created()
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted()
 

	
 
    def test_delete(self):
 
        local = self.copy_data()
 
        local['bogus'] = {'upc': '00000000000000', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(delete=True)
 
        self.import_data(local_data=local, delete=True)
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted('bogus')
 

	
 
    def test_duplicate(self):
 
        host = self.copy_data()
 
        host['32oz-dupe'] = host['32oz']
 
        with self.host_data(host):
 
            with self.local_data(self.sample_data):
 
                self.import_data()
 
        self.import_data(host_data=host)
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_max_create(self):
 
        local = self.copy_data()
 
        del local['16oz']
 
        del local['1gal']
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_create=1)
 
        self.import_data(local_data=local, max_create=1)
 
        self.assert_import_created('16oz')
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_max_total_create(self):
 
        local = self.copy_data()
 
        del local['16oz']
 
        del local['1gal']
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_total=1)
 
        self.import_data(local_data=local, max_total=1)
 
        self.assert_import_created('16oz')
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_max_update(self):
 
        local = self.copy_data()
 
        local['16oz']['description'] = "wrong"
 
        local['1gal']['description'] = "wrong"
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_update=1)
 
        self.import_data(local_data=local, max_update=1)
 
        self.assert_import_created()
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted()
 

	
 
    def test_max_total_update(self):
 
        local = self.copy_data()
 
        local['16oz']['description'] = "wrong"
 
        local['1gal']['description'] = "wrong"
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_total=1)
 
        self.import_data(local_data=local, max_total=1)
 
        self.assert_import_created()
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted()
 

	
 
    def test_max_delete(self):
 
        local = self.copy_data()
 
        local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"}
 
        local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(delete=True, max_delete=1)
 
        self.import_data(local_data=local, delete=True, max_delete=1)
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted('bogus1')
 

	
 
    def test_max_total_delete(self):
 
        local = self.copy_data()
 
        local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"}
 
        local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(delete=True, max_total=1)
 
        self.import_data(local_data=local, delete=True, max_total=1)
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted('bogus1')
 

	
 
    def test_dry_run(self):
 
        local = self.copy_data()
 
        del local['32oz']
 
        local['16oz']['description'] = "wrong description"
 
        local['bogus'] = {'upc': '00000000000000', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(delete=True, dry_run=True)
 
        self.import_data(local_data=local, delete=True, dry_run=True)
 
        # TODO: maybe need a way to confirm no changes actually made due to dry
 
        # run; currently results still reflect "proposed" changes.  this rather
 
        # bogus test is here just for coverage sake
 
        self.assert_import_created('32oz')
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted('bogus')
 

	
 

	
 
class TestImportRattail(DualRattailTestCase):
 

	
 
    def make_command(self, **kwargs):
 
        kwargs.setdefault('config', self.config)
 
        return importing.ImportRattail(**kwargs)
 

	
 
    def test_get_handler_factory(self):
 

	
 
        # default handler
 
        config = RattailConfig()
 
        command = self.make_command(config=config)
 
        Handler = command.get_handler_factory()
 
        self.assertIs(Handler, FromRattailToRattail)
 

	
 
    def test_get_handler_kwargs(self):
 
        command = self.make_command()
 
        kwargs = command.get_handler_kwargs()
 
        self.assertEqual(kwargs, {})
 

	
 
        args = argparse.Namespace(dbkey='other')
 
        kwargs = command.get_handler_kwargs(args=args)
 
        self.assertEqual(len(kwargs), 2)
 
        self.assertIs(kwargs['args'], args)
 
        self.assertEqual(kwargs['dbkey'], 'other')
rattail/tests/importing/test_handlers.py
Show inline comments
 
@@ -29,96 +29,106 @@ class ImportHandlerBattery(ImporterTester):
 
        self.assertEqual(handler.get_importers(), {})
 
        self.assertEqual(handler.get_importer_keys(), [])
 
        self.assertEqual(handler.get_default_keys(), [])
 
        self.assertFalse(handler.commit_host_partial)
 

	
 
        # with config
 
        handler = self.handler_class()
 
        self.assertIsNone(handler.config)
 
        config = RattailConfig()
 
        handler = self.handler_class(config=config)
 
        self.assertIs(handler.config, config)
 

	
 
        # dry run
 
        handler = self.handler_class()
 
        self.assertFalse(handler.dry_run)
 
        handler = self.handler_class(dry_run=True)
 
        self.assertTrue(handler.dry_run)
 

	
 
        # extra kwarg
 
        handler = self.handler_class()
 
        self.assertRaises(AttributeError, getattr, handler, 'foo')
 
        handler = self.handler_class(foo='bar')
 
        self.assertEqual(handler.foo, 'bar')
 

	
 
    def test_get_importer(self):
 
        get_importers = Mock(return_value={'foo': Importer})
 

	
 
        # no importers
 
        handler = self.make_handler()
 
        self.assertIsNone(handler.get_importer('foo'))
 

	
 
        # no config
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class()
 
        importer = handler.get_importer('foo')
 
        self.assertIs(type(importer), Importer)
 
        self.assertIsNone(importer.config)
 
        self.assertIs(importer.handler, handler)
 

	
 
        # with config
 
        config = RattailConfig()
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class(config=config)
 
        importer = handler.get_importer('foo')
 
        self.assertIs(type(importer), Importer)
 
        self.assertIs(importer.config, config)
 
        self.assertIs(importer.handler, handler)
 

	
 
        # # batch size
 
        # with patch.object(self.handler_class, 'get_importers', get_importers):
 
        #     handler = self.handler_class()
 
        # importer = handler.get_importer('foo')
 
        # self.assertEqual(importer.batch_size, 200)
 
        # with patch.object(self.handler_class, 'get_importers', get_importers):
 
        #     handler = self.handler_class(batch_size=42)
 
        # importer = handler.get_importer('foo')
 
        # self.assertEqual(importer.batch_size, 42)
 

	
 
        # dry run
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class()
 
        importer = handler.get_importer('foo')
 
        self.assertFalse(importer.dry_run)
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class(dry_run=True)
 
        importer = handler.get_importer('foo')
 
        self.assertTrue(handler.dry_run)
 

	
 
        # host title
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class()
 
        importer = handler.get_importer('foo')
 
        self.assertIsNone(importer.host_system_title)
 
        handler.host_title = "Foo"
 
        importer = handler.get_importer('foo')
 
        self.assertEqual(importer.host_system_title, "Foo")
 

	
 
        # extra kwarg
 
        with patch.object(self.handler_class, 'get_importers', get_importers):
 
            handler = self.handler_class()
 
        importer = handler.get_importer('foo')
 
        self.assertRaises(AttributeError, getattr, importer, 'bar')
 
        importer = handler.get_importer('foo', bar='baz')
 
        self.assertEqual(importer.bar, 'baz')
 

	
 
    def test_get_importer_kwargs(self):
 

	
 
        # empty by default
 
        handler = self.make_handler()
 
        self.assertEqual(handler.get_importer_kwargs('foo'), {})
 

	
 
        # extra kwargs are preserved
 
        handler = self.make_handler()
 
        self.assertEqual(handler.get_importer_kwargs('foo', bar='baz'), {'bar': 'baz'})
 

	
 
    def test_begin_transaction(self):
 
        handler = self.make_handler()
 
        with patch.object(handler, 'begin_host_transaction') as begin_host:
 
            with patch.object(handler, 'begin_local_transaction') as begin_local:
 
                handler.begin_transaction()
 
                begin_host.assert_called_once_with()
 
                begin_local.assert_called_once_with()
 

	
 
    def test_begin_host_transaction(self):
 
        handler = self.make_handler()
 
        handler.begin_host_transaction()
 
@@ -231,96 +241,122 @@ class ImportHandlerBattery(ImporterTester):
 
            self.assertFalse(commit.called)
 

	
 
        handler.commit_host_partial = True
 
        with patch.object(handler, 'commit_host_transaction') as commit:
 
            self.assertRaises(ValueError, handler.import_data, 'Foo')
 
            commit.assert_called_once_with()
 

	
 

	
 
class BulkImportHandlerBattery(ImportHandlerBattery):
 

	
 
    def test_import_data_invalid_model(self):
 
        handler = self.make_handler()
 
        importer = Mock()
 
        importer.import_data.return_value = 0
 
        FooImporter = Mock(return_value=importer)
 
        handler.importers = {'Foo': FooImporter}
 

	
 
        handler.import_data('Foo')
 
        self.assertEqual(FooImporter.call_count, 1)
 
        importer.import_data.assert_called_once_with()
 

	
 
        FooImporter.reset_mock()
 
        importer.reset_mock()
 

	
 
        handler.import_data('Missing')
 
        self.assertFalse(FooImporter.called)
 
        self.assertFalse(importer.called)
 

	
 
    def test_import_data_with_changes(self):
 
        handler = self.make_handler()
 
        importer = Mock()
 
        FooImporter = Mock(return_value=importer)
 
        handler.importers = {'Foo': FooImporter}
 

	
 
        importer.import_data.return_value = 0
 
        with patch.object(handler, 'process_changes') as process:
 
            handler.import_data('Foo')
 
            self.assertFalse(process.called)
 

	
 
        importer.import_data.return_value = 3
 
        with patch.object(handler, 'process_changes') as process:
 
            handler.import_data('Foo')
 
            self.assertFalse(process.called)
 

	
 

	
 
class TestImportHandler(unittest.TestCase, ImportHandlerBattery):
 
    handler_class = handlers.ImportHandler
 

	
 
    def test_batch_size_kwarg(self):
 
        with patch.object(handlers.ImportHandler, 'get_importers', Mock(return_value={'Foo': Importer})):
 

	
 
            # handler has no batch size by default
 
            handler = handlers.ImportHandler()
 
            self.assertFalse(hasattr(handler, 'batch_size'))
 

	
 
            # but can override with kwarg
 
            handler = handlers.ImportHandler(batch_size=42)
 
            self.assertEqual(handler.batch_size, 42)
 

	
 
            # importer default is 200
 
            handler = handlers.ImportHandler()
 
            importer = handler.get_importer('Foo')
 
            self.assertEqual(importer.batch_size, 200)
 

	
 
            # but can override with handler init kwarg
 
            handler = handlers.ImportHandler(batch_size=37)
 
            importer = handler.get_importer('Foo')
 
            self.assertEqual(importer.batch_size, 37)
 

	
 
            # can also override with get_importer kwarg
 
            handler = handlers.ImportHandler()
 
            importer = handler.get_importer('Foo', batch_size=98)
 
            self.assertEqual(importer.batch_size, 98)
 

	
 
    @patch('rattail.importing.handlers.send_email')
 
    def test_process_changes_sends_email(self, send_email):
 
        handler = handlers.ImportHandler()
 
        handler.import_began = pytz.utc.localize(datetime.datetime.utcnow())
 
        changes = [], [], []
 

	
 
        # warnings disabled
 
        handler.warnings = False
 
        handler.process_changes(changes)
 
        self.assertFalse(send_email.called)
 

	
 
        # warnings enabled
 
        handler.warnings = True
 
        handler.process_changes(changes)
 
        self.assertEqual(send_email.call_count, 1)
 

	
 
        send_email.reset_mock()
 

	
 
        # warnings enabled, with command (just for coverage..)
 
        handler.warnings = True
 
        handler.command = Mock(name='import-testing', parent=Mock(name='rattail'))
 
        handler.process_changes(changes)
 
        self.assertEqual(send_email.call_count, 1)
 

	
 

	
 
class TestBulkImportHandler(unittest.TestCase, BulkImportHandlerBattery):
 
    handler_class = handlers.BulkImportHandler
 

	
 

	
 
######################################################################
 
# fake import handler, tested mostly for basic coverage
 
######################################################################
 

	
 
class MockImportHandler(handlers.ImportHandler):
 

	
 
    def get_importers(self):
 
        return {'Product': MockImporter}
 

	
 
    def import_data(self, *keys, **kwargs):
 
        result = super(MockImportHandler, self).import_data(*keys, **kwargs)
 
        self._result = result
 
        return result
 

	
 

	
 
class TestImportHandlerImportData(ImporterTester, unittest.TestCase):
 

	
 
    sample_data = {
 
        '16oz': {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"},
 
@@ -546,111 +582,48 @@ class TestFromSQLAlchemyHandler(unittest.TestCase):
 
        session = Session()
 
        handler = handlers.FromSQLAlchemyHandler(host_session=session)
 
        self.assertIs(handler.host_session, session)
 
        handler.rollback_host_transaction()
 
        self.assertIsNone(handler.host_session)
 

	
 

	
 
class TestToSQLAlchemyHandler(unittest.TestCase):
 

	
 
    def test_init(self):
 
        handler = handlers.ToSQLAlchemyHandler()
 
        self.assertRaises(NotImplementedError, handler.make_session)
 

	
 
    def test_get_importer_kwargs(self):
 
        session = object()
 
        handler = handlers.ToSQLAlchemyHandler(session=session)
 
        kwargs = handler.get_importer_kwargs(None)
 
        self.assertEqual(list(kwargs.iterkeys()), ['session'])
 
        self.assertIs(kwargs['session'], session)
 

	
 
    def test_begin_local_transaction(self):
 
        handler = MockToSQLAlchemyHandler()
 
        self.assertIsNone(handler.session)
 
        handler.begin_local_transaction()
 
        self.assertIsInstance(handler.session, orm.Session)
 
        handler.session.close()
 

	
 
    def test_commit_local_transaction(self):
 
        # TODO: test actual commit for data changes
 
        session = Session()
 
        handler = handlers.ToSQLAlchemyHandler(session=session)
 
        self.assertIs(handler.session, session)
 
        with patch.object(handler, 'session') as session:
 
            handler.commit_local_transaction()
 
            session.commit.assert_called_once_with()
 
            self.assertFalse(session.rollback.called)
 
        # self.assertIsNone(handler.session)
 

	
 
    def test_rollback_local_transaction(self):
 
        # TODO: test actual rollback for data changes
 
        session = Session()
 
        handler = handlers.ToSQLAlchemyHandler(session=session)
 
        self.assertIs(handler.session, session)
 
        with patch.object(handler, 'session') as session:
 
            handler.rollback_local_transaction()
 
            session.rollback.assert_called_once_with()
 
            self.assertFalse(session.commit.called)
 
        # self.assertIsNone(handler.session)
 

	
 

	
 
######################################################################
 
# fake bulk import handler, tested mostly for basic coverage
 
######################################################################
 

	
 
class MockBulkImportHandler(handlers.BulkToPostgreSQLHandler):
 

	
 
    def get_importers(self):
 
        return {'Department': MockBulkImporter}
 

	
 
    def make_session(self):
 
        return Session()
 

	
 

	
 
class TestBulkImportHandlerOld(RattailTestCase, ImporterTester):
 

	
 
    importer_class = MockBulkImporter
 

	
 
    sample_data = {
 
        'grocery': {'number': 1, 'name': "Grocery", 'uuid': 'decd909a194011e688093ca9f40bc550'},
 
        'bulk': {'number': 2, 'name': "Bulk", 'uuid': 'e633d54c194011e687e33ca9f40bc550'},
 
        'hba': {'number': 3, 'name': "HBA", 'uuid': 'e2bad79e194011e6a4783ca9f40bc550'},
 
    }
 

	
 
    def setUp(self):
 
        self.setup_rattail()
 
        self.tempio = TempIO()
 
        self.config.set('rattail', 'workdir', self.tempio.realpath())
 
        self.handler = MockBulkImportHandler(config=self.config)
 

	
 
    def tearDown(self):
 
        self.teardown_rattail()
 
        self.tempio = None
 

	
 
    def import_data(self, host_data=None, **kwargs):
 
        if host_data is None:
 
            host_data = list(self.copy_data().itervalues())
 
        with patch.object(self.importer_class, 'normalize_host_data', Mock(return_value=host_data)):
 
            with patch.object(self.handler, 'make_session', Mock(return_value=self.session)):
 
                return self.handler.import_data('Department', **kwargs)
 

	
 
    def test_invalid_importer_key_is_ignored(self):
 
        handler = MockBulkImportHandler()
 
        self.assertNotIn('InvalidKey', handler.importers)
 
        self.assertEqual(handler.import_data('InvalidKey'), {})
 

	
 
    def assert_import_created(self, *keys):
 
        pass
 

	
 
    def assert_import_updated(self, *keys):
 
        pass
 

	
 
    def assert_import_deleted(self, *keys):
 
        pass
 

	
 
    def test_normal_run(self):
 
        if self.postgresql():
 
            self.import_data()
 

	
 
    def test_dry_run(self):
 
        if self.postgresql():
 
            self.import_data(dry_run=True)
rattail/tests/importing/test_importers.py
Show inline comments
 
@@ -14,99 +14,100 @@ from rattail.tests.importing import ImporterTester
 

	
 

	
 
class ImporterBattery(ImporterTester):
 
    """
 
    Battery of tests which can hopefully be ran for any non-bulk importer.
 
    """
 

	
 
    def test_import_data_empty(self):
 
        importer = self.make_importer()
 
        result = importer.import_data()
 
        self.assertEqual(result, {})
 

	
 
    def test_import_data_dry_run(self):
 
        importer = self.make_importer()
 
        self.assertFalse(importer.dry_run)
 
        importer.import_data(dry_run=True)
 
        self.assertTrue(importer.dry_run)
 

	
 
    def test_import_data_create(self):
 
        importer = self.make_importer()
 
        with patch.object(importer, 'get_key', lambda k: k):
 
            with patch.object(importer, 'create_object') as create:
 
                importer.import_data(host_data=[1, 2, 3])
 
                self.assertEqual(create.call_args_list, [
 
                    call(1, 1), call(2, 2), call(3, 3)])
 

	
 
    def test_import_data_max_create(self):
 
        importer = self.make_importer()
 
        with patch.object(importer, 'get_key', lambda k: k):
 
            with patch.object(importer, 'create_object') as create:
 
                importer.import_data(host_data=[1, 2, 3], max_create=1)
 
                self.assertEqual(create.call_args_list, [call(1, 1)])
 

	
 

	
 
class BulkImporterBattery(ImporterBattery):
 
    """
 
    Battery of tests which can hopefully be ran for any bulk importer.
 
    """
 

	
 
    def test_import_data_empty(self):
 
        importer = self.make_importer()
 
        result = importer.import_data()
 
        self.assertEqual(result, 0)
 

	
 

	
 
class TestImporter(TestCase):
 

	
 
    def test_init(self):
 

	
 
        # defaults
 
        importer = importers.Importer()
 
        self.assertIsNone(importer.model_class)
 
        self.assertIsNone(importer.model_name)
 
        self.assertIsNone(importer.key)
 
        self.assertEqual(importer.fields, [])
 
        self.assertIsNone(importer.host_system_title)
 

	
 
        # key must be included among the fields
 
        self.assertRaises(ValueError, importers.Importer, key='upc', fields=[])
 
        importer = importers.Importer(key='upc', fields=['upc'])
 
        self.assertEqual(importer.key, ('upc',))
 
        self.assertEqual(importer.fields, ['upc'])
 

	
 
        # extra bits are passed as-is
 
        importer = importers.Importer()
 
        self.assertFalse(hasattr(importer, 'extra_bit'))
 
        extra_bit = object()
 
        importer = importers.Importer(extra_bit=extra_bit)
 
        self.assertIs(importer.extra_bit, extra_bit)
 

	
 
    def test_delete_flag(self):
 
        # disabled by default
 
        importer = importers.Importer()
 
        self.assertTrue(importer.allow_delete)
 
        self.assertFalse(importer.delete)
 
        importer.import_data(host_data=[])
 
        self.assertFalse(importer.delete)
 

	
 
        # but can be enabled
 
        importer = importers.Importer(delete=True)
 
        self.assertTrue(importer.allow_delete)
 
        self.assertTrue(importer.delete)
 
        importer = importers.Importer()
 
        self.assertFalse(importer.delete)
 
        importer.import_data(host_data=[], delete=True)
 
        self.assertTrue(importer.delete)
 

	
 
    def test_get_host_objects(self):
 
        importer = importers.Importer()
 
        objects = importer.get_host_objects()
 
        self.assertEqual(objects, [])
 

	
 
    def test_cache_local_data(self):
 
        importer = importers.Importer()
 
        self.assertRaises(NotImplementedError, importer.cache_local_data)
 

	
 
    def test_get_local_object(self):
 
        importer = importers.Importer()
 
        self.assertFalse(importer.caches_local_data)
 
        self.assertRaises(NotImplementedError, importer.get_local_object, None)
 

	
 
@@ -147,114 +148,232 @@ class TestImporter(TestCase):
 

	
 
    def test_update_object(self):
 
        importer = importers.Importer(key='upc', fields=['upc', 'description'])
 
        importer.simple_fields = importer.fields
 
        obj = Mock(upc='00074305001321', description="Apple Cider Vinegar")
 

	
 
        newobj = importer.update_object(obj, {'upc': '00074305001321', 'description': "Apple Cider Vinegar"})
 
        self.assertIs(newobj, obj)
 
        self.assertEqual(obj.description, "Apple Cider Vinegar")
 

	
 
        newobj = importer.update_object(obj, {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"})
 
        self.assertIs(newobj, obj)
 
        self.assertEqual(obj.description, "Apple Cider Vinegar 32oz")
 

	
 
    def test_normalize_host_data(self):
 
        importer = importers.Importer(key='upc', fields=['upc', 'description'],
 
                                          progress=NullProgress)
 

	
 
        data = [
 
            {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"},
 
            {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"},
 
        ]
 

	
 
        host_data = importer.normalize_host_data(host_objects=[])
 
        self.assertEqual(host_data, [])
 

	
 
        host_data = importer.normalize_host_data(host_objects=data)
 
        self.assertEqual(host_data, data)
 

	
 
        with patch.object(importer, 'get_host_objects', new=Mock(return_value=data)):
 
            host_data = importer.normalize_host_data()
 
        self.assertEqual(host_data, data)
 

	
 
    def test_get_deletion_keys(self):
 
        importer = importers.Importer()
 
        self.assertFalse(importer.caches_local_data)
 
        keys = importer.get_deletion_keys()
 
        self.assertEqual(keys, set())
 

	
 
        importer.caches_local_data = True
 
        self.assertIsNone(importer.cached_local_data)
 
        keys = importer.get_deletion_keys()
 
        self.assertEqual(keys, set())
 

	
 
        importer.cached_local_data = {'delete-me': object()}
 
        keys = importer.get_deletion_keys()
 
        self.assertEqual(keys, set(['delete-me']))
 

	
 
    def test_model_name_attr(self):
 
        # default is None
 
        importer = importers.Importer()
 
        self.assertIsNone(importer.model_name)
 

	
 
        # but may be overridden via init kwarg
 
        importer = importers.Importer(model_name='Foo')
 
        self.assertEqual(importer.model_name, 'Foo')
 

	
 
        # or may inherit its value from 'model_class'
 
        class Foo:
 
            pass
 
        importer = importers.Importer(model_class=Foo)
 
        self.assertEqual(importer.model_name, 'Foo')
 

	
 
    def test_batch_size_attr(self):
 
        # default is 200
 
        importer = importers.Importer()
 
        self.assertEqual(importer.batch_size, 200)
 

	
 
        # but may be overridden via init kwarg
 
        importer = importers.Importer(batch_size=0)
 
        self.assertEqual(importer.batch_size, 0)
 
        importer = importers.Importer(batch_size=42)
 
        self.assertEqual(importer.batch_size, 42)
 

	
 
        # batch size determines how often flush occurs
 
        data = [{'id': i} for i in range(1, 101)]
 
        importer = importers.Importer(model_name='Foo', key='id', fields=['id'], empty_local_data=True)
 
        with patch.object(importer, 'create_object'): # just mock that out
 
            with patch.object(importer, 'flush_create_update') as flush:
 

	
 
                # 4 batches @ 33/per
 
                importer.import_data(host_data=data, batch_size=33)
 
                self.assertEqual(flush.call_count, 4)
 
                flush.reset_mock()
 

	
 
                # 3 batches @ 34/per
 
                importer.import_data(host_data=data, batch_size=34)
 
                self.assertEqual(flush.call_count, 3)
 
                flush.reset_mock()
 

	
 
                # 2 batches @ 50/per
 
                importer.import_data(host_data=data, batch_size=100)
 
                self.assertEqual(flush.call_count, 2)
 
                flush.reset_mock()
 

	
 
                # one extra/final flush happens, whenever the total number of
 
                # changes happens to match the batch size...
 

	
 
                # 1 batch @ 100/per, plus final flush
 
                importer.import_data(host_data=data, batch_size=100)
 
                self.assertEqual(flush.call_count, 2)
 
                flush.reset_mock()
 

	
 
                # 1 batch @ 200/per
 
                importer.import_data(host_data=data, batch_size=200)
 
                self.assertEqual(flush.call_count, 1)
 
                flush.reset_mock()
 

	
 
                # one extra/final flush also happens when batching is disabled
 

	
 
                # 100 "batches" @ 0/per, plus final flush
 
                importer.import_data(host_data=data, batch_size=0)
 
                self.assertEqual(flush.call_count, 101)
 
                flush.reset_mock()
 

	
 

	
 
class TestFromQuery(RattailTestCase):
 

	
 
    def test_query(self):
 
        importer = importers.FromQuery()
 
        self.assertRaises(NotImplementedError, importer.query)
 

	
 
    def test_get_host_objects(self):
 
        query = self.session.query(model.Product)
 
        importer = importers.FromQuery()
 
        with patch.object(importer, 'query', Mock(return_value=query)):
 
            objects = importer.get_host_objects()
 
        self.assertIsInstance(objects, QuerySequence)
 

	
 

	
 
class TestBulkImporter(TestCase, BulkImporterBattery):
 
    importer_class = importers.BulkImporter
 
        
 
    def test_batch_size_attr(self):
 
        # default is 200
 
        importer = importers.BulkImporter()
 
        self.assertEqual(importer.batch_size, 200)
 

	
 
        # but may be overridden via init kwarg
 
        importer = importers.BulkImporter(batch_size=0)
 
        self.assertEqual(importer.batch_size, 0)
 
        importer = importers.BulkImporter(batch_size=42)
 
        self.assertEqual(importer.batch_size, 42)
 

	
 
        # batch size determines how often flush occurs
 
        data = [{'id': i} for i in range(1, 101)]
 
        importer = importers.BulkImporter(model_name='Foo', key='id', fields=['id'], empty_local_data=True)
 
        with patch.object(importer, 'create_object'): # just mock that out
 
            with patch.object(importer, 'flush_create_update') as flush:
 

	
 
                # 4 batches @ 33/per
 
                importer.import_data(host_data=data, batch_size=33)
 
                self.assertEqual(flush.call_count, 4)
 
                flush.reset_mock()
 

	
 
                # 3 batches @ 34/per
 
                importer.import_data(host_data=data, batch_size=34)
 
                self.assertEqual(flush.call_count, 3)
 
                flush.reset_mock()
 

	
 
                # 2 batches @ 50/per
 
                importer.import_data(host_data=data, batch_size=100)
 
                self.assertEqual(flush.call_count, 2)
 
                flush.reset_mock()
 

	
 
                # one extra/final flush happens, whenever the total number of
 
                # changes happens to match the batch size...
 

	
 
                # 1 batch @ 100/per, plus final flush
 
                importer.import_data(host_data=data, batch_size=100)
 
                self.assertEqual(flush.call_count, 2)
 
                flush.reset_mock()
 

	
 
                # 1 batch @ 200/per
 
                importer.import_data(host_data=data, batch_size=200)
 
                self.assertEqual(flush.call_count, 1)
 
                flush.reset_mock()
 

	
 
                # one extra/final flush also happens when batching is disabled
 

	
 
                # 100 "batches" @ 0/per, plus final flush
 
                importer.import_data(host_data=data, batch_size=0)
 
                self.assertEqual(flush.call_count, 101)
 
                flush.reset_mock()
 

	
 

	
 
######################################################################
 
# fake importer class, tested mostly for basic coverage
 
######################################################################
 

	
 
class Product(object):
 
    upc = None
 
    description = None
 

	
 

	
 
class MockImporter(importers.Importer):
 
    model_class = Product
 
    key = 'upc'
 
    simple_fields = ['upc', 'description']
 
    supported_fields = simple_fields
 
    caches_local_data = True
 
    flush_every_x = 1
 
    session = Mock()
 

	
 
    def normalize_local_object(self, obj):
 
        return obj
 

	
 
    def update_object(self, obj, host_data, local_data=None):
 
        return host_data
 

	
 

	
 
class TestMockImporter(ImporterTester, TestCase):
 
    importer_class = MockImporter
 

	
 
    sample_data = {
 
        '16oz': {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"},
 
        '32oz': {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"},
 
        '1gal': {'upc': '00074305011283', 'description': "Apple Cider Vinegar 1gal"},
 
    }
 

	
 
    def setUp(self):
 
        self.importer = self.make_importer()
 

	
 
    def test_create(self):
 
        local = self.copy_data()
 
        del local['32oz']
 
        self.import_data(local_data=local)
 
        self.assert_import_created('32oz')
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_create_empty(self):
rattail/tests/importing/test_postgresql.py
Show inline comments
 
# -*- coding: utf-8 -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
import datetime
 
import unittest
 

	
 
import sqlalchemy as sa
 
from sqlalchemy import orm
 
from fixture import TempIO
 

	
 
from rattail.db import Session, model
 
from rattail.importing import postgresql as pgimport
 
from rattail.config import RattailConfig
 
from rattail.exceptions import ConfigurationError
 
from rattail.tests import RattailTestCase, NullProgress
 
from rattail.tests.importing import ImporterTester
 
from rattail.tests.importing.test_rattail import DualRattailTestCase
 
from rattail.time import localtime
 

	
 

	
 
class Widget(object):
 
    pass
 

	
 

	
 
class TestBulkToPostgreSQL(unittest.TestCase):
 

	
 
    def setUp(self):
 
        self.tempio = TempIO()
 
        self.config = RattailConfig()
 
        self.config.set('rattail', 'workdir', self.tempio.realpath())
 
        self.config.set('rattail', 'timezone.default', 'America/Chicago')
 

	
 
    def tearDown(self):
 
        self.tempio = None
 

	
 
    def make_importer(self, **kwargs):
 
        kwargs.setdefault('config', self.config)
 
        kwargs.setdefault('fields', ['id']) # hack
 
        return pgimport.BulkToPostgreSQL(**kwargs)
 

	
 
    def test_data_path(self):
 
        importer = self.make_importer(config=None)
 
        self.assertIsNone(importer.config)
 
        self.assertRaises(AttributeError, getattr, importer, 'data_path')
 
        importer.config = RattailConfig()
 
        self.assertRaises(ConfigurationError, getattr, importer, 'data_path')
 
        importer.config = self.config
 
    def test_data_path_property(self):
 
        self.config.set('rattail', 'workdir', '/tmp')
 
        self.assertEqual(importer.data_path, '/tmp/import_bulk_postgresql_None.csv') # no model yet
 
        importer.model_class = Widget
 
        self.assertEqual(importer.data_path, '/tmp/import_bulk_postgresql_Widget.csv')
 
        importer = pgimport.BulkToPostgreSQL(config=self.config, fields=['id'])
 

	
 
        # path leverages model name, so default is None
 
        self.assertEqual(importer.data_path, '/tmp/import_bulk_postgresql_None.csv')
 

	
 
        # but it will reflect model name
 
        importer.model_name = 'Foo'
 
        self.assertEqual(importer.data_path, '/tmp/import_bulk_postgresql_Foo.csv')
 

	
 
    def test_setup(self):
 
        importer = self.make_importer()
 
        self.assertFalse(hasattr(importer, 'data_buffer'))
 
        importer.setup()
 
        self.assertIsNotNone(importer.data_buffer)
 
        importer.data_buffer.close()
 

	
 
    def test_teardown(self):
 
        importer = self.make_importer()
 
        importer.data_buffer = open(importer.data_path, 'wb')
 
        importer.teardown()
 
        self.assertIsNone(importer.data_buffer)
 

	
 
    def test_prep_value_for_postgres(self):
 
        importer = self.make_importer()
 

	
 
        # constants
 
        self.assertEqual(importer.prep_value_for_postgres(None), '\\N')
 
        self.assertEqual(importer.prep_value_for_postgres(True), 't')
 
        self.assertEqual(importer.prep_value_for_postgres(False), 'f')
 

	
 
        # datetime (local zone is Chicago/CDT; UTC-5)
 
        value = localtime(self.config, datetime.datetime(2016, 5, 13, 12))
 
        self.assertEqual(importer.prep_value_for_postgres(value), '2016-05-13 17:00:00')
 

	
 
        # strings...
 

	
 
        # backslash is escaped by doubling
 
        self.assertEqual(importer.prep_value_for_postgres('\\'), '\\\\')
 

	
 
        # newlines are collapsed (\r\n -> \n) and escaped
 
        self.assertEqual(importer.prep_value_for_postgres('one\rtwo\nthree\r\nfour\r\nfive\nsix\rseven'), 'one\\rtwo\\nthree\\r\\nfour\\r\\nfive\\nsix\\rseven')
 

	
 
    def test_prep_data_for_postgres(self):
 
        importer = self.make_importer()
 
        time = localtime(self.config, datetime.datetime(2016, 5, 13, 12))
 
        data = {
 
            'none': None,
 
            'true': True,
 
            'false': False,
 
            'datetime': time,
 
            'backslash': '\\',
 
            'newlines': 'one\rtwo\nthree\r\nfour\r\nfive\nsix\rseven',
 
        }
 
        data = importer.prep_data_for_postgres(data)
 
        self.assertEqual(data['none'], '\\N')
 
        self.assertEqual(data['true'], 't')
rattail/tests/importing/test_sqlalchemy.py
Show inline comments
 
# -*- coding: utf-8 -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from unittest import TestCase
 

	
 
from mock import patch
 

	
 
import sqlalchemy as sa
 
from sqlalchemy import orm
 
from sqlalchemy.orm.exc import MultipleResultsFound
 

	
 
from rattail.importing import sqlalchemy as saimport
 

	
 

	
 
class Widget(object):
 
    pass
 

	
 
metadata = sa.MetaData()
 

	
 
widget_table = sa.Table('widgets', metadata,
 
                        sa.Column('id', sa.Integer(), primary_key=True),
 
                        sa.Column('description', sa.String(length=50)))
 

	
 
widget_mapper = orm.mapper(Widget, widget_table)
 

	
 
WIDGETS = [
 
    {'id': 1, 'description': "Main Widget"},
 
    {'id': 2, 'description': "Other Widget"},
 
    {'id': 3, 'description': "Other Widget"},
 
]
 

	
 

	
 
class TestFromSQLAlchemy(TestCase):
 

	
 
    def test_query(self):
 
        Session = orm.sessionmaker(bind=sa.create_engine('sqlite://'))
 
        session = Session()
 
        importer = saimport.FromSQLAlchemy(host_session=session,
 
                                           host_model_class=Widget)
 
        self.assertEqual(unicode(importer.query()),
 
                         "SELECT widgets.id AS widgets_id, widgets.description AS widgets_description \n"
 
                         "FROM widgets")
 

	
 

	
 
class TestToSQLAlchemy(TestCase):
 

	
 
    def setUp(self):
 
        engine = sa.create_engine('sqlite://')
 
        metadata.create_all(bind=engine)
 
        Session = orm.sessionmaker(bind=engine)
 
        self.session = Session()
 
        for data in WIDGETS:
 
            widget = Widget()
 
            for key, value in data.iteritems():
 
                setattr(widget, key, value)
 
@@ -100,48 +102,54 @@ class TestToSQLAlchemy(TestCase):
 
    def test_create_object(self):
 
        importer = self.make_importer(key='id', session=self.session)
 
        widget = importer.create_object((42,), {'id': 42, 'description': "Latest Widget"})
 
        self.assertFalse(self.session.new or self.session.dirty or self.session.deleted) # i.e. has been flushed
 
        self.assertIn(widget, self.session) # therefore widget has been flushed and would be committed
 
        self.assertEqual(widget.id, 42)
 
        self.assertEqual(widget.description, "Latest Widget")
 

	
 
    def test_delete_object(self):
 
        widget = self.session.query(Widget).get(1)
 
        self.assertIn(widget, self.session)
 
        importer = self.make_importer(session=self.session)
 
        self.assertTrue(importer.delete_object(widget))
 
        self.assertNotIn(widget, self.session)
 
        self.assertIsNone(self.session.query(Widget).get(1))
 

	
 
    def test_cache_model(self):
 
        importer = self.make_importer(key='id', session=self.session)
 
        cache = importer.cache_model(Widget, key='id')
 
        self.assertEqual(len(cache), 3)
 
        for i in range(1, 4):
 
            self.assertIn(i, cache)
 
            self.assertIsInstance(cache[i], Widget)
 
            self.assertEqual(cache[i].id, i)
 
            self.assertEqual(cache[i].description, WIDGETS[i-1]['description'])
 

	
 
    def test_cache_local_data(self):
 
        importer = self.make_importer(key='id', session=self.session)
 
        cache = importer.cache_local_data()
 
        self.assertEqual(len(cache), 3)
 
        for i in range(1, 4):
 
            self.assertIn((i,), cache)
 
            cached = cache[(i,)]
 
            self.assertIsInstance(cached, dict)
 
            self.assertIsInstance(cached['object'], Widget)
 
            self.assertEqual(cached['object'].id, i)
 
            self.assertEqual(cached['object'].description, WIDGETS[i-1]['description'])
 
            self.assertIsInstance(cached['data'], dict)
 
            self.assertEqual(cached['data']['id'], i)
 
            self.assertEqual(cached['data']['description'], WIDGETS[i-1]['description'])
 

	
 
    # TODO: lame
 
    def test_flush_session(self):
 
        importer = self.make_importer(fields=['id'], session=self.session, flush_session=True)
 
        widget = Widget()
 
        widget.id = 1
 
        widget, original = importer.update_object(widget, {'id': 1}), widget
 
        self.assertIs(widget, original)
 

	
 
    def test_flush_create_update(self):
 
        importer = saimport.ToSQLAlchemy(fields=['id'], session=self.session)
 
        with patch.object(self.session, 'flush') as flush:
 
            importer.flush_create_update()
 
            flush.assert_called_once_with()
0 comments (0 inline, 0 general)