airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From bo...@apache.org
Subject incubator-airflow git commit: [AIRFLOW-492] Make sure stat updates cannot fail a task
Date Wed, 26 Apr 2017 18:39:57 GMT
Repository: incubator-airflow
Updated Branches:
  refs/heads/master 4147d6b80 -> c2472ffa1


[AIRFLOW-492] Make sure stat updates cannot fail a task

Previously a failed commit into the db for the
statistics
could also fail a task. Secondly, the ui could
display
out of date statistics.

This patch reworks DagStat so that failure to
update the
statistics does not propagate. Next to that, it
make sure the ui always displays the latest
statistics.

Closes #2254 from bolkedebruin/AIRFLOW-492


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/c2472ffa
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/c2472ffa
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/c2472ffa

Branch: refs/heads/master
Commit: c2472ffa124ffc65b8762ea583554494624dbb6a
Parents: 4147d6b
Author: Bolke de Bruin <bolke@xs4all.nl>
Authored: Wed Apr 26 20:39:48 2017 +0200
Committer: Bolke de Bruin <bolke@xs4all.nl>
Committed: Wed Apr 26 20:39:48 2017 +0200

----------------------------------------------------------------------
 airflow/jobs.py      |   4 +-
 airflow/models.py    | 135 ++++++++++++++++++++++++++++++++--------------
 airflow/www/views.py |   7 +--
 tests/core.py        |  38 +++++++------
 tests/models.py      |  53 +++++++++++++++++-
 5 files changed, 171 insertions(+), 66 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c2472ffa/airflow/jobs.py
----------------------------------------------------------------------
diff --git a/airflow/jobs.py b/airflow/jobs.py
index b5d68b0..02449c5 100644
--- a/airflow/jobs.py
+++ b/airflow/jobs.py
@@ -1176,7 +1176,7 @@ class SchedulerJob(BaseJob):
             self._process_task_instances(dag, tis_out)
             self.manage_slas(dag)
 
-        models.DagStat.clean_dirty([d.dag_id for d in dags])
+        models.DagStat.update([d.dag_id for d in dags])
 
     def _process_executor_events(self):
         """
@@ -1968,7 +1968,7 @@ class BackfillJob(BaseJob):
                     active_dag_runs.remove(run)
 
                 if run.dag.is_paused:
-                    models.DagStat.clean_dirty([run.dag_id], session=session)
+                    models.DagStat.update([run.dag_id], session=session)
 
             msg = ' | '.join([
                 "[backfill progress]",

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c2472ffa/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index e6374d4..d2e41cf 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -30,6 +30,7 @@ import functools
 import getpass
 import imp
 import importlib
+import itertools
 import inspect
 import zipfile
 import jinja2
@@ -739,6 +740,7 @@ class TaskInstance(Base):
     even while multiple schedulers may be firing task instances.
     """
 
+
     __tablename__ = "task_instance"
 
     task_id = Column(String(ID_LEN), primary_key=True)
@@ -3129,7 +3131,7 @@ class DAG(BaseDag, LoggingMixin):
         for dr in drs:
             dr.state = state
             dirty_ids.append(dr.dag_id)
-        DagStat.clean_dirty(dirty_ids, session=session)
+        DagStat.update(dirty_ids, session=session)
 
     def clear(
             self, start_date=None, end_date=None,
@@ -3423,6 +3425,9 @@ class DAG(BaseDag, LoggingMixin):
             state=state
         )
         session.add(run)
+
+        DagStat.set_dirty(dag_id=self.dag_id, session=session)
+
         session.commit()
 
         run.dag = self
@@ -3432,12 +3437,7 @@ class DAG(BaseDag, LoggingMixin):
         run.verify_integrity(session=session)
 
         run.refresh_from_db()
-        DagStat.set_dirty(self.dag_id, session=session)
 
-        # add a placeholder row into DagStat table
-        if not session.query(DagStat).filter(DagStat.dag_id == self.dag_id).first():
-            session.add(DagStat(dag_id=self.dag_id, state=state, count=0, dirty=True))
-        session.commit()
         return run
 
     @staticmethod
@@ -3848,7 +3848,7 @@ class DagStat(Base):
     count = Column(Integer, default=0)
     dirty = Column(Boolean, default=False)
 
