airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From bo...@apache.org
Subject incubator-airflow git commit: [AIRFLOW-1873] Set TI.try_number to right value depending TI state
Date Thu, 07 Dec 2017 13:41:55 GMT
Repository: incubator-airflow
Updated Branches:
  refs/heads/v1-9-test 1b210337e -> 9a552c4fe


[AIRFLOW-1873] Set TI.try_number to right value depending TI state

Rather than having try_number+1 in various places,
try_number
will now automatically contain the right value for
when the TI
will next be run, and handle the case where
try_number is
accessed when the task is currently running.

This showed up as a bug where the logs from
running operators would
show up in the next log file (2.log for the first
try)

Closes #2832 from ashb/AIRFLOW-1873-task-operator-
log-try-number

(cherry picked from commit 4b4e504eeae81e48f3c9d796a61dd9e86000c663)
Signed-off-by: Bolke de Bruin <bolke@xs4all.nl>


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/9a552c4f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/9a552c4f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/9a552c4f

Branch: refs/heads/v1-9-test
Commit: 9a552c4fe88d362f79d644ec04340b7b63af795f
Parents: 1b21033
Author: Ash Berlin-Taylor <ash_github@firemirror.com>
Authored: Thu Dec 7 13:31:38 2017 +0000
Committer: Bolke de Bruin <bolke@xs4all.nl>
Committed: Thu Dec 7 13:41:34 2017 +0000

----------------------------------------------------------------------
 airflow/models.py                      | 44 ++++++++++----
 airflow/utils/log/file_task_handler.py |  8 +--
 airflow/utils/log/gcs_task_handler.py  |  4 +-
 airflow/utils/log/s3_task_handler.py   |  4 +-
 tests/jobs.py                          |  3 +-
 tests/models.py                        | 89 +++++++++++++++++++----------
 tests/utils/test_log_handlers.py       | 52 ++++++++++++++---
 7 files changed, 144 insertions(+), 60 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9a552c4f/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index 76f879c..e979b07 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -132,13 +132,13 @@ def clear_task_instances(tis, session, activate_dag_runs=True, dag=None):
             if dag and dag.has_task(task_id):
                 task = dag.get_task(task_id)
                 task_retries = task.retries
-                ti.max_tries = ti.try_number + task_retries
+                ti.max_tries = ti.try_number + task_retries - 1
             else:
                 # Ignore errors when updating max_tries if dag is None or
                 # task not found in dag since database records could be
                 # outdated. We make max_tries the maximum value of its
                 # original max_tries or the current task try number.
-                ti.max_tries = max(ti.max_tries, ti.try_number)
+                ti.max_tries = max(ti.max_tries, ti.try_number - 1)
             ti.state = State.NONE
             session.merge(ti)
 
@@ -769,7 +769,7 @@ class TaskInstance(Base, LoggingMixin):
     end_date = Column(DateTime)
     duration = Column(Float)
     state = Column(String(20))
-    try_number = Column(Integer, default=0)
+    _try_number = Column('try_number', Integer, default=0)
     max_tries = Column(Integer)
     hostname = Column(String(1000))
     unixname = Column(String(1000))
@@ -811,6 +811,24 @@ class TaskInstance(Base, LoggingMixin):
         """ Initialize the attributes that aren't stored in the DB. """
         self.test_mode = False  # can be changed when calling 'run'
 
