Files @ 110c68468203
Branch filter:

Location: rattail-project/rattail/tests/importing/test_sqlalchemy.py

lance
Update changelog
# -*- coding: utf-8; -*-

from __future__ import unicode_literals, absolute_import

from unittest import TestCase

import six
from mock import patch, Mock

import sqlalchemy as sa
from sqlalchemy import orm
from sqlalchemy.orm.exc import MultipleResultsFound

from rattail.importing import sqlalchemy as saimport
from rattail.config import make_config


class Widget(object):
    pass

metadata = sa.MetaData()

widget_table = sa.Table('widgets', metadata,
                        sa.Column('id', sa.Integer(), primary_key=True),
                        sa.Column('description', sa.String(length=50)))

widget_mapper = orm.mapper(Widget, widget_table)

WIDGETS = [
    {'id': 1, 'description': "Main Widget"},
    {'id': 2, 'description': "Other Widget"},
    {'id': 3, 'description': "Other Widget"},
]


class TestFromSQLAlchemy(TestCase):

    def setUp(self):
        self.config = self.make_config()

    def make_config(self):
        config = make_config([], extend=False)
        config.setdefault('rattail', 'timezone.default', 'America/Chicago')
        return config

    def test_query(self):
        Session = orm.sessionmaker(bind=sa.create_engine('sqlite://'))
        session = Session()
        importer = saimport.FromSQLAlchemy(self.config,
                                           host_session=session,
                                           host_model_class=Widget)
        self.assertEqual(six.text_type(importer.query()),
                         "SELECT widgets.id AS widgets_id, widgets.description AS widgets_description \n"
                         "FROM widgets")


class TestToSQLAlchemy(TestCase):

    def setUp(self):
        self.config = self.make_config()
        engine = sa.create_engine('sqlite://')
        metadata.create_all(bind=engine)
        Session = orm.sessionmaker(bind=engine)
        self.session = Session()
        for data in WIDGETS:
            widget = Widget()
            for key, value in data.items():
                setattr(widget, key, value)
            self.session.add(widget)
        self.session.commit()

    def tearDown(self):
        self.session.close()
        del self.session

    def make_config(self):
        config = make_config([], extend=False)
        config.setdefault('rattail', 'timezone.default', 'America/Chicago')
        return config

    def make_importer(self, **kwargs):
        kwargs.setdefault('config', self.config)
        kwargs.setdefault('model_class', Widget)
        return saimport.ToSQLAlchemy(**kwargs)

    def test_model_mapper(self):
        importer = self.make_importer()
        self.assertIs(importer.model_mapper, widget_mapper)

    def test_model_table(self):
        importer = self.make_importer()
        self.assertIs(importer.model_table, widget_table)

    def test_simple_fields(self):
        importer = self.make_importer()
        self.assertEqual(importer.simple_fields, ['id', 'description'])

    def test_get_single_local_object(self):

        # simple key
        importer = self.make_importer(key='id', session=self.session)
        widget = importer.get_single_local_object((1,))
        self.assertEqual(widget.id, 1)
        self.assertEqual(widget.description, "Main Widget")

        # compound key
        importer = self.make_importer(key=('id', 'description'), session=self.session)
        widget = importer.get_single_local_object((1, "Main Widget"))
        self.assertEqual(widget.id, 1)
        self.assertEqual(widget.description, "Main Widget")

        # widget not found
        importer = self.make_importer(key='id', session=self.session)
        self.assertIsNone(importer.get_single_local_object((42,)))

        # multiple widgets found
        importer = self.make_importer(key='description', session=self.session)
        self.assertRaises(MultipleResultsFound, importer.get_single_local_object, ("Other Widget",))

    def test_create_object(self):
        importer = self.make_importer(key='id', session=self.session)
        widget = importer.create_object((42,), {'id': 42, 'description': "Latest Widget"})
        self.assertIn(widget, self.session)
        self.assertIn(widget, self.session.new)
        self.assertEqual(widget.id, 42)
        self.assertEqual(widget.description, "Latest Widget")

    def test_delete_object(self):
        widget = self.session.query(Widget).get(1)
        self.assertIn(widget, self.session)
        importer = self.make_importer(session=self.session)
        self.assertTrue(importer.delete_object(widget))
        self.assertIn(widget, self.session)
        self.assertIn(widget, self.session.deleted)
        self.session.flush()
        self.assertIsNone(self.session.query(Widget).get(1))

    def test_cache_model(self):
        importer = self.make_importer(key='id', session=self.session)
        cache = importer.cache_model(Widget, key='id')
        self.assertEqual(len(cache), 3)
        for i in range(1, 4):
            self.assertIn(i, cache)
            self.assertIsInstance(cache[i], Widget)
            self.assertEqual(cache[i].id, i)
            self.assertEqual(cache[i].description, WIDGETS[i-1]['description'])

    def test_cache_local_data(self):
        importer = self.make_importer(key='id', session=self.session)
        importer.handler = Mock(local_title="Nevermind")
        cache = importer.cache_local_data()
        self.assertEqual(len(cache), 3)
        for i in range(1, 4):
            self.assertIn((i,), cache)
            cached = cache[(i,)]
            self.assertIsInstance(cached, dict)
            self.assertIsInstance(cached['object'], Widget)
            self.assertEqual(cached['object'].id, i)
            self.assertEqual(cached['object'].description, WIDGETS[i-1]['description'])
            self.assertIsInstance(cached['data'], dict)
            self.assertEqual(cached['data']['id'], i)
            self.assertEqual(cached['data']['description'], WIDGETS[i-1]['description'])

    # TODO: lame
    def test_flush_session(self):
        importer = self.make_importer(fields=['id'], key='id', session=self.session, flush_session=True)
        widget = Widget()
        widget.id = 1
        widget, original = importer.update_object(widget, {'id': 1}), widget
        self.assertIs(widget, original)

    def test_flush_create_update(self):
        importer = saimport.ToSQLAlchemy(self.config, fields=['id'],
                                         session=self.session)
        with patch.object(self.session, 'flush') as flush:
            importer.flush_create_update()
            flush.assert_called_once_with()