Changeset - 8016ffbd268e
[Not reviewed]
1 3 6
Lance Edgar - 8 years ago 2016-05-12 12:08:37
ledgar@sacfoodcoop.com
Add initial handlers and subcommand for new/final importing framework

plus tests
9 files changed with 976 insertions and 81 deletions:
0 comments (0 inline, 0 general)
rattail/commands/__init__.py
Show inline comments
 
@@ -27,3 +27,4 @@ Console Commands
 
from __future__ import unicode_literals, absolute_import
 

	
 
from .core import main, Command, Subcommand, OldImportSubcommand, NewImportSubcommand, Dump, date_argument
 
from .importing import ImportSubcommand
rattail/commands/importing.py
Show inline comments
 
new file 100644
 
# -*- coding: utf-8 -*-
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2016 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
#  Rattail is free software: you can redistribute it and/or modify it under the
 
#  terms of the GNU Affero General Public License as published by the Free
 
#  Software Foundation, either version 3 of the License, or (at your option)
 
#  any later version.
 
#
 
#  Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
 
#  WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 
#  FOR A PARTICULAR PURPOSE.  See the GNU Affero General Public License for
 
#  more details.
 
#
 
#  You should have received a copy of the GNU Affero General Public License
 
#  along with Rattail.  If not, see <http://www.gnu.org/licenses/>.
 
#
 
################################################################################
 
"""
 
Importing Commands
 
"""
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
import logging
 

	
 
from rattail.commands.core import Subcommand, date_argument
 
from rattail.util import load_object
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
class ImportSubcommand(Subcommand):
 
    """
 
    Base class for subcommands which use the (new) data importing system.
 
    """
 
    # TODO: move this into Subcommand
 
    parent_name = None
 

	
 
    def get_handler_factory(self):
 
        """
 
        Subclasses must override this, and return a callable that creates an
 
        import handler instance which the command should use.
 
        """
 
        raise NotImplementedError
 

	
 
    def get_handler(self, **kwargs):
 
        """
 
        Returns a handler instance to be used by the command.
 
        """
 
        factory = self.get_handler_factory()
 
        kwargs.setdefault('config', getattr(self, 'config', None))
 
        kwargs.setdefault('command', self)
 
        kwargs = self.get_handler_kwargs(**kwargs)
 
        return factory(**kwargs)
 

	
 
    def get_handler_kwargs(self, **kwargs):
 
        """
 
        Return a dict of kwargs to be passed to the handler factory.
 
        """
 
        return kwargs
 

	
 
    def add_parser_args(self, parser):
 

	
 
        # model names (aka importer keys)
 
        doc = ("Which data models to import.  If you specify any, then only "
 
               "data for those models will be imported.  If you do not specify "
 
               "any, then all *default* models will be imported.")
 
        try:
 
            handler = self.get_handler()
 
        except NotImplementedError:
 
            pass
 
        else:
 
            doc += "  Supported models are: ({})".format(', '.join(handler.get_importer_keys()))
 
        parser.add_argument('models', nargs='*', metavar='MODEL', help=doc)
 

	
 
        # start/end date
 
        parser.add_argument('--start-date', type=date_argument,
 
                            help="Optional (inclusive) starting point for date range, by which host "
 
                            "data should be filtered.  Only used by certain importers.")
 
        parser.add_argument('--end-date', type=date_argument,
 
                            help="Optional (inclusive) ending point for date range, by which host "
 
                            "data should be filtered.  Only used by certain importers.")
 

	
 
        # allow create?
 
        parser.add_argument('--create', action='store_true', default=True,
 
                            help="Allow new records to be created during the import.")
 
        parser.add_argument('--no-create', action='store_false', dest='create',
 
                            help="Do not allow new records to be created during the import.")
 
        parser.add_argument('--max-create', type=int, metavar='COUNT',
 
                            help="Maximum number of records which may be created, after which a "
 
                            "given import task should stop.  Note that this applies on a per-model "
 
                            "basis and not overall.")
 

	
 
        # allow update?
 
        parser.add_argument('--update', action='store_true', default=True,
 
                            help="Allow existing records to be updated during the import.")
 
        parser.add_argument('--no-update', action='store_false', dest='update',
 
                            help="Do not allow existing records to be updated during the import.")
 
        parser.add_argument('--max-update', type=int, metavar='COUNT',
 
                            help="Maximum number of records which may be updated, after which a "
 
                            "given import task should stop.  Note that this applies on a per-model "
 
                            "basis and not overall.")
 

	
 
        # allow delete?
 
        parser.add_argument('--delete', action='store_true', default=False,
 
                            help="Allow records to be deleted during the import.")
 
        parser.add_argument('--no-delete', action='store_false', dest='delete',
 
                            help="Do not allow records to be deleted during the import.")
 
        parser.add_argument('--max-delete', type=int, metavar='COUNT',
 
                            help="Maximum number of records which may be deleted, after which a "
 
                            "given import task should stop.  Note that this applies on a per-model "
 
                            "basis and not overall.")
 

	
 
        # max total changes, per model
 
        parser.add_argument('--max-total', type=int, metavar='COUNT',
 
                            help="Maximum number of *any* record changes which may occur, after which "
 
                            "a given import task should stop.  Note that this applies on a per-model "
 
                            "basis and not overall.")
 

	
 
        # treat changes as warnings?
 
        parser.add_argument('--warnings', '-W', action='store_true',
 
                            help="Set this flag if you expect a \"clean\" import, and wish for any "
 
                            "changes which do occur to be processed further and/or specially.  The "
 
                            "behavior of this flag is ultimately up to the import handler, but the "
 
                            "default is to send an email notification.")
 

	
 
        # dry run?
 
        parser.add_argument('--dry-run', action='store_true',
 
                            help="Go through the full motions and allow logging etc. to "
 
                            "occur, but rollback (abort) the transaction at the end.")
 

	
 
    def run(self, args):
 
        log.info("begin `{} {}` for data models: {}".format(
 
                self.parent_name, self.name, ', '.join(args.models or ["(ALL)"])))
 

	
 
        handler = self.get_handler(args=args, progress=self.progress)
 
        models = args.models or handler.get_default_keys()
 
        log.debug("using handler: {}".format(handler))
 
        log.debug("importing models: {}".format(models))
 
        log.debug("args are: {}".format(args))
 
        handler.import_data(*models)
 

	
 
        # TODO: should this logging happen elsewhere / be customizable?
 
        if args.dry_run:
 
            log.info("dry run, so transaction was rolled back")
 
        else:
 
            log.info("transaction was committed")
