# -*- coding: utf-8; -*- import os import shutil import tempfile from unittest import TestCase import sqlalchemy as sa from rattail import db from rattail.config import RattailConfig class TestSession(TestCase): def test_init_rattail_config(self): db.Session.configure(rattail_config=None) session = db.Session() self.assertIsNone(session.rattail_config) session.close() config = object() session = db.Session(rattail_config=config) self.assertIs(session.rattail_config, config) session.close() def test_init_record_changes(self): if hasattr(db.Session, 'kw'): self.assertIsNone(db.Session.kw.get('rattail_record_changes')) session = db.Session() self.assertFalse(session.rattail_record_changes) session.close() session = db.Session(rattail_record_changes=True) self.assertTrue(session.rattail_record_changes) session.close() engine = sa.create_engine('sqlite://') engine.rattail_record_changes = True session = db.Session(bind=engine) self.assertTrue(session.rattail_record_changes) session.close() class TestConfigExtension(TestCase): def setUp(self): self.tempdir = tempfile.mkdtemp() def tearDown(self): db.Session.configure(bind=None, rattail_config=None) shutil.rmtree(self.tempdir) def write_file(self, fname, content): path = os.path.join(self.tempdir, fname) with open(path, 'wt') as f: f.write(content) return path def test_configure_empty(self): config = RattailConfig() self.assertRaises(AttributeError, getattr, config, 'rattail_engines') self.assertRaises(AttributeError, getattr, config, 'rattail_engine') if hasattr(db.Session, 'kw'): self.assertIsNone(db.Session.kw['bind']) self.assertIsNone(db.Session.kw['rattail_config']) db.ConfigExtension().configure(config) self.assertEqual(config.rattail_engines, {}) self.assertIsNone(config.rattail_engine) if hasattr(db.Session, 'kw'): self.assertIs(db.Session.kw['rattail_config'], config) def test_configure_connections(self): default_path = self.write_file('default.sqlite', '') default_url = 'sqlite:///{}'.format(default_path) host_path = self.write_file('host.sqlite', '') host_url = 'sqlite:///{}'.format(host_path) config = RattailConfig() config.setdefault('rattail.db', 'keys', 'default, host') config.setdefault('rattail.db', 'default.url', default_url) config.setdefault('rattail.db', 'host.url', host_url) db.ConfigExtension().configure(config) self.assertEqual(len(config.rattail_engines), 2) self.assertEqual(str(config.rattail_engines['default'].url), default_url) self.assertEqual(str(config.rattail_engines['host'].url), host_url) self.assertEqual(str(config.rattail_engine.url), default_url)