diff --git a/tests/db/test_model.py b/tests/db/test_model.py index e1998cbc015962414a4c26ebba1d8a44a9827059..e30e74bfd641c3a037324f8f5dc07d9430d7ebc1 100644 --- a/tests/db/test_model.py +++ b/tests/db/test_model.py @@ -1,11 +1,74 @@ from unittest import TestCase from . import DataTestCase +from mock import patch, DEFAULT, Mock, MagicMock from rattail.db.extension import model +from sqlalchemy import String, Boolean, Numeric +from rattail.db.types import GPCType from sqlalchemy.exc import IntegrityError +class TestBatch(TestCase): + + @patch('rattail.db.extension.model.object_session') + def test_rowclass(self, object_session): + object_session.return_value = object_session + + # no row classes to start with + self.assertEqual(model.Batch._rowclasses, {}) + + # basic, empty row class + batch = model.Batch(uuid='some_uuid') + batch.get_sqlalchemy_type = Mock(return_value='some_type') + batch.columns = MagicMock() + rowclass = batch.rowclass + self.assertTrue(issubclass(rowclass, model.BatchRow)) + self.assertEqual(model.Batch._rowclasses.keys(), ['some_uuid']) + self.assertIs(model.Batch._rowclasses['some_uuid'], rowclass) + self.assertFalse(object_session.flush.called) + + # make sure rowclass.batch works + object_session.query.return_value.get.return_value = batch + self.assertIs(rowclass().batch, batch) + object_session.query.return_value.get.assert_called_once_with('some_uuid') + + # row class with generated uuid and some columns + batch = model.Batch(uuid=None) + batch.columns = [model.BatchColumn(name='F01'), model.BatchColumn(name='F02')] + model.Batch.get_sqlalchemy_type = Mock(return_value=String(length=20)) + def set_uuid(): + batch.uuid = 'fresh_uuid' + object_session.flush.side_effect = set_uuid + rowclass = batch.rowclass + object_session.flush.assert_called_once_with() + self.assertItemsEqual(model.Batch._rowclasses.keys(), ['some_uuid', 'fresh_uuid']) + self.assertIs(model.Batch._rowclasses['fresh_uuid'], rowclass) + + def test_get_sqlalchemy_type(self): + + # gpc + self.assertIsInstance(model.Batch.get_sqlalchemy_type('GPC(14)'), GPCType) + + # boolean + self.assertIsInstance(model.Batch.get_sqlalchemy_type('FLAG(1)'), Boolean) + + # string + type_ = model.Batch.get_sqlalchemy_type('CHAR(20)') + self.assertIsInstance(type_, String) + self.assertEqual(type_.length, 20) + + # numeric + type_ = model.Batch.get_sqlalchemy_type('NUMBER(9,3)') + self.assertIsInstance(type_, Numeric) + self.assertEqual(type_.precision, 9) + self.assertEqual(type_.scale, 3) + + # invalid + self.assertRaises(AssertionError, model.Batch.get_sqlalchemy_type, 'CHAR(9,3)') + self.assertRaises(AssertionError, model.Batch.get_sqlalchemy_type, 'OMGWTFBBQ') + + class TestCustomer(DataTestCase): def test_repr(self):