rattail/importing/__init__.py
Show inline comments
 
@@ -28,3 +28,4 @@ from __future__ import unicode_literals, absolute_import
 

	
 
from .importers import Importer, FromQuery
 
from .sqlalchemy import FromSQLAlchemy, ToSQLAlchemy
 
from .handlers import ImportHandler, FromSQLAlchemyHandler, ToSQLAlchemyHandler
rattail/importing/handlers.py
Show inline comments
 
new file 100644
 
# -*- coding: utf-8 -*-
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2016 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
#  Rattail is free software: you can redistribute it and/or modify it under the
 
#  terms of the GNU Affero General Public License as published by the Free
 
#  Software Foundation, either version 3 of the License, or (at your option)
 
#  any later version.
 
#
 
#  Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
 
#  WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
 
#  FOR A PARTICULAR PURPOSE.  See the GNU Affero General Public License for
 
#  more details.
 
#
 
#  You should have received a copy of the GNU Affero General Public License
 
#  along with Rattail.  If not, see <http://www.gnu.org/licenses/>.
 
#
 
################################################################################
 
"""
 
Import Handlers
 
"""
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
import datetime
 
import logging
 

	
 
from rattail.util import OrderedDict
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
class ImportHandler(object):
 
    """
 
    Base class for all import handlers.
 
    """
 
    host_title = "Host"
 
    local_title = "Local"
 
    progress = None
 
    dry_run = False
 

	
 
    def __init__(self, config=None, **kwargs):
 
        self.config = config
 
        self.importers = self.get_importers()
 
        for key, value in kwargs.iteritems():
 
            setattr(self, key, value)
 

	
 
    def get_importers(self):
 
        """
 
        Returns a dict of all available importers, where the keys are model
 
        names and the values are importer factories.  All subclasses will want
 
        to override this.  Note that if you return an
 
        :class:`python:collections.OrderedDict` instance, you can affect the
 
        ordering of keys in the command line help system, etc.
 
        """
 
        return {}
 

	
 
    def get_importer_keys(self):
 
        """
 
        Returns the list of keys corresponding to the available importers.
 
        """
 
        return list(self.importers.iterkeys())
 

	
 
    def get_default_keys(self):
 
        """
 
        Returns the list of keys corresponding to the "default" importers.
 
        Override this if you wish certain importers to be excluded by default,
 
        e.g. when first testing them out etc.
 
        """
 
        return self.get_importer_keys()
 

	
 
    def get_importer(self, key, **kwargs):
 
        """
 
        Returns an importer instance corresponding to the given key.
 
        """
 
        if key in self.importers:
 
            kwargs.setdefault('handler', self)
 
            kwargs.setdefault('config', self.config)
 
            kwargs = self.get_importer_kwargs(key, **kwargs)
 
            return self.importers[key](**kwargs)
 

	
 
    def get_importer_kwargs(self, key, **kwargs):
 
        """
 
        Return a dict of kwargs to be used when construcing an importer with
 
        the given key.
 
        """
 
        return kwargs
 

	
 
    def import_data(self, *keys, **kwargs):
 
        """
 
        Import all data for the given importer/model keys.
 
        """
 
        self.import_began = datetime.datetime.utcnow()
 
        self.dry_run = kwargs.pop('dry_run', False)
 
        self.progress = kwargs.pop('progress', None)
 
        self.warnings = kwargs.pop('warnings', False)
 
        kwargs.update({'dry_run': self.dry_run,
 
                       'progress': self.progress})
 

	
 
        self.setup()
 
        self.begin_transaction()
 
        changes = OrderedDict()
 

	
 
        for key in keys:
 
            importer = self.get_importer(key, **kwargs)
 
            if not importer:
 
                log.warning("skipping unknown importer: {}".format(key))
 
                continue
 

	
 
            created, updated, deleted = importer.import_data()
 

	
 
            changed = bool(created or updated or deleted)
 
            logger = log.warning if changed and self.warnings else log.info
 
            logger("{} -> {}: added {}, updated {}, deleted {} {} records".format(
 
                self.host_title, self.local_title, len(created), len(updated), len(deleted), key))
 
            if changed:
 
                changes[key] = created, updated, deleted
 

	
 
        if changes:
 
            self.process_changes(changes)
 
        if self.dry_run:
 
            self.rollback_transaction()
 
        else:
 
            self.commit_transaction()
 
        self.teardown()
 
        return changes
 

	
 
    def setup(self):
 
        """
 
        Perform any additional setup if/as necessary, prior to running the
 
        import task(s).
 
        """
 

	
 
    def teardown(self):
 
        """
 
        Perform any cleanup necessary, after running the import task(s).
 
        """
 

	
 
    def begin_transaction(self):
 
        self.begin_host_transaction()
 
        self.begin_local_transaction()
 

	
 
    def begin_host_transaction(self):
 
        pass
 

	
 
    def begin_local_transaction(self):
 
        pass
 

	
 
    def rollback_transaction(self):
 
        self.rollback_host_transaction()
 
        self.rollback_local_transaction()
 

	
 
    def rollback_host_transaction(self):
 
        pass
 

	
 
    def rollback_local_transaction(self):
 
        pass
 

	
 
    def commit_transaction(self):
 
        self.commit_host_transaction()
 
        self.commit_local_transaction()
 

	
 
    def commit_host_transaction(self):
 
        pass
 

	
 
    def commit_local_transaction(self):
 
        pass
 

	
 
    def process_changes(self, changes):
 
        """
 
        This method is called any time changes occur, regardless of whether the
 
        import is running in "warnings" mode.  Default implementation does
 
        nothing; override as needed.
 
        """
 

	
 

	
 
