Files @ c44ddd3740e8
Branch filter:

Location: rattail-project/rattail/tests/db/test_model.py

lance
Overhauled some database stuff; added tests.

from unittest import TestCase
from . import DataTestCase
from mock import patch, DEFAULT, Mock, MagicMock

from rattail.db.extension import model
from sqlalchemy import String, Boolean, Numeric
from rattail.db.types import GPCType
from sqlalchemy.exc import IntegrityError


class TestBatch(TestCase):

    @patch('rattail.db.extension.model.object_session')
    def test_rowclass(self, object_session):
        object_session.return_value = object_session

        # no row classes to start with
        self.assertEqual(model.Batch._rowclasses, {})

        # basic, empty row class
        batch = model.Batch(uuid='some_uuid')
        batch.get_sqlalchemy_type = Mock(return_value='some_type')
        batch.columns = MagicMock()
        rowclass = batch.rowclass
        self.assertTrue(issubclass(rowclass, model.BatchRow))
        self.assertEqual(model.Batch._rowclasses.keys(), ['some_uuid'])
        self.assertIs(model.Batch._rowclasses['some_uuid'], rowclass)
        self.assertFalse(object_session.flush.called)

        # make sure rowclass.batch works
        object_session.query.return_value.get.return_value = batch
        self.assertIs(rowclass().batch, batch)
        object_session.query.return_value.get.assert_called_once_with('some_uuid')

        # row class with generated uuid and some columns
        batch = model.Batch(uuid=None)
        batch.columns = [model.BatchColumn(name='F01'), model.BatchColumn(name='F02')]
        model.Batch.get_sqlalchemy_type = Mock(return_value=String(length=20))
        def set_uuid():
            batch.uuid = 'fresh_uuid'
        object_session.flush.side_effect = set_uuid
        rowclass = batch.rowclass
        object_session.flush.assert_called_once_with()
        self.assertItemsEqual(model.Batch._rowclasses.keys(), ['some_uuid', 'fresh_uuid'])
        self.assertIs(model.Batch._rowclasses['fresh_uuid'], rowclass)

    def test_get_sqlalchemy_type(self):

        # gpc
        self.assertIsInstance(model.Batch.get_sqlalchemy_type('GPC(14)'), GPCType)

        # boolean
        self.assertIsInstance(model.Batch.get_sqlalchemy_type('FLAG(1)'), Boolean)

        # string
        type_ = model.Batch.get_sqlalchemy_type('CHAR(20)')
        self.assertIsInstance(type_, String)
        self.assertEqual(type_.length, 20)

        # numeric
        type_ = model.Batch.get_sqlalchemy_type('NUMBER(9,3)')
        self.assertIsInstance(type_, Numeric)
        self.assertEqual(type_.precision, 9)
        self.assertEqual(type_.scale, 3)

        # invalid
        self.assertRaises(AssertionError, model.Batch.get_sqlalchemy_type, 'CHAR(9,3)')
        self.assertRaises(AssertionError, model.Batch.get_sqlalchemy_type, 'OMGWTFBBQ')


class TestCustomer(DataTestCase):

    def test_repr(self):
        customer = model.Customer(uuid='whatever')
        self.assertEqual(repr(customer), "Customer(uuid='whatever')")

    def test_unicode(self):
        customer = model.Customer()
        self.assertEqual(unicode(customer), u'None')
        customer = model.Customer(name='Fred')
        self.assertEqual(unicode(customer), u'Fred')

    def test_cascade_delete_assignment(self):
        customer = model.Customer()
        assignment = model.CustomerGroupAssignment(
            customer=customer, group=model.CustomerGroup(), ordinal=1)
        self.session.add_all([customer, assignment])
        self.session.commit()
        self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 1)
        self.session.delete(customer)
        self.session.commit()
        self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 0)


class TestCustomerPerson(DataTestCase):

    def test_repr(self):
        assoc = model.CustomerPerson(uuid='whatever')
        self.assertEqual(repr(assoc), "CustomerPerson(uuid='whatever')")

    def test_customer_required(self):
        assoc = model.CustomerPerson(person=model.Person())
        self.session.add(assoc)
        self.assertRaises(IntegrityError, self.session.commit)
        self.session.rollback()
        self.assertEqual(self.session.query(model.CustomerPerson).count(), 0)
        assoc.customer = model.Customer()
        self.session.add(assoc)
        self.session.commit()
        self.assertEqual(self.session.query(model.CustomerPerson).count(), 1)

    def test_person_required(self):
        assoc = model.CustomerPerson(customer=model.Customer())
        self.session.add(assoc)
        self.assertRaises(IntegrityError, self.session.commit)
        self.session.rollback()
        self.assertEqual(self.session.query(model.CustomerPerson).count(), 0)
        assoc.person = model.Person()
        self.session.add(assoc)
        self.session.commit()
        self.assertEqual(self.session.query(model.CustomerPerson).count(), 1)

    def test_ordinal_autoincrement(self):
        customer = model.Customer()
        self.session.add(customer)
        assoc = model.CustomerPerson(person=model.Person())
        customer._people.append(assoc)
        self.session.commit()
        self.assertEqual(assoc.ordinal, 1)
        assoc = model.CustomerPerson(person=model.Person())
        customer._people.append(assoc)
        self.session.commit()
        self.assertEqual(assoc.ordinal, 2)


class TestCustomerGroupAssignment(DataTestCase):

    def test_repr(self):
        assignment = model.CustomerGroupAssignment(uuid='whatever')
        self.assertEqual(repr(assignment), "CustomerGroupAssignment(uuid='whatever')")

    def test_customer_required(self):
        assignment = model.CustomerGroupAssignment(group=model.CustomerGroup())
        self.session.add(assignment)
        self.assertRaises(IntegrityError, self.session.commit)
        self.session.rollback()
        self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 0)
        assignment.customer = model.Customer()
        self.session.add(assignment)
        self.session.commit()
        self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 1)

    def test_group_required(self):
        assignment = model.CustomerGroupAssignment(customer=model.Customer())
        self.session.add(assignment)
        self.assertRaises(IntegrityError, self.session.commit)
        self.session.rollback()
        self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 0)
        assignment.group = model.CustomerGroup()
        self.session.add(assignment)
        self.session.commit()
        self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 1)

    def test_ordinal_autoincrement(self):
        customer = model.Customer()
        self.session.add(customer)
        assignment = model.CustomerGroupAssignment(group=model.CustomerGroup())
        customer._groups.append(assignment)
        self.session.commit()
        self.assertEqual(assignment.ordinal, 1)
        assignment = model.CustomerGroupAssignment(group=model.CustomerGroup())
        customer._groups.append(assignment)
        self.session.commit()
        self.assertEqual(assignment.ordinal, 2)