airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From criccom...@apache.org
Subject [12/36] incubator-airflow git commit: Merge pull request #2195 from bolkedebruin/AIRFLOW-719
Date Tue, 09 May 2017 17:35:41 GMT
Merge pull request #2195 from bolkedebruin/AIRFLOW-719

(cherry picked from commit 4a6bef69d1817a5fc3ddd6ffe14c2578eaa49cf0)
Signed-off-by: Bolke de Bruin <bolke@xs4all.nl>


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/dff6d21b
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/dff6d21b
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/dff6d21b

Branch: refs/heads/v1-8-stable
Commit: dff6d21bfd9a2585ca484fc8fd56aa100f640908
Parents: 9070a82
Author: Bolke de Bruin <bolke@xs4all.nl>
Authored: Tue Apr 4 17:04:12 2017 +0200
Committer: Bolke de Bruin <bolke@xs4all.nl>
Committed: Wed Apr 5 19:16:22 2017 +0200

----------------------------------------------------------------------
 airflow/operators/latest_only_operator.py     |  30 ++-
 airflow/operators/python_operator.py          |  82 +++++--
 airflow/ti_deps/deps/trigger_rule_dep.py      |   6 +-
 scripts/ci/requirements.txt                   |   1 +
 tests/dags/test_dagrun_short_circuit_false.py |  38 ----
 tests/models.py                               |  77 +++----
 tests/operators/__init__.py                   |   2 +
 tests/operators/latest_only_operator.py       |  12 +-
 tests/operators/python_operator.py            | 244 +++++++++++++++++++++
 9 files changed, 384 insertions(+), 108 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/airflow/operators/latest_only_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/latest_only_operator.py b/airflow/operators/latest_only_operator.py
index 8b4e614..9d5defb 100644
--- a/airflow/operators/latest_only_operator.py
+++ b/airflow/operators/latest_only_operator.py
@@ -34,7 +34,7 @@ class LatestOnlyOperator(BaseOperator):
     def execute(self, context):
         # If the DAG Run is externally triggered, then return without
         # skipping downstream tasks
-        if context['dag_run'].external_trigger:
+        if context['dag_run'] and context['dag_run'].external_trigger:
             logging.info("""Externally triggered DAG_Run:
                          allowing execution to proceed.""")
             return
@@ -46,17 +46,39 @@ class LatestOnlyOperator(BaseOperator):
         logging.info(
             'Checking latest only with left_window: %s right_window: %s '
             'now: %s', left_window, right_window, now)
+
         if not left_window < now <= right_window:
             logging.info('Not latest execution, skipping downstream.')
             session = settings.Session()
-            for task in context['task'].downstream_list:
-                ti = TaskInstance(
-                    task, execution_date=context['ti'].execution_date)
+
+            TI = TaskInstance
+            tis = session.query(TI).filter(
+                TI.execution_date == context['ti'].execution_date,
+                TI.task_id.in_(context['task'].downstream_task_ids)
+            ).with_for_update().all()
+
+            for ti in tis:
                 logging.info('Skipping task: %s', ti.task_id)
                 ti.state = State.SKIPPED
                 ti.start_date = now
                 ti.end_date = now
                 session.merge(ti)
+
+            # this is defensive against dag runs that are not complete
+            for task in context['task'].downstream_list:
+                if task.task_id in tis:
+                    continue
+
+                logging.warning("Task {} was not part of a dag run. "
+                                "This should not happen."
+                                .format(task))
+                now = datetime.datetime.now()
+                ti = TaskInstance(task, execution_date=context['ti'].execution_date)
+                ti.state = State.SKIPPED
+                ti.start_date = now
+                ti.end_date = now
+                session.merge(ti)
+
             session.commit()
             session.close()
             logging.info('Done.')

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/airflow/operators/python_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py
index b5f6386..114bc7e 100644
--- a/airflow/operators/python_operator.py
+++ b/airflow/operators/python_operator.py
@@ -106,14 +106,36 @@ class BranchPythonOperator(PythonOperator):
         logging.info("Following branch " + branch)
         logging.info("Marking other directly downstream tasks as skipped")
         session = settings.Session()
+
+        TI = TaskInstance
+        tis = session.query(TI).filter(
+            TI.execution_date == context['ti'].execution_date,
+            TI.task_id.in_(context['task'].downstream_task_ids),
+            TI.task_id != branch,
+        ).with_for_update().all()
+
+        for ti in tis:
+            logging.info('Skipping task: %s', ti.task_id)
+            ti.state = State.SKIPPED
+            ti.start_date = datetime.now()
+            ti.end_date = datetime.now()
+
+        # this is defensive against dag runs that are not complete
         for task in context['task'].downstream_list:
-            if task.task_id != branch:
-                ti = TaskInstance(
-                    task, execution_date=context['ti'].execution_date)
-                ti.state = State.SKIPPED
-                ti.start_date = datetime.now()
-                ti.end_date = datetime.now()
-                session.merge(ti)
+            if task.task_id in tis:
+                continue
+
+            if task.task_id == branch:
+                continue
+
+            logging.warning("Task {} was not part of a dag run. This should not happen."
+                            .format(task))
+            ti = TaskInstance(task, execution_date=context['ti'].execution_date)
+            ti.state = State.SKIPPED
+            ti.start_date = datetime.now()
+            ti.end_date = datetime.now()
+            session.merge(ti)
+
         session.commit()
         session.close()
         logging.info("Done.")
@@ -134,19 +156,39 @@ class ShortCircuitOperator(PythonOperator):
     def execute(self, context):
         condition = super(ShortCircuitOperator, self).execute(context)
         logging.info("Condition result is {}".format(condition))
+
         if condition:
             logging.info('Proceeding with downstream tasks...')
             return
-        else:
-            logging.info('Skipping downstream tasks...')
-            session = settings.Session()
-            for task in context['task'].downstream_list:
-                ti = TaskInstance(
-                    task, execution_date=context['ti'].execution_date)
-                ti.state = State.SKIPPED
-                ti.start_date = datetime.now()
-                ti.end_date = datetime.now()
-                session.merge(ti)
-            session.commit()
-            session.close()
-            logging.info("Done.")
+
+        logging.info('Skipping downstream tasks...')
+        session = settings.Session()
+
+        TI = TaskInstance
+        tis = session.query(TI).filter(
+            TI.execution_date == context['ti'].execution_date,
+            TI.task_id.in_(context['task'].downstream_task_ids),
+        ).with_for_update().all()
+
+        for ti in tis:
+            logging.info('Skipping task: %s', ti.task_id)
+            ti.state = State.SKIPPED
+            ti.start_date = datetime.now()
+            ti.end_date = datetime.now()
+
+        # this is defensive against dag runs that are not complete
+        for task in context['task'].downstream_list:
+            if task.task_id in tis:
+                continue
+
+            logging.warning("Task {} was not part of a dag run. This should not happen."
+                            .format(task))
+            ti = TaskInstance(task, execution_date=context['ti'].execution_date)
+            ti.state = State.SKIPPED
+            ti.start_date = datetime.now()
+            ti.end_date = datetime.now()
+            session.merge(ti)
+
+        session.commit()
+        session.close()
+        logging.info("Done.")

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/airflow/ti_deps/deps/trigger_rule_dep.py
----------------------------------------------------------------------
diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py
index da13bba..281ed51 100644
--- a/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -135,7 +135,7 @@ class TriggerRuleDep(BaseTIDep):
             if tr == TR.ALL_SUCCESS:
                 if upstream_failed or failed:
                     ti.set_state(State.UPSTREAM_FAILED, session)
-                elif skipped == upstream:
+                elif skipped:
                     ti.set_state(State.SKIPPED, session)
             elif tr == TR.ALL_FAILED:
                 if successes or skipped:
@@ -148,7 +148,7 @@ class TriggerRuleDep(BaseTIDep):
                     ti.set_state(State.SKIPPED, session)
 
         if tr == TR.ONE_SUCCESS:
-            if successes <= 0 and skipped <= 0:
+            if successes <= 0:
                 yield self._failing_status(
                     reason="Task's trigger rule '{0}' requires one upstream "
                     "task success, but none were found. "
@@ -162,7 +162,7 @@ class TriggerRuleDep(BaseTIDep):
                     "upstream_tasks_state={1}, upstream_task_ids={2}"
                     .format(tr, upstream_tasks_state, task.upstream_task_ids))
         elif tr == TR.ALL_SUCCESS:
-            num_failures = upstream - (successes + skipped)
+            num_failures = upstream - successes
             if num_failures > 0:
                 yield self._failing_status(
                     reason="Task's trigger rule '{0}' requires all upstream "

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/scripts/ci/requirements.txt
----------------------------------------------------------------------
diff --git a/scripts/ci/requirements.txt b/scripts/ci/requirements.txt
index a5786f6..9a2bce2 100644
--- a/scripts/ci/requirements.txt
+++ b/scripts/ci/requirements.txt
@@ -20,6 +20,7 @@ flask-cache
 flask-login==0.2.11
 Flask-WTF
 flower
+freezegun
 future
 gunicorn
 hdfs

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/tests/dags/test_dagrun_short_circuit_false.py
----------------------------------------------------------------------
diff --git a/tests/dags/test_dagrun_short_circuit_false.py b/tests/dags/test_dagrun_short_circuit_false.py
deleted file mode 100644
index 805ab67..0000000
--- a/tests/dags/test_dagrun_short_circuit_false.py
+++ /dev/null
@@ -1,38 +0,0 @@
-# -*- coding: utf-8 -*-
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from datetime import datetime
-
-from airflow.models import DAG
-from airflow.operators.python_operator import ShortCircuitOperator
-from airflow.operators.dummy_operator import DummyOperator
-
-
-# DAG that has its short circuit op fail and skip multiple downstream tasks
-dag = DAG(
-    dag_id='test_dagrun_short_circuit_false',
-    start_date=datetime(2017, 1, 1)
-)
-dag_task1 = ShortCircuitOperator(
-    task_id='test_short_circuit_false',
-    dag=dag,
-    python_callable=lambda: False)
-dag_task2 = DummyOperator(
-    task_id='test_state_skipped1',
-    dag=dag)
-dag_task3 = DummyOperator(
-    task_id='test_state_skipped2',
-    dag=dag)
-dag_task1.set_downstream(dag_task2)
-dag_task2.set_downstream(dag_task3)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/tests/models.py
----------------------------------------------------------------------
diff --git a/tests/models.py b/tests/models.py
index 83183f8..9478088 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -31,11 +31,12 @@ from airflow.models import DagModel
 from airflow.operators.dummy_operator import DummyOperator
 from airflow.operators.bash_operator import BashOperator
 from airflow.operators.python_operator import PythonOperator
+from airflow.operators.python_operator import ShortCircuitOperator
 from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
 from airflow.utils.state import State
 from mock import patch
 from nose_parameterized import parameterized
-from tests.core import TEST_DAG_FOLDER
+
 
 DEFAULT_DATE = datetime.datetime(2016, 1, 1)
 TEST_DAGS_FOLDER = os.path.join(
@@ -235,17 +236,13 @@ class DagTest(unittest.TestCase):
 
 class DagRunTest(unittest.TestCase):
 
-    def setUp(self):
-        self.dagbag = models.DagBag(dag_folder=TEST_DAG_FOLDER)
-
-    def create_dag_run(self, dag_id, state=State.RUNNING, task_states=None):
+    def create_dag_run(self, dag, state=State.RUNNING, task_states=None):
         now = datetime.datetime.now()
-        dag = self.dagbag.get_dag(dag_id)
         dag_run = dag.create_dagrun(
             run_id='manual__' + now.isoformat(),
             execution_date=now,
             start_date=now,
-            state=State.RUNNING,
+            state=state,
             external_trigger=False,
         )
 
@@ -298,33 +295,34 @@ class DagRunTest(unittest.TestCase):
         self.assertEqual(0, len(models.DagRun.find(dag_id=dag_id2, external_trigger=True)))
         self.assertEqual(1, len(models.DagRun.find(dag_id=dag_id2, external_trigger=False)))
 
-    def test_dagrun_running_when_upstream_skipped(self):
-        """
-        Tests that a DAG run is not failed when an upstream task is skipped
-        """
-        initial_task_states = {
-            'test_short_circuit_false': State.SUCCESS,
-            'test_state_skipped1': State.SKIPPED,
-            'test_state_skipped2': State.NONE,
-        }
-        # dags/test_dagrun_short_circuit_false.py
-        dag_run = self.create_dag_run('test_dagrun_short_circuit_false',
-                                      state=State.RUNNING,
-                                      task_states=initial_task_states)
-        updated_dag_state = dag_run.update_state()
-        self.assertEqual(State.RUNNING, updated_dag_state)
-
     def test_dagrun_success_when_all_skipped(self):
         """
         Tests that a DAG run succeeds when all tasks are skipped
         """
+        dag = DAG(
+            dag_id='test_dagrun_success_when_all_skipped',
+            start_date=datetime.datetime(2017, 1, 1)
+        )
+        dag_task1 = ShortCircuitOperator(
+            task_id='test_short_circuit_false',
+            dag=dag,
+            python_callable=lambda: False)
+        dag_task2 = DummyOperator(
+            task_id='test_state_skipped1',
+            dag=dag)
+        dag_task3 = DummyOperator(
+            task_id='test_state_skipped2',
+            dag=dag)
+        dag_task1.set_downstream(dag_task2)
+        dag_task2.set_downstream(dag_task3)
+
         initial_task_states = {
             'test_short_circuit_false': State.SUCCESS,
             'test_state_skipped1': State.SKIPPED,
             'test_state_skipped2': State.SKIPPED,
         }
-        # dags/test_dagrun_short_circuit_false.py
-        dag_run = self.create_dag_run('test_dagrun_short_circuit_false',
+
+        dag_run = self.create_dag_run(dag=dag,
                                       state=State.RUNNING,
                                       task_states=initial_task_states)
         updated_dag_state = dag_run.update_state()
@@ -385,10 +383,17 @@ class DagRunTest(unittest.TestCase):
         """
         Make sure that a proper value is returned when a dagrun has no task instances
         """
+        dag = DAG(
+            dag_id='test_get_task_instance_on_empty_dagrun',
+            start_date=datetime.datetime(2017, 1, 1)
+        )
+        dag_task1 = ShortCircuitOperator(
+            task_id='test_short_circuit_false',
+            dag=dag,
+            python_callable=lambda: False)
+
         session = settings.Session()
 
-        # Any dag will work for this
-        dag = self.dagbag.get_dag('test_dagrun_short_circuit_false')
         now = datetime.datetime.now()
 
         # Don't use create_dagrun since it will create the task instances too which we
@@ -784,7 +789,7 @@ class TaskInstanceTest(unittest.TestCase):
         self.assertEqual(dt, ti.end_date+max_delay)
 
     def test_depends_on_past(self):
-        dagbag = models.DagBag(dag_folder=TEST_DAG_FOLDER)
+        dagbag = models.DagBag()
         dag = dagbag.get_dag('test_depends_on_past')
         dag.clear()
         task = dag.tasks[0]
@@ -813,11 +818,10 @@ class TaskInstanceTest(unittest.TestCase):
         #
         # Tests for all_success
         #
-        ['all_success', 5, 0, 0, 0, 5, True, None, True],
-        ['all_success', 2, 0, 0, 0, 2, True, None, False],
-        ['all_success', 2, 0, 1, 0, 3, True, ST.UPSTREAM_FAILED, False],
-        ['all_success', 2, 1, 0, 0, 3, True, None, False],
-        ['all_success', 0, 5, 0, 0, 5, True, ST.SKIPPED, True],
+        ['all_success', 5, 0, 0, 0, 0, True, None, True],
+        ['all_success', 2, 0, 0, 0, 0, True, None, False],
+        ['all_success', 2, 0, 1, 0, 0, True, ST.UPSTREAM_FAILED, False],
+        ['all_success', 2, 1, 0, 0, 0, True, ST.SKIPPED, False],
         #
         # Tests for one_success
         #
@@ -825,7 +829,6 @@ class TaskInstanceTest(unittest.TestCase):
         ['one_success', 2, 0, 0, 0, 2, True, None, True],
         ['one_success', 2, 0, 1, 0, 3, True, None, True],
         ['one_success', 2, 1, 0, 0, 3, True, None, True],
-        ['one_success', 0, 2, 0, 0, 2, True, None, True],
         #
         # Tests for all_failed
         #
@@ -837,9 +840,9 @@ class TaskInstanceTest(unittest.TestCase):
         #
         # Tests for one_failed
         #
-        ['one_failed', 5, 0, 0, 0, 5, True, ST.SKIPPED, False],
-        ['one_failed', 2, 0, 0, 0, 2, True, None, False],
-        ['one_failed', 2, 0, 1, 0, 2, True, None, True],
+        ['one_failed', 5, 0, 0, 0, 0, True, None, False],
+        ['one_failed', 2, 0, 0, 0, 0, True, None, False],
+        ['one_failed', 2, 0, 1, 0, 0, True, None, True],
         ['one_failed', 2, 1, 0, 0, 3, True, None, False],
         ['one_failed', 2, 3, 0, 0, 5, True, ST.SKIPPED, False],
         #

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/tests/operators/__init__.py
----------------------------------------------------------------------
diff --git a/tests/operators/__init__.py b/tests/operators/__init__.py
index 1fb0e5e..aeb243c 100644
--- a/tests/operators/__init__.py
+++ b/tests/operators/__init__.py
@@ -18,3 +18,5 @@ from .operators import *
 from .sensors import *
 from .hive_operator import *
 from .s3_to_hive_operator import *
+from .python_operator import *
+from .latest_only_operator import *

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/tests/operators/latest_only_operator.py
----------------------------------------------------------------------
diff --git a/tests/operators/latest_only_operator.py b/tests/operators/latest_only_operator.py
index 37aec38..9137491 100644
--- a/tests/operators/latest_only_operator.py
+++ b/tests/operators/latest_only_operator.py
@@ -77,17 +77,17 @@ class LatestOnlyOperatorTest(unittest.TestCase):
         latest_instances = get_task_instances('latest')
         exec_date_to_latest_state = {
             ti.execution_date: ti.state for ti in latest_instances}
-        assert exec_date_to_latest_state == {
+        self.assertEqual({
             datetime.datetime(2016, 1, 1): 'success',
             datetime.datetime(2016, 1, 1, 12): 'success',
-            datetime.datetime(2016, 1, 2): 'success',
-        }
+            datetime.datetime(2016, 1, 2): 'success', },
+            exec_date_to_latest_state)
 
         downstream_instances = get_task_instances('downstream')
         exec_date_to_downstream_state = {
             ti.execution_date: ti.state for ti in downstream_instances}
-        assert exec_date_to_downstream_state == {
+        self.assertEqual({
             datetime.datetime(2016, 1, 1): 'skipped',
             datetime.datetime(2016, 1, 1, 12): 'skipped',
-            datetime.datetime(2016, 1, 2): 'success',
-        }
+            datetime.datetime(2016, 1, 2): 'success',},
+            exec_date_to_downstream_state)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dff6d21b/tests/operators/python_operator.py
----------------------------------------------------------------------
diff --git a/tests/operators/python_operator.py b/tests/operators/python_operator.py
new file mode 100644
index 0000000..3aa8b6c
--- /dev/null
+++ b/tests/operators/python_operator.py
@@ -0,0 +1,244 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import print_function, unicode_literals
+
+import datetime
+import unittest
+
+from airflow import configuration, DAG
+from airflow.models import TaskInstance as TI
+from airflow.operators.python_operator import PythonOperator, BranchPythonOperator
+from airflow.operators.python_operator import ShortCircuitOperator
+from airflow.operators.dummy_operator import DummyOperator
+from airflow.settings import Session
+from airflow.utils.state import State
+
+from airflow.exceptions import AirflowException
+
+DEFAULT_DATE = datetime.datetime(2016, 1, 1)
+END_DATE = datetime.datetime(2016, 1, 2)
+INTERVAL = datetime.timedelta(hours=12)
+FROZEN_NOW = datetime.datetime(2016, 1, 2, 12, 1, 1)
+
+
+class PythonOperatorTest(unittest.TestCase):
+
+    def setUp(self):
+        super(PythonOperatorTest, self).setUp()
+        configuration.load_test_config()
+        self.dag = DAG(
+            'test_dag',
+            default_args={
+                'owner': 'airflow',
+                'start_date': DEFAULT_DATE},
+            schedule_interval=INTERVAL)
+        self.addCleanup(self.dag.clear)
+        self.clear_run()
+        self.addCleanup(self.clear_run)
+
+    def do_run(self):
+        self.run = True
+
+    def clear_run(self):
+        self.run = False
+
+    def is_run(self):
+        return self.run
+
+    def test_python_operator_run(self):
+        """Tests that the python callable is invoked on task run."""
+        task = PythonOperator(
+            python_callable=self.do_run,
+            task_id='python_operator',
+            dag=self.dag)
+        self.assertFalse(self.is_run())
+        task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        self.assertTrue(self.is_run())
+
+    def test_python_operator_python_callable_is_callable(self):
+        """Tests that PythonOperator will only instantiate if
+        the python_callable argument is callable."""
+        not_callable = {}
+        with self.assertRaises(AirflowException):
+            PythonOperator(
+                python_callable=not_callable,
+                task_id='python_operator',
+                dag=self.dag)
+        not_callable = None
+        with self.assertRaises(AirflowException):
+            PythonOperator(
+                python_callable=not_callable,
+                task_id='python_operator',
+                dag=self.dag)
+
+
+class BranchOperatorTest(unittest.TestCase):
+    def setUp(self):
+        self.dag = DAG('branch_operator_test',
+                       default_args={
+                           'owner': 'airflow',
+                           'start_date': DEFAULT_DATE},
+                       schedule_interval=INTERVAL)
+        self.branch_op = BranchPythonOperator(task_id='make_choice',
+                                              dag=self.dag,
+                                              python_callable=lambda: 'branch_1')
+
+        self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag)
+        self.branch_1.set_upstream(self.branch_op)
+        self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag)
+        self.branch_2.set_upstream(self.branch_op)
+        self.dag.clear()
+
+    def test_without_dag_run(self):
+        """This checks the defensive against non existent tasks in a dag run"""
+        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        session = Session()
+        tis = session.query(TI).filter(
+            TI.dag_id == self.dag.dag_id,
+            TI.execution_date == DEFAULT_DATE
+        )
+        session.close()
+
+        for ti in tis:
+            if ti.task_id == 'make_choice':
+                self.assertEquals(ti.state, State.SUCCESS)
+            elif ti.task_id == 'branch_1':
+                # should not exist
+                raise
+            elif ti.task_id == 'branch_2':
+                self.assertEquals(ti.state, State.SKIPPED)
+            else:
+                raise
+
+    def test_with_dag_run(self):
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=datetime.datetime.now(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING
+        )
+
+        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        tis = dr.get_task_instances()
+        for ti in tis:
+            if ti.task_id == 'make_choice':
+                self.assertEquals(ti.state, State.SUCCESS)
+            elif ti.task_id == 'branch_1':
+                self.assertEquals(ti.state, State.NONE)
+            elif ti.task_id == 'branch_2':
+                self.assertEquals(ti.state, State.SKIPPED)
+            else:
+                raise
+
+
+class ShortCircuitOperatorTest(unittest.TestCase):
+    def setUp(self):
+        self.dag = DAG('shortcircuit_operator_test',
+                       default_args={
+                           'owner': 'airflow',
+                           'start_date': DEFAULT_DATE},
+                       schedule_interval=INTERVAL)
+        self.short_op = ShortCircuitOperator(task_id='make_choice',
+                                             dag=self.dag,
+                                             python_callable=lambda: self.value)
+
+        self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag)
+        self.branch_1.set_upstream(self.short_op)
+        self.upstream = DummyOperator(task_id='upstream', dag=self.dag)
+        self.upstream.set_downstream(self.short_op)
+        self.dag.clear()
+
+        self.value = True
+
+    def test_without_dag_run(self):
+        """This checks the defensive against non existent tasks in a dag run"""
+        self.value = False
+        self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        session = Session()
+        tis = session.query(TI).filter(
+            TI.dag_id == self.dag.dag_id,
+            TI.execution_date == DEFAULT_DATE
+        )
+
+        for ti in tis:
+            if ti.task_id == 'make_choice':
+                self.assertEquals(ti.state, State.SUCCESS)
+            elif ti.task_id == 'upstream':
+                # should not exist
+                raise
+            elif ti.task_id == 'branch_1':
+                self.assertEquals(ti.state, State.SKIPPED)
+            else:
+                raise
+
+        self.value = True
+        self.dag.clear()
+
+        self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        for ti in tis:
+            if ti.task_id == 'make_choice':
+                self.assertEquals(ti.state, State.SUCCESS)
+            elif ti.task_id == 'upstream':
+                # should not exist
+                raise
+            elif ti.task_id == 'branch_1':
+                self.assertEquals(ti.state, State.NONE)
+            else:
+                raise
+
+        session.close()
+
+    def test_with_dag_run(self):
+        self.value = False
+        dr = self.dag.create_dagrun(
+            run_id="manual__",
+            start_date=datetime.datetime.now(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING
+        )
+
+        self.upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        tis = dr.get_task_instances()
+        for ti in tis:
+            if ti.task_id == 'make_choice':
+                self.assertEquals(ti.state, State.SUCCESS)
+            elif ti.task_id == 'upstream':
+                self.assertEquals(ti.state, State.SUCCESS)
+            elif ti.task_id == 'branch_1':
+                self.assertEquals(ti.state, State.SKIPPED)
+            else:
+                raise
+
+        self.value = True
+        self.dag.clear()
+
+        self.upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+        self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+        tis = dr.get_task_instances()
+        for ti in tis:
+            if ti.task_id == 'make_choice':
+                self.assertEquals(ti.state, State.SUCCESS)
+            elif ti.task_id == 'upstream':
+                self.assertEquals(ti.state, State.SUCCESS)
+            elif ti.task_id == 'branch_1':
+                self.assertEquals(ti.state, State.NONE)
+            else:
+                raise


Mime
View raw message