Changeset - 22a589773921
[Not reviewed]
0 8 0
Lance Edgar (lance) - 11 years ago 2013-12-18 23:02:44
lance@edbob.org
Removed global `Session` from `rattail.db`.

A Session class may now be had via `get_session_class()`.
8 files changed with 90 insertions and 53 deletions:
0 comments (0 inline, 0 general)
rattail/commands.py
Show inline comments
 
@@ -34,7 +34,6 @@ from edbob import commands
 
from edbob.commands import Subcommand
 

	
 
from ._version import __version__
 
from .db import Session
 
from .db import model
 

	
 

	
 
@@ -76,13 +75,15 @@ class AddUser(Subcommand):
 

	
 
    def run(self, args):
 
        from sqlalchemy import create_engine
 
        from .db import Session
 
        from sqlalchemy.orm import sessionmaker
 
        from .db.model import User
 
        from getpass import getpass
 
        from .db.auth import set_user_password, administrator_role
 

	
 
        engine = create_engine(args.url)
 
        session = Session(bind=engine)
 
        Session = sessionmaker(bind=engine)
 

	
 
        session = Session()
 
        if session.query(User).filter_by(username=args.username).count():
 
            session.close()
 
            print("User '{0}' already exists.".format(args.username))
 
@@ -156,7 +157,7 @@ class Dump(Subcommand):
 
            'model', help="Model whose data will be dumped.")
 

	
 
    def run(self, args):
 
        from .db import Session
 
        from .db import get_session_class
 
        from .db.dump import dump_data
 
        from .console import Progress
 

	
 
@@ -175,6 +176,7 @@ class Dump(Subcommand):
 
        else:
 
            output = sys.stdout
 

	
 
        Session = get_session_class(edbob.config)
 
        session = Session()
 
        dump_data(session, cls, output, progress=progress)
 
        session.close()
 
@@ -435,12 +437,13 @@ class PurgeBatchesCommand(Subcommand):
 
                            help="Purge ALL batches regardless of purge date")
 

	
 
    def run(self, args):
 
        from rattail.db.batches.util import purge_batches
 
        from .db import get_session_class
 
        from .db.batches.util import purge_batches
 

	
 
        edbob.init_modules(['edbob.db', 'rattail.db'])
 
        Session = get_session_class(edbob.config)
 

	
 
        print "Purging batches from database:"
 
        print "    %s" % edbob.engine.url
 
        print "    %s" % Session.kw['bind'].url
 

	
 
        session = Session()
 
        purged = purge_batches(session, purge_everything=args.all)
rattail/db/__init__.py
Show inline comments
 
@@ -26,35 +26,62 @@
 
Database Stuff
 
"""
 

	
 
from edbob.db import Session
 
import warnings
 

	
 
from sqlalchemy import engine_from_config
 
from sqlalchemy.orm import sessionmaker
 

	
 
from .core import *
 
from .changes import *
 

	
 

	
 
def init(config):
 
def get_engines(config):
 
    """
 
    Initialize the Rattail database framework.
 
    Fetch all database engines defined in the given config object.
 

	
 
    :param config: A ``ConfigParser`` instance containing app configuration.
 

	
 
    :returns: A dictionary of SQLAlchemy engine instances, keyed according to
 
       the config settings.
 
    """
 
    keys = config.get('edbob.db', 'keys')
 
    if keys:
 
        keys = keys.split(',')
 
    else:
 
        keys = ['default']
 

	
 
    import edbob
 
    import rattail
 
    engines = {}
 
    cfg = config.get_dict('edbob.db')
 
    for key in keys:
 
        key = key.strip()
 
        try:
 
            engines[key] = engine_from_config(cfg, prefix='{0}.'.format(key))
 
        except KeyError:
 
            if key == 'default':
 
                try:
 
                    engines[key] = engine_from_config(cfg, prefix='sqlalchemy.')
 
                except KeyError:
 
                    pass
 
    return engines
 

	
 
    # Pretend all ``edbob`` models come from Rattail, until that is true.
 
    from edbob.db import Base
 
    names = []
 
    for name in edbob.__all__:
 
        obj = getattr(edbob, name)
 
        if isinstance(obj, type) and issubclass(obj, Base):
 
            names.append(name)
 
    edbob.graft(rattail, edbob, names)
 

	
 
    # Pretend all ``edbob`` enumerations come from Rattail.
 
    from edbob import enum
 
    edbob.graft(rattail, enum)
 
def get_default_engine(config):
 
    """
 
    Fetch the default SQLAlchemy database engine.
 
    """
 
    return get_engines(config).get('default')
 

	
 
    from rattail.db.extension import model
 
    edbob.graft(rattail, model)
 

	
 
def get_session_class(config):
 
    """
 
    Create and configure a database session class using the given config object.
 

	
 
    :returns: A class inheriting from ``sqlalchemy.orm.Session``.
 
    """
 
    from .changes import record_changes
 

	
 
    engine = get_default_engine(config)
 
    Session = sessionmaker(bind=engine)
 

	
 
    ignore_role_changes = config.getboolean(
 
        'rattail.db', 'changes.ignore_roles', default=True)
 
@@ -63,8 +90,14 @@ def init(config):
 
        record_changes(Session, ignore_role_changes)
 

	
 
    elif config.getboolean('rattail.db', 'record_changes'):
 
        import warnings
 
        warnings.warn("Config setting 'record_changes' in section [rattail.db] "
 
                      "is deprecated; please use 'changes.record' instead.",
 
                      DeprecationWarning)
 
        record_changes(Session, ignore_role_changes)
 

	
 
    return Session
 

	
 

	
 
# TODO: Remove once deprecation is complete.
 
def init(config):
 
    warnings.warn("Calling `rattail.db.init()` is deprecated.", DeprecationWarning)
rattail/db/diffs.py
Show inline comments
 
@@ -26,7 +26,6 @@
 
``rattail.db.diffs`` -- Data Diffs
 
"""
 

	
 