+    @property
+    def try_number(self):
+        """
+        Return the try number that this task number will be when it is acutally
+        run.
+
+        If the TI is currently running, this will match the column in the
+        databse, in all othercases this will be incremenetd
+        """
+        # This is designed so that task logs end up in the right file.
+        if self.state == State.RUNNING:
+            return self._try_number
+        return self._try_number + 1
+
+    @try_number.setter
+    def try_number(self, value):
+        self._try_number = value
+
     def command(
             self,
             mark_success=False,
@@ -1039,7 +1057,9 @@ class TaskInstance(Base, LoggingMixin):
             self.state = ti.state
             self.start_date = ti.start_date
             self.end_date = ti.end_date
-            self.try_number = ti.try_number
+            # Get the raw value of try_number column, don't read through the
+            # accessor here otherwise it will be incremeneted by one already.
+            self.try_number = ti._try_number
             self.max_tries = ti.max_tries
             self.hostname = ti.hostname
             self.pid = ti.pid
@@ -1339,7 +1359,7 @@ class TaskInstance(Base, LoggingMixin):
         # not 0-indexed lists (i.e. Attempt 1 instead of
         # Attempt 0 for the first attempt).
         msg = "Starting attempt {attempt} of {total}".format(
-            attempt=self.try_number + 1,
+            attempt=self.try_number,
             total=self.max_tries + 1)
         self.start_date = datetime.utcnow()
 
@@ -1361,7 +1381,7 @@ class TaskInstance(Base, LoggingMixin):
             self.state = State.NONE
             msg = ("FIXME: Rescheduling due to concurrency limits reached at task "
                    "runtime. Attempt {attempt} of {total}. State set to NONE.").format(
-                attempt=self.try_number + 1,
+                attempt=self.try_number,
                 total=self.max_tries + 1)
             self.log.warning(hr + msg + hr)
 
@@ -1381,7 +1401,7 @@ class TaskInstance(Base, LoggingMixin):
 
         # print status message
         self.log.info(hr + msg + hr)
-        self.try_number += 1
+        self._try_number += 1
 
         if not test_mode:
             session.add(Log(State.RUNNING, self))
@@ -1583,10 +1603,10 @@ class TaskInstance(Base, LoggingMixin):
 
         # Let's go deeper
         try:
-            # try_number is incremented by 1 during task instance run. So the
-            # current task instance try_number is the try_number for the next
-            # task instance run. We only mark task instance as FAILED if the
-            # next task instance try_number exceeds the max_tries.
+            # Since this function is called only when the TI state is running,
+            # try_number contains the current try_number (not the next). We
+            # only mark task instance as FAILED if the next task instance
+            # try_number exceeds the max_tries.
             if task.retries and self.try_number <= self.max_tries:
                 self.state = State.UP_FOR_RETRY
                 self.log.info('Marking task as UP_FOR_RETRY')
@@ -1751,7 +1771,7 @@ class TaskInstance(Base, LoggingMixin):
             "Host: {self.hostname}<br>"
             "Log file: {self.log_filepath}<br>"
             "Mark success: <a href='{self.mark_success_url}'>Link</a><br>"
-        ).format(try_number=self.try_number + 1, max_tries=self.max_tries + 1, **locals())
+        ).format(try_number=self.try_number, max_tries=self.max_tries + 1, **locals())
         send_email(task.email, title, body)
 
     def set_duration(self):

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9a552c4f/airflow/utils/log/file_task_handler.py
----------------------------------------------------------------------
diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py
index 6038fbf..9a8061a 100644
--- a/airflow/utils/log/file_task_handler.py
+++ b/airflow/utils/log/file_task_handler.py
@@ -89,7 +89,7 @@ class FileTaskHandler(logging.Handler):
         # Task instance here might be different from task instance when
         # initializing the handler. Thus explicitly getting log location
         # is needed to get correct log path.
-        log_relative_path = self._render_filename(ti, try_number + 1)
+        log_relative_path = self._render_filename(ti, try_number)
         location = os.path.join(self.local_base, log_relative_path)
 
         log = ""
@@ -142,8 +142,8 @@ class FileTaskHandler(logging.Handler):
         next_try = task_instance.try_number
 
         if try_number is None:
-            try_numbers = list(range(next_try))
-        elif try_number < 0:
+            try_numbers = list(range(1, next_try))
+        elif try_number < 1:
             logs = ['Error fetching the logs. Try number {} is invalid.'.format(try_number)]
             return logs
         else:
