Files @ 93889b9694f4
Branch filter:

Location: rattail-project/rattail/tests/importing/lib.py

lance
bump: version 0.18.12 → 0.19.0
# -*- coding: utf-8; -*-

import copy
from contextlib import contextmanager
from unittest.mock import patch

from .. import NullProgress


class ImporterTester(object):
    """
    Mixin for importer test suites.
    """
    handler_class = None
    importer_class = None
    sample_data = {}

    def make_handler(self, **kwargs):
        if 'config' not in kwargs and hasattr(self, 'config'):
            kwargs['config'] = self.config
        return self.handler_class(**kwargs)

    def make_importer(self, **kwargs):
        if 'config' not in kwargs and hasattr(self, 'config'):
            kwargs['config'] = self.config
        kwargs.setdefault('progress', NullProgress)
        return self.importer_class(**kwargs)

    def copy_data(self):
        return copy.deepcopy(self.sample_data)

    @contextmanager
    def host_data(self, data):
        self._host_data = data
        host_data = [self.importer.normalize_host_object(obj) for obj in data.values()]
        with patch.object(self.importer, 'normalize_host_data') as normalize:
            normalize.return_value = host_data
            yield

    @contextmanager
    def local_data(self, data):
        self._local_data = data
        local_data = {}
        for key, obj in data.items():
            normal = self.importer.normalize_local_object(obj)
            local_data[self.importer.get_key(normal)] = {'object': obj, 'data': normal}
        with patch.object(self.importer, 'cache_local_data') as cache:
            cache.return_value = local_data
            yield

    def import_data(self, host_data=None, local_data=None, **kwargs):
        if host_data is None:
            host_data = self.sample_data
        if local_data is None:
            local_data = self.sample_data
        with self.host_data(host_data):
            with self.local_data(local_data):
                self.result = self.importer.import_data(**kwargs)

    def assert_import_created(self, *keys):
        created, updated, deleted = self.result
        self.assertEqual(len(created), len(keys))
        for key in keys:
            key = self.importer.get_key(self._host_data[key])
            found = False
            for local_object, host_data in created:
                if self.importer.get_key(host_data) == key:
                    found = True
                    break
            if not found:
                raise self.failureException("Key {} not created when importing with {}".format(key, self.importer))

    def assert_import_updated(self, *keys):
        created, updated, deleted = self.result
        self.assertEqual(len(updated), len(keys))
        for key in keys:
            key = self.importer.get_key(self._host_data[key])
            found = False
            for local_object, local_data, host_data in updated:
                if self.importer.get_key(local_data) == key:
                    found = True
                    break
            if not found:
                raise self.failureException("Key {} not updated when importing with {}".format(key, self.importer))

    def assert_import_deleted(self, *keys):
        created, updated, deleted = self.result
        self.assertEqual(len(deleted), len(keys))
        for key in keys:
            key = self.importer.get_key(self._local_data[key])
            found = False
            for local_object, local_data in deleted:
                if self.importer.get_key(local_data) == key:
                    found = True
                    break
            if not found:
                raise self.failureException("Key {} not deleted when importing with {}".format(key, self.importer))