ariatosca-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dankil...@apache.org
Subject incubator-ariatosca git commit: ARIA-63 Implement attribute tracking for subprocesses [Forced Update!]
Date Tue, 17 Jan 2017 03:09:03 GMT
Repository: incubator-ariatosca
Updated Branches:
  refs/heads/ARIA-63-runtime-properties-modification 9f29d2912 -> 036f9c8ea (forced update)


ARIA-63 Implement attribute tracking for subprocesses


Project: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/commit/036f9c8e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/tree/036f9c8e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/diff/036f9c8e

Branch: refs/heads/ARIA-63-runtime-properties-modification
Commit: 036f9c8eadffc774a12b1df8a11b277191bbd916
Parents: dac4da7
Author: Dan Kilman <dank@gigaspaces.com>
Authored: Sun Jan 15 17:42:23 2017 +0200
Committer: Dan Kilman <dank@gigaspaces.com>
Committed: Tue Jan 17 05:06:58 2017 +0200

----------------------------------------------------------------------
 aria/orchestrator/workflows/executor/process.py |  47 ++--
 aria/storage/instrumentation.py                 | 157 +++++++++++
 aria/storage/type.py                            |  35 ++-
 tests/.pylintrc                                 |   2 +-
 .../test_process_executor_tracked_changes.py    |  95 +++++++
 tests/storage/test_instrumentation.py           | 274 +++++++++++++++++++
 6 files changed, 590 insertions(+), 20 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/036f9c8e/aria/orchestrator/workflows/executor/process.py
----------------------------------------------------------------------
diff --git a/aria/orchestrator/workflows/executor/process.py b/aria/orchestrator/workflows/executor/process.py
index e0a8aeb..cd80287 100644
--- a/aria/orchestrator/workflows/executor/process.py
+++ b/aria/orchestrator/workflows/executor/process.py
@@ -42,6 +42,8 @@ import jsonpickle
 from aria.utils import imports
 from aria.orchestrator.workflows.executor import base
 from aria.orchestrator.context import serialization
+from aria.storage import instrumentation
+from aria.storage import type as storage_type
 
 _IS_WIN = os.name == 'nt'
 
@@ -139,10 +141,17 @@ class ProcessExecutor(base.BaseExecutor):
                 if message_type == 'started':
                     self._task_started(self._tasks[task_id])
                 elif message_type == 'succeeded':
-                    self._task_succeeded(self._remove_task(task_id))
+                    task = self._remove_task(task_id)
+                    instrumentation.apply_tracked_changes(
+                        tracked_changes=message['tracked_changes'],
+                        model=task.context.model)
+                    self._task_succeeded(task)
                 elif message_type == 'failed':
-                    self._task_failed(self._remove_task(task_id),
-                                      exception=message['exception'])
+                    task = self._remove_task(task_id)
+                    instrumentation.apply_tracked_changes(
+                        tracked_changes=message['tracked_changes'],
+                        model=task.context.model)
+                    self._task_failed(task, exception=message['exception'])
                 else:
                     raise RuntimeError('Invalid state')
             except BaseException as e:
@@ -227,26 +236,27 @@ class _Messenger(object):
         """Task started message"""
         self._send_message(type='started')
 
-    def succeeded(self):
+    def succeeded(self, tracked_changes):
         """Task succeeded message"""
-        self._send_message(type='succeeded')
+        self._send_message(type='succeeded', tracked_changes=tracked_changes)
 
-    def failed(self, exception):
+    def failed(self, tracked_changes, exception):
         """Task failed message"""
-        self._send_message(type='failed', exception=exception)
+        self._send_message(type='failed', tracked_changes=tracked_changes, exception=exception)
 
     def closed(self):
         """Executor closed message"""
         self._send_message(type='closed')
 
-    def _send_message(self, type, exception=None):
+    def _send_message(self, type, tracked_changes=None, exception=None):
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         sock.connect(('localhost', self.port))
         try:
             data = jsonpickle.dumps({
                 'type': type,
                 'task_id': self.task_id,
-                'exception': exception
+                'exception': exception,
+                'tracked_changes': tracked_changes
             })
             sock.send(struct.pack(_INT_FMT, len(data)))
             sock.sendall(data)