@@ -174,7 +174,7 @@ class FileTaskHandler(logging.Handler):
         # writable by both users, then it's possible that re-running a task
         # via the UI (or vice versa) results in a permission error as the task
         # tries to write to a log file created by the other user.
-        relative_path = self._render_filename(ti, ti.try_number + 1)
+        relative_path = self._render_filename(ti, ti.try_number)
         full_path = os.path.join(self.local_base, relative_path)
         directory = os.path.dirname(full_path)
         # Create the log file and give it group writable permissions

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9a552c4f/airflow/utils/log/gcs_task_handler.py
----------------------------------------------------------------------
diff --git a/airflow/utils/log/gcs_task_handler.py b/airflow/utils/log/gcs_task_handler.py
index c11e7ad..1520347 100644
--- a/airflow/utils/log/gcs_task_handler.py
+++ b/airflow/utils/log/gcs_task_handler.py
@@ -58,7 +58,7 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin):
         # Log relative path is used to construct local and remote
         # log path to upload log files into GCS and read from the
         # remote location.
-        self.log_relative_path = self._render_filename(ti, ti.try_number + 1)
+        self.log_relative_path = self._render_filename(ti, ti.try_number)
 
     def close(self):
         """
@@ -94,7 +94,7 @@ class GCSTaskHandler(FileTaskHandler, LoggingMixin):
         # Explicitly getting log relative path is necessary as the given
         # task instance might be different than task instance passed in
         # in set_context method.
-        log_relative_path = self._render_filename(ti, try_number + 1)
+        log_relative_path = self._render_filename(ti, try_number)
         remote_loc = os.path.join(self.remote_base, log_relative_path)
 
         if self.gcs_log_exists(remote_loc):

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9a552c4f/airflow/utils/log/s3_task_handler.py
----------------------------------------------------------------------
diff --git a/airflow/utils/log/s3_task_handler.py b/airflow/utils/log/s3_task_handler.py
index cfa966a..5ff90c6 100644
--- a/airflow/utils/log/s3_task_handler.py
+++ b/airflow/utils/log/s3_task_handler.py
@@ -53,7 +53,7 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
         super(S3TaskHandler, self).set_context(ti)
         # Local location and remote location is needed to open and
         # upload local log file to S3 remote storage.
-        self.log_relative_path = self._render_filename(ti, ti.try_number + 1)
+        self.log_relative_path = self._render_filename(ti, ti.try_number)
 
     def close(self):
         """
@@ -89,7 +89,7 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
         # Explicitly getting log relative path is necessary as the given
         # task instance might be different than task instance passed in
         # in set_context method.
-        log_relative_path = self._render_filename(ti, try_number + 1)
+        log_relative_path = self._render_filename(ti, try_number)
         remote_loc = os.path.join(self.remote_base, log_relative_path)
 
         if self.s3_log_exists(remote_loc):

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9a552c4f/tests/jobs.py
----------------------------------------------------------------------
diff --git a/tests/jobs.py b/tests/jobs.py
index e8fff7e..9d2f363 100644
--- a/tests/jobs.py
+++ b/tests/jobs.py
@@ -2352,10 +2352,11 @@ class SchedulerJobTest(unittest.TestCase):
         (command, priority, queue, ti) = ti_tuple
         ti.task = dag_task1
 
+        self.assertEqual(ti.try_number, 1)
         # fail execution
         run_with_error(ti)
         self.assertEqual(ti.state, State.UP_FOR_RETRY)
-        self.assertEqual(ti.try_number, 1)
+        self.assertEqual(ti.try_number, 2)
 
         ti.refresh_from_db(lock_for_update=True, session=session)
         ti.state = State.SCHEDULED

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9a552c4f/tests/models.py
----------------------------------------------------------------------
diff --git a/tests/models.py b/tests/models.py
index a1de17d..88be50e 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -923,10 +923,11 @@ class TaskInstanceTest(unittest.TestCase):
         ti = TI(
             task=task, execution_date=datetime.datetime.now())
 
