Files @ 1e2491a5a9fa
Branch filter:

Location: rattail-project/rattail/tests/db/test_model.py

lance
More db sync fixes (tested on Windows).

from unittest import TestCase
from . import DataTestCase
from mock import patch, DEFAULT, Mock, MagicMock

from rattail.db.extension import model
from sqlalchemy import String, Boolean, Numeric
from rattail.db.types import GPCType
from sqlalchemy.exc import IntegrityError
from rattail.db import record_changes


class TestBatch(TestCase):

    @patch('rattail.db.extension.model.object_session')
    def test_rowclass(self, object_session):
        object_session.return_value = object_session

        # no row classes to start with
        self.assertEqual(model.Batch._rowclasses, {})

        # basic, empty row class
        batch = model.Batch(uuid='some_uuid')
        batch.get_sqlalchemy_type = Mock(return_value='some_type')
        batch.columns = MagicMock()
        rowclass = batch.rowclass
        self.assertTrue(issubclass(rowclass, model.BatchRow))
        self.assertEqual(model.Batch._rowclasses.keys(), ['some_uuid'])
        self.assertIs(model.Batch._rowclasses['some_uuid'], rowclass)
        self.assertFalse(object_session.flush.called)

        # make sure rowclass.batch works
        object_session.query.return_value.get.return_value = batch
        self.assertIs(rowclass().batch, batch)
        object_session.query.return_value.get.assert_called_once_with('some_uuid')

        # row class with generated uuid and some columns
        batch = model.Batch(uuid=None)
        batch.columns = [model.BatchColumn(name='F01'), model.BatchColumn(name='F02')]
        model.Batch.get_sqlalchemy_type = Mock(return_value=String(length=20))
        def set_uuid():
            batch.uuid = 'fresh_uuid'
        object_session.flush.side_effect = set_uuid
        rowclass = batch.rowclass
        object_session.flush.assert_called_once_with()
        self.assertItemsEqual(model.Batch._rowclasses.keys(), ['some_uuid', 'fresh_uuid'])
        self.assertIs(model.Batch._rowclasses['fresh_uuid'], rowclass)

    def test_get_sqlalchemy_type(self):

        # gpc
        self.assertIsInstance(model.Batch.get_sqlalchemy_type('GPC(14)'), GPCType)

        # boolean
        self.assertIsInstance(model.Batch.get_sqlalchemy_type('FLAG(1)'), Boolean)

        # string
        type_ = model.Batch.get_sqlalchemy_type('CHAR(20)')
        self.assertIsInstance(type_, String)
        self.assertEqual(type_.length, 20)

        # numeric
        type_ = model.Batch.get_sqlalchemy_type('NUMBER(9,3)')
        self.assertIsInstance(type_, Numeric)
        self.assertEqual(type_.precision, 9)
        self.assertEqual(type_.scale, 3)

        # invalid
        self.assertRaises(AssertionError, model.Batch.get_sqlalchemy_type, 'CHAR(9,3)')
        self.assertRaises(AssertionError, model.Batch.get_sqlalchemy_type, 'OMGWTFBBQ')


class TestCustomer(DataTestCase):

    def test_repr(self):
        customer = model.Customer(uuid='whatever')
        self.assertEqual(repr(customer), "Customer(uuid='whatever')")

    def test_unicode(self):
        customer = model.Customer()
        self.assertEqual(unicode(customer), u'None')
        customer = model.Customer(name='Fred')
        self.assertEqual(unicode(customer), u'Fred')

    def test_cascade_delete_assignment(self):
        customer = model.Customer()
        assignment = model.CustomerGroupAssignment(
            customer=customer, group=model.CustomerGroup(), ordinal=1)
        self.session.add_all([customer, assignment])
        self.session.commit()
        self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 1)
        self.session.delete(customer)
        self.session.commit()
        self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 0)


