Files @ 1e2491a5a9fa
Branch filter:

Location: rattail-project/rattail/tests/db/sync/test_init.py

lance
More db sync fixes (tested on Windows).

import warnings
from unittest import TestCase
from mock import patch, call, Mock, DEFAULT

from sqlalchemy import create_engine
from sqlalchemy.exc import OperationalError, SAWarning

from rattail.db import sync
from rattail.db import Session
from rattail.db import model


class SynchronizerTests(TestCase):

    def setUp(self):
        self.local_engine = create_engine('sqlite://')
        self.remote_engines = {
            'one': create_engine('sqlite://'),
            'two': create_engine('sqlite://'),
            }
        model.Base.metadata.create_all(bind=self.local_engine)
        model.Base.metadata.create_all(bind=self.remote_engines['one'])
        model.Base.metadata.create_all(bind=self.remote_engines['two'])

    def test_init(self):
        synchronizer = sync.Synchronizer(self.local_engine, self.remote_engines)
        self.assertIs(synchronizer.local_engine, self.local_engine)
        self.assertIs(synchronizer.remote_engines, self.remote_engines)

    def test_loop(self):

        class FakeOperationalError(OperationalError):
            def __init__(self, connection_invalidated):
                self.connection_invalidated = connection_invalidated

        synchronizer = sync.Synchronizer(self.local_engine, self.remote_engines)
        with patch.object(synchronizer, 'sleep') as sleep:
            with patch.object(synchronizer, 'synchronize') as synchronize:

                synchronize.side_effect = [1, 2, 3, FakeOperationalError(True),
                                                   5, 6, 7, FakeOperationalError(False)]
                self.assertRaises(FakeOperationalError, synchronizer.loop)
                self.assertEqual(synchronize.call_count, 8)
                self.assertEqual(sleep.call_args_list, [
                        call(3), call(3), call(3), call(5), call(3),
                        call(3), call(3), call(3)])

    def test_synchronize(self):
        synchronizer = sync.Synchronizer(self.local_engine, self.remote_engines)

        with patch.object(synchronizer, 'synchronize_changes') as synchronize_changes:

            # no changes
            synchronizer.synchronize()
            self.assertFalse(synchronize_changes.called)

            # some changes
            local_session = Session(bind=self.local_engine)
            product = model.Product()
            local_session.add(product)
            local_session.flush()
            local_session.add(model.Change(class_name='Product', uuid=product.uuid, deleted=False))
            product = model.Product()
            local_session.add(product)
            local_session.flush()
            local_session.add(model.Change(class_name='Product', uuid=product.uuid, deleted=False))
            local_session.commit()
            synchronizer.synchronize()
            self.assertEqual(synchronize_changes.call_count, 1)
            # call_args is a tuple of (args, kwargs) - first element of args should be our 2 changes
            self.assertEqual(len(synchronize_changes.call_args[0][0]), 2)
            self.assertIsInstance(synchronize_changes.call_args[0][0][0], model.Change)

    def test_synchronize_changes(self):
        synchronizer = sync.Synchronizer(self.local_engine, self.remote_engines)

        local_session = Session(bind=self.local_engine)
        remote_sessions = {
            'one': Session(bind=self.remote_engines['one']),
            'two': Session(bind=self.remote_engines['two']),
            }

        # no changes; nothing should happen but make sure nothing blows up also
        local_changes = []
        synchronizer.synchronize_changes(local_changes, local_session, remote_sessions)

        # add a product, with change
        product = model.Product()
        local_session.add(product)
        local_session.flush()
        change = model.Change(class_name='Product', uuid=product.uuid, deleted=False)
        local_session.add(change)
        local_session.flush()
        self.assertEqual(local_session.query(model.Product).count(), 1)
        self.assertEqual(local_session.query(model.Change).count(), 1)

        # remote sessions don't have the product yet
        self.assertEqual(remote_sessions['one'].query(model.Product).count(), 0)
        self.assertEqual(remote_sessions['two'].query(model.Product).count(), 0)

        # sync the change
        synchronizer.synchronize_changes([change], local_session, remote_sessions)
        self.assertEqual(local_session.query(model.Product).count(), 1)
        self.assertEqual(local_session.query(model.Change).count(), 0)

        # remote session 'one' has the product
        self.assertEqual(remote_sessions['one'].query(model.Product).count(), 1)
        remote_product_1 = remote_sessions['one'].query(model.Product).one()
        self.assertEqual(remote_product_1.uuid, product.uuid)

        # remote session 'two' has the product
        self.assertEqual(remote_sessions['two'].query(model.Product).count(), 1)
        remote_product_2 = remote_sessions['two'].query(model.Product).one()
        self.assertEqual(remote_product_2.uuid, product.uuid)

        # delete the product (new change)
        local_session.delete(product)
        change = model.Change(class_name='Product', uuid=product.uuid, deleted=True)
        local_session.add(change)
        local_session.flush()
        self.assertEqual(local_session.query(model.Product).count(), 0)
        self.assertEqual(local_session.query(model.Change).count(), 1)

        # sync the change
        synchronizer.synchronize_changes([change], local_session, remote_sessions)
        self.assertEqual(local_session.query(model.Change).count(), 0)

        # remote sessions no longer have the product
        self.assertEqual(remote_sessions['one'].query(model.Product).count(), 0)
        self.assertEqual(remote_sessions['two'].query(model.Product).count(), 0)

    def test_merge_instance(self):

        class FakeClass(object):
            pass

        synchronizer = sync.Synchronizer(self.local_engine, self.remote_engines)
        session = Mock()
        instance = FakeClass()

        self.assertFalse(hasattr(synchronizer, 'merge_FakeClass'))
        synchronizer.merge_instance(session, instance)
        session.merge.assert_called_once_with(instance)

        synchronizer.merge_FakeClass = Mock()
        synchronizer.merge_instance(session, instance)
        synchronizer.merge_FakeClass.assert_called_once_with(session, instance)

    def test_merge_Product(self):
        synchronizer = sync.Synchronizer(self.local_engine, self.remote_engines)

        with warnings.catch_warnings():
            warnings.filterwarnings(
                'ignore',
                r'^Dialect sqlite\+pysqlite does \*not\* support Decimal '
                'objects natively\, and SQLAlchemy must convert from floating '
                'point - rounding errors and other issues may occur\. Please '
                'consider storing Decimal numbers as strings or integers on '
                'this platform for lossless storage\.$',
                SAWarning, r'^sqlalchemy\.types$')

            # no prices
            local_session = Session(bind=self.local_engine)
            remote_session = Session(bind=self.remote_engines['one'])
            source_product = model.Product()
            local_session.add(source_product)
            local_session.flush()
            self.assertIsNone(source_product.regular_price_uuid)
            self.assertIsNone(source_product.regular_price)
            self.assertIsNone(source_product.current_price_uuid)
            self.assertIsNone(source_product.current_price)
            target_product = synchronizer.merge_Product(remote_session, source_product)
            self.assertIsNotNone(target_product)
            self.assertIsNot(source_product, target_product)
            self.assertEqual(source_product.uuid, target_product.uuid)
            self.assertIsNone(target_product.regular_price_uuid)
            self.assertIsNone(target_product.regular_price)
            self.assertIsNone(target_product.current_price_uuid)
            self.assertIsNone(target_product.current_price)
            local_session.rollback()
            local_session.close()
            remote_session.rollback()
            remote_session.close()

            # regular price
            local_session = Session(bind=self.local_engine)
            remote_session = Session(bind=self.remote_engines['one'])
            source_product = model.Product()
            regular_price = model.ProductPrice()
            source_product.prices.append(regular_price)
            source_product.regular_price = regular_price
            local_session.add(source_product)
            local_session.flush()
            self.assertIsNotNone(source_product.regular_price_uuid)
            self.assertIsNotNone(source_product.regular_price)
            target_product = synchronizer.merge_Product(remote_session, source_product)
            self.assertEqual(target_product.regular_price_uuid, source_product.regular_price_uuid)
            self.assertIsNotNone(target_product.regular_price)
            local_session.rollback()
            local_session.close()
            remote_session.rollback()
            remote_session.close()

            # current price
            local_session = Session(bind=self.local_engine)
            remote_session = Session(bind=self.remote_engines['one'])
            source_product = model.Product()
            current_price = model.ProductPrice()
            source_product.prices.append(current_price)
            source_product.current_price = current_price
            local_session.add(source_product)
            local_session.flush()
            self.assertIsNotNone(source_product.current_price_uuid)
            self.assertIsNotNone(source_product.current_price)
            target_product = synchronizer.merge_Product(remote_session, source_product)
            self.assertEqual(target_product.current_price_uuid, source_product.current_price_uuid)
            self.assertIsNotNone(target_product.current_price)
            local_session.rollback()
            local_session.close()
            remote_session.rollback()
            remote_session.close()

    def test_delete_instance(self):

        class FakeClass(object):
            pass

        synchronizer = sync.Synchronizer(self.local_engine, self.remote_engines)
        session = Mock()
        instance = FakeClass()

        self.assertFalse(hasattr(synchronizer, 'delete_FakeClass'))
        synchronizer.delete_instance(session, instance)
        session.delete.assert_called_once_with(instance)

        synchronizer.delete_FakeClass = Mock()
        synchronizer.delete_instance(session, instance)
        synchronizer.delete_FakeClass.assert_called_once_with(session, instance)

    def test_delete_Department(self):
        synchronizer = sync.Synchronizer(self.local_engine, self.remote_engines)

        with warnings.catch_warnings():
            warnings.filterwarnings(
                'ignore',
                r'^Dialect sqlite\+pysqlite does \*not\* support Decimal '
                'objects natively\, and SQLAlchemy must convert from floating '
                'point - rounding errors and other issues may occur\. Please '
                'consider storing Decimal numbers as strings or integers on '
                'this platform for lossless storage\.$',
                SAWarning, r'^sqlalchemy\.types$')

            session = Session(bind=self.local_engine)
            department = model.Department()
            department.subdepartments.append(model.Subdepartment())
            session.add(department)
            session.flush()
            self.assertEqual(session.query(model.Subdepartment).count(), 1)
            subdepartment = session.query(model.Subdepartment).one()
            self.assertEqual(subdepartment.department_uuid, department.uuid)
            synchronizer.delete_Department(session, department)
            self.assertEqual(session.query(model.Subdepartment).count(), 1)
            subdepartment = session.query(model.Subdepartment).one()
            self.assertIsNone(subdepartment.department_uuid)
            session.rollback()
            session.close()

            session = Session(bind=self.local_engine)
            department = model.Department()
            product = model.Product(department=department)
            session.add(product)
            session.flush()
            product = session.query(model.Product).one()
            self.assertEqual(product.department_uuid, department.uuid)
            synchronizer.delete_Department(session, department)
            product = session.query(model.Product).one()
            self.assertIsNone(product.department_uuid)
            session.rollback()
            session.close()

    def test_delete_Subdepartment(self):
        synchronizer = sync.Synchronizer(self.local_engine, self.remote_engines)

        session = Session(bind=self.local_engine)
        subdepartment = model.Subdepartment()
        product = model.Product(subdepartment=subdepartment)
        session.add(product)
        session.flush()
        product = session.query(model.Product).one()
        self.assertEqual(product.subdepartment_uuid, subdepartment.uuid)
        synchronizer.delete_Subdepartment(session, subdepartment)
        product = session.query(model.Product).one()
        self.assertIsNone(product.subdepartment_uuid)
        session.rollback()
        session.close()

    def test_delete_Family(self):
        synchronizer = sync.Synchronizer(self.local_engine, self.remote_engines)

        session = Session(bind=self.local_engine)
        family = model.Family()
        product = model.Product(family=family)
        session.add(product)
        session.flush()
        product = session.query(model.Product).one()
        self.assertEqual(product.family_uuid, family.uuid)
        synchronizer.delete_Family(session, family)
        product = session.query(model.Product).one()
        self.assertIsNone(product.family_uuid)
        session.rollback()
        session.close()

    def test_delete_Vendor(self):
        synchronizer = sync.Synchronizer(self.local_engine, self.remote_engines)

        session = Session(bind=self.local_engine)
        vendor = model.Vendor()
        product = model.Product()
        product.costs.append(model.ProductCost(vendor=vendor))
        session.add(product)
        session.flush()
        cost = session.query(model.ProductCost).one()
        self.assertEqual(cost.vendor_uuid, vendor.uuid)
        synchronizer.delete_Vendor(session, vendor)
        self.assertEqual(session.query(model.ProductCost).count(), 0)
        session.rollback()
        session.close()

    def test_delete_CustomerGroup(self):
        synchronizer = sync.Synchronizer(self.local_engine, self.remote_engines)

        session = Session(bind=self.local_engine)
        group = model.CustomerGroup()
        customer = model.Customer()
        customer.groups.append(group)
        session.add(customer)
        session.flush()
        assignment = session.query(model.CustomerGroupAssignment).one()
        self.assertEqual(assignment.customer_uuid, customer.uuid)
        self.assertEqual(assignment.group_uuid, group.uuid)
        synchronizer.delete_CustomerGroup(session, group)
        self.assertEqual(session.query(model.CustomerGroupAssignment).count(), 0)
        session.rollback()
        session.close()


