airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ephraimanier...@apache.org
Subject [airflow] branch main updated: Improve `dag_maker` fixture (#17324)
Date Mon, 02 Aug 2021 06:38:04 GMT
This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5c1e09c  Improve `dag_maker` fixture (#17324)
5c1e09c is described below

commit 5c1e09cafacea922b9281e901db7da7cadb3e9be
Author: Ephraim Anierobi <splendidzigy24@gmail.com>
AuthorDate: Mon Aug 2 07:37:40 2021 +0100

    Improve `dag_maker` fixture (#17324)
    
    This PR improves the dag_maker fixture to enable creation of dagrun, dag and dag_model
separately
    
    Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
---
 tests/conftest.py                 |  53 +++++-----
 tests/jobs/test_backfill_job.py   | 204 ++++++++++++++++++++------------------
 tests/jobs/test_local_task_job.py | 103 ++++++++-----------
 3 files changed, 175 insertions(+), 185 deletions(-)

diff --git a/tests/conftest.py b/tests/conftest.py
index 896e32a..48ac9b2 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -428,8 +428,9 @@ def app():
 
 @pytest.fixture
 def dag_maker(request):
-    from airflow.models import DAG
+    from airflow.models import DAG, DagModel
     from airflow.utils import timezone
+    from airflow.utils.session import provide_session
     from airflow.utils.state import State
 
     DEFAULT_DATE = timezone.datetime(2016, 1, 1)
@@ -444,33 +445,39 @@ def dag_maker(request):
             dag.__exit__(type, value, traceback)
             if type is None:
                 dag.clear()
-                self.dag_run = dag.create_dagrun(
-                    run_id=self.kwargs.get("run_id", "test"),
-                    state=self.kwargs.get('state', State.RUNNING),
-                    execution_date=self.kwargs.get('execution_date', self.kwargs['start_date']),
-                    start_date=self.kwargs['start_date'],
-                )
+
+        @provide_session
+        def make_dagmodel(self, session=None, **kwargs):
+            dag = self.dag
+            defaults = dict(dag_id=dag.dag_id, next_dagrun=dag.start_date, is_active=True)
+            kwargs = {**defaults, **kwargs}
+            dag_model = DagModel(**kwargs)
+            session.add(dag_model)
+            session.flush()
+            return dag_model
+
+        def create_dagrun(self, **kwargs):
+            dag = self.dag
+            defaults = dict(
+                run_id='test',
+                state=State.RUNNING,
+                execution_date=self.start_date,
+                start_date=self.start_date,
+            )
+            kwargs = {**defaults, **kwargs}
+            self.dag_run = dag.create_dagrun(**kwargs)
+            return self.dag_run
 
         def __call__(self, dag_id='test_dag', **kwargs):
             self.kwargs = kwargs
-            if "start_date" not in kwargs:
+            self.start_date = self.kwargs.get('start_date', None)
+            if not self.start_date:
                 if hasattr(request.module, 'DEFAULT_DATE'):
-                    kwargs['start_date'] = getattr(request.module, 'DEFAULT_DATE')
+                    self.start_date = getattr(request.module, 'DEFAULT_DATE')
                 else:
-                    kwargs['start_date'] = DEFAULT_DATE
-            dagrun_fields_not_in_dag = [
-                'state',
-                'execution_date',
-                'run_type',
-                'queued_at',
-                "run_id",
-                "creating_job_id",
-                "external_trigger",
-                "last_scheduling_decision",
-                "dag_hash",
-            ]
-            kwargs = {k: v for k, v in kwargs.items() if k not in dagrun_fields_not_in_dag}
-            self.dag = DAG(dag_id, **kwargs)
+                    self.start_date = DEFAULT_DATE
+            self.kwargs['start_date'] = self.start_date
+            self.dag = DAG(dag_id, **self.kwargs)
             return self
 
     return DagFactory()
diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py
index c110e63..d70606a 100644
--- a/tests/jobs/test_backfill_job.py
+++ b/tests/jobs/test_backfill_job.py
@@ -46,7 +46,7 @@ from airflow.utils.session import create_session
 from airflow.utils.state import State
 from airflow.utils.timeout import timeout
 from airflow.utils.types import DagRunType
-from tests.test_utils.db import clear_db_pools, clear_db_runs, set_default_pool_slots
+from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs, set_default_pool_slots
 from tests.test_utils.mock_executor import MockExecutor
 
 logger = logging.getLogger(__name__)
@@ -59,44 +59,10 @@ def dag_bag():
     return DagBag(include_examples=True)
 
 
-@pytest.fixture
-def get_dummy_dag_and_run(dag_maker):
-    def _get_dummy_dag_and_run(
-        dag_id='test_dag', pool=Pool.DEFAULT_POOL_NAME, task_concurrency=None, task_id='op',
**kwargs
-    ):
-        with dag_maker(dag_id=dag_id, schedule_interval='@daily', **kwargs) as dag:
-            DummyOperator(task_id=task_id, pool=pool, task_concurrency=task_concurrency)
-
-        return dag, dag_maker.dag_run
-
-    return _get_dummy_dag_and_run
-
-
-@pytest.fixture
-def get_dag_test_max_active_limits(dag_maker):
-    def _get_dag_test_max_active_limits(dag_id='test_dag', max_active_runs=1, **kwargs):
-        with dag_maker(
-            dag_id=dag_id,
-            start_date=DEFAULT_DATE,
-            schedule_interval="@hourly",
-            max_active_runs=max_active_runs,
-            **kwargs,
-        ) as dag:
-            op1 = DummyOperator(task_id='leave1')
-            op2 = DummyOperator(task_id='leave2')
-            op3 = DummyOperator(task_id='upstream_level_1')
-            op4 = DummyOperator(task_id='upstream_level_2')
-
-            op1 >> op2 >> op3
-            op4 >> op3
-        return dag, dag_maker.dag_run
-
-    return _get_dag_test_max_active_limits
-
-
 class TestBackfillJob:
     @staticmethod
     def clean_db():
+        clear_db_dags()
         clear_db_runs()
         clear_db_pools()
 
@@ -106,6 +72,20 @@ class TestBackfillJob:
         self.parser = cli_parser.get_parser()
         self.dagbag = dag_bag
 
+    def _get_dummy_dag(
+        self,
+        dag_maker_fixture,
+        dag_id='test_dag',
+        pool=Pool.DEFAULT_POOL_NAME,
+        task_concurrency=None,
+        task_id='op',
+        **kwargs,
+    ):
+        with dag_maker_fixture(dag_id=dag_id, schedule_interval='@daily', **kwargs) as dag:
+            DummyOperator(task_id=task_id, pool=pool, task_concurrency=task_concurrency)
+
+        return dag
+
     def _times_called_with(self, method, class_):
         count = 0
         for args in method.call_args_list:
@@ -113,8 +93,9 @@ class TestBackfillJob:
                 count += 1
         return count
 
-    def test_unfinished_dag_runs_set_to_failed(self, get_dummy_dag_and_run):
-        dag, dag_run = get_dummy_dag_and_run(dag_id='dummy_dag')
+    def test_unfinished_dag_runs_set_to_failed(self, dag_maker):
+        dag = self._get_dummy_dag(dag_maker)
+        dag_run = dag_maker.create_dagrun()
 
         job = BackfillJob(
             dag=dag,
@@ -129,8 +110,9 @@ class TestBackfillJob:
 
         assert State.FAILED == dag_run.state
 
-    def test_dag_run_with_finished_tasks_set_to_success(self, get_dummy_dag_and_run):
-        dag, dag_run = get_dummy_dag_and_run(dag_id='dummy_dag')
+    def test_dag_run_with_finished_tasks_set_to_success(self, dag_maker):
+        dag = self._get_dummy_dag(dag_maker)
+        dag_run = dag_maker.create_dagrun()
 
         for ti in dag_run.get_task_instances():
             ti.set_state(State.SUCCESS)
@@ -289,8 +271,9 @@ class TestBackfillJob:
             for task_id in expected_execution_order
         ] == executor.sorted_tasks
 
-    def test_backfill_conf(self, get_dummy_dag_and_run):
-        dag, _ = get_dummy_dag_and_run(dag_id='test_backfill_conf')
+    def test_backfill_conf(self, dag_maker):
+        dag = self._get_dummy_dag(dag_maker, dag_id='test_backfill_conf')
+        dag_maker.create_dagrun()
 
         executor = MockExecutor()
 
@@ -312,12 +295,14 @@ class TestBackfillJob:
         assert conf_ == dr[0].conf
 
     @patch('airflow.jobs.backfill_job.BackfillJob.log')
-    def test_backfill_respect_task_concurrency_limit(self, mock_log, get_dummy_dag_and_run):
+    def test_backfill_respect_task_concurrency_limit(self, mock_log, dag_maker):
         task_concurrency = 2
-        dag, _ = get_dummy_dag_and_run(
+        dag = self._get_dummy_dag(
+            dag_maker,
             dag_id='test_backfill_respect_task_concurrency_limit',
             task_concurrency=task_concurrency,
         )
+        dag_maker.create_dagrun()
 
         executor = MockExecutor()
 
@@ -364,9 +349,9 @@ class TestBackfillJob:
         assert times_task_concurrency_limit_reached_in_debug > 0
 
     @patch('airflow.jobs.backfill_job.BackfillJob.log')
-    def test_backfill_respect_dag_concurrency_limit(self, mock_log, get_dummy_dag_and_run):
-
-        dag, _ = get_dummy_dag_and_run(dag_id='test_backfill_respect_concurrency_limit')
+    def test_backfill_respect_dag_concurrency_limit(self, mock_log, dag_maker):
+        dag = self._get_dummy_dag(dag_maker, dag_id='test_backfill_respect_concurrency_limit')
+        dag_maker.create_dagrun()
         dag.max_active_tasks = 2
 
         executor = MockExecutor()
@@ -415,11 +400,12 @@ class TestBackfillJob:
         assert times_dag_concurrency_limit_reached_in_debug > 0
 
     @patch('airflow.jobs.backfill_job.BackfillJob.log')
-    def test_backfill_respect_default_pool_limit(self, mock_log, get_dummy_dag_and_run):
+    def test_backfill_respect_default_pool_limit(self, mock_log, dag_maker):
         default_pool_slots = 2
         set_default_pool_slots(default_pool_slots)
 
-        dag, _ = get_dummy_dag_and_run(dag_id='test_backfill_with_no_pool_limit')
+        dag = self._get_dummy_dag(dag_maker, dag_id='test_backfill_with_no_pool_limit')
+        dag_maker.create_dagrun()
 
         executor = MockExecutor()
 
@@ -469,11 +455,13 @@ class TestBackfillJob:
         assert 0 == times_task_concurrency_limit_reached_in_debug
         assert times_pool_limit_reached_in_debug > 0
 
-    def test_backfill_pool_not_found(self, get_dummy_dag_and_run):
-        dag, _ = get_dummy_dag_and_run(
+    def test_backfill_pool_not_found(self, dag_maker):
+        dag = self._get_dummy_dag(
+            dag_maker,
             dag_id='test_backfill_pool_not_found',
             pool='king_pool',
         )
+        dag_maker.create_dagrun()
 
         executor = MockExecutor()
 
@@ -490,7 +478,7 @@ class TestBackfillJob:
             return
 
     @patch('airflow.jobs.backfill_job.BackfillJob.log')
-    def test_backfill_respect_pool_limit(self, mock_log, get_dummy_dag_and_run):
+    def test_backfill_respect_pool_limit(self, mock_log, dag_maker):
         session = settings.Session()
 
         slots = 2
@@ -501,10 +489,12 @@ class TestBackfillJob:
         session.add(pool)
         session.commit()
 
-        dag, _ = get_dummy_dag_and_run(
+        dag = self._get_dummy_dag(
+            dag_maker,
             dag_id='test_backfill_respect_pool_limit',
             pool=pool.pool,
         )
+        dag_maker.create_dagrun()
 
         executor = MockExecutor()
 
@@ -550,10 +540,11 @@ class TestBackfillJob:
         assert 0 == times_dag_concurrency_limit_reached_in_debug
         assert times_pool_limit_reached_in_debug > 0
 
-    def test_backfill_run_rescheduled(self, get_dummy_dag_and_run):
-        dag, _ = get_dummy_dag_and_run(
-            dag_id="test_backfill_run_rescheduled", task_id="test_backfill_run_rescheduled_task-1"
+    def test_backfill_run_rescheduled(self, dag_maker):
+        dag = self._get_dummy_dag(
+            dag_maker, dag_id="test_backfill_run_rescheduled", task_id="test_backfill_run_rescheduled_task-1"
         )
+        dag_maker.create_dagrun()
 
         executor = MockExecutor()
 
@@ -581,10 +572,11 @@ class TestBackfillJob:
         ti.refresh_from_db()
         assert ti.state == State.SUCCESS
 
-    def test_backfill_rerun_failed_tasks(self, get_dummy_dag_and_run):
-        dag, _ = get_dummy_dag_and_run(
-            dag_id="test_backfill_rerun_failed", task_id="test_backfill_rerun_failed_task-1"
+    def test_backfill_rerun_failed_tasks(self, dag_maker):
+        dag = self._get_dummy_dag(
+            dag_maker, dag_id="test_backfill_rerun_failed", task_id="test_backfill_rerun_failed_task-1"
         )
+        dag_maker.create_dagrun()
 
         executor = MockExecutor()
 
@@ -614,12 +606,11 @@ class TestBackfillJob:
 
     def test_backfill_rerun_upstream_failed_tasks(self, dag_maker):
 
-        with dag_maker(
-            dag_id='test_backfill_rerun_upstream_failed', start_date=DEFAULT_DATE, schedule_interval='@daily'
-        ) as dag:
+        with dag_maker(dag_id='test_backfill_rerun_upstream_failed', schedule_interval='@daily')
as dag:
             op1 = DummyOperator(task_id='test_backfill_rerun_upstream_failed_task-1')
             op2 = DummyOperator(task_id='test_backfill_rerun_upstream_failed_task-2')
             op1.set_upstream(op2)
+        dag_maker.create_dagrun()
 
         executor = MockExecutor()
 
@@ -647,10 +638,11 @@ class TestBackfillJob:
         ti.refresh_from_db()
         assert ti.state == State.SUCCESS
 
-    def test_backfill_rerun_failed_tasks_without_flag(self, get_dummy_dag_and_run):
-        dag, _ = get_dummy_dag_and_run(
-            dag_id='test_backfill_rerun_failed', task_id='test_backfill_rerun_failed_task-1'
+    def test_backfill_rerun_failed_tasks_without_flag(self, dag_maker):
+        dag = self._get_dummy_dag(
+            dag_maker, dag_id='test_backfill_rerun_failed', task_id='test_backfill_rerun_failed_task-1'
         )
+        dag_maker.create_dagrun()
 
         executor = MockExecutor()
 
@@ -680,7 +672,6 @@ class TestBackfillJob:
     def test_backfill_retry_intermittent_failed_task(self, dag_maker):
         with dag_maker(
             dag_id='test_intermittent_failure_job',
-            start_date=DEFAULT_DATE,
             schedule_interval="@daily",
             default_args={
                 'retries': 2,
@@ -688,6 +679,7 @@ class TestBackfillJob:
             },
         ) as dag:
             task1 = DummyOperator(task_id="task1")
+        dag_maker.create_dagrun()
 
         executor = MockExecutor(parallelism=16)
         executor.mock_task_results[
@@ -707,7 +699,6 @@ class TestBackfillJob:
     def test_backfill_retry_always_failed_task(self, dag_maker):
         with dag_maker(
             dag_id='test_always_failure_job',
-            start_date=DEFAULT_DATE,
             schedule_interval="@daily",
             default_args={
                 'retries': 1,
@@ -715,6 +706,7 @@ class TestBackfillJob:
             },
         ) as dag:
             task1 = DummyOperator(task_id="task1")
+        dag_maker.create_dagrun()
 
         executor = MockExecutor(parallelism=16)
         executor.mock_task_results[
@@ -734,7 +726,6 @@ class TestBackfillJob:
 
         with dag_maker(
             dag_id='test_backfill_ordered_concurrent_execute',
-            start_date=DEFAULT_DATE,
             schedule_interval="@daily",
         ) as dag:
             op1 = DummyOperator(task_id='leave1')
@@ -747,6 +738,7 @@ class TestBackfillJob:
             op1.set_downstream(op3)
             op4.set_downstream(op5)
             op3.set_downstream(op4)
+        dag_maker.create_dagrun()
 
         executor = MockExecutor(parallelism=16)
         job = BackfillJob(
@@ -881,10 +873,29 @@ class TestBackfillJob:
         parsed_args = self.parser.parse_args(args)
         assert 0.5 == parsed_args.delay_on_limit
 
-    def test_backfill_max_limit_check_within_limit(self, get_dag_test_max_active_limits):
-        dag, _ = get_dag_test_max_active_limits(
-            dag_id='test_backfill_max_limit_check_within_limit', max_active_runs=16
+    def _get_dag_test_max_active_limits(
+        self, dag_maker_fixture, dag_id='test_dag', max_active_runs=1, **kwargs
+    ):
+        with dag_maker_fixture(
+            dag_id=dag_id,
+            schedule_interval="@hourly",
+            max_active_runs=max_active_runs,
+            **kwargs,
+        ) as dag:
+            op1 = DummyOperator(task_id='leave1')
+            op2 = DummyOperator(task_id='leave2')
+            op3 = DummyOperator(task_id='upstream_level_1')
+            op4 = DummyOperator(task_id='upstream_level_2')
+
+            op1 >> op2 >> op3
+            op4 >> op3
+        return dag
+
+    def test_backfill_max_limit_check_within_limit(self, dag_maker):
+        dag = self._get_dag_test_max_active_limits(
+            dag_maker, dag_id='test_backfill_max_limit_check_within_limit', max_active_runs=16
         )
+        dag_maker.create_dagrun()
         start_date = DEFAULT_DATE - datetime.timedelta(hours=1)
         end_date = DEFAULT_DATE
 
@@ -898,7 +909,7 @@ class TestBackfillJob:
         assert 2 == len(dagruns)
         assert all(run.state == State.SUCCESS for run in dagruns)
 
-    def test_backfill_max_limit_check(self, get_dag_test_max_active_limits):
+    def test_backfill_max_limit_check(self, dag_maker):
         dag_id = 'test_backfill_max_limit_check'
         run_id = 'test_dag_run'
         start_date = DEFAULT_DATE - datetime.timedelta(hours=1)
@@ -911,9 +922,12 @@ class TestBackfillJob:
             # this session object is different than the one in the main thread
             with create_session() as thread_session:
                 try:
-                    dag, _ = get_dag_test_max_active_limits(
-                        # Existing dagrun that is not within the backfill range
+                    dag = self._get_dag_test_max_active_limits(
+                        dag_maker,
                         dag_id=dag_id,
+                    )
+                    dag_maker.create_dagrun(
+                        # Existing dagrun that is not within the backfill range
                         run_id=run_id,
                         execution_date=DEFAULT_DATE + datetime.timedelta(hours=1),
                     )
@@ -960,11 +974,14 @@ class TestBackfillJob:
             finally:
                 dag_run_created_cond.release()
 
-    def test_backfill_max_limit_check_no_count_existing(self, get_dag_test_max_active_limits):
+    def test_backfill_max_limit_check_no_count_existing(self, dag_maker):
         start_date = DEFAULT_DATE
         end_date = DEFAULT_DATE
         # Existing dagrun that is within the backfill range
-        dag, _ = get_dag_test_max_active_limits(dag_id='test_backfill_max_limit_check_no_count_existing')
+        dag = self._get_dag_test_max_active_limits(
+            dag_maker, dag_id='test_backfill_max_limit_check_no_count_existing'
+        )
+        dag_maker.create_dagrun()
 
         executor = MockExecutor()
         job = BackfillJob(
@@ -980,8 +997,11 @@ class TestBackfillJob:
         assert 1 == len(dagruns)
         assert State.SUCCESS == dagruns[0].state
 
-    def test_backfill_max_limit_check_complete_loop(self, get_dag_test_max_active_limits):
-        dag, _ = get_dag_test_max_active_limits(dag_id='test_backfill_max_limit_check_complete_loop')
+    def test_backfill_max_limit_check_complete_loop(self, dag_maker):
+        dag = self._get_dag_test_max_active_limits(
+            dag_maker, dag_id='test_backfill_max_limit_check_complete_loop'
+        )
+        dag_maker.create_dagrun()
         start_date = DEFAULT_DATE - datetime.timedelta(hours=1)
         end_date = DEFAULT_DATE
 
@@ -1003,9 +1023,6 @@ class TestBackfillJob:
 
         with dag_maker(
             'test_sub_set_subdag',
-            start_date=DEFAULT_DATE,
-            default_args={'owner': 'owner1'},
-            execution_date=DEFAULT_DATE,
         ) as dag:
             op1 = DummyOperator(task_id='leave1')
             op2 = DummyOperator(task_id='leave2')
@@ -1018,7 +1035,7 @@ class TestBackfillJob:
             op4.set_downstream(op5)
             op3.set_downstream(op4)
 
-        dr = dag_maker.dag_run
+        dr = dag_maker.create_dagrun()
 
         executor = MockExecutor()
         sub_dag = dag.partial_subset(
@@ -1043,9 +1060,6 @@ class TestBackfillJob:
     def test_backfill_fill_blanks(self, dag_maker):
         with dag_maker(
             'test_backfill_fill_blanks',
-            start_date=DEFAULT_DATE,
-            default_args={'owner': 'owner1'},
-            execution_date=DEFAULT_DATE,
         ) as dag:
             op1 = DummyOperator(task_id='op1')
             op2 = DummyOperator(task_id='op2')
@@ -1054,7 +1068,7 @@ class TestBackfillJob:
             op5 = DummyOperator(task_id='op5')
             op6 = DummyOperator(task_id='op6')
 
-        dr = dag_maker.dag_run
+        dr = dag_maker.create_dagrun()
 
         executor = MockExecutor()
 
@@ -1231,11 +1245,9 @@ class TestBackfillJob:
         dag.clear()
 
     def test_update_counters(self, dag_maker):
-        with dag_maker(
-            dag_id='test_manage_executor_state', start_date=DEFAULT_DATE, execution_date=DEFAULT_DATE
-        ) as dag:
-            task1 = DummyOperator(task_id='dummy', dag=dag, owner='airflow')
-        dr = dag_maker.dag_run
+        with dag_maker(dag_id='test_manage_executor_state', start_date=DEFAULT_DATE) as dag:
+            task1 = DummyOperator(task_id='dummy', owner='airflow')
+        dr = dag_maker.create_dagrun()
         job = BackfillJob(dag=dag)
         session = settings.Session()
 
@@ -1380,9 +1392,7 @@ class TestBackfillJob:
         states_to_reset = [State.QUEUED, State.SCHEDULED, State.NONE]
 
         tasks = []
-        with dag_maker(
-            dag_id=prefix, start_date=DEFAULT_DATE, schedule_interval="@daily", run_id='test1'
-        ) as dag:
+        with dag_maker(dag_id=prefix, start_date=DEFAULT_DATE, schedule_interval="@daily")
as dag:
             for i in range(len(states)):
                 task_id = f"{prefix}_task_{i}"
                 task = DummyOperator(task_id=task_id)
@@ -1392,7 +1402,7 @@ class TestBackfillJob:
         job = BackfillJob(dag=dag)
 
         # create dagruns
-        dr1 = dag_maker.dag_run
+        dr1 = dag_maker.create_dagrun()
         dr2 = dag.create_dagrun(run_id='test2', state=State.SUCCESS)
 
         # create taskinstances and set states
@@ -1445,15 +1455,13 @@ class TestBackfillJob:
             dag_id=dag_id,
             start_date=DEFAULT_DATE,
             schedule_interval='@daily',
-            state=State.SUCCESS,
-            run_id='test1',
         ) as dag:
             DummyOperator(task_id=task_id, dag=dag)
 
         job = BackfillJob(dag=dag)
         session = settings.Session()
         # make two dagruns, only reset for one
-        dr1 = dag_maker.dag_run  # Already created in dag_maker with state=SUCCESS
+        dr1 = dag_maker.create_dagrun(state=State.SUCCESS)
         dr2 = dag.create_dagrun(run_id='test2', state=State.RUNNING)
         ti1 = dr1.get_task_instances(session=session)[0]
         ti2 = dr2.get_task_instances(session=session)[0]
diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py
index 7aa596c..1d6d572 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -27,14 +27,12 @@ from unittest import mock
 from unittest.mock import patch
 
 import pytest
-from parameterized import parameterized
 
 from airflow import settings
 from airflow.exceptions import AirflowException, AirflowFailException
 from airflow.executors.sequential_executor import SequentialExecutor
 from airflow.jobs.local_task_job import LocalTaskJob
 from airflow.jobs.scheduler_job import SchedulerJob
-from airflow.models.dag import DAG, DagModel
 from airflow.models.dagbag import DagBag
 from airflow.models.taskinstance import TaskInstance
 from airflow.operators.dummy import DummyOperator
@@ -73,10 +71,19 @@ def clear_db_class():
     db.clear_db_task_fail()
 
 
+@pytest.fixture(scope='module')
+def dagbag():
+    return DagBag(
+        dag_folder=TEST_DAG_FOLDER,
+        include_examples=False,
+    )
+
+
 @pytest.mark.usefixtures('clear_db_class', 'clear_db')
 class TestLocalTaskJob:
     @pytest.fixture(autouse=True)
-    def set_instance_attrs(self):
+    def set_instance_attrs(self, dagbag):
+        self.dagbag = dagbag
         with patch('airflow.jobs.base_job.sleep') as self.mock_base_job_sleep:
             yield
 
@@ -92,12 +99,10 @@ class TestLocalTaskJob:
         of LocalTaskJob can be assigned with
         proper values without intervention
         """
-        with dag_maker(
-            'test_localtaskjob_essential_attr', start_date=DEFAULT_DATE, default_args={'owner':
'owner1'}
-        ):
+        with dag_maker('test_localtaskjob_essential_attr'):
             op1 = DummyOperator(task_id='op1')
 
-        dr = dag_maker.dag_run
+        dr = dag_maker.create_dagrun()
 
         ti = dr.get_task_instance(task_id=op1.task_id)
 
@@ -116,7 +121,7 @@ class TestLocalTaskJob:
         with dag_maker('test_localtaskjob_heartbeat'):
             op1 = DummyOperator(task_id='op1')
 
-        dr = dag_maker.dag_run
+        dr = dag_maker.create_dagrun()
         ti = dr.get_task_instance(task_id=op1.task_id, session=session)
         ti.state = State.RUNNING
         ti.hostname = "blablabla"
@@ -148,7 +153,7 @@ class TestLocalTaskJob:
         session = settings.Session()
         with dag_maker('test_localtaskjob_heartbeat'):
             op1 = DummyOperator(task_id='op1', run_as_user='myuser')
-        dr = dag_maker.dag_run
+        dr = dag_maker.create_dagrun()
         ti = dr.get_task_instance(task_id=op1.task_id, session=session)
         ti.state = State.RUNNING
         ti.pid = 2
@@ -190,7 +195,7 @@ class TestLocalTaskJob:
         session = settings.Session()
         with dag_maker('test_localtaskjob_heartbeat'):
             op1 = DummyOperator(task_id='op1')
-        dr = dag_maker.dag_run
+        dr = dag_maker.create_dagrun()
         ti = dr.get_task_instance(task_id=op1.task_id, session=session)
         ti.state = State.RUNNING
         ti.pid = 2
@@ -234,13 +239,10 @@ class TestLocalTaskJob:
         dag_id = 'test_heartbeat_failed_fast'
         task_id = 'test_heartbeat_failed_fast_op'
         with create_session() as session:
-            dagbag = DagBag(
-                dag_folder=TEST_DAG_FOLDER,
-                include_examples=False,
-            )
+
             dag_id = 'test_heartbeat_failed_fast'
             task_id = 'test_heartbeat_failed_fast_op'
-            dag = dagbag.get_dag(dag_id)
+            dag = self.dagbag.get_dag(dag_id)
             task = dag.get_task(task_id)
 
             dag.create_dagrun(
@@ -276,11 +278,7 @@ class TestLocalTaskJob:
         Test that ensures that mark_success in the UI doesn't cause
         the task to fail, and that the task exits
         """
-        dagbag = DagBag(
-            dag_folder=TEST_DAG_FOLDER,
-            include_examples=False,
-        )
-        dag = dagbag.dags.get('test_mark_success')
+        dag = self.dagbag.dags.get('test_mark_success')
         task = dag.get_task('task1')
 
         session = settings.Session()
@@ -316,11 +314,7 @@ class TestLocalTaskJob:
 
     def test_localtaskjob_double_trigger(self):
 
-        dagbag = DagBag(
-            dag_folder=TEST_DAG_FOLDER,
-            include_examples=False,
-        )
-        dag = dagbag.dags.get('test_localtaskjob_double_trigger')
+        dag = self.dagbag.dags.get('test_localtaskjob_double_trigger')
         task = dag.get_task('test_localtaskjob_double_trigger_task')
 
         session = settings.Session()
@@ -356,11 +350,8 @@ class TestLocalTaskJob:
 
     @pytest.mark.quarantined
     def test_localtaskjob_maintain_heart_rate(self):
-        dagbag = DagBag(
-            dag_folder=TEST_DAG_FOLDER,
-            include_examples=False,
-        )
-        dag = dagbag.dags.get('test_localtaskjob_double_trigger')
+
+        dag = self.dagbag.dags.get('test_localtaskjob_double_trigger')
         task = dag.get_task('test_localtaskjob_double_trigger_task')
 
         session = settings.Session()
@@ -439,6 +430,7 @@ class TestLocalTaskJob:
                 python_callable=task_function,
                 on_failure_callback=check_failure,
             )
+        dag_maker.create_dagrun()
         ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
         ti.refresh_from_db()
 
@@ -480,6 +472,7 @@ class TestLocalTaskJob:
                 python_callable=task_function,
                 on_failure_callback=failure_callback,
             )
+        dag_maker.create_dagrun()
         ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
         ti.refresh_from_db()
 
@@ -653,7 +646,8 @@ class TestLocalTaskJob:
         assert task_terminated_externally.value == 1
         assert not process.is_alive()
 
-    @parameterized.expand(
+    @pytest.mark.parametrize(
+        "conf, dependencies, init_state, first_run_state, second_run_state, error_message",
         [
             (
                 {('scheduler', 'schedule_after_task_execution'): 'True'},
@@ -687,27 +681,17 @@ class TestLocalTaskJob:
                 None,
                 "A -> C & B -> C, when A is QUEUED but B has FAILED, C is marked
UPSTREAM_FAILED.",
             ),
-        ]
+        ],
     )
     def test_fast_follow(
-        self, conf, dependencies, init_state, first_run_state, second_run_state, error_message
+        self, conf, dependencies, init_state, first_run_state, second_run_state, error_message,
dag_maker
     ):
 
         with conf_vars(conf):
             session = settings.Session()
 
-            dag = DAG('test_dagrun_fast_follow', start_date=DEFAULT_DATE)
-
-            dag_model = DagModel(
-                dag_id=dag.dag_id,
-                next_dagrun=dag.start_date,
-                is_active=True,
-            )
-            session.add(dag_model)
-            session.flush()
-
             python_callable = lambda: True
-            with dag:
+            with dag_maker('test_dagrun_fast_follow') as dag:
                 task_a = PythonOperator(task_id='A', python_callable=python_callable)
                 task_b = PythonOperator(task_id='B', python_callable=python_callable)
                 task_c = PythonOperator(task_id='C', python_callable=python_callable)
@@ -716,6 +700,8 @@ class TestLocalTaskJob:
                 for upstream, downstream in dependencies.items():
                     dag.set_dependency(upstream, downstream)
 
+            dag_maker.make_dagmodel()
+
             scheduler_job = SchedulerJob(subdir=os.devnull)
             scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
 
@@ -851,34 +837,24 @@ class TestLocalTaskJob:
         assert retry_callback_called.value == 1
         assert task_terminated_externally.value == 1
 
-    def test_task_exit_should_update_state_of_finished_dagruns_with_dag_paused(self):
+    def test_task_exit_should_update_state_of_finished_dagruns_with_dag_paused(self, dag_maker):
         """Test that with DAG paused, DagRun state will update when the tasks finishes the
run"""
-        dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE)
-        op1 = PythonOperator(task_id='dummy', dag=dag, owner='airflow', python_callable=lambda:
True)
+        with dag_maker(dag_id='test_dags') as dag:
+            op1 = PythonOperator(task_id='dummy', python_callable=lambda: True)
 
         session = settings.Session()
-        orm_dag = DagModel(
-            dag_id=dag.dag_id,
+        dag_maker.make_dagmodel(
             has_task_concurrency_limits=False,
-            next_dagrun=dag.start_date,
             next_dagrun_create_after=dag.following_schedule(DEFAULT_DATE),
             is_active=True,
             is_paused=True,
         )
-        session.add(orm_dag)
-        session.flush()
         # Write Dag to DB
         dagbag = DagBag(dag_folder="/dev/null", include_examples=False, read_dags_from_db=False)
         dagbag.bag_dag(dag, root_dag=dag)
         dagbag.sync_to_db()
 
-        dr = dag.create_dagrun(
-            run_type=DagRunType.SCHEDULED,
-            state=State.RUNNING,
-            execution_date=DEFAULT_DATE,
-            start_date=DEFAULT_DATE,
-            session=session,
-        )
+        dr = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
 
         assert dr.state == State.RUNNING
         ti = TaskInstance(op1, dr.execution_date)
@@ -901,13 +877,12 @@ def clean_db_helper():
 class TestLocalTaskJobPerformance:
     @pytest.mark.parametrize("return_codes", [[0], 9 * [None] + [0]])  # type: ignore
     @mock.patch("airflow.jobs.local_task_job.get_task_runner")
-    def test_number_of_queries_single_loop(self, mock_get_task_runner, return_codes):
+    def test_number_of_queries_single_loop(self, mock_get_task_runner, return_codes, dag_maker):
         unique_prefix = str(uuid.uuid4())
-        dag = DAG(dag_id=f'{unique_prefix}_test_number_of_queries', start_date=DEFAULT_DATE)
-        task = DummyOperator(task_id='test_state_succeeded1', dag=dag)
+        with dag_maker(dag_id=f'{unique_prefix}_test_number_of_queries'):
+            task = DummyOperator(task_id='test_state_succeeded1')
 
-        dag.clear()
-        dag.create_dagrun(run_id=unique_prefix, execution_date=DEFAULT_DATE, state=State.NONE)
+        dag_maker.create_dagrun(run_id=unique_prefix, state=State.NONE)
 
         ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)
 

Mime
View raw message