-    def __init__(self, dag_id, state, count, dirty=False):
+    def __init__(self, dag_id, state, count=0, dirty=False):
         self.dag_id = dag_id
         self.state = state
         self.count = count
@@ -3857,48 +3857,104 @@ class DagStat(Base):
     @staticmethod
     @provide_session
     def set_dirty(dag_id, session=None):
-        for dag in session.query(DagStat).filter(DagStat.dag_id == dag_id):
-            dag.dirty = True
-        session.commit()
+        """
+        :param dag_id: the dag_id to mark dirty
+        :param session: database session
+        :return: 
+        """
+        DagStat.create(dag_id=dag_id, session=session)
+
+        try:
+            stats = session.query(DagStat).filter(
+                DagStat.dag_id == dag_id
+            ).with_for_update().all()
+
+            for stat in stats:
+                stat.dirty = True
+            session.commit()
+        except Exception as e:
+            session.rollback()
+            logging.warning("Could not update dag stats for {}".format(dag_id))
+            logging.exception(e)
 
     @staticmethod
     @provide_session
-    def clean_dirty(dag_ids, session=None):
+    def update(dag_ids=None, dirty_only=True, session=None):
         """
-        Cleans out the dirty/out-of-sync rows from dag_stats table
+        Updates the stats for dirty/out-of-sync dags
 
-        :param dag_ids: dag_ids that may be dirty
+        :param dag_ids: dag_ids to be updated
         :type dag_ids: list
+        :param dirty_only: only updated for marked dirty, defaults to True
+        :type dirty_only: bool
+        :param session: db session to use
+        :type session: Session
         """
-        # avoid querying with an empty IN clause
-        if not dag_ids:
-            return
+        if dag_ids is not None:
+            dag_ids = set(dag_ids)
 
-        dag_ids = set(dag_ids)
+        try:
+            qry = session.query(DagStat)
 
-        qry = (
-            session.query(DagStat)
-            .filter(and_(DagStat.dag_id.in_(dag_ids), DagStat.dirty == True))
-        )
+            if dag_ids is not None:
+                qry = qry.filter(DagStat.dag_id.in_(dag_ids))
+            if dirty_only:
+                qry = qry.filter(DagStat.dirty == True)
 
-        dirty_ids = {dag.dag_id for dag in qry.all()}
-        qry.delete(synchronize_session='fetch')
-        session.commit()
+            qry = qry.with_for_update().all()
 
-        # avoid querying with an empty IN clause
-        if not dirty_ids:
-            return
+            ids = set([dag_stat.dag_id for dag_stat in qry])
 
-        qry = (
-            session.query(DagRun.dag_id, DagRun.state, func.count('*'))
-            .filter(DagRun.dag_id.in_(dirty_ids))
-            .group_by(DagRun.dag_id, DagRun.state)
-        )
+            # avoid querying with an empty IN clause
+            if len(ids) == 0:
+                session.commit()
+                return
 
-        for dag_id, state, count in qry:
-            session.add(DagStat(dag_id=dag_id, state=state, count=count))
+            dagstat_states = set(itertools.product(ids, State.dag_states))
+            qry = (
+                session.query(DagRun.dag_id, DagRun.state, func.count('*'))
+                .filter(DagRun.dag_id.in_(ids))
+                .group_by(DagRun.dag_id, DagRun.state)
+            )
 
-        session.commit()
+            counts = {(dag_id, state): count for dag_id, state, count in qry}
+            for dag_id, state in dagstat_states:
+                count = 0
+                if (dag_id, state) in counts:
+                    count = counts[(dag_id, state)]
+
+                session.merge(
+                    DagStat(dag_id=dag_id, state=state, count=count, dirty=False)
+                )
+
+            session.commit()
+        except Exception as e:
+            session.rollback()
+            logging.warning("Could not update dag stat table")
+            logging.exception(e)
+
+    @staticmethod
+    @provide_session
+    def create(dag_id, session=None):
+        """
+        Creates the missing states the stats table for the dag specified 
+        
+        :param dag_id: dag id of the dag to create stats for
+        :param session: database session
+        :return: 
+        """
+        # unfortunately sqlalchemy does not know upsert
+        qry = session.query(DagStat).filter(DagStat.dag_id == dag_id).all()
+        states = [dag_stat.state for dag_stat in qry]
+        for state in State.dag_states:
+            if state not in states:
+                try:
+                    session.merge(DagStat(dag_id=dag_id, state=state))
+                    session.commit()
+                except Exception as e:
+                    session.rollback()
+                    logging.warning("Could not create stat record")
+                    logging.exception(e)
 
 
 class DagRun(Base):