+        self.assertEqual(ti.try_number, 1)
         # first run -- up for retry
         run_with_error(ti)
         self.assertEqual(ti.state, State.UP_FOR_RETRY)
-        self.assertEqual(ti.try_number, 1)
+        self.assertEqual(ti.try_number, 2)
 
         # second run -- still up for retry because retry_delay hasn't expired
         run_with_error(ti)
@@ -963,16 +964,19 @@ class TaskInstanceTest(unittest.TestCase):
 
         ti = TI(
             task=task, execution_date=datetime.datetime.now())
+        self.assertEqual(ti.try_number, 1)
 
         # first run -- up for retry
         run_with_error(ti)
         self.assertEqual(ti.state, State.UP_FOR_RETRY)
-        self.assertEqual(ti.try_number, 1)
+        self.assertEqual(ti._try_number, 1)
+        self.assertEqual(ti.try_number, 2)
 
         # second run -- fail
         run_with_error(ti)
         self.assertEqual(ti.state, State.FAILED)
-        self.assertEqual(ti.try_number, 2)
+        self.assertEqual(ti._try_number, 2)
+        self.assertEqual(ti.try_number, 3)
 
         # Clear the TI state since you can't run a task with a FAILED state without
         # clearing it first
@@ -981,12 +985,15 @@ class TaskInstanceTest(unittest.TestCase):
         # third run -- up for retry
         run_with_error(ti)
         self.assertEqual(ti.state, State.UP_FOR_RETRY)
-        self.assertEqual(ti.try_number, 3)
+        self.assertEqual(ti._try_number, 3)
+        self.assertEqual(ti.try_number, 4)
 
         # fourth run -- fail
         run_with_error(ti)
+        ti.refresh_from_db()
         self.assertEqual(ti.state, State.FAILED)
-        self.assertEqual(ti.try_number, 4)
+        self.assertEqual(ti._try_number, 4)
+        self.assertEqual(ti.try_number, 5)
 
     def test_next_retry_datetime(self):
         delay = datetime.timedelta(seconds=30)
@@ -1007,17 +1014,16 @@ class TaskInstanceTest(unittest.TestCase):
             task=task, execution_date=DEFAULT_DATE)
         ti.end_date = datetime.datetime.now()
 
-        ti.try_number = 1
         dt = ti.next_retry_datetime()
         # between 30 * 2^0.5 and 30 * 2^1 (15 and 30)
         self.assertEqual(dt, ti.end_date + datetime.timedelta(seconds=20.0))
 
-        ti.try_number = 4
+        ti.try_number = 3
         dt = ti.next_retry_datetime()
         # between 30 * 2^2 and 30 * 2^3 (120 and 240)
         self.assertEqual(dt, ti.end_date + datetime.timedelta(seconds=181.0))
 
-        ti.try_number = 6
+        ti.try_number = 5
         dt = ti.next_retry_datetime()
         # between 30 * 2^4 and 30 * 2^5 (480 and 960)
         self.assertEqual(dt, ti.end_date + datetime.timedelta(seconds=825.0))
@@ -1224,7 +1230,11 @@ class TaskInstanceTest(unittest.TestCase):
         task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
         ti = TI(
             task=task, execution_date=datetime.datetime.now())
+        self.assertEqual(ti._try_number, 0)
         self.assertTrue(ti._check_and_change_state_before_execution())
+        # State should be running, and try_number column should be incremented
+        self.assertEqual(ti.state, State.RUNNING)
+        self.assertEqual(ti._try_number, 1)
 
     def test_check_and_change_state_before_execution_dep_not_met(self):
         dag = models.DAG(dag_id='test_check_and_change_state_before_execution')
@@ -1235,6 +1245,20 @@ class TaskInstanceTest(unittest.TestCase):
             task=task2, execution_date=datetime.datetime.now())
         self.assertFalse(ti._check_and_change_state_before_execution())
 