class FromSQLAlchemyHandler(ImportHandler):
 
    """
 
    Handler for imports for which the host data source is represented by a
 
    SQLAlchemy engine and ORM.
 
    """
 
    host_session = None
 

	
 
    def make_host_session(self):
 
        """
 
        Subclasses must override this to define the host database connection.
 
        """
 
        raise NotImplementedError
 

	
 
    def get_importer_kwargs(self, key):
 
        kwargs = super(FromSQLAlchemyHandler, self).get_importer_kwargs(key)
 
        kwargs.setdefault('host_session', self.host_session)
 
        return kwargs
 

	
 
    def begin_host_transaction(self):
 
        self.host_session = self.make_host_session()
 

	
 
    def rollback_host_transaction(self):
 
        self.host_session.rollback()
 
        self.host_session.close()
 
        self.host_session = None
 

	
 
    def commit_host_transaction(self):
 
        self.host_session.commit()
 
        self.host_session.close()
 
        self.host_session = None
 

	
 

	
 
class ToSQLAlchemyHandler(ImportHandler):
 
    """
 
    Handler for imports which target a SQLAlchemy ORM on the local side.
 
    """
 
    session = None
 

	
 
    def make_session(self):
 
        """
 
        Subclasses must override this to define the local database connection.
 
        """
 
        raise NotImplementedError
 

	
 
    def get_importer_kwargs(self, key):
 
        kwargs = super(ToSQLAlchemyHandler, self).get_importer_kwargs(key)
 
        kwargs.setdefault('session', self.session)
 
        return kwargs
 

	
 
    def begin_local_transaction(self):
 
        self.session = self.make_session()
 

	
 
    def rollback_local_transaction(self):
 
        self.session.rollback()
 
        self.session.close()
 
        self.session = None
 

	
 
    def commit_local_transaction(self):
 
        self.session.commit()
 
        self.session.close()
 
        self.session = None
rattail/importing/importers.py
Show inline comments
 
@@ -88,7 +88,6 @@ class Importer(object):
 
        self._setup(**kwargs)
 

	
 
    def _setup(self, **kwargs):
 
        self.now = kwargs.pop('now', make_utc(datetime.datetime.utcnow(), tzinfo=True))
 
        self.create = kwargs.pop('create', self.allow_create) and self.allow_create
 
        self.update = kwargs.pop('update', self.allow_update) and self.allow_update
 
        self.delete = kwargs.pop('delete', self.allow_delete) and self.allow_delete
 