@@ -271,13 +281,18 @@ def _main():
     operation_mapping = arguments['operation_mapping']
     operation_inputs = arguments['operation_inputs']
     context_dict = arguments['context']
-    try:
-        ctx = serialization.operation_context_from_dict(context_dict)
-        task_func = imports.load_attribute(operation_mapping)
-        task_func(ctx=ctx, **operation_inputs)
-        messenger.succeeded()
-    except BaseException as e:
-        messenger.failed(exception=e)
+
+    # See docstring of `remove_mutable_association_listener`
+    storage_type.remove_mutable_association_listener()
+
+    with instrumentation.track_changes() as instrument:
+        try:
+            ctx = serialization.operation_context_from_dict(context_dict)
+            task_func = imports.load_attribute(operation_mapping)
+            task_func(ctx=ctx, **operation_inputs)
+            messenger.succeeded(tracked_changes=instrument.tracked_changes)
+        except BaseException as e:
+            messenger.failed(exception=e, tracked_changes=instrument.tracked_changes)
 
 if __name__ == '__main__':
     _main()

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/036f9c8e/aria/storage/instrumentation.py
----------------------------------------------------------------------
diff --git a/aria/storage/instrumentation.py b/aria/storage/instrumentation.py
new file mode 100644
index 0000000..7ac649a
--- /dev/null
+++ b/aria/storage/instrumentation.py
@@ -0,0 +1,157 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+
+import sqlalchemy.event
+
+from . import api
+from . import model as _model
+
+_STUB = object()
+_INSTRUMENTED = {
+    _model.NodeInstance.runtime_properties: dict
+}
+
+
+def track_changes(instrumented=None):
+    """Track changes in the specified model columns
+
+    This call will register event listeners using sqlalchemy's event mechanism. The listeners
+    instrument all returned objects such that the attributes specified in ``instrumented``,
will
+    be replaced with a value that is stored in the returned instrumentation context
+    ``tracked_changes`` property.
+
+    Why should this be implemented when sqlalchemy already does a fantastic job at tracking
changes
+    you ask? Well, when sqlalchemy is used with sqlite, due to how sqlite works, only once
process
+    can hold a write lock to the database. This does not work well when ARIA runs tasks in
+    subprocesses (by the process executor) and these tasks wish to change some state as well.
These
+    tasks certainly deserve a change to do so!
+
+    To enable this, the subprocess calls track_changes before any state changes are made.
At the
+    end of the subprocess execution, it should return the ``tracked_changes`` attribute of
the
+    context returned from this call to the parent process. The parent process will then call
+    ``apply_tracked_changes()`` that resides in this module as well. At that point, the changes
+    will actually be written back to the database.
+
+    :param instrumented: A dict from model columns to their python native type
+    :return: The instrumentation context
+    """
+    return _Instrumentation(instrumented or _INSTRUMENTED)
+
+
+class _Instrumentation(object):
+
+    def __init__(self, instrumented):
+        self.tracked_changes = {}
+        self.listeners = []
+        self._track_changes(instrumented)
+
+    def _track_changes(self, instrumented):
+        instrumented_classes = {}
+        for instrumented_attribute, attribute_type in instrumented.items():
+            self._register_set_attribute_listener(
+                instrumented_attribute=instrumented_attribute,
+                attribute_type=attribute_type)
+            instrumented_class = instrumented_attribute.parent.entity
+            instrumented_class_attributes = instrumented_classes.setdefault(instrumented_class,
{})
+            instrumented_class_attributes[instrumented_attribute.key] = attribute_type
+        for instrumented_class, instrumented_attributes in instrumented_classes.items():
+            self._register_instance_listeners(
+                instrumented_class=instrumented_class,
+                instrumented_attributes=instrumented_attributes)
+
+    def _register_set_attribute_listener(self, instrumented_attribute, attribute_type):
+        def listener(target, value, *_):
+            mapi_name = api.generate_lower_name(target.__class__)
+            tracked_instances = self.tracked_changes.setdefault(mapi_name, {})
+            tracked_attributes = tracked_instances.setdefault(target.id, {})
+            if value is None:
+                current = None
+            else:
+                current = copy.deepcopy(attribute_type(value))
+            tracked_attributes[instrumented_attribute.key] = _Value(_STUB, current)
+            return current
+        listener_args = (instrumented_attribute, 'set', listener)
+        sqlalchemy.event.listen(*listener_args, retval=True)
+        self.listeners.append(listener_args)
+
+    def _register_instance_listeners(self, instrumented_class, instrumented_attributes):
+        def listener(target, *_):
+            mapi_name = api.generate_lower_name(instrumented_class)
+            tracked_instances = self.tracked_changes.setdefault(mapi_name, {})
+            tracked_attributes = tracked_instances.setdefault(target.id, {})
+            for attribute_name, attribute_type in instrumented_attributes.items():
+                if attribute_name not in tracked_attributes:
+                    initial = getattr(target, attribute_name)
+                    if initial is None:
+                        current = None
+                    else:
+                        current = copy.deepcopy(attribute_type(initial))
+                    tracked_attributes[attribute_name] = _Value(initial, current)
+                target.__dict__[attribute_name] = tracked_attributes[attribute_name].current
+        for listener_args in [(instrumented_class, 'load', listener),
+                              (instrumented_class, 'refresh', listener),
+                              (instrumented_class, 'refresh_flush', listener)]:
+            sqlalchemy.event.listen(*listener_args)
+            self.listeners.append(listener_args)
+
+    def restore(self):
+        """Remove all listeners registered by this instrumentation"""
+        for listener_args in self.listeners:
+            if sqlalchemy.event.contains(*listener_args):
+                sqlalchemy.event.remove(*listener_args)
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.restore()
+
+
+class _Value(object):
+    # You may wonder why is this a full blown class and not a named tuple. The reason is
that
+    # jsonpickle that is used to serialize the tracked_changes, does not handle named tuples
very
+    # well. At the very least, I could not get it to behave.
+
+    def __init__(self, initial, current):
+        self.initial = initial
+        self.current = current
+
+    def __eq__(self, other):
+        if not isinstance(other, _Value):
+            return False
+        return self.initial == other.initial and self.current == other.current
+
+    def __hash__(self):
+        return hash(self.initial) ^ hash(self.current)
+
+
+def apply_tracked_changes(tracked_changes, model):
+    """Write tracked changes back to the database using provided model storage
+
+    :param tracked_changes: The tracked_changes attribute of the context returned by calling
+                            ``track_changes()``
+    :param model: The model storage used to actually apply the changes.
+    """
+    for mapi_name, tracked_instances in tracked_changes.items():
+        mapi = getattr(model, mapi_name)
+        for instance_id, tracked_attributes in tracked_instances.items():
+            instance = None
+            for attribute_name, value in tracked_attributes.items():
+                if value.initial != value.current:
+                    if not instance:
+                        instance = mapi.get(instance_id)
+                    setattr(instance, attribute_name, value.current)

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/036f9c8e/aria/storage/type.py
----------------------------------------------------------------------
diff --git a/aria/storage/type.py b/aria/storage/type.py
index ab50b0f..3fe206f 100644
--- a/aria/storage/type.py
+++ b/aria/storage/type.py
@@ -16,7 +16,8 @@ import json
 
 from sqlalchemy import (
     TypeDecorator,
-    VARCHAR
+    VARCHAR,
+    event
 )
 
 from sqlalchemy.ext import mutable