+    def test_try_number(self):
+        """
+        Test the try_number accessor behaves in various running states
+        """
+        dag = models.DAG(dag_id='test_check_and_change_state_before_execution')
+        task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
+        ti = TI(task=task, execution_date=datetime.datetime.utcnow())
+        self.assertEqual(1, ti.try_number)
+        ti.try_number = 2
+        ti.state = State.RUNNING
+        self.assertEqual(2, ti.try_number)
+        ti.state = State.SUCCESS
+        self.assertEqual(3, ti.try_number)
+
     def test_get_num_running_task_instances(self):
         session = settings.Session()
 
@@ -1257,7 +1281,7 @@ class TaskInstanceTest(unittest.TestCase):
         self.assertEquals(1, ti1.get_num_running_task_instances(session=session))
         self.assertEquals(1, ti2.get_num_running_task_instances(session=session))
         self.assertEquals(1, ti3.get_num_running_task_instances(session=session))
-        
+
 
 class ClearTasksTest(unittest.TestCase):
     def test_clear_task_instances(self):
@@ -1277,9 +1301,10 @@ class ClearTasksTest(unittest.TestCase):
         session.commit()
         ti0.refresh_from_db()
         ti1.refresh_from_db()
-        self.assertEqual(ti0.try_number, 1)
+        # Next try to run will be try 2
+        self.assertEqual(ti0.try_number, 2)
         self.assertEqual(ti0.max_tries, 1)
-        self.assertEqual(ti1.try_number, 1)
+        self.assertEqual(ti1.try_number, 2)
         self.assertEqual(ti1.max_tries, 3)
 
     def test_clear_task_instances_without_task(self):
@@ -1305,9 +1330,10 @@ class ClearTasksTest(unittest.TestCase):
         # When dag is None, max_tries will be maximum of original max_tries or try_number.
         ti0.refresh_from_db()
         ti1.refresh_from_db()
-        self.assertEqual(ti0.try_number, 1)
+        # Next try to run will be try 2
+        self.assertEqual(ti0.try_number, 2)
         self.assertEqual(ti0.max_tries, 1)
-        self.assertEqual(ti1.try_number, 1)
+        self.assertEqual(ti1.try_number, 2)
         self.assertEqual(ti1.max_tries, 2)
 
     def test_clear_task_instances_without_dag(self):
@@ -1328,9 +1354,10 @@ class ClearTasksTest(unittest.TestCase):
         # When dag is None, max_tries will be maximum of original max_tries or try_number.
         ti0.refresh_from_db()
         ti1.refresh_from_db()
-        self.assertEqual(ti0.try_number, 1)
+        # Next try to run will be try 2
+        self.assertEqual(ti0.try_number, 2)
         self.assertEqual(ti0.max_tries, 1)
-        self.assertEqual(ti1.try_number, 1)
+        self.assertEqual(ti1.try_number, 2)
         self.assertEqual(ti1.max_tries, 2)
 
     def test_dag_clear(self):
@@ -1338,12 +1365,13 @@ class ClearTasksTest(unittest.TestCase):
                   end_date=DEFAULT_DATE + datetime.timedelta(days=10))
         task0 = DummyOperator(task_id='test_dag_clear_task_0', owner='test', dag=dag)
         ti0 = TI(task=task0, execution_date=DEFAULT_DATE)
-        self.assertEqual(ti0.try_number, 0)
-        ti0.run()
+        # Next try to run will be try 1
         self.assertEqual(ti0.try_number, 1)
+        ti0.run()
+        self.assertEqual(ti0.try_number, 2)
         dag.clear()
         ti0.refresh_from_db()
-        self.assertEqual(ti0.try_number, 1)
+        self.assertEqual(ti0.try_number, 2)
         self.assertEqual(ti0.state, State.NONE)
         self.assertEqual(ti0.max_tries, 1)
 
@@ -1352,8 +1380,9 @@ class ClearTasksTest(unittest.TestCase):
         ti1 = TI(task=task1, execution_date=DEFAULT_DATE)
         self.assertEqual(ti1.max_tries, 2)
         ti1.try_number = 1
