diff --git a/rattail/commands/__init__.py b/rattail/commands/__init__.py index c86091fd3e8a89707d22cfb2dfbf638f6e3d3eec..d0269a1dac5cdd99f3e80167863d8bd457a7ae3d 100644 --- a/rattail/commands/__init__.py +++ b/rattail/commands/__init__.py @@ -27,3 +27,4 @@ Console Commands from __future__ import unicode_literals, absolute_import from .core import main, Command, Subcommand, OldImportSubcommand, NewImportSubcommand, Dump, date_argument +from .importing import ImportSubcommand diff --git a/rattail/commands/importing.py b/rattail/commands/importing.py new file mode 100644 index 0000000000000000000000000000000000000000..2a10a5ca966133126e7bd855d3341d1bcd0d8c17 --- /dev/null +++ b/rattail/commands/importing.py @@ -0,0 +1,153 @@ +# -*- coding: utf-8 -*- +################################################################################ +# +# Rattail -- Retail Software Framework +# Copyright © 2010-2016 Lance Edgar +# +# This file is part of Rattail. +# +# Rattail is free software: you can redistribute it and/or modify it under the +# terms of the GNU Affero General Public License as published by the Free +# Software Foundation, either version 3 of the License, or (at your option) +# any later version. +# +# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for +# more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with Rattail. If not, see . +# +################################################################################ +""" +Importing Commands +""" + +from __future__ import unicode_literals, absolute_import + +import logging + +from rattail.commands.core import Subcommand, date_argument +from rattail.util import load_object + + +log = logging.getLogger(__name__) + + +class ImportSubcommand(Subcommand): + """ + Base class for subcommands which use the (new) data importing system. + """ + # TODO: move this into Subcommand + parent_name = None + + def get_handler_factory(self): + """ + Subclasses must override this, and return a callable that creates an + import handler instance which the command should use. + """ + raise NotImplementedError + + def get_handler(self, **kwargs): + """ + Returns a handler instance to be used by the command. + """ + factory = self.get_handler_factory() + kwargs.setdefault('config', getattr(self, 'config', None)) + kwargs.setdefault('command', self) + kwargs = self.get_handler_kwargs(**kwargs) + return factory(**kwargs) + + def get_handler_kwargs(self, **kwargs): + """ + Return a dict of kwargs to be passed to the handler factory. + """ + return kwargs + + def add_parser_args(self, parser): + + # model names (aka importer keys) + doc = ("Which data models to import. If you specify any, then only " + "data for those models will be imported. If you do not specify " + "any, then all *default* models will be imported.") + try: + handler = self.get_handler() + except NotImplementedError: + pass + else: + doc += " Supported models are: ({})".format(', '.join(handler.get_importer_keys())) + parser.add_argument('models', nargs='*', metavar='MODEL', help=doc) + + # start/end date + parser.add_argument('--start-date', type=date_argument, + help="Optional (inclusive) starting point for date range, by which host " + "data should be filtered. Only used by certain importers.") + parser.add_argument('--end-date', type=date_argument, + help="Optional (inclusive) ending point for date range, by which host " + "data should be filtered. Only used by certain importers.") + + # allow create? + parser.add_argument('--create', action='store_true', default=True, + help="Allow new records to be created during the import.") + parser.add_argument('--no-create', action='store_false', dest='create', + help="Do not allow new records to be created during the import.") + parser.add_argument('--max-create', type=int, metavar='COUNT', + help="Maximum number of records which may be created, after which a " + "given import task should stop. Note that this applies on a per-model " + "basis and not overall.") + + # allow update? + parser.add_argument('--update', action='store_true', default=True, + help="Allow existing records to be updated during the import.") + parser.add_argument('--no-update', action='store_false', dest='update', + help="Do not allow existing records to be updated during the import.") + parser.add_argument('--max-update', type=int, metavar='COUNT', + help="Maximum number of records which may be updated, after which a " + "given import task should stop. Note that this applies on a per-model " + "basis and not overall.") + + # allow delete? + parser.add_argument('--delete', action='store_true', default=False, + help="Allow records to be deleted during the import.") + parser.add_argument('--no-delete', action='store_false', dest='delete', + help="Do not allow records to be deleted during the import.") + parser.add_argument('--max-delete', type=int, metavar='COUNT', + help="Maximum number of records which may be deleted, after which a " + "given import task should stop. Note that this applies on a per-model " + "basis and not overall.") + + # max total changes, per model + parser.add_argument('--max-total', type=int, metavar='COUNT', + help="Maximum number of *any* record changes which may occur, after which " + "a given import task should stop. Note that this applies on a per-model " + "basis and not overall.") + + # treat changes as warnings? + parser.add_argument('--warnings', '-W', action='store_true', + help="Set this flag if you expect a \"clean\" import, and wish for any " + "changes which do occur to be processed further and/or specially. The " + "behavior of this flag is ultimately up to the import handler, but the " + "default is to send an email notification.") + + # dry run? + parser.add_argument('--dry-run', action='store_true', + help="Go through the full motions and allow logging etc. to " + "occur, but rollback (abort) the transaction at the end.") + + def run(self, args): + log.info("begin `{} {}` for data models: {}".format( + self.parent_name, self.name, ', '.join(args.models or ["(ALL)"]))) + + handler = self.get_handler(args=args, progress=self.progress) + models = args.models or handler.get_default_keys() + log.debug("using handler: {}".format(handler)) + log.debug("importing models: {}".format(models)) + log.debug("args are: {}".format(args)) + handler.import_data(*models) + + # TODO: should this logging happen elsewhere / be customizable? + if args.dry_run: + log.info("dry run, so transaction was rolled back") + else: + log.info("transaction was committed") diff --git a/rattail/importing/__init__.py b/rattail/importing/__init__.py index 59a333cc8cf5d243aaa4a69a2ed640e29bfbcab7..8383c939ec6eacb851000e12ce6e4cb3aa6f52cf 100644 --- a/rattail/importing/__init__.py +++ b/rattail/importing/__init__.py @@ -28,3 +28,4 @@ from __future__ import unicode_literals, absolute_import from .importers import Importer, FromQuery from .sqlalchemy import FromSQLAlchemy, ToSQLAlchemy +from .handlers import ImportHandler, FromSQLAlchemyHandler, ToSQLAlchemyHandler diff --git a/rattail/importing/handlers.py b/rattail/importing/handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..185de60d1d9b904759e54ef2674b39020f8762ba --- /dev/null +++ b/rattail/importing/handlers.py @@ -0,0 +1,242 @@ +# -*- coding: utf-8 -*- +################################################################################ +# +# Rattail -- Retail Software Framework +# Copyright © 2010-2016 Lance Edgar +# +# This file is part of Rattail. +# +# Rattail is free software: you can redistribute it and/or modify it under the +# terms of the GNU Affero General Public License as published by the Free +# Software Foundation, either version 3 of the License, or (at your option) +# any later version. +# +# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS +# FOR A PARTICULAR PURPOSE. See the GNU Affero General Public License for +# more details. +# +# You should have received a copy of the GNU Affero General Public License +# along with Rattail. If not, see . +# +################################################################################ +""" +Import Handlers +""" + +from __future__ import unicode_literals, absolute_import + +import datetime +import logging + +from rattail.util import OrderedDict + + +log = logging.getLogger(__name__) + + +class ImportHandler(object): + """ + Base class for all import handlers. + """ + host_title = "Host" + local_title = "Local" + progress = None + dry_run = False + + def __init__(self, config=None, **kwargs): + self.config = config + self.importers = self.get_importers() + for key, value in kwargs.iteritems(): + setattr(self, key, value) + + def get_importers(self): + """ + Returns a dict of all available importers, where the keys are model + names and the values are importer factories. All subclasses will want + to override this. Note that if you return an + :class:`python:collections.OrderedDict` instance, you can affect the + ordering of keys in the command line help system, etc. + """ + return {} + + def get_importer_keys(self): + """ + Returns the list of keys corresponding to the available importers. + """ + return list(self.importers.iterkeys()) + + def get_default_keys(self): + """ + Returns the list of keys corresponding to the "default" importers. + Override this if you wish certain importers to be excluded by default, + e.g. when first testing them out etc. + """ + return self.get_importer_keys() + + def get_importer(self, key, **kwargs): + """ + Returns an importer instance corresponding to the given key. + """ + if key in self.importers: + kwargs.setdefault('handler', self) + kwargs.setdefault('config', self.config) + kwargs = self.get_importer_kwargs(key, **kwargs) + return self.importers[key](**kwargs) + + def get_importer_kwargs(self, key, **kwargs): + """ + Return a dict of kwargs to be used when construcing an importer with + the given key. + """ + return kwargs + + def import_data(self, *keys, **kwargs): + """ + Import all data for the given importer/model keys. + """ + self.import_began = datetime.datetime.utcnow() + self.dry_run = kwargs.pop('dry_run', False) + self.progress = kwargs.pop('progress', None) + self.warnings = kwargs.pop('warnings', False) + kwargs.update({'dry_run': self.dry_run, + 'progress': self.progress}) + + self.setup() + self.begin_transaction() + changes = OrderedDict() + + for key in keys: + importer = self.get_importer(key, **kwargs) + if not importer: + log.warning("skipping unknown importer: {}".format(key)) + continue + + created, updated, deleted = importer.import_data() + + changed = bool(created or updated or deleted) + logger = log.warning if changed and self.warnings else log.info + logger("{} -> {}: added {}, updated {}, deleted {} {} records".format( + self.host_title, self.local_title, len(created), len(updated), len(deleted), key)) + if changed: + changes[key] = created, updated, deleted + + if changes: + self.process_changes(changes) + if self.dry_run: + self.rollback_transaction() + else: + self.commit_transaction() + self.teardown() + return changes + + def setup(self): + """ + Perform any additional setup if/as necessary, prior to running the + import task(s). + """ + + def teardown(self): + """ + Perform any cleanup necessary, after running the import task(s). + """ + + def begin_transaction(self): + self.begin_host_transaction() + self.begin_local_transaction() + + def begin_host_transaction(self): + pass + + def begin_local_transaction(self): + pass + + def rollback_transaction(self): + self.rollback_host_transaction() + self.rollback_local_transaction() + + def rollback_host_transaction(self): + pass + + def rollback_local_transaction(self): + pass + + def commit_transaction(self): + self.commit_host_transaction() + self.commit_local_transaction() + + def commit_host_transaction(self): + pass + + def commit_local_transaction(self): + pass + + def process_changes(self, changes): + """ + This method is called any time changes occur, regardless of whether the + import is running in "warnings" mode. Default implementation does + nothing; override as needed. + """ + + +class FromSQLAlchemyHandler(ImportHandler): + """ + Handler for imports for which the host data source is represented by a + SQLAlchemy engine and ORM. + """ + host_session = None + + def make_host_session(self): + """ + Subclasses must override this to define the host database connection. + """ + raise NotImplementedError + + def get_importer_kwargs(self, key): + kwargs = super(FromSQLAlchemyHandler, self).get_importer_kwargs(key) + kwargs.setdefault('host_session', self.host_session) + return kwargs + + def begin_host_transaction(self): + self.host_session = self.make_host_session() + + def rollback_host_transaction(self): + self.host_session.rollback() + self.host_session.close() + self.host_session = None + + def commit_host_transaction(self): + self.host_session.commit() + self.host_session.close() + self.host_session = None + + +class ToSQLAlchemyHandler(ImportHandler): + """ + Handler for imports which target a SQLAlchemy ORM on the local side. + """ + session = None + + def make_session(self): + """ + Subclasses must override this to define the local database connection. + """ + raise NotImplementedError + + def get_importer_kwargs(self, key): + kwargs = super(ToSQLAlchemyHandler, self).get_importer_kwargs(key) + kwargs.setdefault('session', self.session) + return kwargs + + def begin_local_transaction(self): + self.session = self.make_session() + + def rollback_local_transaction(self): + self.session.rollback() + self.session.close() + self.session = None + + def commit_local_transaction(self): + self.session.commit() + self.session.close() + self.session = None diff --git a/rattail/importing/importers.py b/rattail/importing/importers.py index 35bcea6e9d19e45e24ad4d331a71b6fffa9206e4..73e369daecff679d11f75ecca0eb29974b176cec 100644 --- a/rattail/importing/importers.py +++ b/rattail/importing/importers.py @@ -88,7 +88,6 @@ class Importer(object): self._setup(**kwargs) def _setup(self, **kwargs): - self.now = kwargs.pop('now', make_utc(datetime.datetime.utcnow(), tzinfo=True)) self.create = kwargs.pop('create', self.allow_create) and self.allow_create self.update = kwargs.pop('update', self.allow_update) and self.allow_update self.delete = kwargs.pop('delete', self.allow_delete) and self.allow_delete @@ -110,13 +109,15 @@ class Importer(object): Perform any cleanup after import, if necessary. """ - def import_data(self, host_data=None, **kwargs): + def import_data(self, host_data=None, now=None, **kwargs): """ Import some data! This is the core body of logic for that, regardless of where data is coming from or where it's headed. Note that this method handles deletions as well as adds/updates. """ - self._setup(**kwargs) + self.now = now or make_utc(datetime.datetime.utcnow(), tzinfo=True) + if kwargs: + self._setup(**kwargs) self.setup() created = updated = deleted = [] diff --git a/rattail/tests/commands/__init__.py b/rattail/tests/commands/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/rattail/tests/test_commands.py b/rattail/tests/commands/test_core.py similarity index 76% rename from rattail/tests/test_commands.py rename to rattail/tests/commands/test_core.py index 4964f033863610f591b2cfd5af72a0d248fa3e3f..fc80ab55aa6ca44fdfdd231181640f6d92355a24 100644 --- a/rattail/tests/test_commands.py +++ b/rattail/tests/commands/test_core.py @@ -5,18 +5,14 @@ from __future__ import unicode_literals, absolute_import import csv import datetime import argparse -import logging from unittest import TestCase from cStringIO import StringIO +# from sqlalchemy import func from mock import patch, Mock from fixture import TempIO -from sqlalchemy import create_engine -from sqlalchemy import func - -from rattail import commands -from rattail.commands.core import ArgumentParser, date_argument +from rattail.commands import core from rattail.db import Session, model from rattail.db.auth import authenticate_user from rattail.tests import DataTestCase @@ -25,7 +21,7 @@ from rattail.tests import DataTestCase class TestArgumentParser(TestCase): def test_parse_args_preserves_extra_argv(self): - parser = ArgumentParser() + parser = core.ArgumentParser() parser.add_argument('--some-optional-arg') parser.add_argument('some_required_arg') args = parser.parse_args([ @@ -39,26 +35,26 @@ class TestArgumentParser(TestCase): class TestDateArgument(TestCase): def test_valid_date_string_returns_date_object(self): - date = date_argument('2014-01-01') + date = core.date_argument('2014-01-01') self.assertEqual(date, datetime.date(2014, 1, 1)) def test_invalid_date_string_raises_error(self): - self.assertRaises(argparse.ArgumentTypeError, date_argument, 'invalid-date') + self.assertRaises(argparse.ArgumentTypeError, core.date_argument, 'invalid-date') class TestCommand(TestCase): def test_initial_subcommands_are_sane(self): - command = commands.Command() + command = core.Command() self.assertTrue('filemon' in command.subcommands) def test_unicode(self): - command = commands.Command() + command = core.Command() command.name = 'some-app' self.assertEqual(unicode(command), u'some-app') def test_iter_subcommands_includes_expected_item(self): - command = commands.Command() + command = core.Command() found = False for subcommand in command.iter_subcommands(): if subcommand.name == 'filemon': @@ -67,7 +63,7 @@ class TestCommand(TestCase): self.assertTrue(found) def test_print_help(self): - command = commands.Command() + command = core.Command() stdout = StringIO() command.stdout = stdout command.print_help() @@ -77,41 +73,41 @@ class TestCommand(TestCase): self.assertTrue('Options:' in output) def test_run_with_no_args_prints_help(self): - command = commands.Command() + command = core.Command() with patch.object(command, 'print_help') as print_help: command.run() print_help.assert_called_once_with() def test_run_with_single_help_arg_prints_help(self): - command = commands.Command() + command = core.Command() with patch.object(command, 'print_help') as print_help: command.run('help') print_help.assert_called_once_with() def test_run_with_help_and_unknown_subcommand_args_prints_help(self): - command = commands.Command() + command = core.Command() with patch.object(command, 'print_help') as print_help: command.run('help', 'invalid-subcommand-name') print_help.assert_called_once_with() def test_run_with_help_and_subcommand_args_prints_subcommand_help(self): - command = commands.Command() + command = core.Command() fake = command.subcommands['fake'] = Mock() command.run('help', 'fake') fake.return_value.parser.print_help.assert_called_once_with() def test_run_with_unknown_subcommand_arg_prints_help(self): - command = commands.Command() + command = core.Command() with patch.object(command, 'print_help') as print_help: command.run('invalid-command-name') print_help.assert_called_once_with() def test_stdout_may_be_redirected(self): - class Fake(commands.Subcommand): + class Fake(core.Subcommand): def run(self, args): self.stdout.write("standard output stuff") self.stdout.flush() - command = commands.Command() + command = core.Command() fake = command.subcommands['fake'] = Fake tmp = TempIO() config_path = tmp.putfile('test.ini', '') @@ -121,11 +117,11 @@ class TestCommand(TestCase): self.assertEqual(f.read(), "standard output stuff") def test_stderr_may_be_redirected(self): - class Fake(commands.Subcommand): + class Fake(core.Subcommand): def run(self, args): self.stderr.write("standard error stuff") self.stderr.flush() - command = commands.Command() + command = core.Command() fake = command.subcommands['fake'] = Fake tmp = TempIO() config_path = tmp.putfile('test.ini', '') @@ -145,22 +141,22 @@ class TestCommand(TestCase): class TestSubcommand(TestCase): def test_repr(self): - command = commands.Command() - subcommand = commands.Subcommand(command) + command = core.Command() + subcommand = core.Subcommand(command) subcommand.name = 'fake-command' self.assertEqual(repr(subcommand), "Subcommand(name=u'fake-command')") def test_add_parser_args_does_nothing(self): - command = commands.Command() - subcommand = commands.Subcommand(command) + command = core.Command() + subcommand = core.Subcommand(command) # Not sure this is really the way to test this, but... self.assertEqual(len(subcommand.parser._action_groups[0]._actions), 1) subcommand.add_parser_args(subcommand.parser) self.assertEqual(len(subcommand.parser._action_groups[0]._actions), 1) def test_run_not_implemented(self): - command = commands.Command() - subcommand = commands.Subcommand(command) + command = core.Command() + subcommand = core.Subcommand(command) args = subcommand.parser.parse_args([]) self.assertRaises(NotImplementedError, subcommand.run, args) @@ -177,7 +173,7 @@ class TestAddUser(DataTestCase): self.session.add(model.User(username='fred')) self.session.commit() self.assertEqual(self.session.query(model.User).count(), 1) - commands.main('adduser', '--no-init', '--stderr', self.stderr_path, 'fred') + core.main('adduser', '--no-init', '--stderr', self.stderr_path, 'fred') with open(self.stderr_path) as f: self.assertEqual(f.read(), "User 'fred' already exists.\n") self.assertEqual(self.session.query(model.User).count(), 1) @@ -186,7 +182,7 @@ class TestAddUser(DataTestCase): self.assertEqual(self.session.query(model.User).count(), 0) with patch('rattail.commands.core.getpass') as getpass: getpass.side_effect = KeyboardInterrupt - commands.main('adduser', '--no-init', '--stderr', self.stderr_path, 'fred') + core.main('adduser', '--no-init', '--stderr', self.stderr_path, 'fred') with open(self.stderr_path) as f: self.assertEqual(f.read(), "\nOperation was canceled.\n") self.assertEqual(self.session.query(model.User).count(), 0) @@ -195,7 +191,7 @@ class TestAddUser(DataTestCase): self.assertEqual(self.session.query(model.User).count(), 0) with patch('rattail.commands.core.getpass') as getpass: getpass.return_value = 'fredpass' - commands.main('adduser', '--no-init', '--stdout', self.stdout_path, 'fred') + core.main('adduser', '--no-init', '--stdout', self.stdout_path, 'fred') with open(self.stdout_path) as f: self.assertEqual(f.read(), "Created user: fred\n") fred = self.session.query(model.User).one() @@ -208,55 +204,12 @@ class TestAddUser(DataTestCase): self.assertEqual(self.session.query(model.User).count(), 0) with patch('rattail.commands.core.getpass') as getpass: getpass.return_value = 'fredpass' - commands.main('adduser', '--no-init', '--stdout', self.stdout_path, 'fred', '--administrator') + core.main('adduser', '--no-init', '--stdout', self.stdout_path, 'fred', '--administrator') fred = self.session.query(model.User).one() self.assertEqual(len(fred.roles), 1) self.assertEqual(fred.roles[0].name, 'Administrator') -# TODO: more broken tests..ugh. these aren't very good or else i might bother -# fixing them... -# class TestDatabaseSync(TestCase): - -# @patch('rattail.db.sync.linux.start_daemon') -# def test_start_daemon_with_default_args(self, start_daemon): -# commands.main('dbsync', '--no-init', 'start') -# start_daemon.assert_called_once_with(None, None, True) - -# @patch('rattail.db.sync.linux.start_daemon') -# def test_start_daemon_with_explicit_args(self, start_daemon): -# tmp = TempIO() -# pid_path = tmp.putfile('test.pid', '') -# commands.main('dbsync', '--no-init', '--pidfile', pid_path, '--do-not-daemonize', 'start') -# start_daemon.assert_called_once_with(None, pid_path, False) - -# @patch('rattail.db.sync.linux.start_daemon') -# def test_keyboard_interrupt_raises_error_when_daemonized(self, start_daemon): -# start_daemon.side_effect = KeyboardInterrupt -# self.assertRaises(KeyboardInterrupt, commands.main, 'dbsync', '--no-init', 'start') - -# @patch('rattail.db.sync.linux.start_daemon') -# def test_keyboard_interrupt_handled_gracefully_when_not_daemonized(self, start_daemon): -# tmp = TempIO() -# stderr_path = tmp.putfile('stderr.txt', '') -# start_daemon.side_effect = KeyboardInterrupt -# commands.main('dbsync', '--no-init', '--stderr', stderr_path, '--do-not-daemonize', 'start') -# with open(stderr_path) as f: -# self.assertEqual(f.read(), "Interrupted.\n") - -# @patch('rattail.db.sync.linux.stop_daemon') -# def test_stop_daemon_with_default_args(self, stop_daemon): -# commands.main('dbsync', '--no-init', 'stop') -# stop_daemon.assert_called_once_with(None, None) - -# @patch('rattail.db.sync.linux.stop_daemon') -# def test_stop_daemon_with_explicit_args(self, stop_daemon): -# tmp = TempIO() -# pid_path = tmp.putfile('test.pid', '') -# commands.main('dbsync', '--no-init', '--pidfile', pid_path, 'stop') -# stop_daemon.assert_called_once_with(None, pid_path) - - class TestDump(DataTestCase): def setUp(self): @@ -268,14 +221,14 @@ class TestDump(DataTestCase): def test_unknown_model_cannot_be_dumped(self): tmp = TempIO() stderr_path = tmp.putfile('stderr.txt', '') - self.assertRaises(SystemExit, commands.main, '--no-init', '--stderr', stderr_path, 'dump', 'NoSuchModel') + self.assertRaises(SystemExit, core.main, '--no-init', '--stderr', stderr_path, 'dump', 'NoSuchModel') with open(stderr_path) as f: self.assertEqual(f.read(), "Unknown model: NoSuchModel\n") def test_dump_goes_to_stdout_by_default(self): tmp = TempIO() stdout_path = tmp.putfile('stdout.txt', '') - commands.main('--no-init', '--stdout', stdout_path, 'dump', 'Product') + core.main('--no-init', '--stdout', stdout_path, 'dump', 'Product') with open(stdout_path, 'rb') as csv_file: reader = csv.DictReader(csv_file) upcs = [row['upc'] for row in reader] @@ -286,7 +239,7 @@ class TestDump(DataTestCase): def test_dump_goes_to_file_if_so_invoked(self): tmp = TempIO() output_path = tmp.putfile('output.txt', '') - commands.main('--no-init', 'dump', 'Product', '--output', output_path) + core.main('--no-init', 'dump', 'Product', '--output', output_path) with open(output_path, 'rb') as csv_file: reader = csv.DictReader(csv_file) upcs = [row['upc'] for row in reader] diff --git a/rattail/tests/commands/test_importing.py b/rattail/tests/commands/test_importing.py new file mode 100644 index 0000000000000000000000000000000000000000..bfaf4f6fbbecda9e3470361cfe50ea1dd8687ad8 --- /dev/null +++ b/rattail/tests/commands/test_importing.py @@ -0,0 +1,209 @@ +# -*- coding: utf-8 -*- + +from __future__ import unicode_literals, absolute_import + +import argparse +from unittest import TestCase + +from mock import Mock, patch + +from rattail.commands import importing +from rattail.importing import ImportHandler +from rattail.config import RattailConfig +from rattail.tests.importing import ImporterTester +from rattail.tests.importing.test_handlers import MockImportHandler +from rattail.tests.importing.test_importers import MockImporter + + +class MockImport(importing.ImportSubcommand): + + def get_handler_factory(self): + return MockImportHandler + + +class TestImportSubcommandBasics(TestCase): + + def test_get_handler_factory(self): + command = importing.ImportSubcommand() + self.assertRaises(NotImplementedError, command.get_handler_factory) + + def test_get_handler(self): + + # no config + command = MockImport() + handler = command.get_handler() + self.assertIs(type(handler), MockImportHandler) + self.assertIsNone(handler.config) + + # with config + config = RattailConfig() + command = MockImport(config=config) + handler = command.get_handler() + self.assertIs(type(handler), MockImportHandler) + self.assertIs(handler.config, config) + + def test_add_parser_args(self): + # TODO: this doesn't really test anything, but does give some coverage.. + + # no handler + command = importing.ImportSubcommand() + parser = argparse.ArgumentParser() + command.add_parser_args(parser) + + # with handler + command = MockImport() + parser = argparse.ArgumentParser() + command.add_parser_args(parser) + + +class TestImportSubcommandRun(ImporterTester, TestCase): + + sample_data = { + '16oz': {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"}, + '32oz': {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"}, + '1gal': {'upc': '00074305011283', 'description': "Apple Cider Vinegar 1gal"}, + } + + def setUp(self): + self.command = MockImport() + self.handler = MockImportHandler() + self.importer = MockImporter() + + def import_data(self, **kwargs): + models = kwargs.pop('models', []) + kwargs.setdefault('dry_run', False) + args = argparse.Namespace(models=models, **kwargs) + + # must modify our importer in-place since we need the handler to return + # that specific instance, below (because the host/local data context + # managers reference that instance directly) + self.importer._setup(**kwargs) + with patch.object(self.command, 'get_handler', Mock(return_value=self.handler)): + with patch.object(self.handler, 'get_importer', Mock(return_value=self.importer)): + self.command.run(args) + + if self.handler._result: + self.result = self.handler._result['Product'] + else: + self.result = [], [], [] + + def test_create(self): + local = self.copy_data() + del local['32oz'] + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data() + self.assert_import_created('32oz') + self.assert_import_updated() + self.assert_import_deleted() + + def test_update(self): + local = self.copy_data() + local['16oz']['description'] = "wrong description" + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data() + self.assert_import_created() + self.assert_import_updated('16oz') + self.assert_import_deleted() + + def test_delete(self): + local = self.copy_data() + local['bogus'] = {'upc': '00000000000000', 'description': "Delete Me"} + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data() + self.assert_import_created() + self.assert_import_updated() + self.assert_import_deleted('bogus') + + def test_duplicate(self): + host = self.copy_data() + host['32oz-dupe'] = host['32oz'] + with self.host_data(host): + with self.local_data(self.sample_data): + self.import_data() + self.assert_import_created() + self.assert_import_updated() + self.assert_import_deleted() + + def test_max_create(self): + local = self.copy_data() + del local['16oz'] + del local['1gal'] + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data(max_create=1) + self.assert_import_created('16oz') + self.assert_import_updated() + self.assert_import_deleted() + + def test_max_total_create(self): + local = self.copy_data() + del local['16oz'] + del local['1gal'] + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data(max_total=1) + self.assert_import_created('16oz') + self.assert_import_updated() + self.assert_import_deleted() + + def test_max_update(self): + local = self.copy_data() + local['16oz']['description'] = "wrong" + local['1gal']['description'] = "wrong" + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data(max_update=1) + self.assert_import_created() + self.assert_import_updated('16oz') + self.assert_import_deleted() + + def test_max_total_update(self): + local = self.copy_data() + local['16oz']['description'] = "wrong" + local['1gal']['description'] = "wrong" + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data(max_total=1) + self.assert_import_created() + self.assert_import_updated('16oz') + self.assert_import_deleted() + + def test_max_delete(self): + local = self.copy_data() + local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"} + local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"} + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data(max_delete=1) + self.assert_import_created() + self.assert_import_updated() + self.assert_import_deleted('bogus1') + + def test_max_total_delete(self): + local = self.copy_data() + local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"} + local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"} + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data(max_total=1) + self.assert_import_created() + self.assert_import_updated() + self.assert_import_deleted('bogus1') + + def test_dry_run(self): + local = self.copy_data() + del local['32oz'] + local['16oz']['description'] = "wrong description" + local['bogus'] = {'upc': '00000000000000', 'description': "Delete Me"} + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data(dry_run=True) + # TODO: maybe need a way to confirm no changes actually made due to dry + # run; currently results still reflect "proposed" changes. this rather + # bogus test is here just for coverage sake + self.assert_import_created('32oz') + self.assert_import_updated('16oz') + self.assert_import_deleted('bogus') diff --git a/rattail/tests/importing/test_handlers.py b/rattail/tests/importing/test_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..2243be6a2aabc60cd7c48f7f379c7f8b9491b3a6 --- /dev/null +++ b/rattail/tests/importing/test_handlers.py @@ -0,0 +1,335 @@ +# -*- coding: utf-8 -*- + +from __future__ import unicode_literals, absolute_import + +from unittest import TestCase + +from sqlalchemy import orm +from mock import patch, Mock + +from rattail.importing import handlers, Importer +from rattail.config import RattailConfig +from rattail.tests.importing import ImporterTester +from rattail.tests.importing.test_importers import MockImporter + + +class TestImportHandlerBasics(TestCase): + + def test_init(self): + + # vanilla + handler = handlers.ImportHandler() + self.assertEqual(handler.importers, {}) + self.assertEqual(handler.get_importers(), {}) + self.assertEqual(handler.get_importer_keys(), []) + self.assertEqual(handler.get_default_keys(), []) + + # with config + handler = handlers.ImportHandler() + self.assertIsNone(handler.config) + config = RattailConfig() + handler = handlers.ImportHandler(config=config) + self.assertIs(handler.config, config) + + # extra kwarg + handler = handlers.ImportHandler() + self.assertRaises(AttributeError, getattr, handler, 'foo') + handler = handlers.ImportHandler(foo='bar') + self.assertEqual(handler.foo, 'bar') + + def test_get_importer(self): + get_importers = Mock(return_value={'foo': Importer}) + + # no importers + handler = handlers.ImportHandler() + self.assertIsNone(handler.get_importer('foo')) + + # no config + with patch.object(handlers.ImportHandler, 'get_importers', get_importers): + handler = handlers.ImportHandler() + importer = handler.get_importer('foo') + self.assertIs(type(importer), Importer) + self.assertIsNone(importer.config) + self.assertIs(importer.handler, handler) + + # with config + config = RattailConfig() + with patch.object(handlers.ImportHandler, 'get_importers', get_importers): + handler = handlers.ImportHandler(config=config) + importer = handler.get_importer('foo') + self.assertIs(type(importer), Importer) + self.assertIs(importer.config, config) + self.assertIs(importer.handler, handler) + + # with extra kwarg + with patch.object(handlers.ImportHandler, 'get_importers', get_importers): + handler = handlers.ImportHandler() + importer = handler.get_importer('foo') + self.assertRaises(AttributeError, getattr, importer, 'bar') + importer = handler.get_importer('foo', bar='baz') + self.assertEqual(importer.bar, 'baz') + + def test_get_importer_kwargs(self): + + # empty by default + handler = handlers.ImportHandler() + self.assertEqual(handler.get_importer_kwargs('foo'), {}) + + # extra kwargs are preserved + handler = handlers.ImportHandler() + self.assertEqual(handler.get_importer_kwargs('foo', bar='baz'), {'bar': 'baz'}) + + +###################################################################### +# fake import handler, tested mostly for basic coverage +###################################################################### + +class MockImportHandler(handlers.ImportHandler): + + def get_importers(self): + return {'Product': MockImporter} + + def import_data(self, *keys, **kwargs): + result = super(MockImportHandler, self).import_data(*keys, **kwargs) + self._result = result + return result + + +class TestImportHandlerImportData(ImporterTester, TestCase): + + sample_data = { + '16oz': {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"}, + '32oz': {'upc': '00074305001321', 'description': "Apple Cider Vinegar 32oz"}, + '1gal': {'upc': '00074305011283', 'description': "Apple Cider Vinegar 1gal"}, + } + + def setUp(self): + self.handler = MockImportHandler() + self.importer = MockImporter() + + def import_data(self, **kwargs): + # must modify our importer in-place since we need the handler to return + # that specific instance, below (because the host/local data context + # managers reference that instance directly) + self.importer._setup(**kwargs) + with patch.object(self.handler, 'get_importer', Mock(return_value=self.importer)): + result = self.handler.import_data('Product', **kwargs) + if result: + self.result = result['Product'] + else: + self.result = [], [], [] + + def test_invalid_importer_key_is_ignored(self): + handler = handlers.ImportHandler() + self.assertNotIn('InvalidKey', handler.importers) + self.assertEqual(handler.import_data('InvalidKey'), {}) + + def test_create(self): + local = self.copy_data() + del local['32oz'] + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data() + self.assert_import_created('32oz') + self.assert_import_updated() + self.assert_import_deleted() + + def test_update(self): + local = self.copy_data() + local['16oz']['description'] = "wrong description" + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data() + self.assert_import_created() + self.assert_import_updated('16oz') + self.assert_import_deleted() + + def test_delete(self): + local = self.copy_data() + local['bogus'] = {'upc': '00000000000000', 'description': "Delete Me"} + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data() + self.assert_import_created() + self.assert_import_updated() + self.assert_import_deleted('bogus') + + def test_duplicate(self): + host = self.copy_data() + host['32oz-dupe'] = host['32oz'] + with self.host_data(host): + with self.local_data(self.sample_data): + self.import_data() + self.assert_import_created() + self.assert_import_updated() + self.assert_import_deleted() + + def test_max_create(self): + local = self.copy_data() + del local['16oz'] + del local['1gal'] + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data(max_create=1) + self.assert_import_created('16oz') + self.assert_import_updated() + self.assert_import_deleted() + + def test_max_total_create(self): + local = self.copy_data() + del local['16oz'] + del local['1gal'] + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data(max_total=1) + self.assert_import_created('16oz') + self.assert_import_updated() + self.assert_import_deleted() + + def test_max_update(self): + local = self.copy_data() + local['16oz']['description'] = "wrong" + local['1gal']['description'] = "wrong" + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data(max_update=1) + self.assert_import_created() + self.assert_import_updated('16oz') + self.assert_import_deleted() + + def test_max_total_update(self): + local = self.copy_data() + local['16oz']['description'] = "wrong" + local['1gal']['description'] = "wrong" + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data(max_total=1) + self.assert_import_created() + self.assert_import_updated('16oz') + self.assert_import_deleted() + + def test_max_delete(self): + local = self.copy_data() + local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"} + local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"} + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data(max_delete=1) + self.assert_import_created() + self.assert_import_updated() + self.assert_import_deleted('bogus1') + + def test_max_total_delete(self): + local = self.copy_data() + local['bogus1'] = {'upc': '00000000000001', 'description': "Delete Me"} + local['bogus2'] = {'upc': '00000000000002', 'description': "Delete Me"} + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data(max_total=1) + self.assert_import_created() + self.assert_import_updated() + self.assert_import_deleted('bogus1') + + def test_dry_run(self): + local = self.copy_data() + del local['32oz'] + local['16oz']['description'] = "wrong description" + local['bogus'] = {'upc': '00000000000000', 'description': "Delete Me"} + with self.host_data(self.sample_data): + with self.local_data(local): + self.import_data(dry_run=True) + # TODO: maybe need a way to confirm no changes actually made due to dry + # run; currently results still reflect "proposed" changes. this rather + # bogus test is here just for coverage sake + self.assert_import_created('32oz') + self.assert_import_updated('16oz') + self.assert_import_deleted('bogus') + + +Session = orm.sessionmaker() + + +class MockFromSQLAlchemyHandler(handlers.FromSQLAlchemyHandler): + + def make_host_session(self): + return Session() + + +class MockToSQLAlchemyHandler(handlers.ToSQLAlchemyHandler): + + def make_session(self): + return Session() + + +class TestFromSQLAlchemyHandler(TestCase): + + def test_init(self): + handler = handlers.FromSQLAlchemyHandler() + self.assertRaises(NotImplementedError, handler.make_host_session) + + def test_get_importer_kwargs(self): + session = object() + handler = handlers.FromSQLAlchemyHandler(host_session=session) + kwargs = handler.get_importer_kwargs(None) + self.assertEqual(list(kwargs.iterkeys()), ['host_session']) + self.assertIs(kwargs['host_session'], session) + + def test_begin_host_transaction(self): + handler = MockFromSQLAlchemyHandler() + self.assertIsNone(handler.host_session) + handler.begin_host_transaction() + self.assertIsInstance(handler.host_session, orm.Session) + handler.host_session.close() + + def test_commit_host_transaction(self): + # TODO: test actual commit for data changes + session = Session() + handler = handlers.FromSQLAlchemyHandler(host_session=session) + self.assertIs(handler.host_session, session) + handler.commit_host_transaction() + self.assertIsNone(handler.host_session) + + def test_rollback_host_transaction(self): + # TODO: test actual rollback for data changes + session = Session() + handler = handlers.FromSQLAlchemyHandler(host_session=session) + self.assertIs(handler.host_session, session) + handler.rollback_host_transaction() + self.assertIsNone(handler.host_session) + + +class TestToSQLAlchemyHandler(TestCase): + + def test_init(self): + handler = handlers.ToSQLAlchemyHandler() + self.assertRaises(NotImplementedError, handler.make_session) + + def test_get_importer_kwargs(self): + session = object() + handler = handlers.ToSQLAlchemyHandler(session=session) + kwargs = handler.get_importer_kwargs(None) + self.assertEqual(list(kwargs.iterkeys()), ['session']) + self.assertIs(kwargs['session'], session) + + def test_begin_local_transaction(self): + handler = MockToSQLAlchemyHandler() + self.assertIsNone(handler.session) + handler.begin_local_transaction() + self.assertIsInstance(handler.session, orm.Session) + handler.session.close() + + def test_commit_local_transaction(self): + # TODO: test actual commit for data changes + session = Session() + handler = handlers.ToSQLAlchemyHandler(session=session) + self.assertIs(handler.session, session) + handler.commit_local_transaction() + self.assertIsNone(handler.session) + + def test_rollback_local_transaction(self): + # TODO: test actual rollback for data changes + session = Session() + handler = handlers.ToSQLAlchemyHandler(session=session) + self.assertIs(handler.session, session) + handler.rollback_local_transaction() + self.assertIsNone(handler.session)