from rattail.db import Session
 
from rattail.db.sync import get_sync_engines
 

	
 
from sqlalchemy.orm import class_mapper
rattail/db/load.py
Show inline comments
 
@@ -31,16 +31,15 @@ from sqlalchemy.orm import joinedload
 
import edbob
 

	
 
from ..core import Object
 
from . import Session
 
from . import model
 
from . import get_session_class
 

	
 

	
 
class LoadProcessor(Object):
 

	
 
    def load_all_data(self, host_engine, progress=None):
 

	
 
        edbob.init_modules(['edbob.db', 'rattail.db'])
 

	
 
        Session = get_session_class(edbob.config)
 
        self.host_session = Session(bind=host_engine)
 
        self.session = Session()
 

	
rattail/db/sync/__init__.py
Show inline comments
 
@@ -34,12 +34,11 @@ if sys.platform == 'win32': # pragma no cover
 
    import win32api
 

	
 
import sqlalchemy.exc
 
from sqlalchemy.orm import class_mapper
 
from sqlalchemy.orm import sessionmaker, class_mapper
 
from sqlalchemy.exc import OperationalError
 

	
 
import edbob
 

	
 
from rattail.db import Session
 
from rattail.db import model
 

	
 

	
 
@@ -74,6 +73,7 @@ class Synchronizer(object):
 
    model = model
 

	
 
    def __init__(self, local_engine, remote_engines):
 
        self.Session = sessionmaker()
 
        self.local_engine = local_engine
 
        self.remote_engines = remote_engines
 

	
 
@@ -99,14 +99,14 @@ class Synchronizer(object):
 
            time.sleep(seconds)
 

	
 
    def synchronize(self):
 
        local_session = Session(bind=self.local_engine)
 
        local_session = self.Session(bind=self.local_engine)
 
        local_changes = local_session.query(model.Change).all()
 
        if len(local_changes):
 
            log.debug("Synchronizer.synchronize: found {0} change(s) to synchronize".format(len(local_changes)))
 

	
 
            remote_sessions = {}
 
            for key, remote_engine in self.remote_engines.iteritems():
 
                remote_sessions[key] = Session(bind=remote_engine)
 
                remote_sessions[key] = self.Session(bind=remote_engine)
 

	
 
            self.synchronize_changes(local_changes, local_session, remote_sessions)
 

	
tests/db/__init__.py
Show inline comments
 
@@ -2,8 +2,8 @@
 
import unittest
 

	
 
from sqlalchemy import create_engine
 
from sqlalchemy.orm import sessionmaker
 

	
 
from rattail.db import Session
 
from rattail.db.model import Base
 

	
 

	
 
@@ -15,7 +15,8 @@ class DataTestCase(unittest.TestCase):
 
    def setUp(self):
 
        engine = create_engine('sqlite://')
 
        Base.metadata.create_all(bind=engine)
 
        self.session = Session(bind=engine)
 
        self.Session = sessionmaker(bind=engine)
 
        self.session = self.Session()
 

	
 
    def tearDown(self):
 
        self.session.close()
tests/db/sync/__init__.py
Show inline comments
 
@@ -2,6 +2,7 @@
 
from unittest import TestCase
 

	
 
from sqlalchemy import create_engine
 
from sqlalchemy.orm import sessionmaker
 

	
 
from rattail.db import model
 

	
 
@@ -14,6 +15,7 @@ class SyncTestCase(TestCase):
 
            'one': create_engine('sqlite://'),
 
            'two': create_engine('sqlite://'),
 
            }
 
        self.Session = sessionmaker()
 
        model.Base.metadata.create_all(bind=self.local_engine)
 
        model.Base.metadata.create_all(bind=self.remote_engines['one'])
 
        model.Base.metadata.create_all(bind=self.remote_engines['two'])
tests/db/sync/test_init.py
Show inline comments
 
@@ -7,7 +7,7 @@ from sqlalchemy.exc import OperationalError, SAWarning
 

	
 
from . import SyncTestCase
 
from rattail.db import sync
 
from rattail.db import Session
 
from rattail.db import get_session_class
 
from rattail.db import model
 

	
 

	
 
