airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From bo...@apache.org
Subject incubator-airflow git commit: [AIRFLOW-1296] Propagate SKIPPED to all downstream tasks
Date Wed, 21 Jun 2017 08:12:18 GMT
Repository: incubator-airflow
Updated Branches:
  refs/heads/master ea240cd1d -> a45e2d188


[AIRFLOW-1296] Propagate SKIPPED to all downstream tasks

The ShortCircuitOperator and LatestOnlyOperator
did not mark
all downstream tasks as skipped, but only direct
downstream
tasks.

Closes #2365 from bolkedebruin/AIRFLOW-719-3


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/a45e2d18
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/a45e2d18
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/a45e2d18

Branch: refs/heads/master
Commit: a45e2d1888ffb19dab8401e07b10724090bf20f0
Parents: ea240cd
Author: Bolke de Bruin <bolke@xs4all.nl>
Authored: Wed Jun 21 10:12:09 2017 +0200
Committer: Bolke de Bruin <bolke@xs4all.nl>
Committed: Wed Jun 21 10:12:09 2017 +0200

----------------------------------------------------------------------
 airflow/models.py                         | 41 ++++++++++++
 airflow/operators/latest_only_operator.py | 43 +++----------
 airflow/operators/python_operator.py      | 88 ++++++--------------------
 tests/operators/latest_only_operator.py   | 82 ++++++++++++++++++++++++
 tests/operators/python_operator.py        | 16 +++--
 5 files changed, 161 insertions(+), 109 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a45e2d18/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index c628958..2c433ad 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -1802,6 +1802,47 @@ class Log(Base):
         self.owner = owner or task_owner
 
 
+class SkipMixin(object):
+    def skip(self, dag_run, execution_date, tasks):
+        """
+        Sets tasks instances to skipped from the same dag run.
+        :param dag_run: the DagRun for which to set the tasks to skipped
+        :param execution_date: execution_date
+        :param tasks: tasks to skip (not task_ids)
+        """
+        if not tasks:
+            return
+
+        task_ids = [d.task_id for d in tasks]
+        now = datetime.now()
+        session = settings.Session()
+
+        if dag_run:
+            session.query(TaskInstance).filter(
+                TaskInstance.dag_id == dag_run.dag_id,
+                TaskInstance.execution_date == dag_run.execution_date,
+                TaskInstance.task_id.in_(task_ids)
+            ).update({TaskInstance.state : State.SKIPPED,
+                      TaskInstance.start_date: now,
+                      TaskInstance.end_date: now},
+                     synchronize_session=False)
+            session.commit()
+        else:
+            assert execution_date is not None, "Execution date is None and no dag run"
+
+            logging.warning("No DAG RUN present this should not happen")
+            # this is defensive against dag runs that are not complete
+            for task in tasks:
+                ti = TaskInstance(task, execution_date=execution_date)
+                ti.state = State.SKIPPED
+                ti.start_date = now
+                ti.end_date = now
+                session.merge(ti)
+
+            session.commit()
+        session.close()
+
+
 @functools.total_ordering
 class BaseOperator(object):
     """

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a45e2d18/airflow/operators/latest_only_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/latest_only_operator.py b/airflow/operators/latest_only_operator.py
index f1d8085..909a211 100644
--- a/airflow/operators/latest_only_operator.py
+++ b/airflow/operators/latest_only_operator.py
@@ -15,12 +15,10 @@
 import datetime
 import logging
 
-from airflow.models import BaseOperator, TaskInstance
-from airflow.utils.state import State
-from airflow import settings
+from airflow.models import BaseOperator, SkipMixin
 
 
-class LatestOnlyOperator(BaseOperator):
+class LatestOnlyOperator(BaseOperator, SkipMixin):
     """
     Allows a workflow to skip tasks that are not running during the most
     recent schedule interval.
@@ -49,39 +47,14 @@ class LatestOnlyOperator(BaseOperator):
 
         if not left_window < now <= right_window:
             logging.info('Not latest execution, skipping downstream.')
-            downstream_task_ids = context['task'].downstream_task_ids
-            if downstream_task_ids:
-                session = settings.Session()
-                TI = TaskInstance
-                tis = session.query(TI).filter(
-                    TI.execution_date == context['ti'].execution_date,
-                    TI.task_id.in_(downstream_task_ids)
-                ).with_for_update().all()
 
