# -*- coding: utf-8; -*-
import copy
from contextlib import contextmanager
from unittest.mock import patch
from .. import NullProgress
class ImporterTester(object):
"""
Mixin for importer test suites.
"""
handler_class = None
importer_class = None
sample_data = {}
def make_handler(self, **kwargs):
if 'config' not in kwargs and hasattr(self, 'config'):
kwargs['config'] = self.config
return self.handler_class(**kwargs)
def make_importer(self, **kwargs):
if 'config' not in kwargs and hasattr(self, 'config'):
kwargs['config'] = self.config
kwargs.setdefault('progress', NullProgress)
return self.importer_class(**kwargs)
def copy_data(self):
return copy.deepcopy(self.sample_data)
@contextmanager
def host_data(self, data):
self._host_data = data
host_data = [self.importer.normalize_host_object(obj) for obj in data.values()]
with patch.object(self.importer, 'normalize_host_data') as normalize:
normalize.return_value = host_data
yield
@contextmanager
def local_data(self, data):
self._local_data = data
local_data = {}
for key, obj in data.items():
normal = self.importer.normalize_local_object(obj)
local_data[self.importer.get_key(normal)] = {'object': obj, 'data': normal}
with patch.object(self.importer, 'cache_local_data') as cache:
cache.return_value = local_data
yield
def import_data(self, host_data=None, local_data=None, **kwargs):
if host_data is None:
host_data = self.sample_data
if local_data is None:
local_data = self.sample_data
with self.host_data(host_data):
with self.local_data(local_data):
self.result = self.importer.import_data(**kwargs)
def assert_import_created(self, *keys):
created, updated, deleted = self.result
self.assertEqual(len(created), len(keys))
for key in keys:
key = self.importer.get_key(self._host_data[key])
found = False
for local_object, host_data in created:
if self.importer.get_key(host_data) == key:
found = True
break
if not found:
raise self.failureException("Key {} not created when importing with {}".format(key, self.importer))
def assert_import_updated(self, *keys):
created, updated, deleted = self.result
self.assertEqual(len(updated), len(keys))
for key in keys:
key = self.importer.get_key(self._host_data[key])
found = False
for local_object, local_data, host_data in updated:
if self.importer.get_key(local_data) == key:
found = True
break
if not found:
raise self.failureException("Key {} not updated when importing with {}".format(key, self.importer))
def assert_import_deleted(self, *keys):
created, updated, deleted = self.result
self.assertEqual(len(deleted), len(keys))
for key in keys:
key = self.importer.get_key(self._local_data[key])
found = False
for local_object, local_data in deleted:
if self.importer.get_key(local_data) == key:
found = True
break
if not found:
raise self.failureException("Key {} not deleted when importing with {}".format(key, self.importer))