Changeset - 9fcbdb023c5d
[Not reviewed]
! ! !
Lance Edgar (lance) - 4 months ago 2024-07-04 10:35:04
lance@edbob.org
fix: refactor code so most things work without sqlalchemy

also refactor tests to help ensure that remains true
54 files changed:
Changeset was too big and was cut off... Show full diff anyway
0 comments (0 inline, 0 general)
rattail/autocomplete/base.py
Show inline comments
 
@@ -2,7 +2,7 @@
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2023 Lance Edgar
 
#  Copyright © 2010-2024 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
@@ -26,11 +26,8 @@ Autocomplete handlers - base class
 

	
 
import re
 

	
 
import sqlalchemy as sa
 
from sqlalchemy import orm
 

	
 

	
 
class Autocompleter(object):
 
class Autocompleter:
 
    """
 
    Base class and partial default implementation for autocomplete
 
    handlers.  It is expected that all autocomplete handlers will
 
@@ -77,7 +74,10 @@ class Autocompleter(object):
 
        self.config = config
 
        self.app = self.config.get_app()
 
        self.enum = config.get_enum()
 
        self.model = config.get_model()
 
        try:
 
            self.model = config.get_model()
 
        except ImportError:
 
            pass
 

	
 
    def get_model_class(self):
 
        return self.model_class
 
@@ -165,6 +165,8 @@ class Autocompleter(object):
 
        """
 
        Apply the actual "search" filtering and return the query.
 
        """
 
        import sqlalchemy as sa
 

	
 
        model_class = self.get_model_class()
 
        column = getattr(model_class, self.autocomplete_fieldname)
 
        criteria = [column.ilike('%{}%'.format(word))
 
@@ -215,6 +217,9 @@ class PhoneMagicMixin(object):
 
        the search term resembles a phone number and if so, do a phone
 
        number search; otherwise a name search.
 
        """
 
        import sqlalchemy as sa
 
        from sqlalchemy import orm
 

	
 
        column = getattr(self.model_class, self.autocomplete_fieldname)
 

	
 
        # define the base query
rattail/batch/handlers.py
Show inline comments
 
@@ -2,7 +2,7 @@
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2023 Lance Edgar
 
#  Copyright © 2010-2024 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
@@ -31,7 +31,6 @@ import warnings
 
import logging
 

	
 
import json
 
from sqlalchemy import orm
 

	
 
from rattail.barcodes import upce_to_upca
 

	
 
@@ -221,7 +220,7 @@ class BatchHandler(object):
 
        ``rowcount`` is incremented, and the :meth:`after_add_row()` hook is
 
        invoked.
 
        """
 
        session = orm.object_session(batch)
 
        session = self.app.get_session(batch)
 
        with session.no_autoflush:
 
            batch.data_rows.append(row)
 
            self.refresh_row(row)
 
@@ -267,6 +266,8 @@ class BatchHandler(object):
 

	
 
        :returns: Integer indicating the number of batches purged.
 
        """
 
        from sqlalchemy import orm
 

	
 
        if delete_all_data and dry_run:
 
            raise ValueError("You can enable (n)either of `dry_run` or "
 
                             "`delete_all_data` but both cannot be True")
 