class ModuleTests(TestCase):

    @patch('rattail.db.sync.edbob')
    def test_get_sync_engines(self, edbob):

        # nothing configured
        edbob.config.get.return_value = None
        self.assertIsNone(sync.get_sync_engines())

        # fake config with 2 out of 3 engines synced
        edbob.engines = {
            'one': 'first',
            'two': 'second',
            'three': 'third',
            }
        edbob.config.get.return_value = 'one, two'
        engines = sync.get_sync_engines()
        self.assertEqual(engines, {'one': 'first', 'two': 'second'})

    def test_dependency_sort(self):

        # Product depends on Department, so Department should come first.
        self.assertEqual(sync.dependency_sort('Department', 'Product'), -1)
        self.assertEqual(sync.dependency_sort('Product', 'Department'), 1)

        # Product has dependencies (e.g. Department), so should come after
        # e.g. Store even though there is no direct connection, since Store has
        # no dependencies.
        self.assertEqual(sync.dependency_sort('Store', 'Product'), -1)
        self.assertEqual(sync.dependency_sort('Product', 'Store'), 1)

        # Sometimes the tie can't be broken...
        self.assertEqual(sync.dependency_sort('Store', 'CustomerGroup'), 0)
        self.assertEqual(sync.dependency_sort('Product', 'Product'), 0)

    @patch.multiple('rattail.db.sync', edbob=DEFAULT, Synchronizer=DEFAULT)
    def test_synchronize_changes(self, edbob, Synchronizer):

        local_engine = Mock()
        remote_engines = Mock()

        # default synchronizer class
        edbob.config.get.return_value = None
        sync.synchronize_changes(local_engine, remote_engines)
        Synchronizer.assert_called_once_with(local_engine, remote_engines)
        Synchronizer.return_value.loop.assert_called_once_with()

        # custom synchronizer class
        edbob.config.get.return_value = 'some_class'
        sync.synchronize_changes(local_engine, remote_engines)
        edbob.load_spec.return_value.assert_called_once_with(local_engine, remote_engines)
        edbob.load_spec.return_value.return_value.loop.assert_called_once_with()