ariatosca-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dankil...@apache.org
Subject incubator-ariatosca git commit: ARIA-79-concurrent-modifications
Date Tue, 31 Jan 2017 14:23:26 GMT
Repository: incubator-ariatosca
Updated Branches:
  refs/heads/ARIA-79-concurrent-storage-modifications [created] 67a42409b


ARIA-79-concurrent-modifications


Project: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/commit/67a42409
Tree: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/tree/67a42409
Diff: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/diff/67a42409

Branch: refs/heads/ARIA-79-concurrent-storage-modifications
Commit: 67a42409b94307479dbdb5fcc246cee2a3eb55e0
Parents: 9e62fca
Author: Dan Kilman <dank@gigaspaces.com>
Authored: Mon Jan 30 16:49:00 2017 +0200
Committer: Dan Kilman <dank@gigaspaces.com>
Committed: Tue Jan 31 16:22:59 2017 +0200

----------------------------------------------------------------------
 aria/orchestrator/workflows/executor/process.py | 156 ++++++++++++-------
 aria/storage/base_model.py                      |   3 +
 aria/storage/instrumentation.py                 |  22 ++-
 aria/storage/sql_mapi.py                        |   4 +
 ...process_executor_concurrent_modifications.py |  83 ++++++++++
 tests/storage/__init__.py                       |   5 +-
 6 files changed, 213 insertions(+), 60 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/67a42409/aria/orchestrator/workflows/executor/process.py
----------------------------------------------------------------------
diff --git a/aria/orchestrator/workflows/executor/process.py b/aria/orchestrator/workflows/executor/process.py
index 7d990fa..319982e 100644
--- a/aria/orchestrator/workflows/executor/process.py
+++ b/aria/orchestrator/workflows/executor/process.py
@@ -74,6 +74,13 @@ class ProcessExecutor(base.BaseExecutor):
         # Contains reference to all currently running tasks
         self._tasks = {}
 
+        self._request_handlers = {
+            'started': self._handle_task_started_request,
+            'succeeded': self._handle_task_succeeded_request,
+            'failed': self._handle_task_failed_request,
+            'apply_tracked_changes': self._handle_apply_tracked_changes_request
+        }
+
         # Server socket used to accept task status messages from subprocesses
         self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         self._server_socket.bind(('localhost', 0))
@@ -131,58 +138,6 @@ class ProcessExecutor(base.BaseExecutor):
     def _remove_task(self, task_id):
         return self._tasks.pop(task_id)
 
-    def _listener(self):
-        # Notify __init__ method this thread has actually started
-        self._listener_started.put(True)
-        while not self._stopped:
-            try:
-                # Accept messages written to the server socket
-                with contextlib.closing(self._server_socket.accept()[0]) as connection:
-                    message = self._recv_message(connection)
-                    message_type = message['type']
-                    if message_type == 'closed':
-                        break
-                    task_id = message['task_id']
-                    if message_type == 'started':
-                        self._task_started(self._tasks[task_id])
-                    elif message_type == 'apply_tracked_changes':
-                        task = self._tasks[task_id]
-                        instrumentation.apply_tracked_changes(
-                            tracked_changes=message['tracked_changes'],
-                            model=task.context.model)
-                    elif message_type == 'succeeded':
-                        task = self._remove_task(task_id)
-                        instrumentation.apply_tracked_changes(
-                            tracked_changes=message['tracked_changes'],
-                            model=task.context.model)
-                        self._task_succeeded(task)
-                    elif message_type == 'failed':
-                        task = self._remove_task(task_id)
-                        instrumentation.apply_tracked_changes(
-                            tracked_changes=message['tracked_changes'],
-                            model=task.context.model)
-                        self._task_failed(task, exception=message['exception'])
-                    else:
-                        raise RuntimeError('Invalid state')
-            except BaseException as e:
-                self.logger.debug('Error in process executor listener: {0}'.format(e))
-
-    def _recv_message(self, connection):
-        message_len, = struct.unpack(_INT_FMT, self._recv_bytes(connection, _INT_SIZE))
-        return jsonpickle.loads(self._recv_bytes(connection, message_len))
-
-    @staticmethod
-    def _recv_bytes(connection, count):
-        result = io.BytesIO()
-        while True:
-            if not count:
-                return result.getvalue()
-            read = connection.recv(count)
-            if not read:
-                return result.getvalue()
-            result.write(read)
-            count -= len(read)
-
     def _check_closed(self):
         if self._stopped:
             raise RuntimeError('Executor closed')