@@ -84,5 +85,33 @@ class _MutableList(mutable.MutableList):
             raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e)))
 
 
-_MutableList.associate_with(List)
-_MutableDict.associate_with(Dict)
+def _mutable_association_listener(mapper, cls):
+    for prop in mapper.column_attrs:
+        column_type = prop.columns[0].type
+        if isinstance(column_type, Dict):
+            _MutableDict.associate_with_attribute(getattr(cls, prop.key))
+        if isinstance(column_type, List):
+            _MutableList.associate_with_attribute(getattr(cls, prop.key))
+_LISTENER_ARGS = (mutable.mapper, 'mapper_configured', _mutable_association_listener)
+
+
+def _register_mutable_association_listener():
+    event.listen(*_LISTENER_ARGS)
+
+
+def remove_mutable_association_listener():
+    """
+    Remove the event listener that associates Dict and List column types with MutableDict
+    and MutableList, respectively.
+
+    This call must happen before any model instance is instantiated.
+    This is because once it does, that would trigger the listener we are trying to remove.
+    Once it is triggered, many other listeners will then be registered.
+    At that point, it is too late.
+
+    Note that the event listener this call removes is registered by default.
+    """
+    if event.contains(*_LISTENER_ARGS):
+        event.remove(*_LISTENER_ARGS)
+
+_register_mutable_association_listener()

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/036f9c8e/tests/.pylintrc
----------------------------------------------------------------------
diff --git a/tests/.pylintrc b/tests/.pylintrc
index 23251af..5de0691 100644
--- a/tests/.pylintrc
+++ b/tests/.pylintrc
@@ -77,7 +77,7 @@ confidence=
 # --enable=similarities". If you want to run only the classes checker, but have
 # no Warning level messages displayed, use"--disable=all --enable=classes
 # --disable=W"