@@ -110,13 +109,15 @@ class Importer(object):
 
        Perform any cleanup after import, if necessary.
 
        """
 

	
 
    def import_data(self, host_data=None, **kwargs):
 
    def import_data(self, host_data=None, now=None, **kwargs):
 
        """
 
        Import some data!  This is the core body of logic for that, regardless
 
        of where data is coming from or where it's headed.  Note that this
 
        method handles deletions as well as adds/updates.
 
        """
 
        self._setup(**kwargs)
 
        self.now = now or make_utc(datetime.datetime.utcnow(), tzinfo=True)
 
        if kwargs:
 
            self._setup(**kwargs)
 
        self.setup()
 
        created = updated = deleted = []
 

	
rattail/tests/commands/__init__.py
Show inline comments
 
new file 100644
rattail/tests/commands/test_core.py
Show inline comments
 
file renamed from rattail/tests/test_commands.py to rattail/tests/commands/test_core.py
 
@@ -5,18 +5,14 @@ from __future__ import unicode_literals, absolute_import
 
import csv
 
import datetime
 
import argparse
 
import logging
 
from unittest import TestCase
 
from cStringIO import StringIO
 

	
 
# from sqlalchemy import func
 
from mock import patch, Mock
 
from fixture import TempIO
 

	
 
from sqlalchemy import create_engine
 
from sqlalchemy import func
 

	
 
from rattail import commands
 
from rattail.commands.core import ArgumentParser, date_argument
 
from rattail.commands import core
 
from rattail.db import Session, model
 
from rattail.db.auth import authenticate_user
 
from rattail.tests import DataTestCase
 
@@ -25,7 +21,7 @@ from rattail.tests import DataTestCase
 
class TestArgumentParser(TestCase):
 

	
 
    def test_parse_args_preserves_extra_argv(self):
 
        parser = ArgumentParser()
 
        parser = core.ArgumentParser()
 
        parser.add_argument('--some-optional-arg')
 
        parser.add_argument('some_required_arg')
 
        args = parser.parse_args([
 
@@ -39,26 +35,26 @@ class TestArgumentParser(TestCase):
 
class TestDateArgument(TestCase):
 

	
 
    def test_valid_date_string_returns_date_object(self):
 
        date = date_argument('2014-01-01')
 
        date = core.date_argument('2014-01-01')
 
        self.assertEqual(date, datetime.date(2014, 1, 1))
 

	
 
    def test_invalid_date_string_raises_error(self):
 
        self.assertRaises(argparse.ArgumentTypeError, date_argument, 'invalid-date')
 
        self.assertRaises(argparse.ArgumentTypeError, core.date_argument, 'invalid-date')
 

	
 

	
 
class TestCommand(TestCase):
 

	
 
    def test_initial_subcommands_are_sane(self):
 
        command = commands.Command()
 
        command = core.Command()
 
        self.assertTrue('filemon' in command.subcommands)
 

	
 
    def test_unicode(self):
 
        command = commands.Command()
 
        command = core.Command()
 
        command.name = 'some-app'
 
        self.assertEqual(unicode(command), u'some-app')
 
        
 
    def test_iter_subcommands_includes_expected_item(self):
 
        command = commands.Command()
 
        command = core.Command()
 
        found = False
 
        for subcommand in command.iter_subcommands():
 
            if subcommand.name == 'filemon':
 
@@ -67,7 +63,7 @@ class TestCommand(TestCase):
 
        self.assertTrue(found)
 

	
 
    def test_print_help(self):
 
        command = commands.Command()
 
        command = core.Command()
 
        stdout = StringIO()
 
        command.stdout = stdout
 
        command.print_help()
 
@@ -77,41 +73,41 @@ class TestCommand(TestCase):
 
        self.assertTrue('Options:' in output)
 

	
 
    def test_run_with_no_args_prints_help(self):
 
        command = commands.Command()
 
        command = core.Command()
 
        with patch.object(command, 'print_help') as print_help:
 
            command.run()
 
            print_help.assert_called_once_with()
 

	
 
    def test_run_with_single_help_arg_prints_help(self):
 
        command = commands.Command()
 
        command = core.Command()
 
        with patch.object(command, 'print_help') as print_help:
 
            command.run('help')
 
            print_help.assert_called_once_with()
 

	
 
    def test_run_with_help_and_unknown_subcommand_args_prints_help(self):
 
        command = commands.Command()
 
        command = core.Command()
 
        with patch.object(command, 'print_help') as print_help:
 
            command.run('help', 'invalid-subcommand-name')
 
            print_help.assert_called_once_with()
 

	
 
    def test_run_with_help_and_subcommand_args_prints_subcommand_help(self):
 
        command = commands.Command()
 
        command = core.Command()
 
        fake = command.subcommands['fake'] = Mock()
 
        command.run('help', 'fake')
 
        fake.return_value.parser.print_help.assert_called_once_with()
 

	
 
    def test_run_with_unknown_subcommand_arg_prints_help(self):
 
        command = commands.Command()
 
        command = core.Command()
 
        with patch.object(command, 'print_help') as print_help:
 
            command.run('invalid-command-name')
 
            print_help.assert_called_once_with()
 

	
 
    def test_stdout_may_be_redirected(self):
 
        class Fake(commands.Subcommand):
 
        class Fake(core.Subcommand):
 
            def run(self, args):
 
                self.stdout.write("standard output stuff")
 
                self.stdout.flush()
 
        command = commands.Command()
 
        command = core.Command()
 
        fake = command.subcommands['fake'] = Fake
 
        tmp = TempIO()
 
        config_path = tmp.putfile('test.ini', '')
 
@@ -121,11 +117,11 @@ class TestCommand(TestCase):
 
            self.assertEqual(f.read(), "standard output stuff")
 

	
 
    def test_stderr_may_be_redirected(self):
 
        class Fake(commands.Subcommand):
 
        class Fake(core.Subcommand):
 
            def run(self, args):
 
                self.stderr.write("standard error stuff")
 
                self.stderr.flush()
 
        command = commands.Command()
 
        command = core.Command()
 
        fake = command.subcommands['fake'] = Fake
 
        tmp = TempIO()
 
        config_path = tmp.putfile('test.ini', '')
 
@@ -145,22 +141,22 @@ class TestCommand(TestCase):
 
class TestSubcommand(TestCase):
 

	
 
    def test_repr(self):
 
        command = commands.Command()
 
        subcommand = commands.Subcommand(command)
 
        command = core.Command()
 
        subcommand = core.Subcommand(command)
 
        subcommand.name = 'fake-command'
 
        self.assertEqual(repr(subcommand), "Subcommand(name=u'fake-command')")
 

	
 
    def test_add_parser_args_does_nothing(self):
 
        command = commands.Command()
 
        subcommand = commands.Subcommand(command)
 
        command = core.Command()
 
        subcommand = core.Subcommand(command)
 
        # Not sure this is really the way to test this, but...
 
        self.assertEqual(len(subcommand.parser._action_groups[0]._actions), 1)
 
        subcommand.add_parser_args(subcommand.parser)
 
        self.assertEqual(len(subcommand.parser._action_groups[0]._actions), 1)
 

	
 
    def test_run_not_implemented(self):
 
        command = commands.Command()
 
        subcommand = commands.Subcommand(command)
 
        command = core.Command()
 
        subcommand = core.Subcommand(command)
 
        args = subcommand.parser.parse_args([])
 
        self.assertRaises(NotImplementedError, subcommand.run, args)
 

	
 
@@ -177,7 +173,7 @@ class TestAddUser(DataTestCase):
 
        self.session.add(model.User(username='fred'))
 
        self.session.commit()
 
        self.assertEqual(self.session.query(model.User).count(), 1)
 
        commands.main('adduser', '--no-init', '--stderr', self.stderr_path, 'fred')
 
        core.main('adduser', '--no-init', '--stderr', self.stderr_path, 'fred')
 
        with open(self.stderr_path) as f:
 
            self.assertEqual(f.read(), "User 'fred' already exists.\n")
 
        self.assertEqual(self.session.query(model.User).count(), 1)
 
@@ -186,7 +182,7 @@ class TestAddUser(DataTestCase):
 
        self.assertEqual(self.session.query(model.User).count(), 0)
 
        with patch('rattail.commands.core.getpass') as getpass:
 
            getpass.side_effect = KeyboardInterrupt
 
            commands.main('adduser', '--no-init', '--stderr', self.stderr_path, 'fred')
 
            core.main('adduser', '--no-init', '--stderr', self.stderr_path, 'fred')
 
        with open(self.stderr_path) as f:
 
            self.assertEqual(f.read(), "\nOperation was canceled.\n")
 
        self.assertEqual(self.session.query(model.User).count(), 0)
 
@@ -195,7 +191,7 @@ class TestAddUser(DataTestCase):
 
        self.assertEqual(self.session.query(model.User).count(), 0)
 
        with patch('rattail.commands.core.getpass') as getpass:
 
            getpass.return_value = 'fredpass'
 
            commands.main('adduser', '--no-init', '--stdout', self.stdout_path, 'fred')
 
            core.main('adduser', '--no-init', '--stdout', self.stdout_path, 'fred')
 
        with open(self.stdout_path) as f:
 
            self.assertEqual(f.read(), "Created user: fred\n")
 
        fred = self.session.query(model.User).one()
 
@@ -208,55 +204,12 @@ class TestAddUser(DataTestCase):
 
        self.assertEqual(self.session.query(model.User).count(), 0)
 
        with patch('rattail.commands.core.getpass') as getpass:
 
            getpass.return_value = 'fredpass'
 
            commands.main('adduser', '--no-init', '--stdout', self.stdout_path, 'fred', '--administrator')
 
            core.main('adduser', '--no-init', '--stdout', self.stdout_path, 'fred', '--administrator')
 
        fred = self.session.query(model.User).one()
 
        self.assertEqual(len(fred.roles), 1)
 
        self.assertEqual(fred.roles[0].name, 'Administrator')
 

	
 

	
 
# TODO: more broken tests..ugh.  these aren't very good or else i might bother
 
# fixing them...
 
# class TestDatabaseSync(TestCase):
 

	
 
#     @patch('rattail.db.sync.linux.start_daemon')
 
#     def test_start_daemon_with_default_args(self, start_daemon):
 
#         commands.main('dbsync', '--no-init', 'start')
 
#         start_daemon.assert_called_once_with(None, None, True)
 

	
 
#     @patch('rattail.db.sync.linux.start_daemon')
 
#     def test_start_daemon_with_explicit_args(self, start_daemon):
 
#         tmp = TempIO()
 
#         pid_path = tmp.putfile('test.pid', '')
 
#         commands.main('dbsync', '--no-init', '--pidfile', pid_path, '--do-not-daemonize', 'start')
 
#         start_daemon.assert_called_once_with(None, pid_path, False)
 

	
 
#     @patch('rattail.db.sync.linux.start_daemon')
 
#     def test_keyboard_interrupt_raises_error_when_daemonized(self, start_daemon):
 
#         start_daemon.side_effect = KeyboardInterrupt
 
#         self.assertRaises(KeyboardInterrupt, commands.main, 'dbsync', '--no-init', 'start')
 

	
 
#     @patch('rattail.db.sync.linux.start_daemon')
 
#     def test_keyboard_interrupt_handled_gracefully_when_not_daemonized(self, start_daemon):
 
#         tmp = TempIO()
 
#         stderr_path = tmp.putfile('stderr.txt', '')
 
#         start_daemon.side_effect = KeyboardInterrupt
 
#         commands.main('dbsync', '--no-init', '--stderr', stderr_path, '--do-not-daemonize', 'start')
 
#         with open(stderr_path) as f:
 
#             self.assertEqual(f.read(), "Interrupted.\n")
 

	
 
#     @patch('rattail.db.sync.linux.stop_daemon')
 
#     def test_stop_daemon_with_default_args(self, stop_daemon):
 
#         commands.main('dbsync', '--no-init', 'stop')
 
#         stop_daemon.assert_called_once_with(None, None)
 

	
 
#     @patch('rattail.db.sync.linux.stop_daemon')
 
#     def test_stop_daemon_with_explicit_args(self, stop_daemon):
 
#         tmp = TempIO()
 
#         pid_path = tmp.putfile('test.pid', '')
 
#         commands.main('dbsync', '--no-init', '--pidfile', pid_path, 'stop')
 
#         stop_daemon.assert_called_once_with(None, pid_path)
 

	
 

	
 
class TestDump(DataTestCase):
 

	
 
    def setUp(self):
 
@@ -268,14 +221,14 @@ class TestDump(DataTestCase):
 
    def test_unknown_model_cannot_be_dumped(self):
 
        tmp = TempIO()
 
        stderr_path = tmp.putfile('stderr.txt', '')
 
        self.assertRaises(SystemExit, commands.main, '--no-init', '--stderr', stderr_path, 'dump', 'NoSuchModel')
 
        self.assertRaises(SystemExit, core.main, '--no-init', '--stderr', stderr_path, 'dump', 'NoSuchModel')
 
        with open(stderr_path) as f:
 
            self.assertEqual(f.read(), "Unknown model: NoSuchModel\n")
 

	
 
    def test_dump_goes_to_stdout_by_default(self):
 
        tmp = TempIO()
 
        stdout_path = tmp.putfile('stdout.txt', '')
 
        commands.main('--no-init', '--stdout', stdout_path, 'dump', 'Product')
 
        core.main('--no-init', '--stdout', stdout_path, 'dump', 'Product')
 
        with open(stdout_path, 'rb') as csv_file:
 
            reader = csv.DictReader(csv_file)
 
            upcs = [row['upc'] for row in reader]
 
@@ -286,7 +239,7 @@ class TestDump(DataTestCase):
 
    def test_dump_goes_to_file_if_so_invoked(self):
 
        tmp = TempIO()
 
        output_path = tmp.putfile('output.txt', '')
 
        commands.main('--no-init', 'dump', 'Product', '--output', output_path)
 
        core.main('--no-init', 'dump', 'Product', '--output', output_path)
 
        with open(output_path, 'rb') as csv_file:
 
            reader = csv.DictReader(csv_file)
 
            upcs = [row['upc'] for row in reader]
rattail/tests/commands/test_importing.py
Show inline comments
 
new file 100644
 
# -*- coding: utf-8 -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
import argparse
 
from unittest import TestCase
 

	
 
from mock import Mock, patch
 

	
 
from rattail.commands import importing
 
from rattail.importing import ImportHandler
 
from rattail.config import RattailConfig
 
from rattail.tests.importing import ImporterTester
 
from rattail.tests.importing.test_handlers import MockImportHandler
 
from rattail.tests.importing.test_importers import MockImporter
 

	
 

	
 
class MockImport(importing.ImportSubcommand):
 

	
 
    def get_handler_factory(self):
 
        return MockImportHandler
 

	
 

	
 
class TestImportSubcommandBasics(TestCase):
 

	
 
    def test_get_handler_factory(self):
 
        command = importing.ImportSubcommand()
 
        self.assertRaises(NotImplementedError, command.get_handler_factory)
 

	
 
    def test_get_handler(self):
 

	
 
        # no config
 
        command = MockImport()
 
        handler = command.get_handler()
 
        self.assertIs(type(handler), MockImportHandler)
 
        self.assertIsNone(handler.config)
 

	
 
        # with config
 
        config = RattailConfig()
 
        command = MockImport(config=config)
 
        handler = command.get_handler()
 
        self.assertIs(type(handler), MockImportHandler)
 
        self.assertIs(handler.config, config)
 

	
 
    def test_add_parser_args(self):
 
        # TODO: this doesn't really test anything, but does give some coverage..
 

	
 
        # no handler
 
        command = importing.ImportSubcommand()
 
        parser = argparse.ArgumentParser()
 
        command.add_parser_args(parser)
 

	
 
        # with handler
 
        command = MockImport()
 
        parser = argparse.ArgumentParser()
 
        command.add_parser_args(parser)
 

	
 

	
 
class TestImportSubcommandRun(ImporterTester, TestCase):
 

	
 
    sample_data = {
 
        '16oz': {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"},
 
        '32oz': {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"},
 
        '1gal': {'upc': '00074305011283', 'description': "Apple Cider Vinegar 1gal"},
 
    }
 

	
 
    def setUp(self):
 
        self.command = MockImport()
 
        self.handler = MockImportHandler()
 
        self.importer = MockImporter()
 

	
 
    def import_data(self, **kwargs):
 
        models = kwargs.pop('models', [])
 
        kwargs.setdefault('dry_run', False)
 
        args = argparse.Namespace(models=models, **kwargs)
 

	
 
        # must modify our importer in-place since we need the handler to return
 
        # that specific instance, below (because the host/local data context
 
        # managers reference that instance directly)
 
        self.importer._setup(**kwargs)
 
        with patch.object(self.command, 'get_handler', Mock(return_value=self.handler)):
 
            with patch.object(self.handler, 'get_importer', Mock(return_value=self.importer)):
 
                self.command.run(args)
 

	
 
        if self.handler._result:
 
            self.result = self.handler._result['Product']
 
        else:
 
            self.result = [], [], []
 

	
 
    def test_create(self):
 
        local = self.copy_data()
 
        del local['32oz']
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data()
 
        self.assert_import_created('32oz')
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_update(self):
 
        local = self.copy_data()
 
        local['16oz']['description'] = "wrong description"
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data()
 
        self.assert_import_created()
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted()
 

	
 
    def test_delete(self):
 
        local = self.copy_data()
 
        local['bogus'] = {'upc': '00000000000000', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data()
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted('bogus')
 

	
 
    def test_duplicate(self):
 
        host = self.copy_data()
 
        host['32oz-dupe'] = host['32oz']
 
        with self.host_data(host):
 
            with self.local_data(self.sample_data):
 
                self.import_data()
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_max_create(self):
 
        local = self.copy_data()
 
        del local['16oz']
 
        del local['1gal']
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_create=1)
 
        self.assert_import_created('16oz')
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_max_total_create(self):
 
        local = self.copy_data()
 
        del local['16oz']
 
        del local['1gal']
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_total=1)
 
        self.assert_import_created('16oz')
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_max_update(self):
 
        local = self.copy_data()
 
        local['16oz']['description'] = "wrong"
 
        local['1gal']['description'] = "wrong"
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_update=1)
 
        self.assert_import_created()
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted()
 

	
 
    def test_max_total_update(self):
 
        local = self.copy_data()
 
        local['16oz']['description'] = "wrong"
 
        local['1gal']['description'] = "wrong"
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_total=1)
 
        self.assert_import_created()
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted()
 

	
 
    def test_max_delete(self):
 
        local = self.copy_data()
 
        local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"}
 
        local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_delete=1)
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted('bogus1')
 

	
 
    def test_max_total_delete(self):
 
        local = self.copy_data()
 
        local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"}
 
        local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_total=1)
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted('bogus1')
 

	
 
    def test_dry_run(self):
 
        local = self.copy_data()
 
        del local['32oz']
 
        local['16oz']['description'] = "wrong description"
 
        local['bogus'] = {'upc': '00000000000000', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(dry_run=True)
 
        # TODO: maybe need a way to confirm no changes actually made due to dry
 
        # run; currently results still reflect "proposed" changes.  this rather
 
        # bogus test is here just for coverage sake
 
        self.assert_import_created('32oz')
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted('bogus')
rattail/tests/importing/test_handlers.py
Show inline comments
 
new file 100644
 
# -*- coding: utf-8 -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from unittest import TestCase
 

	
 
from sqlalchemy import orm
 
from mock import patch, Mock
 

	
 
from rattail.importing import handlers, Importer
 
from rattail.config import RattailConfig
 
from rattail.tests.importing import ImporterTester
 
from rattail.tests.importing.test_importers import MockImporter
 

	
 

	
 
class TestImportHandlerBasics(TestCase):
 

	
 
    def test_init(self):
 

	
 
        # vanilla
 
        handler = handlers.ImportHandler()
 
        self.assertEqual(handler.importers, {})
 
        self.assertEqual(handler.get_importers(), {})
 
        self.assertEqual(handler.get_importer_keys(), [])
 
        self.assertEqual(handler.get_default_keys(), [])
 

	
 
        # with config
 
        handler = handlers.ImportHandler()
 
        self.assertIsNone(handler.config)
 
        config = RattailConfig()
 
        handler = handlers.ImportHandler(config=config)
 
        self.assertIs(handler.config, config)
 

	
 
        # extra kwarg
 
        handler = handlers.ImportHandler()
 
        self.assertRaises(AttributeError, getattr, handler, 'foo')
 
        handler = handlers.ImportHandler(foo='bar')
 
        self.assertEqual(handler.foo, 'bar')
 

	
 
    def test_get_importer(self):
 
        get_importers = Mock(return_value={'foo': Importer})
 

	
 
        # no importers
 
        handler = handlers.ImportHandler()
 
        self.assertIsNone(handler.get_importer('foo'))
 

	
 
        # no config
 
        with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
 
            handler = handlers.ImportHandler()
 
        importer = handler.get_importer('foo')
 
        self.assertIs(type(importer), Importer)
 
        self.assertIsNone(importer.config)
 
        self.assertIs(importer.handler, handler)
 

	
 
        # with config
 
        config = RattailConfig()
 
        with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
 
            handler = handlers.ImportHandler(config=config)
 
        importer = handler.get_importer('foo')
 
        self.assertIs(type(importer), Importer)
 
        self.assertIs(importer.config, config)
 
        self.assertIs(importer.handler, handler)
 

	
 
        # with extra kwarg
 
        with patch.object(handlers.ImportHandler, 'get_importers', get_importers):
 
            handler = handlers.ImportHandler()
 
        importer = handler.get_importer('foo')
 
        self.assertRaises(AttributeError, getattr, importer, 'bar')
 
        importer = handler.get_importer('foo', bar='baz')
 
        self.assertEqual(importer.bar, 'baz')
 

	
 
    def test_get_importer_kwargs(self):
 

	
 
        # empty by default
 
        handler = handlers.ImportHandler()
 
        self.assertEqual(handler.get_importer_kwargs('foo'), {})
 

	
 
        # extra kwargs are preserved
 
        handler = handlers.ImportHandler()
 
        self.assertEqual(handler.get_importer_kwargs('foo', bar='baz'), {'bar': 'baz'})
 

	
 

	
 
######################################################################
 
# fake import handler, tested mostly for basic coverage
 
######################################################################
 

	
 
class MockImportHandler(handlers.ImportHandler):
 

	
 
    def get_importers(self):
 
        return {'Product': MockImporter}
 

	
 
    def import_data(self, *keys, **kwargs):
 
        result = super(MockImportHandler, self).import_data(*keys, **kwargs)
 
        self._result = result
 
        return result
 

	
 

	
 
class TestImportHandlerImportData(ImporterTester, TestCase):
 

	
 
    sample_data = {
 
        '16oz': {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"},
 
        '32oz': {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"},
 
        '1gal': {'upc': '00074305011283', 'description': "Apple Cider Vinegar 1gal"},
 
    }
 

	
 
    def setUp(self):
 
        self.handler = MockImportHandler()
 
        self.importer = MockImporter()
 

	
 
    def import_data(self, **kwargs):
 
        # must modify our importer in-place since we need the handler to return
 
        # that specific instance, below (because the host/local data context
 
        # managers reference that instance directly)
 
        self.importer._setup(**kwargs)
 
        with patch.object(self.handler, 'get_importer', Mock(return_value=self.importer)):
 
            result = self.handler.import_data('Product', **kwargs)
 
        if result:
 
            self.result = result['Product']
 
        else:
 
            self.result = [], [], []
 

	
 
    def test_invalid_importer_key_is_ignored(self):
 
        handler = handlers.ImportHandler()
 
        self.assertNotIn('InvalidKey', handler.importers)
 
        self.assertEqual(handler.import_data('InvalidKey'), {})
 

	
 
    def test_create(self):
 
        local = self.copy_data()
 
        del local['32oz']
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data()
 
        self.assert_import_created('32oz')
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_update(self):
 
        local = self.copy_data()
 
        local['16oz']['description'] = "wrong description"
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data()
 
        self.assert_import_created()
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted()
 

	
 
    def test_delete(self):
 
        local = self.copy_data()
 
        local['bogus'] = {'upc': '00000000000000', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data()
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted('bogus')
 

	
 
    def test_duplicate(self):
 
        host = self.copy_data()
 
        host['32oz-dupe'] = host['32oz']
 
        with self.host_data(host):
 
            with self.local_data(self.sample_data):
 
                self.import_data()
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_max_create(self):
 
        local = self.copy_data()
 
        del local['16oz']
 
        del local['1gal']
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_create=1)
 
        self.assert_import_created('16oz')
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_max_total_create(self):
 
        local = self.copy_data()
 
        del local['16oz']
 
        del local['1gal']
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_total=1)
 
        self.assert_import_created('16oz')
 
        self.assert_import_updated()
 
        self.assert_import_deleted()
 

	
 
    def test_max_update(self):
 
        local = self.copy_data()
 
        local['16oz']['description'] = "wrong"
 
        local['1gal']['description'] = "wrong"
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_update=1)
 
        self.assert_import_created()
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted()
 

	
 
    def test_max_total_update(self):
 
        local = self.copy_data()
 
        local['16oz']['description'] = "wrong"
 
        local['1gal']['description'] = "wrong"
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_total=1)
 
        self.assert_import_created()
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted()
 

	
 
    def test_max_delete(self):
 
        local = self.copy_data()
 
        local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"}
 
        local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_delete=1)
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted('bogus1')
 

	
 
    def test_max_total_delete(self):
 
        local = self.copy_data()
 
        local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"}
 
        local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(max_total=1)
 
        self.assert_import_created()
 
        self.assert_import_updated()
 
        self.assert_import_deleted('bogus1')
 

	
 
    def test_dry_run(self):
 
        local = self.copy_data()
 
        del local['32oz']
 
        local['16oz']['description'] = "wrong description"
 
        local['bogus'] = {'upc': '00000000000000', 'description': "Delete Me"}
 
        with self.host_data(self.sample_data):
 
            with self.local_data(local):
 
                self.import_data(dry_run=True)
 
        # TODO: maybe need a way to confirm no changes actually made due to dry
 
        # run; currently results still reflect "proposed" changes.  this rather
 
        # bogus test is here just for coverage sake
 
        self.assert_import_created('32oz')
 
        self.assert_import_updated('16oz')
 
        self.assert_import_deleted('bogus')
 

	
 

	
 
Session = orm.sessionmaker()
 

	
 

	
 
class MockFromSQLAlchemyHandler(handlers.FromSQLAlchemyHandler):
 

	
 
    def make_host_session(self):
 
        return Session()
 

	
 

	
 
class MockToSQLAlchemyHandler(handlers.ToSQLAlchemyHandler):
 

	
 
    def make_session(self):
 
        return Session()
 

	
 

	
 
class TestFromSQLAlchemyHandler(TestCase):
 

	
 
    def test_init(self):
 
        handler = handlers.FromSQLAlchemyHandler()
 
        self.assertRaises(NotImplementedError, handler.make_host_session)
 

	
 
    def test_get_importer_kwargs(self):
 
        session = object()
 
        handler = handlers.FromSQLAlchemyHandler(host_session=session)
 
        kwargs = handler.get_importer_kwargs(None)
 
        self.assertEqual(list(kwargs.iterkeys()), ['host_session'])
 
        self.assertIs(kwargs['host_session'], session)
 

	
 
    def test_begin_host_transaction(self):
 
        handler = MockFromSQLAlchemyHandler()
 
        self.assertIsNone(handler.host_session)
 
        handler.begin_host_transaction()
 
        self.assertIsInstance(handler.host_session, orm.Session)
 
        handler.host_session.close()
 

	
 
    def test_commit_host_transaction(self):
 
        # TODO: test actual commit for data changes
 
        session = Session()
 
        handler = handlers.FromSQLAlchemyHandler(host_session=session)
 
        self.assertIs(handler.host_session, session)
 
        handler.commit_host_transaction()
 
        self.assertIsNone(handler.host_session)
 

	
 
    def test_rollback_host_transaction(self):
 
        # TODO: test actual rollback for data changes
 
        session = Session()
 
        handler = handlers.FromSQLAlchemyHandler(host_session=session)
 
        self.assertIs(handler.host_session, session)
 
        handler.rollback_host_transaction()
 
        self.assertIsNone(handler.host_session)
 

	
 

	
 
class TestToSQLAlchemyHandler(TestCase):
 

	
 
    def test_init(self):
 
        handler = handlers.ToSQLAlchemyHandler()
 
        self.assertRaises(NotImplementedError, handler.make_session)
 

	
 
    def test_get_importer_kwargs(self):
 
        session = object()
 
        handler = handlers.ToSQLAlchemyHandler(session=session)
 
        kwargs = handler.get_importer_kwargs(None)
 
        self.assertEqual(list(kwargs.iterkeys()), ['session'])
 
        self.assertIs(kwargs['session'], session)
 

	
 
    def test_begin_local_transaction(self):
 
        handler = MockToSQLAlchemyHandler()
 
        self.assertIsNone(handler.session)
 
        handler.begin_local_transaction()
 
        self.assertIsInstance(handler.session, orm.Session)
 
        handler.session.close()
 

	
 
    def test_commit_local_transaction(self):
 
        # TODO: test actual commit for data changes
 
        session = Session()
 
        handler = handlers.ToSQLAlchemyHandler(session=session)
 
        self.assertIs(handler.session, session)
 
        handler.commit_local_transaction()
 
        self.assertIsNone(handler.session)
 

	
 
    def test_rollback_local_transaction(self):
 
        # TODO: test actual rollback for data changes
 
        session = Session()
 
        handler = handlers.ToSQLAlchemyHandler(session=session)
 
        self.assertIs(handler.session, session)
 
        handler.rollback_local_transaction()
 
        self.assertIsNone(handler.session)
0 comments (0 inline, 0 general)