ariatosca-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From mxm...@apache.org
Subject [2/8] incubator-ariatosca git commit: ARIA-7 Initial celery based executor implementation
Date Tue, 08 Nov 2016 10:16:31 GMT
ARIA-7 Initial celery based executor implementation

This commit also includes code re-org. Executors now live under
their own package.


Project: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/commit/c0bf3479
Tree: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/tree/c0bf3479
Diff: http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/diff/c0bf3479

Branch: refs/heads/ARIA-9-API-for-operation-context
Commit: c0bf34791bc73adbbe1dbf2d9dfb16a109589ad3
Parents: e1c919d
Author: Dan Kilman <dankilman@gmail.com>
Authored: Sun Oct 30 17:21:36 2016 +0200
Committer: Dan Kilman <dankilman@gmail.com>
Committed: Tue Nov 1 11:27:14 2016 +0200

----------------------------------------------------------------------
 aria/cli/commands.py                    |   2 +-
 aria/workflows/core/executor.py         | 192 ---------------------------
 aria/workflows/executor/__init__.py     |  14 ++
 aria/workflows/executor/base.py         |  54 ++++++++
 aria/workflows/executor/blocking.py     |  37 ++++++
 aria/workflows/executor/celery.py       |  96 ++++++++++++++
 aria/workflows/executor/multiprocess.py |  98 ++++++++++++++
 aria/workflows/executor/thread.py       |  67 ++++++++++
 tests/workflows/test_executor.py        |  47 +++++--
 9 files changed, 400 insertions(+), 207 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c0bf3479/aria/cli/commands.py
----------------------------------------------------------------------
diff --git a/aria/cli/commands.py b/aria/cli/commands.py
index a38229d..ddc27b5 100644
--- a/aria/cli/commands.py
+++ b/aria/cli/commands.py
@@ -31,7 +31,7 @@ from aria.storage import FileSystemModelDriver, FileSystemResourceDriver
 from aria.tools.application import StorageManager
 from aria.contexts import WorkflowContext
 from aria.workflows.core.engine import Engine