-disable=import-star-module-level,old-octal-literal,oct-method,print-statement,unpacking-in-except,parameter-unpacking,backtick,old-raise-syntax,old-ne-operator,long-suffix,dict-view-method,dict-iter-method,metaclass-assignment,next-method-called,raising-string,indexing-exception,raw_input-builtin,long-builtin,file-builtin,execfile-builtin,coerce-builtin,cmp-builtin,buffer-builtin,basestring-builtin,apply-builtin,filter-builtin-not-iterating,using-cmp-argument,useless-suppression,range-builtin-not-iterating,suppressed-message,no-absolute-import,old-division,cmp-method,reload-builtin,zip-builtin-not-iterating,intern-builtin,unichr-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,input-builtin,round-builtin,hex-method,nonzero-method,map-builtin-not-iterating,redefined-builtin,no-self-use,missing-docstring,attribute-defined-outside-init,redefined-outer-name,import-error,redefined-variable-type,broad
 -except,protected-access,global-statement,too-many-locals
+disable=import-star-module-level,old-octal-literal,oct-method,print-statement,unpacking-in-except,parameter-unpacking,backtick,old-raise-syntax,old-ne-operator,long-suffix,dict-view-method,dict-iter-method,metaclass-assignment,next-method-called,raising-string,indexing-exception,raw_input-builtin,long-builtin,file-builtin,execfile-builtin,coerce-builtin,cmp-builtin,buffer-builtin,basestring-builtin,apply-builtin,filter-builtin-not-iterating,using-cmp-argument,useless-suppression,range-builtin-not-iterating,suppressed-message,no-absolute-import,old-division,cmp-method,reload-builtin,zip-builtin-not-iterating,intern-builtin,unichr-builtin,reduce-builtin,standarderror-builtin,unicode-builtin,xrange-builtin,coerce-method,delslice-method,getslice-method,setslice-method,input-builtin,round-builtin,hex-method,nonzero-method,map-builtin-not-iterating,redefined-builtin,no-self-use,missing-docstring,attribute-defined-outside-init,redefined-outer-name,import-error,redefined-variable-type,broad
 -except,protected-access,global-statement,too-many-locals,abstract-method
 
 [REPORTS]
 

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/036f9c8e/tests/orchestrator/workflows/executor/test_process_executor_tracked_changes.py
----------------------------------------------------------------------
diff --git a/tests/orchestrator/workflows/executor/test_process_executor_tracked_changes.py
b/tests/orchestrator/workflows/executor/test_process_executor_tracked_changes.py
new file mode 100644
index 0000000..1564292
--- /dev/null
+++ b/tests/orchestrator/workflows/executor/test_process_executor_tracked_changes.py
@@ -0,0 +1,95 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+
+from aria.orchestrator.workflows import api
+from aria.orchestrator.workflows.core import engine
+from aria.orchestrator.workflows.executor import process
+from aria.orchestrator import workflow, operation
+from aria.orchestrator.workflows import exceptions
+
+import tests
+from tests import mock
+from tests import storage
+
+
+_TEST_RUNTIME_PROPERTIES = {
+    'some': 'values', 'that': 'are', 'most': 'likely', 'only': 'set', 'here': 'yo'
+}
+
+
+def test_track_changes_of_successful_operation(context, executor):
+    _run_workflow(context=context, executor=executor, op_func=_mock_success_operation)
+    _assert_tracked_changes_are_applied(context)
+
+
+def test_track_changes_of_failed_operation(context, executor):
+    with pytest.raises(exceptions.ExecutorException):
+        _run_workflow(context=context, executor=executor, op_func=_mock_fail_operation)
+    _assert_tracked_changes_are_applied(context)
+
+
+def _assert_tracked_changes_are_applied(context):
+    instance = context.model.node_instance.get_by_name(mock.models.DEPENDENCY_NODE_INSTANCE_NAME)
+    assert instance.runtime_properties == _TEST_RUNTIME_PROPERTIES
+
+
+def _update_runtime_properties(context):
+    context.node_instance.runtime_properties.clear()
+    context.node_instance.runtime_properties.update(_TEST_RUNTIME_PROPERTIES)
+
+
+def _run_workflow(context, executor, op_func):
+    @workflow
+    def mock_workflow(ctx, graph):
+        node_instance = ctx.model.node_instance.get_by_name(
+            mock.models.DEPENDENCY_NODE_INSTANCE_NAME)
+        node_instance.node.operations['test.op'] = {'operation': _operation_mapping(op_func)}
+        task = api.task.OperationTask.node_instance(instance=node_instance, name='test.op')
+        graph.add_tasks(task)
+        return graph
+    graph = mock_workflow(ctx=context)  # pylint: disable=no-value-for-parameter
+    eng = engine.Engine(executor=executor, workflow_context=context, tasks_graph=graph)
+    eng.execute()
+
+
+@operation
+def _mock_success_operation(ctx):
+    _update_runtime_properties(ctx)
+
+
+@operation
+def _mock_fail_operation(ctx):
+    _update_runtime_properties(ctx)
+    raise RuntimeError
+
+
+def _operation_mapping(func):
+    return '{name}.{func.__name__}'.format(name=__name__, func=func)
+
+
+@pytest.fixture
+def executor():
+    result = process.ProcessExecutor(python_path=[tests.ROOT_DIR])
+    yield result
+    result.close()
+
+
+@pytest.fixture
+def context(tmpdir):
+    result = mock.context.simple(storage.get_sqlite_api_kwargs(str(tmpdir)))
+    yield result
+    storage.release_sqlite_storage(result.model)

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/036f9c8e/tests/storage/test_instrumentation.py
----------------------------------------------------------------------
diff --git a/tests/storage/test_instrumentation.py b/tests/storage/test_instrumentation.py
new file mode 100644
index 0000000..b00bbd3
--- /dev/null
+++ b/tests/storage/test_instrumentation.py
@@ -0,0 +1,274 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import pytest
+from sqlalchemy import Column, Text, Integer, event
+
+from aria.storage import (
+    model,
+    structure,
+    type as aria_type,
+    ModelStorage,
+    sql_mapi,
+    instrumentation
+)
+from ..storage import get_sqlite_api_kwargs, release_sqlite_storage
+
+
+STUB = instrumentation._STUB
+Value = instrumentation._Value
+instruments_holder = []
+
+
+class TestInstrumentation(object):
+
+    def test_track_changes(self, storage):
+        model_kwargs = dict(
+            name='name',
+            dict1={'initial': 'value'},
+            dict2={'initial': 'value'},
+            list1=['initial'],
+            list2=['initial'],
+            int1=0,
+            int2=0,
+            string2='string')
+        model1_instance = MockModel1(**model_kwargs)
+        model2_instance = MockModel2(**model_kwargs)
+        storage.mock_model_1.put(model1_instance)
+        storage.mock_model_2.put(model2_instance)
+
+        instrument = self._track_changes({
+            MockModel1.dict1: dict,
+            MockModel1.list1: list,
+            MockModel1.int1: int,
+            MockModel1.string2: str,
+            MockModel2.dict2: dict,
+            MockModel2.list2: list,
+            MockModel2.int2: int,
+            MockModel2.name: str
+        })
+
+        assert not instrument.tracked_changes
+
+        storage_model1_instance = storage.mock_model_1.get(model1_instance.id)
+        storage_model2_instance = storage.mock_model_2.get(model2_instance.id)
+
+        storage_model1_instance.dict1 = {'hello': 'world'}
+        storage_model1_instance.dict2 = {'should': 'not track'}
+        storage_model1_instance.list1 = ['hello']
+        storage_model1_instance.list2 = ['should not track']
+        storage_model1_instance.int1 = 100
+        storage_model1_instance.int2 = 20000
+        storage_model1_instance.name = 'should not track'
+        storage_model1_instance.string2 = 'new_string'
+
+        storage_model2_instance.dict1.update({'should': 'not track'})
+        storage_model2_instance.dict2.update({'hello': 'world'})
+        storage_model2_instance.list1.append('should not track')
+        storage_model2_instance.list2.append('hello')
+        storage_model2_instance.int1 = 100
+        storage_model2_instance.int2 = 20000
+        storage_model2_instance.name = 'new_name'
+        storage_model2_instance.string2 = 'should not track'
+
+        assert instrument.tracked_changes == {
+            'mock_model_1': {
+                model1_instance.id: {
+                    'dict1': Value(STUB, {'hello': 'world'}),
+                    'list1': Value(STUB, ['hello']),
+                    'int1': Value(STUB, 100),
+                    'string2': Value(STUB, 'new_string')
+                }
+            },
+            'mock_model_2': {
+                model2_instance.id: {
+                    'dict2': Value({'initial': 'value'}, {'hello': 'world', 'initial': 'value'}),
+                    'list2': Value(['initial'], ['initial', 'hello']),
+                    'int2': Value(STUB, 20000),
+                    'name': Value(STUB, 'new_name'),
+                }
+            }
+        }
+
+    def test_attribute_initial_none_value(self, storage):
+        instance1 = MockModel1(name='name1', dict1=None)
+        instance2 = MockModel1(name='name2', dict1=None)
+        storage.mock_model_1.put(instance1)
+        storage.mock_model_1.put(instance2)
+        instrument = self._track_changes({MockModel1.dict1: dict})
+        instance1 = storage.mock_model_1.get(instance1.id)
+        instance2 = storage.mock_model_1.get(instance2.id)
+        instance1.dict1 = {'new': 'value'}
+        assert instrument.tracked_changes == {
+            'mock_model_1': {
+                instance1.id: {'dict1': Value(STUB, {'new': 'value'})},
+                instance2.id: {'dict1': Value(None, None)},
+            }
+        }
+
+    def test_attribute_set_none_value(self, storage):
+        instance = MockModel1(name='name')
+        storage.mock_model_1.put(instance)
+        instrument = self._track_changes({
+            MockModel1.dict1: dict,
+            MockModel1.list1: list,
+            MockModel1.string2: str,
+            MockModel1.int1: int
+        })
+        instance = storage.mock_model_1.get(instance.id)
+        instance.dict1 = None
+        instance.list1 = None
+        instance.string2 = None
+        instance.int1 = None
+        assert instrument.tracked_changes == {
+            'mock_model_1': {
+                instance.id: {
+                    'dict1': Value(STUB, None),
+                    'list1': Value(STUB, None),
+                    'string2': Value(STUB, None),
+                    'int1': Value(STUB, None)
+                }
+            }
+        }
+
+    def test_restore(self):
+        instrument = self._track_changes({MockModel1.dict1: dict})
+        # set instance attribute, load instance, refresh instance and flush_refresh listeners
+        assert len(instrument.listeners) == 4
+        for listener_args in instrument.listeners:
+            assert event.contains(*listener_args)
+        instrument.restore()
+        assert len(instrument.listeners) == 4
+        for listener_args in instrument.listeners:
+            assert not event.contains(*listener_args)
+        return instrument
+
+    def test_restore_twice(self):
+        instrument = self.test_restore()
+        instrument.restore()
+
+    def test_instrumentation_context_manager(self, storage):
+        instance = MockModel1(name='name')
+        storage.mock_model_1.put(instance)
+        with self._track_changes({MockModel1.dict1: dict}) as instrument:
+            instance = storage.mock_model_1.get(instance.id)
+            instance.dict1 = {'new': 'value'}
+            assert instrument.tracked_changes == {
+                'mock_model_1': {instance.id: {'dict1': Value(STUB, {'new': 'value'})}}
+            }
+            assert len(instrument.listeners) == 4
+            for listener_args in instrument.listeners:
+                assert event.contains(*listener_args)
+        for listener_args in instrument.listeners:
+            assert not event.contains(*listener_args)
+
+    def test_apply_tracked_changes(self, storage):
+        initial_values = {'dict1': {'initial': 'value'}, 'list1': ['initial']}
+        instance1_1 = MockModel1(name='instance1_1', **initial_values)
+        instance1_2 = MockModel1(name='instance1_2', **initial_values)
+        instance2_1 = MockModel2(name='instance2_1', **initial_values)
+        instance2_2 = MockModel2(name='instance2_2', **initial_values)
+        storage.mock_model_1.put(instance1_1)
+        storage.mock_model_1.put(instance1_2)
+        storage.mock_model_2.put(instance2_1)
+        storage.mock_model_2.put(instance2_2)
+
+        instrument = self._track_changes({
+            MockModel1.dict1: dict,
+            MockModel1.list1: list,
+            MockModel2.dict1: dict,
+            MockModel2.list1: list
+        })
+
+        def get_instances():
+            return (storage.mock_model_1.get(instance1_1.id),
+                    storage.mock_model_1.get(instance1_2.id),
+                    storage.mock_model_2.get(instance2_1.id),
+                    storage.mock_model_2.get(instance2_2.id))
+
+        instance1_1, instance1_2, instance2_1, instance2_2 = get_instances()
+        instance1_1.dict1 = {'new': 'value'}
+        instance1_2.list1 = ['new_value']
+        instance2_1.dict1.update({'new': 'value'})
+        instance2_2.list1.append('new_value')
+
+        instrument.restore()
+        storage.mock_model_1._session.expire_all()
+
+        instance1_1, instance1_2, instance2_1, instance2_2 = get_instances()
+        instance1_1.dict1 = {'overriding': 'value'}
+        instance1_2.list1 = ['overriding_value']
+        instance2_1.dict1 = {'overriding': 'value'}
+        instance2_2.list1 = ['overriding_value']
+        storage.mock_model_1.put(instance1_1)
+        storage.mock_model_1.put(instance1_2)
+        storage.mock_model_2.put(instance2_1)
+        storage.mock_model_2.put(instance2_2)
+        instance1_1, instance1_2, instance2_1, instance2_2 = get_instances()
+        assert instance1_1.dict1 == {'overriding': 'value'}
+        assert instance1_2.list1 == ['overriding_value']
+        assert instance2_1.dict1 == {'overriding': 'value'}
+        assert instance2_2.list1 == ['overriding_value']
+
+        instrumentation.apply_tracked_changes(
+            tracked_changes=instrument.tracked_changes,
+            model=storage)
+
+        instance1_1, instance1_2, instance2_1, instance2_2 = get_instances()
+        assert instance1_1.dict1 == {'new': 'value'}
+        assert instance1_2.list1 == ['new_value']
+        assert instance2_1.dict1 == {'initial': 'value', 'new': 'value'}
+        assert instance2_2.list1 == ['initial', 'new_value']
+
+    def _track_changes(self, instrumented):
+        instrument = instrumentation.track_changes(instrumented)
+        instruments_holder.append(instrument)
+        return instrument
+
+
+@pytest.fixture(autouse=True)
+def restore_instrumentation():
+    for instrument in instruments_holder:
+        instrument.restore()
+    del instruments_holder[:]
+
+
+@pytest.fixture
+def storage():
+    result = ModelStorage(
+        api_cls=sql_mapi.SQLAlchemyModelAPI,
+        api_kwargs=get_sqlite_api_kwargs(),
+        items=(MockModel1, MockModel2))
+    yield result
+    release_sqlite_storage(result)
+
+
+class _MockModel(structure.ModelMixin):
+    name = Column(Text)
+    dict1 = Column(aria_type.Dict)
+    dict2 = Column(aria_type.Dict)
+    list1 = Column(aria_type.List)
+    list2 = Column(aria_type.List)
+    int1 = Column(Integer)
+    int2 = Column(Integer)
+    string2 = Column(Text)
+
+
+class MockModel1(model.DeclarativeBase, _MockModel):
+    __tablename__ = 'mock_model1'
+
+
+class MockModel2(model.DeclarativeBase, _MockModel):
+    __tablename__ = 'mock_model2'


Mime
View raw message