@@ -3944,10 +4000,11 @@ class DagRun(Base):
     def set_state(self, state):
         if self._state != state:
             self._state = state
-            # something really weird goes on here: if you try to close the session
-            # dag runs will end up detached
-            session = settings.Session()
-            DagStat.set_dirty(self.dag_id, session=session)
+            if self.dag_id is not None:
+                # something really weird goes on here: if you try to close the session
+                # dag runs will end up detached
+                session = settings.Session()
+                DagStat.set_dirty(self.dag_id, session=session)
 
     @declared_attr
     def state(self):

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c2472ffa/airflow/www/views.py
----------------------------------------------------------------------
diff --git a/airflow/www/views.py b/airflow/www/views.py
index b0a952c..c71c5f9 100644
--- a/airflow/www/views.py
+++ b/airflow/www/views.py
@@ -474,6 +474,8 @@ class Airflow(BaseView):
         ds = models.DagStat
         session = Session()
 
+        ds.update()
+
         qry = (
             session.query(ds.dag_id, ds.state, ds.count)
         )
@@ -2280,9 +2282,8 @@ class DagRunModelView(ModelViewOnly):
         session.commit()
         dirty_ids = []
         for row in deleted:
-            models.DagStat.set_dirty(row.dag_id, session=session)
             dirty_ids.append(row.dag_id)
-        models.DagStat.clean_dirty(dirty_ids, session=session)
+        models.DagStat.update(dirty_ids, dirty_only=False, session=session)
         session.close()
 
     @action('set_running', "Set state to 'running'", None)
@@ -2312,7 +2313,7 @@ class DagRunModelView(ModelViewOnly):
                 else:
                     dr.end_date = datetime.now()
             session.commit()
-            models.DagStat.clean_dirty(dirty_ids, session=session)
+            models.DagStat.update(dirty_ids, session=session)
             flash(
                 "{count} dag runs were set to '{target_state}'".format(**locals()))
         except Exception as ex:

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c2472ffa/tests/core.py
----------------------------------------------------------------------
diff --git a/tests/core.py b/tests/core.py
index 5b18a18..a4757a7 100644
--- a/tests/core.py
+++ b/tests/core.py
@@ -984,50 +984,48 @@ class CoreTest(unittest.TestCase):
         session.query(models.DagStat).delete()
         session.commit()
 
-        with warnings.catch_warnings(record=True) as caught_warnings:
-            models.DagStat.clean_dirty([], session=session)
-        self.assertEqual([], caught_warnings)
+        models.DagStat.update([], session=session)
 
         run1 = self.dag_bash.create_dagrun(
             run_id="run1",
             execution_date=DEFAULT_DATE,
             state=State.RUNNING)
 
-        with warnings.catch_warnings(record=True) as caught_warnings:
-            models.DagStat.clean_dirty([self.dag_bash.dag_id], session=session)
-        self.assertEqual([], caught_warnings)
+        models.DagStat.update([self.dag_bash.dag_id], session=session)
 
         qry = session.query(models.DagStat).all()
 
-        self.assertEqual(1, len(qry))
+        self.assertEqual(3, len(qry))
         self.assertEqual(self.dag_bash.dag_id, qry[0].dag_id)
-        self.assertEqual(State.RUNNING, qry[0].state)
-        self.assertEqual(1, qry[0].count)
-        self.assertFalse(qry[0].dirty)
+        for stats in qry:
+            if stats.state == State.RUNNING:
+                self.assertEqual(stats.count, 1)
+            else:
+                self.assertEqual(stats.count, 0)
+            self.assertFalse(stats.dirty)
 
         run2 = self.dag_bash.create_dagrun(
             run_id="run2",
             execution_date=DEFAULT_DATE+timedelta(days=1),
             state=State.RUNNING)
 
-        with warnings.catch_warnings(record=True) as caught_warnings:
-            models.DagStat.clean_dirty([self.dag_bash.dag_id], session=session)
-        self.assertEqual([], caught_warnings)
+        models.DagStat.update([self.dag_bash.dag_id], session=session)
 
         qry = session.query(models.DagStat).all()
 