-from aria.workflows.core.executor import ThreadExecutor
+from aria.workflows.executor.thread import ThreadExecutor
 
 from .storage import (
     local_resource_storage,

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c0bf3479/aria/workflows/core/executor.py
----------------------------------------------------------------------
diff --git a/aria/workflows/core/executor.py b/aria/workflows/core/executor.py
deleted file mode 100644
index ace445a..0000000
--- a/aria/workflows/core/executor.py
+++ /dev/null
@@ -1,192 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-"""
-Executors for workflow tasks
-"""
-
-import threading
-import multiprocessing
-import Queue
-
-import jsonpickle
-
-from aria import events
-from aria.tools import module
-
-
-class BaseExecutor(object):
-    """
-    Base class for executors for running tasks
-    """
-
-    def __init__(self, *args, **kwargs):
-        pass
-
-    def execute(self, task):
-        """
-        Execute a task
-        :param task: task to execute
-        """
-        raise NotImplementedError
-
-    def close(self):
-        """
-        Close the executor
-        """
-        pass
-
-    @staticmethod
-    def _task_started(task):
-        events.start_task_signal.send(task)
-
-    @staticmethod
-    def _task_failed(task, exception):
-        events.on_failure_task_signal.send(task, exception=exception)
-
-    @staticmethod
-    def _task_succeeded(task):
-        events.on_success_task_signal.send(task)
-
-
-class CurrentThreadBlockingExecutor(BaseExecutor):
-    """
-    Executor which runs tasks in the current thread (blocking)
-    """
-
-    def execute(self, task):
-        self._task_started(task)
-        try:
-            operation_context = task.context
-            task_func = module.load_attribute(operation_context.operation_details['operation'])
-            task_func(**operation_context.inputs)
-            self._task_succeeded(task)
-        except BaseException as e:
-            self._task_failed(task, exception=e)
-
-
-class ThreadExecutor(BaseExecutor):
-    """
-    Executor which runs tasks in a separate thread
-    """
-
-    def __init__(self, pool_size=1, *args, **kwargs):
-        super(ThreadExecutor, self).__init__(*args, **kwargs)
-        self._stopped = False
-        self._queue = Queue.Queue()
-        self._pool = []
-        for i in range(pool_size):
-            name = 'ThreadExecutor-{index}'.format(index=i+1)
-            thread = threading.Thread(target=self._processor, name=name)
-            thread.daemon = True
-            thread.start()
-            self._pool.append(thread)
-
-    def execute(self, task):
-        self._queue.put(task)
-
-    def close(self):
-        self._stopped = True
-        for thread in self._pool:
-            thread.join()
-
-    def _processor(self):
-        while not self._stopped:
-            try:
-                task = self._queue.get(timeout=1)
-                self._task_started(task)
-                try:
-                    operation_context = task.context
-                    task_func = module.load_attribute(
-                        operation_context.operation_details['operation'])
-                    task_func(**operation_context.inputs)
-                    self._task_succeeded(task)
-                except BaseException as e:
-                    self._task_failed(task, exception=e)
-            # Daemon threads
-            except BaseException:
-                pass
-
-
-class MultiprocessExecutor(BaseExecutor):
-    """
-    Executor which runs tasks in a multiprocess environment
-    """
-
-    def __init__(self, pool_size=1, *args, **kwargs):
-        super(MultiprocessExecutor, self).__init__(*args, **kwargs)
-        self._stopped = False
-        self._manager = multiprocessing.Manager()
-        self._queue = self._manager.Queue()
-        self._tasks = {}
-        self._listener_thread = threading.Thread(target=self._listener)
-        self._listener_thread.daemon = True
-        self._listener_thread.start()
-        self._pool = multiprocessing.Pool(processes=pool_size,
-                                          maxtasksperchild=1)
-
-    def execute(self, task):
-        self._tasks[task.id] = task
-        self._pool.apply_async(_multiprocess_handler, args=(
-            self._queue,
-            task.id,
-            task.context.operation_details,
-            task.context.inputs))
-
-    def close(self):
-        self._pool.close()
-        self._stopped = True
-        self._pool.join()
-        self._listener_thread.join()
-
-    def _listener(self):
-        while not self._stopped:
-            try:
-                message = self._queue.get(timeout=1)
-                if message.type == 'task_started':
-                    self._task_started(self._tasks[message.task_id])
-                elif message.type == 'task_succeeded':
-                    self._task_succeeded(self._remove_task(message.task_id))
-                elif message.type == 'task_failed':
-                    self._task_failed(self._remove_task(message.task_id),
-                                      exception=jsonpickle.loads(message.exception))
-                else:
-                    # TODO: something
-                    raise RuntimeError()
-            # Daemon threads
-            except BaseException:
-                pass
-
-    def _remove_task(self, task_id):
-        return self._tasks.pop(task_id)
-
-
-class _MultiprocessMessage(object):
-
-    def __init__(self, type, task_id, exception=None):
-        self.type = type
-        self.task_id = task_id
-        self.exception = exception
-
-
-def _multiprocess_handler(queue, task_id, operation_details, operation_inputs):
-    queue.put(_MultiprocessMessage(type='task_started', task_id=task_id))
-    try:
-        task_func = module.load_attribute(operation_details['operation'])
-        task_func(**operation_inputs)
-        queue.put(_MultiprocessMessage(type='task_succeeded', task_id=task_id))
-    except BaseException as e:
-        queue.put(_MultiprocessMessage(type='task_failed', task_id=task_id,
-                                       exception=jsonpickle.dumps(e)))

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c0bf3479/aria/workflows/executor/__init__.py
----------------------------------------------------------------------
diff --git a/aria/workflows/executor/__init__.py b/aria/workflows/executor/__init__.py
new file mode 100644
index 0000000..ae1e83e
--- /dev/null
+++ b/aria/workflows/executor/__init__.py
@@ -0,0 +1,14 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c0bf3479/aria/workflows/executor/base.py
----------------------------------------------------------------------
diff --git a/aria/workflows/executor/base.py b/aria/workflows/executor/base.py
new file mode 100644
index 0000000..118ab2b
--- /dev/null
+++ b/aria/workflows/executor/base.py
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Base executor module
+"""
+
+from aria import events
+
+
+class BaseExecutor(object):
+    """
+    Base class for executors for running tasks
+    """
+
+    def __init__(self, *args, **kwargs):
+        pass
+
+    def execute(self, task):
+        """
+        Execute a task
+        :param task: task to execute
+        """
+        raise NotImplementedError
+
+    def close(self):
+        """
+        Close the executor
+        """
+        pass
+
+    @staticmethod
+    def _task_started(task):
+        events.start_task_signal.send(task)
+
+    @staticmethod
+    def _task_failed(task, exception):
+        events.on_failure_task_signal.send(task, exception=exception)
+
+    @staticmethod
+    def _task_succeeded(task):
+        events.on_success_task_signal.send(task)

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c0bf3479/aria/workflows/executor/blocking.py
----------------------------------------------------------------------
diff --git a/aria/workflows/executor/blocking.py b/aria/workflows/executor/blocking.py
new file mode 100644
index 0000000..86171ba
--- /dev/null
+++ b/aria/workflows/executor/blocking.py
@@ -0,0 +1,37 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Blocking executor
+"""
+
+from aria.tools import module
+from .base import BaseExecutor
+
+
+class CurrentThreadBlockingExecutor(BaseExecutor):
+    """
+    Executor which runs tasks in the current thread (blocking)
+    """
+
+    def execute(self, task):
+        self._task_started(task)
+        try:
+            operation_context = task.context
+            task_func = module.load_attribute(operation_context.operation_details['operation'])
+            task_func(**operation_context.inputs)
+            self._task_succeeded(task)
+        except BaseException as e:
+            self._task_failed(task, exception=e)

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c0bf3479/aria/workflows/executor/celery.py
----------------------------------------------------------------------
diff --git a/aria/workflows/executor/celery.py b/aria/workflows/executor/celery.py
new file mode 100644
index 0000000..2d486f2
--- /dev/null
+++ b/aria/workflows/executor/celery.py
@@ -0,0 +1,96 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Celery based executor
+"""
+
+import threading
+import Queue
+
+from .base import BaseExecutor
+
+
+class CeleryExecutor(BaseExecutor):
+    """
+    Executor which runs tasks using celery
+    """
+
+    def __init__(self, app, *args, **kwargs):
+        super(CeleryExecutor, self).__init__(*args, **kwargs)
+        self._app = app
+        self._started_signaled = False
+        self._started_queue = Queue.Queue(maxsize=1)
+        self._tasks = {}
+        self._results = {}
+        self._receiver = None
+        self._stopped = False
+        self._receiver_thread = threading.Thread(target=self._events_receiver)
+        self._receiver_thread.daemon = True
+        self._receiver_thread.start()
+        self._started_queue.get(timeout=30)
+
+    def execute(self, task):
+        operation_context = task.context
+        self._tasks[task.id] = task
+        self._results[task.id] = self._app.send_task(
+            operation_context.operation_details['operation'],
+            kwargs=operation_context.inputs,
+            task_id=task.id,
+            queue=self._get_queue(task))
+
+    def close(self):
+        self._stopped = True
+        if self._receiver:
+            self._receiver.should_stop = True
+        self._receiver_thread.join()
+
+    @staticmethod
+    def _get_queue(task):
+        return None if task else None  # TODO
+
+    def _events_receiver(self):
+        with self._app.connection() as connection:
+            self._receiver = self._app.events.Receiver(connection, handlers={
+                'task-started': self._celery_task_started,
+                'task-succeeded': self._celery_task_succeeded,
+                'task-failed': self._celery_task_failed,
+            })
+            for _ in self._receiver.itercapture(limit=None, timeout=None, wakeup=True):
+                if not self._started_signaled:
+                    self._started_queue.put(True)
+                    self._started_signaled = True
+                if self._stopped:
+                    return
+
+    def _celery_task_started(self, event):
+        self._task_started(self._tasks[event['uuid']])
+
+    def _celery_task_succeeded(self, event):
+        task, _ = self._remove_task(event['uuid'])
+        self._task_succeeded(task)
+
+    def _celery_task_failed(self, event):
+        task, async_result = self._remove_task(event['uuid'])
+        try:
+            exception = async_result.result
+        except BaseException as e:
+            exception = RuntimeError(
+                'Could not de-serialize exception of task {0} --> {1}: {2}'
+                .format(task.name, type(e).__name__, str(e)))
+        self._task_failed(task, exception=exception)
+
+    def _remove_task(self, task_id):
+        return self._tasks.pop(task_id), self._results.pop(task_id)

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c0bf3479/aria/workflows/executor/multiprocess.py
----------------------------------------------------------------------
diff --git a/aria/workflows/executor/multiprocess.py b/aria/workflows/executor/multiprocess.py
new file mode 100644
index 0000000..e6faf5f
--- /dev/null
+++ b/aria/workflows/executor/multiprocess.py
@@ -0,0 +1,98 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Multiprocess based executor
+"""
+
+import threading
+import multiprocessing
+
+import jsonpickle
+
+from aria.tools import module
+from .base import BaseExecutor
+
+
+class MultiprocessExecutor(BaseExecutor):
+    """
+    Executor which runs tasks in a multiprocess environment
+    """
+
+    def __init__(self, pool_size=1, *args, **kwargs):
+        super(MultiprocessExecutor, self).__init__(*args, **kwargs)
+        self._stopped = False
+        self._manager = multiprocessing.Manager()
+        self._queue = self._manager.Queue()
+        self._tasks = {}
+        self._listener_thread = threading.Thread(target=self._listener)
+        self._listener_thread.daemon = True
+        self._listener_thread.start()
+        self._pool = multiprocessing.Pool(processes=pool_size,
+                                          maxtasksperchild=1)
+
+    def execute(self, task):
+        self._tasks[task.id] = task
+        self._pool.apply_async(_multiprocess_handler, args=(
+            self._queue,
+            task.id,
+            task.context.operation_details,
+            task.context.inputs))
+
+    def close(self):
+        self._pool.close()
+        self._stopped = True
+        self._pool.join()
+        self._listener_thread.join()
+
+    def _listener(self):
+        while not self._stopped:
+            try:
+                message = self._queue.get(timeout=1)
+                if message.type == 'task_started':
+                    self._task_started(self._tasks[message.task_id])
+                elif message.type == 'task_succeeded':
+                    self._task_succeeded(self._remove_task(message.task_id))
+                elif message.type == 'task_failed':
+                    self._task_failed(self._remove_task(message.task_id),
+                                      exception=jsonpickle.loads(message.exception))
+                else:
+                    # TODO: something
+                    raise RuntimeError()
+            # Daemon threads
+            except BaseException:
+                pass
+
+    def _remove_task(self, task_id):
+        return self._tasks.pop(task_id)
+
+
+class _MultiprocessMessage(object):
+
+    def __init__(self, type, task_id, exception=None):
+        self.type = type
+        self.task_id = task_id
+        self.exception = exception
+
+
+def _multiprocess_handler(queue, task_id, operation_details, operation_inputs):
+    queue.put(_MultiprocessMessage(type='task_started', task_id=task_id))
+    try:
+        task_func = module.load_attribute(operation_details['operation'])
+        task_func(**operation_inputs)
+        queue.put(_MultiprocessMessage(type='task_succeeded', task_id=task_id))
+    except BaseException as e:
+        queue.put(_MultiprocessMessage(type='task_failed', task_id=task_id,
+                                       exception=jsonpickle.dumps(e)))

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c0bf3479/aria/workflows/executor/thread.py
----------------------------------------------------------------------
diff --git a/aria/workflows/executor/thread.py b/aria/workflows/executor/thread.py
new file mode 100644
index 0000000..dfc0f18
--- /dev/null
+++ b/aria/workflows/executor/thread.py
@@ -0,0 +1,67 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Thread based executor
+"""
+
+import threading
+import Queue
+
+from aria.tools import module
+from .base import BaseExecutor
+
+
+class ThreadExecutor(BaseExecutor):
+    """
+    Executor which runs tasks in a separate thread
+    """
+
+    def __init__(self, pool_size=1, *args, **kwargs):
+        super(ThreadExecutor, self).__init__(*args, **kwargs)
+        self._stopped = False
+        self._queue = Queue.Queue()
+        self._pool = []
+        for i in range(pool_size):
+            name = 'ThreadExecutor-{index}'.format(index=i+1)
+            thread = threading.Thread(target=self._processor, name=name)
+            thread.daemon = True
+            thread.start()
+            self._pool.append(thread)
+
+    def execute(self, task):
+        self._queue.put(task)
+
+    def close(self):
+        self._stopped = True
+        for thread in self._pool:
+            thread.join()
+
+    def _processor(self):
+        while not self._stopped:
+            try:
+                task = self._queue.get(timeout=1)
+                self._task_started(task)
+                try:
+                    operation_context = task.context
+                    task_func = module.load_attribute(
+                        operation_context.operation_details['operation'])
+                    task_func(**operation_context.inputs)
+                    self._task_succeeded(task)
+                except BaseException as e:
+                    self._task_failed(task, exception=e)
+            # Daemon threads
+            except BaseException:
+                pass

http://git-wip-us.apache.org/repos/asf/incubator-ariatosca/blob/c0bf3479/tests/workflows/test_executor.py
----------------------------------------------------------------------
diff --git a/tests/workflows/test_executor.py b/tests/workflows/test_executor.py
index 27cb2ad..7457fd0 100644
--- a/tests/workflows/test_executor.py
+++ b/tests/workflows/test_executor.py
@@ -21,20 +21,34 @@ import retrying
 
 from aria import events
 from aria.storage import models
-from aria.workflows.core import executor
+from aria.workflows.executor import (
+    thread,
+    multiprocess,
+    blocking,
+    # celery
+)
+
+try:
+    import celery as _celery
+    app = _celery.Celery()
+    app.conf.update(CELERY_RESULT_BACKEND='amqp://')
+except ImportError:
+    _celery = None
+    app = None
 
 
 class TestExecutor(object):
 
-    @pytest.mark.parametrize('pool_size,executor_cls', [
-        (1, executor.ThreadExecutor),
-        (2, executor.ThreadExecutor),
-        (1, executor.MultiprocessExecutor),
-        (2, executor.MultiprocessExecutor),
-        (0, executor.CurrentThreadBlockingExecutor)
+    @pytest.mark.parametrize('executor_cls,executor_kwargs', [
+        (thread.ThreadExecutor, {'pool_size': 1}),
+        (thread.ThreadExecutor, {'pool_size': 2}),
+        (multiprocess.MultiprocessExecutor, {'pool_size': 1}),
+        (multiprocess.MultiprocessExecutor, {'pool_size': 2}),
+        (blocking.CurrentThreadBlockingExecutor, {}),
+        # (celery.CeleryExecutor, {'app': app})
     ])
-    def test_execute(self, pool_size, executor_cls):
-        self.executor = executor_cls(pool_size)
+    def test_execute(self, executor_cls, executor_kwargs):
+        self.executor = executor_cls(**executor_kwargs)
         expected_value = 'value'
         successful_task = MockTask(mock_successful_task)
         failing_task = MockTask(mock_failing_task)
@@ -48,8 +62,8 @@ class TestExecutor(object):
             assert successful_task.states == ['start', 'success']
             assert failing_task.states == ['start', 'failure']
             assert task_with_inputs.states == ['start', 'failure']
-            assert isinstance(failing_task.exception, TestException)
-            assert isinstance(task_with_inputs.exception, TestException)
+            assert isinstance(failing_task.exception, MockException)
+            assert isinstance(task_with_inputs.exception, MockException)
             assert task_with_inputs.exception.message == expected_value
         assertion()
 
@@ -71,14 +85,19 @@ def mock_successful_task():
 
 
 def mock_failing_task():
-    raise TestException
+    raise MockException
 
 
 def mock_task_with_input(input):
-    raise TestException(input)
+    raise MockException(input)
 
+if app:
+    mock_successful_task = app.task(mock_successful_task)
+    mock_failing_task = app.task(mock_failing_task)
+    mock_task_with_input = app.task(mock_task_with_input)
 
-class TestException(Exception):
+
+class MockException(Exception):
     pass
 
 


Mime
View raw message