# -*- coding: utf-8 -*-
from __future__ import unicode_literals, absolute_import
import copy
from contextlib import contextmanager
from mock import patch
from rattail.tests import NullProgress
class ImporterTester(object):
"""
Mixin for importer test suites.
"""
importer_class = None
sample_data = {}
def make_importer(self, **kwargs):
if 'config' not in kwargs and hasattr(self, 'config'):
kwargs['config'] = self.config
kwargs.setdefault('progress', NullProgress)
return self.importer_class(**kwargs)
def copy_data(self):
return copy.deepcopy(self.sample_data)
@contextmanager
def host_data(self, data):
self._host_data = data
host_data = [self.importer.normalize_host_object(obj) for obj in data.itervalues()]
with patch.object(self.importer, 'normalize_host_data') as normalize:
normalize.return_value = host_data
yield
@contextmanager
def local_data(self, data):
self._local_data = data
local_data = {}
for key, obj in data.iteritems():
normal = self.importer.normalize_local_object(obj)
local_data[self.importer.get_key(normal)] = {'object': obj, 'data': normal}
with patch.object(self.importer, 'cache_local_data') as cache:
cache.return_value = local_data
yield
def import_data(self, **kwargs):
self.result = self.importer.import_data(**kwargs)
def assert_import_created(self, *keys):
created, updated, deleted = self.result
self.assertEqual(len(created), len(keys))
for key in keys:
key = self.importer.get_key(self._host_data[key])
found = False
for local_object, host_data in created:
if self.importer.get_key(host_data) == key:
found = True
break
if not found:
raise self.failureException("Key {} not created when importing with {}".format(key, self.importer))
def assert_import_updated(self, *keys):
created, updated, deleted = self.result
self.assertEqual(len(updated), len(keys))
for key in keys:
key = self.importer.get_key(self._host_data[key])
found = False
for local_object, local_data, host_data in updated:
if self.importer.get_key(local_data) == key:
found = True
break
if not found:
raise self.failureException("Key {} not updated when importing with {}".format(key, self.importer))
def assert_import_deleted(self, *keys):
created, updated, deleted = self.result
self.assertEqual(len(deleted), len(keys))
for key in keys:
key = self.importer.get_key(self._local_data[key])
found = False
for local_object, local_data in deleted:
if self.importer.get_key(local_data) == key:
found = True
break
if not found:
raise self.failureException("Key {} not deleted when importing with {}".format(key, self.importer))