@@ -231,6 +186,87 @@ class ProcessExecutor(base.BaseExecutor):
                 os.pathsep,
                 env.get('PYTHONPATH', ''))
 
+    def _listener(self):
+        # Notify __init__ method this thread has actually started
+        self._listener_started.put(True)
+        while not self._stopped:
+            try:
+                with self._accept_request() as (request, response):
+                    request_type = request['type']
+                    if request_type == 'closed':
+                        break
+                    request_handler = self._request_handlers.get(request_type)
+                    if not request_handler:
+                        raise RuntimeError('Invalid request type: {0}'.format(request_type))
+                    request_handler(task_id=request['task_id'], request=request, response=response)
+            except BaseException as e:
+                self.logger.debug('Error in process executor listener: {0}'.format(e))
+
+    @contextlib.contextmanager
+    def _accept_request(self):
+        with contextlib.closing(self._server_socket.accept()[0]) as connection:
+            message = _recv_message(connection)
+            response = {}
+            yield message, response
+            _send_message(connection, response)
+
+    def _handle_task_started_request(self, task_id, **kwargs):
+        self._task_started(self._tasks[task_id])
+
+    def _handle_task_succeeded_request(self, task_id, request, **kwargs):
+        task = self._remove_task(task_id)
+        try:
+            self._apply_tracked_changes(task, request)
+        except BaseException as e:
+            self._task_failed(task, exception=e)
+        else:
+            self._task_succeeded(task)
+
+    def _handle_task_failed_request(self, task_id, request, **kwargs):
+        task = self._remove_task(task_id)
+        try:
+            self._apply_tracked_changes(task, request)
+        except BaseException as e:
+            self._task_failed(task, exception=e)
+        else:
+            self._task_failed(task, exception=request['exception'])
+
+    def _handle_apply_tracked_changes_request(self, task_id, request, response):
+        task = self._tasks[task_id]
+        try:
+            self._apply_tracked_changes(task, request)
+        except BaseException as e:
+            response['exception'] = exceptions.wrap_if_needed(e)
+
+    @staticmethod
+    def _apply_tracked_changes(task, request):
+        instrumentation.apply_tracked_changes(
+            tracked_changes=request['tracked_changes'],
+            model=task.context.model)
+
+
+def _send_message(connection, message):
+    data = jsonpickle.dumps(message)
+    connection.send(struct.pack(_INT_FMT, len(data)))
+    connection.sendall(data)
+
+
+def _recv_message(connection):
+    message_len, = struct.unpack(_INT_FMT, _recv_bytes(connection, _INT_SIZE))
+    return jsonpickle.loads(_recv_bytes(connection, message_len))
+
+
+def _recv_bytes(connection, count):
+    result = io.BytesIO()
+    while True:
+        if not count:
+            return result.getvalue()
+        read = connection.recv(count)
+        if not read:
+            return result.getvalue()
+        result.write(read)
+        count -= len(read)
+
 
 class _Messenger(object):
 
@@ -261,17 +297,16 @@ class _Messenger(object):
         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
         sock.connect(('localhost', self.port))
         try:
-            data = jsonpickle.dumps({
+            _send_message(sock, {
                 'type': type,
                 'task_id': self.task_id,
                 'exception': exceptions.wrap_if_needed(exception),
                 'tracked_changes': tracked_changes
             })
-            sock.send(struct.pack(_INT_FMT, len(data)))
-            sock.sendall(data)
-            # send message will block until the server side closes the connection socket
-            # because we want it to be synchronous
-            sock.recv(1)
+            response = _recv_message(sock)
+            response_exception = response.get('exception')
+            if response_exception:
+                raise response_exception
         finally:
             sock.close()
 
@@ -294,12 +329,17 @@ def _patch_session(ctx, messenger, instrument):
         messenger.apply_tracked_changes(instrument.tracked_changes)
         instrument.clear()
 
+    def patched_rollback():
+        # Rollback is performed on parent process when commit fails
+        pass
+
     # when autoflush is set to true (the default), refreshing an object will trigger
     # an auto flush by sqlalchemy, this autoflush will attempt to commit changes made so
     # far on the session. this is not the desired behavior in the subprocess
     session.autoflush = False
 
     session.commit = patched_commit
+    session.rollback = patched_rollback
     session.refresh = patched_refresh
 
 

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/67a42409/aria/storage/base_model.py
----------------------------------------------------------------------
diff --git a/aria/storage/base_model.py b/aria/storage/base_model.py
index f7d0e5b..56605fc 100644
--- a/aria/storage/base_model.py
+++ b/aria/storage/base_model.py
@@ -479,6 +479,7 @@ class NodeInstanceBase(ModelMixin):
     __tablename__ = 'node_instances'
     _private_fields = ['node_fk', 'host_fk']
 
+    version_id = Column(Integer, nullable=False)
     runtime_properties = Column(Dict)
     scaling_groups = Column(List)
     state = Column(Text, nullable=False)
@@ -528,6 +529,8 @@ class NodeInstanceBase(ModelMixin):
             return host_node.properties['ip']
         return None
 
+    __mapper_args__ = {'version_id_col': version_id}
+
 
 class RelationshipInstanceBase(ModelMixin):
     """

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/67a42409/aria/storage/instrumentation.py
----------------------------------------------------------------------
diff --git a/aria/storage/instrumentation.py b/aria/storage/instrumentation.py
index 537dbb5..1e39121 100644
--- a/aria/storage/instrumentation.py
+++ b/aria/storage/instrumentation.py
@@ -15,11 +15,15 @@
 
 import copy
 
+import sqlalchemy
 import sqlalchemy.event
 
+from . import exceptions
 from . import api
 from . import model as _model
 
+
+_VERSION_ID_COL = 'version_id'
 _STUB = object()
 _INSTRUMENTED = {
     _model.NodeInstance.runtime_properties: dict
@@ -93,6 +97,9 @@ class _Instrumentation(object):
             mapi_name = self._mapi_name(instrumented_class)
             tracked_instances = self.tracked_changes.setdefault(mapi_name, {})
             tracked_attributes = tracked_instances.setdefault(target.id, {})
+            if hasattr(target, _VERSION_ID_COL):
+                tracked_attributes.setdefault(_VERSION_ID_COL,
+                                              _Value(_STUB, getattr(target, _VERSION_ID_COL)))
             for attribute_name, attribute_type in instrumented_attributes.items():
                 if attribute_name not in tracked_attributes:
                     initial = getattr(target, attribute_name)
@@ -148,7 +155,7 @@ class _Value(object):
         return self.initial == other.initial and self.current == other.current
 
     def __hash__(self):
-        return hash(self.initial) ^ hash(self.current)
+        return hash((self.initial, self.current))
 
 
 def apply_tracked_changes(tracked_changes, model):
@@ -168,4 +175,17 @@ def apply_tracked_changes(tracked_changes, model):
                         instance = mapi.get(instance_id)
                     setattr(instance, attribute_name, value.current)
             if instance:
+                version_id = sqlalchemy.inspect(instance).committed_state.get(_VERSION_ID_COL)
+                # There are two version conflict code paths:
+                # 1. The instance committed state loaded already holds a newer version,
+                #    in this case, we manually raise the error
+                # 2. The UPDATE statement is executed with version validation and sqlalchemy
+                #    will raise a StateDataError if there is a version mismatch.
+                if version_id and getattr(instance, _VERSION_ID_COL) != version_id:
+                    raise exceptions.StorageError(
+                        'Version conflict: committed and object {0} differ '
+                        '[committed {0}={1}, object {0}={2}]'
+                        .format(_VERSION_ID_COL,
+                                version_id,
+                                getattr(instance, _VERSION_ID_COL)))
                 mapi.update(instance)

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/67a42409/aria/storage/sql_mapi.py
----------------------------------------------------------------------
diff --git a/aria/storage/sql_mapi.py b/aria/storage/sql_mapi.py
index 809f677..0c08e48 100644
--- a/aria/storage/sql_mapi.py
+++ b/aria/storage/sql_mapi.py
@@ -17,6 +17,7 @@ SQLAlchemy based MAPI
 """
 
 from sqlalchemy.exc import SQLAlchemyError
+from sqlalchemy.orm.exc import StaleDataError
 
 from aria.utils.collections import OrderedDict
 from aria.storage import (
@@ -152,6 +153,9 @@ class SQLAlchemyModelAPI(api.ModelAPI):
         """
         try:
             self._session.commit()
+        except StaleDataError as e:
+            self._session.rollback()
+            raise exceptions.StorageError('Version conflict: {0}'.format(str(e)))
         except (SQLAlchemyError, ValueError) as e:
             self._session.rollback()
             raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e)))

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/67a42409/tests/orchestrator/workflows/executor/test_process_executor_concurrent_modifications.py
----------------------------------------------------------------------
diff --git a/tests/orchestrator/workflows/executor/test_process_executor_concurrent_modifications.py
b/tests/orchestrator/workflows/executor/test_process_executor_concurrent_modifications.py
new file mode 100644
index 0000000..434180a
--- /dev/null
+++ b/tests/orchestrator/workflows/executor/test_process_executor_concurrent_modifications.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+
+import pytest
+
+from aria.orchestrator.workflows import api
+from aria.orchestrator.workflows.core import engine
+from aria.orchestrator.workflows.executor import process
+from aria.orchestrator import workflow, operation
+
+import tests
+from tests import mock
+from tests import storage
+
+
+def _test_concurrent_runtime_properties_modification(context, executor):
+    props = _run_workflow(context, executor)
+    print props
+
+
+def _run_workflow(context, executor):
+    @workflow
+    def mock_workflow(ctx, graph):
+        key = 'key'
+        op = 'test.op'
+        op_dict = {'operation': '{0}.{1}'.format(__name__, _mock_operation.__name__)}
+        node_instance = ctx.model.node_instance.get_by_name(
+            mock.models.DEPENDENCY_NODE_INSTANCE_NAME)
+        node_instance.node.operations[op] = op_dict
+        task1 = api.task.OperationTask.node_instance(
+            instance=node_instance, name=op, inputs={'sleep': 0, 'key': key, 'value': 1})
+        task2 = api.task.OperationTask.node_instance(
+            instance=node_instance, name=op, inputs={'sleep': 4, 'key': key, 'value': 2})
+        graph.add_tasks(task1, task2)
+        return graph
+    graph = mock_workflow(ctx=context)  # pylint: disable=no-value-for-parameter
+    eng = engine.Engine(executor=executor, workflow_context=context, tasks_graph=graph)
+    eng.execute()
+    return context.model.node_instance.get_by_name(
+        mock.models.DEPENDENCY_NODE_INSTANCE_NAME).runtime_properties
+
+
+@operation
+def _mock_operation(ctx, sleep, key, value):
+    if not sleep:
+        time.sleep(2)
+    instance = ctx.node_instance
+    print instance.version_id, value, instance.runtime_properties
+    time.sleep(sleep)
+    if sleep:
+        print instance.version_id
+        # ctx.model.node_instance.refresh(instance)
+        # print instance.version_id
+        # ctx.model.node_instance.update(instance)
+    instance.runtime_properties[key] = value
+
+
+@pytest.fixture
+def executor():
+    result = process.ProcessExecutor(python_path=[tests.ROOT_DIR])
+    yield result
+    result.close()
+
+
+@pytest.fixture
+def context(tmpdir):
+    result = mock.context.simple(storage.get_sqlite_api_kwargs(str(tmpdir)))
+    yield result
+    storage.release_sqlite_storage(result.model)

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/67a42409/tests/storage/__init__.py
----------------------------------------------------------------------
diff --git a/tests/storage/__init__.py b/tests/storage/__init__.py
index 3b3715e..8fb4f30 100644
--- a/tests/storage/__init__.py
+++ b/tests/storage/__init__.py
@@ -75,8 +75,11 @@ def get_sqlite_api_kwargs(base_dir=None, filename='db.sqlite'):
 
     engine = create_engine(uri, **engine_kwargs)
     session_factory = orm.sessionmaker(bind=engine)
-    session = orm.scoped_session(session_factory=session_factory) if base_dir else session_factory()
 
+    if base_dir:
+        session = orm.scoped_session(session_factory=session_factory)
+    else:
+        session = session_factory()
     model.DeclarativeBase.metadata.create_all(bind=engine)
     return dict(engine=engine, session=session)
 


Mime
View raw message