class TestCustomerPerson(DataTestCase):

    def test_repr(self):
        assoc = model.CustomerPerson(uuid='whatever')
        self.assertEqual(repr(assoc), "CustomerPerson(uuid='whatever')")

    def test_customer_required(self):
        assoc = model.CustomerPerson(person=model.Person())
        self.session.add(assoc)
        self.assertRaises(IntegrityError, self.session.commit)
        self.session.rollback()
        self.assertEqual(self.session.query(model.CustomerPerson).count(), 0)
        assoc.customer = model.Customer()
        self.session.add(assoc)
        self.session.commit()
        self.assertEqual(self.session.query(model.CustomerPerson).count(), 1)

    def test_person_required(self):
        assoc = model.CustomerPerson(customer=model.Customer())
        self.session.add(assoc)
        self.assertRaises(IntegrityError, self.session.commit)
        self.session.rollback()
        self.assertEqual(self.session.query(model.CustomerPerson).count(), 0)
        assoc.person = model.Person()
        self.session.add(assoc)
        self.session.commit()
        self.assertEqual(self.session.query(model.CustomerPerson).count(), 1)

    def test_ordinal_autoincrement(self):
        customer = model.Customer()
        self.session.add(customer)
        assoc = model.CustomerPerson(person=model.Person())
        customer._people.append(assoc)
        self.session.commit()
        self.assertEqual(assoc.ordinal, 1)
        assoc = model.CustomerPerson(person=model.Person())
        customer._people.append(assoc)
        self.session.commit()
        self.assertEqual(assoc.ordinal, 2)


class TestCustomerGroupAssignment(DataTestCase):

    def test_repr(self):
        assignment = model.CustomerGroupAssignment(uuid='whatever')
        self.assertEqual(repr(assignment), "CustomerGroupAssignment(uuid='whatever')")

    def test_customer_required(self):
        assignment = model.CustomerGroupAssignment(group=model.CustomerGroup())
        self.session.add(assignment)
        self.assertRaises(IntegrityError, self.session.commit)
        self.session.rollback()
        self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 0)
        assignment.customer = model.Customer()
        self.session.add(assignment)
        self.session.commit()
        self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 1)

    def test_group_required(self):
        assignment = model.CustomerGroupAssignment(customer=model.Customer())
        self.session.add(assignment)
        self.assertRaises(IntegrityError, self.session.commit)
        self.session.rollback()
        self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 0)
        assignment.group = model.CustomerGroup()
        self.session.add(assignment)
        self.session.commit()
        self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 1)

    def test_ordinal_autoincrement(self):
        customer = model.Customer()
        self.session.add(customer)
        assignment = model.CustomerGroupAssignment(group=model.CustomerGroup())
        customer._groups.append(assignment)
        self.session.commit()
        self.assertEqual(assignment.ordinal, 1)
        assignment = model.CustomerGroupAssignment(group=model.CustomerGroup())
        customer._groups.append(assignment)
        self.session.commit()
        self.assertEqual(assignment.ordinal, 2)


class TestCustomerEmailAddress(DataTestCase):

    def test_pop(self):
        customer = model.Customer()
        customer.add_email_address('fred.home@mailinator.com')
        customer.add_email_address('fred.work@mailinator.com')
        self.session.add(customer)
        self.session.commit()
        self.assertEqual(self.session.query(model.Customer).count(), 1)
        self.assertEqual(self.session.query(model.CustomerEmailAddress).count(), 2)

        while customer.emails:
            customer.emails.pop()
        self.session.commit()
        self.assertEqual(self.session.query(model.Customer).count(), 1)
        self.assertEqual(self.session.query(model.CustomerEmailAddress).count(), 0)

        # changes weren't being recorded
        self.assertEqual(self.session.query(model.Change).count(), 0)

    def test_pop_with_changes(self):
        record_changes(self.session)

        customer = model.Customer()
        customer.add_email_address('fred.home@mailinator.com')
        customer.add_email_address('fred.work@mailinator.com')
        self.session.add(customer)
        self.session.commit()
        self.assertEqual(self.session.query(model.Customer).count(), 1)
        self.assertEqual(self.session.query(model.CustomerEmailAddress).count(), 2)

        while customer.emails:
            customer.emails.pop()
        self.session.commit()
        self.assertEqual(self.session.query(model.Customer).count(), 1)
        self.assertEqual(self.session.query(model.CustomerEmailAddress).count(), 0)

        # changes should have been recorded
        changes = self.session.query(model.Change)
        self.assertEqual(changes.count(), 3)

        customer_change = changes.filter_by(class_name='Customer').one()
        self.assertEqual(customer_change.uuid, customer.uuid)
        self.assertFalse(customer_change.deleted)

        email_changes = changes.filter_by(class_name='CustomerEmailAddress')
        self.assertEqual(email_changes.count(), 2)
        self.assertEqual([x.deleted for x in email_changes], [False, False])