diff --git a/tests/db/test_model.py b/tests/db/test_model.py index a5ce87ff6bea4b988d1e3db7df34c75c3c9272b6..a7ec3b7ae007decd79fa42f57bc7ea24072d77dd 100644 --- a/tests/db/test_model.py +++ b/tests/db/test_model.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from unittest import TestCase from . import DataTestCase @@ -11,6 +12,17 @@ from rattail.db.types import GPCType from rattail.db.changes import record_changes +class SAErrorHelper(object): + + def integrity_or_flush_error(self): + try: + from sqlalchemy.exc import FlushError + except ImportError: + return IntegrityError + else: + return (IntegrityError, FlushError) + + class TestPerson(DataTestCase): def test_default_display_name_is_generated_from_first_and_last_name_if_both_provided(self): @@ -119,7 +131,7 @@ class TestCustomer(DataTestCase): self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 0) -class TestCustomerPerson(DataTestCase): +class TestCustomerPerson(DataTestCase, SAErrorHelper): def test_repr(self): assoc = model.CustomerPerson(uuid='whatever') @@ -128,7 +140,7 @@ class TestCustomerPerson(DataTestCase): def test_customer_required(self): assoc = model.CustomerPerson(person=model.Person()) self.session.add(assoc) - self.assertRaises(IntegrityError, self.session.commit) + self.assertRaises(self.integrity_or_flush_error(), self.session.commit) self.session.rollback() self.assertEqual(self.session.query(model.CustomerPerson).count(), 0) assoc.customer = model.Customer() @@ -160,7 +172,7 @@ class TestCustomerPerson(DataTestCase): self.assertEqual(assoc.ordinal, 2) -class TestCustomerGroupAssignment(DataTestCase): +class TestCustomerGroupAssignment(DataTestCase, SAErrorHelper): def test_repr(self): assignment = model.CustomerGroupAssignment(uuid='whatever') @@ -169,7 +181,7 @@ class TestCustomerGroupAssignment(DataTestCase): def test_customer_required(self): assignment = model.CustomerGroupAssignment(group=model.CustomerGroup()) self.session.add(assignment) - self.assertRaises(IntegrityError, self.session.commit) + self.assertRaises(self.integrity_or_flush_error(), self.session.commit) self.session.rollback() self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 0) assignment.customer = model.Customer()