# -*- coding: utf-8; -*-
import os
import shutil
import tempfile
from unittest import TestCase
import sqlalchemy as sa
from rattail import db
from rattail.config import RattailConfig
class TestSession(TestCase):
def test_init_rattail_config(self):
db.Session.configure(rattail_config=None)
session = db.Session()
self.assertIsNone(session.rattail_config)
session.close()
config = object()
session = db.Session(rattail_config=config)
self.assertIs(session.rattail_config, config)
session.close()
def test_init_record_changes(self):
if hasattr(db.Session, 'kw'):
self.assertIsNone(db.Session.kw.get('rattail_record_changes'))
session = db.Session()
self.assertFalse(session.rattail_record_changes)
session.close()
session = db.Session(rattail_record_changes=True)
self.assertTrue(session.rattail_record_changes)
session.close()
engine = sa.create_engine('sqlite://')
engine.rattail_record_changes = True
session = db.Session(bind=engine)
self.assertTrue(session.rattail_record_changes)
session.close()
class TestConfigExtension(TestCase):
def setUp(self):
self.tempdir = tempfile.mkdtemp()
def tearDown(self):
db.Session.configure(bind=None, rattail_config=None)
shutil.rmtree(self.tempdir)
def write_file(self, fname, content):
path = os.path.join(self.tempdir, fname)
with open(path, 'wt') as f:
f.write(content)
return path
def test_configure_empty(self):
config = RattailConfig()
self.assertRaises(AttributeError, getattr, config, 'rattail_engines')
self.assertRaises(AttributeError, getattr, config, 'rattail_engine')
if hasattr(db.Session, 'kw'):
self.assertIsNone(db.Session.kw['bind'])
self.assertIsNone(db.Session.kw['rattail_config'])
db.ConfigExtension().configure(config)
self.assertEqual(config.rattail_engines, {})
self.assertIsNone(config.rattail_engine)
if hasattr(db.Session, 'kw'):
self.assertIs(db.Session.kw['rattail_config'], config)
def test_configure_connections(self):
default_path = self.write_file('default.sqlite', '')
default_url = 'sqlite:///{}'.format(default_path)
host_path = self.write_file('host.sqlite', '')
host_url = 'sqlite:///{}'.format(host_path)
config = RattailConfig()
config.setdefault('rattail.db', 'keys', 'default, host')
config.setdefault('rattail.db', 'default.url', default_url)
config.setdefault('rattail.db', 'host.url', host_url)
db.ConfigExtension().configure(config)
self.assertEqual(len(config.rattail_engines), 2)
self.assertEqual(str(config.rattail_engines['default'].url), default_url)
self.assertEqual(str(config.rattail_engines['host'].url), host_url)
self.assertEqual(str(config.rattail_engine.url), default_url)