+        # Next try will be 2
         ti1.run()
-        self.assertEqual(ti1.try_number, 2)
+        self.assertEqual(ti1.try_number, 3)
         self.assertEqual(ti1.max_tries, 2)
 
         dag.clear()
@@ -1361,9 +1390,9 @@ class ClearTasksTest(unittest.TestCase):
         ti1.refresh_from_db()
         # after clear dag, ti2 should show attempt 3 of 5
         self.assertEqual(ti1.max_tries, 4)
-        self.assertEqual(ti1.try_number, 2)
+        self.assertEqual(ti1.try_number, 3)
         # after clear dag, ti1 should show attempt 2 of 2
-        self.assertEqual(ti0.try_number, 1)
+        self.assertEqual(ti0.try_number, 2)
         self.assertEqual(ti0.max_tries, 1)
 
     def test_dags_clear(self):
@@ -1383,7 +1412,7 @@ class ClearTasksTest(unittest.TestCase):
         for i in range(num_of_dags):
             tis[i].run()
             self.assertEqual(tis[i].state, State.SUCCESS)
-            self.assertEqual(tis[i].try_number, 1)
+            self.assertEqual(tis[i].try_number, 2)
             self.assertEqual(tis[i].max_tries, 0)
 
         DAG.clear_dags(dags)
@@ -1391,14 +1420,14 @@ class ClearTasksTest(unittest.TestCase):
         for i in range(num_of_dags):
             tis[i].refresh_from_db()
             self.assertEqual(tis[i].state, State.NONE)
-            self.assertEqual(tis[i].try_number, 1)
+            self.assertEqual(tis[i].try_number, 2)
             self.assertEqual(tis[i].max_tries, 1)
 
         # test dry_run
         for i in range(num_of_dags):
             tis[i].run()
             self.assertEqual(tis[i].state, State.SUCCESS)
-            self.assertEqual(tis[i].try_number, 2)
+            self.assertEqual(tis[i].try_number, 3)
             self.assertEqual(tis[i].max_tries, 1)
 
         DAG.clear_dags(dags, dry_run=True)
@@ -1406,7 +1435,7 @@ class ClearTasksTest(unittest.TestCase):
         for i in range(num_of_dags):
             tis[i].refresh_from_db()
             self.assertEqual(tis[i].state, State.SUCCESS)
-            self.assertEqual(tis[i].try_number, 2)
+            self.assertEqual(tis[i].try_number, 3)
             self.assertEqual(tis[i].max_tries, 1)
 
         # test only_failed
@@ -1422,11 +1451,11 @@ class ClearTasksTest(unittest.TestCase):
             tis[i].refresh_from_db()
             if i != failed_dag_idx:
                 self.assertEqual(tis[i].state, State.SUCCESS)
-                self.assertEqual(tis[i].try_number, 2)
+                self.assertEqual(tis[i].try_number, 3)
                 self.assertEqual(tis[i].max_tries, 1)
             else:
                 self.assertEqual(tis[i].state, State.NONE)
-                self.assertEqual(tis[i].try_number, 2)
+                self.assertEqual(tis[i].try_number, 3)
                 self.assertEqual(tis[i].max_tries, 2)
 
     def test_operator_clear(self):
@@ -1441,17 +1470,17 @@ class ClearTasksTest(unittest.TestCase):
         ti2 = TI(task=t2, execution_date=DEFAULT_DATE)
         ti2.run()
         # Dependency not met
-        self.assertEqual(ti2.try_number, 0)
+        self.assertEqual(ti2.try_number, 1)
         self.assertEqual(ti2.max_tries, 1)
 
         t2.clear(upstream=True)
         ti1.run()
         ti2.run()
-        self.assertEqual(ti1.try_number, 1)
+        self.assertEqual(ti1.try_number, 2)
         # max_tries is 0 because there is no task instance in db for ti1
         # so clear won't change the max_tries.
         self.assertEqual(ti1.max_tries, 0)
