Files @ d2d8513e900b
Branch filter:

Location: rattail-project/rattail/rattail/importing/importers.py

Lance Edgar
Fix dependency bug when testing coverage with tox
  1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
# -*- coding: utf-8 -*-
################################################################################
#
#  Rattail -- Retail Software Framework
#  Copyright © 2010-2016 Lance Edgar
#
#  This file is part of Rattail.
#
#  Rattail is free software: you can redistribute it and/or modify it under the
#  terms of the GNU Affero General Public License as published by the Free
#  Software Foundation, either version 3 of the License, or (at your option)
#  any later version.
#
#  Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
#  WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
#  FOR A PARTICULAR PURPOSE.  See the GNU Affero General Public License for
#  more details.
#
#  You should have received a copy of the GNU Affero General Public License
#  along with Rattail.  If not, see <http://www.gnu.org/licenses/>.
#
################################################################################
"""
Data Importers
"""

from __future__ import unicode_literals, absolute_import

import datetime
import logging

from rattail.db.util import QuerySequence
from rattail.time import make_utc


log = logging.getLogger(__name__)


class Importer(object):
    """
    Base class for all data importers.
    """
    # Set this to the data model class which is targeted on the local side.
    model_class = None

    key = None

    # The full list of field names supported by the importer, i.e. for the data
    # model to which the importer pertains.  By definition this list will be
    # restricted to what the local side can acommodate, but may be further
    # restricted by what the host side has to offer.
    supported_fields = []

    # The list of field names which may be considered "simple" and therefore
    # treated as such, i.e. with basic getattr/setattr calls.  Note that this
    # only applies to the local side, it has no effect on the host side.
    simple_fields = []

    allow_create = True
    allow_update = True
    allow_delete = True
    dry_run = False

    max_create = None
    max_update = None
    max_delete = None
    max_total = None
    progress = None

    caches_local_data = False
    cached_local_data = None

    host_system_title = None
    local_system_title = None

    def __init__(self, config=None, fields=None, key=None, **kwargs):
        self.config = config
        self.fields = fields or self.supported_fields
        if key is not None:
            self.key = key
        if isinstance(self.key, basestring):
            self.key = (self.key,)
        if self.key:
            for field in self.key:
                if field not in self.fields:
                    raise ValueError("Key field '{}' must be included in effective fields "
                                     "for {}".format(field, self.__class__.__name__))
        self._setup(**kwargs)

    def _setup(self, **kwargs):
        self.create = kwargs.pop('create', self.allow_create) and self.allow_create
        self.update = kwargs.pop('update', self.allow_update) and self.allow_update
        self.delete = kwargs.pop('delete', self.allow_delete) and self.allow_delete
        for key, value in kwargs.iteritems():
            setattr(self, key, value)

    @property
    def model_name(self):
        if self.model_class:
            return self.model_class.__name__

    def setup(self):
        """
        Perform any setup necessary, e.g. cache lookups for existing data.
        """

    def teardown(self):
        """
        Perform any cleanup after import, if necessary.
        """

    def import_data(self, host_data=None, now=None, **kwargs):
        """
        Import some data!  This is the core body of logic for that, regardless
        of where data is coming from or where it's headed.  Note that this
        method handles deletions as well as adds/updates.
        """
        self.now = now or make_utc(datetime.datetime.utcnow(), tzinfo=True)
        if kwargs:
            self._setup(**kwargs)
        self.setup()
        created = updated = deleted = []

        # Get complete set of normalized host data.
        if host_data is None:
            host_data = self.normalize_host_data()

        # Prune duplicate keys from host/source data.  This is for the sake of
        # sanity since duplicates typically lead to a ping-pong effect, where a
        # "clean" (change-less) import is impossible.
        unique = {}
        for data in host_data:
            key = self.get_key(data)
            if key in unique:
                log.warning("duplicate records detected from {} for key: {}".format(
                    self.host_system_title, key))
            unique[key] = data
        host_data = []
        for key in sorted(unique):
            host_data.append(unique[key])

        # Cache local data if appropriate.
        if self.caches_local_data:
            self.cached_local_data = self.cache_local_data(host_data)

        # Create and/or update data.
        if self.create or self.update:
            created, updated = self._import_create_update(host_data)

        # Delete data.
        if self.delete:
            changes = len(created) + len(updated)
            if self.max_total and changes >= self.max_total:
                log.warning("max of {} total changes already reached; skipping deletions".format(self.max_total))
            else:
                deleted = self._import_delete(host_data, set(unique), changes=changes)

        self.teardown()
        return created, updated, deleted

    def _import_create_update(self, data):
        """
        Import the given data; create and/or update records as needed.
        """
        created, updated = [], []
        count = len(data)
        if not count:
            return created, updated

        prog = None
        if self.progress:
            prog = self.progress("Importing {} data".format(self.model_name), count)

        for i, host_data in enumerate(data, 1):

            # Fetch local object, using key from host data.
            key = self.get_key(host_data)
            local_object = self.get_local_object(key)

            # If we have a local object, but its data differs from host, update it.
            if local_object and self.update:
                local_data = self.normalize_local_object(local_object)
                diffs = self.data_diffs(local_data, host_data)
                if diffs:
                    log.debug("fields '{}' differed for local data: {}, host data: {}".format(
                        ','.join(diffs), local_data, host_data))
                    local_object = self.update_object(local_object, host_data, local_data)
                    updated.append((local_object, local_data, host_data))
                    if self.max_update and len(updated) >= self.max_update:
                        log.warning("max of {} *updated* records has been reached; stopping now".format(self.max_update))
                        break
                    if self.max_total and (len(created) + len(updated)) >= self.max_total:
                        log.warning("max of {} *total changes* has been reached; stopping now".format(self.max_total))
                        break

            # If we did not yet have a local object, create it using host data.
            elif not local_object and self.create:
                local_object = self.create_object(key, host_data)
                log.debug("created new {} {}: {}".format(self.model_name, key, local_object))
                created.append((local_object, host_data))
                if self.caches_local_data and self.cached_local_data is not None:
                    self.cached_local_data[key] = {'object': local_object, 'data': self.normalize_local_object(local_object)}
                if self.max_create and len(created) >= self.max_create:
                    log.warning("max of {} *created* records has been reached; stopping now".format(self.max_create))
                    break
                if self.max_total and (len(created) + len(updated)) >= self.max_total:
                    log.warning("max of {} *total changes* has been reached; stopping now".format(self.max_total))
                    break

            self.flush_changes(i)
            # # TODO: this needs to be customizable etc. somehow maybe..
            # if i % 100 == 0 and hasattr(self, 'session'):
            #     self.session.flush()

            if prog:
                prog.update(i)
        if prog:
            prog.destroy()

        return created, updated

    # TODO: this surely goes elsewhere
    flush_every_x = 100

    def flush_changes(self, x):
        if self.flush_every_x and x % self.flush_every_x == 0:
            if hasattr(self, 'session'):
                self.session.flush()

    def _import_delete(self, host_data, host_keys, changes=0):
        """
        Import deletions for the given data set.
        """
        deleted = []
        deleting = self.get_deletion_keys() - host_keys
        count = len(deleting)
        log.debug("found {} instances to delete".format(count))
        if count:

            prog = None
            if self.progress:
                prog = self.progress("Deleting {} data".format(self.model_name), count)

            for i, key in enumerate(sorted(deleting), 1):

                cached = self.cached_local_data.pop(key)
                obj = cached['object']
                if self.delete_object(obj):
                    deleted.append((obj, cached['data']))

                    if self.max_delete and len(deleted) >= self.max_delete:
                        log.warning("max of {} *deleted* records has been reached; stopping now".format(self.max_delete))
                        break
                    if self.max_total and (changes + len(deleted)) >= self.max_total:
                        log.warning("max of {} *total changes* has been reached; stopping now".format(self.max_total))
                        break

                if prog:
                    prog.update(i)
            if prog:
                prog.destroy()

        return deleted

    def get_key(self, data):
        """
        Return the key value for the given data dict.
        """
        return tuple(data[k] for k in self.key)

    def get_host_objects(self):
        """
        Return the "raw" (as-is, not normalized) host objects which are to be
        imported.  This may return any sequence-like object, which has a
        ``len()`` value and responds to iteration etc.  The objects contained
        within it may be of any type, no assumptions are made there.  (That is
        the job of the :meth:`normalize_host_data()` method.)
        """
        return []

    def normalize_host_data(self, host_objects=None):
        """
        Return a normalized version of the full set of host data.  Note that
        this calls :meth:`get_host_objects()` to obtain the initial raw
        objects, and then normalizes each object.  The normalization process
        may filter out some records from the set, in which case the return
        value will be smaller than the original data set.
        """
        if host_objects is None:
            host_objects = self.get_host_objects()
        normalized = []
        count = len(host_objects)
        if count == 0:
            return normalized
        prog = None
        if self.progress:
            prog = self.progress("Reading {} data from {}".format(self.model_name, self.host_system_title), count)
        for i, obj in enumerate(host_objects, 1):
            data = self.normalize_host_object(obj)
            if data:
                normalized.append(data)
            if prog:
                prog.update(i)
        if prog:
            prog.destroy()
        return normalized

    def normalize_host_object(self, obj):
        """
        Normalize a raw host object into a data dict, or return ``None`` if the
        object should be ignored for the importer's purposes.
        """
        return obj

    def cache_local_data(self, host_data=None):
        """
        Cache all raw objects and normalized data from the local system.
        """
        raise NotImplementedError

    def get_cache_key(self, obj, normal):
        """
        Get the primary cache key for a given object and normalized data.

        Note that this method's signature is designed for use with the
        :func:`rattail.db.cache.cache_model()` function, and as such the
        ``normal`` parameter is understood to be a dict with a ``'data'`` key,
        value for which is the normalized data dict for the raw object.
        """
        return tuple(normal['data'].get(k) for k in self.key)

    def normalize_cache_object(self, obj):
        """
        Normalizer for cached local data.  This returns a simple dict with
        ``'object'`` and ``'data'`` keys; values are the raw object and its
        normalized data dict, respectively.
        """
        return {'object': obj, 'data': self.normalize_local_object(obj)}

    def normalize_local_object(self, obj):
        """
        Normalize a local (raw) object into a data dict.
        """
        data = {}
        for field in self.simple_fields:
            if field in self.fields:
                data[field] = getattr(obj, field)
        return data

    def get_local_object(self, key):
        """
        Must return the local object corresponding to the given key, or
        ``None``.  Default behavior here will be to check the cache if one is
        in effect, otherwise return the value from
        :meth:`get_single_local_object()`.
        """
        if self.caches_local_data and self.cached_local_data is not None:
            data = self.cached_local_data.get(key)
            return data['object'] if data else None
        return self.get_single_local_object(key)

    def get_single_local_object(self, key):
        """
        Must return the local object corresponding to the given key, or None.
        This method should not consult the cache; that is handled within the
        :meth:`get_local_object()` method.
        """
        raise NotImplementedError

    def data_diffs(self, local_data, host_data):
        """
        Find all (relevant) fields which differ between the host and local data
        values for a given record.
        """
        diffs = []
        for field in self.fields:
            if local_data[field] != host_data[field]:
                diffs.append(field)
        return diffs

    def make_object(self):
        """
        Make a new/empty local object from scratch.
        """
        return self.model_class()

    def new_object(self, key):
        """
        Return a new local object to correspond to the given key.  Note that
        this method should only populate the object's key, and leave the rest
        of the fields to :meth:`update_object()`.
        """
        obj = self.make_object()
        for i, k in enumerate(self.key):
            if hasattr(obj, k):
                setattr(obj, k, key[i])
        return obj

    def create_object(self, key, host_data):
        """
        Create and return a new local object for the given key, fully populated
        from the given host data.  This may return ``None`` if no object is
        created.
        """
        obj = self.new_object(key)
        if obj:
            return self.update_object(obj, host_data)

    def update_object(self, obj, host_data, local_data=None):
        """
        Update the local data object with the given host data, and return the
        object.
        """
        for field in self.simple_fields:
            if field in self.fields:
                if not local_data or local_data[field] != host_data[field]:
                    setattr(obj, field, host_data[field])
        return obj

    def get_deletion_keys(self):
        """
        Return a set of keys from the *local* data set, which are eligible for
        deletion.  By default this will be all keys from the local cached data
        set, or an empty set if local data isn't cached.
        """
        if self.caches_local_data and self.cached_local_data is not None:
            return set(self.cached_local_data)
        return set()

    def delete_object(self, obj):
        """
        Delete the given object from the local system (or not), and return a
        boolean indicating whether deletion was successful.  What exactly this
        entails may vary; default implementation does nothing at all.
        """
        return True


class FromQuery(Importer):
    """
    Generic base class for importers whose raw external data source is a
    SQLAlchemy (or Django, or possibly other?) query.
    """

    def query(self):
        """
        Subclasses must override this, and return the primary query which will
        define the data set.
        """
        raise NotImplementedError

    def get_host_objects(self, progress=None):
        """
        Returns (raw) query results as a sequence.
        """
        return QuerySequence(self.query())