diff --git a/rattail/db/changes.py b/rattail/db/changes.py index 413e19f7ad0daf778c2b20699e6efc5af0e47659..8c595e4f5a91311500e1430e98f661b2f4971fdc 100644 --- a/rattail/db/changes.py +++ b/rattail/db/changes.py @@ -1,9 +1,8 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- +# -*- coding: utf-8 -*- ################################################################################ # # Rattail -- Retail Software Framework -# Copyright © 2010-2012 Lance Edgar +# Copyright © 2010-2014 Lance Edgar # # This file is part of Rattail. # @@ -26,8 +25,9 @@ Data Changes """ -from sqlalchemy.event import listen from sqlalchemy.orm import object_mapper, RelationshipProperty +from sqlalchemy.orm.interfaces import SessionExtension +from sqlalchemy.orm.session import Session from . import model from ..core import get_uuid @@ -44,16 +44,25 @@ def record_changes(session, ignore_role_changes=True): """ Record all changes which occur within a session. - :param session: A ``sqlalchemy.orm.sessionmaker`` class, or an instance - thereof. + :param session: A :class:`sqlalchemy:sqlalchemy.orm.session.Session` class, + or an instance thereof. :param ignore_role_changes: Whether changes involving roles and role membership should be ignored. This defaults to ``True``, which means each database will be responsible for maintaining its own role (and by extension, permissions) data. """ - - listen(session, 'before_flush', ChangeRecorder(ignore_role_changes)) + recorder = ChangeRecorder(ignore_role_changes) + try: + from sqlalchemy.event import listen + except ImportError: # pragma: no cover + extension = ChangeRecorderExtension(recorder) + if isinstance(session, Session): + session.extensions.append(extension) + else: + session.configure(extension=extension) + else: + listen(session, u'before_flush', recorder) class ChangeRecorder(object): @@ -196,3 +205,19 @@ class ChangeRecorder(object): instance.uuid = get_uuid() log.error("ChangeRecorder.ensure_uuid: unexpected scenario; generated new UUID for instance: {0}".format(repr(instance))) + + +class ChangeRecorderExtension(SessionExtension): # pragma: no cover + """ + Session extension for recording changes. + + .. note:: + This is only used when the installed SQLAlchemy version is old enough + not to support the new event interfaces. + """ + + def __init__(self, recorder): + self.recorder = recorder + + def before_flush(self, session, flush_context, instances): + self.recorder(session, flush_context, instances) diff --git a/tests/db/test_changes.py b/tests/db/test_changes.py index 0cf413aa820d94741e2040b3729d2dd5abeddd29..dab033e37b7aa27901978e655b105bf8570af5a1 100644 --- a/tests/db/test_changes.py +++ b/tests/db/test_changes.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from unittest import TestCase from mock import patch, DEFAULT, Mock, MagicMock, call @@ -9,24 +10,6 @@ from rattail.db import model from sqlalchemy.orm import RelationshipProperty -class TestChanges(TestCase): - - @patch.multiple('rattail.db.changes', listen=DEFAULT, ChangeRecorder=DEFAULT) - def test_record_changes(self, listen, ChangeRecorder): - session = Mock() - ChangeRecorder.return_value = 'whatever' - - changes.record_changes(session) - ChangeRecorder.assert_called_once_with(True) - listen.assert_called_once_with(session, 'before_flush', 'whatever') - - ChangeRecorder.reset_mock() - listen.reset_mock() - changes.record_changes(session, ignore_role_changes=False) - ChangeRecorder.assert_called_once_with(False) - listen.assert_called_once_with(session, 'before_flush', 'whatever') - - class TestChangeRecorder(TestCase): def test_init(self): diff --git a/tests/db/test_model.py b/tests/db/test_model.py index a5ce87ff6bea4b988d1e3db7df34c75c3c9272b6..a7ec3b7ae007decd79fa42f57bc7ea24072d77dd 100644 --- a/tests/db/test_model.py +++ b/tests/db/test_model.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from unittest import TestCase from . import DataTestCase @@ -11,6 +12,17 @@ from rattail.db.types import GPCType from rattail.db.changes import record_changes +class SAErrorHelper(object): + + def integrity_or_flush_error(self): + try: + from sqlalchemy.exc import FlushError + except ImportError: + return IntegrityError + else: + return (IntegrityError, FlushError) + + class TestPerson(DataTestCase): def test_default_display_name_is_generated_from_first_and_last_name_if_both_provided(self): @@ -119,7 +131,7 @@ class TestCustomer(DataTestCase): self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 0) -class TestCustomerPerson(DataTestCase): +class TestCustomerPerson(DataTestCase, SAErrorHelper): def test_repr(self): assoc = model.CustomerPerson(uuid='whatever') @@ -128,7 +140,7 @@ class TestCustomerPerson(DataTestCase): def test_customer_required(self): assoc = model.CustomerPerson(person=model.Person()) self.session.add(assoc) - self.assertRaises(IntegrityError, self.session.commit) + self.assertRaises(self.integrity_or_flush_error(), self.session.commit) self.session.rollback() self.assertEqual(self.session.query(model.CustomerPerson).count(), 0) assoc.customer = model.Customer() @@ -160,7 +172,7 @@ class TestCustomerPerson(DataTestCase): self.assertEqual(assoc.ordinal, 2) -class TestCustomerGroupAssignment(DataTestCase): +class TestCustomerGroupAssignment(DataTestCase, SAErrorHelper): def test_repr(self): assignment = model.CustomerGroupAssignment(uuid='whatever') @@ -169,7 +181,7 @@ class TestCustomerGroupAssignment(DataTestCase): def test_customer_required(self): assignment = model.CustomerGroupAssignment(group=model.CustomerGroup()) self.session.add(assignment) - self.assertRaises(IntegrityError, self.session.commit) + self.assertRaises(self.integrity_or_flush_error(), self.session.commit) self.session.rollback() self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 0) assignment.customer = model.Customer()