-        self.assertEqual(ti2.try_number, 1)
+        self.assertEqual(ti2.try_number, 2)
         # try_number (0) + retries(1)
         self.assertEqual(ti2.max_tries, 1)
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/9a552c4f/tests/utils/test_log_handlers.py
----------------------------------------------------------------------
diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py
index 25faa7c..54e8cff 100644
--- a/tests/utils/test_log_handlers.py
+++ b/tests/utils/test_log_handlers.py
@@ -12,17 +12,18 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import copy
 import logging
 import logging.config
-import mock
 import os
 import unittest
+import six
 
 from datetime import datetime
-from airflow.models import TaskInstance, DAG
+from airflow.models import TaskInstance, DAG, DagRun
 from airflow.config_templates.airflow_local_settings import DEFAULT_LOGGING_CONFIG
 from airflow.operators.dummy_operator import DummyOperator
+from airflow.settings import Session
+from airflow.utils.log.logging_mixin import set_context
 from airflow.utils.log.file_task_handler import FileTaskHandler
 
 DEFAULT_DATE = datetime(2016, 1, 1)
@@ -31,11 +32,22 @@ FILE_TASK_HANDLER = 'file.task'
 
 
 class TestFileTaskLogHandler(unittest.TestCase):
+    def cleanUp(self):
+        session = Session()
+
+        session.query(DagRun).delete()
+        session.query(TaskInstance).delete()
 
     def setUp(self):
         super(TestFileTaskLogHandler, self).setUp()
-        # We use file task handler by default.
         logging.config.dictConfig(DEFAULT_LOGGING_CONFIG)
+        logging.root.disabled = False
+        self.cleanUp()
+        # We use file task handler by default.
+
+    def tearDown(self):
+        self.cleanUp()
+        super(TestFileTaskLogHandler, self).tearDown()
 
     def test_default_task_logging_setup(self):
         # file task handler is used by default.
@@ -46,29 +58,51 @@ class TestFileTaskLogHandler(unittest.TestCase):
         self.assertEqual(handler.name, FILE_TASK_HANDLER)
 
     def test_file_task_handler(self):
+        def task_callable(ti, **kwargs):
+            ti.log.info("test")
         dag = DAG('dag_for_testing_file_task_handler', start_date=DEFAULT_DATE)
-        task = DummyOperator(task_id='task_for_testing_file_log_handler', dag=dag)
+        task = PythonOperator(
+            task_id='task_for_testing_file_log_handler',
+            dag=dag,
+            python_callable=task_callable,
+            provide_context=True
+        )
         ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
 
-        logger = logging.getLogger(TASK_LOGGER)
+        logger = ti.log
+        ti.log.disabled = False
+
         file_handler = next((handler for handler in logger.handlers
                              if handler.name == FILE_TASK_HANDLER), None)
         self.assertIsNotNone(file_handler)
 
-        file_handler.set_context(ti)
+        set_context(logger, ti)
         self.assertIsNotNone(file_handler.handler)
         # We expect set_context generates a file locally.
         log_filename = file_handler.handler.baseFilename
         self.assertTrue(os.path.isfile(log_filename))
+        self.assertTrue(log_filename.endswith("1.log"), log_filename)
+
+        ti.run(ignore_ti_state=True)
 
-        logger.info("test")
-        ti.run()
+        file_handler.flush()
+        file_handler.close()
 
         self.assertTrue(hasattr(file_handler, 'read'))
         # Return value of read must be a list.
         logs = file_handler.read(ti)
         self.assertTrue(isinstance(logs, list))
         self.assertEqual(len(logs), 1)
+        target_re = r'\n\[[^\]]+\] {test_log_handlers.py:\d+} INFO - test\n'
+
+        # We should expect our log line from the callable above to appear in
+        # the logs we read back
+        six.assertRegex(
+            self,
+            logs[0],
+            target_re,
+            "Logs were " + str(logs)
+        )
 
         # Remove the generated tmp log file.
         os.remove(log_filename)


Mime
View raw message