@@ -46,7 +46,7 @@ class SynchronizerTests(SyncTestCase):
 
            self.assertFalse(synchronize_changes.called)
 

	
 
            # some changes
 
            local_session = Session(bind=self.local_engine)
 
            local_session = self.Session(bind=self.local_engine)
 
            product = model.Product()
 
            local_session.add(product)
 
            local_session.flush()
 
@@ -65,10 +65,10 @@ class SynchronizerTests(SyncTestCase):
 
    def test_synchronize_changes(self):
 
        synchronizer = sync.Synchronizer(self.local_engine, self.remote_engines)
 

	
 
        local_session = Session(bind=self.local_engine)
 
        local_session = self.Session(bind=self.local_engine)
 
        remote_sessions = {
 
            'one': Session(bind=self.remote_engines['one']),
 
            'two': Session(bind=self.remote_engines['two']),
 
            'one': self.Session(bind=self.remote_engines['one']),
 
            'two': self.Session(bind=self.remote_engines['two']),
 
            }
 

	
 
        # no changes; nothing should happen but make sure nothing blows up also
 
@@ -168,8 +168,8 @@ class SynchronizerTests(SyncTestCase):
 
                SAWarning, r'^sqlalchemy\.types$')
 

	
 
            # no prices
 
            local_session = Session(bind=self.local_engine)
 
            remote_session = Session(bind=self.remote_engines['one'])
 
            local_session = self.Session(bind=self.local_engine)
 
            remote_session = self.Session(bind=self.remote_engines['one'])
 
            source_product = model.Product()
 
            local_session.add(source_product)
 
            local_session.flush()
 
@@ -191,8 +191,8 @@ class SynchronizerTests(SyncTestCase):
 
            remote_session.close()
 

	
 
            # regular price
 
            local_session = Session(bind=self.local_engine)
 
            remote_session = Session(bind=self.remote_engines['one'])
 
            local_session = self.Session(bind=self.local_engine)
 
            remote_session = self.Session(bind=self.remote_engines['one'])
 
            source_product = model.Product()
 
            regular_price = model.ProductPrice()
 
            source_product.prices.append(regular_price)
 
@@ -210,8 +210,8 @@ class SynchronizerTests(SyncTestCase):
 
            remote_session.close()
 

	
 
            # current price
 
            local_session = Session(bind=self.local_engine)
 
            remote_session = Session(bind=self.remote_engines['one'])
 
            local_session = self.Session(bind=self.local_engine)
 
            remote_session = self.Session(bind=self.remote_engines['one'])
 
            source_product = model.Product()
 
            current_price = model.ProductPrice()
 
            source_product.prices.append(current_price)
 
@@ -258,7 +258,7 @@ class SynchronizerTests(SyncTestCase):
 
                'this platform for lossless storage\.$',
 
                SAWarning, r'^sqlalchemy\.types$')
 

	
 
            session = Session(bind=self.local_engine)
 
            session = self.Session(bind=self.local_engine)
 
            department = model.Department()
 
            department.subdepartments.append(model.Subdepartment())
 
            session.add(department)
 
@@ -273,7 +273,7 @@ class SynchronizerTests(SyncTestCase):
 
            session.rollback()
 
            session.close()
 

	
 
            session = Session(bind=self.local_engine)
 
            session = self.Session(bind=self.local_engine)
 
            department = model.Department()
 
            product = model.Product(department=department)
 
            session.add(product)
 
@@ -289,7 +289,7 @@ class SynchronizerTests(SyncTestCase):
 
    def test_delete_Subdepartment(self):
 
        synchronizer = sync.Synchronizer(self.local_engine, self.remote_engines)
 

	
 
        session = Session(bind=self.local_engine)
 
        session = self.Session(bind=self.local_engine)
 
        subdepartment = model.Subdepartment()
 
        product = model.Product(subdepartment=subdepartment)
 
        session.add(product)
 
@@ -305,7 +305,7 @@ class SynchronizerTests(SyncTestCase):
 
    def test_delete_Family(self):
 
        synchronizer = sync.Synchronizer(self.local_engine, self.remote_engines)
 

	
 
        session = Session(bind=self.local_engine)
 
        session = self.Session(bind=self.local_engine)
 
        family = model.Family()
 
        product = model.Product(family=family)
 
        session.add(product)
 
@@ -321,7 +321,7 @@ class SynchronizerTests(SyncTestCase):
 
    def test_delete_Vendor(self):
 
        synchronizer = sync.Synchronizer(self.local_engine, self.remote_engines)
 

	
 
        session = Session(bind=self.local_engine)
 
        session = self.Session(bind=self.local_engine)
 
        vendor = model.Vendor()
 
        product = model.Product()
 
        product.costs.append(model.ProductCost(vendor=vendor))
 
@@ -337,7 +337,7 @@ class SynchronizerTests(SyncTestCase):
 
    def test_delete_CustomerGroup(self):
 
        synchronizer = sync.Synchronizer(self.local_engine, self.remote_engines)
 

	
 
        session = Session(bind=self.local_engine)
 
        session = self.Session(bind=self.local_engine)
 
        group = model.CustomerGroup()
 
        customer = model.Customer()
 
        customer.groups.append(group)
0 comments (0 inline, 0 general)