-                for ti in tis:
-                    logging.info('Skipping task: %s', ti.task_id)
-                    ti.state = State.SKIPPED
-                    ti.start_date = now
-                    ti.end_date = now
-                    session.merge(ti)
+            downstream_tasks = context['task'].get_flat_relatives(upstream=False)
+            logging.debug("Downstream task_ids {}".format(downstream_tasks))
 
-                # this is defensive against dag runs that are not complete
-                for task in context['task'].downstream_list:
-                    if task.task_id in tis:
-                        continue
-
-                    logging.warning("Task {} was not part of a dag run. "
-                                    "This should not happen."
-                                    .format(task))
-                    now = datetime.datetime.now()
-                    ti = TaskInstance(task, execution_date=context['ti'].execution_date)
-                    ti.state = State.SKIPPED
-                    ti.start_date = now
-                    ti.end_date = now
-                    session.merge(ti)
-
-                session.commit()
-                session.close()
+            if downstream_tasks:
+                self.skip(context['dag_run'],
+                          context['ti'].execution_date,
+                          downstream_tasks)
 
             logging.info('Done.')
         else:

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a45e2d18/airflow/operators/python_operator.py
----------------------------------------------------------------------
diff --git a/airflow/operators/python_operator.py b/airflow/operators/python_operator.py
index cf240f2..bef9bb0 100644
--- a/airflow/operators/python_operator.py
+++ b/airflow/operators/python_operator.py
@@ -13,14 +13,11 @@
 # limitations under the License.
 
 from builtins import str
-from datetime import datetime
 import logging
 
 from airflow.exceptions import AirflowException
-from airflow.models import BaseOperator, TaskInstance
-from airflow.utils.state import State
+from airflow.models import BaseOperator, SkipMixin
 from airflow.utils.decorators import apply_defaults
-from airflow import settings
 
 
 class PythonOperator(BaseOperator):
@@ -85,7 +82,7 @@ class PythonOperator(BaseOperator):
         return return_value
 
 
-class BranchPythonOperator(PythonOperator):
+class BranchPythonOperator(PythonOperator, SkipMixin):
     """
     Allows a workflow to "branch" or follow a single path following the
     execution of this task.
@@ -106,45 +103,20 @@ class BranchPythonOperator(PythonOperator):
     """
     def execute(self, context):
         branch = super(BranchPythonOperator, self).execute(context)
-        logging.info("Following branch " + branch)
+        logging.info("Following branch {}".format(branch))
         logging.info("Marking other directly downstream tasks as skipped")
-        session = settings.Session()
-
-        TI = TaskInstance
-        tis = session.query(TI).filter(
-            TI.execution_date == context['ti'].execution_date,
-            TI.task_id.in_(context['task'].downstream_task_ids),
-            TI.task_id != branch,
-        ).with_for_update().all()
-
-        for ti in tis:
-            logging.info('Skipping task: %s', ti.task_id)
-            ti.state = State.SKIPPED
-            ti.start_date = datetime.now()
-            ti.end_date = datetime.now()
-
-        # this is defensive against dag runs that are not complete
-        for task in context['task'].downstream_list:
-            if task.task_id in tis:
-                continue
-
-            if task.task_id == branch:
-                continue
-
-            logging.warning("Task {} was not part of a dag run. This should not happen."
-                            .format(task))
-            ti = TaskInstance(task, execution_date=context['ti'].execution_date)
-            ti.state = State.SKIPPED
-            ti.start_date = datetime.now()
-            ti.end_date = datetime.now()
-            session.merge(ti)
-
-        session.commit()
-        session.close()
+
+        downstream_tasks = context['task'].downstream_list
+        logging.debug("Downstream task_ids {}".format(downstream_tasks))
+
+        skip_tasks = [t for t in downstream_tasks if t.task_id != branch]
+        if downstream_tasks:
+            self.skip(context['dag_run'], context['ti'].execution_date, skip_tasks)
+
         logging.info("Done.")
 
 
