Files @ 110c68468203
Branch filter:

Location: rattail-project/rattail/tests/db/test_init.py

lance
Update changelog
# -*- coding: utf-8; -*-

from __future__ import unicode_literals, absolute_import

import os
import shutil
import tempfile
from unittest import TestCase

import six
import sqlalchemy as sa

from rattail import db
from rattail.config import RattailConfig


class TestSession(TestCase):

    def test_init_rattail_config(self):
        session = db.Session()
        self.assertIsNone(session.rattail_config)
        session.close()

        config = object()
        session = db.Session(rattail_config=config)
        self.assertIs(session.rattail_config, config)
        session.close()

    def test_init_record_changes(self):
        if hasattr(db.Session, 'kw'):
            self.assertIsNone(db.Session.kw.get('rattail_record_changes'))

        session = db.Session()
        self.assertFalse(session.rattail_record_changes)
        session.close()

        session = db.Session(rattail_record_changes=True)
        self.assertTrue(session.rattail_record_changes)
        session.close()

        engine = sa.create_engine('sqlite://')
        engine.rattail_record_changes = True
        session = db.Session(bind=engine)
        self.assertTrue(session.rattail_record_changes)
        session.close()


class TestConfigExtension(TestCase):

    def setUp(self):
        self.tempdir = tempfile.mkdtemp()

    def tearDown(self):
        db.Session.configure(bind=None, rattail_config=None)
        shutil.rmtree(self.tempdir)

    def write_file(self, fname, content):
        path = os.path.join(self.tempdir, fname)
        with open(path, 'wt') as f:
            f.write(content)
        return path

    def test_configure_empty(self):
        config = RattailConfig()
        self.assertRaises(AttributeError, getattr, config, 'rattail_engines')
        self.assertRaises(AttributeError, getattr, config, 'rattail_engine')
        if hasattr(db.Session, 'kw'):
            self.assertIsNone(db.Session.kw['bind'])
            self.assertIsNone(db.Session.kw['rattail_config'])

        db.ConfigExtension().configure(config)
        self.assertEqual(config.rattail_engines, {})
        self.assertIsNone(config.rattail_engine)
        if hasattr(db.Session, 'kw'):
            self.assertIs(db.Session.kw['rattail_config'], config)

    def test_configure_connections(self):
        default_path = self.write_file('default.sqlite', '')
        default_url = 'sqlite:///{}'.format(default_path)
        host_path = self.write_file('host.sqlite', '')
        host_url = 'sqlite:///{}'.format(host_path)

        config = RattailConfig()
        config.set('rattail.db', 'keys', 'default, host')
        config.set('rattail.db', 'default.url', default_url)
        config.set('rattail.db', 'host.url', host_url)
        db.ConfigExtension().configure(config)
        self.assertEqual(len(config.rattail_engines), 2)
        self.assertEqual(six.text_type(config.rattail_engines['default'].url), default_url)
        self.assertEqual(six.text_type(config.rattail_engines['host'].url), host_url)
        self.assertEqual(six.text_type(config.rattail_engine.url), default_url)