-        self.assertEqual(1, len(qry))
+        self.assertEqual(3, len(qry))
         self.assertEqual(self.dag_bash.dag_id, qry[0].dag_id)
-        self.assertEqual(State.RUNNING, qry[0].state)
-        self.assertEqual(2, qry[0].count)
-        self.assertFalse(qry[0].dirty)
+        for stats in qry:
+            if stats.state == State.RUNNING:
+                self.assertEqual(stats.count, 2)
+            else:
+                self.assertEqual(stats.count, 0)
+            self.assertFalse(stats.dirty)
 
         session.query(models.DagRun).first().state = State.SUCCESS
         session.commit()
 
-        with warnings.catch_warnings(record=True) as caught_warnings:
-            models.DagStat.clean_dirty([self.dag_bash.dag_id], session=session)
-        self.assertEqual([], caught_warnings)
+        models.DagStat.update([self.dag_bash.dag_id], session=session)
 
         qry = session.query(models.DagStat).filter(models.DagStat.state == State.SUCCESS).all()
         self.assertEqual(1, len(qry))

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/c2472ffa/tests/models.py
----------------------------------------------------------------------
diff --git a/tests/models.py b/tests/models.py
index a30830e..2180896 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -27,7 +27,7 @@ from airflow import models, settings, AirflowException
 from airflow.exceptions import AirflowSkipException
 from airflow.models import DAG, TaskInstance as TI
 from airflow.models import State as ST
-from airflow.models import DagModel
+from airflow.models import DagModel, DagStat
 from airflow.operators.dummy_operator import DummyOperator
 from airflow.operators.bash_operator import BashOperator
 from airflow.operators.python_operator import PythonOperator
@@ -234,6 +234,55 @@ class DagTest(unittest.TestCase):
         session.close()
 
 
+class DagStatTest(unittest.TestCase):
+    def test_dagstats_crud(self):
+        DagStat.create(dag_id='test_dagstats_crud')
+
+        session = settings.Session()
+        qry = session.query(DagStat).filter(DagStat.dag_id == 'test_dagstats_crud')
+        self.assertEqual(len(qry.all()), len(State.dag_states))
+
+        DagStat.set_dirty(dag_id='test_dagstats_crud')
+        res = qry.all()
+
+        for stat in res:
+            self.assertTrue(stat.dirty)
+
+        # create missing
+        DagStat.set_dirty(dag_id='test_dagstats_crud_2')
+        qry2 = session.query(DagStat).filter(DagStat.dag_id == 'test_dagstats_crud_2')
+        self.assertEqual(len(qry2.all()), len(State.dag_states))
+
+        dag = DAG(
+            'test_dagstats_crud',
+            start_date=DEFAULT_DATE,
+            default_args={'owner': 'owner1'})
+
+        with dag:
+            op1 = DummyOperator(task_id='A')
+
+        now = datetime.datetime.now()
+        dr = dag.create_dagrun(
+            run_id='manual__' + now.isoformat(),
+            execution_date=now,
+            start_date=now,
+            state=State.FAILED,
+            external_trigger=False,
+        )
+
+        DagStat.update(dag_ids=['test_dagstats_crud'])
+        res = qry.all()
+        for stat in res:
+            if stat.state == State.FAILED:
+                self.assertEqual(stat.count, 1)
+            else:
+                self.assertEqual(stat.count, 0)
+
+        DagStat.update()
+        res = qry2.all()
+        for stat in res:
+            self.assertFalse(stat.dirty)
+
 class DagRunTest(unittest.TestCase):
 
     def create_dag_run(self, dag, state=State.RUNNING, task_states=None, execution_date=None):
@@ -419,7 +468,7 @@ class DagRunTest(unittest.TestCase):
         dag = DAG(
             dag_id='test_latest_runs_1',
             start_date=DEFAULT_DATE)
-        dag_1_run_1 = self.create_dag_run(dag, 
+        dag_1_run_1 = self.create_dag_run(dag,
                 execution_date=datetime.datetime(2015, 1, 1))
         dag_1_run_2 = self.create_dag_run(dag,
                 execution_date=datetime.datetime(2015, 1, 2))


Mime
View raw message