diff --git a/rattail/importing/model.py b/rattail/importing/model.py index ee6ed7095b38a5469152a377022fbe3dd1fa4184..ec220cd7d73a76b1ea10b5e41332bb6c8fd6d889 100644 --- a/rattail/importing/model.py +++ b/rattail/importing/model.py @@ -26,7 +26,7 @@ Rattail Model Importers from __future__ import unicode_literals, absolute_import -from rattail.db import model +from rattail.db import model, auth from rattail.importing import ToSQLAlchemy @@ -72,6 +72,38 @@ class UserImporter(ToRattail): model_class = model.User +class AdminUserImporter(UserImporter): + """ + User data importer, plus 'admin' boolean field. + """ + + @property + def supported_fields(self): + return super(AdminUserImporter, self).supported_fields + ['admin'] + + def get_admin(self, session=None): + return auth.administrator_role(session or self.session) + + def normalize_local_object(self, user): + data = super(AdminUserImporter, self).normalize_local_object(user) + if 'admin' in self.fields: + data['admin'] = self.get_admin() in user.roles + return data + + def update_object(self, user, data, local_data=None): + user = super(UserImporter, self).update_object(user, data, local_data) + if user: + if 'admin' in self.fields: + admin = self.get_admin() + if data['admin']: + if admin not in user.roles: + user.roles.append(admin) + else: + if admin in user.roles: + user.roles.remove(admin) + return user + + class MessageImporter(ToRattail): """ User message data importer. diff --git a/rattail/importing/rattail.py b/rattail/importing/rattail.py index 682a036e26cbdf303047ace05f430b186b5bdce3..9a15583a723206dbf63687eb6028da1d6ab1e165 100644 --- a/rattail/importing/rattail.py +++ b/rattail/importing/rattail.py @@ -55,6 +55,7 @@ class FromRattailToRattail(importing.FromSQLAlchemyHandler, importing.ToSQLAlche importers['PersonPhoneNumber'] = PersonPhoneNumberImporter importers['PersonMailingAddress'] = PersonMailingAddressImporter importers['User'] = UserImporter + importers['AdminUser'] = AdminUserImporter importers['Message'] = MessageImporter importers['MessageRecipient'] = MessageRecipientImporter importers['Store'] = StoreImporter @@ -90,6 +91,12 @@ class FromRattailToRattail(importing.FromSQLAlchemyHandler, importing.ToSQLAlche importers['ProductPrice'] = ProductPriceImporter return importers + def get_default_keys(self): + keys = self.get_importer_keys() + if 'AdminUser' in keys: + keys.remove('AdminUser') + return keys + class FromRattail(importing.FromSQLAlchemy): """ @@ -119,6 +126,15 @@ class PersonMailingAddressImporter(FromRattail, importing.model.PersonMailingAdd class UserImporter(FromRattail, importing.model.UserImporter): pass +class AdminUserImporter(FromRattail, importing.model.AdminUserImporter): + + def normalize_host_object(self, user): + data = super(AdminUserImporter, self).normalize_local_object(user) # sic + if 'admin' in self.fields: + data['admin'] = self.get_admin(self.host_session) in user.roles + return data + + class MessageImporter(FromRattail, importing.model.MessageImporter): pass diff --git a/rattail/tests/importing/test_model.py b/rattail/tests/importing/test_model.py new file mode 100644 index 0000000000000000000000000000000000000000..de286b305c08859c7c5e064e56f4cc58cd1f1a7a --- /dev/null +++ b/rattail/tests/importing/test_model.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- + +from __future__ import unicode_literals, absolute_import + +from rattail.db import model, auth +from rattail.importing import model as import_model +from rattail.tests import RattailTestCase + + +class TestAdminUser(RattailTestCase): + + def make_importer(self, **kwargs): + kwargs.setdefault('session', self.session) + return import_model.AdminUserImporter(**kwargs) + + def get_admin(self): + return auth.administrator_role(self.session) + + def test_supported_fields(self): + importer = import_model.UserImporter() + standard_fields = importer.fields + importer = self.make_importer() + extra_fields = set(importer.fields) - set(standard_fields) + self.assertEqual(len(extra_fields), 1) + self.assertEqual(list(extra_fields)[0], 'admin') + + def test_normalize_local_object(self): + importer = self.make_importer() + user = model.User() + data = importer.normalize_local_object(user) + self.assertFalse(data['admin']) + user.roles.append(self.get_admin()) + data = importer.normalize_local_object(user) + self.assertTrue(data['admin']) + + def test_update_object(self): + importer = self.make_importer(fields=['uuid', 'admin']) + data = {'uuid': 'ccb1915419e511e6a3ad3ca9f40bc550'} + user = model.User(**data) + admin = self.get_admin() + self.assertNotIn(admin, user.roles) + + data['admin'] = True + importer.update_object(user, data) + self.assertIn(admin, user.roles) + + data['admin'] = False + importer.update_object(user, data) + self.assertNotIn(admin, user.roles) diff --git a/rattail/tests/importing/test_rattail.py b/rattail/tests/importing/test_rattail.py index 273d35c1977223ef9c08c586edda91b789255efc..a5d85a04d6668ddd24548caf21d91e46a88961c8 100644 --- a/rattail/tests/importing/test_rattail.py +++ b/rattail/tests/importing/test_rattail.py @@ -8,7 +8,7 @@ import sqlalchemy as sa from mock import patch from fixture import TempIO -from rattail.db import model, Session, SessionBase +from rattail.db import model, Session, SessionBase, auth from rattail.importing import rattail as rattail_importing from rattail.tests import RattailMixin, RattailTestCase @@ -112,3 +112,24 @@ class TestFromRattail(DualRattailTestCase): self.assertEqual(data, {}) normalize_local.assert_called_once_with(product) self.assertEqual(data, importer.normalize_local_object(product)) + + +class TestAdminUser(DualRattailTestCase): + + importer_class = rattail_importing.AdminUserImporter + + def make_importer(self, **kwargs): + kwargs.setdefault('session', self.session) + return self.importer_class(**kwargs) + + def get_admin(self): + return auth.administrator_role(self.session) + + def test_normalize_host_object(self): + importer = self.make_importer() + user = model.User() + data = importer.normalize_host_object(user) + self.assertFalse(data['admin']) + user.roles.append(self.get_admin()) + data = importer.normalize_host_object(user) + self.assertTrue(data['admin'])