@@ -470,7 +471,7 @@ class BatchHandler(object):
 
        should *not* override the :meth:`~do_refresh()` method, but callers
 
        *should* use that one directly.
 
        """
 
        session = orm.object_session(batch)
 
        session = self.app.get_session(batch)
 
        self.setup_refresh(batch, progress=progress)
 
        if self.repopulate_when_refresh:
 
            del batch.data_rows[:]
 
@@ -841,7 +842,7 @@ class BatchHandler(object):
 
        *should* override the :meth:`~delete()` method, but callers should
 
        *not* use that one directly.
 
        """
 
        session = orm.object_session(batch)
 
        session = self.app.get_session(batch)
 

	
 
        if 'delete_all_data' in kwargs:
 
            warnings.warn("The 'delete_all_data' kwarg is not supported for "
 
@@ -877,7 +878,7 @@ class BatchHandler(object):
 
        # TODO: in other words i don't even know why this is necessary.  seems
 
        # to me that one fell swoop should not incur FK errors
 
        if hasattr(batch, 'data_rows'):
 
            session = orm.object_session(batch)
 
            session = self.app.get_session(batch)
 

	
 
            def delete(row, i):
 
                session.delete(row)
 
@@ -930,6 +931,8 @@ class BatchHandler(object):
 
        """
 
        Clone the given batch as a new batch, and return the new batch.
 
        """
 
        from sqlalchemy import orm
 

	
 
        self.setup_clone(oldbatch, progress=progress)
 
        batch_class = self.batch_model_class
 
        batch_mapper = orm.class_mapper(batch_class)
 
@@ -963,6 +966,8 @@ class BatchHandler(object):
 
        return batch.data_rows
 

	
 
    def clone_row(self, oldrow):
 
        from sqlalchemy import orm
 

	
 
        row_class = self.batch_model_class.row_class
 
        row_mapper = orm.class_mapper(row_class)
 
        newrow = row_class()
rattail/bouncer/handler.py
Show inline comments
 
@@ -2,7 +2,7 @@
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2023 Lance Edgar
 
#  Copyright © 2010-2024 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
@@ -31,11 +31,9 @@ import warnings
 
from email import message_from_string
 
from email.utils import parsedate_tz, mktime_tz
 

	
 
from sqlalchemy import orm
 
from flufl.bounce import all_failures
 

	
 
from rattail.core import Object
 
from rattail.util import load_object
 

	
 

	
 
log = logging.getLogger(__name__)
 
@@ -59,7 +57,7 @@ class Link(Object):
 
    """
 

	
 

	
 
class BounceHandler(object):
 
class BounceHandler:
 
    """
 
    Default implementation for email bounce handlers.
 
    """
 
@@ -68,9 +66,13 @@ class BounceHandler(object):
 
        self.config = config
 
        self.config_key = config_key
 
        self.app = config.get_app()
 
        self.model = config.get_model()
 
        self.enum = config.get_enum()
 

	
 
        try:
 
            self.model = config.get_model()
 
        except ImportError:
 
            pass
 

	
 
    def get_all_failures(self, msg):
 
        warnings, failures = all_failures(msg)
 
        if warnings:
rattail/clientele.py
Show inline comments
 
@@ -2,7 +2,7 @@
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2023 Lance Edgar
 
#  Copyright © 2010-2024 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
@@ -28,9 +28,6 @@ from collections import OrderedDict
 
import logging
 
import warnings
 

	
 
import sqlalchemy as sa
 
from sqlalchemy import orm
 

	
 
from rattail.app import GenericHandler
 

	
 

	
 
@@ -116,6 +113,8 @@ class ClienteleHandler(GenericHandler):
 
        :returns: List of
 
           :class:`~rattail.db.model.customers.Customer` objects.
 
        """
 
        import sqlalchemy as sa
 

	
 
        model = self.model
 
        customers = session.query(model.Customer)\
 
                           .order_by(model.Customer.name)
 
@@ -371,6 +370,8 @@ class ClienteleHandler(GenericHandler):
 
        :returns: First :class:`~rattail.db.model.customers.Customer`
 
           instance found if there was a match; otherwise ``None``.
 
        """
 
        from sqlalchemy import orm
 

	
 
        if not entry:
 
            return
 

	
 
@@ -401,6 +402,8 @@ class ClienteleHandler(GenericHandler):
 
        :returns: First :class:`~rattail.db.model.customers.Customer`
 
           instance found if there was a match; otherwise ``None``.
 
        """
 
        from sqlalchemy import orm
 

	
 
        if not entry:
 
            return
 

	
rattail/contrib/vendors/catalogs/dutchvalley.py
Show inline comments
 
@@ -31,7 +31,6 @@ import logging
 

	
 
import xlrd
 

	
 
from rattail.db import model
 
from rattail.vendors.catalogs import CatalogParser
 

	
 

	
 
@@ -69,7 +68,7 @@ class DutchValleyCatalogParser(CatalogParser):
 
                continue
 
            code = code.replace(' ', '')
 
            
 
            row = model.VendorCatalogBatchRow()
 
            row = self.make_row()
 
            row.vendor_code = code
 

	
 
            # Try to parse unit and/or case size from description.
rattail/contrib/vendors/catalogs/equalexchange.py
Show inline comments
 
@@ -29,8 +29,6 @@ import decimal
 

	
 
import xlrd
 

	
 
from rattail.db import model
 
from rattail.gpc import GPC
 
from rattail.vendors.catalogs import CatalogParser
 

	
 

	
 
@@ -59,7 +57,7 @@ class EqualExchangeCatalogParser(CatalogParser):
 
        unit_price_pattern = re.compile(r'^\$(\d+\.\d\d)/lb$')
 

	
 
        for r in range(sheet.nrows):
 
            row = model.VendorCatalogBatchRow()
 
            row = self.make_row()
 
            row.brand_name = "Equal Exchange"
 
            row.department_name = department_name
 

	
 
@@ -86,7 +84,7 @@ class EqualExchangeCatalogParser(CatalogParser):
 
            if unit_upc != 'NA':
 
                assert len(unit_upc) == 12
 
                row.item_entry = unit_upc
 
                row.upc = GPC(unit_upc, calc_check_digit=False)
 
                row.upc = self.app.make_gpc(unit_upc, calc_check_digit=False)
 

	
 
            # column E (UNIT PRICE)
 
            unit_price = sheet.cell_value(r, 4)
rattail/contrib/vendors/catalogs/kehe.py
Show inline comments
 
@@ -2,7 +2,7 @@
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2021 Lance Edgar
 
#  Copyright © 2010-2024 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
@@ -24,8 +24,6 @@
 
Vendor catalog parser for KeHE Distributors
 
"""
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
import os
 
import re
 
import datetime
 
@@ -33,11 +31,7 @@ import decimal
 
import logging
 

	
 
import xlrd
 
from sqlalchemy.orm import joinedload
 

	
 
from rattail.db import model
 
from rattail.db.cache import cache_model
 
from rattail.gpc import GPC
 
from rattail.vendors.catalogs import CatalogParser
 

	
 

	
 
@@ -68,17 +62,20 @@ class KeheCatalogParser(CatalogParser):
 
            return datetime.datetime.strptime(match.group(1), '%b %Y').date()
 

	
 
    def parse_rows(self, path, progress=None):
 
        from sqlalchemy import orm
 

	
 
        model = self.app.model
 
        book = xlrd.open_workbook(path)
 
        sheet = book.sheet_by_index(0)
 
        products = cache_model(self.session, model.Product, key='upc',
 
                               query_options=[joinedload(model.Product.costs),
 
                                              joinedload(model.Product.cost)])
 
        products = self.app.cache_model(self.session, model.Product, key='upc',
 
                                        query_options=[orm.joinedload(model.Product.costs),
 
                                                       orm.joinedload(model.Product.cost)])
 

	
 
        for r in range(1, sheet.nrows): # Skip first header row.
 

	
 
            row = model.VendorCatalogBatchRow()
 
            row = self.make_row()
 
            upc = sheet.cell_value(r, 5) or None
 
            row.upc = GPC(int(upc)) if upc else None
 
            row.upc = self.app.make_gpc(int(upc)) if upc else None
 
            row.brand_name = sheet.cell_value(r, 1)
 
            row.description = sheet.cell_value(r, 2)
 
            row.size = sheet.cell_value(r, 3)
rattail/contrib/vendors/catalogs/lotuslight.py
Show inline comments
 
@@ -2,7 +2,7 @@
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2021 Lance Edgar
 
#  Copyright © 2010-2024 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
@@ -24,8 +24,6 @@
 
Vendor catalog parser for Lotus Light
 
"""
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
import os
 
import re
 
import datetime
 
@@ -33,8 +31,6 @@ import logging
 

	
 
import xlrd
 

	
 
from rattail.db import model
 
from rattail.gpc import GPC
 
from rattail.vendors.catalogs import CatalogParser
 

	
 

	
 
@@ -72,16 +68,16 @@ class LotusLightCatalogParser(CatalogParser):
 
        sheet = book.sheet_by_index(0)
 
        for i in range(1, sheet.nrows): # skip first header row
 

	
 
            row = model.VendorCatalogBatchRow()
 
            row = self.make_row()
 

	
 
            upc = sheet.cell_value(i, 1)
 
            if upc:
 
                if upc.isdigit():
 
                    # Sometimes their catalog doesn't include the check digit..?
 
                    if len(upc) == 11:
 
                        row.upc = GPC(int(upc), calc_check_digit='upc')
 
                        row.upc = self.app.make_gpc(int(upc), calc_check_digit='upc')
 
                    else:
 
                        row.upc = GPC(int(upc))
 
                        row.upc = self.app.make_gpc(int(upc))
 
                else:
 
                    log.warning("invalid UPC at row {0}: {1}".format(i + 1, repr(upc)))
 

	
rattail/contrib/vendors/catalogs/unfi.py
Show inline comments
 
@@ -30,8 +30,6 @@ import decimal
 

	
 
import xlrd
 

	
 
from rattail.db import model
 
from rattail.gpc import GPC
 
from rattail.vendors.catalogs import CatalogParser
 
from rattail.csvutil import UnicodeDictReader
 

	
 
@@ -71,12 +69,12 @@ class UNFICatalogParser(CatalogParser):
 
            # Warn if UPC is not valid.
 
            upc = sheet.cell_value(r, 1)
 
            if self.upc_pattern.match(upc):
 
                upc = GPC(upc.replace('-', ''))
 
                upc = self.app.make_gpc(upc.replace('-', ''))
 
            else:
 
                log.warning("invalid upc at row {0}: {1}".format(r + 1, upc))
 
                upc = None
 

	
 
            row = model.VendorCatalogBatchRow()
 
            row = self.make_row()
 
            row.upc = upc
 
            row.vendor_code = code.replace('-', '')
 
            row.brand_name = sheet.cell_value(r, 2)
 
@@ -107,12 +105,12 @@ class UNFICatalogParser2(UNFICatalogParser):
 
            # Warn if UPC is not valid.
 
            upc = sheet.cell_value(r, 1)
 
            if self.upc_pattern.match(upc):
 
                upc = GPC(upc.replace('-', ''))
 
                upc = self.app.make_gpc(upc.replace('-', ''))
 
            else:
 
                log.warning("invalid upc at row {0}: {1}".format(r + 1, upc))
 
                upc = None
 

	
 
            row = model.VendorCatalogBatchRow()
 
            row = self.make_row()
 
            row.upc = upc
 
            row.vendor_code = code.replace('-', '')
 
            row.brand_name = sheet.cell_value(r, 2)
rattail/db/__init__.py
Show inline comments
 
@@ -139,9 +139,9 @@ class ConfigExtension(BaseExtension):
 
    key = 'rattail.db'
 

	
 
    def configure(self, config):
 
        from rattail.db.config import configure_session
 

	
 
        if Session:
 
            from rattail.db.config import configure_session
 
            from wuttjamaican.db import get_engines
 

	
 
            # Add Rattail database connection info to config.
rattail/db/util.py
Show inline comments
 
@@ -30,231 +30,13 @@ import pprint
 
import logging
 
import warnings
 

	
 
import sqlalchemy as sa
 
from sqlalchemy import orm
 
from sqlalchemy.ext.associationproxy import ASSOCIATION_PROXY
 

	
 

	
 
log = logging.getLogger(__name__)
 

	
 

	
 
class CounterMagic(object):
 
    """
 
    Provides magic counter values, to simulate PostgreSQL sequence.
 
    """
 

	
 
    def __init__(self, config):
 
        self.config = config
 
        self.metadata = sa.MetaData()
 

	
 
    def next_value(self, session, key):
 
        """
 
        Increment and return the next counter value for given key.
 
        """
 
        engine = session.bind
 
        table = sa.Table('counter_{}'.format(key), self.metadata,
 
                         sa.Column('value', sa.Integer(), primary_key=True))
 
        table.create(engine, checkfirst=True)
 
        with engine.begin() as cxn:
 
            result = cxn.execute(table.insert())
 
            return result.lastrowid
 

	
 

	
 
class QuerySequence(object):
 
    """
 
    Simple wrapper for a SQLAlchemy (or Django, or other?) query, to make it
 
    sort of behave like a normal sequence, as much as needed to e.g. make an
 
    importer happy.
 
    """
 

	
 
    def __init__(self, query):
 
        self.query = query
 

	
 
    def __len__(self):
 
        try:
 
            return len(self.query)
 
        except TypeError:
 
            return self.query.count()
 

	
 
    def __iter__(self):
 
        return iter(self.query)
 

	
 

	
 
def short_session(
 
        session=None,
 
        Session=None,
 
        commit=False,
 
        factory=None,
 
        config=None):
 
    """
 
    Compatibility wrapper around
 
    :class:`wuttjamaican:wuttjamaican.db.sess.short_session`.
 

	
 
    Note that this wrapper is a function whereas the upsream version
 
    is a proper context manager (class).  So calling this function
 
    will return a new instance of the upsream class.
 

	
 
    You should always specify keyword arguments when calling this
 
    function, since the arg order is different between this function
 
    and the upstream class.  And note that this function will
 
    eventually be deprecated and removed, so new code should call
 
    upstream directly.
 
    """
 
    from wuttjamaican.db import short_session
 

	
 
    warnings.warn("rattail.db.util.short_session() is deprecated; "
 
                  "please use wuttjamaican.db.short_session() instead",
 
                  DeprecationWarning, stacklevel=2)
 

	
 
    if not factory and Session:
 
        warnings.warn("passing a 'Session' kwarg is deprecated; "
 
                      "please pass 'factory' instead",
 
                      DeprecationWarning, stacklevel=2)
 
        factory = Session
 

	
 
    if not session and not factory and not config:
 
        from rattail.db import Session
 
        factory = Session
 

	
 
    return short_session(config=config, factory=factory, session=session, commit=commit)
 

	
 

	
 
def finalize_session(session, dry_run=False, success=True):
 
    """
 
    Wrap up the given session, per the given arguments.  This is meant
 
    to provide a simple convenience, for commands which must do work
 
    within a DB session, but would like to support a "dry run" mode.
 
    """
 
    if dry_run:
 
        session.rollback()
 
        log.info("dry run, so transaction was aborted")
 
    elif success:
 
        session.commit()
 
        log.info("transaction was committed")
 
    else:
 
        session.rollback()
 
        log.warning("action failed, so transaction was rolled back")
 
    session.close()
 

	
 

	
 
def get_fieldnames(config, obj, columns=True, proxies=True,
 
                   relations=False):
 
    """
 
    Produce a simple list of fieldnames for the given class,
 
    reflecting its table columns as well as any association proxies,
 
    and optionally, relationships.
 

	
 
    :param obj: Either a class or instance of a class, which derives
 
       from the base model class.
 

	
 
    :param columns: Whether or not to include simple columns.
 

	
 
    :param relations: Whether or not to include fields which represent
 
       relationships to other models.  If ``False`` (the default) then
 
       only "simple" fields will be included.
 

	
 
    :param proxies: Whether or not to include association proxy fields.
 
    """
 
    if isinstance(obj, type):
 
        cls = obj
 
    else:
 
        cls = obj.__class__
 

	
 
    mapper = orm.class_mapper(cls)
 
    fields = []
 

	
 
    # columns + relations
 
    prop_classes = []
 
    if columns:
 
        prop_classes.append(orm.ColumnProperty)
 
    if relations:
 
        prop_classes.append(orm.RelationshipProperty)
 
    if prop_classes:
 
        prop_classes = tuple(prop_classes)
 
        fields.extend([prop.key for prop in mapper.iterate_properties
 
                       if isinstance(prop, prop_classes)
 
                       and not prop.key.startswith('_')
 
                       and prop.key != 'versions'])
 

	
 
    # proxies
 
    if proxies:
 
        for key, desc in sa.inspect(cls).all_orm_descriptors.items():
 
            if desc.extension_type == ASSOCIATION_PROXY:
 

	
 
                # must avoid association proxies which in turn use
 
                # relationships, unless those are wanted by caller
 
                if not relations:
 
                    # TODO: this probably needs help, i stumbled thru it..
 
                    prop = sa.inspect(desc.for_class(cls).target_class)\
 
                             .get_property(desc.value_attr)
 
                    if isinstance(prop, orm.RelationshipProperty):
 
                        continue
 

	
 
                fields.append(key)
 

	
 
    return fields
 

	
 

	
 
def maxlen(attr):
 
    """
 
    Return the maximum length for the given attribute.
 
    """
 
    if len(attr.property.columns) == 1:
 
        type_ = attr.property.columns[0].type
 
        return getattr(type_, 'length', None)
 

	
 

	
 
def maxval(attr):
 
    """
 
    Return the maximum value possible for the given attribute.
 
    """
 
    if len(attr.property.columns) == 1:
 
        typ = attr.property.columns[0].type
 

	
 
        if isinstance(typ, sa.Numeric) and not isinstance(typ, sa.Float):
 
            maxint = pow(10, (typ.precision - typ.scale)) - 1
 
            return decimal.Decimal('{}.{}'.format(maxint,
 
                                                  '9' * typ.scale))
 

	
 

	
 
def make_topo_sortkey(model, metadata=None):
 
    """
 
    Returns a function suitable for use as a ``key`` kwarg to a standard Python
 
    sorting call.  This key function will expect a single class mapper and
 
    return a sequence number associated with that model.  The sequence is
 
    determined by SQLAlchemy's topological table sorting.
 
    """
 
    if metadata is None:
 
        metadata = model.Base.metadata
 

	
 
    tables = {}
 
    for i, table in enumerate(metadata.sorted_tables, 1):
 
        tables[table.name] = i
 

	
 
    # log.debug("topo sortkeys for '{}' will be:\n{}".format(model.__name__, pprint.pformat(
 
    #     [(i, name) for name, i in sorted(tables.items(), key=lambda t: t[1])])))
 

	
 
    def sortkey(name):
 
        if hasattr(model, name):
 
            mapper = orm.class_mapper(getattr(model, name))
 
            return tuple(tables[t.name] for t in mapper.tables)
 
        else:
 
            return tuple()
 

	
 
    return sortkey
 

	
 

	
 
def make_full_description(brand_name, description, size):
 
    """
 
    Combine the given field values into a complete description.
 
    """
 
    fields = [
 
        brand_name or '',
 
        description or '',
 
        size or '']
 
    fields = [f.strip() for f in fields if f.strip()]
 
    return ' '.join(fields)
 

	
 
##############################
 
# people
 
##############################
 

	
 
def normalize_full_name(first_name, last_name):
 
    """
 
@@ -273,7 +55,7 @@ def normalize_full_name(first_name, last_name):
 

	
 

	
 
##############################
 
# phone number validation
 
# phone numbers
 
##############################
 

	
 
class PhoneValidator(object):
 
@@ -319,3 +101,237 @@ def format_phone_number(number):
 
    if number and len(number) == 10:
 
        return '({}) {}-{}'.format(number[:3], number[3:6], number[6:])
 
    return original
 

	
 

	
 
##############################
 
# products
 
##############################
 

	
 
def make_full_description(brand_name, description, size):
 
    """
 
    Combine the given field values into a complete description.
 
    """
 
    fields = [
 
        brand_name or '',
 
        description or '',
 
        size or '']
 
    fields = [f.strip() for f in fields if f.strip()]
 
    return ' '.join(fields)
 

	
 

	
 
##############################
 
# database
 
##############################
 

	
 
try:
 
    import sqlalchemy as sa
 
    from sqlalchemy import orm
 
    from sqlalchemy.ext.associationproxy import ASSOCIATION_PROXY
 
except ImportError:
 
    pass
 
else:
 

	
 
    class CounterMagic(object):
 
        """
 
        Provides magic counter values, to simulate PostgreSQL sequence.
 
        """
 

	
 
        def __init__(self, config):
 
            self.config = config
 
            self.metadata = sa.MetaData()
 

	
 
        def next_value(self, session, key):
 
            """
 
            Increment and return the next counter value for given key.
 
            """
 
            engine = session.bind
 
            table = sa.Table('counter_{}'.format(key), self.metadata,
 
                             sa.Column('value', sa.Integer(), primary_key=True))
 
            table.create(engine, checkfirst=True)
 
            with engine.begin() as cxn:
 
                result = cxn.execute(table.insert())
 
                return result.lastrowid
 

	
 

	
 
    class QuerySequence(object):
 
        """
 
        Simple wrapper for a SQLAlchemy (or Django, or other?) query, to make it
 
        sort of behave like a normal sequence, as much as needed to e.g. make an
 
        importer happy.
 
        """
 

	
 
        def __init__(self, query):
 
            self.query = query
 

	
 
        def __len__(self):
 
            try:
 
                return len(self.query)
 
            except TypeError:
 
                return self.query.count()
 

	
 
        def __iter__(self):
 
            return iter(self.query)
 

	
 

	
 
    def short_session(
 
            session=None,
 
            Session=None,
 
            commit=False,
 
            factory=None,
 
            config=None):
 
        """
 
        Compatibility wrapper around
 
        :class:`wuttjamaican:wuttjamaican.db.sess.short_session`.
 

	
 
        Note that this wrapper is a function whereas the upsream version
 
        is a proper context manager (class).  So calling this function
 
        will return a new instance of the upsream class.
 

	
 
        You should always specify keyword arguments when calling this
 
        function, since the arg order is different between this function
 
        and the upstream class.  And note that this function will
 
        eventually be deprecated and removed, so new code should call
 
        upstream directly.
 
        """
 
        from wuttjamaican.db import short_session
 

	
 
        warnings.warn("rattail.db.util.short_session() is deprecated; "
 
                      "please use wuttjamaican.db.short_session() instead",
 
                      DeprecationWarning, stacklevel=2)
 

	
 
        if not factory and Session:
 
            warnings.warn("passing a 'Session' kwarg is deprecated; "
 
                          "please pass 'factory' instead",
 
                          DeprecationWarning, stacklevel=2)
 
            factory = Session
 

	
 
        if not session and not factory and not config:
 
            from rattail.db import Session
 
            factory = Session
 

	
 
        return short_session(config=config, factory=factory, session=session, commit=commit)
 

	
 

	
 
    def finalize_session(session, dry_run=False, success=True):
 
        """
 
        Wrap up the given session, per the given arguments.  This is meant
 
        to provide a simple convenience, for commands which must do work
 
        within a DB session, but would like to support a "dry run" mode.
 
        """
 
        if dry_run:
 
            session.rollback()
 
            log.info("dry run, so transaction was aborted")
 
        elif success:
 
            session.commit()
 
            log.info("transaction was committed")
 
        else:
 
            session.rollback()
 
            log.warning("action failed, so transaction was rolled back")
 
        session.close()
 

	
 

	
 
    def get_fieldnames(config, obj, columns=True, proxies=True,
 
                       relations=False):
 
        """
 
        Produce a simple list of fieldnames for the given class,
 
        reflecting its table columns as well as any association proxies,
 
        and optionally, relationships.
 

	
 
        :param obj: Either a class or instance of a class, which derives
 
           from the base model class.
 

	
 
        :param columns: Whether or not to include simple columns.
 

	
 
        :param relations: Whether or not to include fields which represent
 
           relationships to other models.  If ``False`` (the default) then
 
           only "simple" fields will be included.
 

	
 
        :param proxies: Whether or not to include association proxy fields.
 
        """
 
        if isinstance(obj, type):
 
            cls = obj
 
        else:
 
            cls = obj.__class__
 

	
 
        mapper = orm.class_mapper(cls)
 
        fields = []
 

	
 
        # columns + relations
 
        prop_classes = []
 
        if columns:
 
            prop_classes.append(orm.ColumnProperty)
 
        if relations:
 
            prop_classes.append(orm.RelationshipProperty)
 
        if prop_classes:
 
            prop_classes = tuple(prop_classes)
 
            fields.extend([prop.key for prop in mapper.iterate_properties
 
                           if isinstance(prop, prop_classes)
 
                           and not prop.key.startswith('_')
 
                           and prop.key != 'versions'])
 

	
 
        # proxies
 
        if proxies:
 
            for key, desc in sa.inspect(cls).all_orm_descriptors.items():
 
                if desc.extension_type == ASSOCIATION_PROXY:
 

	
 
                    # must avoid association proxies which in turn use
 
                    # relationships, unless those are wanted by caller
 
                    if not relations:
 
                        # TODO: this probably needs help, i stumbled thru it..
 
                        prop = sa.inspect(desc.for_class(cls).target_class)\
 
                                 .get_property(desc.value_attr)
 
                        if isinstance(prop, orm.RelationshipProperty):
 
                            continue
 

	
 
                    fields.append(key)
 

	
 
        return fields
 

	
 

	
 
    def maxlen(attr):
 
        """
 
        Return the maximum length for the given attribute.
 
        """
 
        if len(attr.property.columns) == 1:
 
            type_ = attr.property.columns[0].type
 
            return getattr(type_, 'length', None)
 

	
 

	
 
    def maxval(attr):
 
        """
 
        Return the maximum value possible for the given attribute.
 
        """
 
        if len(attr.property.columns) == 1:
 
            typ = attr.property.columns[0].type
 

	
 
            if isinstance(typ, sa.Numeric) and not isinstance(typ, sa.Float):
 
                maxint = pow(10, (typ.precision - typ.scale)) - 1
 
                return decimal.Decimal('{}.{}'.format(maxint,
 
                                                      '9' * typ.scale))
 

	
 

	
 
    def make_topo_sortkey(model, metadata=None):
 
        """
 
        Returns a function suitable for use as a ``key`` kwarg to a standard Python
 
        sorting call.  This key function will expect a single class mapper and
 
        return a sequence number associated with that model.  The sequence is
 
        determined by SQLAlchemy's topological table sorting.
 
        """
 
        if metadata is None:
 
            metadata = model.Base.metadata
 

	
 
        tables = {}
 
        for i, table in enumerate(metadata.sorted_tables, 1):
 
            tables[table.name] = i
 

	
 
        # log.debug("topo sortkeys for '{}' will be:\n{}".format(model.__name__, pprint.pformat(
 
        #     [(i, name) for name, i in sorted(tables.items(), key=lambda t: t[1])])))
 

	
 
        def sortkey(name):
 
            if hasattr(model, name):
 
                mapper = orm.class_mapper(getattr(model, name))
 
                return tuple(tables[t.name] for t in mapper.tables)
 
            else:
 
                return tuple()
 

	
 
        return sortkey
rattail/importing/__init__.py
Show inline comments
 
@@ -2,7 +2,7 @@
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2018 Lance Edgar
 
#  Copyright © 2010-2024 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
@@ -24,11 +24,14 @@
 
Data Importing Framework
 
"""
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from .importers import Importer, FromCSV, FromQuery, FromDjango, BatchImporter, BulkImporter
 
from .sqlalchemy import FromSQLAlchemy, ToSQLAlchemy
 
from .postgresql import BulkToPostgreSQL
 
from .handlers import ImportHandler, BulkImportHandler, FromSQLAlchemyHandler, ToSQLAlchemyHandler
 
from .rattail import FromRattailHandler, ToRattailHandler
 
from . import model
 
from .handlers import ImportHandler, BulkImportHandler
 

	
 
try:
 
    from .sqlalchemy import FromSQLAlchemy, ToSQLAlchemy
 
    from .postgresql import BulkToPostgreSQL
 
    from .handlers import FromSQLAlchemyHandler, ToSQLAlchemyHandler
 
    from .rattail import FromRattailHandler, ToRattailHandler
 
    from . import model
 
except ImportError:
 
    pass # sqlalchemy not installed
rattail/importing/handlers.py
Show inline comments
 
@@ -33,7 +33,6 @@ from collections import OrderedDict
 

	
 
import humanize
 
import markupsafe
 
import sqlalchemy as sa
 

	
 
from rattail.util import get_object_spec
 

	
 
@@ -93,7 +92,10 @@ class ImportHandler(object):
 
        if self.config:
 
            self.app = self.config.get_app()
 
            self.enum = self.config.get_enum()
 
            self.model = self.config.get_model()
 
            try:
 
                self.model = self.config.get_model()
 
            except ImportError:
 
                pass
 

	
 
        # nb. must assign extra attrs before we get_importers() since
 
        # attrs may need to affect that behavior
 
@@ -615,69 +617,6 @@ class BulkImportHandler(ImportHandler):
 
        return changes
 

	
 

	
 
class FromSQLAlchemyHandler(ImportHandler):
 
    """
 
    Handler for imports for which the host data source is represented by a
 
    SQLAlchemy engine and ORM.
 
    """
 
    host_session = None
 

	
 
    def make_host_session(self):
 
        """
 
        Subclasses must override this to define the host database connection.
 
        """
 
        raise NotImplementedError
 

	
 
    def get_importer_kwargs(self, key, **kwargs):
 
        kwargs = super(FromSQLAlchemyHandler, self).get_importer_kwargs(key, **kwargs)
 
        kwargs.setdefault('host_session', self.host_session)
 
        return kwargs
 

	
 
    def begin_host_transaction(self):
 
        self.host_session = self.make_host_session()
 

	
 
    def rollback_host_transaction(self):
 
        self.host_session.rollback()
 
        self.host_session.close()
 
        self.host_session = None
 

	
 
    def commit_host_transaction(self):
 
        self.host_session.commit()
 
        self.host_session.close()
 
        self.host_session = None
 

	
 

	
 
class ToSQLAlchemyHandler(ImportHandler):
 
    """
 
    Handler for imports which target a SQLAlchemy ORM on the local side.
 
    """
 
    session = None
 

	
 
    def make_session(self):
 
        """
 
        Subclasses must override this to define the local database connection.
 
        """
 
        raise NotImplementedError
 

	
 
    def get_importer_kwargs(self, key, **kwargs):
 
        kwargs = super(ToSQLAlchemyHandler, self).get_importer_kwargs(key, **kwargs)
 
        kwargs.setdefault('session', self.session)
 
        return kwargs
 

	
 
    def begin_local_transaction(self):
 
        self.session = self.make_session()
 

	
 
    def rollback_local_transaction(self):
 
        self.session.rollback()
 
        self.session.close()
 
        self.session = None
 

	
 
    def commit_local_transaction(self):
 
        self.session.commit()
 
        self.session.close()
 
        self.session = None
 

	
 

	
 
class FromFileHandler(ImportHandler):
 
    """
 
    Handler for imports whose data comes from file(s) on the host side.
 
@@ -792,3 +731,72 @@ class RecordRenderer(object):
 
                    name = 'people'
 
                url = '{}/{}/{{uuid}}'.format(url, name)
 
                return url.format(uuid=record.uuid)
 

	
 

	
 
try:
 
    import sqlalchemy as sa
 
except ImportError:
 
    pass
 
else:
 

	
 
    class FromSQLAlchemyHandler(ImportHandler):
 
        """
 
        Handler for imports for which the host data source is represented by a
 
        SQLAlchemy engine and ORM.
 
        """
 
        host_session = None
 

	
 
        def make_host_session(self):
 
            """
 
            Subclasses must override this to define the host database connection.
 
            """
 
            raise NotImplementedError
 

	
 
        def get_importer_kwargs(self, key, **kwargs):
 
            kwargs = super(FromSQLAlchemyHandler, self).get_importer_kwargs(key, **kwargs)
 
            kwargs.setdefault('host_session', self.host_session)
 
            return kwargs
 

	
 
        def begin_host_transaction(self):
 
            self.host_session = self.make_host_session()
 

	
 
        def rollback_host_transaction(self):
 
            self.host_session.rollback()
 
            self.host_session.close()
 
            self.host_session = None
 

	
 
        def commit_host_transaction(self):
 
            self.host_session.commit()
 
            self.host_session.close()
 
            self.host_session = None
 

	
 

	
 
    class ToSQLAlchemyHandler(ImportHandler):
 
        """
 
        Handler for imports which target a SQLAlchemy ORM on the local side.
 
        """
 
        session = None
 

	
 
        def make_session(self):
 
            """
 
            Subclasses must override this to define the local database connection.
 
            """
 
            raise NotImplementedError
 

	
 
        def get_importer_kwargs(self, key, **kwargs):
 
            kwargs = super(ToSQLAlchemyHandler, self).get_importer_kwargs(key, **kwargs)
 
            kwargs.setdefault('session', self.session)
 
            return kwargs
 

	
 
        def begin_local_transaction(self):
 
            self.session = self.make_session()
 

	
 
        def rollback_local_transaction(self):
 
            self.session.rollback()
 
            self.session.close()
 
            self.session = None
 

	
 
        def commit_local_transaction(self):
 
            self.session.commit()
 
            self.session.close()
 
            self.session = None
rattail/importing/importers.py
Show inline comments
 
@@ -2,7 +2,7 @@
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2023 Lance Edgar
 
#  Copyright © 2010-2024 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
@@ -28,10 +28,7 @@ import datetime
 
import logging
 
from collections import OrderedDict
 

	
 
from rattail.db import cache
 
from rattail.db.util import QuerySequence
 
from rattail.time import make_utc
 
from rattail.util import data_diffs, progress_loop
 
from rattail.util import data_diffs
 
from rattail.csvutil import UnicodeDictReader
 

	
 

	
 
@@ -127,7 +124,10 @@ class Importer(object):
 
        self.config = config
 
        self.app = config.get_app() if config else None
 
        self.enum = config.get_enum() if config else None
 
        self.model = config.get_model() if config else None
 
        try:
 
            self.model = config.get_model() if config else None
 
        except ImportError:
 
            pass
 
        self.model_class = kwargs.pop('model_class', self.get_model_class())
 
        if key is not None:
 
            self.key = key
 
@@ -261,7 +261,7 @@ class Importer(object):
 

	
 
    def progress_loop(self, func, items, factory=None, **kwargs):
 
        factory = factory or self.progress
 
        return progress_loop(func, items, factory, **kwargs)
 
        return self.app.progress_loop(func, items, factory, **kwargs)
 

	
 
    def unique_data(self, host_data, warn=True):
 
        # Prune duplicate keys from host/source data.  This is for the sake of
 
@@ -284,7 +284,7 @@ class Importer(object):
 
        of where data is coming from or where it's headed.  Note that this
 
        method handles deletions as well as adds/updates.
 
        """
 
        self.now = now or make_utc(tzinfo=True)
 
        self.now = now or self.app.make_utc(tzinfo=True)
 
        if kwargs:
 
            self._setup(**kwargs)
 
        self.setup()
 
@@ -714,7 +714,7 @@ class Importer(object):
 
        if not session:
 
            session = self.session
 
        kwargs.setdefault('progress', self.progress)
 
        return cache.cache_model(session, model, **kwargs)
 
        return self.app.cache_model(session, model, **kwargs)
 

	
 
    def data_diffs(self, local_data, host_data):
 
        """
 
@@ -858,6 +858,8 @@ class FromQuery(Importer):
 
        """
 
        Returns (raw) query results as a sequence.
 
        """
 
        from rattail.db.util import QuerySequence
 

	
 
        query = self.query()
 
        if hasattr(self, 'sorted_query'):
 
            query = self.sorted_query(query)
 
@@ -888,7 +890,7 @@ class BatchImporter(Importer):
 
        if host_data is not None:
 
            raise ValueError("User-provided host data is not supported")
 

	
 
        self.now = now or make_utc(tzinfo=True)
 
        self.now = now or self.app.make_utc(tzinfo=True)
 
        if kwargs:
 
            self._setup(**kwargs)
 
        self.setup()
 
@@ -947,7 +949,7 @@ class BulkImporter(Importer):
 
    """
 

	
 
    def import_data(self, host_data=None, now=None, **kwargs):
 
        self.now = now or make_utc(tzinfo=True)
 
        self.now = now or self.app.make_utc(tzinfo=True)
 
        if kwargs:
 
            self._setup(**kwargs)
 
        self.setup()
rattail/monitoring.py
Show inline comments
 
@@ -46,7 +46,10 @@ class MonitorAction(object):
 
        self.config = config
 
        self.app = config.get_app()
 
        self.enum = config.get_enum()
 
        self.model = config.get_model()
 
        try:
 
            self.model = config.get_model()
 
        except ImportError:
 
            pass
 

	
 
    def __call__(self, *args, **kwargs):
 
        """
rattail/people.py
Show inline comments
 
@@ -28,9 +28,6 @@ See also :doc:`rattail-manual:base/handlers/other/people`.
 

	
 
import warnings
 

	
 
import sqlalchemy as sa
 
from sqlalchemy import orm
 

	
 
from rattail.app import GenericHandler, MergeMixin
 

	
 

	
 
@@ -96,6 +93,8 @@ class PeopleHandler(GenericHandler, MergeMixin):
 
        default this will do a lookup based on the configured Customer
 
        key field.  Override as needed.
 
        """
 
        from sqlalchemy import orm
 

	
 
        model = self.model
 
        field = self.app.get_customer_key_field()
 

	
 
@@ -557,6 +556,8 @@ class PeopleHandler(GenericHandler, MergeMixin):
 
        """
 
        If there was a merge request(s) for this pair, mark it complete.
 
        """
 
        import sqlalchemy as sa
 

	
 
        session = self.app.get_session(keeping)
 
        model = self.model
 
        merge_requests = session.query(model.MergePeopleRequest)\
rattail/problems/rattail.py
Show inline comments
 
@@ -2,7 +2,7 @@
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2023 Lance Edgar
 
#  Copyright © 2010-2024 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
@@ -26,8 +26,6 @@ Problem Reports for Rattail Systems
 

	
 
import datetime
 

	
 
from sqlalchemy import orm
 

	
 
from rattail.problems import ProblemReport
 

	
 

	
 
@@ -82,6 +80,8 @@ class ProductWithoutPrice(RattailProblemReport):
 
    problem_title = "Products with no price"
 

	
 
    def find_problems(self, **kwargs):
 
        from sqlalchemy import orm
 

	
 
        problems = []
 
        session = self.app.make_session()
 
        model = self.model
 
@@ -111,13 +111,15 @@ class StaleInventoryBatch(RattailProblemReport):
 
    problem_title = "Stale inventory batches"
 

	
 
    def __init__(self, *args, **kwargs):
 
        super(StaleInventoryBatch, self).__init__(*args, **kwargs)
 
        super().__init__(*args, **kwargs)
 

	
 
        self.cutoff_days = self.config.getint(
 
            'rattail', 'problems.stale_inventory_batches.cutoff_days',
 
            default=4)
 

	
 
    def find_problems(self, **kwargs):
 
        from sqlalchemy import orm
 

	
 
        session = self.app.make_session()
 
        model = self.model
 

	
 
@@ -137,8 +139,7 @@ class StaleInventoryBatch(RattailProblemReport):
 
        return batches
 

	
 
    def get_email_context(self, problems, **kwargs):
 
        kwargs = super(StaleInventoryBatch, self).get_email_context(problems,
 
                                                                    **kwargs)
 
        kwargs = super().get_email_context(problems, **kwargs)
 
        kwargs['cutoff_days'] = self.cutoff_days
 
        return kwargs
 

	
 
@@ -151,6 +152,8 @@ class UpgradePending(RattailProblemReport):
 
    problem_title = "Pending upgrade"
 

	
 
    def find_problems(self, **kwargs):
 
        from sqlalchemy import orm
 

	
 
        session = self.app.make_session()
 
        model = self.model
 
        upgrades = session.query(model.Upgrade)\
rattail/products.py
Show inline comments
 
@@ -28,9 +28,6 @@ import decimal
 
import warnings
 
import logging
 

	
 
import sqlalchemy as sa
 
from sqlalchemy import orm
 

	
 
from rattail import pod
 
from rattail.app import GenericHandler, MergeMixin
 
from rattail.gpc import GPC
 
@@ -338,6 +335,8 @@ class ProductsHandler(GenericHandler, MergeMixin):
 
        :returns: First :class:`~rattail.db.model.products.Product`
 
           instance found if there was a match; otherwise ``None``.
 
        """
 
        from sqlalchemy import orm
 

	
 
        if not entry:
 
            return
 

	
 
@@ -385,6 +384,8 @@ class ProductsHandler(GenericHandler, MergeMixin):
 
        :returns: First :class:`~rattail.db.model.products.Product`
 
           instance found if there was a match; otherwise ``None``.
 
        """
 
        from sqlalchemy import orm
 

	
 
        if not entry:
 
            return
 

	
 
@@ -478,6 +479,8 @@ class ProductsHandler(GenericHandler, MergeMixin):
 
        :returns: First :class:`~rattail.db.model.products.Product`
 
           instance found if there was a match; otherwise ``None``.
 
        """
 
        from sqlalchemy import orm
 

	
 
        if not gpc:
 
            return
 

	
 
@@ -1150,6 +1153,9 @@ class ProductsHandler(GenericHandler, MergeMixin):
 
        :returns: SIL code for the UOM, as string (e.g. ``'49'``), or ``None``
 
           if no matching code was found.
 
        """
 
        import sqlalchemy as sa
 
        from sqlalchemy import orm
 

	
 
        model = self.model
 
        query = session.query(model.UnitOfMeasure)
 
        if uppercase:
rattail/reporting/handlers.py
Show inline comments
 
@@ -2,7 +2,7 @@
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2023 Lance Edgar
 
#  Copyright © 2010-2024 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
@@ -27,8 +27,8 @@ Report Handlers
 
import datetime
 
import decimal
 

	
 
from rattail.db import model
 
from rattail.util import load_entry_points, load_object
 
from wuttjamaican.util import load_entry_points
 

	
 
from rattail.time import localtime
 

	
 

	
 
@@ -86,6 +86,7 @@ class ReportHandler(object):
 
        """
 
        Generate and return output for the given report and params.
 
        """
 
        model = self.model
 
        data = report.make_data(session, params, progress=progress, **kwargs)
 

	
 
        output = model.ReportOutput()
 
@@ -159,7 +160,8 @@ def get_report_handler(config, **kwargs):
 
    """
 
    Create and return the configured :class:`ReportHandler` instance.
 
    """
 
    app = config.get_app()
 
    spec = config.get('rattail.reports', 'handler')
 
    if spec:
 
        return load_object(spec)(config)
 
        return app.load_object(spec)(config)
 
    return ReportHandler(config)
rattail/reporting/reports.py
Show inline comments
 
@@ -26,11 +26,8 @@ Report Definitions
 

	
 
import re
 

	
 
from rattail.db import cache
 
from rattail.util import progress_loop
 

	
 

	
 
class Report(object):
 
class Report:
 
    """
 
    Base class for all reports.
 

	
 
@@ -152,10 +149,10 @@ class Report(object):
 
        raise NotImplementedError
 

	
 
    def cache_model(self, session, model, **kwargs):
 
        return cache.cache_model(session, model, **kwargs)
 
        return self.app.cache_model(session, model, **kwargs)
 

	
 
    def progress_loop(self, func, items, factory, **kwargs):
 
        return progress_loop(func, items, factory, **kwargs)
 
        return self.app.progress_loop(func, items, factory, **kwargs)
 

	
 

	
 
class ReportParam(object):
rattail/trainwreck/handler.py
Show inline comments
 
@@ -28,7 +28,6 @@ import warnings
 
from collections import OrderedDict
 

	
 
from rattail.app import GenericHandler
 
from rattail.trainwreck.db import Session as TrainwreckSession
 

	
 

	
 
class TrainwreckHandler(GenericHandler):
 
@@ -68,10 +67,12 @@ class TrainwreckHandler(GenericHandler):
 
        """
 
        Make a session for a Trainwreck DB.
 
        """
 
        from rattail.trainwreck.db import Session
 

	
 
        if 'bind' not in kwargs:
 
            engine = self.config.trainwreck_engines[dbkey]
 
            kwargs['bind'] = engine
 
        return TrainwreckSession(**kwargs)
 
        return Session(**kwargs)
 

	
 
    def get_model(self, **kwargs):
 
        """
rattail/vendors/catalogs.py
Show inline comments
 
@@ -2,7 +2,7 @@
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2023 Lance Edgar
 
#  Copyright © 2010-2024 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
@@ -24,8 +24,8 @@
 
Vendor Catalogs
 
"""
 

	
 
import decimal
 
import warnings
 
from decimal import Decimal
 

	
 
from rattail.exceptions import RattailError
 
from rattail.util import load_entry_points
 
@@ -58,9 +58,13 @@ class CatalogParser(object):
 
        if config:
 
            self.config = config
 
            self.app = config.get_app()
 
            self.model = config.get_model()
 
            self.enum = config.get_enum()
 

	
 
            try:
 
                self.model = config.get_model()
 
            except ImportError:
 
                pass # sqlalchemy not installed
 

	
 
    @property
 
    def key(self):
 
        """
 
@@ -102,14 +106,14 @@ class CatalogParser(object):
 
        # No reason to convert integers, really.
 
        if isinstance(value, int):
 
            return value
 
        if isinstance(value, Decimal):
 
        if isinstance(value, decimal.Decimal):
 
            return value
 

	
 
        if isinstance(value, float):
 
            value = "{{0:0.{0}f}}".format(scale).format(value)
 
        else:
 
            value = value.strip()
 
        return Decimal(value)
 
        return decimal.Decimal(value)
 

	
 
    def int_(self, value):
 
        """
rattail/vendors/handler.py
Show inline comments
 
@@ -2,7 +2,7 @@
 
################################################################################
 
#
 
#  Rattail -- Retail Software Framework
 
#  Copyright © 2010-2023 Lance Edgar
 
#  Copyright © 2010-2024 Lance Edgar
 
#
 
#  This file is part of Rattail.
 
#
 
@@ -24,8 +24,6 @@
 
Vendors Handler
 
"""
 

	
 
from sqlalchemy import orm
 

	
 
from rattail.app import GenericHandler
 
from rattail.util import load_entry_points
 

	
 
@@ -82,6 +80,8 @@ class VendorHandler(GenericHandler):
 
        :returns: The :class:`~rattail.db.model.vendors.Vendor`
 
           instance if found; otherwise ``None``.
 
        """
 
        from sqlalchemy import orm
 

	
 
        model = self.model
 

	
 
        # Vendor.uuid match?
tests/__init__.py
Show inline comments
 
@@ -4,21 +4,7 @@ import os
 
import warnings
 
from unittest import TestCase
 

	
 
import sqlalchemy as sa
 
from sqlalchemy.exc import SAWarning
 

	
 
from rattail.config import make_config
 
from rattail.db import model
 
from rattail.db import Session
 

	
 

	
 
warnings.filterwarnings(
 
    'ignore',
 
    r"^Dialect sqlite\+pysqlite does \*not\* support Decimal objects natively\, "
 
    "and SQLAlchemy must convert from floating point - rounding errors and other "
 
    "issues may occur\. Please consider storing Decimal numbers as strings or "
 
    "integers on this platform for lossless storage\.$",
 
    SAWarning, r'^sqlalchemy\..*$')
 

	
 

	
 
class NullProgress(object):
 
@@ -40,80 +26,98 @@ class NullProgress(object):
 
        pass
 

	
 

	
 
class RattailMixin(object):
 
    """
 
    Generic mixin for ``TestCase`` classes which need common Rattail setup
 
    functionality.
 
    """
 
    engine_url = os.environ.get('RATTAIL_TEST_ENGINE_URL', 'sqlite://')
 
    host_engine_url = os.environ.get('RATTAIL_TEST_HOST_ENGINE_URL')
 

	
 
    def postgresql(self):
 
        return self.config.rattail_engine.url.get_dialect().name == 'postgresql'
 
try:
 
    import sqlalchemy as sa
 
    from sqlalchemy.exc import SAWarning
 
    from rattail.db import model
 
    from rattail.db import Session
 
except ImportError:
 
    pass
 
else:
 

	
 
    def setUp(self):
 
        self.setup_rattail()
 
    warnings.filterwarnings(
 
        'ignore',
 
        r"^Dialect sqlite\+pysqlite does \*not\* support Decimal objects natively\, "
 
        "and SQLAlchemy must convert from floating point - rounding errors and other "
 
        "issues may occur\. Please consider storing Decimal numbers as strings or "
 
        "integers on this platform for lossless storage\.$",
 
        SAWarning, r'^sqlalchemy\..*$')
 

	
 
    def tearDown(self):
 
        self.teardown_rattail()
 

	
 
    def setup_rattail(self):
 
        config = self.make_rattail_config()
 
        self.config = config
 
        self.rattail_config = config
 

	
 
        engine = sa.create_engine(self.engine_url)
 
        config.rattail_engines = {'default': engine}
 
        config.rattail_engine = engine
 
    class RattailMixin(object):
 
        """
 
        Generic mixin for ``TestCase`` classes which need common Rattail setup
 
        functionality.
 
        """
 
        engine_url = os.environ.get('RATTAIL_TEST_ENGINE_URL', 'sqlite://')
 
        host_engine_url = os.environ.get('RATTAIL_TEST_HOST_ENGINE_URL')
 

	
 
        if self.host_engine_url:
 
            config.rattail_engines['host'] = sa.create_engine(self.host_engine_url)
 
        def postgresql(self):
 
            return self.config.rattail_engine.url.get_dialect().name == 'postgresql'
 

	
 
        model = self.get_rattail_model()
 
        model.Base.metadata.create_all(bind=engine)
 
        def setUp(self):
 
            self.setup_rattail()
 

	
 
        Session.configure(bind=engine, rattail_config=config)
 
        self.session = Session()
 
        def tearDown(self):
 
            self.teardown_rattail()
 

	
 
    def teardown_rattail(self):
 
        self.session.close()
 
        Session.configure(bind=None, rattail_config=None)
 
        model = self.get_rattail_model()
 
        model.Base.metadata.drop_all(bind=self.rattail_config.rattail_engine)
 
        def setup_rattail(self):
 
            config = self.make_rattail_config()
 
            self.config = config
 
            self.rattail_config = config
 

	
 
    def make_rattail_config(self, **kwargs):
 
        kwargs.setdefault('files', [])
 
        kwargs.setdefault('use_wuttaconfig', True)
 
        return make_config(**kwargs)
 
            engine = sa.create_engine(self.engine_url)
 
            config.rattail_engines = {'default': engine}
 
            config.rattail_engine = engine
 

	
 
    def get_rattail_model(self):
 
        return model
 
            if self.host_engine_url:
 
                config.rattail_engines['host'] = sa.create_engine(self.host_engine_url)
 

	
 
            model = self.get_rattail_model()
 
            model.Base.metadata.create_all(bind=engine)
 

	
 
class RattailTestCase(RattailMixin, TestCase):
 
    """
 
    Generic base class for Rattail tests.
 
    """
 
            Session.configure(bind=engine, rattail_config=config)
 
            self.session = Session()
 

	
 
        def teardown_rattail(self):
 
            self.session.close()
 
            Session.configure(bind=None, rattail_config=None)
 
            model = self.get_rattail_model()
 
            model.Base.metadata.drop_all(bind=self.rattail_config.rattail_engine)
 

	
 
class DataTestCase(TestCase):
 
        def make_rattail_config(self, **kwargs):
 
            kwargs.setdefault('files', [])
 
            kwargs.setdefault('use_wuttaconfig', True)
 
            return make_config(**kwargs)
 

	
 
    engine_url = os.environ.get('RATTAIL_TEST_ENGINE_URL', 'sqlite://')
 
        def get_rattail_model(self):
 
            return model
 

	
 
    def setUp(self):
 
        self.engine = sa.create_engine(self.engine_url)
 
        model.Base.metadata.create_all(bind=self.engine)
 
        Session.configure(bind=self.engine)
 
        self.session = Session()
 
        self.extra_setup()
 

	
 
    def extra_setup(self):
 
    class RattailTestCase(RattailMixin, TestCase):
 
        """
 
        Derivative classes may define this as necessary, to avoid having to
 
        override the :meth:`setUp()` method.
 
        Generic base class for Rattail tests.
 
        """
 

	
 
    def tearDown(self):
 
        self.session.close()
 
        Session.configure(bind=None)
 
        model.Base.metadata.drop_all(bind=self.engine)
 

	
 
    class DataTestCase(TestCase):
 

	
 
        engine_url = os.environ.get('RATTAIL_TEST_ENGINE_URL', 'sqlite://')
 

	
 
        def setUp(self):
 
            self.engine = sa.create_engine(self.engine_url)
 
            model.Base.metadata.create_all(bind=self.engine)
 
            Session.configure(bind=self.engine)
 
            self.session = Session()
 
            self.extra_setup()
 

	
 
        def extra_setup(self):
 
            """
 
            Derivative classes may define this as necessary, to avoid having to
 
            override the :meth:`setUp()` method.
 
            """
 

	
 
        def tearDown(self):
 
            self.session.close()
 
            Session.configure(bind=None)
 
            model.Base.metadata.drop_all(bind=self.engine)
tests/autocomplete/test_base.py
Show inline comments
 
# -*- coding: utf-8; -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from unittest import TestCase
 

	
 
import sqlalchemy as sa
 
import pytest
 

	
 
from rattail.autocomplete import base as mod
 
from rattail.config import make_config
 
from rattail.autocomplete import base as mod
 
from rattail.db import Session
 

	
 

	
 
@@ -33,7 +31,10 @@ class TestAutocompleter(TestCase):
 
        self.assertIsNotNone(autocompleter)
 

	
 
    def test_get_model_class(self):
 
        model = self.config.get_model()
 
        try:
 
            model = self.config.get_model()
 
        except ImportError:
 
            pytest.skip("test is not relevant without sqlalchemy")
 

	
 
        # no model class by default; hence error
 
        class BadAutocompleter(mod.Autocompleter):
 
@@ -66,6 +67,11 @@ class TestAutocompleter(TestCase):
 
        self.assertEqual(autocompleter.autocomplete_fieldname, 'name')
 

	
 
    def test_autocomplete(self):
 
        try:
 
            import sqlalchemy as sa
 
        except ImportError:
 
            pytest.skip("test is not relevant without sqlalchemy")
 

	
 
        engine = sa.create_engine('sqlite://')
 
        model = self.config.get_model()
 
        model.Base.metadata.create_all(bind=engine)
 
@@ -131,6 +137,11 @@ class TestPhoneMagicMixin(TestCase):
 
        return make_config([], extend=False)
 

	
 
    def test_autocomplete(self):
 
        try:
 
            import sqlalchemy as sa
 
        except ImportError:
 
            pytest.skip("test is not relevant without sqlalchemy")
 

	
 
        engine = sa.create_engine('sqlite://')
 
        model = self.config.get_model()
 
        model.Base.metadata.create_all(bind=engine)
tests/autocomplete/test_brands.py
Show inline comments
 
# -*- coding: utf-8; -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from unittest import TestCase
 

	
 
import sqlalchemy as sa
 

	
 
from rattail.autocomplete import brands as mod
 
from rattail.config import make_config
 
from rattail.db import Session
 

	
 

	
 
class TestBrandAutocompleter(TestCase):
 

	
 
    def setUp(self):
 
        self.config = self.make_config()
 
        self.autocompleter = self.make_autocompleter()
 

	
 
    def make_config(self):
 
        return make_config([], extend=False)
 

	
 
    def make_autocompleter(self):
 
        return mod.BrandAutocompleter(self.config)
 

	
 
    def test_autocomplete(self):
 
        engine = sa.create_engine('sqlite://')
 
        model = self.config.get_model()
 
        model.Base.metadata.create_all(bind=engine)
 
        session = Session(bind=engine)
 

	
 
        # first create a few brands
 
        alpha = model.Brand(name='Alpha Natural Foods')
 
        session.add(alpha)
 
        beta = model.Brand(name='Beta Natural Foods')
 
        session.add(beta)
 
        gamma = model.Brand(name='Gamma Natural Foods')
 
        session.add(gamma)
 

	
 
        # searching for nothing yields no results
 
        result = self.autocompleter.autocomplete(session, '')
 
        self.assertEqual(len(result), 0)
 

	
 
        # search for 'natural' yields all 3 brands
 
        result = self.autocompleter.autocomplete(session, 'natural')
 
        self.assertEqual(len(result), 3)
 

	
 
        # search for 'gamma' yields just that brand
 
        result = self.autocompleter.autocomplete(session, 'gamma')
 
        self.assertEqual(len(result), 1)
 
        self.assertEqual(result[0]['value'], gamma.uuid)
 
try:
 
    import sqlalchemy as sa
 
    from rattail.autocomplete import brands as mod
 
    from rattail.db import Session
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestBrandAutocompleter(TestCase):
 

	
 
        def setUp(self):
 
            self.config = self.make_config()
 
            self.autocompleter = self.make_autocompleter()
 

	
 
        def make_config(self):
 
            return make_config([], extend=False)
 

	
 
        def make_autocompleter(self):
 
            return mod.BrandAutocompleter(self.config)
 

	
 
        def test_autocomplete(self):
 
            engine = sa.create_engine('sqlite://')
 
            model = self.config.get_model()
 
            model.Base.metadata.create_all(bind=engine)
 
            session = Session(bind=engine)
 

	
 
            # first create a few brands
 
            alpha = model.Brand(name='Alpha Natural Foods')
 
            session.add(alpha)
 
            beta = model.Brand(name='Beta Natural Foods')
 
            session.add(beta)
 
            gamma = model.Brand(name='Gamma Natural Foods')
 
            session.add(gamma)
 

	
 
            # searching for nothing yields no results
 
            result = self.autocompleter.autocomplete(session, '')
 
            self.assertEqual(len(result), 0)
 

	
 
            # search for 'natural' yields all 3 brands
 
            result = self.autocompleter.autocomplete(session, 'natural')
 
            self.assertEqual(len(result), 3)
 

	
 
            # search for 'gamma' yields just that brand
 
            result = self.autocompleter.autocomplete(session, 'gamma')
 
            self.assertEqual(len(result), 1)
 
            self.assertEqual(result[0]['value'], gamma.uuid)
tests/autocomplete/test_customers.py
Show inline comments
 
# -*- coding: utf-8; -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from unittest import TestCase
 

	
 
import sqlalchemy as sa
 

	
 
from rattail.autocomplete import customers as mod
 
from rattail.config import make_config
 
from rattail.db import Session
 

	
 

	
 
class TestCustomerAutocompleter(TestCase):
 

	
 
    def setUp(self):
 
        self.config = self.make_config()
 
        self.autocompleter = self.make_autocompleter()
 

	
 
    def make_config(self):
 
        return make_config([], extend=False)
 

	
 
    def make_autocompleter(self):
 
        return mod.CustomerAutocompleter(self.config)
 

	
 
    def test_autocomplete(self):
 
        engine = sa.create_engine('sqlite://')
 
        model = self.config.get_model()
 
        model.Base.metadata.create_all(bind=engine)
 
        session = Session(bind=engine)
 
        enum = self.config.get_enum()
 

	
 
        # first create some customers
 
        alice = model.Customer(name='Alice Chalmers')
 
        session.add(alice)
 
        bob = model.Customer(name='Bob Loblaw')
 
        session.add(bob)
 
        charlie = model.Customer(name='Charlie Chaplin')
 
        session.add(charlie)
 

	
 
        # searching for nothing yields no results
 
        result = self.autocompleter.autocomplete(session, '')
 
        self.assertEqual(len(result), 0)
 

	
 
        # search for 'l' yields all 3 customers
 
        result = self.autocompleter.autocomplete(session, 'l')
 
        self.assertEqual(len(result), 3)
 

	
 
        # search for 'cha' yields just 2 customers
 
        result = self.autocompleter.autocomplete(session, 'cha')
 
        self.assertEqual(len(result), 2)
 
        uuids = [info['value'] for info in result]
 
        self.assertIn(alice.uuid, uuids)
 
        self.assertIn(charlie.uuid, uuids)
 

	
 
try:
 
    import sqlalchemy as sa
 
    from rattail.autocomplete import customers as mod
 
    from rattail.db import Session
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestCustomerAutocompleter(TestCase):
 

	
 
        def setUp(self):
 
            self.config = self.make_config()
 
            self.autocompleter = self.make_autocompleter()
 

	
 
        def make_config(self):
 
            return make_config([], extend=False)
 

	
 
        def make_autocompleter(self):
 
            return mod.CustomerAutocompleter(self.config)
 

	
 
        def test_autocomplete(self):
 
            engine = sa.create_engine('sqlite://')
 
            model = self.config.get_model()
 
            model.Base.metadata.create_all(bind=engine)
 
            session = Session(bind=engine)
 
            enum = self.config.get_enum()
 

	
 
            # first create some customers
 
            alice = model.Customer(name='Alice Chalmers')
 
            session.add(alice)
 
            bob = model.Customer(name='Bob Loblaw')
 
            session.add(bob)
 
            charlie = model.Customer(name='Charlie Chaplin')
 
            session.add(charlie)
 

	
 
            # searching for nothing yields no results
 
            result = self.autocompleter.autocomplete(session, '')
 
            self.assertEqual(len(result), 0)
 

	
 
            # search for 'l' yields all 3 customers
 
            result = self.autocompleter.autocomplete(session, 'l')
 
            self.assertEqual(len(result), 3)
 

	
 
            # search for 'cha' yields just 2 customers
 
            result = self.autocompleter.autocomplete(session, 'cha')
 
            self.assertEqual(len(result), 2)
 
            uuids = [info['value'] for info in result]
 
            self.assertIn(alice.uuid, uuids)
 
            self.assertIn(charlie.uuid, uuids)
tests/autocomplete/test_departments.py
Show inline comments
 
# -*- coding: utf-8; -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from unittest import TestCase
 

	
 
import sqlalchemy as sa
 

	
 
from rattail.autocomplete import departments as mod
 
from rattail.config import make_config
 
from rattail.db import Session
 

	
 

	
 
class TestDepartmentAutocompleter(TestCase):
 

	
 
    def setUp(self):
 
        self.config = self.make_config()
 
        self.autocompleter = self.make_autocompleter()
 

	
 
    def make_config(self):
 
        return make_config([], extend=False)
 

	
 
    def make_autocompleter(self):
 
        return mod.DepartmentAutocompleter(self.config)
 

	
 
    def test_autocomplete(self):
 
        engine = sa.create_engine('sqlite://')
 
        model = self.config.get_model()
 
        model.Base.metadata.create_all(bind=engine)
 
        session = Session(bind=engine)
 

	
 
        # first create a few departments
 
        grocery = model.Department(name='Grocery')
 
        session.add(grocery)
 
        wellness = model.Department(name='Wellness')
 
        session.add(wellness)
 
        bulk = model.Department(name='Bulk')
 
        session.add(bulk)
 

	
 
        # searching for nothing yields no results
 
        result = self.autocompleter.autocomplete(session, '')
 
        self.assertEqual(len(result), 0)
 

	
 
        # search for 'l' yields Wellness, Bulk
 
        result = self.autocompleter.autocomplete(session, 'l')
 
        self.assertEqual(len(result), 2)
 
        uuids = [info['value'] for info in result]
 
        self.assertIn(wellness.uuid, uuids)
 
        self.assertIn(bulk.uuid, uuids)
 

	
 
        # search for 'grocery' yields just that department
 
        result = self.autocompleter.autocomplete(session, 'grocery')
 
        self.assertEqual(len(result), 1)
 
        self.assertEqual(result[0]['value'], grocery.uuid)
 

	
 
try:
 
    import sqlalchemy as sa
 
    from rattail.autocomplete import departments as mod
 
    from rattail.db import Session
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestDepartmentAutocompleter(TestCase):
 

	
 
        def setUp(self):
 
            self.config = self.make_config()
 
            self.autocompleter = self.make_autocompleter()
 

	
 
        def make_config(self):
 
            return make_config([], extend=False)
 

	
 
        def make_autocompleter(self):
 
            return mod.DepartmentAutocompleter(self.config)
 

	
 
        def test_autocomplete(self):
 
            engine = sa.create_engine('sqlite://')
 
            model = self.config.get_model()
 
            model.Base.metadata.create_all(bind=engine)
 
            session = Session(bind=engine)
 

	
 
            # first create a few departments
 
            grocery = model.Department(name='Grocery')
 
            session.add(grocery)
 
            wellness = model.Department(name='Wellness')
 
            session.add(wellness)
 
            bulk = model.Department(name='Bulk')
 
            session.add(bulk)
 

	
 
            # searching for nothing yields no results
 
            result = self.autocompleter.autocomplete(session, '')
 
            self.assertEqual(len(result), 0)
 

	
 
            # search for 'l' yields Wellness, Bulk
 
            result = self.autocompleter.autocomplete(session, 'l')
 
            self.assertEqual(len(result), 2)
 
            uuids = [info['value'] for info in result]
 
            self.assertIn(wellness.uuid, uuids)
 
            self.assertIn(bulk.uuid, uuids)
 

	
 
            # search for 'grocery' yields just that department
 
            result = self.autocompleter.autocomplete(session, 'grocery')
 
            self.assertEqual(len(result), 1)
 
            self.assertEqual(result[0]['value'], grocery.uuid)
tests/autocomplete/test_employees.py
Show inline comments
 
# -*- coding: utf-8; -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from unittest import TestCase
 

	
 
import sqlalchemy as sa
 

	
 
from rattail.autocomplete import employees as mod
 
from rattail.config import make_config
 
from rattail.db import Session
 

	
 

	
 
class TestEmployeeAutocompleter(TestCase):
 

	
 
    def setUp(self):
 
        self.config = self.make_config()
 
        self.autocompleter = self.make_autocompleter()
 

	
 
    def make_config(self):
 
        return make_config([], extend=False)
 

	
 
    def make_autocompleter(self):
 
        return mod.EmployeeAutocompleter(self.config)
 

	
 
    def test_autocomplete(self):
 
        engine = sa.create_engine('sqlite://')
 
        model = self.config.get_model()
 
        model.Base.metadata.create_all(bind=engine)
 
        session = Session(bind=engine)
 
        enum = self.config.get_enum()
 

	
 
        # first create some employees
 
        alice = model.Person(display_name='Alice Chalmers')
 
        alice.employee = model.Employee(status=enum.EMPLOYEE_STATUS_CURRENT)
 
        session.add(alice)
 
        bob = model.Person(display_name='Bob Loblaw')
 
        bob.employee = model.Employee(status=enum.EMPLOYEE_STATUS_CURRENT)
 
        session.add(bob)
 
        charlie = model.Person(display_name='Charlie Chaplin')
 
        charlie.employee = model.Employee(status=enum.EMPLOYEE_STATUS_FORMER)
 
        session.add(charlie)
 

	
 
        # searching for nothing yields no results
 
        result = self.autocompleter.autocomplete(session, '')
 
        self.assertEqual(len(result), 0)
 

	
 
        # search for 'l' yields only 2 current employees
 
        result = self.autocompleter.autocomplete(session, 'l')
 
        self.assertEqual(len(result), 2)
 

	
 
        # search for 'alice' yields just Alice Chalmers
 
        result = self.autocompleter.autocomplete(session, 'alice')
 
        self.assertEqual(len(result), 1)
 
        self.assertEqual(result[0]['value'], alice.employee.uuid)
 

	
 
try:
 
    import sqlalchemy as sa
 
    from rattail.autocomplete import employees as mod
 
    from rattail.db import Session
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestEmployeeAutocompleter(TestCase):
 

	
 
        def setUp(self):
 
            self.config = self.make_config()
 
            self.autocompleter = self.make_autocompleter()
 

	
 
        def make_config(self):
 
            return make_config([], extend=False)
 

	
 
        def make_autocompleter(self):
 
            return mod.EmployeeAutocompleter(self.config)
 

	
 
        def test_autocomplete(self):
 
            engine = sa.create_engine('sqlite://')
 
            model = self.config.get_model()
 
            model.Base.metadata.create_all(bind=engine)
 
            session = Session(bind=engine)
 
            enum = self.config.get_enum()
 

	
 
            # first create some employees
 
            alice = model.Person(display_name='Alice Chalmers')
 
            alice.employee = model.Employee(status=enum.EMPLOYEE_STATUS_CURRENT)
 
            session.add(alice)
 
            bob = model.Person(display_name='Bob Loblaw')
 
            bob.employee = model.Employee(status=enum.EMPLOYEE_STATUS_CURRENT)
 
            session.add(bob)
 
            charlie = model.Person(display_name='Charlie Chaplin')
 
            charlie.employee = model.Employee(status=enum.EMPLOYEE_STATUS_FORMER)
 
            session.add(charlie)
 

	
 
            # searching for nothing yields no results
 
            result = self.autocompleter.autocomplete(session, '')
 
            self.assertEqual(len(result), 0)
 

	
 
            # search for 'l' yields only 2 current employees
 
            result = self.autocompleter.autocomplete(session, 'l')
 
            self.assertEqual(len(result), 2)
 

	
 
            # search for 'alice' yields just Alice Chalmers
 
            result = self.autocompleter.autocomplete(session, 'alice')
 
            self.assertEqual(len(result), 1)
 
            self.assertEqual(result[0]['value'], alice.employee.uuid)
tests/autocomplete/test_people.py
Show inline comments
 
# -*- coding: utf-8; -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from unittest import TestCase
 

	
 
import sqlalchemy as sa
 

	
 
from rattail.autocomplete import people as mod
 
from rattail.config import make_config
 
from rattail.db import Session
 

	
 

	
 
class TestPersonAutocompleter(TestCase):
 

	
 
    def setUp(self):
 
        self.config = self.make_config()
 
        self.autocompleter = self.make_autocompleter()
 

	
 
    def make_config(self):
 
        return make_config([], extend=False)
 

	
 
    def make_autocompleter(self):
 
        return mod.PersonAutocompleter(self.config)
 

	
 
    def test_autocomplete(self):
 
        engine = sa.create_engine('sqlite://')
 
        model = self.config.get_model()
 
        model.Base.metadata.create_all(bind=engine)
 
        session = Session(bind=engine)
 
        enum = self.config.get_enum()
 

	
 
        # first create some people
 
        alice = model.Person(display_name='Alice Chalmers')
 
        session.add(alice)
 
        bob = model.Person(display_name='Bob Loblaw')
 
        session.add(bob)
 
        charlie = model.Person(display_name='Charlie Chaplin')
 
        session.add(charlie)
 

	
 
        # searching for nothing yields no results
 
        result = self.autocompleter.autocomplete(session, '')
 
        self.assertEqual(len(result), 0)
 

	
 
        # search for 'l' yields all 3 people
 
        result = self.autocompleter.autocomplete(session, 'l')
 
        self.assertEqual(len(result), 3)
 

	
 
        # search for 'cha' yields just 2 people
 
        result = self.autocompleter.autocomplete(session, 'cha')
 
        self.assertEqual(len(result), 2)
 
        uuids = [info['value'] for info in result]
 
        self.assertIn(alice.uuid, uuids)
 
        self.assertIn(charlie.uuid, uuids)
 

	
 

	
 
class TestPersonEmployeeAutocompleter(TestCase):
 

	
 
    def setUp(self):
 
        self.config = self.make_config()
 
        self.autocompleter = self.make_autocompleter()
 

	
 
    def make_config(self):
 
        return make_config([], extend=False)
 

	
 
    def make_autocompleter(self):
 
        return mod.PersonEmployeeAutocompleter(self.config)
 

	
 
    def test_autocomplete(self):
 
        engine = sa.create_engine('sqlite://')
 
        model = self.config.get_model()
 
        model.Base.metadata.create_all(bind=engine)
 
        session = Session(bind=engine)
 
        enum = self.config.get_enum()
 

	
 
        # first create some people
 
        alice = model.Person(display_name='Alice Chalmers')
 
        session.add(alice)
 
        bob = model.Person(display_name='Bob Loblaw')
 
        bob.employee = model.Employee(status=enum.EMPLOYEE_STATUS_CURRENT)
 
        session.add(bob)
 
        charlie = model.Person(display_name='Charlie Chaplin')
 
        charlie.employee = model.Employee(status=enum.EMPLOYEE_STATUS_FORMER)
 
        session.add(charlie)
 

	
 
        # searching for nothing yields no results
 
        result = self.autocompleter.autocomplete(session, '')
 
        self.assertEqual(len(result), 0)
 

	
 
        # search for 'l' yields only Bob, Charlie
 
        result = self.autocompleter.autocomplete(session, 'l')
 
        self.assertEqual(len(result), 2)
 
        uuids = [info['value'] for info in result]
 
        self.assertIn(bob.uuid, uuids)
 
        self.assertIn(charlie.uuid, uuids)
 

	
 
        # search for 'cha' yields just Charlie
 
        result = self.autocompleter.autocomplete(session, 'cha')
 
        self.assertEqual(len(result), 1)
 
        self.assertEqual(result[0]['value'], charlie.uuid)
 

	
 
try:
 
    import sqlalchemy as sa
 
    from rattail.autocomplete import people as mod
 
    from rattail.db import Session
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestPersonAutocompleter(TestCase):
 

	
 
        def setUp(self):
 
            self.config = self.make_config()
 
            self.autocompleter = self.make_autocompleter()
 

	
 
        def make_config(self):
 
            return make_config([], extend=False)
 

	
 
        def make_autocompleter(self):
 
            return mod.PersonAutocompleter(self.config)
 

	
 
        def test_autocomplete(self):
 
            engine = sa.create_engine('sqlite://')
 
            model = self.config.get_model()
 
            model.Base.metadata.create_all(bind=engine)
 
            session = Session(bind=engine)
 
            enum = self.config.get_enum()
 

	
 
            # first create some people
 
            alice = model.Person(display_name='Alice Chalmers')
 
            session.add(alice)
 
            bob = model.Person(display_name='Bob Loblaw')
 
            session.add(bob)
 
            charlie = model.Person(display_name='Charlie Chaplin')
 
            session.add(charlie)
 

	
 
            # searching for nothing yields no results
 
            result = self.autocompleter.autocomplete(session, '')
 
            self.assertEqual(len(result), 0)
 

	
 
            # search for 'l' yields all 3 people
 
            result = self.autocompleter.autocomplete(session, 'l')
 
            self.assertEqual(len(result), 3)
 

	
 
            # search for 'cha' yields just 2 people
 
            result = self.autocompleter.autocomplete(session, 'cha')
 
            self.assertEqual(len(result), 2)
 
            uuids = [info['value'] for info in result]
 
            self.assertIn(alice.uuid, uuids)
 
            self.assertIn(charlie.uuid, uuids)
 

	
 

	
 
    class TestPersonEmployeeAutocompleter(TestCase):
 

	
 
        def setUp(self):
 
            self.config = self.make_config()
 
            self.autocompleter = self.make_autocompleter()
 

	
 
        def make_config(self):
 
            return make_config([], extend=False)
 

	
 
        def make_autocompleter(self):
 
            return mod.PersonEmployeeAutocompleter(self.config)
 

	
 
        def test_autocomplete(self):
 
            engine = sa.create_engine('sqlite://')
 
            model = self.config.get_model()
 
            model.Base.metadata.create_all(bind=engine)
 
            session = Session(bind=engine)
 
            enum = self.config.get_enum()
 

	
 
            # first create some people
 
            alice = model.Person(display_name='Alice Chalmers')
 
            session.add(alice)
 
            bob = model.Person(display_name='Bob Loblaw')
 
            bob.employee = model.Employee(status=enum.EMPLOYEE_STATUS_CURRENT)
 
            session.add(bob)
 
            charlie = model.Person(display_name='Charlie Chaplin')
 
            charlie.employee = model.Employee(status=enum.EMPLOYEE_STATUS_FORMER)
 
            session.add(charlie)
 

	
 
            # searching for nothing yields no results
 
            result = self.autocompleter.autocomplete(session, '')
 
            self.assertEqual(len(result), 0)
 

	
 
            # search for 'l' yields only Bob, Charlie
 
            result = self.autocompleter.autocomplete(session, 'l')
 
            self.assertEqual(len(result), 2)
 
            uuids = [info['value'] for info in result]
 
            self.assertIn(bob.uuid, uuids)
 
            self.assertIn(charlie.uuid, uuids)
 

	
 
            # search for 'cha' yields just Charlie
 
            result = self.autocompleter.autocomplete(session, 'cha')
 
            self.assertEqual(len(result), 1)
 
            self.assertEqual(result[0]['value'], charlie.uuid)
tests/autocomplete/test_products.py
Show inline comments
 
# -*- coding: utf-8; -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from unittest import TestCase
 

	
 
import sqlalchemy as sa
 

	
 
from rattail.autocomplete import products as mod
 
from rattail.config import make_config
 
from rattail.db import Session
 

	
 

	
 
class AutocompleterTestCase(TestCase):
 

	
 
    def setUp(self):
 
        self.config = self.make_config()
 
        self.autocompleter = self.make_autocompleter()
 
try:
 
    import sqlalchemy as sa
 
    from rattail.autocomplete import products as mod
 
    from rattail.db import Session
 
except ImportError:
 
    pass
 
else:
 

	
 
    class AutocompleterTestCase(TestCase):
 

	
 
        self.engine = sa.create_engine('sqlite://')
 
        self.model = self.config.get_model()
 
        self.model.Base.metadata.create_all(bind=self.engine)
 
        self.session = Session(bind=self.engine)
 

	
 
    def tearDown(self):
 
        self.session.rollback()
 
        self.session.close()
 

	
 
    def make_config(self):
 
        return make_config([], extend=False)
 

	
 

	
 
class TestProductAutocompleter(AutocompleterTestCase):
 

	
 
    def make_autocompleter(self):
 
        return mod.ProductAutocompleter(self.config)
 

	
 
    def test_autocomplete(self):
 
        session = self.session
 
        model = self.model
 

	
 
        # first create a few products
 
        vinegar = model.Product(description='Apple Cider Vinegar')
 
        session.add(vinegar)
 
        dressing = model.Product(description='Apple Cider Dressing')
 
        session.add(dressing)
 
        oats = model.Product(description='Bulk Oats')
 
        session.add(oats)
 
        deleted = model.Product(description='More Oats', deleted=True)
 
        session.add(deleted)
 

	
 
        # searching for nothing yields no results
 
        result = self.autocompleter.autocomplete(session, '')
 
        self.assertEqual(len(result), 0)
 

	
 
        # search for 'apple' yields Vinegar, Dressing
 
        result = self.autocompleter.autocomplete(session, 'apple')
 
        self.assertEqual(len(result), 2)
 
        uuids = [info['value'] for info in result]
 
        self.assertIn(vinegar.uuid, uuids)
 
        self.assertIn(dressing.uuid, uuids)
 

	
 
        # search for 'oats' yields just the undeleted product
 
        result = self.autocompleter.autocomplete(session, 'oats')
 
        self.assertEqual(len(result), 1)
 
        self.assertEqual(result[0]['value'], oats.uuid)
 

	
 

	
 
class TestProductAllAutocompleter(AutocompleterTestCase):
 

	
 
    def make_autocompleter(self):
 
        return mod.ProductAllAutocompleter(self.config)
 

	
 
    def test_autocomplete(self):
 
        session = self.session
 
        model = self.model
 

	
 
        # first create a few products
 
        vinegar = model.Product(description='Apple Cider Vinegar')
 
        session.add(vinegar)
 
        dressing = model.Product(description='Apple Cider Dressing')
 
        session.add(dressing)
 
        oats = model.Product(description='Bulk Oats')
 
        session.add(oats)
 
        deleted = model.Product(description='More Oats', deleted=True)
 
        session.add(deleted)
 

	
 
        # searching for nothing yields no results
 
        result = self.autocompleter.autocomplete(session, '')
 
        self.assertEqual(len(result), 0)
 

	
 
        # search for 'apple' yields Vinegar, Dressing
 
        result = self.autocompleter.autocomplete(session, 'apple')
 
        self.assertEqual(len(result), 2)
 
        uuids = [info['value'] for info in result]
 
        self.assertIn(vinegar.uuid, uuids)
 
        self.assertIn(dressing.uuid, uuids)
 

	
 
        # search for 'oats' yields Bulk, More
 
        result = self.autocompleter.autocomplete(session, 'oats')
 
        self.assertEqual(len(result), 2)
 
        uuids = [info['value'] for info in result]
 
        self.assertIn(oats.uuid, uuids)
 
        self.assertIn(deleted.uuid, uuids)
 

	
 

	
 
class TestProductNewOrderAutocompleter(AutocompleterTestCase):
 

	
 
    def make_autocompleter(self):
 
        return mod.ProductNewOrderAutocompleter(self.config)
 

	
 
    def test_autocomplete(self):
 
        session = self.session
 
        model = self.model
 

	
 
        # first create a few products
 
        vinegar = model.Product(description='Apple Cider Vinegar',
 
                                upc='074305001321')
 
        session.add(vinegar)
 
        dressing = model.Product(description='Apple Cider Dressing')
 
        session.add(dressing)
 
        oats = model.Product(description='Bulk Oats')
 
        session.add(oats)
 
        deleted = model.Product(description='More Oats', deleted=True)
 
        session.add(deleted)
 

	
 
        # searching for nothing yields no results
 
        result = self.autocompleter.autocomplete(session, '')
 
        self.assertEqual(len(result), 0)
 

	
 
        # search for 'apple' yields Vinegar, Dressing
 
        result = self.autocompleter.autocomplete(session, 'apple')
 
        self.assertEqual(len(result), 2)
 
        uuids = [info['value'] for info in result]
 
        self.assertIn(vinegar.uuid, uuids)
 
        self.assertIn(dressing.uuid, uuids)
 

	
 
        # search for unknown upc yields no results
 
        result = self.autocompleter.autocomplete(session, '7430500116')
 
        self.assertEqual(len(result), 0)
 

	
 
        # search for known upc yields just that product
 
        result = self.autocompleter.autocomplete(session, '7430500132')
 
        self.assertEqual(len(result), 1)
 
        self.assertEqual(result[0]['value'], vinegar.uuid)
 
        def setUp(self):
 
            self.config = self.make_config()
 
            self.autocompleter = self.make_autocompleter()
 

	
 
            self.engine = sa.create_engine('sqlite://')
 
            self.model = self.config.get_model()
 
            self.model.Base.metadata.create_all(bind=self.engine)
 
            self.session = Session(bind=self.engine)
 

	
 
        def tearDown(self):
 
            self.session.rollback()
 
            self.session.close()
 

	
 
        def make_config(self):
 
            return make_config([], extend=False)
 

	
 

	
 
    class TestProductAutocompleter(AutocompleterTestCase):
 

	
 
        def make_autocompleter(self):
 
            return mod.ProductAutocompleter(self.config)
 

	
 
        def test_autocomplete(self):
 
            session = self.session
 
            model = self.model
 

	
 
            # first create a few products
 
            vinegar = model.Product(description='Apple Cider Vinegar')
 
            session.add(vinegar)
 
            dressing = model.Product(description='Apple Cider Dressing')
 
            session.add(dressing)
 
            oats = model.Product(description='Bulk Oats')
 
            session.add(oats)
 
            deleted = model.Product(description='More Oats', deleted=True)
 
            session.add(deleted)
 

	
 
            # searching for nothing yields no results
 
            result = self.autocompleter.autocomplete(session, '')
 
            self.assertEqual(len(result), 0)
 

	
 
            # search for 'apple' yields Vinegar, Dressing
 
            result = self.autocompleter.autocomplete(session, 'apple')
 
            self.assertEqual(len(result), 2)
 
            uuids = [info['value'] for info in result]
 
            self.assertIn(vinegar.uuid, uuids)
 
            self.assertIn(dressing.uuid, uuids)
 

	
 
            # search for 'oats' yields just the undeleted product
 
            result = self.autocompleter.autocomplete(session, 'oats')
 
            self.assertEqual(len(result), 1)
 
            self.assertEqual(result[0]['value'], oats.uuid)
 

	
 

	
 
    class TestProductAllAutocompleter(AutocompleterTestCase):
 

	
 
        def make_autocompleter(self):
 
            return mod.ProductAllAutocompleter(self.config)
 

	
 
        def test_autocomplete(self):
 
            session = self.session
 
            model = self.model
 

	
 
            # first create a few products
 
            vinegar = model.Product(description='Apple Cider Vinegar')
 
            session.add(vinegar)
 
            dressing = model.Product(description='Apple Cider Dressing')
 
            session.add(dressing)
 
            oats = model.Product(description='Bulk Oats')
 
            session.add(oats)
 
            deleted = model.Product(description='More Oats', deleted=True)
 
            session.add(deleted)
 

	
 
            # searching for nothing yields no results
 
            result = self.autocompleter.autocomplete(session, '')
 
            self.assertEqual(len(result), 0)
 

	
 
            # search for 'apple' yields Vinegar, Dressing
 
            result = self.autocompleter.autocomplete(session, 'apple')
 
            self.assertEqual(len(result), 2)
 
            uuids = [info['value'] for info in result]
 
            self.assertIn(vinegar.uuid, uuids)
 
            self.assertIn(dressing.uuid, uuids)
 

	
 
            # search for 'oats' yields Bulk, More
 
            result = self.autocompleter.autocomplete(session, 'oats')
 
            self.assertEqual(len(result), 2)
 
            uuids = [info['value'] for info in result]
 
            self.assertIn(oats.uuid, uuids)
 
            self.assertIn(deleted.uuid, uuids)
 

	
 

	
 
    class TestProductNewOrderAutocompleter(AutocompleterTestCase):
 

	
 
        def make_autocompleter(self):
 
            return mod.ProductNewOrderAutocompleter(self.config)
 

	
 
        def test_autocomplete(self):
 
            session = self.session
 
            model = self.model
 

	
 
            # first create a few products
 
            vinegar = model.Product(description='Apple Cider Vinegar',
 
                                    upc='074305001321')
 
            session.add(vinegar)
 
            dressing = model.Product(description='Apple Cider Dressing')
 
            session.add(dressing)
 
            oats = model.Product(description='Bulk Oats')
 
            session.add(oats)
 
            deleted = model.Product(description='More Oats', deleted=True)
 
            session.add(deleted)
 

	
 
            # searching for nothing yields no results
 
            result = self.autocompleter.autocomplete(session, '')
 
            self.assertEqual(len(result), 0)
 

	
 
            # search for 'apple' yields Vinegar, Dressing
 
            result = self.autocompleter.autocomplete(session, 'apple')
 
            self.assertEqual(len(result), 2)
 
            uuids = [info['value'] for info in result]
 
            self.assertIn(vinegar.uuid, uuids)
 
            self.assertIn(dressing.uuid, uuids)
 

	
 
            # search for unknown upc yields no results
 
            result = self.autocompleter.autocomplete(session, '7430500116')
 
            self.assertEqual(len(result), 0)
 

	
 
            # search for known upc yields just that product
 
            result = self.autocompleter.autocomplete(session, '7430500132')
 
            self.assertEqual(len(result), 1)
 
            self.assertEqual(result[0]['value'], vinegar.uuid)
tests/autocomplete/test_vendors.py
Show inline comments
 
# -*- coding: utf-8; -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from unittest import TestCase
 

	
 
import sqlalchemy as sa
 

	
 
from rattail.autocomplete import vendors as mod
 
from rattail.config import make_config
 
from rattail.db import Session
 

	
 

	
 
class TestVendorAutocompleter(TestCase):
 

	
 
    def setUp(self):
 
        self.config = self.make_config()
 
        self.autocompleter = self.make_autocompleter()
 

	
 
    def make_config(self):
 
        return make_config([], extend=False)
 

	
 
    def make_autocompleter(self):
 
        return mod.VendorAutocompleter(self.config)
 

	
 
    def test_autocomplete(self):
 
        engine = sa.create_engine('sqlite://')
 
        model = self.config.get_model()
 
        model.Base.metadata.create_all(bind=engine)
 
        session = Session(bind=engine)
 

	
 
        # first create some vendors
 
        acme = model.Vendor(name='Acme Wholesale Foods')
 
        session.add(acme)
 
        bigboy = model.Vendor(name='Big Boy Distributors')
 
        session.add(bigboy)
 

	
 
        # searching for nothing yields no results
 
        result = self.autocompleter.autocomplete(session, '')
 
        self.assertEqual(len(result), 0)
 

	
 
        # search for 'd' yields both vendors
 
        result = self.autocompleter.autocomplete(session, 'd')
 
        self.assertEqual(len(result), 2)
 

	
 
        # search for 'big' yields just Big Boy
 
        result = self.autocompleter.autocomplete(session, 'big')
 
        self.assertEqual(len(result), 1)
 
        self.assertEqual(result[0]['value'], bigboy.uuid)
 
try:
 
    import sqlalchemy as sa
 
    from rattail.autocomplete import vendors as mod
 
    from rattail.db import Session
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestVendorAutocompleter(TestCase):
 

	
 
        def setUp(self):
 
            self.config = self.make_config()
 
            self.autocompleter = self.make_autocompleter()
 

	
 
        def make_config(self):
 
            return make_config([], extend=False)
 

	
 
        def make_autocompleter(self):
 
            return mod.VendorAutocompleter(self.config)
 

	
 
        def test_autocomplete(self):
 
            engine = sa.create_engine('sqlite://')
 
            model = self.config.get_model()
 
            model.Base.metadata.create_all(bind=engine)
 
            session = Session(bind=engine)
 

	
 
            # first create some vendors
 
            acme = model.Vendor(name='Acme Wholesale Foods')
 
            session.add(acme)
 
            bigboy = model.Vendor(name='Big Boy Distributors')
 
            session.add(bigboy)
 

	
 
            # searching for nothing yields no results
 
            result = self.autocompleter.autocomplete(session, '')
 
            self.assertEqual(len(result), 0)
 

	
 
            # search for 'd' yields both vendors
 
            result = self.autocompleter.autocomplete(session, 'd')
 
            self.assertEqual(len(result), 2)
 

	
 
            # search for 'big' yields just Big Boy
 
            result = self.autocompleter.autocomplete(session, 'big')
 
            self.assertEqual(len(result), 1)
 
            self.assertEqual(result[0]['value'], bigboy.uuid)
tests/batch/test_handheld.py
Show inline comments
 
# -*- coding: utf-8; -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from unittest import TestCase
 

	
 
import sqlalchemy as sa
 

	
 
from rattail.batch import handheld as mod
 
from rattail.config import make_config
 
from rattail.db import Session
 

	
 

	
 
class TestHandheldBatchHandler(TestCase):
 

	
 
    def setUp(self):
 
        self.config = self.make_config()
 
        self.handler = self.make_handler()
 

	
 
    def make_config(self):
 
        return make_config([], extend=False)
 

	
 
    def make_handler(self):
 
        return mod.HandheldBatchHandler(self.config)
 

	
 
    def test_make_inventory_batch(self):
 
        engine = sa.create_engine('sqlite://')
 
        model = self.config.get_model()
 
        model.Base.metadata.create_all(bind=engine)
 
        session = Session(bind=engine)
 

	
 
        # prep data
 
        betty = model.User(username='betty')
 
        handbatch = model.HandheldBatch(id=1, created_by=betty)
 
        session.add(handbatch)
 
        session.commit()
 

	
 
        # make basic inventory batch
 
        invbatch = self.handler.make_inventory_batch([handbatch], betty, id=2)
 
        self.assertIsNotNone(invbatch)
 
        self.assertEqual(invbatch.id, 2)
 

	
 
    def test_make_label_batch(self):
 
        engine = sa.create_engine('sqlite://')
 
        model = self.config.get_model()
 
        model.Base.metadata.create_all(bind=engine)
 
        session = Session(bind=engine)
 

	
 
        # prep data
 
        betty = model.User(username='betty')
 
        handbatch = model.HandheldBatch(id=1, created_by=betty)
 
        session.add(handbatch)
 
        session.commit()
 

	
 
        # make basic label batch
 
        lblbatch = self.handler.make_label_batch([handbatch], betty, id=2)
 
        self.assertIsNotNone(lblbatch)
 
        self.assertEqual(lblbatch.id, 2)
 

	
 
try:
 
    import sqlalchemy as sa
 
    from rattail.db import Session
 
    from rattail.batch import handheld as mod
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestHandheldBatchHandler(TestCase):
 

	
 
        def setUp(self):
 
            self.config = self.make_config()
 
            self.handler = self.make_handler()
 

	
 
        def make_config(self):
 
            return make_config([], extend=False)
 

	
 
        def make_handler(self):
 
            return mod.HandheldBatchHandler(self.config)
 

	
 
        def test_make_inventory_batch(self):
 
            engine = sa.create_engine('sqlite://')
 
            model = self.config.get_model()
 
            model.Base.metadata.create_all(bind=engine)
 
            session = Session(bind=engine)
 

	
 
            # prep data
 
            betty = model.User(username='betty')
 
            handbatch = model.HandheldBatch(id=1, created_by=betty)
 
            session.add(handbatch)
 
            session.commit()
 

	
 
            # make basic inventory batch
 
            invbatch = self.handler.make_inventory_batch([handbatch], betty, id=2)
 
            self.assertIsNotNone(invbatch)
 
            self.assertEqual(invbatch.id, 2)
 

	
 
        def test_make_label_batch(self):
 
            engine = sa.create_engine('sqlite://')
 
            model = self.config.get_model()
 
            model.Base.metadata.create_all(bind=engine)
 
            session = Session(bind=engine)
 

	
 
            # prep data
 
            betty = model.User(username='betty')
 
            handbatch = model.HandheldBatch(id=1, created_by=betty)
 
            session.add(handbatch)
 
            session.commit()
 

	
 
            # make basic label batch
 
            lblbatch = self.handler.make_label_batch([handbatch], betty, id=2)
 
            self.assertIsNotNone(lblbatch)
 
            self.assertEqual(lblbatch.id, 2)
tests/batch/test_handlers.py
Show inline comments
 
# -*- coding: utf-8; -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
import os
 
from unittest import TestCase
 

	
 
import sqlalchemy as sa
 

	
 
from rattail.batch import handlers as mod
 
from rattail.config import make_config
 
from rattail.db import Session
 

	
 

	
 
class TestBatchHandler(TestCase):
 

	
 
    def setUp(self):
 
        self.config = self.make_config()
 
        self.handler = self.make_handler()
 

	
 
    def make_config(self):
 
        return make_config([], extend=False)
 

	
 
    def make_handler(self):
 
        return mod.BatchHandler(self.config)
 

	
 
    def test_consume_batch_id(self):
 
        engine = sa.create_engine('sqlite://')
 
        model = self.config.get_model()
 
        model.Base.metadata.create_all(bind=engine)
 
        session = Session(bind=engine)
 

	
 
        # first id is 1
 
        result = self.handler.consume_batch_id(session)
 
        self.assertEqual(result, 1)
 

	
 
        # second is 2; test string version
 
        result = self.handler.consume_batch_id(session, as_str=True)
 
        self.assertEqual(result, '00000002')
 

	
 
    def test_get_effective_rows(self):
 
        engine = sa.create_engine('sqlite://')
 
        model = self.config.get_model()
 
        model.Base.metadata.create_all(bind=engine)
 
        session = Session(bind=engine)
 

	
 
        # make batch w/ 3 rows
 
        user = model.User(username='patty')
 
        batch = model.NewProductBatch(id=1, created_by=user)
 
        batch.data_rows.append(model.NewProductBatchRow())
 
        batch.data_rows.append(model.NewProductBatchRow())
 
        batch.data_rows.append(model.NewProductBatchRow())
 
        self.assertEqual(len(batch.data_rows), 3)
 

	
 
        # all rows should be effective by default
 
        result = self.handler.get_effective_rows(batch)
 
        self.assertEqual(len(result), 3)
 

	
 
        # unless we mark one as "removed"
 
        batch.data_rows[1].removed = True
 
        result = self.handler.get_effective_rows(batch)
 
        self.assertEqual(len(result), 2)
 

	
 
        # or if we delete one
 
        batch.data_rows.pop(-1)
 
        result = self.handler.get_effective_rows(batch)
 
        self.assertEqual(len(result), 1)
 
try:
 
    import sqlalchemy as sa
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestBatchHandler(TestCase):
 

	
 
        def setUp(self):
 
            self.config = self.make_config()
 
            self.handler = self.make_handler()
 

	
 
        def make_config(self):
 
            return make_config([], extend=False)
 

	
 
        def make_handler(self):
 
            return mod.BatchHandler(self.config)
 

	
 
        def test_consume_batch_id(self):
 
            engine = sa.create_engine('sqlite://')
 
            model = self.config.get_model()
 
            model.Base.metadata.create_all(bind=engine)
 
            session = Session(bind=engine)
 

	
 
            # first id is 1
 
            result = self.handler.consume_batch_id(session)
 
            self.assertEqual(result, 1)
 

	
 
            # second is 2; test string version
 
            result = self.handler.consume_batch_id(session, as_str=True)
 
            self.assertEqual(result, '00000002')
 

	
 
        def test_get_effective_rows(self):
 
            engine = sa.create_engine('sqlite://')
 
            model = self.config.get_model()
 
            model.Base.metadata.create_all(bind=engine)
 
            session = Session(bind=engine)
 

	
 
            # make batch w/ 3 rows
 
            user = model.User(username='patty')
 
            batch = model.NewProductBatch(id=1, created_by=user)
 
            batch.data_rows.append(model.NewProductBatchRow())
 
            batch.data_rows.append(model.NewProductBatchRow())
 
            batch.data_rows.append(model.NewProductBatchRow())
 
            self.assertEqual(len(batch.data_rows), 3)
 

	
 
            # all rows should be effective by default
 
            result = self.handler.get_effective_rows(batch)
 
            self.assertEqual(len(result), 3)
 

	
 
            # unless we mark one as "removed"
 
            batch.data_rows[1].removed = True
 
            result = self.handler.get_effective_rows(batch)
 
            self.assertEqual(len(result), 2)
 

	
 
            # or if we delete one
 
            batch.data_rows.pop(-1)
 
            result = self.handler.get_effective_rows(batch)
 
            self.assertEqual(len(result), 1)
tests/batch/test_product.py
Show inline comments
 
# -*- coding: utf-8; -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from unittest import TestCase
 

	
 
import sqlalchemy as sa
 

	
 
from rattail.batch import product as mod
 
from rattail.config import make_config
 
from rattail.db import Session
 

	
 

	
 
class TestProductBatchHandler(TestCase):
 

	
 
    def setUp(self):
 
        self.config = self.make_config()
 
        self.handler = self.make_handler()
 

	
 
    def make_config(self):
 
        return make_config([], extend=False)
 

	
 
    def make_handler(self):
 
        return mod.ProductBatchHandler(self.config)
 

	
 
    def test_make_label_batch(self):
 
        engine = sa.create_engine('sqlite://')
 
        model = self.config.get_model()
 
        model.Base.metadata.create_all(bind=engine)
 
        session = Session(bind=engine)
 

	
 
        # prep data
 
        betty = model.User(username='betty')
 
        prodbatch = model.ProductBatch(id=1, created_by=betty)
 
        session.add(prodbatch)
 
        session.commit()
 

	
 
        # make basic label batch
 
        lblbatch = self.handler.make_label_batch(prodbatch, betty, id=2)
 
        self.assertIsNotNone(lblbatch)
 
        self.assertEqual(lblbatch.id, 2)
 

	
 
    def test_make_pricing_batch(self):
 
        engine = sa.create_engine('sqlite://')
 
        model = self.config.get_model()
 
        model.Base.metadata.create_all(bind=engine)
 
        session = Session(bind=engine)
 

	
 
        # prep data
 
        betty = model.User(username='betty')
 
        prodbatch = model.ProductBatch(id=1, created_by=betty)
 
        session.add(prodbatch)
 
        session.commit()
 

	
 
        # make basic pricing batch
 
        prcbatch = self.handler.make_pricing_batch(prodbatch, betty, id=2)
 
        self.assertIsNotNone(prcbatch)
 
        self.assertEqual(prcbatch.id, 2)
 

	
 
try:
 
    import sqlalchemy as sa
 
    from rattail.batch import product as mod
 
    from rattail.db import Session
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestProductBatchHandler(TestCase):
 

	
 
        def setUp(self):
 
            self.config = self.make_config()
 
            self.handler = self.make_handler()
 

	
 
        def make_config(self):
 
            return make_config([], extend=False)
 

	
 
        def make_handler(self):
 
            return mod.ProductBatchHandler(self.config)
 

	
 
        def test_make_label_batch(self):
 
            engine = sa.create_engine('sqlite://')
 
            model = self.config.get_model()
 
            model.Base.metadata.create_all(bind=engine)
 
            session = Session(bind=engine)
 

	
 
            # prep data
 
            betty = model.User(username='betty')
 
            prodbatch = model.ProductBatch(id=1, created_by=betty)
 
            session.add(prodbatch)
 
            session.commit()
 

	
 
            # make basic label batch
 
            lblbatch = self.handler.make_label_batch(prodbatch, betty, id=2)
 
            self.assertIsNotNone(lblbatch)
 
            self.assertEqual(lblbatch.id, 2)
 

	
 
        def test_make_pricing_batch(self):
 
            engine = sa.create_engine('sqlite://')
 
            model = self.config.get_model()
 
            model.Base.metadata.create_all(bind=engine)
 
            session = Session(bind=engine)
 

	
 
            # prep data
 
            betty = model.User(username='betty')
 
            prodbatch = model.ProductBatch(id=1, created_by=betty)
 
            session.add(prodbatch)
 
            session.commit()
 

	
 
            # make basic pricing batch
 
            prcbatch = self.handler.make_pricing_batch(prodbatch, betty, id=2)
 
            self.assertIsNotNone(prcbatch)
 
            self.assertEqual(prcbatch.id, 2)
tests/batch/test_vendorcatalog.py
Show inline comments
 
# -*- coding: utf-8; -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
import os
 
import shutil
 
import decimal
 
from unittest import TestCase
 

	
 
import sqlalchemy as sa
 

	
 
from rattail.batch import vendorcatalog as mod
 
from rattail.config import make_config
 
from rattail.db import Session
 
from rattail.excel import ExcelWriter
 
from rattail.gpc import GPC
 

	
 

	
 
class TestProductBatchHandler(TestCase):
 

	
 
    def setUp(self):
 
        self.config = self.make_config()
 
        self.handler = self.make_handler()
 

	
 
    def make_config(self):
 
        return make_config([], extend=False)
 

	
 
    def make_handler(self):
 
        return mod.VendorCatalogHandler(self.config)
 

	
 
    def test_allow_future(self):
 

	
 
        # off by default
 
        result = self.handler.allow_future()
 
        self.assertFalse(result)
 

	
 
        # but can be enabled via config
 
        self.config.setdefault('rattail.batch', 'vendor_catalog.allow_future',
 
                               'true')
 
        result = self.handler.allow_future()
 
        self.assertTrue(result)
 

	
 
    def test_populate_from_file(self):
 
        engine = sa.create_engine('sqlite://')
 
        model = self.config.get_model()
 
        model.Base.metadata.create_all(bind=engine)
 
        session = Session(bind=engine)
 
        app = self.config.get_app()
 

	
 
        # we'll need a user to create the batches
 
        user = model.User(username='ralph')
 
        session.add(user)
 

	
 
        # make root folder to contain all temp files
 
        tempdir = app.make_temp_dir()
 

	
 
        # generate sample xlsx file
 
        path = os.path.join(tempdir, 'sample.xlsx')
 
        writer = ExcelWriter(path, ['UPC', 'Vendor Code', 'Unit Cost'])
 
        writer.write_header()
 
        writer.write_row(['074305001321', '123456', 4.19], row=2)
 
        writer.save()
 

	
 
        # make, configure folder for batch files
 
        filesdir = os.path.join(tempdir, 'batch_files')
 
        os.makedirs(filesdir)
 
        self.config.setdefault('rattail', 'batch.files', filesdir)
 

	
 
        # make the basic batch
 
        batch = model.VendorCatalogBatch(uuid=app.make_uuid(),
 
                                         id=1, created_by=user)
 
        session.add(batch)
 

	
 
        # batch must have certain attributes, else error
 
        self.assertRaises(ValueError, self.handler.populate_from_file, batch)
 
        self.handler.set_input_file(batch, path) # sets batch.filename
 
        self.assertRaises(ValueError, self.handler.populate_from_file, batch)
 
        batch.parser_key = 'rattail.contrib.generic'
 

	
 
        # and finally, test our method proper
 
        self.handler.setup_populate(batch)
 
        self.handler.populate_from_file(batch)
 
        self.assertEqual(len(batch.data_rows), 1)
 
        row = batch.data_rows[0]
 
        self.assertEqual(row.item_entry, '074305001321')
 
        self.assertEqual(row.vendor_code, '123456')
 
        self.assertEqual(row.unit_cost, decimal.Decimal('4.19'))
 

	
 
        shutil.rmtree(tempdir)
 
        session.rollback()
 
        session.close()
 

	
 
    def test_identify_product(self):
 
        engine = sa.create_engine('sqlite://')
 
        model = self.config.get_model()
 
        model.Base.metadata.create_all(bind=engine)
 
        session = Session(bind=engine)
 
        app = self.config.get_app()
 

	
 
        # make a test user, vendor, product, cost
 
        user = model.User(username='ralph')
 
        session.add(user)
 
        vendor = model.Vendor()
 
        session.add(vendor)
 
        product = model.Product(upc=GPC('074305001321'))
 
        session.add(product)
 
        cost = model.ProductCost(vendor=vendor,
 
                                 code='123456',
 
                                 case_size=12,
 
                                 case_cost=decimal.Decimal('54.00'),
 
                                 unit_cost=decimal.Decimal('4.50'))
 
        product.costs.append(cost)
 

	
 
        # also a batch to contain the rows
 
        batch = model.VendorCatalogBatch(uuid=app.make_uuid(),
 
                                         id=1, created_by=user,
 
                                         vendor=vendor,
 
                                         filename='sample.xlsx',
 
                                         parser_key='rattail.contrib.generic')
 
        session.add(batch)
 

	
 
        # row w/ no interesting attributes cannot yield a product
 
        row = model.VendorCatalogBatchRow()
 
        batch.data_rows.append(row)
 
        result = self.handler.identify_product(row)
 
        self.assertIsNone(result)
 

	
 
        # but if we give row a upc, product is found
 
        row.upc = GPC('074305001321')
 
        result = self.handler.identify_product(row)
 
        self.assertIs(result, product)
 

	
 
        # now try one with vendor code instead of upc
 
        row = model.VendorCatalogBatchRow(vendor_code='123456')
 
        batch.data_rows.append(row)
 
        result = self.handler.identify_product(row)
 
        self.assertIs(result, product)
 

	
 
        session.rollback()
 
        session.close()
 

	
 
    def test_refresh_row(self):
 
        engine = sa.create_engine('sqlite://')
 
        model = self.config.get_model()
 
        model.Base.metadata.create_all(bind=engine)
 
        session = Session(bind=engine)
 
        app = self.config.get_app()
 

	
 
        # make a test user, vendor, product
 
        user = model.User(username='ralph')
 
        session.add(user)
 
        vendor = model.Vendor()
 
        session.add(vendor)
 
        product = model.Product(upc=GPC('074305001321'))
 
        session.add(product)
 

	
 
        # also a batch to contain the rows
 
        batch = model.VendorCatalogBatch(uuid=app.make_uuid(),
 
                                         id=1, created_by=user,
 
                                         vendor=vendor,
 
                                         filename='sample.xlsx',
 
                                         parser_key='rattail.contrib.generic')
 
        session.add(batch)
 

	
 
        # empty row is just marked as product not found
 
        row = model.VendorCatalogBatchRow()
 
        batch.data_rows.append(row)
 
        self.handler.refresh_row(row)
 
        self.assertEqual(row.status_code, row.STATUS_PRODUCT_NOT_FOUND)
 

	
 
        # row with upc is matched with product; also make sure unit
 
        # cost is calculated from case cost
 
        row = model.VendorCatalogBatchRow(upc=GPC('074305001321'),
 
                                          case_size=12,
 
                                          case_cost=decimal.Decimal('58.00'))
 
        batch.data_rows.append(row)
 
        self.handler.refresh_row(row)
 
        self.assertIs(row.product, product)
 
        self.assertEqual(row.status_code, row.STATUS_NEW_COST)
 
        self.assertEqual(row.case_cost, 58)
 
        self.assertEqual(row.case_size, 12)
 
        self.assertEqual(row.unit_cost, decimal.Decimal('4.8333'))
 

	
 
        # now we add a cost to the master product, and make sure new
 
        # row will reflect an update for that cost
 
        cost = model.ProductCost(vendor=vendor, 
 
                                 case_size=12,
 
                                 case_cost=decimal.Decimal('54.00'),
 
                                 unit_cost=decimal.Decimal('4.50'))
 
        product.costs.append(cost)
 
        row = model.VendorCatalogBatchRow(upc=GPC('074305001321'),
 
                                          case_size=12,
 
                                          case_cost=decimal.Decimal('58.00'))
 
        batch.data_rows.append(row)
 
        self.handler.refresh_row(row)
 
        self.assertIs(row.product, product)
 
        self.assertEqual(row.status_code, row.STATUS_CHANGE_COST)
 
        self.assertEqual(row.old_case_cost, 54)
 
        self.assertEqual(row.case_cost, 58)
 
        self.assertEqual(row.old_unit_cost, decimal.Decimal('4.50'))
 
        self.assertEqual(row.unit_cost, decimal.Decimal('4.8333'))
 

	
 
        # and finally let's refresh everything, note that row #2
 
        # should now *also* get "change cost" status
 
        row = batch.data_rows[1]
 
        self.assertEqual(row.status_code, row.STATUS_NEW_COST)
 
        self.handler.setup_refresh(batch)
 
        for row in batch.data_rows:
 
try:
 
    import sqlalchemy as sa
 
    from rattail.batch import vendorcatalog as mod
 
    from rattail.db import Session
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestProductBatchHandler(TestCase):
 

	
 
        def setUp(self):
 
            self.config = self.make_config()
 
            self.handler = self.make_handler()
 

	
 
        def make_config(self):
 
            return make_config([], extend=False)
 

	
 
        def make_handler(self):
 
            return mod.VendorCatalogHandler(self.config)
 

	
 
        def test_allow_future(self):
 

	
 
            # off by default
 
            result = self.handler.allow_future()
 
            self.assertFalse(result)
 

	
 
            # but can be enabled via config
 
            self.config.setdefault('rattail.batch', 'vendor_catalog.allow_future',
 
                                   'true')
 
            result = self.handler.allow_future()
 
            self.assertTrue(result)
 

	
 
        def test_populate_from_file(self):
 
            engine = sa.create_engine('sqlite://')
 
            model = self.config.get_model()
 
            model.Base.metadata.create_all(bind=engine)
 
            session = Session(bind=engine)
 
            app = self.config.get_app()
 

	
 
            # we'll need a user to create the batches
 
            user = model.User(username='ralph')
 
            session.add(user)
 

	
 
            # make root folder to contain all temp files
 
            tempdir = app.make_temp_dir()
 

	
 
            # generate sample xlsx file
 
            path = os.path.join(tempdir, 'sample.xlsx')
 
            writer = ExcelWriter(path, ['UPC', 'Vendor Code', 'Unit Cost'])
 
            writer.write_header()
 
            writer.write_row(['074305001321', '123456', 4.19], row=2)
 
            writer.save()
 

	
 
            # make, configure folder for batch files
 
            filesdir = os.path.join(tempdir, 'batch_files')
 
            os.makedirs(filesdir)
 
            self.config.setdefault('rattail', 'batch.files', filesdir)
 

	
 
            # make the basic batch
 
            batch = model.VendorCatalogBatch(uuid=app.make_uuid(),
 
                                             id=1, created_by=user)
 
            session.add(batch)
 

	
 
            # batch must have certain attributes, else error
 
            self.assertRaises(ValueError, self.handler.populate_from_file, batch)
 
            self.handler.set_input_file(batch, path) # sets batch.filename
 
            self.assertRaises(ValueError, self.handler.populate_from_file, batch)
 
            batch.parser_key = 'rattail.contrib.generic'
 

	
 
            # and finally, test our method proper
 
            self.handler.setup_populate(batch)
 
            self.handler.populate_from_file(batch)
 
            self.assertEqual(len(batch.data_rows), 1)
 
            row = batch.data_rows[0]
 
            self.assertEqual(row.item_entry, '074305001321')
 
            self.assertEqual(row.vendor_code, '123456')
 
            self.assertEqual(row.unit_cost, decimal.Decimal('4.19'))
 

	
 
            shutil.rmtree(tempdir)
 
            session.rollback()
 
            session.close()
 

	
 
        def test_identify_product(self):
 
            engine = sa.create_engine('sqlite://')
 
            model = self.config.get_model()
 
            model.Base.metadata.create_all(bind=engine)
 
            session = Session(bind=engine)
 
            app = self.config.get_app()
 

	
 
            # make a test user, vendor, product, cost
 
            user = model.User(username='ralph')
 
            session.add(user)
 
            vendor = model.Vendor()
 
            session.add(vendor)
 
            product = model.Product(upc=GPC('074305001321'))
 
            session.add(product)
 
            cost = model.ProductCost(vendor=vendor,
 
                                     code='123456',
 
                                     case_size=12,
 
                                     case_cost=decimal.Decimal('54.00'),
 
                                     unit_cost=decimal.Decimal('4.50'))
 
            product.costs.append(cost)
 

	
 
            # also a batch to contain the rows
 
            batch = model.VendorCatalogBatch(uuid=app.make_uuid(),
 
                                             id=1, created_by=user,
 
                                             vendor=vendor,
 
                                             filename='sample.xlsx',
 
                                             parser_key='rattail.contrib.generic')
 
            session.add(batch)
 

	
 
            # row w/ no interesting attributes cannot yield a product
 
            row = model.VendorCatalogBatchRow()
 
            batch.data_rows.append(row)
 
            result = self.handler.identify_product(row)
 
            self.assertIsNone(result)
 

	
 
            # but if we give row a upc, product is found
 
            row.upc = GPC('074305001321')
 
            result = self.handler.identify_product(row)
 
            self.assertIs(result, product)
 

	
 
            # now try one with vendor code instead of upc
 
            row = model.VendorCatalogBatchRow(vendor_code='123456')
 
            batch.data_rows.append(row)
 
            result = self.handler.identify_product(row)
 
            self.assertIs(result, product)
 

	
 
            session.rollback()
 
            session.close()
 

	
 
        def test_refresh_row(self):
 
            engine = sa.create_engine('sqlite://')
 
            model = self.config.get_model()
 
            model.Base.metadata.create_all(bind=engine)
 
            session = Session(bind=engine)
 
            app = self.config.get_app()
 

	
 
            # make a test user, vendor, product
 
            user = model.User(username='ralph')
 
            session.add(user)
 
            vendor = model.Vendor()
 
            session.add(vendor)
 
            product = model.Product(upc=GPC('074305001321'))
 
            session.add(product)
 

	
 
            # also a batch to contain the rows
 
            batch = model.VendorCatalogBatch(uuid=app.make_uuid(),
 
                                             id=1, created_by=user,
 
                                             vendor=vendor,
 
                                             filename='sample.xlsx',
 
                                             parser_key='rattail.contrib.generic')
 
            session.add(batch)
 

	
 
            # empty row is just marked as product not found
 
            row = model.VendorCatalogBatchRow()
 
            batch.data_rows.append(row)
 
            self.handler.refresh_row(row)
 
        self.assertEqual(row.status_code, row.STATUS_CHANGE_COST)
 

	
 
        session.rollback()
 
        session.close()
 
            self.assertEqual(row.status_code, row.STATUS_PRODUCT_NOT_FOUND)
 

	
 
            # row with upc is matched with product; also make sure unit
 
            # cost is calculated from case cost
 
            row = model.VendorCatalogBatchRow(upc=GPC('074305001321'),
 
                                              case_size=12,
 
                                              case_cost=decimal.Decimal('58.00'))
 
            batch.data_rows.append(row)
 
            self.handler.refresh_row(row)
 
            self.assertIs(row.product, product)
 
            self.assertEqual(row.status_code, row.STATUS_NEW_COST)
 
            self.assertEqual(row.case_cost, 58)
 
            self.assertEqual(row.case_size, 12)
 
            self.assertEqual(row.unit_cost, decimal.Decimal('4.8333'))
 

	
 
            # now we add a cost to the master product, and make sure new
 
            # row will reflect an update for that cost
 
            cost = model.ProductCost(vendor=vendor, 
 
                                     case_size=12,
 
                                     case_cost=decimal.Decimal('54.00'),
 
                                     unit_cost=decimal.Decimal('4.50'))
 
            product.costs.append(cost)
 
            row = model.VendorCatalogBatchRow(upc=GPC('074305001321'),
 
                                              case_size=12,
 
                                              case_cost=decimal.Decimal('58.00'))
 
            batch.data_rows.append(row)
 
            self.handler.refresh_row(row)
 
            self.assertIs(row.product, product)
 
            self.assertEqual(row.status_code, row.STATUS_CHANGE_COST)
 
            self.assertEqual(row.old_case_cost, 54)
 
            self.assertEqual(row.case_cost, 58)
 
            self.assertEqual(row.old_unit_cost, decimal.Decimal('4.50'))
 
            self.assertEqual(row.unit_cost, decimal.Decimal('4.8333'))
 

	
 
            # and finally let's refresh everything, note that row #2
 
            # should now *also* get "change cost" status
 
            row = batch.data_rows[1]
 
            self.assertEqual(row.status_code, row.STATUS_NEW_COST)
 
            self.handler.setup_refresh(batch)
 
            for row in batch.data_rows:
 
                self.handler.refresh_row(row)
 
            self.assertEqual(row.status_code, row.STATUS_CHANGE_COST)
 

	
 
            session.rollback()
 
            session.close()
tests/db/model/test_customers.py
Show inline comments
 
# -*- coding: utf-8; -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from rattail.db import model
 
from ... import DataTestCase
 

	
 

	
 
class TestCustomer(DataTestCase):
 

	
 
    # TODO: this is duplicated in TestPerson
 
    def test_add_email_address(self):
 
        customer = model.Customer()
 
        self.assertEqual(len(customer.emails), 0)
 
        customer.add_email_address('fred@mailinator.com')
 
        self.assertEqual(len(customer.emails), 1)
 
        email = customer.emails[0]
 
        self.assertEqual(email.type, 'Home')
 

	
 
        customer = model.Customer()
 
        self.assertEqual(len(customer.emails), 0)
 
        customer.add_email_address('fred@mailinator.com', type='Work')
 
        self.assertEqual(len(customer.emails), 1)
 
        email = customer.emails[0]
 
        self.assertEqual(email.type, 'Work')
 

	
 
    # TODO: this is duplicated in TestPerson
 
    def test_add_phone_number(self):
 
        customer = model.Customer()
 
        self.assertEqual(len(customer.phones), 0)
 
        customer.add_phone_number('417-555-1234')
 
        self.assertEqual(len(customer.phones), 1)
 
        phone = customer.phones[0]
 
        self.assertEqual(phone.type, 'Home')
 

	
 
        customer = model.Customer()
 
        self.assertEqual(len(customer.phones), 0)
 
        customer.add_phone_number('417-555-1234', type='Work')
 
        self.assertEqual(len(customer.phones), 1)
 
        phone = customer.phones[0]
 
        self.assertEqual(phone.type, 'Work')
 
try:
 
    from rattail.db import model
 
    from ... import DataTestCase
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestCustomer(DataTestCase):
 

	
 
        # TODO: this is duplicated in TestPerson
 
        def test_add_email_address(self):
 
            customer = model.Customer()
 
            self.assertEqual(len(customer.emails), 0)
 
            customer.add_email_address('fred@mailinator.com')
 
            self.assertEqual(len(customer.emails), 1)
 
            email = customer.emails[0]
 
            self.assertEqual(email.type, 'Home')
 

	
 
            customer = model.Customer()
 
            self.assertEqual(len(customer.emails), 0)
 
            customer.add_email_address('fred@mailinator.com', type='Work')
 
            self.assertEqual(len(customer.emails), 1)
 
            email = customer.emails[0]
 
            self.assertEqual(email.type, 'Work')
 

	
 
        # TODO: this is duplicated in TestPerson
 
        def test_add_phone_number(self):
 
            customer = model.Customer()
 
            self.assertEqual(len(customer.phones), 0)
 
            customer.add_phone_number('417-555-1234')
 
            self.assertEqual(len(customer.phones), 1)
 
            phone = customer.phones[0]
 
            self.assertEqual(phone.type, 'Home')
 

	
 
            customer = model.Customer()
 
            self.assertEqual(len(customer.phones), 0)
 
            customer.add_phone_number('417-555-1234', type='Work')
 
            self.assertEqual(len(customer.phones), 1)
 
            phone = customer.phones[0]
 
            self.assertEqual(phone.type, 'Work')
tests/db/model/test_datasync.py
Show inline comments
 
# -*- coding: utf-8; -*-
 

	
 
from rattail.db import model
 
from ... import DataTestCase
 
try:
 
    from rattail.db import model
 
    from ... import DataTestCase
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestDataSyncChange(DataTestCase):
 

	
 
class TestDataSyncChange(DataTestCase):
 
        def test_unicode(self):
 
            change = model.DataSyncChange()
 
            self.assertEqual(str(change), "(empty)")
 

	
 
    def test_unicode(self):
 
        change = model.DataSyncChange()
 
        self.assertEqual(str(change), "(empty)")
 
            change = model.DataSyncChange(payload_type='Product', payload_key='00074305001321')
 
            self.assertEqual(str(change), "Product: 00074305001321")
 

	
 
        change = model.DataSyncChange(payload_type='Product', payload_key='00074305001321')
 
        self.assertEqual(str(change), "Product: 00074305001321")
 

	
 
        change = model.DataSyncChange(payload_type='Product', payload_key='00074305001321', deletion=True)
 
        self.assertEqual(str(change), "Product: 00074305001321 (deletion)")
 
            change = model.DataSyncChange(payload_type='Product', payload_key='00074305001321', deletion=True)
 
            self.assertEqual(str(change), "Product: 00074305001321 (deletion)")
tests/db/model/test_people.py
Show inline comments
 
# -*- coding: utf-8; -*-
 

	
 
from unittest import TestCase
 
from unittest.mock import Mock
 

	
 
from mock import Mock
 
try:
 
    from rattail.db import model
 
    from rattail.db.model import people
 
    from ... import DataTestCase
 
except ImportError:
 
    pass
 
else:
 

	
 
from rattail.db import model
 
from rattail.db.model import people
 
from ... import DataTestCase
 
    class TestPerson(DataTestCase):
 

	
 
        def test_unicode(self):
 
            person = model.Person()
 
            self.assertEqual(str(person), "")
 

	
 
class TestPerson(DataTestCase):
 
            person = model.Person(display_name="Fred Flintstone")
 
            self.assertEqual(str(person), "Fred Flintstone")
 

	
 
    def test_unicode(self):
 
        person = model.Person()
 
        self.assertEqual(str(person), "")
 
        # TODO: this is duplicated in TestPerson
 
        def test_add_email_address(self):
 
            person = model.Person()
 
            self.assertEqual(len(person.emails), 0)
 
            person.add_email_address('fred@mailinator.com')
 
            self.assertEqual(len(person.emails), 1)
 
            email = person.emails[0]
 
            self.assertEqual(email.type, 'Home')
 

	
 
        person = model.Person(display_name="Fred Flintstone")
 
        self.assertEqual(str(person), "Fred Flintstone")
 
            person = model.Person()
 
            self.assertEqual(len(person.emails), 0)
 
            person.add_email_address('fred@mailinator.com', type='Work')
 
            self.assertEqual(len(person.emails), 1)
 
            email = person.emails[0]
 
            self.assertEqual(email.type, 'Work')
 

	
 
    # TODO: this is duplicated in TestPerson
 
    def test_add_email_address(self):
 
        person = model.Person()
 
        self.assertEqual(len(person.emails), 0)
 
        person.add_email_address('fred@mailinator.com')
 
        self.assertEqual(len(person.emails), 1)
 
        email = person.emails[0]
 
        self.assertEqual(email.type, 'Home')
 
        # TODO: this is duplicated in TestPerson
 
        def test_add_phone_number(self):
 
            person = model.Person()
 
            self.assertEqual(len(person.phones), 0)
 
            person.add_phone_number('417-555-1234')
 
            self.assertEqual(len(person.phones), 1)
 
            phone = person.phones[0]
 
            self.assertEqual(phone.type, 'Home')
 

	
 
        person = model.Person()
 
        self.assertEqual(len(person.emails), 0)
 
        person.add_email_address('fred@mailinator.com', type='Work')
 
        self.assertEqual(len(person.emails), 1)
 
        email = person.emails[0]
 
        self.assertEqual(email.type, 'Work')
 
            person = model.Person()
 
            self.assertEqual(len(person.phones), 0)
 
            person.add_phone_number('417-555-1234', type='Work')
 
            self.assertEqual(len(person.phones), 1)
 
            phone = person.phones[0]
 
            self.assertEqual(phone.type, 'Work')
 

	
 
    # TODO: this is duplicated in TestPerson
 
    def test_add_phone_number(self):
 
        person = model.Person()
 
        self.assertEqual(len(person.phones), 0)
 
        person.add_phone_number('417-555-1234')
 
        self.assertEqual(len(person.phones), 1)
 
        phone = person.phones[0]
 
        self.assertEqual(phone.type, 'Home')
 

	
 
        person = model.Person()
 
        self.assertEqual(len(person.phones), 0)
 
        person.add_phone_number('417-555-1234', type='Work')
 
        self.assertEqual(len(person.phones), 1)
 
        phone = person.phones[0]
 
        self.assertEqual(phone.type, 'Work')
 
    # TODO: deprecate/remove this?
 
    class TestFunctions(TestCase):
 

	
 
        def test_get_person_display_name(self):
 
            name = people.get_person_display_name("Fred", "Flintstone")
 
            self.assertEqual(name, "Fred Flintstone")
 

	
 
# TODO: deprecate/remove this?
 
class TestFunctions(TestCase):
 

	
 
    def test_get_person_display_name(self):
 
        name = people.get_person_display_name("Fred", "Flintstone")
 
        self.assertEqual(name, "Fred Flintstone")
 

	
 
    def test_get_person_display_name_from_context(self):
 
        context = Mock(current_parameters={'first_name': "Fred", 'last_name': "Flintstone"})
 
        name = people.get_person_display_name_from_context(context)
 
        self.assertEqual(name, "Fred Flintstone")
 
        def test_get_person_display_name_from_context(self):
 
            context = Mock(current_parameters={'first_name': "Fred", 'last_name': "Flintstone"})
 
            name = people.get_person_display_name_from_context(context)
 
            self.assertEqual(name, "Fred Flintstone")
tests/db/model/test_users.py
Show inline comments
 
# -*- coding: utf-8 -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from rattail.db import model
 
from ... import DataTestCase
 

	
 

	
 
class TestUserEmailAddress(DataTestCase):
 

	
 
    def extra_setup(self):
 
        self.user = model.User(username='fred')
 
        self.session.add(self.user)
 
        self.session.flush()
 

	
 
    def test_email_defaults_to_none(self):
 
        self.assertTrue(self.user.get_email_address() is None)
 

	
 
    def test_email_comes_from_person_then_customer(self):
 
        # only customer has email at this point
 
        person = model.Person(first_name='Fred')
 
        customer = model.Customer(name='Fred')
 
        customer.add_email_address('customer@mailinator.com')
 
        customer.people.append(person)
 
        self.user.person = person
 
        self.session.add(customer)
 
        self.session.flush()
 
        self.assertEqual(self.user.get_email_address(), 'customer@mailinator.com')
 

	
 
        # now person email will take preference
 
        person.add_email_address('person@mailinator.com')
 
        self.session.refresh(person)
 
        self.assertEqual(self.user.get_email_address(), 'person@mailinator.com')
 

	
 
    def test_email_address_property_works_too(self):
 
        # even though this may go away some day, cover it for now
 
        person = model.Person(first_name='Fred')
 
        person.add_email_address('person@mailinator.com')
 
        self.user.person = person
 
        self.session.flush()
 
        self.assertEqual(self.user.email_address, 'person@mailinator.com')
 
# -*- coding: utf-8; -*-
 

	
 
try:
 
    from rattail.db import model
 
    from ... import DataTestCase
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestUserEmailAddress(DataTestCase):
 

	
 
        def extra_setup(self):
 
            self.user = model.User(username='fred')
 
            self.session.add(self.user)
 
            self.session.flush()
 

	
 
        def test_email_defaults_to_none(self):
 
            self.assertTrue(self.user.get_email_address() is None)
 

	
 
        def test_email_comes_from_person_then_customer(self):
 
            # only customer has email at this point
 
            person = model.Person(first_name='Fred')
 
            customer = model.Customer(name='Fred')
 
            customer.add_email_address('customer@mailinator.com')
 
            customer.people.append(person)
 
            self.user.person = person
 
            self.session.add(customer)
 
            self.session.flush()
 
            self.assertEqual(self.user.get_email_address(), 'customer@mailinator.com')
 

	
 
            # now person email will take preference
 
            person.add_email_address('person@mailinator.com')
 
            self.session.refresh(person)
 
            self.assertEqual(self.user.get_email_address(), 'person@mailinator.com')
 

	
 
        def test_email_address_property_works_too(self):
 
            # even though this may go away some day, cover it for now
 
            person = model.Person(first_name='Fred')
 
            person.add_email_address('person@mailinator.com')
 
            self.user.person = person
 
            self.session.flush()
 
            self.assertEqual(self.user.email_address, 'person@mailinator.com')
tests/db/test_auth.py
Show inline comments
 
# -*- coding: utf-8 -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from rattail.db import auth, model
 
from .. import DataTestCase
 

	
 

	
 
class TestAuthenticateUser(DataTestCase):
 

	
 
    def test_nonexistent_user_returns_none(self):
 
        self.assertTrue(auth.authenticate_user(self.session, u'fred', u'fredpass') is None)
 

	
 
    def test_correct_credentials_returns_user(self):
 
        fred = model.User(username=u'fred')
 
        auth.set_user_password(fred, u'fredpass')
 
        self.session.add(fred)
 
        self.session.commit()
 
        user = auth.authenticate_user(self.session, u'fred', u'fredpass')
 
        self.assertTrue(user is fred)
 

	
 
    def test_wrong_password_user_returns_none(self):
 
        fred = model.User(username=u'fred', active=False)
 
        auth.set_user_password(fred, u'fredpass')
 
        self.session.add(fred)
 
        self.session.commit()
 
        self.assertTrue(auth.authenticate_user(self.session, u'fred', u'BADPASS') is None)
 

	
 
    def test_inactive_user_returns_none(self):
 
        fred = model.User(username=u'fred', active=False)
 
        auth.set_user_password(fred, u'fredpass')
 
        self.session.add(fred)
 
        self.session.commit()
 
        self.assertTrue(auth.authenticate_user(self.session, u'fred', u'fredpass') is None)
 
# -*- coding: utf-8; -*-
 

	
 
try:
 
    from rattail.db import auth, model
 
    from .. import DataTestCase
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestAuthenticateUser(DataTestCase):
 

	
 
        def test_nonexistent_user_returns_none(self):
 
            self.assertTrue(auth.authenticate_user(self.session, u'fred', u'fredpass') is None)
 

	
 
        def test_correct_credentials_returns_user(self):
 
            fred = model.User(username=u'fred')
 
            auth.set_user_password(fred, u'fredpass')
 
            self.session.add(fred)
 
            self.session.commit()
 
            user = auth.authenticate_user(self.session, u'fred', u'fredpass')
 
            self.assertTrue(user is fred)
 

	
 
        def test_wrong_password_user_returns_none(self):
 
            fred = model.User(username=u'fred', active=False)
 
            auth.set_user_password(fred, u'fredpass')
 
            self.session.add(fred)
 
            self.session.commit()
 
            self.assertTrue(auth.authenticate_user(self.session, u'fred', u'BADPASS') is None)
 

	
 
        def test_inactive_user_returns_none(self):
 
            fred = model.User(username=u'fred', active=False)
 
            auth.set_user_password(fred, u'fredpass')
 
            self.session.add(fred)
 
            self.session.commit()
 
            self.assertTrue(auth.authenticate_user(self.session, u'fred', u'fredpass') is None)
tests/db/test_changes.py
Show inline comments
 
# -*- coding: utf-8 -*-
 

	
 
from unittest import TestCase
 
from unittest.mock import patch, DEFAULT, Mock, MagicMock, call
 

	
 
from sqlalchemy import orm
 
from mock import patch, DEFAULT, Mock, MagicMock, call
 

	
 
from rattail import db
 
from rattail.db import changes, model
 
from .. import DataTestCase
 
from rattail.config import RattailConfig
 
from rattail.core import get_uuid
 

	
 
try:
 
    from sqlalchemy import orm
 
    from rattail import db
 
    from rattail.db import changes, model
 
    from .. import DataTestCase
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestRecordChangesFunc(TestCase):
 

	
 
        def setUp(self):
 
            self.config = RattailConfig()
 

	
 
        def test_session_class(self):
 
            Session = orm.sessionmaker()
 
            if hasattr(Session, 'kw'):
 
                self.assertRaises(KeyError, Session.kw.__getitem__, 'rattail_record_changes')
 
            self.assertRaises(AttributeError, getattr, Session, 'rattail_record_changes')
 
            changes.record_changes(Session)
 
            self.assertTrue(Session.rattail_record_changes)
 

	
 
        def test_session_instance(self):
 
            session = db.Session()
 
            self.assertFalse(session.rattail_record_changes)
 
            changes.record_changes(session)
 
            self.assertTrue(session.rattail_record_changes)
 
            session.close()
 

	
 
        def test_recorder(self):
 

	
 
            # no recorder
 
            session = db.Session()
 
            self.assertRaises(AttributeError, getattr, session, 'rattail_change_recorder')
 
            session.close()
 

	
 
            # default recorder
 
            session = db.Session()
 
            changes.record_changes(session)
 
            self.assertIs(type(session.rattail_change_recorder), changes.ChangeRecorder)
 
            session.close()
 

	
 
            # specify recorder instance
 
            recorder = changes.ChangeRecorder(self.config)
 
            session = db.Session()
 
            changes.record_changes(session, recorder=recorder)
 
            self.assertIs(session.rattail_change_recorder, recorder)
 
            session.close()
 

	
 
            # specify recorder factory
 
            session = db.Session()
 
            changes.record_changes(session, recorder=changes.ChangeRecorder, config=self.config)
 
            self.assertIs(type(session.rattail_change_recorder), changes.ChangeRecorder)
 
            session.close()
 

	
 
            # specify recorder spec via config
 
            config = RattailConfig()
 
            config.setdefault('rattail.db', 'changes.recorder', 'rattail.db.changes:ChangeRecorder')
 
            session = db.Session()
 
            changes.record_changes(session, config=config)
 
            self.assertIs(type(session.rattail_change_recorder), changes.ChangeRecorder)
 
            session.close()
 

	
 
            # invalid recorder
 
            session = db.Session()
 
            self.assertRaises(ValueError, changes.record_changes, session, recorder='bogus')
 
            session.close()
 

	
 

	
 
    class TestChangeRecorder(DataTestCase):
 

	
 
        def extra_setup(self):
 
            self.config = RattailConfig()
 

	
 
        def test_ignore_object(self):
 
            recorder = changes.ChangeRecorder(self.config)
 
            self.assertTrue(recorder.ignore_object(model.Setting()))
 
            self.assertTrue(recorder.ignore_object(model.Change()))
 
            self.assertTrue(recorder.ignore_object(model.DataSyncChange()))
 
            self.assertFalse(recorder.ignore_object(model.Product()))
 
            self.assertFalse(recorder.ignore_object(model.Customer()))
 

	
 
        def test_process_new_object(self):
 
            self.assertEqual(self.session.query(model.Change).count(), 0)
 
            recorder = changes.ChangeRecorder(self.config)
 
            product = model.Product(uuid='6de299ca178d11e6be2c3ca9f40bc550')
 
            recorder.process_new_object(self.session, product)
 
            change = self.session.query(model.Change).one()
 
            self.assertEqual(change.class_name, 'Product')
 
            self.assertEqual(change.object_key, '6de299ca178d11e6be2c3ca9f40bc550')
 
            self.assertFalse(change.deleted)
 

	
 
        def test_process_dirty_object(self):
 
            self.assertEqual(self.session.query(model.Change).count(), 0)
 
            recorder = changes.ChangeRecorder(self.config)
 
            product = model.Product(uuid='6de299ca178d11e6be2c3ca9f40bc550')
 
            recorder.process_dirty_object(self.session, product)
 
            change = self.session.query(model.Change).one()
 
            self.assertEqual(change.class_name, 'Product')
 
            self.assertEqual(change.object_key, '6de299ca178d11e6be2c3ca9f40bc550')
 
            self.assertFalse(change.deleted)
 

	
 
class TestRecordChangesFunc(TestCase):
 

	
 
    def setUp(self):
 
        self.config = RattailConfig()
 

	
 
    def test_session_class(self):
 
        Session = orm.sessionmaker()
 
        if hasattr(Session, 'kw'):
 
            self.assertRaises(KeyError, Session.kw.__getitem__, 'rattail_record_changes')
 
        self.assertRaises(AttributeError, getattr, Session, 'rattail_record_changes')
 
        changes.record_changes(Session)
 
        self.assertTrue(Session.rattail_record_changes)
 

	
 
    def test_session_instance(self):
 
        session = db.Session()
 
        self.assertFalse(session.rattail_record_changes)
 
        changes.record_changes(session)
 
        self.assertTrue(session.rattail_record_changes)
 
        session.close()
 

	
 
    def test_recorder(self):
 

	
 
        # no recorder
 
        session = db.Session()
 
        self.assertRaises(AttributeError, getattr, session, 'rattail_change_recorder')
 
        session.close()
 

	
 
        # default recorder
 
        session = db.Session()
 
        changes.record_changes(session)
 
        self.assertIs(type(session.rattail_change_recorder), changes.ChangeRecorder)
 
        session.close()
 

	
 
        # specify recorder instance
 
        recorder = changes.ChangeRecorder(self.config)
 
        session = db.Session()
 
        changes.record_changes(session, recorder=recorder)
 
        self.assertIs(session.rattail_change_recorder, recorder)
 
        session.close()
 

	
 
        # specify recorder factory
 
        session = db.Session()
 
        changes.record_changes(session, recorder=changes.ChangeRecorder, config=self.config)
 
        self.assertIs(type(session.rattail_change_recorder), changes.ChangeRecorder)
 
        session.close()
 

	
 
        # specify recorder spec via config
 
        config = RattailConfig()
 
        config.setdefault('rattail.db', 'changes.recorder', 'rattail.db.changes:ChangeRecorder')
 
        session = db.Session()
 
        changes.record_changes(session, config=config)
 
        self.assertIs(type(session.rattail_change_recorder), changes.ChangeRecorder)
 
        session.close()
 

	
 
        # invalid recorder
 
        session = db.Session()
 
        self.assertRaises(ValueError, changes.record_changes, session, recorder='bogus')
 
        session.close()
 

	
 

	
 
class TestChangeRecorder(DataTestCase):
 

	
 
    def extra_setup(self):
 
        self.config = RattailConfig()
 

	
 
    def test_ignore_object(self):
 
        recorder = changes.ChangeRecorder(self.config)
 
        self.assertTrue(recorder.ignore_object(model.Setting()))
 
        self.assertTrue(recorder.ignore_object(model.Change()))
 
        self.assertTrue(recorder.ignore_object(model.DataSyncChange()))
 
        self.assertFalse(recorder.ignore_object(model.Product()))
 
        self.assertFalse(recorder.ignore_object(model.Customer()))
 

	
 
    def test_process_new_object(self):
 
        self.assertEqual(self.session.query(model.Change).count(), 0)
 
        recorder = changes.ChangeRecorder(self.config)
 
        product = model.Product(uuid='6de299ca178d11e6be2c3ca9f40bc550')
 
        recorder.process_new_object(self.session, product)
 
        change = self.session.query(model.Change).one()
 
        self.assertEqual(change.class_name, 'Product')
 
        self.assertEqual(change.object_key, '6de299ca178d11e6be2c3ca9f40bc550')
 
        self.assertFalse(change.deleted)
 

	
 
    def test_process_dirty_object(self):
 
        self.assertEqual(self.session.query(model.Change).count(), 0)
 
        recorder = changes.ChangeRecorder(self.config)
 
        product = model.Product(uuid='6de299ca178d11e6be2c3ca9f40bc550')
 
        recorder.process_dirty_object(self.session, product)
 
        change = self.session.query(model.Change).one()
 
        self.assertEqual(change.class_name, 'Product')
 
        self.assertEqual(change.object_key, '6de299ca178d11e6be2c3ca9f40bc550')
 
        self.assertFalse(change.deleted)
 

	
 
    def test_process_deleted_object(self):
 
        self.assertEqual(self.session.query(model.Change).count(), 0)
 
        recorder = changes.ChangeRecorder(self.config)
 
        product = model.Product(uuid='6de299ca178d11e6be2c3ca9f40bc550')
 
        recorder.process_deleted_object(self.session, product)
 
        change = self.session.query(model.Change).one()
 
        self.assertEqual(change.class_name, 'Product')
 
        self.assertEqual(change.object_key, '6de299ca178d11e6be2c3ca9f40bc550')
 
        self.assertTrue(change.deleted)
 

	
 
    def test_process_deleted_object_special(self):
 
        self.assertEqual(self.session.query(model.Change).count(), 0)
 
        recorder = changes.ChangeRecorder(self.config)
 

	
 
        person = model.Person(uuid='06100a34178e11e6a8633ca9f40bc550')
 
        for Model in (model.PersonEmailAddress, model.PersonPhoneNumber, model.PersonMailingAddress):
 
        def test_process_deleted_object(self):
 
            self.assertEqual(self.session.query(model.Change).count(), 0)
 
            recorder = changes.ChangeRecorder(self.config)
 
            product = model.Product(uuid='6de299ca178d11e6be2c3ca9f40bc550')
 
            recorder.process_deleted_object(self.session, product)
 
            change = self.session.query(model.Change).one()
 
            self.assertEqual(change.class_name, 'Product')
 
            self.assertEqual(change.object_key, '6de299ca178d11e6be2c3ca9f40bc550')
 
            self.assertTrue(change.deleted)
 

	
 
        def test_process_deleted_object_special(self):
 
            self.assertEqual(self.session.query(model.Change).count(), 0)
 
            uuid = get_uuid()
 
            obj = Model(uuid=uuid, person=person)
 
            recorder.process_deleted_object(self.session, obj)
 
            self.assertEqual(self.session.query(model.Change).count(), 2)
 
            recorder = changes.ChangeRecorder(self.config)
 

	
 
            person = model.Person(uuid='06100a34178e11e6a8633ca9f40bc550')
 
            for Model in (model.PersonEmailAddress, model.PersonPhoneNumber, model.PersonMailingAddress):
 

	
 
                self.assertEqual(self.session.query(model.Change).count(), 0)
 
                uuid = get_uuid()
 
                obj = Model(uuid=uuid, person=person)
 
                recorder.process_deleted_object(self.session, obj)
 
                self.assertEqual(self.session.query(model.Change).count(), 2)
 

	
 
                change = self.session.query(model.Change).filter_by(class_name=Model.__name__).one()
 
                self.assertEqual(change.object_key, uuid)
 
                self.assertTrue(change.deleted)
 

	
 
                change = self.session.query(model.Change).filter_by(class_name='Person').one()
 
                self.assertEqual(change.object_key, '06100a34178e11e6a8633ca9f40bc550')
 
                self.assertFalse(change.deleted)
 

	
 
                self.session.query(model.Change).delete(synchronize_session=False)
 

	
 
            employee = model.Employee(uuid='b302ac38178e11e6b8843ca9f40bc550')
 
            for Model in (model.EmployeeStore, model.EmployeeDepartment):
 

	
 
                self.assertEqual(self.session.query(model.Change).count(), 0)
 
                uuid = get_uuid()
 
                obj = Model(uuid=uuid, employee=employee)
 
                recorder.process_deleted_object(self.session, obj)
 
                self.assertEqual(self.session.query(model.Change).count(), 2)
 

	
 
                change = self.session.query(model.Change).filter_by(class_name=Model.__name__).one()
 
                self.assertEqual(change.object_key, uuid)
 
                self.assertTrue(change.deleted)
 

	
 
                change = self.session.query(model.Change).filter_by(class_name='Employee').one()
 
                self.assertEqual(change.object_key, 'b302ac38178e11e6b8843ca9f40bc550')
 
                self.assertFalse(change.deleted)
 

	
 
                self.session.query(model.Change).delete(synchronize_session=False)
 

	
 
            change = self.session.query(model.Change).filter_by(class_name=Model.__name__).one()
 
            self.assertEqual(change.object_key, uuid)
 
        def test_record_change(self):
 
            self.assertEqual(self.session.query(model.Change).count(), 0)
 
            recorder = changes.ChangeRecorder(self.config)
 

	
 
            recorder.record_change(self.session, class_name='Bogus', object_key='bogus', deleted=False)
 
            self.assertEqual(self.session.query(model.Change).count(), 1)
 
            change = self.session.query(model.Change).one()
 
            self.assertEqual(change.class_name, 'Bogus')
 
            self.assertEqual(change.object_key, 'bogus')
 
            self.assertFalse(change.deleted)
 

	
 
            recorder.record_change(self.session, class_name='Invalid', object_key='invalid', deleted=True)
 
            self.assertEqual(self.session.query(model.Change).count(), 2)
 
            change = self.session.query(model.Change).filter_by(class_name='Invalid').one()
 
            self.assertEqual(change.object_key, 'invalid')
 
            self.assertTrue(change.deleted)
 

	
 
            change = self.session.query(model.Change).filter_by(class_name='Person').one()
 
            self.assertEqual(change.object_key, '06100a34178e11e6a8633ca9f40bc550')
 
        def test_record_rattail_change(self):
 
            self.assertEqual(self.session.query(model.Change).count(), 0)
 
            recorder = changes.ChangeRecorder(self.config)
 

	
 
            # ignore change object (TODO: redundant?)
 
            self.assertFalse(recorder.record_rattail_change(self.session, model.Change()))
 
            self.assertFalse(recorder.record_rattail_change(self.session, model.DataSyncChange()))
 
            self.assertEqual(self.session.query(model.Change).count(), 0)
 

	
 
            # ignore batch object (TODO: redundant?)
 
            self.assertFalse(recorder.record_rattail_change(self.session, model.BatchMixin()))
 
            self.assertFalse(recorder.record_rattail_change(self.session, model.BatchRowMixin()))
 
            self.assertEqual(self.session.query(model.Change).count(), 0)
 

	
 
            # ignore instance with no UUID attribute
 
            self.assertFalse(recorder.record_rattail_change(self.session, model.Setting()))
 
            self.assertFalse(recorder.record_rattail_change(self.session, model.Permission()))
 
            self.assertEqual(self.session.query(model.Change).count(), 0)
 

	
 
            product = model.Product(uuid='c1b0fb94178f11e680fd3ca9f40bc550')
 

	
 
            # default
 
            self.assertTrue(recorder.record_rattail_change(self.session, product))
 
            change = self.session.query(model.Change).one()
 
            self.assertEqual(change.class_name, 'Product')
 
            self.assertEqual(change.object_key, 'c1b0fb94178f11e680fd3ca9f40bc550')
 
            self.assertFalse(change.deleted)
 
            self.session.query(model.Change).delete(synchronize_session=False)
 

	
 
            # new
 
            self.assertTrue(recorder.record_rattail_change(self.session, product, type_='new'))
 
            change = self.session.query(model.Change).one()
 
            self.assertEqual(change.class_name, 'Product')
 
            self.assertEqual(change.object_key, 'c1b0fb94178f11e680fd3ca9f40bc550')
 
            self.assertFalse(change.deleted)
 
            self.session.query(model.Change).delete(synchronize_session=False)
 

	
 
        employee = model.Employee(uuid='b302ac38178e11e6b8843ca9f40bc550')
 
        for Model in (model.EmployeeStore, model.EmployeeDepartment):
 
            # dirty
 
            self.assertTrue(recorder.record_rattail_change(self.session, product, type_='dirty'))
 
            change = self.session.query(model.Change).one()
 
            self.assertEqual(change.class_name, 'Product')
 
            self.assertEqual(change.object_key, 'c1b0fb94178f11e680fd3ca9f40bc550')
 
            self.assertFalse(change.deleted)
 
            self.session.query(model.Change).delete(synchronize_session=False)
 

	
 
            # deleted
 
            self.assertTrue(recorder.record_rattail_change(self.session, product, type_='deleted'))
 
            change = self.session.query(model.Change).one()
 
            self.assertEqual(change.class_name, 'Product')
 
            self.assertEqual(change.object_key, 'c1b0fb94178f11e680fd3ca9f40bc550')
 
            self.assertTrue(change.deleted)
 
            self.session.query(model.Change).delete(synchronize_session=False)
 

	
 

	
 
    class TestChangeRecorderLegacy(TestCase):
 

	
 
        def setUp(self):
 
            self.config = RattailConfig()
 

	
 
        def test_init(self):
 
            recorder = changes.ChangeRecorder(self.config)
 

	
 
        # def test_record_change(self):
 
        #     session = Mock()
 
        #     recorder = changes.ChangeRecorder()
 
        #     recorder.ensure_uuid = Mock()
 

	
 
        #     # don't record changes for changes
 
        #     self.assertFalse(recorder.record_change(session, model.Change()))
 

	
 
        #     # don't record changes for objects with no uuid attribute
 
        #     self.assertFalse(recorder.record_change(session, object()))
 

	
 
        #     # none of the above should have involved a call to `ensure_uuid()`
 
        #     self.assertFalse(recorder.ensure_uuid.called)
 

	
 
        #     # so far no *new* changes have been created
 
        #     self.assertFalse(session.add.called)
 

	
 
        #     # mock up session to force new change creation
 
        #     session.query.return_value = session
 
        #     session.get.return_value = None
 
        #     self.assertTrue(recorder.record_change(session, model.Product()))
 

	
 
        @patch.multiple('rattail.db.changes', get_uuid=DEFAULT, object_mapper=DEFAULT)
 
        def test_ensure_uuid(self, get_uuid, object_mapper):
 
            recorder = changes.ChangeRecorder(self.config)
 
            uuid_column = Mock()
 
            object_mapper.return_value.columns.__getitem__.return_value = uuid_column
 

	
 
            # uuid already present
 
            product = model.Product(uuid='some_uuid')
 
            recorder.ensure_uuid(product)
 
            self.assertEqual(product.uuid, 'some_uuid')
 
            self.assertFalse(get_uuid.called)
 

	
 
            # no uuid yet, auto-generate
 
            uuid_column.foreign_keys = False
 
            get_uuid.return_value = 'another_uuid'
 
            product = model.Product()
 
            self.assertTrue(product.uuid is None)
 
            recorder.ensure_uuid(product)
 
            get_uuid.assert_called_once_with()
 
            self.assertEqual(product.uuid, 'another_uuid')
 

	
 
            # some heavy mocking for following tests
 
            uuid_column.foreign_keys = True
 
            remote_side = MagicMock(key='uuid')
 
            prop = MagicMock(__class__=orm.RelationshipProperty, key='foreign_thing')
 
            prop.remote_side.__len__.return_value = 1
 
            prop.remote_side.__iter__.return_value = [remote_side]
 
            object_mapper.return_value.iterate_properties.__iter__.return_value = [prop]
 

	
 
            # uuid fetched from existing foreign key object
 
            get_uuid.reset_mock()
 
            instance = Mock(uuid=None, foreign_thing=Mock(uuid='secondary_uuid'))
 
            recorder.ensure_uuid(instance)
 
            self.assertFalse(get_uuid.called)
 
            self.assertEqual(instance.uuid, 'secondary_uuid')
 

	
 
            # foreign key object doesn't exist; uuid generated as fallback
 
            get_uuid.return_value = 'fallback_uuid'
 
            instance = Mock(uuid=None, foreign_thing=None)
 
            recorder.ensure_uuid(instance)
 
            get_uuid.assert_called_once_with()
 
            self.assertEqual(instance.uuid, 'fallback_uuid')
 

	
 

	
 
    class TestFunctionalChanges(DataTestCase):
 

	
 
        def setUp(self):
 
            super(TestFunctionalChanges, self).setUp()
 
            changes.record_changes(self.session)
 

	
 
        def test_add(self):
 
            product = model.Product()
 
            self.session.add(product)
 
            self.session.commit()
 

	
 
            self.assertEqual(self.session.query(model.Change).count(), 1)
 
            change = self.session.query(model.Change).one()
 
            self.assertEqual(change.class_name, 'Product')
 
            self.assertEqual(change.instance_uuid, product.uuid)
 
            self.assertFalse(change.deleted)
 

	
 
        def test_change(self):
 
            product = model.Product()
 
            self.session.add(product)
 
            self.session.commit()
 

	
 
            self.assertEqual(self.session.query(model.Change).count(), 1)
 
            self.session.query(model.Change).delete()
 
            self.assertEqual(self.session.query(model.Change).count(), 0)
 
            uuid = get_uuid()
 
            obj = Model(uuid=uuid, employee=employee)
 
            recorder.process_deleted_object(self.session, obj)
 
            self.assertEqual(self.session.query(model.Change).count(), 2)
 

	
 
            change = self.session.query(model.Change).filter_by(class_name=Model.__name__).one()
 
            self.assertEqual(change.object_key, uuid)
 
            product.description = 'Acme Bricks'
 
            self.session.commit()
 

	
 
            self.assertEqual(self.session.query(model.Change).count(), 1)
 
            change = self.session.query(model.Change).one()
 
            self.assertEqual(change.class_name, 'Product')
 
            self.assertEqual(change.instance_uuid, product.uuid)
 
            self.assertFalse(change.deleted)
 

	
 
        def test_delete(self):
 
            product = model.Product()
 
            self.session.add(product)
 
            self.session.commit()
 

	
 
            self.assertEqual(self.session.query(model.Change).count(), 1)
 
            self.session.query(model.Change).delete()
 
            self.assertEqual(self.session.query(model.Change).count(), 0)
 

	
 
            self.session.delete(product)
 

	
 
            self.assertEqual(self.session.query(model.Change).count(), 1)
 
            change = self.session.query(model.Change).one()
 
            self.assertEqual(change.class_name, 'Product')
 
            self.assertEqual(change.instance_uuid, product.uuid)
 
            self.assertTrue(change.deleted)
 

	
 
            change = self.session.query(model.Change).filter_by(class_name='Employee').one()
 
            self.assertEqual(change.object_key, 'b302ac38178e11e6b8843ca9f40bc550')
 
        def test_orphan_change(self):
 
            department = model.Department()
 
            subdepartment = model.Subdepartment()
 
            department.subdepartments.append(subdepartment)
 
            self.session.add(department)
 
            self.session.commit()
 

	
 
            self.assertEqual(self.session.query(model.Change).count(), 2)
 
            change = self.session.query(model.Change).filter_by(class_name='Department').one()
 
            self.assertFalse(change.deleted)
 
            change = self.session.query(model.Change).filter_by(class_name='Subdepartment').one()
 
            self.assertFalse(change.deleted)
 

	
 
            self.session.query(model.Change).delete()
 
            self.assertEqual(self.session.query(model.Change).count(), 0)
 

	
 
            # Creating an orphaned Subdepartment, which should be recorded as a
 
            # *change* due to the cascade rules in effect.
 
            department.subdepartments.remove(subdepartment)
 
            self.session.commit()
 

	
 
            self.assertEqual(self.session.query(model.Change).count(), 2)
 
            change = self.session.query(model.Change).filter_by(class_name='Department').one()
 
            self.assertFalse(change.deleted)
 
            change = self.session.query(model.Change).filter_by(class_name='Subdepartment').one()
 
            self.assertFalse(change.deleted)
 
            self.assertEqual(self.session.query(model.Subdepartment).count(), 1)
 

	
 
            self.session.query(model.Change).delete(synchronize_session=False)
 
        def test_orphan_delete(self):
 
            customer = model.Customer()
 
            group = model.CustomerGroup()
 
            customer.groups.append(group)
 
            self.session.add(customer)
 
            self.session.commit()
 

	
 
            self.assertEqual(self.session.query(model.Change).count(), 3)
 
            change = self.session.query(model.Change).filter_by(class_name='Customer').one()
 
            self.assertFalse(change.deleted)
 
            change = self.session.query(model.Change).filter_by(class_name='CustomerGroup').one()
 
            self.assertFalse(change.deleted)
 
            change = self.session.query(model.Change).filter_by(class_name='CustomerGroupAssignment').one()
 
            self.assertFalse(change.deleted)
 

	
 
            self.session.query(model.Change).delete()
 
            self.assertEqual(self.session.query(model.Change).count(), 0)
 

	
 
    def test_record_change(self):
 
        self.assertEqual(self.session.query(model.Change).count(), 0)
 
        recorder = changes.ChangeRecorder(self.config)
 

	
 
        recorder.record_change(self.session, class_name='Bogus', object_key='bogus', deleted=False)
 
        self.assertEqual(self.session.query(model.Change).count(), 1)
 
        change = self.session.query(model.Change).one()
 
        self.assertEqual(change.class_name, 'Bogus')
 
        self.assertEqual(change.object_key, 'bogus')
 
        self.assertFalse(change.deleted)
 

	
 
        recorder.record_change(self.session, class_name='Invalid', object_key='invalid', deleted=True)
 
        self.assertEqual(self.session.query(model.Change).count(), 2)
 
        change = self.session.query(model.Change).filter_by(class_name='Invalid').one()
 
        self.assertEqual(change.object_key, 'invalid')
 
        self.assertTrue(change.deleted)
 

	
 
    def test_record_rattail_change(self):
 
        self.assertEqual(self.session.query(model.Change).count(), 0)
 
        recorder = changes.ChangeRecorder(self.config)
 

	
 
        # ignore change object (TODO: redundant?)
 
        self.assertFalse(recorder.record_rattail_change(self.session, model.Change()))
 
        self.assertFalse(recorder.record_rattail_change(self.session, model.DataSyncChange()))
 
        self.assertEqual(self.session.query(model.Change).count(), 0)
 

	
 
        # ignore batch object (TODO: redundant?)
 
        self.assertFalse(recorder.record_rattail_change(self.session, model.BatchMixin()))
 
        self.assertFalse(recorder.record_rattail_change(self.session, model.BatchRowMixin()))
 
        self.assertEqual(self.session.query(model.Change).count(), 0)
 

	
 
        # ignore instance with no UUID attribute
 
        self.assertFalse(recorder.record_rattail_change(self.session, model.Setting()))
 
        self.assertFalse(recorder.record_rattail_change(self.session, model.Permission()))
 
        self.assertEqual(self.session.query(model.Change).count(), 0)
 

	
 
        product = model.Product(uuid='c1b0fb94178f11e680fd3ca9f40bc550')
 

	
 
        # default
 
        self.assertTrue(recorder.record_rattail_change(self.session, product))
 
        change = self.session.query(model.Change).one()
 
        self.assertEqual(change.class_name, 'Product')
 
        self.assertEqual(change.object_key, 'c1b0fb94178f11e680fd3ca9f40bc550')
 
        self.assertFalse(change.deleted)
 
        self.session.query(model.Change).delete(synchronize_session=False)
 

	
 
        # new
 
        self.assertTrue(recorder.record_rattail_change(self.session, product, type_='new'))
 
        change = self.session.query(model.Change).one()
 
        self.assertEqual(change.class_name, 'Product')
 
        self.assertEqual(change.object_key, 'c1b0fb94178f11e680fd3ca9f40bc550')
 
        self.assertFalse(change.deleted)
 
        self.session.query(model.Change).delete(synchronize_session=False)
 

	
 
        # dirty
 
        self.assertTrue(recorder.record_rattail_change(self.session, product, type_='dirty'))
 
        change = self.session.query(model.Change).one()
 
        self.assertEqual(change.class_name, 'Product')
 
        self.assertEqual(change.object_key, 'c1b0fb94178f11e680fd3ca9f40bc550')
 
        self.assertFalse(change.deleted)
 
        self.session.query(model.Change).delete(synchronize_session=False)
 

	
 
        # deleted
 
        self.assertTrue(recorder.record_rattail_change(self.session, product, type_='deleted'))
 
        change = self.session.query(model.Change).one()
 
        self.assertEqual(change.class_name, 'Product')
 
        self.assertEqual(change.object_key, 'c1b0fb94178f11e680fd3ca9f40bc550')
 
        self.assertTrue(change.deleted)
 
        self.session.query(model.Change).delete(synchronize_session=False)
 

	
 

	
 
class TestChangeRecorderLegacy(TestCase):
 

	
 
    def setUp(self):
 
        self.config = RattailConfig()
 

	
 
    def test_init(self):
 
        recorder = changes.ChangeRecorder(self.config)
 

	
 
    # def test_record_change(self):
 
    #     session = Mock()
 
    #     recorder = changes.ChangeRecorder()
 
    #     recorder.ensure_uuid = Mock()
 

	
 
    #     # don't record changes for changes
 
    #     self.assertFalse(recorder.record_change(session, model.Change()))
 

	
 
    #     # don't record changes for objects with no uuid attribute
 
    #     self.assertFalse(recorder.record_change(session, object()))
 

	
 
    #     # none of the above should have involved a call to `ensure_uuid()`
 
    #     self.assertFalse(recorder.ensure_uuid.called)
 

	
 
    #     # so far no *new* changes have been created
 
    #     self.assertFalse(session.add.called)
 

	
 
    #     # mock up session to force new change creation
 
    #     session.query.return_value = session
 
    #     session.get.return_value = None
 
    #     self.assertTrue(recorder.record_change(session, model.Product()))
 

	
 
    @patch.multiple('rattail.db.changes', get_uuid=DEFAULT, object_mapper=DEFAULT)
 
    def test_ensure_uuid(self, get_uuid, object_mapper):
 
        recorder = changes.ChangeRecorder(self.config)
 
        uuid_column = Mock()
 
        object_mapper.return_value.columns.__getitem__.return_value = uuid_column
 

	
 
        # uuid already present
 
        product = model.Product(uuid='some_uuid')
 
        recorder.ensure_uuid(product)
 
        self.assertEqual(product.uuid, 'some_uuid')
 
        self.assertFalse(get_uuid.called)
 

	
 
        # no uuid yet, auto-generate
 
        uuid_column.foreign_keys = False
 
        get_uuid.return_value = 'another_uuid'
 
        product = model.Product()
 
        self.assertTrue(product.uuid is None)
 
        recorder.ensure_uuid(product)
 
        get_uuid.assert_called_once_with()
 
        self.assertEqual(product.uuid, 'another_uuid')
 

	
 
        # some heavy mocking for following tests
 
        uuid_column.foreign_keys = True
 
        remote_side = MagicMock(key='uuid')
 
        prop = MagicMock(__class__=orm.RelationshipProperty, key='foreign_thing')
 
        prop.remote_side.__len__.return_value = 1
 
        prop.remote_side.__iter__.return_value = [remote_side]
 
        object_mapper.return_value.iterate_properties.__iter__.return_value = [prop]
 
        
 
        # uuid fetched from existing foreign key object
 
        get_uuid.reset_mock()
 
        instance = Mock(uuid=None, foreign_thing=Mock(uuid='secondary_uuid'))
 
        recorder.ensure_uuid(instance)
 
        self.assertFalse(get_uuid.called)
 
        self.assertEqual(instance.uuid, 'secondary_uuid')
 

	
 
        # foreign key object doesn't exist; uuid generated as fallback
 
        get_uuid.return_value = 'fallback_uuid'
 
        instance = Mock(uuid=None, foreign_thing=None)
 
        recorder.ensure_uuid(instance)
 
        get_uuid.assert_called_once_with()
 
        self.assertEqual(instance.uuid, 'fallback_uuid')
 

	
 

	
 
class TestFunctionalChanges(DataTestCase):
 

	
 
    def setUp(self):
 
        super(TestFunctionalChanges, self).setUp()
 
        changes.record_changes(self.session)
 

	
 
    def test_add(self):
 
        product = model.Product()
 
        self.session.add(product)
 
        self.session.commit()
 

	
 
        self.assertEqual(self.session.query(model.Change).count(), 1)
 
        change = self.session.query(model.Change).one()
 
        self.assertEqual(change.class_name, 'Product')
 
        self.assertEqual(change.instance_uuid, product.uuid)
 
        self.assertFalse(change.deleted)
 

	
 
    def test_change(self):
 
        product = model.Product()
 
        self.session.add(product)
 
        self.session.commit()
 

	
 
        self.assertEqual(self.session.query(model.Change).count(), 1)
 
        self.session.query(model.Change).delete()
 
        self.assertEqual(self.session.query(model.Change).count(), 0)
 

	
 
        product.description = 'Acme Bricks'
 
        self.session.commit()
 

	
 
        self.assertEqual(self.session.query(model.Change).count(), 1)
 
        change = self.session.query(model.Change).one()
 
        self.assertEqual(change.class_name, 'Product')
 
        self.assertEqual(change.instance_uuid, product.uuid)
 
        self.assertFalse(change.deleted)
 

	
 
    def test_delete(self):
 
        product = model.Product()
 
        self.session.add(product)
 
        self.session.commit()
 

	
 
        self.assertEqual(self.session.query(model.Change).count(), 1)
 
        self.session.query(model.Change).delete()
 
        self.assertEqual(self.session.query(model.Change).count(), 0)
 

	
 
        self.session.delete(product)
 

	
 
        self.assertEqual(self.session.query(model.Change).count(), 1)
 
        change = self.session.query(model.Change).one()
 
        self.assertEqual(change.class_name, 'Product')
 
        self.assertEqual(change.instance_uuid, product.uuid)
 
        self.assertTrue(change.deleted)
 

	
 
    def test_orphan_change(self):
 
        department = model.Department()
 
        subdepartment = model.Subdepartment()
 
        department.subdepartments.append(subdepartment)
 
        self.session.add(department)
 
        self.session.commit()
 

	
 
        self.assertEqual(self.session.query(model.Change).count(), 2)
 
        change = self.session.query(model.Change).filter_by(class_name='Department').one()
 
        self.assertFalse(change.deleted)
 
        change = self.session.query(model.Change).filter_by(class_name='Subdepartment').one()
 
        self.assertFalse(change.deleted)
 

	
 
        self.session.query(model.Change).delete()
 
        self.assertEqual(self.session.query(model.Change).count(), 0)
 

	
 
        # Creating an orphaned Subdepartment, which should be recorded as a
 
        # *change* due to the cascade rules in effect.
 
        department.subdepartments.remove(subdepartment)
 
        self.session.commit()
 

	
 
        self.assertEqual(self.session.query(model.Change).count(), 2)
 
        change = self.session.query(model.Change).filter_by(class_name='Department').one()
 
        self.assertFalse(change.deleted)
 
        change = self.session.query(model.Change).filter_by(class_name='Subdepartment').one()
 
        self.assertFalse(change.deleted)
 
        self.assertEqual(self.session.query(model.Subdepartment).count(), 1)
 
    
 
    def test_orphan_delete(self):
 
        customer = model.Customer()
 
        group = model.CustomerGroup()
 
        customer.groups.append(group)
 
        self.session.add(customer)
 
        self.session.commit()
 

	
 
        self.assertEqual(self.session.query(model.Change).count(), 3)
 
        change = self.session.query(model.Change).filter_by(class_name='Customer').one()
 
        self.assertFalse(change.deleted)
 
        change = self.session.query(model.Change).filter_by(class_name='CustomerGroup').one()
 
        self.assertFalse(change.deleted)
 
        change = self.session.query(model.Change).filter_by(class_name='CustomerGroupAssignment').one()
 
        self.assertFalse(change.deleted)
 

	
 
        self.session.query(model.Change).delete()
 
        self.assertEqual(self.session.query(model.Change).count(), 0)
 

	
 
        # Creating an orphaned CustomerGroupAssociation, which should be
 
        # recorded as a *deletion* due to the cascade rules in effect.  Note
 
        # that the CustomerGroup is not technically an orphan and in fact is
 
        # not even changed.
 
        customer.groups.remove(group)
 
        self.session.commit()
 

	
 
        self.assertEqual(self.session.query(model.Change).count(), 2)
 
        change = self.session.query(model.Change).filter_by(class_name='Customer').one()
 
        self.assertFalse(change.deleted)
 
        change = self.session.query(model.Change).filter_by(class_name='CustomerGroupAssignment').one()
 
        self.assertTrue(change.deleted)
 
        self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 0)
 
            # Creating an orphaned CustomerGroupAssociation, which should be
 
            # recorded as a *deletion* due to the cascade rules in effect.  Note
 
            # that the CustomerGroup is not technically an orphan and in fact is
 
            # not even changed.
 
            customer.groups.remove(group)
 
            self.session.commit()
 

	
 
            self.assertEqual(self.session.query(model.Change).count(), 2)
 
            change = self.session.query(model.Change).filter_by(class_name='Customer').one()
 
            self.assertFalse(change.deleted)
 
            change = self.session.query(model.Change).filter_by(class_name='CustomerGroupAssignment').one()
 
            self.assertTrue(change.deleted)
 
            self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 0)
tests/db/test_config.py
Show inline comments
 
@@ -2,37 +2,40 @@
 

	
 
from unittest import TestCase
 

	
 
from rattail.db import config as conf
 

	
 

	
 
class TestMakeEngineFromConfig(TestCase):
 

	
 
    def test_record_changes(self):
 

	
 
        # no attribute is set by default
 
        engine = conf.make_engine_from_config({
 
            'sqlalchemy.url': 'sqlite://',
 
        })
 
        self.assertRaises(AttributeError, getattr, engine, 'rattail_record_changes')
 

	
 
        # but if flag is true, attr is set
 
        engine = conf.make_engine_from_config({
 
            'sqlalchemy.url': 'sqlite://',
 
            'sqlalchemy.record_changes': 'true'
 
        })
 
        self.assertTrue(engine.rattail_record_changes)
 

	
 
    def test_log_pool_status(self):
 

	
 
        # no attribute is set by default
 
        engine = conf.make_engine_from_config({
 
            'sqlalchemy.url': 'sqlite://',
 
        })
 
        self.assertRaises(AttributeError, getattr, engine, 'rattail_log_pool_status')
 

	
 
        # but if flag is true, attr is set
 
        engine = conf.make_engine_from_config({
 
            'sqlalchemy.url': 'sqlite://',
 
            'sqlalchemy.log_pool_status': 'true'
 
        })
 
        self.assertTrue(engine.rattail_log_pool_status)
 
try:
 
    from rattail.db import config as conf
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestMakeEngineFromConfig(TestCase):
 

	
 
        def test_record_changes(self):
 

	
 
            # no attribute is set by default
 
            engine = conf.make_engine_from_config({
 
                'sqlalchemy.url': 'sqlite://',
 
            })
 
            self.assertRaises(AttributeError, getattr, engine, 'rattail_record_changes')
 

	
 
            # but if flag is true, attr is set
 
            engine = conf.make_engine_from_config({
 
                'sqlalchemy.url': 'sqlite://',
 
                'sqlalchemy.record_changes': 'true'
 
            })
 
            self.assertTrue(engine.rattail_record_changes)
 

	
 
        def test_log_pool_status(self):
 

	
 
            # no attribute is set by default
 
            engine = conf.make_engine_from_config({
 
                'sqlalchemy.url': 'sqlite://',
 
            })
 
            self.assertRaises(AttributeError, getattr, engine, 'rattail_log_pool_status')
 

	
 
            # but if flag is true, attr is set
 
            engine = conf.make_engine_from_config({
 
                'sqlalchemy.url': 'sqlite://',
 
                'sqlalchemy.log_pool_status': 'true'
 
            })
 
            self.assertTrue(engine.rattail_log_pool_status)
tests/db/test_core.py
Show inline comments
 
# -*- coding: utf-8 -*-
 

	
 
from __future__ import unicode_literals
 
# -*- coding: utf-8; -*-
 

	
 
from unittest import TestCase
 

	
 
from sqlalchemy import create_engine
 
from sqlalchemy import Column, Integer, String, ForeignKey
 
from sqlalchemy.orm import sessionmaker, relationship
 
from sqlalchemy.ext.declarative import declarative_base
 
from sqlalchemy.ext.associationproxy import association_proxy
 

	
 
from rattail.db import core
 

	
 

	
 
class TestCore(TestCase):
 

	
 
    def test_uuid_column(self):
 
        column = core.uuid_column()
 
        self.assertTrue(isinstance(column, Column))
 
        self.assertEqual(column.name, None)
 
        self.assertTrue(column.primary_key)
 
        self.assertFalse(column.nullable)
 
        self.assertFalse(column.default is None)
 

	
 
    def test_uuid_column_no_default(self):
 
        column = core.uuid_column(default=None)
 
        self.assertTrue(column.default is None)
 

	
 
    def test_uuid_column_nullable(self):
 
        column = core.uuid_column(nullable=True)
 
        self.assertTrue(column.nullable)
 

	
 

	
 
class TestGetSetFactory(TestCase):
 

	
 
    def setUp(self):
 
        Base = declarative_base()
 

	
 
        class Primary(Base):
 
            __tablename__ = 'primary'
 
            id = Column(Integer(), primary_key=True)
 
            foo = Column(String(length=10))
 

	
 
        class Secondary(Base):
 
            __tablename__ = 'secondary'
 
            id = Column(Integer(), primary_key=True)
 
            primary_id = Column(Integer(), ForeignKey('primary.id'))
 
            bar = Column(String(length=10))
 

	
 
        Primary._secondary = relationship(
 
            Secondary, backref='primary', uselist=False)
 
        Primary.bar = association_proxy(
 
            '_secondary', 'bar',
 
            getset_factory=core.getset_factory)
 

	
 
        self.Primary = Primary
 
        self.Secondary = Secondary
 
        
 
        engine = create_engine('sqlite://')
 
        Base.metadata.create_all(bind=engine)
 
        Session = sessionmaker(bind=engine)
 
        self.session = Session()
 

	
 
    def tearDown(self):
 
        self.session.close()
 

	
 
    def test_getter_returns_none_if_proxy_value_is_absent(self):
 
        p = self.Primary()
 
        self.session.add(p)
 
        self.assertTrue(p.bar is None)
 

	
 
    def test_getter_returns_proxy_value_if_proxy_value_is_present(self):
 
        p = self.Primary()
 
        self.assertTrue(p.bar is None)
 
        s = self.Secondary(primary=p, bar='something')
 
        self.session.add(p)
 
        self.assertEqual(p.bar, 'something')
 

	
 
    def test_setter_assigns_proxy_value(self):
 
        p = self.Primary()
 
        s = self.Secondary(primary=p)
 
        self.session.add(p)
 
        self.assertTrue(s.bar is None)
 
        p.bar = 'something'
 
        self.assertEqual(s.bar, 'something')
 
try:
 
    from sqlalchemy import create_engine
 
    from sqlalchemy import Column, Integer, String, ForeignKey
 
    from sqlalchemy.orm import sessionmaker, relationship
 
    from sqlalchemy.ext.declarative import declarative_base
 
    from sqlalchemy.ext.associationproxy import association_proxy
 
    from rattail.db import core
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestCore(TestCase):
 

	
 
        def test_uuid_column(self):
 
            column = core.uuid_column()
 
            self.assertTrue(isinstance(column, Column))
 
            self.assertEqual(column.name, None)
 
            self.assertTrue(column.primary_key)
 
            self.assertFalse(column.nullable)
 
            self.assertFalse(column.default is None)
 

	
 
        def test_uuid_column_no_default(self):
 
            column = core.uuid_column(default=None)
 
            self.assertTrue(column.default is None)
 

	
 
        def test_uuid_column_nullable(self):
 
            column = core.uuid_column(nullable=True)
 
            self.assertTrue(column.nullable)
 

	
 

	
 
    class TestGetSetFactory(TestCase):
 

	
 
        def setUp(self):
 
            Base = declarative_base()
 

	
 
            class Primary(Base):
 
                __tablename__ = 'primary'
 
                id = Column(Integer(), primary_key=True)
 
                foo = Column(String(length=10))
 

	
 
            class Secondary(Base):
 
                __tablename__ = 'secondary'
 
                id = Column(Integer(), primary_key=True)
 
                primary_id = Column(Integer(), ForeignKey('primary.id'))
 
                bar = Column(String(length=10))
 

	
 
            Primary._secondary = relationship(
 
                Secondary, backref='primary', uselist=False)
 
            Primary.bar = association_proxy(
 
                '_secondary', 'bar',
 
                getset_factory=core.getset_factory)
 

	
 
            self.Primary = Primary
 
            self.Secondary = Secondary
 

	
 
            engine = create_engine('sqlite://')
 
            Base.metadata.create_all(bind=engine)
 
            Session = sessionmaker(bind=engine)
 
            self.session = Session()
 

	
 
        def tearDown(self):
 
            self.session.close()
 

	
 
        def test_getter_returns_none_if_proxy_value_is_absent(self):
 
            p = self.Primary()
 
            self.session.add(p)
 
            self.assertTrue(p.bar is None)
 

	
 
        def test_getter_returns_proxy_value_if_proxy_value_is_present(self):
 
            p = self.Primary()
 
            self.assertTrue(p.bar is None)
 
            s = self.Secondary(primary=p, bar='something')
 
            self.session.add(p)
 
            self.assertEqual(p.bar, 'something')
 

	
 
        def test_setter_assigns_proxy_value(self):
 
            p = self.Primary()
 
            s = self.Secondary(primary=p)
 
            self.session.add(p)
 
            self.assertTrue(s.bar is None)
 
            p.bar = 'something'
 
            self.assertEqual(s.bar, 'something')
tests/db/test_init.py
Show inline comments
 
@@ -5,51 +5,18 @@ import shutil
 
import tempfile
 
from unittest import TestCase
 

	
 
import sqlalchemy as sa
 

	
 
from rattail import db
 
from rattail.config import RattailConfig
 

	
 

	
 
class TestSession(TestCase):
 

	
 
    def test_init_rattail_config(self):
 
        db.Session.configure(rattail_config=None)
 
        session = db.Session()
 
        self.assertIsNone(session.rattail_config)
 
        session.close()
 

	
 
        config = object()
 
        session = db.Session(rattail_config=config)
 
        self.assertIs(session.rattail_config, config)
 
        session.close()
 

	
 
    def test_init_record_changes(self):
 
        if hasattr(db.Session, 'kw'):
 
            self.assertIsNone(db.Session.kw.get('rattail_record_changes'))
 

	
 
        session = db.Session()
 
        self.assertFalse(session.rattail_record_changes)
 
        session.close()
 

	
 
        session = db.Session(rattail_record_changes=True)
 
        self.assertTrue(session.rattail_record_changes)
 
        session.close()
 

	
 
        engine = sa.create_engine('sqlite://')
 
        engine.rattail_record_changes = True
 
        session = db.Session(bind=engine)
 
        self.assertTrue(session.rattail_record_changes)
 
        session.close()
 

	
 

	
 
class TestConfigExtension(TestCase):
 

	
 
    def setUp(self):
 
        self.tempdir = tempfile.mkdtemp()
 

	
 
    def tearDown(self):
 
        db.Session.configure(bind=None, rattail_config=None)
 
        if db.Session:
 
            db.Session.configure(bind=None, rattail_config=None)
 
        shutil.rmtree(self.tempdir)
 

	
 
    def write_file(self, fname, content):
 
@@ -66,11 +33,12 @@ class TestConfigExtension(TestCase):
 
            self.assertIsNone(db.Session.kw['bind'])
 
            self.assertIsNone(db.Session.kw['rattail_config'])
 

	
 
        db.ConfigExtension().configure(config)
 
        self.assertEqual(config.rattail_engines, {})
 
        self.assertIsNone(config.rattail_engine)
 
        if hasattr(db.Session, 'kw'):
 
            self.assertIs(db.Session.kw['rattail_config'], config)
 
        if db.Session:
 
            db.ConfigExtension().configure(config)
 
            self.assertEqual(config.rattail_engines, {})
 
            self.assertIsNone(config.rattail_engine)
 
            if hasattr(db.Session, 'kw'):
 
                self.assertIs(db.Session.kw['rattail_config'], config)
 

	
 
    def test_configure_connections(self):
 
        default_path = self.write_file('default.sqlite', '')
 
@@ -83,7 +51,46 @@ class TestConfigExtension(TestCase):
 
        config.setdefault('rattail.db', 'default.url', default_url)
 
        config.setdefault('rattail.db', 'host.url', host_url)
 
        db.ConfigExtension().configure(config)
 
        self.assertEqual(len(config.rattail_engines), 2)
 
        self.assertEqual(str(config.rattail_engines['default'].url), default_url)
 
        self.assertEqual(str(config.rattail_engines['host'].url), host_url)
 
        self.assertEqual(str(config.rattail_engine.url), default_url)
 
        if db.Session:
 
            self.assertEqual(len(config.rattail_engines), 2)
 
            self.assertEqual(str(config.rattail_engines['default'].url), default_url)
 
            self.assertEqual(str(config.rattail_engines['host'].url), host_url)
 
            self.assertEqual(str(config.rattail_engine.url), default_url)
 

	
 

	
 
try:
 
    import sqlalchemy as sa
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestSession(TestCase):
 

	
 
        def test_init_rattail_config(self):
 
            db.Session.configure(rattail_config=None)
 
            session = db.Session()
 
            self.assertIsNone(session.rattail_config)
 
            session.close()
 

	
 
            config = object()
 
            session = db.Session(rattail_config=config)
 
            self.assertIs(session.rattail_config, config)
 
            session.close()
 

	
 
        def test_init_record_changes(self):
 
            if hasattr(db.Session, 'kw'):
 
                self.assertIsNone(db.Session.kw.get('rattail_record_changes'))
 

	
 
            session = db.Session()
 
            self.assertFalse(session.rattail_record_changes)
 
            session.close()
 

	
 
            session = db.Session(rattail_record_changes=True)
 
            self.assertTrue(session.rattail_record_changes)
 
            session.close()
 

	
 
            engine = sa.create_engine('sqlite://')
 
            engine.rattail_record_changes = True
 
            session = db.Session(bind=engine)
 
            self.assertTrue(session.rattail_record_changes)
 
            session.close()
tests/db/test_model.py
Show inline comments
 
# -*- coding: utf-8; -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 

	
 
from sqlalchemy.exc import IntegrityError
 

	
 
from rattail.db import model
 
# from rattail.db.changes import record_changes
 
from .. import DataTestCase
 

	
 

	
 
class SAErrorHelper(object):
 

	
 
    def integrity_or_flush_error(self):
 
        try:
 
            from sqlalchemy.exc import FlushError
 
        except ImportError:
 
            return IntegrityError
 
        else:
 
            return (IntegrityError, FlushError)
 

	
 

	
 
class TestCustomerPerson(DataTestCase, SAErrorHelper):
 

	
 
    def test_customer_required(self):
 
        assoc = model.CustomerPerson(person=model.Person())
 
        self.session.add(assoc)
 
        self.assertRaises(self.integrity_or_flush_error(), self.session.commit)
 
        self.session.rollback()
 
        self.assertEqual(self.session.query(model.CustomerPerson).count(), 0)
 
        assoc.customer = model.Customer()
 
        self.session.add(assoc)
 
        self.session.commit()
 
        self.assertEqual(self.session.query(model.CustomerPerson).count(), 1)
 

	
 
    def test_person_required(self):
 
        assoc = model.CustomerPerson(customer=model.Customer())
 
        self.session.add(assoc)
 
        self.assertRaises(IntegrityError, self.session.commit)
 
        self.session.rollback()
 
        self.assertEqual(self.session.query(model.CustomerPerson).count(), 0)
 
        assoc.person = model.Person()
 
        self.session.add(assoc)
 
        self.session.commit()
 
        self.assertEqual(self.session.query(model.CustomerPerson).count(), 1)
 

	
 
    def test_ordinal_autoincrement(self):
 
        customer = model.Customer()
 
        self.session.add(customer)
 
        assoc = model.CustomerPerson(person=model.Person())
 
        customer._people.append(assoc)
 
        self.session.commit()
 
        self.assertEqual(assoc.ordinal, 1)
 
        assoc = model.CustomerPerson(person=model.Person())
 
        customer._people.append(assoc)
 
        self.session.commit()
 
        self.assertEqual(assoc.ordinal, 2)
 

	
 

	
 
class TestCustomerGroupAssignment(DataTestCase, SAErrorHelper):
 

	
 
    def test_customer_required(self):
 
        assignment = model.CustomerGroupAssignment(group=model.CustomerGroup())
 
        self.session.add(assignment)
 
        self.assertRaises(self.integrity_or_flush_error(), self.session.commit)
 
        self.session.rollback()
 
        self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 0)
 
        assignment.customer = model.Customer()
 
        self.session.add(assignment)
 
        self.session.commit()
 
        self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 1)
 

	
 
    def test_group_required(self):
 
        assignment = model.CustomerGroupAssignment(customer=model.Customer())
 
        self.session.add(assignment)
 
        self.assertRaises(IntegrityError, self.session.commit)
 
        self.session.rollback()
 
        self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 0)
 
        assignment.group = model.CustomerGroup()
 
        self.session.add(assignment)
 
        self.session.commit()
 
        self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 1)
 

	
 
    def test_ordinal_autoincrement(self):
 
        customer = model.Customer()
 
        self.session.add(customer)
 
        assignment = model.CustomerGroupAssignment(group=model.CustomerGroup())
 
        customer._groups.append(assignment)
 
        self.session.commit()
 
        self.assertEqual(assignment.ordinal, 1)
 
        assignment = model.CustomerGroupAssignment(group=model.CustomerGroup())
 
        customer._groups.append(assignment)
 
        self.session.commit()
 
        self.assertEqual(assignment.ordinal, 2)
 

	
 

	
 
# class TestCustomerEmailAddress(DataTestCase):
 

	
 
#     def test_pop(self):
 
#         customer = model.Customer()
 
#         customer.add_email_address('fred.home@mailinator.com')
 
#         customer.add_email_address('fred.work@mailinator.com')
 
#         self.session.add(customer)
 
#         self.session.commit()
 
#         self.assertEqual(self.session.query(model.Customer).count(), 1)
 
#         self.assertEqual(self.session.query(model.CustomerEmailAddress).count(), 2)
 

	
 
#         while customer.emails:
 
#             customer.emails.pop()
 
#         self.session.commit()
 
#         self.assertEqual(self.session.query(model.Customer).count(), 1)
 
#         self.assertEqual(self.session.query(model.CustomerEmailAddress).count(), 0)
 

	
 
#         # changes weren't being recorded
 
#         self.assertEqual(self.session.query(model.Change).count(), 0)
 

	
 
#     def test_pop_with_changes(self):
 
#         record_changes(self.session)
 

	
 
#         customer = model.Customer()
 
#         customer.add_email_address('fred.home@mailinator.com')
 
#         customer.add_email_address('fred.work@mailinator.com')
 
#         self.session.add(customer)
 
#         self.session.commit()
 
#         self.assertEqual(self.session.query(model.Customer).count(), 1)
 
#         self.assertEqual(self.session.query(model.CustomerEmailAddress).count(), 2)
 

	
 
#         while customer.emails:
 
#             customer.emails.pop()
 
#         self.session.commit()
 
#         self.assertEqual(self.session.query(model.Customer).count(), 1)
 
#         self.assertEqual(self.session.query(model.CustomerEmailAddress).count(), 0)
 

	
 
#         # changes should have been recorded
 
#         changes = self.session.query(model.Change)
 
#         self.assertEqual(changes.count(), 3)
 

	
 
#         customer_change = changes.filter_by(class_name='Customer').one()
 
#         self.assertEqual(customer_change.uuid, customer.uuid)
 
#         self.assertFalse(customer_change.deleted)
 

	
 
#         email_changes = changes.filter_by(class_name='CustomerEmailAddress')
 
#         self.assertEqual(email_changes.count(), 2)
 
#         self.assertEqual([x.deleted for x in email_changes], [True, True])
 

	
 

	
 
class TestLabelProfile(DataTestCase):
 

	
 
    def test_get_printer_setting(self):
 
        profile = model.LabelProfile()
 
        self.session.add(profile)
 

	
 
        self.assertTrue(profile.uuid is None)
 
        setting = profile.get_printer_setting('some_setting')
 
        self.assertTrue(setting is None)
 
        self.assertTrue(profile.uuid is None)
 

	
 
        profile.uuid = 'some_uuid'
 
        self.session.add(model.Setting(
 
                name='labels.some_uuid.printer.some_setting',
 
                value='some_value'))
 
        self.session.flush()
 
        setting = profile.get_printer_setting('some_setting')
 
        self.assertEqual(setting, 'some_value')
 

	
 
    def test_save_printer_setting(self):
 
        self.assertEqual(self.session.query(model.Setting).count(), 0)
 
        profile = model.LabelProfile()
 
        self.session.add(profile)
 

	
 
        self.assertTrue(profile.uuid is None)
 
        profile.save_printer_setting('some_setting', 'some_value')
 
        self.assertFalse(profile.uuid is None)
 
        self.assertEqual(self.session.query(model.Setting).count(), 1)
 

	
 
        profile.uuid = 'some_uuid'
 
        profile.save_printer_setting('some_setting', 'some_value')
 
        self.assertEqual(self.session.query(model.Setting).count(), 2)
 
        setting = self.session.query(model.Setting)\
 
            .filter_by(name='labels.some_uuid.printer.some_setting')\
 
            .one()
 
        self.assertEqual(setting.value, 'some_value')
 
try:
 
    from sqlalchemy.exc import IntegrityError
 
    from rattail.db import model
 
    # from rattail.db.changes import record_changes
 
    from .. import DataTestCase
 
except ImportError:
 
    pass
 
else:
 

	
 
    class SAErrorHelper(object):
 

	
 
        def integrity_or_flush_error(self):
 
            try:
 
                from sqlalchemy.exc import FlushError
 
            except ImportError:
 
                return IntegrityError
 
            else:
 
                return (IntegrityError, FlushError)
 

	
 

	
 
    class TestCustomerPerson(DataTestCase, SAErrorHelper):
 

	
 
        def test_customer_required(self):
 
            assoc = model.CustomerPerson(person=model.Person())
 
            self.session.add(assoc)
 
            self.assertRaises(self.integrity_or_flush_error(), self.session.commit)
 
            self.session.rollback()
 
            self.assertEqual(self.session.query(model.CustomerPerson).count(), 0)
 
            assoc.customer = model.Customer()
 
            self.session.add(assoc)
 
            self.session.commit()
 
            self.assertEqual(self.session.query(model.CustomerPerson).count(), 1)
 

	
 
        def test_person_required(self):
 
            assoc = model.CustomerPerson(customer=model.Customer())
 
            self.session.add(assoc)
 
            self.assertRaises(IntegrityError, self.session.commit)
 
            self.session.rollback()
 
            self.assertEqual(self.session.query(model.CustomerPerson).count(), 0)
 
            assoc.person = model.Person()
 
            self.session.add(assoc)
 
            self.session.commit()
 
            self.assertEqual(self.session.query(model.CustomerPerson).count(), 1)
 

	
 
        def test_ordinal_autoincrement(self):
 
            customer = model.Customer()
 
            self.session.add(customer)
 
            assoc = model.CustomerPerson(person=model.Person())
 
            customer._people.append(assoc)
 
            self.session.commit()
 
            self.assertEqual(assoc.ordinal, 1)
 
            assoc = model.CustomerPerson(person=model.Person())
 
            customer._people.append(assoc)
 
            self.session.commit()
 
            self.assertEqual(assoc.ordinal, 2)
 

	
 

	
 
    class TestCustomerGroupAssignment(DataTestCase, SAErrorHelper):
 

	
 
        def test_customer_required(self):
 
            assignment = model.CustomerGroupAssignment(group=model.CustomerGroup())
 
            self.session.add(assignment)
 
            self.assertRaises(self.integrity_or_flush_error(), self.session.commit)
 
            self.session.rollback()
 
            self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 0)
 
            assignment.customer = model.Customer()
 
            self.session.add(assignment)
 
            self.session.commit()
 
            self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 1)
 

	
 
        def test_group_required(self):
 
            assignment = model.CustomerGroupAssignment(customer=model.Customer())
 
            self.session.add(assignment)
 
            self.assertRaises(IntegrityError, self.session.commit)
 
            self.session.rollback()
 
            self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 0)
 
            assignment.group = model.CustomerGroup()
 
            self.session.add(assignment)
 
            self.session.commit()
 
            self.assertEqual(self.session.query(model.CustomerGroupAssignment).count(), 1)
 

	
 
        def test_ordinal_autoincrement(self):
 
            customer = model.Customer()
 
            self.session.add(customer)
 
            assignment = model.CustomerGroupAssignment(group=model.CustomerGroup())
 
            customer._groups.append(assignment)
 
            self.session.commit()
 
            self.assertEqual(assignment.ordinal, 1)
 
            assignment = model.CustomerGroupAssignment(group=model.CustomerGroup())
 
            customer._groups.append(assignment)
 
            self.session.commit()
 
            self.assertEqual(assignment.ordinal, 2)
 

	
 

	
 
    # class TestCustomerEmailAddress(DataTestCase):
 

	
 
    #     def test_pop(self):
 
    #         customer = model.Customer()
 
    #         customer.add_email_address('fred.home@mailinator.com')
 
    #         customer.add_email_address('fred.work@mailinator.com')
 
    #         self.session.add(customer)
 
    #         self.session.commit()
 
    #         self.assertEqual(self.session.query(model.Customer).count(), 1)
 
    #         self.assertEqual(self.session.query(model.CustomerEmailAddress).count(), 2)
 

	
 
    #         while customer.emails:
 
    #             customer.emails.pop()
 
    #         self.session.commit()
 
    #         self.assertEqual(self.session.query(model.Customer).count(), 1)
 
    #         self.assertEqual(self.session.query(model.CustomerEmailAddress).count(), 0)
 

	
 
    #         # changes weren't being recorded
 
    #         self.assertEqual(self.session.query(model.Change).count(), 0)
 

	
 
    #     def test_pop_with_changes(self):
 
    #         record_changes(self.session)
 

	
 
    #         customer = model.Customer()
 
    #         customer.add_email_address('fred.home@mailinator.com')
 
    #         customer.add_email_address('fred.work@mailinator.com')
 
    #         self.session.add(customer)
 
    #         self.session.commit()
 
    #         self.assertEqual(self.session.query(model.Customer).count(), 1)
 
    #         self.assertEqual(self.session.query(model.CustomerEmailAddress).count(), 2)
 

	
 
    #         while customer.emails:
 
    #             customer.emails.pop()
 
    #         self.session.commit()
 
    #         self.assertEqual(self.session.query(model.Customer).count(), 1)
 
    #         self.assertEqual(self.session.query(model.CustomerEmailAddress).count(), 0)
 

	
 
    #         # changes should have been recorded
 
    #         changes = self.session.query(model.Change)
 
    #         self.assertEqual(changes.count(), 3)
 

	
 
    #         customer_change = changes.filter_by(class_name='Customer').one()
 
    #         self.assertEqual(customer_change.uuid, customer.uuid)
 
    #         self.assertFalse(customer_change.deleted)
 

	
 
    #         email_changes = changes.filter_by(class_name='CustomerEmailAddress')
 
    #         self.assertEqual(email_changes.count(), 2)
 
    #         self.assertEqual([x.deleted for x in email_changes], [True, True])
 

	
 

	
 
    class TestLabelProfile(DataTestCase):
 

	
 
        def test_get_printer_setting(self):
 
            profile = model.LabelProfile()
 
            self.session.add(profile)
 

	
 
            self.assertTrue(profile.uuid is None)
 
            setting = profile.get_printer_setting('some_setting')
 
            self.assertTrue(setting is None)
 
            self.assertTrue(profile.uuid is None)
 

	
 
            profile.uuid = 'some_uuid'
 
            self.session.add(model.Setting(
 
                    name='labels.some_uuid.printer.some_setting',
 
                    value='some_value'))
 
            self.session.flush()
 
            setting = profile.get_printer_setting('some_setting')
 
            self.assertEqual(setting, 'some_value')
 

	
 
        def test_save_printer_setting(self):
 
            self.assertEqual(self.session.query(model.Setting).count(), 0)
 
            profile = model.LabelProfile()
 
            self.session.add(profile)
 

	
 
            self.assertTrue(profile.uuid is None)
 
            profile.save_printer_setting('some_setting', 'some_value')
 
            self.assertFalse(profile.uuid is None)
 
            self.assertEqual(self.session.query(model.Setting).count(), 1)
 

	
 
            profile.uuid = 'some_uuid'
 
            profile.save_printer_setting('some_setting', 'some_value')
 
            self.assertEqual(self.session.query(model.Setting).count(), 2)
 
            setting = self.session.query(model.Setting)\
 
                .filter_by(name='labels.some_uuid.printer.some_setting')\
 
                .one()
 
            self.assertEqual(setting.value, 'some_value')
tests/db/test_util.py
Show inline comments
 
@@ -3,9 +3,7 @@
 
from unittest import TestCase
 
from unittest.mock import MagicMock
 

	
 
from sqlalchemy import orm
 

	
 
from rattail.db import util, Session
 
from rattail.db import util
 

	
 

	
 
class TestFunctions(TestCase):
 
@@ -47,49 +45,56 @@ class TestFunctions(TestCase):
 
        self.assertEqual(number, '(417) 555-1234')
 

	
 

	
 
class TestShortSession(TestCase):
 

	
 
    def test_none(self):
 
        with util.short_session() as s:
 
            self.assertIsInstance(s, Session.class_)
 

	
 
    def test_factory(self):
 
        TestSession = orm.sessionmaker()
 
        with util.short_session(factory=TestSession) as s:
 
            self.assertIsInstance(s, TestSession.class_)
 

	
 
    def test_Session(self):
 
        TestSession = orm.sessionmaker()
 
        with util.short_session(Session=TestSession) as s:
 
            self.assertIsInstance(s, TestSession.class_)
 

	
 
    def test_instance(self):
 
        # nb. nothing really happens if we provide the session instance
 
        session = MagicMock()
 
        with util.short_session(session=session) as s:
 
            pass
 
        session.commit.assert_not_called()
 
        session.close.assert_not_called()
 

	
 
    def test_config(self):
 
        config = MagicMock()
 
        TestSession = orm.sessionmaker()
 
        config.get_app.return_value.make_session = TestSession
 
        with util.short_session(config=config) as s:
 
            self.assertIsInstance(s, TestSession.class_)
 

	
 
    def test_without_commit(self):
 
        session = MagicMock()
 
        TestSession = MagicMock(return_value=session)
 
        with util.short_session(factory=TestSession, commit=False) as s:
 
            pass
 
        session.commit.assert_not_called()
 
        session.close.assert_called_once_with()
 

	
 
    def test_with_commit(self):
 
        session = MagicMock()
 
        TestSession = MagicMock(return_value=session)
 
        with util.short_session(factory=TestSession, commit=True) as s:
 
            pass
 
        session.commit.assert_called_once_with()
 
        session.close.assert_called_once_with()
 
try:
 
    from sqlalchemy import orm
 
    from rattail.db import Session
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestShortSession(TestCase):
 

	
 
        def test_none(self):
 
            with util.short_session() as s:
 
                self.assertIsInstance(s, Session.class_)
 

	
 
        def test_factory(self):
 
            TestSession = orm.sessionmaker()
 
            with util.short_session(factory=TestSession) as s:
 
                self.assertIsInstance(s, TestSession.class_)
 

	
 
        def test_Session(self):
 
            TestSession = orm.sessionmaker()
 
            with util.short_session(Session=TestSession) as s:
 
                self.assertIsInstance(s, TestSession.class_)
 

	
 
        def test_instance(self):
 
            # nb. nothing really happens if we provide the session instance
 
            session = MagicMock()
 
            with util.short_session(session=session) as s:
 
                pass
 
            session.commit.assert_not_called()
 
            session.close.assert_not_called()
 

	
 
        def test_config(self):
 
            config = MagicMock()
 
            TestSession = orm.sessionmaker()
 
            config.get_app.return_value.make_session = TestSession
 
            with util.short_session(config=config) as s:
 
                self.assertIsInstance(s, TestSession.class_)
 

	
 
        def test_without_commit(self):
 
            session = MagicMock()
 
            TestSession = MagicMock(return_value=session)
 
            with util.short_session(factory=TestSession, commit=False) as s:
 
                pass
 
            session.commit.assert_not_called()
 
            session.close.assert_called_once_with()
 

	
 
        def test_with_commit(self):
 
            session = MagicMock()
 
            TestSession = MagicMock(return_value=session)
 
            with util.short_session(factory=TestSession, commit=True) as s:
 
                pass
 
            session.commit.assert_called_once_with()
 
            session.close.assert_called_once_with()
tests/filemon/test_actions.py
Show inline comments
 
@@ -6,8 +6,7 @@ import time
 
import tempfile
 
import queue
 
from unittest import TestCase
 

	
 
from mock import Mock, patch, call
 
from unittest.mock import Mock, patch, call
 

	
 
from rattail.config import make_config, RattailConfig, ConfigProfileAction
 
from rattail.filemon import actions
tests/filemon/test_linux.py
Show inline comments
 
@@ -6,8 +6,7 @@ import tempfile
 

	
 
import queue
 
from unittest import TestCase
 

	
 
from mock import Mock
 
from unittest.mock import Mock
 

	
 
from rattail.config import make_config
 
from rattail.filemon import linux
tests/importing/lib.py
Show inline comments
 
# -*- coding: utf-8 -*-
 

	
 
from __future__ import unicode_literals, absolute_import
 
# -*- coding: utf-8; -*-
 

	
 
import copy
 
from contextlib import contextmanager
 

	
 
from mock import patch
 
from unittest.mock import patch
 

	
 
from .. import NullProgress
 

	
tests/importing/test_handlers.py
Show inline comments
 
# -*- coding: utf-8; -*-
 

	
 
import datetime
 
import unittest
 
from unittest import TestCase
 
from unittest.mock import patch, Mock
 
from collections import OrderedDict
 

	
 
import pytz
 
from sqlalchemy import orm
 
from mock import patch, Mock
 

	
 
from rattail.importing import handlers, Importer
 
from rattail.config import make_config
 
from .. import RattailTestCase
 
from . import ImporterTester
 
from .test_importers import MockImporter
 
from .test_postgresql import MockBulkImporter
 

	
 

	
 
class ImportHandlerBattery(ImporterTester):
 
@@ -282,7 +279,7 @@ class BulkImportHandlerBattery(ImportHandlerBattery):
 
            self.assertFalse(process.called)
 

	
 

	
 
class TestImportHandler(unittest.TestCase, ImportHandlerBattery):
 
class TestImportHandler(TestCase, ImportHandlerBattery):
 
    handler_class = handlers.ImportHandler
 

	
 
    def setUp(self):
 
@@ -340,7 +337,7 @@ class TestImportHandler(unittest.TestCase, ImportHandlerBattery):
 
    #     self.assertEqual(send_email.call_count, 1)
 

	
 

	
 
class TestBulkImportHandler(unittest.TestCase, BulkImportHandlerBattery):
 
class TestBulkImportHandler(TestCase, BulkImportHandlerBattery):
 
    handler_class = handlers.BulkImportHandler
 

	
 

	
 
@@ -359,7 +356,7 @@ class MockImportHandler(handlers.ImportHandler):
 
        return result
 

	
 

	
 
class TestImportHandlerImportData(ImporterTester, unittest.TestCase):
 
class TestImportHandlerImportData(ImporterTester, TestCase):
 

	
 
    sample_data = OrderedDict([
 
        ('16oz', {'upc': '00074305001161', 'description': "Apple Cider Vinegar 16oz"}),
 
@@ -544,96 +541,103 @@ class TestImportHandlerImportData(ImporterTester, unittest.TestCase):
 
    #     self.assert_import_deleted('bogus')
 

	
 

	
 
Session = orm.sessionmaker()
 

	
 

	
 
class MockFromSQLAlchemyHandler(handlers.FromSQLAlchemyHandler):
 

	
 
    def make_host_session(self):
 
        return Session()
 

	
 

	
 
class MockToSQLAlchemyHandler(handlers.ToSQLAlchemyHandler):
 

	
 
    def make_session(self):
 
        return Session()
 

	
 

	
 
class TestFromSQLAlchemyHandler(unittest.TestCase):
 

	
 
    def test_init(self):
 
        handler = handlers.FromSQLAlchemyHandler()
 
        self.assertRaises(NotImplementedError, handler.make_host_session)
 

	
 
    def test_get_importer_kwargs(self):
 
        session = object()
 
        handler = handlers.FromSQLAlchemyHandler(host_session=session)
 
        kwargs = handler.get_importer_kwargs(None)
 
        self.assertEqual(list(kwargs.keys()), ['host_session'])
 
        self.assertIs(kwargs['host_session'], session)
 

	
 
    def test_begin_host_transaction(self):
 
        handler = MockFromSQLAlchemyHandler()
 
        self.assertIsNone(handler.host_session)
 
        handler.begin_host_transaction()
 
        self.assertIsInstance(handler.host_session, orm.Session)
 
        handler.host_session.close()
 

	
 
    def test_commit_host_transaction(self):
 
        # TODO: test actual commit for data changes
 
        session = Session()
 
        handler = handlers.FromSQLAlchemyHandler(host_session=session)
 
        self.assertIs(handler.host_session, session)
 
        handler.commit_host_transaction()
 
        self.assertIsNone(handler.host_session)
 

	
 
    def test_rollback_host_transaction(self):
 
        # TODO: test actual rollback for data changes
 
        session = Session()
 
        handler = handlers.FromSQLAlchemyHandler(host_session=session)
 
        self.assertIs(handler.host_session, session)
 
        handler.rollback_host_transaction()
 
        self.assertIsNone(handler.host_session)
 

	
 

	
 
class TestToSQLAlchemyHandler(unittest.TestCase):
 

	
 
    def test_init(self):
 
        handler = handlers.ToSQLAlchemyHandler()
 
        self.assertRaises(NotImplementedError, handler.make_session)
 

	
 
    def test_get_importer_kwargs(self):
 
        session = object()
 
        handler = handlers.ToSQLAlchemyHandler(session=session)
 
        kwargs = handler.get_importer_kwargs(None)
 
        self.assertEqual(list(kwargs.keys()), ['session'])
 
        self.assertIs(kwargs['session'], session)
 

	
 
    def test_begin_local_transaction(self):
 
        handler = MockToSQLAlchemyHandler()
 
        self.assertIsNone(handler.session)
 
        handler.begin_local_transaction()
 
        self.assertIsInstance(handler.session, orm.Session)
 
        handler.session.close()
 

	
 
    def test_commit_local_transaction(self):
 
        # TODO: test actual commit for data changes
 
        session = Session()
 
        handler = handlers.ToSQLAlchemyHandler(session=session)
 
        self.assertIs(handler.session, session)
 
        with patch.object(handler, 'session') as session:
 
            handler.commit_local_transaction()
 
            session.commit.assert_called_once_with()
 
            self.assertFalse(session.rollback.called)
 
        # self.assertIsNone(handler.session)
 

	
 
    def test_rollback_local_transaction(self):
 
        # TODO: test actual rollback for data changes
 
        session = Session()
 
        handler = handlers.ToSQLAlchemyHandler(session=session)
 
        self.assertIs(handler.session, session)
 
        with patch.object(handler, 'session') as session:
 
            handler.rollback_local_transaction()
 
            session.rollback.assert_called_once_with()
 
            self.assertFalse(session.commit.called)
 
        # self.assertIsNone(handler.session)
 
try:
 
    from sqlalchemy import orm
 
    from .test_postgresql import MockBulkImporter
 
except ImportError:
 
    pass
 
else:
 

	
 
    Session = orm.sessionmaker()
 

	
 

	
 
    class MockFromSQLAlchemyHandler(handlers.FromSQLAlchemyHandler):
 

	
 
        def make_host_session(self):
 
            return Session()
 

	
 

	
 
    class MockToSQLAlchemyHandler(handlers.ToSQLAlchemyHandler):
 

	
 
        def make_session(self):
 
            return Session()
 

	
 

	
 
    class TestFromSQLAlchemyHandler(TestCase):
 

	
 
        def test_init(self):
 
            handler = handlers.FromSQLAlchemyHandler()
 
            self.assertRaises(NotImplementedError, handler.make_host_session)
 

	
 
        def test_get_importer_kwargs(self):
 
            session = object()
 
            handler = handlers.FromSQLAlchemyHandler(host_session=session)
 
            kwargs = handler.get_importer_kwargs(None)
 
            self.assertEqual(list(kwargs.keys()), ['host_session'])
 
            self.assertIs(kwargs['host_session'], session)
 

	
 
        def test_begin_host_transaction(self):
 
            handler = MockFromSQLAlchemyHandler()
 
            self.assertIsNone(handler.host_session)
 
            handler.begin_host_transaction()
 
            self.assertIsInstance(handler.host_session, orm.Session)
 
            handler.host_session.close()
 

	
 
        def test_commit_host_transaction(self):
 
            # TODO: test actual commit for data changes
 
            session = Session()
 
            handler = handlers.FromSQLAlchemyHandler(host_session=session)
 
            self.assertIs(handler.host_session, session)
 
            handler.commit_host_transaction()
 
            self.assertIsNone(handler.host_session)
 

	
 
        def test_rollback_host_transaction(self):
 
            # TODO: test actual rollback for data changes
 
            session = Session()
 
            handler = handlers.FromSQLAlchemyHandler(host_session=session)
 
            self.assertIs(handler.host_session, session)
 
            handler.rollback_host_transaction()
 
            self.assertIsNone(handler.host_session)
 

	
 

	
 
    class TestToSQLAlchemyHandler(TestCase):
 

	
 
        def test_init(self):
 
            handler = handlers.ToSQLAlchemyHandler()
 
            self.assertRaises(NotImplementedError, handler.make_session)
 

	
 
        def test_get_importer_kwargs(self):
 
            session = object()
 
            handler = handlers.ToSQLAlchemyHandler(session=session)
 
            kwargs = handler.get_importer_kwargs(None)
 
            self.assertEqual(list(kwargs.keys()), ['session'])
 
            self.assertIs(kwargs['session'], session)
 

	
 
        def test_begin_local_transaction(self):
 
            handler = MockToSQLAlchemyHandler()
 
            self.assertIsNone(handler.session)
 
            handler.begin_local_transaction()
 
            self.assertIsInstance(handler.session, orm.Session)
 
            handler.session.close()
 

	
 
        def test_commit_local_transaction(self):
 
            # TODO: test actual commit for data changes
 
            session = Session()
 
            handler = handlers.ToSQLAlchemyHandler(session=session)
 
            self.assertIs(handler.session, session)
 
            with patch.object(handler, 'session') as session:
 
                handler.commit_local_transaction()
 
                session.commit.assert_called_once_with()
 
                self.assertFalse(session.rollback.called)
 
            # self.assertIsNone(handler.session)
 

	
 
        def test_rollback_local_transaction(self):
 
            # TODO: test actual rollback for data changes
 
            session = Session()
 
            handler = handlers.ToSQLAlchemyHandler(session=session)
 
            self.assertIs(handler.session, session)
 
            with patch.object(handler, 'session') as session:
 
                handler.rollback_local_transaction()
 
                session.rollback.assert_called_once_with()
 
                self.assertFalse(session.commit.called)
 
            # self.assertIsNone(handler.session)
tests/importing/test_importers.py
Show inline comments
 
@@ -5,13 +5,9 @@ from collections import OrderedDict
 
from unittest import TestCase
 
from unittest.mock import Mock, patch, call
 

	
 
import sqlalchemy as sa
 

	
 
from rattail.db import model, Session
 
from rattail.db.util import QuerySequence
 
from rattail.importing import importers
 
from rattail.config import make_config, RattailConfig
 
from .. import NullProgress, RattailTestCase
 
from .. import NullProgress
 
from . import ImporterTester
 

	
 

	
 
@@ -290,36 +286,6 @@ class TestImporter(TestCase):
 
                flush.reset_mock()
 

	
 

	
 
class TestFromQuery(TestCase):
 

	
 
    def setUp(self):
 
        self.config = RattailConfig(defaults={
 
            'rattail.timezone.default': 'America/Chicago',
 
        })
 
        engine_url = os.environ.get('RATTAIL_TEST_ENGINE_URL', 'sqlite://')
 
        self.engine = sa.create_engine(engine_url)
 
        model.Base.metadata.create_all(bind=self.engine)
 
        Session.configure(bind=self.engine)
 

	
 
    def tearDown(self):
 
        model.Base.metadata.drop_all(bind=self.engine)
 
        Session.configure(bind=None)
 

	
 
    def test_query(self):
 
        importer = importers.FromQuery(self.config)
 
        self.assertRaises(NotImplementedError, importer.query)
 

	
 
    def test_get_host_objects(self):
 
        app = self.config.get_app()
 
        session = app.make_session()
 
        query = session.query(model.Product)
 
        importer = importers.FromQuery(self.config)
 
        with patch.object(importer, 'query', Mock(return_value=query)):
 
            objects = importer.get_host_objects()
 
        self.assertIsInstance(objects, QuerySequence)
 
        session.close()
 

	
 

	
 
class TestBulkImporter(BulkImporterBattery, TestCase):
 
    importer_class = importers.BulkImporter
 
        
 
@@ -525,3 +491,41 @@ class TestMockImporter(ImporterTester, TestCase):
 
        self.assert_import_created('16oz')
 
        self.assert_import_updated('32oz', '1gal')
 
        self.assert_import_deleted()
 

	
 

	
 
try:
 
    import sqlalchemy as sa
 
    from rattail.db import model, Session
 
    from rattail.db.util import QuerySequence
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestFromQuery(TestCase):
 

	
 
        def setUp(self):
 
            self.config = RattailConfig(defaults={
 
                'rattail.timezone.default': 'America/Chicago',
 
            })
 
            engine_url = os.environ.get('RATTAIL_TEST_ENGINE_URL', 'sqlite://')
 
            self.engine = sa.create_engine(engine_url)
 
            model.Base.metadata.create_all(bind=self.engine)
 
            Session.configure(bind=self.engine)
 

	
 
        def tearDown(self):
 
            model.Base.metadata.drop_all(bind=self.engine)
 
            Session.configure(bind=None)
 

	
 
        def test_query(self):
 
            importer = importers.FromQuery(self.config)
 
            self.assertRaises(NotImplementedError, importer.query)
 

	
 
        def test_get_host_objects(self):
 
            app = self.config.get_app()
 
            session = app.make_session()
 
            query = session.query(model.Product)
 
            importer = importers.FromQuery(self.config)
 
            with patch.object(importer, 'query', Mock(return_value=query)):
 
                objects = importer.get_host_objects()
 
            self.assertIsInstance(objects, QuerySequence)
 
            session.close()
tests/importing/test_model.py
Show inline comments
 
@@ -4,76 +4,79 @@ import os
 
from unittest import TestCase
 
from unittest.mock import Mock
 

	
 
import sqlalchemy as sa
 

	
 
from rattail.db import model, auth, Session, ConfigExtension
 
from rattail.importing import model as import_model
 
from rattail.config import RattailConfig
 
from .. import RattailTestCase
 

	
 

	
 
class TestAdminUser(TestCase):
 

	
 
    def setUp(self):
 
        self.config = RattailConfig(defaults={
 
            'rattail.timezone.default': 'America/Chicago',
 
        })
 
        engine_url = os.environ.get('RATTAIL_TEST_ENGINE_URL', 'sqlite://')
 
        self.engine = sa.create_engine(engine_url)
 
        model.Base.metadata.create_all(bind=self.engine)
 
        Session.configure(bind=self.engine)
 
        self.app = self.config.get_app()
 
        self.session = self.app.make_session()
 

	
 
    def tearDown(self):
 
        self.session.close()
 
        model.Base.metadata.drop_all(bind=self.engine)
 
        Session.configure(bind=None)
 

	
 
    def make_importer(self, **kwargs):
 
        kwargs.setdefault('config', self.config)
 
        kwargs.setdefault('session', self.session)
 
        return import_model.AdminUserImporter(**kwargs)
 

	
 
    def get_admin(self):
 
        return auth.administrator_role(self.session)
 

	
 
    def test_supported_fields(self):
 
        importer = import_model.UserImporter(self.config)
 
        standard_fields = importer.fields
 
        importer = self.make_importer()
 
        extra_fields = set(importer.fields) - set(standard_fields)
 
        self.assertEqual(len(extra_fields), 1)
 
        self.assertEqual(list(extra_fields)[0], 'admin')
 

	
 
    def test_normalize_local_object(self):
 
        importer = self.make_importer()
 
        importer.setup()
 

	
 
        user = model.User()
 
        user.username = 'fred'
 
        self.session.add(user)
 
        self.session.flush()
 

	
 
        data = importer.normalize_local_object(user)
 
        self.assertFalse(data['admin'])
 

	
 
        user.roles.append(self.get_admin())
 
        self.session.flush()
 
        data = importer.normalize_local_object(user)
 
        self.assertTrue(data['admin'])
 

	
 
    def test_update_object(self):
 
        importer = self.make_importer(fields=['uuid', 'admin'])
 
        data = {'uuid': 'ccb1915419e511e6a3ad3ca9f40bc550'}
 
        user = model.User(**data)
 
        admin = self.get_admin()
 
        self.assertNotIn(admin, user.roles)
 

	
 
        data['admin'] = True
 
        importer.update_object(user, data)
 
        self.assertIn(admin, user.roles)
 

	
 
        data['admin'] = False
 
        importer.update_object(user, data)
 
        self.assertNotIn(admin, user.roles)
 

	
 
try:
 
    import sqlalchemy as sa
 
    from rattail.db import model, auth, Session, ConfigExtension
 
    from rattail.importing import model as import_model
 
    from .. import RattailTestCase
 
except ImportError:
 
    pass
 
else:
 

	
 
    class TestAdminUser(TestCase):
 

	
 
        def setUp(self):
 
            self.config = RattailConfig(defaults={
 
                'rattail.timezone.default': 'America/Chicago',
 
            })
 
            engine_url = os.environ.get('RATTAIL_TEST_ENGINE_URL', 'sqlite://')
 
            self.engine = sa.create_engine(engine_url)
 
            model.Base.metadata.create_all(bind=self.engine)
 
            Session.configure(bind=self.engine)
 
            self.app = self.config.get_app()
 
            self.session = self.app.make_session()
 

	
 
        def tearDown(self):
 
            self.session.close()
 
            model.Base.metadata.drop_all(bind=self.engine)
 
            Session.configure(bind=None)
 

	
 
        def make_importer(self, **kwargs):
 
            kwargs.setdefault('config', self.config)
 
            kwargs.setdefault('session', self.session)
 
            return import_model.AdminUserImporter(**kwargs)
 

	
 
        def get_admin(self):
 
            return auth.administrator_role(self.session)
 

	
 
        def test_supported_fields(self):
 
            importer = import_model.UserImporter(self.config)
 
            standard_fields = importer.fields
 
            importer = self.make_importer()
 
            extra_fields = set(importer.fields) - set(standard_fields)
 
            self.assertEqual(len(extra_fields), 1)
 
            self.assertEqual(list(extra_fields)[0], 'admin')
 

	
 
        def test_normalize_local_object(self):
 
            importer = self.make_importer()
 
            importer.setup()
 

	
 
            user = model.User()
 
            user.username = 'fred'
 
            self.session.add(user)
 
            self.session.flush()
 

	
 
            data = importer.normalize_local_object(user)
 
            self.assertFalse(data['admin'])
 

	
 
            user.roles.append(self.get_admin())
 
            self.session.flush()
 
            data = importer.normalize_local_object(user)
 
            self.assertTrue(data['admin'])
 

	
 
        def test_update_object(self):
 
            importer = self.make_importer(fields=['uuid', 'admin'])
 
            data = {'uuid': 'ccb1915419e511e6a3ad3ca9f40bc550'}
 
            user = model.User(**data)
 
            admin = self.get_admin()
 
            self.assertNotIn(admin, user.roles)
 

	
 
            data['admin'] = True
 
            importer.update_object(user, data)
 
            self.assertIn(admin, user.roles)
 

	
 
            data['admin'] = False
 
            importer.update_object(user, data)
 
            self.assertNotIn(admin, user.roles)
tests/importing/test_postgresql.py
Show inline comments
 
# -*- coding: utf-8 -*-
 
# -*- coding: utf-8; -*-
 

	
 
import datetime
 
import shutil
 
import tempfile
 
from unittest import TestCase
 

	
 
import sqlalchemy as sa
 
from sqlalchemy import orm
 

	
 
from rattail.db import Session, model
 
from rattail.importing import postgresql as pgimport
 
from rattail.config import RattailConfig
 
from rattail.exceptions import ConfigurationError
 
from .. import RattailTestCase, NullProgress
 
from . import ImporterTester
 
from .test_rattail import DualRattailTestCase
 
from rattail.time import localtime
 

	
 

	
 
class Widget(object):
 
try:
 
    import sqlalchemy as sa
 
    from sqlalchemy import orm
 
    from rattail.db import Session, model
 
    from rattail.importing import postgresql as pgimport
 
    from .. import RattailTestCase, NullProgress
 
    from . import ImporterTester
 
    from .test_rattail import DualRattailTestCase
 
except ImportError:
 
    pass
 
else:
 

	
 

	
 
class TestBulkToPostgreSQL(TestCase):
 

	
 
    def setUp(self):
 
        self.tempdir = tempfile.mkdtemp()
 
        self.config = self.make_config()
 

	
 
    def tearDown(self):
 
        shutil.rmtree(self.tempdir)
 

	
 
    def make_config(self, workdir=True):
 
        cfg = RattailConfig()
 
        if workdir:
 
            cfg.setdefault('rattail', 'workdir', self.tempdir)
 
        cfg.setdefault('rattail', 'timezone.default', 'America/Chicago')
 
        return cfg
 

	
 
    def make_importer(self, **kwargs):
 
        kwargs.setdefault('config', self.config)
 
        kwargs.setdefault('fields', ['id']) # hack
 
        return pgimport.BulkToPostgreSQL(**kwargs)
 

	
 
    def test_data_path_property(self):
 
        self.config = self.make_config(workdir=False)
 
        self.config.setdefault('rattail', 'workdir', '/tmp')
 
        importer = pgimport.BulkToPostgreSQL(config=self.config, fields=['id'])
 

	
 
        # path leverages model name, so default is None
 
        self.assertEqual(importer.data_path, '/tmp/import_bulk_postgresql_None.csv')
 

	
 
        # but it will reflect model name
 
        importer.model_name = 'Foo'
 
        self.assertEqual(importer.data_path, '/tmp/import_bulk_postgresql_Foo.csv')
 

	
 
    def test_setup(self):
 
        importer = self.make_importer()
 
        self.assertFalse(hasattr(importer, 'data_buffer'))
 
        importer.setup()
 
        self.assertIsNotNone(importer.data_buffer)
 
        importer.data_buffer.close()
 

	
 
    def test_teardown(self):
 
        importer = self.make_importer()
 
        importer.data_buffer = open(importer.data_path, 'wb')
 
        importer.teardown()
 
        self.assertIsNone(importer.data_buffer)
 

	
 
    def test_prep_value_for_postgres(self):
 
        importer = self.make_importer()
 

	
 
        # constants
 
        self.assertEqual(importer.prep_value_for_postgres(None), '\\N')
 
        self.assertEqual(importer.prep_value_for_postgres(True), 't')
 
        self.assertEqual(importer.prep_value_for_postgres(False), 'f')
 

	
 
        # datetime (local zone is Chicago/CDT; UTC-5)
 
        value = localtime(self.config, datetime.datetime(2016, 5, 13, 12))
 
        self.assertEqual(importer.prep_value_for_postgres(value), '2016-05-13 17:00:00')
 

	
 
        # strings...
 

	
 
        # backslash is escaped by doubling
 
        self.assertEqual(importer.prep_value_for_postgres('\\'), '\\\\')
 

	
 
        # newlines are collapsed (\r\n -> \n) and escaped
 
        self.assertEqual(importer.prep_value_for_postgres('one\rtwo\nthree\r\nfour\r\nfive\nsix\rseven'), 'one\\rtwo\\nthree\\r\\nfour\\r\\nfive\\nsix\\rseven')
 

	
 
    def test_prep_data_for_postgres(self):
 
        importer = self.make_importer()
 
        time = localtime(self.config, datetime.datetime(2016, 5, 13, 12))
 
        data = {
 
            'none': None,
 
            'true': True,
 
            'false': False,
 
            'datetime': time,
 
            'backslash': '\\',
 
            'newlines': 'one\rtwo\nthree\r\nfour\r\nfive\nsix\rseven',
 
        }
 
        data = importer.prep_data_for_postgres(data)
 
        self.assertEqual(data['none'], '\\N')
 
        self.assertEqual(data['true'], 't')
 
        self.assertEqual(data['false'], 'f')
 
        self.assertEqual(data['datetime'], '2016-05-13 17:00:00')
 
        self.assertEqual(data['backslash'], '\\\\')
 
        self.assertEqual(data['newlines'], 'one\\rtwo\\nthree\\r\\nfour\\r\\nfive\\nsix\\rseven')
 

	
 

	
 
######################################################################
 
# fake importer class, tested mostly for basic coverage
 
######################################################################
 

	
 
class MockBulkImporter(pgimport.BulkToPostgreSQL):
 
    model_class = model.Department
 
    key = 'uuid'
 

	
 
    def normalize_local_object(self, obj):
 
        return obj
 

	
 
    def update_object(self, obj, host_data, local_data=None):
 
        return host_data
 

	
 

	
 
class TestMockBulkImporter(DualRattailTestCase, ImporterTester):
 
    importer_class = MockBulkImporter
 

	
 
    sample_data = {
 
        1: {'number': 1, 'name': "Grocery", 'uuid': 'decd909a194011e688093ca9f40bc550'},
 
        2: {'number': 2, 'name': "Bulk", 'uuid': 'e633d54c194011e687e33ca9f40bc550'},
 
        3: {'number': 3, 'name': "HBA", 'uuid': 'e2bad79e194011e6a4783ca9f40bc550'},
 
    }
 

	
 
    def setUp(self):
 
        self.setup_rattail()
 
        self.tempdir = tempfile.mkdtemp()
 
        self.config.setdefault('rattail', 'workdir', self.tempdir)
 
        self.importer = self.make_importer()
 

	
 
    def tearDown(self):
 
        self.teardown_rattail()
 
        shutil.rmtree(self.tempdir)
 

	
 
    def make_importer(self, **kwargs):
 
        kwargs.setdefault('config', self.config)
 
        return super(TestMockBulkImporter, self).make_importer(**kwargs)
 

	
 
    def import_data(self, **kwargs):
 
        self.importer.session = self.session
 
        self.importer.host_session = self.host_session
 
        self.result = self.importer.import_data(**kwargs)
 

	
 
    def assert_import_created(self, *keys):
 
    class Widget(object):
 
        pass
 

	
 
    def assert_import_updated(self, *keys):
 
        pass
 

	
 
    def assert_import_deleted(self, *keys):
 
        pass
 
    class TestBulkToPostgreSQL(TestCase):
 

	
 
        def setUp(self):
 
            self.tempdir = tempfile.mkdtemp()
 
            self.config = self.make_config()
 

	
 
        def tearDown(self):
 
            shutil.rmtree(self.tempdir)
 

	
 
        def make_config(self, workdir=True):
 
            cfg = RattailConfig()
 
            if workdir:
 
                cfg.setdefault('rattail', 'workdir', self.tempdir)
 
            cfg.setdefault('rattail', 'timezone.default', 'America/Chicago')
 
            return cfg
 

	
 
        def make_importer(self, **kwargs):
 
            kwargs.setdefault('config', self.config)
 
            kwargs.setdefault('fields', ['id']) # hack
 
            return pgimport.BulkToPostgreSQL(**kwargs)
 

	
 
        def test_data_path_property(self):
 
            self.config = self.make_config(workdir=False)
 
            self.config.setdefault('rattail', 'workdir', '/tmp')
 
            importer = pgimport.BulkToPostgreSQL(config=self.config, fields=['id'])
 

	
 
            # path leverages model name, so default is None
 
            self.assertEqual(importer.data_path, '/tmp/import_bulk_postgresql_None.csv')
 

	
 
            # but it will reflect model name
 
            importer.model_name = 'Foo'
 
            self.assertEqual(importer.data_path, '/tmp/import_bulk_postgresql_Foo.csv')
 

	
 
        def test_setup(self):
 
            importer = self.make_importer()
 
            self.assertFalse(hasattr(importer, 'data_buffer'))
 
            importer.setup()
 
            self.assertIsNotNone(importer.data_buffer)
 
            importer.data_buffer.close()
 

	
 
        def test_teardown(self):
 
            importer = self.make_importer()
 
            importer.data_buffer = open(importer.data_path, 'wb')
 
            importer.teardown()
 
            self.assertIsNone(importer.data_buffer)
 

	
 
        def test_prep_value_for_postgres(self):
 
            importer = self.make_importer()
 

	
 
            # constants
 
            self.assertEqual(importer.prep_value_for_postgres(None), '\\N')
 
            self.assertEqual(importer.prep_value_for_postgres(True), 't')
 
            self.assertEqual(importer.prep_value_for_postgres(False), 'f')
 

	
 
            # datetime (local zone is Chicago/CDT; UTC-5)
 
            value = localtime(self.config, datetime.datetime(2016, 5, 13, 12))
 
            self.assertEqual(importer.prep_value_for_postgres(value), '2016-05-13 17:00:00')
 

	
 
            # strings...
 

	
 
            # backslash is escaped by doubling
 
            self.assertEqual(importer.prep_value_for_postgres('\\'), '\\\\')
 

	
 
            # newlines are collapsed (\r\n -> \n) and escaped
 
            self.assertEqual(importer.prep_value_for_postgres('one\rtwo\nthree\r\nfour\r\nfive\nsix\rseven'), 'one\\rtwo\\nthree\\r\\nfour\\r\\nfive\\nsix\\rseven')
 

	
 
        def test_prep_data_for_postgres(self):
 
            importer = self.make_importer()
 
            time = localtime(self.config, datetime.datetime(2016, 5, 13, 12))
 
            data = {
 
                'none': None,
 
                'true': True,
 
                'false': False,
 
                'datetime': time,
 
                'backslash': '\\',
 
                'newlines': 'one\rtwo\nthree\r\nfour\r\nfive\nsix\rseven',
 
            }
 
            data = importer.prep_data_for_postgres(data)
 
            self.assertEqual(data['none'], '\\N')
 
            self.assertEqual(data['true'], 't')
 
            self.assertEqual(data['false'], 'f')
 
            self.assertEqual(data['datetime'], '2016-05-13 17:00:00')
 
            self.assertEqual(data['backslash'], '\\\\')
 
            self.assertEqual(data['newlines'], 'one\\rtwo\\nthree\\r\\nfour\\r\\nfive\\nsix\\rseven')
 

	
 

	
 
    ######################################################################
 
    # fake importer class, tested mostly for basic coverage
 
    ######################################################################
 

	
 
    class MockBulkImporter(pgimport.BulkToPostgreSQL):
 
        model_class = model.Department
 
        key = 'uuid'
 

	
 
        def normalize_local_object(self, obj):
 
            return obj
 

	
 
        def update_object(self, obj, host_data, local_data=None):
 
            return host_data
 

	
 

	
 
    class TestMockBulkImporter(DualRattailTestCase, ImporterTester):
 
        importer_class = MockBulkImporter
 

	
 
        sample_data = {
 
            1: {'number': 1, 'name': "Grocery", 'uuid': 'decd909a194011e688093ca9f40bc550'},
 
            2: {'number': 2, 'name': "Bulk", 'uuid': 'e633d54c194011e687e33ca9f40bc550'},
 
            3: {'number': 3, 'name': "HBA", 'uuid': 'e2bad79e194011e6a4783ca9f40bc550'},
 
        }
 

	
 
    def test_create(self):
 
        if self.postgresql():
 
            with self.host_data(self.sample_data):
 
                self.import_data()
 
            self.assert_import_created(3)
 

	
 
    def test_create_empty(self):
 
        if self.postgresql():
 
            with self.host_data({}):
 
                self.import_data()
 
            self.assert_import_created(0)
 

	
 
    def test_max_create(self):
 
        if self.postgresql():
 
            with self.host_data(self.sample_data):
 
                with self.local_data({}):
 
                    self.import_data(max_create=1)
 
            self.assert_import_created(1)
 

	
 
    def test_max_total_create(self):
 
        if self.postgresql():
 
            with self.host_data(self.sample_data):
 
                with self.local_data({}):
 
                    self.import_data(max_total=1)
 
            self.assert_import_created(1)
 

	
 
    # # TODO: a bit hacky, leveraging the fact that 'user' is a reserved word
 
    # def test_table_name_is_reserved_word(self):
 
    #     if self.postgresql():
 
    #         from rattail.importing.rattail_bulk import UserImporter
 
    #         data = {
 
    #             '521a788e195911e688c13ca9f40bc550': {
 
    #                 'uuid': '521a788e195911e688c13ca9f40bc550',
 
    #                 'username': 'fred',
 
    #                 'active': True,
 
    #             },
 
    #         }
 
    #         self.importer = UserImporter(config=self.config)
 
    #         # with self.host_data(data):
 
    #         self.import_data(host_data=data)
 
    #         # self.assert_import_created(3)
 
        def setUp(self):
 
            self.setup_rattail()
 
            self.tempdir = tempfile.mkdtemp()
 
            self.config.setdefault('rattail', 'workdir', self.tempdir)
 
            self.importer = self.make_importer()
 

	
 
        def tearDown(self):
 
            self.teardown_rattail()
 
            shutil.rmtree(self.tempdir)
 

	
 
        def make_importer(self, **kwargs):
 
            kwargs.setdefault('config', self.config)
 
            return super(TestMockBulkImporter, self).make_importer(**kwargs)
 

	
 
        def import_data(self, **kwargs):
 
            self.importer.session = self.session
 
            self.importer.host_session = self.host_session
 
            self.result = self.importer.import_data(**kwargs)
 

	
 
        def assert_import_created(self, *keys):
 
            pass
 

	
 
        def assert_import_updated(self, *keys):
 
            pass
 

	
 
        def assert_import_deleted(self, *keys):
 
            pass
 

	
 
        def test_create(self):
 
            if self.postgresql():
 
                with self.host_data(self.sample_data):
 
                    self.import_data()
 
                self.assert_import_created(3)
 

	
 
        def test_create_empty(self):
 
            if self.postgresql():
 
                with self.host_data({}):
 
                    self.import_data()
 
                self.assert_import_created(0)
 

	
 
        def test_max_create(self):
 
            if self.postgresql():
 
                with self.host_data(self.sample_data):
 
                    with self.local_data({}):
 
                        self.import_data(max_create=1)
 
                self.assert_import_created(1)
 

	
 
        def test_max_total_create(self):
 
            if self.postgresql():
 
                with self.host_data(self.sample_data):
 
                    with self.local_data({}):
 
                        self.import_data(max_total=1)
 
                self.assert_import_created(1)
 

	
 
        # # TODO: a bit hacky, leveraging the fact that 'user' is a reserved word
 
        # def test_table_name_is_reserved_word(self):
 
        #     if self.postgresql():
 
        #         from rattail.importing.rattail_bulk import UserImporter
 
        #         data = {
 
        #             '521a788e195911e688c13ca9f40bc550': {
 
        #                 'uuid': '521a788e195911e688c13ca9f40bc550',
 
        #                 'username': 'fred',
 
        #                 'active': True,
 
        #             },
 
        #         }
 
        #         self.importer = UserImporter(config=self.config)
 
        #         # with self.host_data(data):
 
        #         self.import_data(host_data=data)
 
        #         # self.assert_import_created(3)

Changeset was too big and was cut off... Show full diff anyway

0 comments (0 inline, 0 general)