-class ShortCircuitOperator(PythonOperator):
+class ShortCircuitOperator(PythonOperator, SkipMixin):
     """
     Allows a workflow to continue only if a condition is met. Otherwise, the
     workflow "short-circuits" and downstream tasks are skipped.
@@ -165,33 +137,11 @@ class ShortCircuitOperator(PythonOperator):
             return
 
         logging.info('Skipping downstream tasks...')
-        session = settings.Session()
-
-        TI = TaskInstance
-        tis = session.query(TI).filter(
-            TI.execution_date == context['ti'].execution_date,
-            TI.task_id.in_(context['task'].downstream_task_ids),
-        ).with_for_update().all()
-
-        for ti in tis:
-            logging.info('Skipping task: %s', ti.task_id)
-            ti.state = State.SKIPPED
-            ti.start_date = datetime.now()
-            ti.end_date = datetime.now()
-
-        # this is defensive against dag runs that are not complete
-        for task in context['task'].downstream_list:
-            if task.task_id in tis:
-                continue
-
-            logging.warning("Task {} was not part of a dag run. This should not happen."
-                            .format(task))
-            ti = TaskInstance(task, execution_date=context['ti'].execution_date)
-            ti.state = State.SKIPPED
-            ti.start_date = datetime.now()
-            ti.end_date = datetime.now()
-            session.merge(ti)
-
-        session.commit()
-        session.close()
+
+        downstream_tasks = context['task'].get_flat_relatives(upstream=False)
+        logging.debug("Downstream task_ids {}".format(downstream_tasks))
+
+        if downstream_tasks:
+            self.skip(context['dag_run'], context['ti'].execution_date, downstream_tasks)
+
         logging.info("Done.")

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a45e2d18/tests/operators/latest_only_operator.py
----------------------------------------------------------------------
diff --git a/tests/operators/latest_only_operator.py b/tests/operators/latest_only_operator.py
index 9137491..225d24f 100644
--- a/tests/operators/latest_only_operator.py
+++ b/tests/operators/latest_only_operator.py
@@ -23,6 +23,7 @@ from airflow.jobs import BackfillJob
 from airflow.models import TaskInstance
 from airflow.operators.latest_only_operator import LatestOnlyOperator
 from airflow.operators.dummy_operator import DummyOperator
+from airflow.utils.state import State
 from freezegun import freeze_time
 
 DEFAULT_DATE = datetime.datetime(2016, 1, 1)
@@ -69,10 +70,82 @@ class LatestOnlyOperatorTest(unittest.TestCase):
         downstream_task = DummyOperator(
             task_id='downstream',
             dag=self.dag)
+        downstream_task2 = DummyOperator(
+            task_id='downstream_2',
+            dag=self.dag)
+
+        downstream_task.set_upstream(latest_task)
+        downstream_task2.set_upstream(downstream_task)
+
+        latest_task.run(start_date=DEFAULT_DATE, end_date=END_DATE)
+        downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE)
+        downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE)
+
+        latest_instances = get_task_instances('latest')
+        exec_date_to_latest_state = {
+            ti.execution_date: ti.state for ti in latest_instances}
+        self.assertEqual({
+            datetime.datetime(2016, 1, 1): 'success',
+            datetime.datetime(2016, 1, 1, 12): 'success',
+            datetime.datetime(2016, 1, 2): 'success', },
+            exec_date_to_latest_state)
+
+        downstream_instances = get_task_instances('downstream')
+        exec_date_to_downstream_state = {
+            ti.execution_date: ti.state for ti in downstream_instances}
+        self.assertEqual({
+            datetime.datetime(2016, 1, 1): 'skipped',
+            datetime.datetime(2016, 1, 1, 12): 'skipped',
+            datetime.datetime(2016, 1, 2): 'success',},
+            exec_date_to_downstream_state)
+
+        downstream_instances = get_task_instances('downstream_2')
+        exec_date_to_downstream_state = {
+            ti.execution_date: ti.state for ti in downstream_instances}
+        self.assertEqual({
+            datetime.datetime(2016, 1, 1): 'skipped',
+            datetime.datetime(2016, 1, 1, 12): 'skipped',
+            datetime.datetime(2016, 1, 2): 'success',},
+            exec_date_to_downstream_state)
+
+    def test_skipping_dagrun(self):
+        latest_task = LatestOnlyOperator(
+            task_id='latest',
+            dag=self.dag)
+        downstream_task = DummyOperator(
+            task_id='downstream',
+            dag=self.dag)
+        downstream_task2 = DummyOperator(
+            task_id='downstream_2',
+            dag=self.dag)
+
         downstream_task.set_upstream(latest_task)
+        downstream_task2.set_upstream(downstream_task)
+
+        dr1 = self.dag.create_dagrun(
+            run_id="manual__1",
+            start_date=datetime.datetime.now(),
+            execution_date=DEFAULT_DATE,
+            state=State.RUNNING
+        )
+
+        dr2 = self.dag.create_dagrun(
+            run_id="manual__2",
+            start_date=datetime.datetime.now(),
+            execution_date=datetime.datetime(2016, 1, 1, 12),
+            state=State.RUNNING
+        )
+
+        dr2 = self.dag.create_dagrun(
+            run_id="manual__3",
+            start_date=datetime.datetime.now(),
+            execution_date=END_DATE,
+            state=State.RUNNING
+        )
 
         latest_task.run(start_date=DEFAULT_DATE, end_date=END_DATE)
         downstream_task.run(start_date=DEFAULT_DATE, end_date=END_DATE)
+        downstream_task2.run(start_date=DEFAULT_DATE, end_date=END_DATE)
 
         latest_instances = get_task_instances('latest')
         exec_date_to_latest_state = {
@@ -91,3 +164,12 @@ class LatestOnlyOperatorTest(unittest.TestCase):
             datetime.datetime(2016, 1, 1, 12): 'skipped',
             datetime.datetime(2016, 1, 2): 'success',},
             exec_date_to_downstream_state)
+
+        downstream_instances = get_task_instances('downstream_2')
+        exec_date_to_downstream_state = {
+            ti.execution_date: ti.state for ti in downstream_instances}
+        self.assertEqual({
+            datetime.datetime(2016, 1, 1): 'skipped',
+            datetime.datetime(2016, 1, 1, 12): 'skipped',
+            datetime.datetime(2016, 1, 2): 'success',},
+            exec_date_to_downstream_state)

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/a45e2d18/tests/operators/python_operator.py
----------------------------------------------------------------------
diff --git a/tests/operators/python_operator.py b/tests/operators/python_operator.py
index 3aa8b6c..71432af 100644
--- a/tests/operators/python_operator.py
+++ b/tests/operators/python_operator.py
@@ -26,6 +26,7 @@ from airflow.settings import Session
 from airflow.utils.state import State
 
 from airflow.exceptions import AirflowException
+import logging
 
 DEFAULT_DATE = datetime.datetime(2016, 1, 1)
 END_DATE = datetime.datetime(2016, 1, 2)
@@ -158,6 +159,8 @@ class ShortCircuitOperatorTest(unittest.TestCase):
 
         self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag)
         self.branch_1.set_upstream(self.short_op)
+        self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag)
+        self.branch_2.set_upstream(self.branch_1)
         self.upstream = DummyOperator(task_id='upstream', dag=self.dag)
         self.upstream.set_downstream(self.short_op)
         self.dag.clear()
@@ -181,7 +184,7 @@ class ShortCircuitOperatorTest(unittest.TestCase):
             elif ti.task_id == 'upstream':
                 # should not exist
                 raise
-            elif ti.task_id == 'branch_1':
+            elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
                 self.assertEquals(ti.state, State.SKIPPED)
             else:
                 raise
@@ -196,7 +199,7 @@ class ShortCircuitOperatorTest(unittest.TestCase):
             elif ti.task_id == 'upstream':
                 # should not exist
                 raise
-            elif ti.task_id == 'branch_1':
+            elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
                 self.assertEquals(ti.state, State.NONE)
             else:
                 raise
@@ -205,6 +208,7 @@ class ShortCircuitOperatorTest(unittest.TestCase):
 
     def test_with_dag_run(self):
         self.value = False
+        logging.error("Tasks {}".format(self.dag.tasks))
         dr = self.dag.create_dagrun(
             run_id="manual__",
             start_date=datetime.datetime.now(),
@@ -216,29 +220,31 @@ class ShortCircuitOperatorTest(unittest.TestCase):
         self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
         tis = dr.get_task_instances()
+        self.assertEqual(len(tis), 4)
         for ti in tis:
             if ti.task_id == 'make_choice':
                 self.assertEquals(ti.state, State.SUCCESS)
             elif ti.task_id == 'upstream':
                 self.assertEquals(ti.state, State.SUCCESS)
-            elif ti.task_id == 'branch_1':
+            elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
                 self.assertEquals(ti.state, State.SKIPPED)
             else:
                 raise
 
         self.value = True
         self.dag.clear()
-
+        dr.verify_integrity()
         self.upstream.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
         self.short_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
 
         tis = dr.get_task_instances()
+        self.assertEqual(len(tis), 4)
         for ti in tis:
             if ti.task_id == 'make_choice':
                 self.assertEquals(ti.state, State.SUCCESS)
             elif ti.task_id == 'upstream':
                 self.assertEquals(ti.state, State.SUCCESS)
-            elif ti.task_id == 'branch_1':
+            elif ti.task_id == 'branch_1' or ti.task_id == 'branch_2':
                 self.assertEquals(ti.state, State.NONE)
             else:
                 raise


Mime
View raw message