From commits-return-21292-archive-asf-public=cust-asf.ponee.io@airflow.incubator.apache.org Sun Sep 9 07:25:27 2018 Return-Path: X-Original-To: archive-asf-public@cust-asf.ponee.io Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by mx-eu-01.ponee.io (Postfix) with SMTP id 0BDAA180671 for ; Sun, 9 Sep 2018 07:25:23 +0200 (CEST) Received: (qmail 50611 invoked by uid 500); 9 Sep 2018 05:25:22 -0000 Mailing-List: contact commits-help@airflow.incubator.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@airflow.incubator.apache.org Delivered-To: mailing list commits@airflow.incubator.apache.org Received: (qmail 50602 invoked by uid 99); 9 Sep 2018 05:25:22 -0000 Received: from pnap-us-west-generic-nat.apache.org (HELO spamd3-us-west.apache.org) (209.188.14.142) by apache.org (qpsmtpd/0.29) with ESMTP; Sun, 09 Sep 2018 05:25:22 +0000 Received: from localhost (localhost [127.0.0.1]) by spamd3-us-west.apache.org (ASF Mail Server at spamd3-us-west.apache.org) with ESMTP id 4135F180A7D for ; Sun, 9 Sep 2018 05:25:22 +0000 (UTC) X-Virus-Scanned: Debian amavisd-new at spamd3-us-west.apache.org X-Spam-Flag: NO X-Spam-Score: -107.5 X-Spam-Level: X-Spam-Status: No, score=-107.5 tagged_above=-999 required=6.31 tests=[ENV_AND_HDR_SPF_MATCH=-0.5, KAM_ASCII_DIVIDERS=0.8, KAM_BADIPHTTP=2, RCVD_IN_DNSWL_MED=-2.3, SPF_PASS=-0.001, USER_IN_DEF_SPF_WL=-7.5, USER_IN_WHITELIST=-100, WEIRD_PORT=0.001] autolearn=disabled Received: from mx1-lw-eu.apache.org ([10.40.0.8]) by localhost (spamd3-us-west.apache.org [10.40.0.10]) (amavisd-new, port 10024) with ESMTP id NE_dnvKQFZqF for ; Sun, 9 Sep 2018 05:25:04 +0000 (UTC) Received: from mailrelay1-us-west.apache.org (mailrelay1-us-west.apache.org [209.188.14.139]) by mx1-lw-eu.apache.org (ASF Mail Server at mx1-lw-eu.apache.org) with ESMTP id D49995F2FF for ; Sun, 9 Sep 2018 05:25:02 +0000 (UTC) Received: from jira-lw-us.apache.org (unknown [207.244.88.139]) by mailrelay1-us-west.apache.org (ASF Mail Server at mailrelay1-us-west.apache.org) with ESMTP id AB05FE111B for ; Sun, 9 Sep 2018 05:25:01 +0000 (UTC) Received: from jira-lw-us.apache.org (localhost [127.0.0.1]) by jira-lw-us.apache.org (ASF Mail Server at jira-lw-us.apache.org) with ESMTP id 1DAFF26B58 for ; Sun, 9 Sep 2018 05:25:01 +0000 (UTC) Date: Sun, 9 Sep 2018 05:25:01 +0000 (UTC) From: "ASF GitHub Bot (JIRA)" To: commits@airflow.incubator.apache.org Message-ID: In-Reply-To: References: Subject: [jira] [Commented] (AIRFLOW-867) Tons of unit tests are ignored MIME-Version: 1.0 Content-Type: text/plain; charset=utf-8 Content-Transfer-Encoding: 7bit X-JIRA-FingerPrint: 30527f35849b9dde25b450d4833f0394 [ https://issues.apache.org/jira/browse/AIRFLOW-867?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=16608317#comment-16608317 ] ASF GitHub Bot commented on AIRFLOW-867: ---------------------------------------- r39132 closed pull request #2078: [AIRFLOW-867] Enable and fix lots of untested unit tests URL: https://github.com/apache/incubator-airflow/pull/2078 This is a PR merged from a forked repository. As GitHub hides the original diff on merge, it is displayed below for the sake of provenance: As this is a foreign pull request (from a fork), the diff is supplied below (as it won't show otherwise due to GitHub magic): diff --git a/airflow/contrib/operators/dataflow_operator.py b/airflow/contrib/operators/dataflow_operator.py index c1dca246f0..b93ee3f184 100644 --- a/airflow/contrib/operators/dataflow_operator.py +++ b/airflow/contrib/operators/dataflow_operator.py @@ -217,7 +217,7 @@ def google_cloud_to_local(self, file_name): # Extracts bucket_id and object_id by first removing 'gs://' prefix and # then split the remaining by path delimiter '/'. path_components = file_name[self.GCS_PREFIX_LENGTH:].split('/') - if path_components < 2: + if len(path_components) < 2: raise Exception( 'Invalid Google Cloud Storage (GCS) object path: {}.' .format(file_name)) diff --git a/airflow/contrib/operators/ecs_operator.py b/airflow/contrib/operators/ecs_operator.py index 7415d32e83..50e3fd92ae 100644 --- a/airflow/contrib/operators/ecs_operator.py +++ b/airflow/contrib/operators/ecs_operator.py @@ -88,7 +88,7 @@ def execute(self, context): def _wait_for_task_ended(self): waiter = self.client.get_waiter('tasks_stopped') - waiter.config.max_attempts = sys.maxint # timeout is managed by airflow + waiter.config.max_attempts = sys.maxsize # timeout is managed by airflow waiter.wait( cluster=self.cluster, tasks=[self.arn] diff --git a/airflow/models.py b/airflow/models.py index 8965adeca9..264696a0f8 100755 --- a/airflow/models.py +++ b/airflow/models.py @@ -1108,7 +1108,6 @@ def are_dependencies_met( :param verbose: whether or not to print details on failed dependencies :type verbose: boolean """ - dep_context = dep_context or DepContext() failed = False for dep_status in self.get_failed_dep_statuses( dep_context=dep_context, @@ -1131,7 +1130,6 @@ def get_failed_dep_statuses( self, dep_context=None, session=None): - dep_context = dep_context or DepContext() for dep in dep_context.deps | self.task.deps: for dep_status in dep.get_dep_statuses( self, diff --git a/airflow/operators/latest_only_operator.py b/airflow/operators/latest_only_operator.py index 8b4e614f44..fc1bb2fd00 100644 --- a/airflow/operators/latest_only_operator.py +++ b/airflow/operators/latest_only_operator.py @@ -34,7 +34,7 @@ class LatestOnlyOperator(BaseOperator): def execute(self, context): # If the DAG Run is externally triggered, then return without # skipping downstream tasks - if context['dag_run'].external_trigger: + if context['dag_run'] and context['dag_run'].external_trigger: logging.info("""Externally triggered DAG_Run: allowing execution to proceed.""") return diff --git a/airflow/ti_deps/deps/base_ti_dep.py b/airflow/ti_deps/deps/base_ti_dep.py index d73526417c..31b0b4c890 100644 --- a/airflow/ti_deps/deps/base_ti_dep.py +++ b/airflow/ti_deps/deps/base_ti_dep.py @@ -69,7 +69,7 @@ def _get_dep_statuses(self, ti, session, dep_context): raise NotImplementedError @provide_session - def get_dep_statuses(self, ti, session, dep_context): + def get_dep_statuses(self, ti, session, dep_context=None): """ Wrapper around the private _get_dep_statuses method that contains some global checks for all dependencies. @@ -81,6 +81,10 @@ def get_dep_statuses(self, ti, session, dep_context): :param dep_context: the context for which this dependency should be evaluated for :type dep_context: DepContext """ + from airflow.ti_deps.dep_context import DepContext + if dep_context is None: + dep_context = DepContext() + if self.IGNOREABLE and dep_context.ignore_all_deps: yield self._passing_status( reason="Context specified all dependencies should be ignored.") @@ -95,7 +99,7 @@ def get_dep_statuses(self, ti, session, dep_context): yield dep_status @provide_session - def is_met(self, ti, session, dep_context): + def is_met(self, ti, session, dep_context=None): """ Returns whether or not this dependency is met for a given task instance. A dependency is considered met if all of the dependency statuses it reports are @@ -113,7 +117,7 @@ def is_met(self, ti, session, dep_context): self.get_dep_statuses(ti, session, dep_context)) @provide_session - def get_failure_reasons(self, ti, session, dep_context): + def get_failure_reasons(self, ti, session, dep_context=None): """ Returns an iterable of strings that explain why this dependency wasn't met. diff --git a/scripts/ci/requirements.txt b/scripts/ci/requirements.txt index d969572ac0..ef0dd147f9 100644 --- a/scripts/ci/requirements.txt +++ b/scripts/ci/requirements.txt @@ -1,6 +1,7 @@ alembic bcrypt boto +boto3 celery cgroupspy chartkick @@ -9,6 +10,7 @@ coverage coveralls croniter cryptography +datadog dill distributed docker-py @@ -21,7 +23,9 @@ flask-cache flask-login==0.2.11 Flask-WTF flower +freezegun future +google-api-python-client gunicorn hdfs hive-thrift-py @@ -34,6 +38,7 @@ ldap3 lxml markdown mock +moto mysqlclient nose nose-exclude diff --git a/tests/__init__.py b/tests/__init__.py index 0c0a01b1cb..9d7677a99b 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -11,16 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import absolute_import - -from .api import * -from .configuration import * -from .contrib import * -from .core import * -from .executors import * -from .jobs import * -from .impersonation import * -from .models import * -from .operators import * -from .utils import * diff --git a/tests/api/__init__.py b/tests/api/__init__.py index 37d59f0d34..9d7677a99b 100644 --- a/tests/api/__init__.py +++ b/tests/api/__init__.py @@ -11,9 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import absolute_import - -from .client import * -from .common import * - diff --git a/tests/api/client/local_client.py b/tests/api/client/test_local_client.py similarity index 100% rename from tests/api/client/local_client.py rename to tests/api/client/test_local_client.py diff --git a/tests/api/common/mark_tasks.py b/tests/api/common/test_mark_tasks.py similarity index 100% rename from tests/api/common/mark_tasks.py rename to tests/api/common/test_mark_tasks.py diff --git a/tests/contrib/__init__.py b/tests/contrib/__init__.py index ff6f9e2529..9d7677a99b 100644 --- a/tests/contrib/__init__.py +++ b/tests/contrib/__init__.py @@ -11,7 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from __future__ import absolute_import -from .operators import * -from .sensors import * diff --git a/tests/contrib/hooks/aws_hook.py b/tests/contrib/hooks/test_aws_hook.py similarity index 86% rename from tests/contrib/hooks/aws_hook.py rename to tests/contrib/hooks/test_aws_hook.py index 6f13e58661..96bc72f729 100644 --- a/tests/contrib/hooks/aws_hook.py +++ b/tests/contrib/hooks/test_aws_hook.py @@ -14,24 +14,20 @@ # import unittest + import boto3 +from moto import mock_emr from airflow import configuration from airflow.contrib.hooks.aws_hook import AwsHook -try: - from moto import mock_emr -except ImportError: - mock_emr = None - - class TestAwsHook(unittest.TestCase): + @mock_emr def setUp(self): configuration.load_test_config() - @unittest.skipIf(mock_emr is None, 'mock_emr package not present') @mock_emr def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self): client = boto3.client('emr', region_name='us-east-1') @@ -42,6 +38,3 @@ def test_get_client_type_returns_a_boto3_client_of_the_requested_type(self): client_from_hook = hook.get_client_type('emr') self.assertEqual(client_from_hook.list_clusters()['Clusters'], []) - -if __name__ == '__main__': - unittest.main() diff --git a/tests/contrib/hooks/bigquery_hook.py b/tests/contrib/hooks/test_bigquery_hook.py similarity index 94% rename from tests/contrib/hooks/bigquery_hook.py rename to tests/contrib/hooks/test_bigquery_hook.py index 68856f8732..028316102e 100644 --- a/tests/contrib/hooks/bigquery_hook.py +++ b/tests/contrib/hooks/test_bigquery_hook.py @@ -104,13 +104,15 @@ def test_invalid_syntax_tiple_dot_var(self): self.assertIn('Format exception for var_x:', str(context.exception), "") + class TestBigQueryHookSourceFormat(unittest.TestCase): def test_invalid_source_format(self): with self.assertRaises(Exception) as context: - hook.BigQueryBaseCursor("test", "test").run_load("test.test", "test_schema.json", ["test_data.json"], source_format="json") - - # since we passed 'json' in, and it's not valid, make sure it's present in the error string. - self.assertIn("json", str(context.exception)) + hook.BigQueryBaseCursor("test", "test").run_load("test.test", + ["test_schema.json"], + ["test_data.json"], + source_format="json") + self.assertIn("JSON", str(context.exception)) class TestBigQueryBaseCursor(unittest.TestCase): @@ -134,6 +136,3 @@ def test_invalid_schema_update_and_write_disposition(self): write_disposition='WRITE_EMPTY' ) self.assertIn("schema_update_options is only", str(context.exception)) - -if __name__ == '__main__': - unittest.main() diff --git a/tests/contrib/hooks/emr_hook.py b/tests/contrib/hooks/test_emr_hook.py similarity index 83% rename from tests/contrib/hooks/emr_hook.py rename to tests/contrib/hooks/test_emr_hook.py index 119df99a7c..a3fa439320 100644 --- a/tests/contrib/hooks/emr_hook.py +++ b/tests/contrib/hooks/test_emr_hook.py @@ -14,30 +14,25 @@ # import unittest + import boto3 +from moto import mock_emr from airflow import configuration from airflow.contrib.hooks.emr_hook import EmrHook -try: - from moto import mock_emr -except ImportError: - mock_emr = None - - class TestEmrHook(unittest.TestCase): + @mock_emr def setUp(self): configuration.load_test_config() - @unittest.skipIf(mock_emr is None, 'mock_emr package not present') @mock_emr def test_get_conn_returns_a_boto3_connection(self): hook = EmrHook(aws_conn_id='aws_default') self.assertIsNotNone(hook.get_conn().list_clusters()) - @unittest.skipIf(mock_emr is None, 'mock_emr package not present') @mock_emr def test_create_job_flow_uses_the_emr_config_to_create_a_cluster(self): client = boto3.client('emr', region_name='us-east-1') @@ -47,7 +42,5 @@ def test_create_job_flow_uses_the_emr_config_to_create_a_cluster(self): hook = EmrHook(aws_conn_id='aws_default', emr_conn_id='emr_default') cluster = hook.create_job_flow({'Name': 'test_cluster'}) - self.assertEqual(client.list_clusters()['Clusters'][0]['Id'], cluster['JobFlowId']) - -if __name__ == '__main__': - unittest.main() + self.assertEqual(client.list_clusters()['Clusters'][0]['Id'], + cluster['JobFlowId']) diff --git a/tests/contrib/hooks/test_ftp_hook.py b/tests/contrib/hooks/test_ftp_hook.py index ab6f459aa5..dbb10273a7 100644 --- a/tests/contrib/hooks/test_ftp_hook.py +++ b/tests/contrib/hooks/test_ftp_hook.py @@ -77,7 +77,3 @@ def test_rename(self): self.conn_mock.rename.assert_called_once_with(from_path, to_path) self.conn_mock.quit.assert_called_once_with() - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/contrib/hooks/gcp_dataflow_hook.py b/tests/contrib/hooks/test_gcp_dataflow_hook.py similarity index 64% rename from tests/contrib/hooks/gcp_dataflow_hook.py rename to tests/contrib/hooks/test_gcp_dataflow_hook.py index 797d40cca2..a4403d0763 100644 --- a/tests/contrib/hooks/gcp_dataflow_hook.py +++ b/tests/contrib/hooks/test_gcp_dataflow_hook.py @@ -14,16 +14,11 @@ # import unittest -from airflow.contrib.hooks.gcp_dataflow_hook import DataFlowHook - -try: - from unittest import mock -except ImportError: - try: - import mock - except ImportError: - mock = None +import mock +from airflow.models import Connection +from airflow.contrib.hooks.gcp_api_base_hook import GoogleCloudBaseHook +from airflow.contrib.hooks.gcp_dataflow_hook import DataFlowHook TASK_ID = 'test-python-dataflow' PY_FILE = 'apache_beam.examples.wordcount' @@ -32,24 +27,15 @@ 'project': 'test', 'staging_location': 'gs://test/staging' } -BASE_STRING = 'airflow.contrib.hooks.gcp_api_base_hook.{}' -DATAFLOW_STRING = 'airflow.contrib.hooks.gcp_dataflow_hook.{}' - - -def mock_init(self, gcp_conn_id, delegate_to=None): - pass class DataFlowHookTest(unittest.TestCase): - def setUp(self): - with mock.patch(BASE_STRING.format('GoogleCloudBaseHook.__init__'), - new=mock_init): - self.dataflow_hook = DataFlowHook(gcp_conn_id='test') - - @mock.patch(DATAFLOW_STRING.format('DataFlowHook._start_dataflow')) + @mock.patch('airflow.contrib.hooks.gcp_dataflow_hook.DataFlowHook._start_dataflow') + @mock.patch.object(GoogleCloudBaseHook, 'get_connection', Connection) def test_start_python_dataflow(self, internal_dataflow_mock): - self.dataflow_hook.start_python_dataflow( + dataflow_hook = DataFlowHook(gcp_conn_id='test') + dataflow_hook.start_python_dataflow( task_id=TASK_ID, variables=OPTIONS, dataflow=PY_FILE, py_options=PY_OPTIONS) internal_dataflow_mock.assert_called_once_with( diff --git a/tests/contrib/hooks/test_jira_hook.py b/tests/contrib/hooks/test_jira_hook.py index 977944e006..a3a848062f 100644 --- a/tests/contrib/hooks/test_jira_hook.py +++ b/tests/contrib/hooks/test_jira_hook.py @@ -23,12 +23,12 @@ from airflow import models from airflow.utils import db -jira_client_mock = Mock( - name="jira_client" -) + +jira_client_mock = Mock(name="jira_client") class TestJiraHook(unittest.TestCase): + def setUp(self): configuration.load_test_config() db.merge_conn( @@ -45,7 +45,3 @@ def test_jira_client_connection(self, jira_mock): self.assertTrue(jira_mock.called) self.assertIsInstance(jira_hook.client, Mock) self.assertEqual(jira_hook.client.name, jira_mock.return_value.name) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/contrib/operators/__init__.py b/tests/contrib/operators/__init__.py index 6e38beaf58..9d7677a99b 100644 --- a/tests/contrib/operators/__init__.py +++ b/tests/contrib/operators/__init__.py @@ -11,8 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# - -from __future__ import absolute_import -from .ssh_execute_operator import * -from .fs_operator import * diff --git a/tests/contrib/operators/dataflow_operator.py b/tests/contrib/operators/test_dataflow_operator.py similarity index 77% rename from tests/contrib/operators/dataflow_operator.py rename to tests/contrib/operators/test_dataflow_operator.py index 7455a45f18..c43952c32b 100644 --- a/tests/contrib/operators/dataflow_operator.py +++ b/tests/contrib/operators/test_dataflow_operator.py @@ -14,17 +14,10 @@ # import unittest +import mock from airflow.contrib.operators.dataflow_operator import DataFlowPythonOperator -try: - from unittest import mock -except ImportError: - try: - import mock - except ImportError: - mock = None - TASK_ID = 'test-python-dataflow' PY_FILE = 'gs://my-bucket/my-object.py' @@ -36,7 +29,6 @@ ADDITIONAL_OPTIONS = { 'output': 'gs://test/output' } -GCS_HOOK_STRING = 'airflow.contrib.operators.dataflow_operator.{}' class DataFlowPythonOperatorTest(unittest.TestCase): @@ -60,23 +52,25 @@ def test_init(self): ADDITIONAL_OPTIONS) @mock.patch('airflow.contrib.operators.dataflow_operator.DataFlowHook') - @mock.patch(GCS_HOOK_STRING.format('GoogleCloudStorageHook')) - def test_exec(self, gcs_hook, dataflow_mock): + @mock.patch('airflow.contrib.operators.dataflow_operator.GoogleCloudStorageHook') + def test_exec(self, mock_gcs_hook, mock_dataflow_hook): """Test DataFlowHook is created and the right args are passed to start_python_workflow. """ - start_python_hook = dataflow_mock.return_value.start_python_dataflow - gcs_download_hook = gcs_hook.return_value.download + mock_gcs_download = mock.Mock(return_value=42) + mock_gcs_hook.return_value.download = mock_gcs_download + mock_start_python = mock_dataflow_hook.return_value.start_python_dataflow + self.dataflow.execute(None) - self.assertTrue(dataflow_mock.called) + self.assertTrue(mock_gcs_hook.called) expected_options = { 'project': 'test', 'staging_location': 'gs://test/staging', 'output': 'gs://test/output' } - gcs_download_hook.assert_called_once_with( - 'my-bucket', 'my-object.py', mock.ANY) - start_python_hook.assert_called_once_with(TASK_ID, expected_options, + mock_gcs_download.assert_called_once_with('my-bucket', 'my-object.py', + mock.ANY) + mock_start_python.assert_called_once_with(TASK_ID, expected_options, mock.ANY, PY_OPTIONS) self.assertTrue(self.dataflow.py_file.startswith('/tmp/dataflow')) diff --git a/tests/contrib/operators/ecs_operator.py b/tests/contrib/operators/test_ecs_operator.py similarity index 85% rename from tests/contrib/operators/ecs_operator.py rename to tests/contrib/operators/test_ecs_operator.py index 5a593a6a6e..97abbe2285 100644 --- a/tests/contrib/operators/ecs_operator.py +++ b/tests/contrib/operators/test_ecs_operator.py @@ -17,18 +17,12 @@ import unittest from copy import deepcopy +from mock import Mock, patch + from airflow import configuration from airflow.exceptions import AirflowException from airflow.contrib.operators.ecs_operator import ECSOperator -try: - from unittest import mock -except ImportError: - try: - import mock - except ImportError: - mock = None - RESPONSE_WITHOUT_FAILURES = { "failures": [], @@ -53,7 +47,7 @@ class TestECSOperator(unittest.TestCase): - @mock.patch('airflow.contrib.operators.ecs_operator.AwsHook') + @patch('airflow.contrib.operators.ecs_operator.AwsHook') def setUp(self, aws_hook_mock): configuration.load_test_config() @@ -80,8 +74,8 @@ def test_init(self): def test_template_fields_overrides(self): self.assertEqual(self.ecs.template_fields, ('overrides',)) - @mock.patch.object(ECSOperator, '_wait_for_task_ended') - @mock.patch.object(ECSOperator, '_check_success_task') + @patch.object(ECSOperator, '_wait_for_task_ended') + @patch.object(ECSOperator, '_check_success_task') def test_execute_without_failures(self, check_mock, wait_mock): client_mock = self.aws_hook_mock.return_value.get_client_type.return_value @@ -93,7 +87,7 @@ def test_execute_without_failures(self, check_mock, wait_mock): client_mock.run_task.assert_called_once_with( cluster='c', overrides={}, - startedBy='Airflow', + startedBy='airflow', taskDefinition='t' ) @@ -115,23 +109,22 @@ def test_execute_with_failures(self): client_mock.run_task.assert_called_once_with( cluster='c', overrides={}, - startedBy='Airflow', + startedBy='airflow', taskDefinition='t' ) def test_wait_end_tasks(self): - - client_mock = mock.Mock() + client_mock = Mock() self.ecs.arn = 'arn' self.ecs.client = client_mock self.ecs._wait_for_task_ended() client_mock.get_waiter.assert_called_once_with('tasks_stopped') client_mock.get_waiter.return_value.wait.assert_called_once_with(cluster='c', tasks=['arn']) - self.assertEquals(sys.maxint, client_mock.get_waiter.return_value.config.max_attempts) + self.assertEquals(sys.maxsize, client_mock.get_waiter.return_value.config.max_attempts) def test_check_success_tasks_raises(self): - client_mock = mock.Mock() + client_mock = Mock() self.ecs.arn = 'arn' self.ecs.client = client_mock @@ -144,14 +137,14 @@ def test_check_success_tasks_raises(self): }] }] } - with self.assertRaises(Exception) as e: + with self.assertRaises(AirflowException) as e: self.ecs._check_success_task() - self.assertEquals(str(e.exception), "This task is not in success state {'containers': [{'lastStatus': 'STOPPED', 'name': 'foo', 'exitCode': 1}]}") + self.assertIn("This task is not in success state", str(e.exception)) client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_tasks_raises_pending(self): - client_mock = mock.Mock() + client_mock = Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { @@ -164,11 +157,11 @@ def test_check_success_tasks_raises_pending(self): } with self.assertRaises(Exception) as e: self.ecs._check_success_task() - self.assertEquals(str(e.exception), "This task is still pending {'containers': [{'lastStatus': 'PENDING', 'name': 'container-name'}]}") + self.assertIn("This task is still pending", str(e.exception)) client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_tasks_raises_mutliple(self): - client_mock = mock.Mock() + client_mock = Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { @@ -187,7 +180,7 @@ def test_check_success_tasks_raises_mutliple(self): client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) def test_check_success_task_not_raises(self): - client_mock = mock.Mock() + client_mock = Mock() self.ecs.client = client_mock self.ecs.arn = 'arn' client_mock.describe_tasks.return_value = { @@ -201,7 +194,3 @@ def test_check_success_task_not_raises(self): } self.ecs._check_success_task() client_mock.describe_tasks.assert_called_once_with(cluster='c', tasks=['arn']) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/contrib/operators/emr_add_steps_operator.py b/tests/contrib/operators/test_emr_add_steps_operator.py similarity index 97% rename from tests/contrib/operators/emr_add_steps_operator.py rename to tests/contrib/operators/test_emr_add_steps_operator.py index 37f9a4c6c8..2ce0aaf1fe 100644 --- a/tests/contrib/operators/emr_add_steps_operator.py +++ b/tests/contrib/operators/test_emr_add_steps_operator.py @@ -18,6 +18,7 @@ from airflow import configuration from airflow.contrib.operators.emr_add_steps_operator import EmrAddStepsOperator + ADD_STEPS_SUCCESS_RETURN = { 'ResponseMetadata': { 'HTTPStatusCode': 200 @@ -37,7 +38,6 @@ def setUp(self): # Mock out the emr_client creator self.boto3_client_mock = MagicMock(return_value=mock_emr_client) - def test_execute_adds_steps_to_the_job_flow_and_returns_step_ids(self): with patch('boto3.client', self.boto3_client_mock): @@ -48,6 +48,3 @@ def test_execute_adds_steps_to_the_job_flow_and_returns_step_ids(self): ) self.assertEqual(operator.execute(None), ['s-2LH3R5GW3A53T']) - -if __name__ == '__main__': - unittest.main() diff --git a/tests/contrib/operators/emr_create_job_flow_operator.py b/tests/contrib/operators/test_emr_create_job_flow_operator.py similarity index 97% rename from tests/contrib/operators/emr_create_job_flow_operator.py rename to tests/contrib/operators/test_emr_create_job_flow_operator.py index 4aa4cd260a..2f69da10cb 100644 --- a/tests/contrib/operators/emr_create_job_flow_operator.py +++ b/tests/contrib/operators/test_emr_create_job_flow_operator.py @@ -19,6 +19,7 @@ from airflow import configuration from airflow.contrib.operators.emr_create_job_flow_operator import EmrCreateJobFlowOperator + RUN_JOB_FLOW_SUCCESS_RETURN = { 'ResponseMetadata': { 'HTTPStatusCode': 200 @@ -26,6 +27,7 @@ 'JobFlowId': 'j-8989898989' } + class TestEmrCreateJobFlowOperator(unittest.TestCase): def setUp(self): configuration.load_test_config() @@ -48,6 +50,3 @@ def test_execute_uses_the_emr_config_to_create_a_cluster_and_returns_job_id(self ) self.assertEqual(operator.execute(None), 'j-8989898989') - -if __name__ == '__main__': - unittest.main() diff --git a/tests/contrib/operators/emr_terminate_job_flow_operator.py b/tests/contrib/operators/test_emr_terminate_job_flow_operator.py similarity index 97% rename from tests/contrib/operators/emr_terminate_job_flow_operator.py rename to tests/contrib/operators/test_emr_terminate_job_flow_operator.py index 94c0124964..952e09a538 100644 --- a/tests/contrib/operators/emr_terminate_job_flow_operator.py +++ b/tests/contrib/operators/test_emr_terminate_job_flow_operator.py @@ -47,6 +47,3 @@ def test_execute_terminates_the_job_flow_and_does_not_error(self): ) operator.execute(None) - -if __name__ == '__main__': - unittest.main() diff --git a/tests/contrib/operators/fs_operator.py b/tests/contrib/operators/test_fs_operator.py similarity index 97% rename from tests/contrib/operators/fs_operator.py rename to tests/contrib/operators/test_fs_operator.py index f990157db6..d4b6fb9e2c 100644 --- a/tests/contrib/operators/fs_operator.py +++ b/tests/contrib/operators/test_fs_operator.py @@ -59,6 +59,3 @@ def test_simple(self): dag=self.dag, ) task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - -if __name__ == '__main__': - unittest.main() diff --git a/tests/contrib/operators/hipchat_operator.py b/tests/contrib/operators/test_hipchat_operator.py similarity index 86% rename from tests/contrib/operators/hipchat_operator.py rename to tests/contrib/operators/test_hipchat_operator.py index 65a2edd74a..4c14d853b8 100644 --- a/tests/contrib/operators/hipchat_operator.py +++ b/tests/contrib/operators/test_hipchat_operator.py @@ -13,6 +13,8 @@ # limitations under the License. import unittest + +import mock import requests from airflow.contrib.operators.hipchat_operator import \ @@ -20,20 +22,12 @@ from airflow.exceptions import AirflowException from airflow import configuration -try: - from unittest import mock -except ImportError: - try: - import mock - except ImportError: - mock = None - class HipChatOperatorTest(unittest.TestCase): + def setUp(self): configuration.load_test_config() - @unittest.skipIf(mock is None, 'mock package not present') @mock.patch('requests.request') def test_execute(self, request_mock): resp = requests.Response() @@ -50,7 +44,6 @@ def test_execute(self, request_mock): operator.execute(None) - @unittest.skipIf(mock is None, 'mock package not present') @mock.patch('requests.request') def test_execute_error_response(self, request_mock): resp = requests.Response() @@ -68,7 +61,3 @@ def test_execute_error_response(self, request_mock): with self.assertRaises(AirflowException): operator.execute(None) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/contrib/operators/jira_operator_test.py b/tests/contrib/operators/test_jira_operator_test.py similarity index 97% rename from tests/contrib/operators/jira_operator_test.py rename to tests/contrib/operators/test_jira_operator_test.py index 6d615df004..496a4b7384 100644 --- a/tests/contrib/operators/jira_operator_test.py +++ b/tests/contrib/operators/test_jira_operator_test.py @@ -23,10 +23,9 @@ from airflow import models from airflow.utils import db + DEFAULT_DATE = datetime.datetime(2017, 1, 1) -jira_client_mock = Mock( - name="jira_client_for_test" -) +jira_client_mock = Mock(name="jira_client_for_test") minimal_test_ticket = { "id": "911539", @@ -43,6 +42,7 @@ class TestJiraOperator(unittest.TestCase): + def setUp(self): configuration.load_test_config() args = { @@ -95,7 +95,3 @@ def test_update_issue(self, jira_mock): self.assertTrue(jira_mock.called) self.assertTrue(jira_mock.return_value.add_comment.called) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/contrib/operators/ssh_execute_operator.py b/tests/contrib/operators/test_ssh_execute_operator.py similarity index 97% rename from tests/contrib/operators/ssh_execute_operator.py rename to tests/contrib/operators/test_ssh_execute_operator.py index ef8162c2af..2d20eb205b 100644 --- a/tests/contrib/operators/ssh_execute_operator.py +++ b/tests/contrib/operators/test_ssh_execute_operator.py @@ -73,7 +73,3 @@ def test_with_env(self): dag=self.dag, ) task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/contrib/sensors/datadog_sensor.py b/tests/contrib/sensors/test_datadog_sensor.py similarity index 86% rename from tests/contrib/sensors/datadog_sensor.py rename to tests/contrib/sensors/test_datadog_sensor.py index 4d601e1dc1..a532c62687 100644 --- a/tests/contrib/sensors/datadog_sensor.py +++ b/tests/contrib/sensors/test_datadog_sensor.py @@ -13,7 +13,7 @@ # limitations under the License. import unittest -from mock import patch +from mock import Mock, patch from airflow.contrib.sensors.datadog_sensor import DatadogSensor @@ -49,14 +49,17 @@ zero_events = [] +mock_connection = Mock(extra_dejson={'api_key': 'foo', 'app_key': 'bar'}) +patch_connection = patch('airflow.contrib.hooks.datadog_hook.DatadogHook.get_connection', + return_value=mock_connection) + class TestDatadogSensor(unittest.TestCase): - @patch('airflow.contrib.hooks.datadog_hook.api.Event.query') - @patch('airflow.contrib.sensors.datadog_sensor.api.Event.query') - def test_sensor_ok(self, api1, api2): - api1.return_value = at_least_one_event - api2.return_value = at_least_one_event + @patch('airflow.contrib.sensors.datadog_sensor.api.Event.query', + return_value=at_least_one_event) + @patch_connection + def test_sensor_ok(self, *_): sensor = DatadogSensor( task_id='test_datadog', datadog_conn_id='datadog_default', @@ -69,12 +72,10 @@ def test_sensor_ok(self, api1, api2): self.assertTrue(sensor.poke({})) - @patch('airflow.contrib.hooks.datadog_hook.api.Event.query') - @patch('airflow.contrib.sensors.datadog_sensor.api.Event.query') - def test_sensor_fail(self, api1, api2): - api1.return_value = zero_events - api2.return_value = zero_events - + @patch('airflow.contrib.sensors.datadog_sensor.api.Event.query', + return_value=[]) + @patch_connection + def test_sensor_fail(self, *_): sensor = DatadogSensor( task_id='test_datadog', datadog_conn_id='datadog_default', @@ -86,6 +87,3 @@ def test_sensor_fail(self, api1, api2): response_check=None) self.assertFalse(sensor.poke({})) - -if __name__ == '__main__': - unittest.main() diff --git a/tests/contrib/sensors/emr_base_sensor.py b/tests/contrib/sensors/test_emr_base_sensor.py similarity index 97% rename from tests/contrib/sensors/emr_base_sensor.py rename to tests/contrib/sensors/test_emr_base_sensor.py index 0b8ad2f479..edd1ec45d7 100644 --- a/tests/contrib/sensors/emr_base_sensor.py +++ b/tests/contrib/sensors/test_emr_base_sensor.py @@ -20,6 +20,7 @@ class TestEmrBaseSensor(unittest.TestCase): + def setUp(self): configuration.load_test_config() @@ -92,7 +93,6 @@ def state_from_response(self, response): self.assertEqual(operator.poke(None), False) - def test_poke_raises_error_when_job_has_failed(self): class EmrBaseSensorSubclass(EmrBaseSensor): NON_TERMINAL_STATES = ['PENDING', 'RUNNING', 'CONTINUE'] @@ -115,12 +115,6 @@ def state_from_response(self, response): ) with self.assertRaises(AirflowException) as context: - operator.poke(None) - - self.assertTrue('EMR job failed' in context.exception) - - -if __name__ == '__main__': - unittest.main() + self.assertTrue('EMR job failed' in str(context.exception)) diff --git a/tests/contrib/sensors/emr_job_flow_sensor.py b/tests/contrib/sensors/test_emr_job_flow_sensor.py similarity index 98% rename from tests/contrib/sensors/emr_job_flow_sensor.py rename to tests/contrib/sensors/test_emr_job_flow_sensor.py index f9937866f1..de0bef6bab 100644 --- a/tests/contrib/sensors/emr_job_flow_sensor.py +++ b/tests/contrib/sensors/test_emr_job_flow_sensor.py @@ -20,6 +20,7 @@ from airflow import configuration from airflow.contrib.sensors.emr_job_flow_sensor import EmrJobFlowSensor + DESCRIBE_CLUSTER_RUNNING_RETURN = { 'Cluster': { 'Applications': [ @@ -86,6 +87,7 @@ class TestEmrJobFlowSensor(unittest.TestCase): + def setUp(self): configuration.load_test_config() @@ -99,7 +101,6 @@ def setUp(self): # Mock out the emr_client creator self.boto3_client_mock = MagicMock(return_value=self.mock_emr_client) - def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_terminal_state(self): with patch('boto3.client', self.boto3_client_mock): @@ -117,7 +118,3 @@ def test_execute_calls_with_the_job_flow_id_until_it_reaches_a_terminal_state(se # make sure it was called with the job_flow_id self.mock_emr_client.describe_cluster.assert_called_with(ClusterId='j-8989898989') - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/contrib/sensors/emr_step_sensor.py b/tests/contrib/sensors/test_emr_step_sensor.py similarity index 96% rename from tests/contrib/sensors/emr_step_sensor.py rename to tests/contrib/sensors/test_emr_step_sensor.py index 58ee461f81..09c4e1e151 100644 --- a/tests/contrib/sensors/emr_step_sensor.py +++ b/tests/contrib/sensors/test_emr_step_sensor.py @@ -16,11 +16,11 @@ import datetime from dateutil.tz import tzlocal from mock import MagicMock, patch -import boto3 from airflow import configuration from airflow.contrib.sensors.emr_step_sensor import EmrStepSensor + DESCRIBE_JOB_STEP_RUNNING_RETURN = { 'ResponseMetadata': { 'HTTPStatusCode': 200, @@ -94,10 +94,8 @@ def setUp(self): # Mock out the emr_client creator self.boto3_client_mock = MagicMock(return_value=self.mock_emr_client) - def test_execute_calls_with_the_job_flow_id_and_step_id_until_it_reaches_a_terminal_state(self): with patch('boto3.client', self.boto3_client_mock): - operator = EmrStepSensor( task_id='test_task', poke_interval=1, @@ -105,15 +103,11 @@ def test_execute_calls_with_the_job_flow_id_and_step_id_until_it_reaches_a_termi step_id='s-VK57YR1Z9Z5N', aws_conn_id='aws_default', ) - operator.execute(None) # make sure we called twice self.assertEqual(self.mock_emr_client.describe_step.call_count, 2) # make sure it was called with the job_flow_id and step_id - self.mock_emr_client.describe_step.assert_called_with(ClusterId='j-8989898989', StepId='s-VK57YR1Z9Z5N') - - -if __name__ == '__main__': - unittest.main() + self.mock_emr_client.describe_step.assert_called_with(ClusterId='j-8989898989', + StepId='s-VK57YR1Z9Z5N') diff --git a/tests/contrib/sensors/hdfs_sensors.py b/tests/contrib/sensors/test_hdfs_sensors.py similarity index 90% rename from tests/contrib/sensors/hdfs_sensors.py rename to tests/contrib/sensors/test_hdfs_sensors.py index 0e2ed0c7a6..fe2f7600d8 100644 --- a/tests/contrib/sensors/hdfs_sensors.py +++ b/tests/contrib/sensors/test_hdfs_sensors.py @@ -11,21 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import logging import sys import unittest import re from datetime import timedelta + from airflow.contrib.sensors.hdfs_sensors import HdfsSensorFolder, HdfsSensorRegex from airflow.exceptions import AirflowSensorTimeout +from tests.operators.test_sensors import MockHDFSHook class HdfsSensorFolderTests(unittest.TestCase): + + @unittest.skipIf(sys.version_info[0] == 3, "HdfsSensor won't work with python3") def setUp(self): - if sys.version_info[0] == 3: - raise unittest.SkipTest('HdfsSensor won\'t work with python3. No need to test anything here') - from tests.core import FakeHDFSHook - self.hook = FakeHDFSHook self.logger = logging.getLogger() self.logger.setLevel(logging.DEBUG) @@ -44,7 +45,7 @@ def test_should_be_empty_directory(self): timeout=1, retry_delay=timedelta(seconds=1), poke_interval=1, - hook=self.hook) + hook=MockHDFSHook) # When task.execute(None) @@ -67,7 +68,7 @@ def test_should_be_empty_directory_fail(self): timeout=1, retry_delay=timedelta(seconds=1), poke_interval=1, - hook=self.hook) + hook=MockHDFSHook) # When # Then @@ -88,7 +89,7 @@ def test_should_be_a_non_empty_directory(self): timeout=1, retry_delay=timedelta(seconds=1), poke_interval=1, - hook=self.hook) + hook=MockHDFSHook) # When task.execute(None) @@ -110,7 +111,7 @@ def test_should_be_non_empty_directory_fail(self): timeout=1, retry_delay=timedelta(seconds=1), poke_interval=1, - hook=self.hook) + hook=MockHDFSHook) # When # Then @@ -119,11 +120,9 @@ def test_should_be_non_empty_directory_fail(self): class HdfsSensorRegexTests(unittest.TestCase): + + @unittest.skipIf(sys.version_info[0] == 3, "HdfsSensor won't work with python3") def setUp(self): - if sys.version_info[0] == 3: - raise unittest.SkipTest('HdfsSensor won\'t work with python3. No need to test anything here') - from tests.core import FakeHDFSHook - self.hook = FakeHDFSHook self.logger = logging.getLogger() self.logger.setLevel(logging.DEBUG) @@ -143,7 +142,7 @@ def test_should_match_regex(self): timeout=1, retry_delay=timedelta(seconds=1), poke_interval=1, - hook=self.hook) + hook=MockHDFSHook) # When task.execute(None) @@ -167,7 +166,7 @@ def test_should_not_match_regex(self): timeout=1, retry_delay=timedelta(seconds=1), poke_interval=1, - hook=self.hook) + hook=MockHDFSHook) # When # Then @@ -193,7 +192,7 @@ def test_should_match_regex_and_filesize(self): timeout=1, retry_delay=timedelta(seconds=1), poke_interval=1, - hook=self.hook) + hook=MockHDFSHook) # When task.execute(None) @@ -218,7 +217,7 @@ def test_should_match_regex_but_filesize(self): timeout=1, retry_delay=timedelta(seconds=1), poke_interval=1, - hook=self.hook) + hook=MockHDFSHook) # When # Then @@ -243,7 +242,7 @@ def test_should_match_regex_but_copyingext(self): timeout=1, retry_delay=timedelta(seconds=1), poke_interval=1, - hook=self.hook) + hook=MockHDFSHook) # When # Then diff --git a/tests/contrib/sensors/jira_sensor_test.py b/tests/contrib/sensors/test_jira_sensor_test.py similarity index 96% rename from tests/contrib/sensors/jira_sensor_test.py rename to tests/contrib/sensors/test_jira_sensor_test.py index 77ca97fc59..e3b6aaaf54 100644 --- a/tests/contrib/sensors/jira_sensor_test.py +++ b/tests/contrib/sensors/test_jira_sensor_test.py @@ -23,11 +23,9 @@ from airflow import models from airflow.utils import db -DEFAULT_DATE = datetime.datetime(2017, 1, 1) -jira_client_mock = Mock( - name="jira_client_for_test" -) +DEFAULT_DATE = datetime.datetime(2017, 1, 1) +jira_client_mock = Mock(name="jira_client_for_test") minimal_test_ticket = { "id": "911539", "self": "https://sandbox.localhost/jira/rest/api/2/issue/911539", @@ -43,6 +41,7 @@ class TestJiraSensor(unittest.TestCase): + def setUp(self): configuration.load_test_config() args = { @@ -79,7 +78,3 @@ def test_issue_label_set(self, jira_mock): @staticmethod def field_checker_func(context, issue): return "test-label-1" in issue['fields']['labels'] - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/executors/__init__.py b/tests/executors/__init__.py index f694969280..9d7677a99b 100644 --- a/tests/executors/__init__.py +++ b/tests/executors/__init__.py @@ -4,12 +4,10 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from .dask_executor import * diff --git a/tests/executors/dask_executor.py b/tests/executors/test_dask_executor.py similarity index 100% rename from tests/executors/dask_executor.py rename to tests/executors/test_dask_executor.py diff --git a/tests/operators/__init__.py b/tests/operators/__init__.py index 1fb0e5e090..9d7677a99b 100644 --- a/tests/operators/__init__.py +++ b/tests/operators/__init__.py @@ -11,10 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from .docker_operator import * -from .subdag_operator import * -from .operators import * -from .sensors import * -from .hive_operator import * -from .s3_to_hive_operator import * diff --git a/tests/operators/docker_operator.py b/tests/operators/test_docker_operator.py similarity index 92% rename from tests/operators/docker_operator.py rename to tests/operators/test_docker_operator.py index cdfae94b0d..7561a116c5 100644 --- a/tests/operators/docker_operator.py +++ b/tests/operators/test_docker_operator.py @@ -13,26 +13,18 @@ # limitations under the License. import unittest +import mock +from airflow.exceptions import AirflowException try: from airflow.operators.docker_operator import DockerOperator from docker.client import Client except ImportError: pass -from airflow.exceptions import AirflowException - -try: - from unittest import mock -except ImportError: - try: - import mock - except ImportError: - mock = None - class DockerOperatorTestCase(unittest.TestCase): - @unittest.skipIf(mock is None, 'mock package not present') + @mock.patch('airflow.utils.file.mkdtemp') @mock.patch('airflow.operators.docker_operator.Client') def test_execute(self, client_class_mock, mkdtemp_mock): @@ -73,7 +65,6 @@ def test_execute(self, client_class_mock, mkdtemp_mock): client_mock.pull.assert_called_with('ubuntu:latest', stream=True) client_mock.wait.assert_called_with('some_id') - @unittest.skipIf(mock is None, 'mock package not present') @mock.patch('airflow.operators.docker_operator.tls.TLSConfig') @mock.patch('airflow.operators.docker_operator.Client') def test_execute_tls(self, client_class_mock, tls_class_mock): @@ -101,7 +92,6 @@ def test_execute_tls(self, client_class_mock, tls_class_mock): client_class_mock.assert_called_with(base_url='https://127.0.0.1:2376', tls=tls_mock, version=None) - @unittest.skipIf(mock is None, 'mock package not present') @mock.patch('airflow.operators.docker_operator.Client') def test_execute_container_fails(self, client_class_mock): client_mock = mock.Mock(spec=Client) @@ -119,7 +109,6 @@ def test_execute_container_fails(self, client_class_mock): with self.assertRaises(AirflowException): operator.execute(None) - @unittest.skipIf(mock is None, 'mock package not present') def test_on_kill(self): client_mock = mock.Mock(spec=Client) @@ -130,7 +119,3 @@ def test_on_kill(self): operator.on_kill() client_mock.stop.assert_called_with('some_id') - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/operators/hive_operator.py b/tests/operators/test_hive_operator.py similarity index 99% rename from tests/operators/hive_operator.py rename to tests/operators/test_hive_operator.py index 69166fd92e..2932697674 100644 --- a/tests/operators/hive_operator.py +++ b/tests/operators/test_hive_operator.py @@ -89,19 +89,18 @@ def test_select_conn_with_schema(self, connect_mock): def test_get_results_with_schema(self): from airflow.hooks.hive_hooks import HiveServer2Hook - from unittest.mock import MagicMock # Configure sql = "select 1" schema = "notdefault" hook = HiveServer2Hook() - cursor_mock = MagicMock( + cursor_mock = mock.MagicMock( __enter__=cursor_mock, __exit__=None, execute=None, fetchall=[], ) - get_conn_mock = MagicMock( + get_conn_mock = mock.MagicMock( __enter__=get_conn_mock, __exit__=None, cursor=cursor_mock, diff --git a/tests/operators/latest_only_operator.py b/tests/operators/test_latest_only_operator.py similarity index 100% rename from tests/operators/latest_only_operator.py rename to tests/operators/test_latest_only_operator.py diff --git a/tests/operators/operators.py b/tests/operators/test_operators.py similarity index 100% rename from tests/operators/operators.py rename to tests/operators/test_operators.py diff --git a/tests/operators/s3_to_hive_operator.py b/tests/operators/test_s3_to_hive_operator.py similarity index 96% rename from tests/operators/s3_to_hive_operator.py rename to tests/operators/test_s3_to_hive_operator.py index faab11e15f..075c7976e2 100644 --- a/tests/operators/s3_to_hive_operator.py +++ b/tests/operators/test_s3_to_hive_operator.py @@ -12,25 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest -try: - from unittest import mock -except ImportError: - try: - import mock - except ImportError: - mock = None +import bz2 +import errno +import filecmp +import gzip +import itertools import logging -from itertools import product -from airflow.operators.s3_to_hive_operator import S3ToHiveTransfer +import shutil +import unittest + from collections import OrderedDict -from airflow.exceptions import AirflowException from tempfile import NamedTemporaryFile, mkdtemp -import gzip -import bz2 -import shutil -import filecmp -import errno + +import mock + +from airflow.exceptions import AirflowException +from airflow.operators.s3_to_hive_operator import S3ToHiveTransfer class S3ToHiveTransferTest(unittest.TestCase): @@ -204,12 +201,11 @@ def test__delete_top_row_and_compress(self): self.assertTrue(self._check_file_equality(bz2_txt_nh, fn_bz2, '.bz2'), msg="bz2 Compressed file not as expected") - @unittest.skipIf(mock is None, 'mock package not present') @mock.patch('airflow.operators.s3_to_hive_operator.HiveCliHook') @mock.patch('airflow.operators.s3_to_hive_operator.S3Hook') def test_execute(self, mock_s3hook, mock_hiveclihook): # Testing txt, zip, bz2 files with and without header row - for test in product(['.txt', '.gz', '.bz2'], [True, False]): + for test in itertools.product(['.txt', '.gz', '.bz2'], [True, False]): ext = test[0] has_header = test[1] self.kwargs['headers'] = has_header @@ -241,7 +237,3 @@ def test_execute(self, mock_s3hook, mock_hiveclihook): # Execute S3ToHiveTransfer s32hive = S3ToHiveTransfer(**self.kwargs) s32hive.execute(None) - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/operators/sensors.py b/tests/operators/test_sensors.py similarity index 50% rename from tests/operators/sensors.py rename to tests/operators/test_sensors.py index e77216b580..58c7c187de 100644 --- a/tests/operators/sensors.py +++ b/tests/operators/test_sensors.py @@ -14,25 +14,165 @@ import logging -import os import sys import time import unittest - from datetime import datetime, timedelta +from mock import Mock + from airflow import DAG, configuration from airflow.operators.sensors import HttpSensor, BaseSensorOperator, HdfsSensor from airflow.utils.decorators import apply_defaults from airflow.exceptions import (AirflowException, AirflowSensorTimeout, AirflowSkipException) + configuration.load_test_config() DEFAULT_DATE = datetime(2015, 1, 1) TEST_DAG_ID = 'unit_test_dag' +def mock_ls(path, include_toplevel=False): + """ + the fake snakebite client + :param path: the array of path to test + :param include_toplevel: to return the toplevel directory info + :return: a list for path for the matching queries + """ + p = path[0] + + if p == '/datadirectory/empty_directory' and not include_toplevel: + return [] + + if p == '/datadirectory/datafile': + return [{'group': u'supergroup', + 'permission': 420, + 'file_type': 'f', + 'access_time': 1481122343796, + 'block_replication': 3, + 'modification_time': 1481122343862, + 'length': 0, + 'blocksize': 134217728, + 'owner': u'hdfs', + 'path': '/datadirectory/datafile', + }] + + if p == '/datadirectory/empty_directory' and include_toplevel: + return [{'group': u'supergroup', + 'permission': 493, + 'file_type': 'd', + 'access_time': 0, + 'block_replication': 0, + 'modification_time': 1481132141540, + 'length': 0, + 'blocksize': 0, + 'owner': u'hdfs', + 'path': '/datadirectory/empty_directory', + }] + + if p == '/datadirectory/not_empty_directory' and include_toplevel: + return [{'group': u'supergroup', + 'permission': 493, + 'file_type': 'd', + 'access_time': 0, + 'block_replication': 0, + 'modification_time': 1481132141540, + 'length': 0, + 'blocksize': 0, + 'owner': u'hdfs', + 'path': '/datadirectory/empty_directory' + }, + {'group': u'supergroup', + 'permission': 420, + 'file_type': 'f', + 'access_time': 1481122343796, + 'block_replication': 3, + 'modification_time': 1481122343862, + 'length': 0, + 'blocksize': 134217728, + 'owner': u'hdfs', + 'path': '/datadirectory/not_empty_directory/test_file', + }] + + if p == '/datadirectory/not_empty_directory': + return [{'group': u'supergroup', + 'permission': 420, + 'file_type': 'f', + 'access_time': 1481122343796, + 'block_replication': 3, + 'modification_time': 1481122343862, + 'length': 0, + 'blocksize': 134217728, + 'owner': u'hdfs', + 'path': '/datadirectory/not_empty_directory/test_file', + }] + + if p == '/datadirectory/regex_dir': + return [{'group': u'supergroup', + 'permission': 420, + 'file_type': 'f', + 'access_time': 1481122343796, + 'block_replication': 3, + 'modification_time': 1481122343862, + 'length': 12582912, + 'blocksize': 134217728, + 'owner': u'hdfs', + 'path': '/datadirectory/regex_dir/test1file' + }, + {'group': u'supergroup', + 'permission': 420, + 'file_type': 'f', + 'access_time': 1481122343796, + 'block_replication': 3, + 'modification_time': 1481122343862, + 'length': 12582912, + 'blocksize': 134217728, + 'owner': u'hdfs', + 'path': '/datadirectory/regex_dir/test2file' + }, + {'group': u'supergroup', + 'permission': 420, + 'file_type': 'f', + 'access_time': 1481122343796, + 'block_replication': 3, + 'modification_time': 1481122343862, + 'length': 12582912, + 'blocksize': 134217728, + 'owner': u'hdfs', + 'path': '/datadirectory/regex_dir/test3file' + }, + {'group': u'supergroup', + 'permission': 420, + 'file_type': 'f', + 'access_time': 1481122343796, + 'block_replication': 3, + 'modification_time': 1481122343862, + 'length': 12582912, + 'blocksize': 134217728, + 'owner': u'hdfs', + 'path': '/datadirectory/regex_dir/copying_file_1.txt._COPYING_' + }, + {'group': u'supergroup', + 'permission': 420, + 'file_type': 'f', + 'access_time': 1481122343796, + 'block_replication': 3, + 'modification_time': 1481122343862, + 'length': 12582912, + 'blocksize': 134217728, + 'owner': u'hdfs', + 'path': '/datadirectory/regex_dir/copying_file_3.txt.sftp' + }] + + raise Exception(p) + + +MockSnakeBiteClient = Mock(return_value=Mock(ls=mock_ls)) +MockHDFSHook = Mock(return_value=Mock(get_conn=MockSnakeBiteClient)) + + class TimeoutTestSensor(BaseSensorOperator): """ Sensor that always returns the return_value provided @@ -115,11 +255,8 @@ def resp_check(resp): class HdfsSensorTests(unittest.TestCase): + @unittest.skipIf(sys.version_info[0] == 3, "HdfsSensor won't work with python3") def setUp(self): - if sys.version_info[0] == 3: - raise unittest.SkipTest('HdfsSensor won\'t work with python3. No need to test anything here') - from tests.core import FakeHDFSHook - self.hook = FakeHDFSHook self.logger = logging.getLogger() self.logger.setLevel(logging.DEBUG) @@ -136,7 +273,7 @@ def test_legacy_file_exist(self): timeout=1, retry_delay=timedelta(seconds=1), poke_interval=1, - hook=self.hook) + hook=MockHDFSHook) task.execute(None) # Then @@ -156,7 +293,7 @@ def test_legacy_file_exist_but_filesize(self): file_size=20, retry_delay=timedelta(seconds=1), poke_interval=1, - hook=self.hook) + hook=MockHDFSHook) # When # Then @@ -175,7 +312,7 @@ def test_legacy_file_does_not_exists(self): timeout=1, retry_delay=timedelta(seconds=1), poke_interval=1, - hook=self.hook) + hook=MockHDFSHook) # When # Then diff --git a/tests/operators/subdag_operator.py b/tests/operators/test_subdag_operator.py similarity index 100% rename from tests/operators/subdag_operator.py rename to tests/operators/test_subdag_operator.py diff --git a/tests/configuration.py b/tests/test_configuration.py similarity index 100% rename from tests/configuration.py rename to tests/test_configuration.py diff --git a/tests/core.py b/tests/test_core.py similarity index 90% rename from tests/core.py rename to tests/test_core.py index 47a7d2b0da..676974ce2d 100644 --- a/tests/core.py +++ b/tests/test_core.py @@ -15,52 +15,59 @@ from __future__ import print_function import doctest +import multiprocessing import os +import psutil import re -import unittest -import multiprocessing -import mock -from numpy.testing import assert_array_almost_equal +import signal +import socket +import subprocess import tempfile +import unittest +from six.moves.urllib.parse import urlparse, parse_qsl +import warnings + from datetime import datetime, time, timedelta from email.mime.multipart import MIMEMultipart from email.mime.application import MIMEApplication -import signal from time import sleep -import warnings +import mock from dateutil.relativedelta import relativedelta -import sqlalchemy +from freezegun import freeze_time +from lxml import html +from numpy.testing import assert_array_almost_equal +from six import StringIO +from six.moves import cPickle as pickle +from sqlalchemy.engine import Engine from airflow import configuration from airflow.executors import SequentialExecutor, LocalExecutor from airflow.models import Variable -from tests.test_utils.fake_datetime import FakeDatetime - configuration.load_test_config() + from airflow import jobs, models, DAG, utils, macros, settings, exceptions +from airflow.bin import cli +from airflow.exceptions import AirflowException from airflow.models import BaseOperator + +from airflow.hooks.base_hook import BaseHook +from airflow.hooks.sqlite_hook import SqliteHook + +from airflow.operators import sensors from airflow.operators.bash_operator import BashOperator from airflow.operators.check_operator import CheckOperator, ValueCheckOperator from airflow.operators.dagrun_operator import TriggerDagRunOperator from airflow.operators.python_operator import PythonOperator from airflow.operators.dummy_operator import DummyOperator from airflow.operators.http_operator import SimpleHttpOperator -from airflow.operators import sensors -from airflow.hooks.base_hook import BaseHook -from airflow.hooks.sqlite_hook import SqliteHook -from airflow.hooks.postgres_hook import PostgresHook -from airflow.bin import cli -from airflow.www import app as application -from airflow.settings import Session +from airflow.operators.sqlite_operator import SqliteOperator + from airflow.utils.state import State from airflow.utils.dates import infer_time_unit, round_time, scale_time_units from airflow.utils.logging import LoggingMixin -from lxml import html -from airflow.exceptions import AirflowException -from airflow.configuration import AirflowConfigException +from airflow.www import app as application -import six NUM_EXAMPLE_DAGS = 18 DEV_NULL = '/dev/null' @@ -72,15 +79,8 @@ TEST_DAG_ID = 'unit_tests' -try: - import cPickle as pickle -except ImportError: - # Python 3 - import pickle - - def reset(dag_id=TEST_DAG_ID): - session = Session() + session = settings.Session() tis = session.query(models.TaskInstance).filter_by(dag_id=dag_id) tis.delete() session.commit() @@ -250,7 +250,7 @@ def test_schedule_dag_start_end_dates(self): self.assertIsNone(additional_dag_run) - @mock.patch('airflow.jobs.datetime', FakeDatetime) + @freeze_time('2016-01-01') def test_schedule_dag_no_end_date_up_to_today_only(self): """ Tests that a Dag created without an end_date can only be scheduled up @@ -260,9 +260,6 @@ def test_schedule_dag_no_end_date_up_to_today_only(self): start_date of 2015-01-01, only jobs up to, but not including 2016-01-01 should be scheduled. """ - from datetime import datetime - FakeDatetime.now = classmethod(lambda cls: datetime(2016, 1, 1)) - session = settings.Session() delta = timedelta(days=1) start_date = DEFAULT_DATE @@ -417,8 +414,6 @@ def test_bash_operator_multi_byte_output(self): t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) def test_bash_operator_kill(self): - import subprocess - import psutil sleep_time = "100%d" % os.getpid() t = BashOperator( task_id='test_bash_operator_kill', @@ -458,8 +453,7 @@ def test_dryrun(self): t.dry_run() def test_sqlite(self): - import airflow.operators.sqlite_operator - t = airflow.operators.sqlite_operator.SqliteOperator( + t = SqliteOperator( task_id='time_sqlite', sql="CREATE TABLE IF NOT EXISTS unitest (dummy VARCHAR(20))", dag=self.dag) @@ -663,9 +657,14 @@ def test_local_task_job(self): job = jobs.LocalTaskJob(task_instance=ti, ignore_ti_state=True) job.run() - @mock.patch('airflow.utils.dag_processing.datetime', FakeDatetime) + # XXX: For some obscure reason, @freeze_time() causes test_scheduler_job to + # hang on Python 3 using LocalExecutor. So work around it by mocking and + # patching just the necessary datetime.now instances + mock_datetime = mock.Mock(side_effect=datetime, + now=mock.Mock(return_value=datetime(2016, 1, 1))) + @mock.patch('airflow.jobs.datetime', mock_datetime) + @mock.patch('airflow.utils.dag_processing.datetime', mock_datetime) def test_scheduler_job(self): - FakeDatetime.now = classmethod(lambda cls: datetime(2016, 1, 1)) job = jobs.SchedulerJob(dag_id='example_bash_operator', **self.default_scheduler_args) job.run() @@ -764,7 +763,7 @@ def test_config_throw_error_when_original_and_fallback_is_absent(self): FERNET_KEY = configuration.get("core", "FERNET_KEY") configuration.remove_option("core", "FERNET_KEY") - with self.assertRaises(AirflowConfigException) as cm: + with self.assertRaises(exceptions.AirflowConfigException) as cm: configuration.get("core", "FERNET_KEY") exception = str(cm.exception) @@ -806,7 +805,7 @@ def test_class_with_logger_should_have_logger_with_correct_name(self): class Blah(LoggingMixin): pass - self.assertEqual("tests.core.Blah", Blah().logger.name) + self.assertEqual("tests.test_core.Blah", Blah().logger.name) self.assertEqual("airflow.executors.sequential_executor.SequentialExecutor", SequentialExecutor().logger.name) self.assertEqual("airflow.executors.local_executor.LocalExecutor", LocalExecutor().logger.name) @@ -1070,8 +1069,7 @@ def test_cli_initdb(self): cli.initdb(self.parser.parse_args(['initdb'])) def test_cli_connections_list(self): - with mock.patch('sys.stdout', - new_callable=six.StringIO) as mock_stdout: + with mock.patch('sys.stdout', new_callable=StringIO) as mock_stdout: cli.connections(self.parser.parse_args(['connections', '--list'])) stdout = mock_stdout.getvalue() conns = [[x.strip("'") for x in re.findall("'\w+'", line)[:2]] @@ -1090,8 +1088,7 @@ def test_cli_connections_list(self): self.assertIn(['postgres_default', 'postgres'], conns) # Attempt to list connections with invalid cli args - with mock.patch('sys.stdout', - new_callable=six.StringIO) as mock_stdout: + with mock.patch('sys.stdout', new_callable=StringIO) as mock_stdout: cli.connections(self.parser.parse_args( ['connections', '--list', '--conn_id=fake', '--conn_uri=fake-uri'])) @@ -1107,8 +1104,7 @@ def test_cli_connections_list(self): def test_cli_connections_add_delete(self): # Add connections: uri = 'postgresql://airflow:airflow@host:5432/airflow' - with mock.patch('sys.stdout', - new_callable=six.StringIO) as mock_stdout: + with mock.patch('sys.stdout', new_callable=StringIO) as mock_stdout: cli.connections(self.parser.parse_args( ['connections', '--add', '--conn_id=new1', '--conn_uri=%s' % uri])) @@ -1137,8 +1133,7 @@ def test_cli_connections_add_delete(self): ]) # Attempt to add duplicate - with mock.patch('sys.stdout', - new_callable=six.StringIO) as mock_stdout: + with mock.patch('sys.stdout', new_callable=StringIO) as mock_stdout: cli.connections(self.parser.parse_args( ['connections', '--add', '--conn_id=new1', '--conn_uri=%s' % uri])) @@ -1151,8 +1146,7 @@ def test_cli_connections_add_delete(self): ]) # Attempt to add without providing conn_id - with mock.patch('sys.stdout', - new_callable=six.StringIO) as mock_stdout: + with mock.patch('sys.stdout', new_callable=StringIO) as mock_stdout: cli.connections(self.parser.parse_args( ['connections', '--add', '--conn_uri=%s' % uri])) stdout = mock_stdout.getvalue() @@ -1165,8 +1159,7 @@ def test_cli_connections_add_delete(self): ]) # Attempt to add without providing conn_uri - with mock.patch('sys.stdout', - new_callable=six.StringIO) as mock_stdout: + with mock.patch('sys.stdout', new_callable=StringIO) as mock_stdout: cli.connections(self.parser.parse_args( ['connections', '--add', '--conn_id=new'])) stdout = mock_stdout.getvalue() @@ -1197,8 +1190,7 @@ def test_cli_connections_add_delete(self): extra[conn_id])) # Delete connections - with mock.patch('sys.stdout', - new_callable=six.StringIO) as mock_stdout: + with mock.patch('sys.stdout', new_callable=StringIO) as mock_stdout: cli.connections(self.parser.parse_args( ['connections', '--delete', '--conn_id=new1'])) cli.connections(self.parser.parse_args( @@ -1228,8 +1220,7 @@ def test_cli_connections_add_delete(self): self.assertTrue(result is None) # Attempt to delete a non-existing connnection - with mock.patch('sys.stdout', - new_callable=six.StringIO) as mock_stdout: + with mock.patch('sys.stdout', new_callable=StringIO) as mock_stdout: cli.connections(self.parser.parse_args( ['connections', '--delete', '--conn_id=fake'])) stdout = mock_stdout.getvalue() @@ -1241,8 +1232,7 @@ def test_cli_connections_add_delete(self): ]) # Attempt to delete with invalid cli args - with mock.patch('sys.stdout', - new_callable=six.StringIO) as mock_stdout: + with mock.patch('sys.stdout', new_callable=StringIO) as mock_stdout: cli.connections(self.parser.parse_args( ['connections', '--delete', '--conn_id=fake', '--conn_uri=%s' % uri])) @@ -1407,6 +1397,7 @@ def test_variables(self): os.remove('variables1.json') os.remove('variables2.json') + class CSRFTests(unittest.TestCase): def setUp(self): configuration.load_test_config() @@ -1446,6 +1437,7 @@ def tearDown(self): configuration.conf.set("webserver", "expose_config", "False") self.dag_bash.clear(start_date=DEFAULT_DATE, end_date=datetime.now()) + class WebUiTests(unittest.TestCase): def setUp(self): configuration.load_test_config() @@ -1610,7 +1602,7 @@ def test_dag_views(self): self.assertIn("Xcoms", response.data.decode('utf-8')) def test_charts(self): - session = Session() + session = settings.Session() chart_label = "Airflow task instance by type" chart = session.query( models.Chart).filter(models.Chart.label == chart_label).first() @@ -1639,7 +1631,7 @@ def test_fetch_task_instance(self): def tearDown(self): configuration.conf.set("webserver", "expose_config", "False") self.dag_bash.clear(start_date=DEFAULT_DATE, end_date=datetime.now()) - session = Session() + session = settings.Session() session.query(models.DagRun).delete() session.query(models.TaskInstance).delete() session.commit() @@ -1649,14 +1641,15 @@ def tearDown(self): class WebPasswordAuthTest(unittest.TestCase): def setUp(self): configuration.conf.set("webserver", "authenticate", "True") - configuration.conf.set("webserver", "auth_backend", "airflow.contrib.auth.backends.password_auth") + configuration.conf.set("webserver", "auth_backend", + "airflow.contrib.auth.backends.password_auth") app = application.create_app() app.config['TESTING'] = True self.app = app.test_client() from airflow.contrib.auth.backends.password_auth import PasswordUser - session = Session() + session = settings.Session() user = models.User() password_user = PasswordUser(user) password_user.username = 'airflow_passwordauth' @@ -1706,7 +1699,7 @@ def test_unauthorized_password_auth(self): def tearDown(self): configuration.load_test_config() - session = Session() + session = settings.Session() session.query(models.User).delete() session.commit() session.close() @@ -1716,7 +1709,8 @@ def tearDown(self): class WebLdapAuthTest(unittest.TestCase): def setUp(self): configuration.conf.set("webserver", "authenticate", "True") - configuration.conf.set("webserver", "auth_backend", "airflow.contrib.auth.backends.ldap_auth") + configuration.conf.set("webserver", "auth_backend", + "airflow.contrib.auth.backends.ldap_auth") try: configuration.conf.add_section("ldap") except: @@ -1790,7 +1784,7 @@ def test_with_filters(self): def tearDown(self): configuration.load_test_config() - session = Session() + session = settings.Session() session.query(models.User).delete() session.commit() session.close() @@ -1800,7 +1794,8 @@ def tearDown(self): class LdapGroupTest(unittest.TestCase): def setUp(self): configuration.conf.set("webserver", "authenticate", "True") - configuration.conf.set("webserver", "auth_backend", "airflow.contrib.auth.backends.ldap_auth") + configuration.conf.set("webserver", "auth_backend", + "airflow.contrib.auth.backends.ldap_auth") try: configuration.conf.add_section("ldap") except: @@ -1815,12 +1810,9 @@ def setUp(self): def test_group_belonging(self): from airflow.contrib.auth.backends.ldap_auth import LdapUser - users = {"user1": ["group1", "group3"], - "user2": ["group2"] - } + users = {"user1": ["group1", "group3"], "user2": ["group2"]} for user in users: - mu = models.User(username=user, - is_superuser=False) + mu = models.User(username=user, is_superuser=False) auth = LdapUser(mu) self.assertEqual(set(users[user]), set(auth.ldap_groups)) @@ -1829,30 +1821,26 @@ def tearDown(self): configuration.conf.set("webserver", "authenticate", "False") -class FakeSession(object): - def __init__(self): - from requests import Response - self.response = Response() - self.response.status_code = 200 - self.response._content = 'airbnb/airflow'.encode('ascii', 'ignore') - - def send(self, request, **kwargs): - return self.response +def mock_session_send(request, **kwargs): + text = 'airbnb/airflow' + params = dict(parse_qsl(urlparse(request.url).query)) + if 'date' in params: + text += '/' + params['date'] + return mock.Mock(text=text) - def prepare_request(self, request): - if 'date' in request.params: - self.response._content += ( - '/' + request.params['date']).encode('ascii', 'ignore') - return self.response class HttpOpSensorTest(unittest.TestCase): + + patch_session_send = mock.patch('requests.Session.send', + mock.Mock(side_effect=mock_session_send)) + def setUp(self): configuration.load_test_config() args = {'owner': 'airflow', 'start_date': DEFAULT_DATE_ISO} dag = DAG(TEST_DAG_ID, default_args=args) self.dag = dag - @mock.patch('requests.Session', FakeSession) + @patch_session_send def test_get(self): t = SimpleHttpOperator( task_id='get_op', @@ -1863,21 +1851,20 @@ def test_get(self): dag=self.dag) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @mock.patch('requests.Session', FakeSession) + @patch_session_send def test_get_response_check(self): t = SimpleHttpOperator( task_id='get_op', method='GET', endpoint='/search', data={"client": "ubuntu", "q": "airflow"}, - response_check=lambda response: ("airbnb/airflow" in response.text), + response_check=lambda response: "airbnb/airflow" in response.text, headers={}, dag=self.dag) t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) - @mock.patch('requests.Session', FakeSession) + @patch_session_send def test_sensor(self): - sensor = sensors.HttpSensor( task_id='http_sensor_check', http_conn_id='http_default', @@ -1892,87 +1879,6 @@ def test_sensor(self): dag=self.dag) sensor.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) -class FakeWebHDFSHook(object): - def __init__(self, conn_id): - self.conn_id = conn_id - - def get_conn(self): - return self.conn_id - - def check_for_path(self, hdfs_path): - return hdfs_path - - -class FakeSnakeBiteClientException(Exception): - pass - - -class FakeSnakeBiteClient(object): - - def __init__(self): - self.started = True - - def ls(self, path, include_toplevel=False): - """ - the fake snakebite client - :param path: the array of path to test - :param include_toplevel: to return the toplevel directory info - :return: a list for path for the matching queries - """ - if path[0] == '/datadirectory/empty_directory' and not include_toplevel: - return [] - elif path[0] == '/datadirectory/datafile': - return [{'group': u'supergroup', 'permission': 420, 'file_type': 'f', 'access_time': 1481122343796, - 'block_replication': 3, 'modification_time': 1481122343862, 'length': 0, 'blocksize': 134217728, - 'owner': u'hdfs', 'path': '/datadirectory/datafile'}] - elif path[0] == '/datadirectory/empty_directory' and include_toplevel: - return [ - {'group': u'supergroup', 'permission': 493, 'file_type': 'd', 'access_time': 0, 'block_replication': 0, - 'modification_time': 1481132141540, 'length': 0, 'blocksize': 0, 'owner': u'hdfs', - 'path': '/datadirectory/empty_directory'}] - elif path[0] == '/datadirectory/not_empty_directory' and include_toplevel: - return [ - {'group': u'supergroup', 'permission': 493, 'file_type': 'd', 'access_time': 0, 'block_replication': 0, - 'modification_time': 1481132141540, 'length': 0, 'blocksize': 0, 'owner': u'hdfs', - 'path': '/datadirectory/empty_directory'}, - {'group': u'supergroup', 'permission': 420, 'file_type': 'f', 'access_time': 1481122343796, - 'block_replication': 3, 'modification_time': 1481122343862, 'length': 0, 'blocksize': 134217728, - 'owner': u'hdfs', 'path': '/datadirectory/not_empty_directory/test_file'}] - elif path[0] == '/datadirectory/not_empty_directory': - return [{'group': u'supergroup', 'permission': 420, 'file_type': 'f', 'access_time': 1481122343796, - 'block_replication': 3, 'modification_time': 1481122343862, 'length': 0, 'blocksize': 134217728, - 'owner': u'hdfs', 'path': '/datadirectory/not_empty_directory/test_file'}] - elif path[0] == '/datadirectory/not_existing_file_or_directory': - raise FakeSnakeBiteClientException - elif path[0] == '/datadirectory/regex_dir': - return [{'group': u'supergroup', 'permission': 420, 'file_type': 'f', 'access_time': 1481122343796, - 'block_replication': 3, 'modification_time': 1481122343862, 'length': 12582912, 'blocksize': 134217728, - 'owner': u'hdfs', 'path': '/datadirectory/regex_dir/test1file'}, - {'group': u'supergroup', 'permission': 420, 'file_type': 'f', 'access_time': 1481122343796, - 'block_replication': 3, 'modification_time': 1481122343862, 'length': 12582912, 'blocksize': 134217728, - 'owner': u'hdfs', 'path': '/datadirectory/regex_dir/test2file'}, - {'group': u'supergroup', 'permission': 420, 'file_type': 'f', 'access_time': 1481122343796, - 'block_replication': 3, 'modification_time': 1481122343862, 'length': 12582912, 'blocksize': 134217728, - 'owner': u'hdfs', 'path': '/datadirectory/regex_dir/test3file'}, - {'group': u'supergroup', 'permission': 420, 'file_type': 'f', 'access_time': 1481122343796, - 'block_replication': 3, 'modification_time': 1481122343862, 'length': 12582912, 'blocksize': 134217728, - 'owner': u'hdfs', 'path': '/datadirectory/regex_dir/copying_file_1.txt._COPYING_'}, - {'group': u'supergroup', 'permission': 420, 'file_type': 'f', 'access_time': 1481122343796, - 'block_replication': 3, 'modification_time': 1481122343862, 'length': 12582912, 'blocksize': 134217728, - 'owner': u'hdfs', 'path': '/datadirectory/regex_dir/copying_file_3.txt.sftp'} - ] - else: - raise FakeSnakeBiteClientException - - -class FakeHDFSHook(object): - def __init__(self, conn_id=None): - self.conn_id = conn_id - - def get_conn(self): - client = FakeSnakeBiteClient() - return client - class ConnectionTest(unittest.TestCase): def setUp(self): @@ -2032,7 +1938,8 @@ def test_env_var_priority(self): def test_dbapi_get_uri(self): conn = BaseHook.get_connection(conn_id='test_uri') hook = conn.get_hook() - self.assertEqual('postgres://username:password@ec2.compute.com:5432/the_database', hook.get_uri()) + self.assertEqual('postgres://username:password@ec2.compute.com:5432/the_database', + hook.get_uri()) conn2 = BaseHook.get_connection(conn_id='test_uri_no_creds') hook2 = conn2.get_hook() self.assertEqual('postgres://ec2.compute.com/the_database', hook2.get_uri()) @@ -2041,8 +1948,9 @@ def test_dbapi_get_sqlalchemy_engine(self): conn = BaseHook.get_connection(conn_id='test_uri') hook = conn.get_hook() engine = hook.get_sqlalchemy_engine() - self.assertIsInstance(engine, sqlalchemy.engine.Engine) - self.assertEqual('postgres://username:password@ec2.compute.com:5432/the_database', str(engine.url)) + self.assertIsInstance(engine, Engine) + self.assertEqual('postgres://username:password@ec2.compute.com:5432/the_database', + str(engine.url)) class WebHDFSHookTest(unittest.TestCase): @@ -2106,8 +2014,6 @@ def test_remote_cmd(self): def test_tunnel(self): print("Setting up remote listener") - import subprocess - import socket self.handle = self.hook.Popen([ "python", "-c", '"{0}"'.format(HELLO_SERVER_CMD) @@ -2147,7 +2053,7 @@ def test_default_backend(self, mock_send_email): @mock.patch('airflow.utils.email.send_email_smtp') def test_custom_backend(self, mock_send_email): - configuration.set('email', 'EMAIL_BACKEND', 'tests.core.send_email_test') + configuration.set('email', 'EMAIL_BACKEND', 'tests.test_core.send_email_test') utils.email.send_email('to', 'subject', 'content') send_email_test.assert_called_with('to', 'subject', 'content', files=None, dryrun=False, cc=None, bcc=None, mime_subtype='mixed') self.assertFalse(mock_send_email.called) @@ -2249,6 +2155,3 @@ def test_send_mime_dryrun(self, mock_smtp, mock_smtp_ssl): utils.email.send_MIME_email('from', 'to', MIMEMultipart(), dryrun=True) self.assertFalse(mock_smtp.called) self.assertFalse(mock_smtp_ssl.called) - -if __name__ == '__main__': - unittest.main() diff --git a/tests/impersonation.py b/tests/test_impersonation.py similarity index 100% rename from tests/impersonation.py rename to tests/test_impersonation.py diff --git a/tests/jobs.py b/tests/test_jobs.py similarity index 90% rename from tests/jobs.py rename to tests/test_jobs.py index 71470e3f0f..b4d9c4fcd3 100644 --- a/tests/jobs.py +++ b/tests/test_jobs.py @@ -17,41 +17,36 @@ from __future__ import print_function from __future__ import unicode_literals -import datetime import logging import os import shutil import unittest -import six -import sys + +from datetime import datetime, timedelta from tempfile import mkdtemp +import mock +import six + from airflow import AirflowException, settings, models from airflow.bin import cli from airflow.jobs import BackfillJob, SchedulerJob -from airflow.models import DAG, DagModel, DagBag, DagRun, Pool, TaskInstance as TI +from airflow.models import DAG, DagModel, DagBag, DagRun, Pool, TaskInstance from airflow.operators.dummy_operator import DummyOperator from airflow.operators.bash_operator import BashOperator from airflow.utils.db import provide_session from airflow.utils.state import State from airflow.utils.timeout import timeout from airflow.utils.dag_processing import SimpleDagBag -from mock import patch + from tests.executors.test_executor import TestExecutor + from airflow import configuration configuration.load_test_config() -try: - from unittest import mock -except ImportError: - try: - import mock - except ImportError: - mock = None - DEV_NULL = '/dev/null' -DEFAULT_DATE = datetime.datetime(2016, 1, 1) +DEFAULT_DATE = datetime(2016, 1, 1) # Include the words "airflow" and "dag" in the file contents, tricking airflow into thinking these # files contain a DAG (otherwise Airflow will skip them) @@ -86,8 +81,7 @@ def test_trigger_controller_dag(self): dag=dag, start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, - ignore_first_depends_on_past=True - ) + ignore_first_depends_on_past=True) job.run() scheduler = SchedulerJob() @@ -101,27 +95,26 @@ def test_trigger_controller_dag(self): @unittest.skipIf('sqlite' in configuration.get('core', 'sql_alchemy_conn'), "concurrent access not supported in sqlite") def test_backfill_multi_dates(self): + session = settings.Session() + session.query(models.DagRun).delete() dag = self.dagbag.get_dag('example_bash_operator') dag.clear() job = BackfillJob( dag=dag, start_date=DEFAULT_DATE, - end_date=DEFAULT_DATE + datetime.timedelta(days=1), - ignore_first_depends_on_past=True - ) + end_date=DEFAULT_DATE + timedelta(days=1), + ignore_first_depends_on_past=True) job.run() - session = settings.Session() drs = session.query(DagRun).filter( DagRun.dag_id=='example_bash_operator' ).order_by(DagRun.execution_date).all() - self.assertTrue(drs[0].execution_date == DEFAULT_DATE) - self.assertTrue(drs[0].state == State.SUCCESS) - self.assertTrue(drs[1].execution_date == - DEFAULT_DATE + datetime.timedelta(days=1)) - self.assertTrue(drs[1].state == State.SUCCESS) + self.assertEqual(drs[0].execution_date, DEFAULT_DATE) + self.assertEqual(drs[0].state, State.SUCCESS) + self.assertEqual(drs[1].execution_date, DEFAULT_DATE + timedelta(days=1)) + self.assertEqual(drs[1].state, State.SUCCESS) dag.clear() session.close() @@ -186,7 +179,7 @@ def test_backfill_pooled_tasks(self): with timeout(seconds=30): job.run() - ti = TI( + ti = TaskInstance( task=dag.get_task('test_backfill_pooled_task'), execution_date=DEFAULT_DATE) ti.refresh_from_db() @@ -198,7 +191,7 @@ def test_backfill_depends_on_past(self): """ dag = self.dagbag.get_dag('test_depends_on_past') dag.clear() - run_date = DEFAULT_DATE + datetime.timedelta(days=5) + run_date = DEFAULT_DATE + timedelta(days=5) # backfill should deadlock self.assertRaisesRegexp( @@ -213,7 +206,7 @@ def test_backfill_depends_on_past(self): ignore_first_depends_on_past=True).run() # ti should have succeeded - ti = TI(dag.tasks[0], run_date) + ti = TaskInstance(dag.tasks[0], run_date) ti.refresh_from_db() self.assertEquals(ti.state, State.SUCCESS) @@ -222,7 +215,7 @@ def test_cli_backfill_depends_on_past(self): Test that CLI respects -I argument """ dag_id = 'test_dagrun_states_deadlock' - run_date = DEFAULT_DATE + datetime.timedelta(days=1) + run_date = DEFAULT_DATE + timedelta(days=1) args = [ 'backfill', dag_id, @@ -240,7 +233,7 @@ def test_cli_backfill_depends_on_past(self): self.parser.parse_args(args)) cli.backfill(self.parser.parse_args(args + ['-I'])) - ti = TI(dag.get_task('test_depends_on_past'), run_date) + ti = TaskInstance(dag.get_task('test_depends_on_past'), run_date) ti.refresh_from_db() # task ran self.assertEqual(ti.state, State.SUCCESS) @@ -311,7 +304,7 @@ def evaluate_dagrun( # test tasks for task_id, expected_state in expected_task_states.items(): task = dag.get_task(task_id) - ti = TI(task, ex_date) + ti = TaskInstance(task, ex_date) ti.refresh_from_db() self.assertEqual(ti.state, expected_state) @@ -409,8 +402,8 @@ def test_scheduler_start_date(self): # zero tasks ran session = settings.Session() - self.assertEqual( - len(session.query(TI).filter(TI.dag_id == dag_id).all()), 0) + self.assertEqual(len(session.query(TaskInstance). + filter(TaskInstance.dag_id == dag_id).all()), 0) # previously, running this backfill would kick off the Scheduler # because it would take the most recent run and start from there @@ -424,8 +417,8 @@ def test_scheduler_start_date(self): # one task ran session = settings.Session() - self.assertEqual( - len(session.query(TI).filter(TI.dag_id == dag_id).all()), 1) + self.assertEqual(len(session.query(TaskInstance). + filter(TaskInstance.dag_id == dag_id).all()), 1) scheduler = SchedulerJob(dag_id, num_runs=2, @@ -434,8 +427,8 @@ def test_scheduler_start_date(self): # still one task session = settings.Session() - self.assertEqual( - len(session.query(TI).filter(TI.dag_id == dag_id).all()), 1) + self.assertEqual(len(session.query(TaskInstance). + filter(TaskInstance.dag_id == dag_id).all()), 1) def test_scheduler_multiprocessing(self): """ @@ -455,8 +448,8 @@ def test_scheduler_multiprocessing(self): # zero tasks ran dag_id = 'test_start_date_scheduling' session = settings.Session() - self.assertEqual( - len(session.query(TI).filter(TI.dag_id == dag_id).all()), 0) + self.assertEqual(len(session.query(TaskInstance). + filter(TaskInstance.dag_id == dag_id).all()), 0) def test_scheduler_dagrun_once(self): """ @@ -465,7 +458,7 @@ def test_scheduler_dagrun_once(self): """ dag = DAG( 'test_scheduler_dagrun_once', - start_date=datetime.datetime(2015, 1, 1), + start_date=datetime(2015, 1, 1), schedule_interval="@once") scheduler = SchedulerJob() @@ -539,7 +532,7 @@ def test_scheduler_do_not_schedule_removed_task(self): def test_scheduler_do_not_schedule_too_early(self): dag = DAG( dag_id='test_scheduler_do_not_schedule_too_early', - start_date=datetime.datetime(2200, 1, 1)) + start_date=datetime(2200, 1, 1)) dag_task1 = DummyOperator( task_id='dummy', dag=dag, @@ -669,7 +662,7 @@ def test_scheduler_fail_dagrun_timeout(self): dag = DAG( dag_id='test_scheduler_fail_dagrun_timeout', start_date=DEFAULT_DATE) - dag.dagrun_timeout = datetime.timedelta(seconds=60) + dag.dagrun_timeout = timedelta(seconds=60) dag_task1 = DummyOperator( task_id='dummy', @@ -686,7 +679,7 @@ def test_scheduler_fail_dagrun_timeout(self): dr = scheduler.create_dag_run(dag) self.assertIsNotNone(dr) - dr.start_date = datetime.datetime.now() - datetime.timedelta(days=1) + dr.start_date = datetime.now() - timedelta(days=1) session.merge(dr) session.commit() @@ -698,14 +691,16 @@ def test_scheduler_fail_dagrun_timeout(self): def test_scheduler_verify_max_active_runs_and_dagrun_timeout(self): """ - Test if a a dagrun will not be scheduled if max_dag_runs has been reached and dagrun_timeout is not reached - Test if a a dagrun will be scheduled if max_dag_runs has been reached but dagrun_timeout is also reached + Test if a a dagrun will not be scheduled if max_dag_runs has been + reached and dagrun_timeout is not reached + Test if a a dagrun will be scheduled if max_dag_runs has been reached + but dagrun_timeout is also reached """ dag = DAG( dag_id='test_scheduler_verify_max_active_runs_and_dagrun_timeout', start_date=DEFAULT_DATE) dag.max_active_runs = 1 - dag.dagrun_timeout = datetime.timedelta(seconds=60) + dag.dagrun_timeout = timedelta(seconds=60) dag_task1 = DummyOperator( task_id='dummy', @@ -729,7 +724,7 @@ def test_scheduler_verify_max_active_runs_and_dagrun_timeout(self): self.assertIsNone(new_dr) # Should be scheduled as dagrun_timeout has passed - dr.start_date = datetime.datetime.now() - datetime.timedelta(days=1) + dr.start_date = datetime.now() - timedelta(days=1) session.merge(dr) session.commit() new_dr = scheduler.create_dag_run(dag) @@ -774,7 +769,7 @@ def test_scheduler_max_active_runs_respected_after_clear(self): (dag.dag_id, dag_task1.task_id, DEFAULT_DATE) ) - @patch.object(TI, 'pool_full') + @mock.patch.object(TaskInstance, 'pool_full') def test_scheduler_verify_pool_full(self, mock_pool_full): """ Test task instances not queued when pool is full @@ -816,7 +811,7 @@ def test_scheduler_verify_pool_full(self, mock_pool_full): # Recreated part of the scheduler here, to kick off tasks -> executor for ti_key in queue: task = dag.get_task(ti_key[1]) - ti = TI(task, ti_key[2]) + ti = TaskInstance(task, ti_key[2]) # Task starts out in the scheduled state. All tasks in the # scheduled state will be sent to the executor ti.state = State.SCHEDULED @@ -825,9 +820,8 @@ def test_scheduler_verify_pool_full(self, mock_pool_full): session.merge(ti) session.commit() - scheduler._execute_task_instances(dagbag, - (State.SCHEDULED, - State.UP_FOR_RETRY)) + scheduler._execute_task_instances(dagbag, (State.SCHEDULED, + State.UP_FOR_RETRY)) self.assertEquals(len(scheduler.executor.queued_tasks), 1) @@ -840,9 +834,8 @@ def test_scheduler_auto_align(self): """ dag = DAG( dag_id='test_scheduler_auto_align_1', - start_date=datetime.datetime(2016, 1, 1, 10, 10, 0), - schedule_interval="4 5 * * *" - ) + start_date=datetime(2016, 1, 1, 10, 10, 0), + schedule_interval="4 5 * * *") dag_task1 = DummyOperator( task_id='dummy', dag=dag, @@ -858,13 +851,12 @@ def test_scheduler_auto_align(self): dr = scheduler.create_dag_run(dag) self.assertIsNotNone(dr) - self.assertEquals(dr.execution_date, datetime.datetime(2016, 1, 2, 5, 4)) + self.assertEquals(dr.execution_date, datetime(2016, 1, 2, 5, 4)) dag = DAG( dag_id='test_scheduler_auto_align_2', - start_date=datetime.datetime(2016, 1, 1, 10, 10, 0), - schedule_interval="10 10 * * *" - ) + start_date=datetime(2016, 1, 1, 10, 10, 0), + schedule_interval="10 10 * * *") dag_task1 = DummyOperator( task_id='dummy', dag=dag, @@ -880,7 +872,7 @@ def test_scheduler_auto_align(self): dr = scheduler.create_dag_run(dag) self.assertIsNotNone(dr) - self.assertEquals(dr.execution_date, datetime.datetime(2016, 1, 1, 10, 10)) + self.assertEquals(dr.execution_date, datetime(2016, 1, 1, 10, 10)) def test_scheduler_reschedule(self): """ @@ -1028,8 +1020,10 @@ def test_retry_handling_job(self): scheduler.run() session = settings.Session() - ti = session.query(TI).filter(TI.dag_id==dag.dag_id, - TI.task_id==dag_task1.task_id).first() + ti = session.query(TaskInstance).filter( + TaskInstance.dag_id==dag.dag_id, + TaskInstance.task_id==dag_task1.task_id + ).first() # make sure the counter has increased self.assertEqual(ti.try_number, 2) @@ -1045,12 +1039,12 @@ def test_scheduler_run_duration(self): self.assertTrue(dag.start_date > DEFAULT_DATE) expected_run_duration = 5 - start_time = datetime.datetime.now() + start_time = datetime.now() scheduler = SchedulerJob(dag_id, run_duration=expected_run_duration, **self.default_scheduler_args) scheduler.run() - end_time = datetime.datetime.now() + end_time = datetime.now() run_duration = (end_time - start_time).total_seconds() logging.info("Test ran in %.2fs, expected %.2fs", @@ -1082,20 +1076,18 @@ def test_dag_with_system_exit(self): **self.default_scheduler_args) scheduler.run() session = settings.Session() - self.assertEqual( - len(session.query(TI).filter(TI.dag_id == dag_id).all()), 1) + self.assertEqual(len(session.query(TaskInstance). + filter(TaskInstance.dag_id == dag_id).all()), 1) def test_dag_get_active_runs(self): """ Test to check that a DAG returns it's active runs """ - - now = datetime.datetime.now() - six_hours_ago_to_the_hour = (now - datetime.timedelta(hours=6)).replace(minute=0, second=0, microsecond=0) - - START_DATE = six_hours_ago_to_the_hour + now = datetime.now() + six_hours_ago_to_the_hour = now - timedelta(hours=6) + START_DATE = six_hours_ago_to_the_hour.replace(minute=0, second=0, + microsecond=0) DAG_NAME1 = 'get_active_runs_test' - default_args = { 'owner': 'airflow', 'depends_on_past': False, @@ -1105,8 +1097,7 @@ def test_dag_get_active_runs(self): dag1 = DAG(DAG_NAME1, schedule_interval='* * * * *', max_active_runs=1, - default_args=default_args - ) + default_args=default_args) run_this_1 = DummyOperator(task_id='run_this_1', dag=dag1) run_this_2 = DummyOperator(task_id='run_this_2', dag=dag1) @@ -1131,29 +1122,22 @@ def test_dag_get_active_runs(self): execution_date = dr.execution_date running_dates = dag1.get_active_runs() - - try: - running_date = running_dates[0] - except: - running_date = 'Except' - - self.assertEqual(execution_date, running_date, 'Running Date must match Execution Date') + self.assertEqual(execution_date, running_dates[0]) def test_dag_catchup_option(self): """ - Test to check that a DAG with catchup = False only schedules beginning now, not back to the start date + Test to check that a DAG with catchup = False only schedules beginning + now, not back to the start date """ - - now = datetime.datetime.now() - six_hours_ago_to_the_hour = (now - datetime.timedelta(hours=6)).replace(minute=0, second=0, microsecond=0) - three_minutes_ago = now - datetime.timedelta(minutes=3) - two_hours_and_three_minutes_ago = three_minutes_ago - datetime.timedelta(hours=2) - - START_DATE = six_hours_ago_to_the_hour + now = datetime.now() + six_hours_ago_to_the_hour = now - timedelta(hours=6) + three_minutes_ago = now - timedelta(minutes=3) + two_hours_and_three_minutes_ago = three_minutes_ago - timedelta(hours=2) + START_DATE = six_hours_ago_to_the_hour.replace(minute=0, second=0, + microsecond=0) DAG_NAME1 = 'no_catchup_test1' DAG_NAME2 = 'no_catchup_test2' DAG_NAME3 = 'no_catchup_test3' - default_args = { 'owner': 'airflow', 'depends_on_past': False, @@ -1163,8 +1147,7 @@ def test_dag_catchup_option(self): dag1 = DAG(DAG_NAME1, schedule_interval='* * * * *', max_active_runs=1, - default_args=default_args - ) + default_args=default_args) default_catchup = configuration.getboolean('scheduler', 'catchup_by_default') # Test configs have catchup by default ON @@ -1175,11 +1158,10 @@ def test_dag_catchup_option(self): self.assertEqual(dag1.catchup, True) dag2 = DAG(DAG_NAME2, - schedule_interval='* * * * *', - max_active_runs=1, - catchup=False, - default_args=default_args - ) + schedule_interval='* * * * *', + max_active_runs=1, + catchup=False, + default_args=default_args) run_this_1 = DummyOperator(task_id='run_this_1', dag=dag2) run_this_2 = DummyOperator(task_id='run_this_2', dag=dag2) @@ -1205,14 +1187,13 @@ def test_dag_catchup_option(self): self.assertGreater(dr.execution_date, three_minutes_ago) # The DR should be scheduled BEFORE now - self.assertLess(dr.execution_date, datetime.datetime.now()) + self.assertLess(dr.execution_date, datetime.now()) dag3 = DAG(DAG_NAME3, schedule_interval='@hourly', max_active_runs=1, catchup=False, - default_args=default_args - ) + default_args=default_args) run_this_1 = DummyOperator(task_id='run_this_1', dag=dag3) run_this_2 = DummyOperator(task_id='run_this_2', dag=dag3) @@ -1239,7 +1220,7 @@ def test_dag_catchup_option(self): self.assertGreater(dr.execution_date, two_hours_and_three_minutes_ago) # The DR should be scheduled BEFORE now - self.assertLess(dr.execution_date, datetime.datetime.now()) + self.assertLess(dr.execution_date, datetime.now()) def test_add_unparseable_file_before_sched_start_creates_import_error(self): try: diff --git a/tests/models.py b/tests/test_models.py similarity index 100% rename from tests/models.py rename to tests/test_models.py diff --git a/tests/plugins_manager.py b/tests/test_plugins_manager.py similarity index 100% rename from tests/plugins_manager.py rename to tests/test_plugins_manager.py diff --git a/tests/utils.py b/tests/test_utils.py similarity index 100% rename from tests/utils.py rename to tests/test_utils.py diff --git a/tests/test_utils/README.md b/tests/test_utils/README.md deleted file mode 100644 index 8a5c90dfd2..0000000000 --- a/tests/test_utils/README.md +++ /dev/null @@ -1 +0,0 @@ -Utilities for use in tests. diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py deleted file mode 100644 index 9d7677a99b..0000000000 --- a/tests/test_utils/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/test_utils/fake_datetime.py b/tests/test_utils/fake_datetime.py deleted file mode 100644 index 9b8102f38f..0000000000 --- a/tests/test_utils/fake_datetime.py +++ /dev/null @@ -1,24 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from datetime import datetime - - -class FakeDatetime(datetime): - """ - A fake replacement for datetime that can be mocked for testing. - """ - - def __new__(cls, *args, **kwargs): - return date.__new__(datetime, *args, **kwargs) diff --git a/tests/ti_deps/contexts/__init__.py b/tests/ti_deps/contexts/__init__.py deleted file mode 100644 index 9d7677a99b..0000000000 --- a/tests/ti_deps/contexts/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/ti_deps/deps/__init__.py b/tests/ti_deps/deps/__init__.py deleted file mode 100644 index 9d7677a99b..0000000000 --- a/tests/ti_deps/deps/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/tests/ti_deps/deps/fake_models.py b/tests/ti_deps/deps/fake_models.py deleted file mode 100644 index 777b7f2d15..0000000000 --- a/tests/ti_deps/deps/fake_models.py +++ /dev/null @@ -1,57 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# A collection of fake models used for unit testing - - -class FakeTI(object): - - def __init__(self, **kwds): - self.__dict__.update(kwds) - - def pool_full(self): - # Allow users of this fake to set pool_filled in the contructor to make this - # return True - try: - return self.pool_filled - except AttributeError: - # If pool_filled was not set default to false - return False - - def get_dagrun(self, _): - return self.dagrun - - def are_dependents_done(self, session): - return self.dependents_done - - -class FakeTask(object): - - def __init__(self, **kwds): - self.__dict__.update(kwds) - - -class FakeDag(object): - - def __init__(self, **kwds): - self.__dict__.update(kwds) - - def get_running_dagruns(self, _): - return self.running_dagruns - - -class FakeContext(object): - - def __init__(self, **kwds): - self.__dict__.update(kwds) diff --git a/tests/ti_deps/deps/pool_has_space_dep.py b/tests/ti_deps/deps/pool_has_space_dep.py deleted file mode 100644 index 411547a901..0000000000 --- a/tests/ti_deps/deps/pool_has_space_dep.py +++ /dev/null @@ -1,37 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -from airflow.ti_deps.deps.pool_has_space_dep import PoolHasSpaceDep -from fake_models import FakeTI - - -class PoolHasSpaceDepTest(unittest.TestCase): - - def test_pool_full(self): - """ - Full pools should fail this dep - """ - ti = FakeTI(pool="fake_pool", pool_filled=True) - - self.assertFalse(PoolHasSpaceDep().is_met(ti=ti, dep_context=None)) - - def test_not_skipped(self): - """ - Pools with room should pass this dep - """ - ti = FakeTI(pool="fake_pool", pool_filled=False) - - self.assertTrue(PoolHasSpaceDep().is_met(ti=ti, dep_context=None)) diff --git a/tests/ti_deps/deps/prev_dagrun_dep.py b/tests/ti_deps/deps/prev_dagrun_dep.py deleted file mode 100644 index 4873467780..0000000000 --- a/tests/ti_deps/deps/prev_dagrun_dep.py +++ /dev/null @@ -1,143 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from datetime import datetime - -from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep -from airflow.utils.state import State -from fake_models import FakeContext, FakeTask, FakeTI - - -class PrevDagrunDepTest(unittest.TestCase): - - def test_not_depends_on_past(self): - """ - If depends on past isn't set in the task then the previous dagrun should be - ignored, even though there is no previous_ti which would normally fail the dep - """ - task = FakeTask( - depends_on_past=False, - start_date=datetime(2016, 1, 1), - wait_for_downstream=False) - prev_ti = FakeTI( - task=task, - execution_date=datetime(2016, 1, 2), - state=State.SUCCESS, - dependents_done=True) - ti = FakeTI( - task=task, - previous_ti=prev_ti, - execution_date=datetime(2016, 1, 3)) - dep_context = FakeContext(ignore_depends_on_past=False) - - self.assertTrue(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)) - - def test_context_ignore_depends_on_past(self): - """ - If the context overrides depends_on_past then the dep should be met, even though - there is no previous_ti which would normally fail the dep - """ - task = FakeTask( - depends_on_past=True, - start_date=datetime(2016, 1, 1), - wait_for_downstream=False) - prev_ti = FakeTI( - task=task, - execution_date=datetime(2016, 1, 2), - state=State.SUCCESS, - dependents_done=True) - ti = FakeTI( - task=task, - previous_ti=prev_ti, - execution_date=datetime(2016, 1, 3)) - dep_context = FakeContext(ignore_depends_on_past=True) - - self.assertTrue(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)) - - def test_first_task_run(self): - """ - The first task run for a TI should pass since it has no previous dagrun. - """ - task = FakeTask( - depends_on_past=True, - start_date=datetime(2016, 1, 1), - wait_for_downstream=False) - prev_ti = None - ti = FakeTI( - task=task, - previous_ti=prev_ti, - execution_date=datetime(2016, 1, 1)) - dep_context = FakeContext(ignore_depends_on_past=False) - - self.assertTrue(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)) - - def test_prev_ti_bad_state(self): - """ - If the previous TI did not complete execution this dep should fail. - """ - task = FakeTask( - depends_on_past=True, - start_date=datetime(2016, 1, 1), - wait_for_downstream=False) - prev_ti = FakeTI( - state=State.NONE, - dependents_done=True) - ti = FakeTI( - task=task, - previous_ti=prev_ti, - execution_date=datetime(2016, 1, 2)) - dep_context = FakeContext(ignore_depends_on_past=False) - - self.assertFalse(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)) - - def test_failed_wait_for_downstream(self): - """ - If the previous TI specified to wait for the downstream tasks of the previous - dagrun then it should fail this dep if the downstream TIs of the previous TI are - not done. - """ - task = FakeTask( - depends_on_past=True, - start_date=datetime(2016, 1, 1), - wait_for_downstream=True) - prev_ti = FakeTI( - state=State.SUCCESS, - dependents_done=False) - ti = FakeTI( - task=task, - previous_ti=prev_ti, - execution_date=datetime(2016, 1, 2)) - dep_context = FakeContext(ignore_depends_on_past=False) - - self.assertFalse(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)) - - def test_all_met(self): - """ - Test to make sure all of the conditions for the dep are met - """ - task = FakeTask( - depends_on_past=True, - start_date=datetime(2016, 1, 1), - wait_for_downstream=True) - prev_ti = FakeTI( - state=State.SUCCESS, - dependents_done=True) - ti = FakeTI( - task=task, - previous_ti=prev_ti, - execution_date=datetime(2016, 1, 2)) - dep_context = FakeContext(ignore_depends_on_past=False) - - self.assertTrue(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)) diff --git a/tests/ti_deps/deps/runnable_exec_date_dep.py b/tests/ti_deps/deps/runnable_exec_date_dep.py deleted file mode 100644 index ae09ddbd8e..0000000000 --- a/tests/ti_deps/deps/runnable_exec_date_dep.py +++ /dev/null @@ -1,92 +0,0 @@ -# -*- coding: utf-8 -*- -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -from datetime import datetime -from mock import patch - -from airflow.ti_deps.deps.runnable_exec_date_dep import RunnableExecDateDep -from fake_models import FakeDag, FakeTask, FakeTI -from tests.test_utils.fake_datetime import FakeDatetime - - -class RunnableExecDateDepTest(unittest.TestCase): - - @patch('airflow.ti_deps.deps.runnable_exec_date_dep.datetime', FakeDatetime) - def test_exec_date_after_end_date(self): - """ - If the dag's execution date is in the future this dep should fail - """ - FakeDatetime.now = classmethod(lambda cls: datetime(2016, 1, 1)) - dag = FakeDag(end_date=datetime(2016, 1, 3)) - task = FakeTask(dag=dag, end_date=datetime(2016, 1, 3)) - ti = FakeTI(task=task, execution_date=datetime(2016, 1, 2)) - - self.assertFalse(RunnableExecDateDep().is_met(ti=ti, dep_context=None)) - - def test_exec_date_before_task_end_date(self): - """ - If the task instance execution date is before the DAG's end date this dep should - fail - """ - FakeDatetime.now = classmethod(lambda cls: datetime(2016, 1, 3)) - dag = FakeDag(end_date=datetime(2016, 1, 1)) - task = FakeTask(dag=dag, end_date=datetime(2016, 1, 2)) - ti = FakeTI(task=task, execution_date=datetime(2016, 1, 1)) - - self.assertFalse(RunnableExecDateDep().is_met(ti=ti, dep_context=None)) - - def test_exec_date_after_task_end_date(self): - """ - If the task instance execution date is after the DAG's end date this dep should - fail - """ - FakeDatetime.now = classmethod(lambda cls: datetime(2016, 1, 3)) - dag = FakeDag(end_date=datetime(2016, 1, 3)) - task = FakeTask(dag=dag, end_date=datetime(2016, 1, 1)) - ti = FakeTI(task=task, execution_date=datetime(2016, 1, 2)) - - self.assertFalse(RunnableExecDateDep().is_met(ti=ti, dep_context=None)) - - def test_exec_date_before_dag_end_date(self): - """ - If the task instance execution date is before the dag's end date this dep should - fail - """ - dag = FakeDag(start_date=datetime(2016, 1, 2)) - task = FakeTask(dag=dag, start_date=datetime(2016, 1, 1)) - ti = FakeTI(task=task, execution_date=datetime(2016, 1, 1)) - - self.assertFalse(RunnableExecDateDep().is_met(ti=ti, dep_context=None)) - - def test_exec_date_after_dag_end_date(self): - """ - If the task instance execution date is after the dag's end date this dep should - fail - """ - dag = FakeDag(end_date=datetime(2016, 1, 1)) - task = FakeTask(dag=dag, end_date=datetime(2016, 1, 3)) - ti = FakeTI(task=task, execution_date=datetime(2016, 1, 2)) - - self.assertFalse(RunnableExecDateDep().is_met(ti=ti, dep_context=None)) - - def test_all_deps_met(self): - """ - Test to make sure all of the conditions for the dep are met - """ - dag = FakeDag(end_date=datetime(2016, 1, 2)) - task = FakeTask(dag=dag, end_date=datetime(2016, 1, 2)) - ti = FakeTI(task=task, execution_date=datetime(2016, 1, 1)) - - self.assertTrue(RunnableExecDateDep().is_met(ti=ti, dep_context=None)) diff --git a/tests/ti_deps/deps/dag_ti_slots_available_dep.py b/tests/ti_deps/test_dag_ti_slots_available_dep.py similarity index 65% rename from tests/ti_deps/deps/dag_ti_slots_available_dep.py rename to tests/ti_deps/test_dag_ti_slots_available_dep.py index 6077d9685c..6910b66ee1 100644 --- a/tests/ti_deps/deps/dag_ti_slots_available_dep.py +++ b/tests/ti_deps/test_dag_ti_slots_available_dep.py @@ -13,9 +13,10 @@ # limitations under the License. import unittest +from mock import Mock +from airflow.models import TaskInstance from airflow.ti_deps.deps.dag_ti_slots_available_dep import DagTISlotsAvailableDep -from fake_models import FakeDag, FakeTask, FakeTI class DagTISlotsAvailableDepTest(unittest.TestCase): @@ -24,18 +25,18 @@ def test_concurrency_reached(self): """ Test concurrency reached should fail dep """ - dag = FakeDag(concurrency=1, concurrency_reached=True) - task = FakeTask(dag=dag) - ti = FakeTI(task=task, dag_id="fake_dag") + dag = Mock(concurrency=1, concurrency_reached=True) + task = Mock(dag=dag) + ti = TaskInstance(task, execution_date=None) - self.assertFalse(DagTISlotsAvailableDep().is_met(ti=ti, dep_context=None)) + self.assertFalse(DagTISlotsAvailableDep().is_met(ti=ti)) def test_all_conditions_met(self): """ Test all conditions met should pass dep """ - dag = FakeDag(concurrency=1, concurrency_reached=False) - task = FakeTask(dag=dag) - ti = FakeTI(task=task, dag_id="fake_dag") + dag = Mock(concurrency=1, concurrency_reached=False) + task = Mock(dag=dag) + ti = TaskInstance(task, execution_date=None) - self.assertTrue(DagTISlotsAvailableDep().is_met(ti=ti, dep_context=None)) + self.assertTrue(DagTISlotsAvailableDep().is_met(ti=ti)) diff --git a/tests/ti_deps/deps/dag_unpaused_dep.py b/tests/ti_deps/test_dag_unpaused_dep.py similarity index 67% rename from tests/ti_deps/deps/dag_unpaused_dep.py rename to tests/ti_deps/test_dag_unpaused_dep.py index 8721a51166..969889a37a 100644 --- a/tests/ti_deps/deps/dag_unpaused_dep.py +++ b/tests/ti_deps/test_dag_unpaused_dep.py @@ -13,9 +13,10 @@ # limitations under the License. import unittest +from mock import Mock +from airflow.models import TaskInstance from airflow.ti_deps.deps.dag_unpaused_dep import DagUnpausedDep -from fake_models import FakeDag, FakeTask, FakeTI class DagUnpausedDepTest(unittest.TestCase): @@ -24,18 +25,18 @@ def test_concurrency_reached(self): """ Test paused DAG should fail dependency """ - dag = FakeDag(is_paused=True) - task = FakeTask(dag=dag) - ti = FakeTI(task=task, dag_id="fake_dag") + dag = Mock(is_paused=True) + task = Mock(dag=dag) + ti = TaskInstance(task=task, execution_date=None) - self.assertFalse(DagUnpausedDep().is_met(ti=ti, dep_context=None)) + self.assertFalse(DagUnpausedDep().is_met(ti=ti)) def test_all_conditions_met(self): """ Test all conditions met should pass dep """ - dag = FakeDag(is_paused=False) - task = FakeTask(dag=dag) - ti = FakeTI(task=task, dag_id="fake_dag") + dag = Mock(is_paused=False) + task = Mock(dag=dag) + ti = TaskInstance(task=task, execution_date=None) - self.assertTrue(DagUnpausedDep().is_met(ti=ti, dep_context=None)) + self.assertTrue(DagUnpausedDep().is_met(ti=ti)) diff --git a/tests/ti_deps/deps/dagrun_exists_dep.py b/tests/ti_deps/test_dagrun_exists_dep.py similarity index 61% rename from tests/ti_deps/deps/dagrun_exists_dep.py rename to tests/ti_deps/test_dagrun_exists_dep.py index 1141647342..daad269fda 100644 --- a/tests/ti_deps/deps/dagrun_exists_dep.py +++ b/tests/ti_deps/test_dagrun_exists_dep.py @@ -13,29 +13,28 @@ # limitations under the License. import unittest +from airflow.utils.state import State +from mock import Mock, patch +from airflow.models import DAG, DagRun from airflow.ti_deps.deps.dagrun_exists_dep import DagrunRunningDep -from fake_models import FakeDag, FakeTask, FakeTI class DagrunRunningDepTest(unittest.TestCase): - def test_dagrun_doesnt_exist(self): + @patch('airflow.models.DagRun.find', return_value=()) + def test_dagrun_doesnt_exist(self, dagrun_find): """ Task instances without dagruns should fail this dep """ - dag = FakeDag(running_dagruns=[], max_active_runs=1) - task = FakeTask(dag=dag) - ti = FakeTI(dagrun=None, task=task, dag_id="fake_dag") - - self.assertFalse(DagrunRunningDep().is_met(ti=ti, dep_context=None)) + dag = DAG('test_dag', max_active_runs=2) + ti = Mock(task=Mock(dag=dag), get_dagrun=Mock(return_value=None)) + self.assertFalse(DagrunRunningDep().is_met(ti=ti)) def test_dagrun_exists(self): """ Task instances with a dagrun should pass this dep """ - dag = FakeDag(running_dagruns=[], max_active_runs=1) - task = FakeTask(dag=dag) - ti = FakeTI(dagrun="Fake Dagrun", task=task, dag_id="fake_dag") - - self.assertTrue(DagrunRunningDep().is_met(ti=ti, dep_context=None)) + dagrun = DagRun(state=State.RUNNING) + ti = Mock(get_dagrun=Mock(return_value=dagrun)) + self.assertTrue(DagrunRunningDep().is_met(ti=ti)) diff --git a/tests/ti_deps/deps/not_in_retry_period_dep.py b/tests/ti_deps/test_not_in_retry_period_dep.py similarity index 53% rename from tests/ti_deps/deps/not_in_retry_period_dep.py rename to tests/ti_deps/test_not_in_retry_period_dep.py index a6657ba8c0..0f23aab333 100644 --- a/tests/ti_deps/deps/not_in_retry_period_dep.py +++ b/tests/ti_deps/test_not_in_retry_period_dep.py @@ -14,48 +14,46 @@ import unittest from datetime import datetime, timedelta +from freezegun import freeze_time +from mock import Mock +from airflow.models import TaskInstance from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep from airflow.utils.state import State -from fake_models import FakeDag, FakeTask, FakeTI class NotInRetryPeriodDepTest(unittest.TestCase): + def _get_task_instance(self, state, end_date=None, + retry_delay=timedelta(minutes=15)): + task = Mock(retry_delay=retry_delay, retry_exponential_backoff=False) + ti = TaskInstance(task=task, state=state, execution_date=None) + ti.end_date = end_date + return ti + + @freeze_time('2016-01-01 15:44') def test_still_in_retry_period(self): """ Task instances that are in their retry period should fail this dep """ - dag = FakeDag() - task = FakeTask(dag=dag, retry_delay=timedelta(minutes=1)) - ti = FakeTI( - task=task, - state=State.UP_FOR_RETRY, - end_date=datetime(2016, 1, 1), - is_premature=True) - - self.assertFalse(NotInRetryPeriodDep().is_met(ti=ti, dep_context=None)) + ti = self._get_task_instance(State.UP_FOR_RETRY, + end_date=datetime(2016, 1, 1, 15, 30)) + self.assertTrue(ti.is_premature) + self.assertFalse(NotInRetryPeriodDep().is_met(ti=ti)) + @freeze_time('2016-01-01 15:46') def test_retry_period_finished(self): """ Task instance's that have had their retry period elapse should pass this dep """ - dag = FakeDag() - task = FakeTask(dag=dag, retry_delay=timedelta(minutes=1)) - ti = FakeTI( - task=task, - state=State.UP_FOR_RETRY, - end_date=datetime(2016, 1, 1), - is_premature=False) - - self.assertTrue(NotInRetryPeriodDep().is_met(ti=ti, dep_context=None)) + ti = self._get_task_instance(State.UP_FOR_RETRY, + end_date=datetime(2016, 1, 1)) + self.assertFalse(ti.is_premature) + self.assertTrue(NotInRetryPeriodDep().is_met(ti=ti)) def test_not_in_retry_period(self): """ Task instance's that are not up for retry can not be in their retry period """ - dag = FakeDag() - task = FakeTask(dag=dag) - ti = FakeTI(task=task, state=State.SUCCESS) - - self.assertTrue(NotInRetryPeriodDep().is_met(ti=ti, dep_context=None)) + ti = self._get_task_instance(State.SUCCESS) + self.assertTrue(NotInRetryPeriodDep().is_met(ti=ti)) diff --git a/tests/ti_deps/deps/not_running_dep.py b/tests/ti_deps/test_not_running_dep.py similarity index 75% rename from tests/ti_deps/deps/not_running_dep.py rename to tests/ti_deps/test_not_running_dep.py index 159d923318..7f8f0cd527 100644 --- a/tests/ti_deps/deps/not_running_dep.py +++ b/tests/ti_deps/test_not_running_dep.py @@ -14,10 +14,10 @@ import unittest from datetime import datetime +from mock import Mock from airflow.ti_deps.deps.not_running_dep import NotRunningDep from airflow.utils.state import State -from fake_models import FakeTI class NotRunningDepTest(unittest.TestCase): @@ -26,14 +26,12 @@ def test_ti_running(self): """ Running task instances should fail this dep """ - ti = FakeTI(state=State.RUNNING, start_date=datetime(2016, 1, 1)) - - self.assertFalse(NotRunningDep().is_met(ti=ti, dep_context=None)) + ti = Mock(state=State.RUNNING, start_date=datetime(2016, 1, 1)) + self.assertFalse(NotRunningDep().is_met(ti=ti)) def test_ti_not_running(self): """ Non-running task instances should pass this dep """ - ti = FakeTI(state=State.NONE, start_date=datetime(2016, 1, 1)) - - self.assertTrue(NotRunningDep().is_met(ti=ti, dep_context=None)) + ti = Mock(state=State.NONE, start_date=datetime(2016, 1, 1)) + self.assertTrue(NotRunningDep().is_met(ti=ti)) diff --git a/tests/ti_deps/deps/not_skipped_dep.py b/tests/ti_deps/test_not_skipped_dep.py similarity index 78% rename from tests/ti_deps/deps/not_skipped_dep.py rename to tests/ti_deps/test_not_skipped_dep.py index 6d7ef55f6f..8a31bf9483 100644 --- a/tests/ti_deps/deps/not_skipped_dep.py +++ b/tests/ti_deps/test_not_skipped_dep.py @@ -13,10 +13,10 @@ # limitations under the License. import unittest +from mock import Mock from airflow.ti_deps.deps.not_skipped_dep import NotSkippedDep from airflow.utils.state import State -from fake_models import FakeTI class NotSkippedDepTest(unittest.TestCase): @@ -25,14 +25,12 @@ def test_skipped(self): """ Skipped task instances should fail this dep """ - ti = FakeTI(state=State.SKIPPED) - - self.assertFalse(NotSkippedDep().is_met(ti=ti, dep_context=None)) + ti = Mock(state=State.SKIPPED) + self.assertFalse(NotSkippedDep().is_met(ti=ti)) def test_not_skipped(self): """ Non-skipped task instances should pass this dep """ - ti = FakeTI(state=State.RUNNING) - - self.assertTrue(NotSkippedDep().is_met(ti=ti, dep_context=None)) + ti = Mock(state=State.RUNNING) + self.assertTrue(NotSkippedDep().is_met(ti=ti)) diff --git a/tests/ti_deps/test_prev_dagrun_dep.py b/tests/ti_deps/test_prev_dagrun_dep.py new file mode 100644 index 0000000000..0f6f5da5df --- /dev/null +++ b/tests/ti_deps/test_prev_dagrun_dep.py @@ -0,0 +1,123 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from datetime import datetime +from mock import Mock + +from airflow.models import DAG, BaseOperator +from airflow.ti_deps.dep_context import DepContext +from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep +from airflow.utils.state import State + + +class PrevDagrunDepTest(unittest.TestCase): + + def _get_task(self, **kwargs): + return BaseOperator(task_id='test_task', dag=DAG('test_dag'), **kwargs) + + def test_not_depends_on_past(self): + """ + If depends on past isn't set in the task then the previous dagrun should be + ignored, even though there is no previous_ti which would normally fail the dep + """ + task = self._get_task(depends_on_past=False, + start_date=datetime(2016, 1, 1), + wait_for_downstream=False) + prev_ti = Mock(task=task, state=State.SUCCESS, + are_dependents_done=Mock(return_value=True), + execution_date=datetime(2016, 1, 2)) + ti = Mock(task=task, previous_ti=prev_ti, + execution_date=datetime(2016, 1, 3)) + dep_context = DepContext(ignore_depends_on_past=False) + + self.assertTrue(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)) + + def test_context_ignore_depends_on_past(self): + """ + If the context overrides depends_on_past then the dep should be met, + even though there is no previous_ti which would normally fail the dep + """ + task = self._get_task(depends_on_past=True, + start_date=datetime(2016, 1, 1), + wait_for_downstream=False) + prev_ti = Mock(task=task, state=State.SUCCESS, + are_dependents_done=Mock(return_value=True), + execution_date=datetime(2016, 1, 2)) + ti = Mock(task=task, previous_ti=prev_ti, + execution_date=datetime(2016, 1, 3)) + dep_context = DepContext(ignore_depends_on_past=True) + + self.assertTrue(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)) + + def test_first_task_run(self): + """ + The first task run for a TI should pass since it has no previous dagrun. + """ + task = self._get_task(depends_on_past=True, + start_date=datetime(2016, 1, 1), + wait_for_downstream=False) + prev_ti = None + ti = Mock(task=task, previous_ti=prev_ti, + execution_date=datetime(2016, 1, 1)) + dep_context = DepContext(ignore_depends_on_past=False) + + self.assertTrue(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)) + + def test_prev_ti_bad_state(self): + """ + If the previous TI did not complete execution this dep should fail. + """ + task = self._get_task(depends_on_past=True, + start_date=datetime(2016, 1, 1), + wait_for_downstream=False) + prev_ti = Mock(state=State.NONE, + are_dependents_done=Mock(return_value=True)) + ti = Mock(task=task, previous_ti=prev_ti, + execution_date=datetime(2016, 1, 2)) + dep_context = DepContext(ignore_depends_on_past=False) + + self.assertFalse(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)) + + def test_failed_wait_for_downstream(self): + """ + If the previous TI specified to wait for the downstream tasks of the + previous dagrun then it should fail this dep if the downstream TIs of + the previous TI are not done. + """ + task = self._get_task(depends_on_past=True, + start_date=datetime(2016, 1, 1), + wait_for_downstream=True) + prev_ti = Mock(state=State.SUCCESS, + are_dependents_done=Mock(return_value=False)) + ti = Mock(task=task, previous_ti=prev_ti, + execution_date=datetime(2016, 1, 2)) + dep_context = DepContext(ignore_depends_on_past=False) + + self.assertFalse(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)) + + def test_all_met(self): + """ + Test to make sure all of the conditions for the dep are met + """ + task = self._get_task(depends_on_past=True, + start_date=datetime(2016, 1, 1), + wait_for_downstream=True) + prev_ti = Mock(state=State.SUCCESS, + are_dependents_done=Mock(return_value=True)) + ti = Mock(task=task, previous_ti=prev_ti, + execution_date=datetime(2016, 1, 2)) + dep_context = DepContext(ignore_depends_on_past=False) + + self.assertTrue(PrevDagrunDep().is_met(ti=ti, dep_context=dep_context)) diff --git a/tests/ti_deps/test_runnable_exec_date_dep.py b/tests/ti_deps/test_runnable_exec_date_dep.py new file mode 100644 index 0000000000..e1a396c8d8 --- /dev/null +++ b/tests/ti_deps/test_runnable_exec_date_dep.py @@ -0,0 +1,76 @@ +# -*- coding: utf-8 -*- +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from datetime import datetime +from freezegun import freeze_time +from mock import Mock + +from airflow.models import TaskInstance +from airflow.ti_deps.deps.runnable_exec_date_dep import RunnableExecDateDep + + +class RunnableExecDateDepTest(unittest.TestCase): + + def _get_task_instance(self, execution_date, dag_end_date=None, task_end_date=None): + dag = Mock(end_date=dag_end_date) + task = Mock(dag=dag, end_date=task_end_date) + return TaskInstance(task=task, execution_date=execution_date) + + @freeze_time('2016-01-01') + def test_exec_date_after_end_date(self): + """ + If the dag's execution date is in the future this dep should fail + """ + ti = self._get_task_instance( + dag_end_date=datetime(2016, 1, 3), + task_end_date=datetime(2016, 1, 3), + execution_date=datetime(2016, 1, 2), + ) + self.assertFalse(RunnableExecDateDep().is_met(ti=ti)) + + def test_exec_date_after_task_end_date(self): + """ + If the task instance execution date is after the tasks's end date + this dep should fail + """ + ti = self._get_task_instance( + dag_end_date=datetime(2016, 1, 3), + task_end_date=datetime(2016, 1, 1), + execution_date=datetime(2016, 1, 2), + ) + self.assertFalse(RunnableExecDateDep().is_met(ti=ti)) + + def test_exec_date_after_dag_end_date(self): + """ + If the task instance execution date is after the dag's end date + this dep should fail + """ + ti = self._get_task_instance( + dag_end_date=datetime(2016, 1, 1), + task_end_date=datetime(2016, 1, 3), + execution_date=datetime(2016, 1, 2), + ) + self.assertFalse(RunnableExecDateDep().is_met(ti=ti)) + + def test_all_deps_met(self): + """ + Test to make sure all of the conditions for the dep are met + """ + ti = self._get_task_instance( + dag_end_date=datetime(2016, 1, 2), + task_end_date=datetime(2016, 1, 2), + execution_date=datetime(2016, 1, 1), + ) + self.assertTrue(RunnableExecDateDep().is_met(ti=ti)) diff --git a/tests/ti_deps/deps/trigger_rule_dep.py b/tests/ti_deps/test_trigger_rule_dep.py similarity index 72% rename from tests/ti_deps/deps/trigger_rule_dep.py rename to tests/ti_deps/test_trigger_rule_dep.py index 04a7737dea..a61ff0d63f 100644 --- a/tests/ti_deps/deps/trigger_rule_dep.py +++ b/tests/ti_deps/test_trigger_rule_dep.py @@ -13,50 +13,43 @@ # limitations under the License. import unittest +from datetime import datetime +from airflow.models import BaseOperator, TaskInstance from airflow.utils.trigger_rule import TriggerRule from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.utils.state import State -from fake_models import FakeTask, FakeTI class TriggerRuleDepTest(unittest.TestCase): + def _get_task_instance(self, trigger_rule=TriggerRule.ALL_SUCCESS, + state=None, upstream_task_ids=None): + task = BaseOperator(task_id='test_task', trigger_rule=trigger_rule, + start_date=datetime(2015, 1, 1)) + if upstream_task_ids: + task._upstream_task_ids.extend(upstream_task_ids) + return TaskInstance(task=task, state=state, execution_date=None) + def test_no_upstream_tasks(self): """ If the TI has no upstream TIs then there is nothing to check and the dep is passed """ - task = FakeTask( - trigger_rule=TriggerRule.ALL_DONE, - upstream_list=[]) - ti = FakeTI( - task=task, - state=State.UP_FOR_RETRY) - - self.assertTrue(TriggerRuleDep().is_met(ti=ti, dep_context=None)) + ti = self._get_task_instance(TriggerRule.ALL_DONE, State.UP_FOR_RETRY) + self.assertTrue(TriggerRuleDep().is_met(ti=ti)) def test_dummy_tr(self): """ The dummy trigger rule should always pass this dep """ - task = FakeTask( - trigger_rule=TriggerRule.DUMMY, - upstream_list=[]) - ti = FakeTI( - task=task, - state=State.UP_FOR_RETRY) - - self.assertTrue(TriggerRuleDep().is_met(ti=ti, dep_context=None)) + ti = self._get_task_instance(TriggerRule.DUMMY, State.UP_FOR_RETRY) + self.assertTrue(TriggerRuleDep().is_met(ti=ti)) def test_one_success_tr_success(self): """ One-success trigger rule success """ - task = FakeTask( - trigger_rule=TriggerRule.ONE_SUCCESS, - upstream_task_ids=[]) - ti = FakeTI(task=task) - + ti = self._get_task_instance(TriggerRule.ONE_SUCCESS, State.UP_FOR_RETRY) dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( ti=ti, successes=1, @@ -66,18 +59,13 @@ def test_one_success_tr_success(self): done=2, flag_upstream_failed=False, session="Fake Session")) - self.assertEqual(len(dep_statuses), 0) def test_one_success_tr_failure(self): """ One-success trigger rule failure """ - task = FakeTask( - trigger_rule=TriggerRule.ONE_SUCCESS, - upstream_task_ids=[]) - ti = FakeTI(task=task) - + ti = self._get_task_instance(TriggerRule.ONE_SUCCESS) dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( ti=ti, successes=0, @@ -87,7 +75,6 @@ def test_one_success_tr_failure(self): done=2, flag_upstream_failed=False, session="Fake Session")) - self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) @@ -95,11 +82,7 @@ def test_one_failure_tr_failure(self): """ One-failure trigger rule failure """ - task = FakeTask( - trigger_rule=TriggerRule.ONE_FAILED, - upstream_task_ids=[]) - ti = FakeTI(task=task) - + ti = self._get_task_instance(TriggerRule.ONE_FAILED) dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( ti=ti, successes=2, @@ -109,16 +92,14 @@ def test_one_failure_tr_failure(self): done=2, flag_upstream_failed=False, session="Fake Session")) + self.assertEqual(len(dep_statuses), 1) + self.assertFalse(dep_statuses[0].passed) def test_one_failure_tr_success(self): """ One-failure trigger rule success """ - task = FakeTask( - trigger_rule=TriggerRule.ONE_FAILED, - upstream_task_ids=[]) - ti = FakeTI(task=task) - + ti = self._get_task_instance(TriggerRule.ONE_FAILED) dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( ti=ti, successes=0, @@ -128,7 +109,6 @@ def test_one_failure_tr_success(self): done=2, flag_upstream_failed=False, session="Fake Session")) - self.assertEqual(len(dep_statuses), 0) dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( @@ -140,18 +120,14 @@ def test_one_failure_tr_success(self): done=2, flag_upstream_failed=False, session="Fake Session")) - self.assertEqual(len(dep_statuses), 0) def test_all_success_tr_success(self): """ All-success trigger rule success """ - task = FakeTask( - trigger_rule=TriggerRule.ALL_SUCCESS, - upstream_task_ids=["FakeTaskID"]) - ti = FakeTI(task=task) - + ti = self._get_task_instance(TriggerRule.ALL_SUCCESS, + upstream_task_ids=["FakeTaskID"]) dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( ti=ti, successes=1, @@ -161,18 +137,15 @@ def test_all_success_tr_success(self): done=1, flag_upstream_failed=False, session="Fake Session")) - self.assertEqual(len(dep_statuses), 0) def test_all_success_tr_failure(self): """ All-success trigger rule failure """ - task = FakeTask( - trigger_rule=TriggerRule.ALL_SUCCESS, - upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]) - ti = FakeTI(task=task) - + ti = self._get_task_instance(TriggerRule.ALL_SUCCESS, + upstream_task_ids=["FakeTaskID", + "OtherFakeTaskID"]) dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( ti=ti, successes=1, @@ -182,7 +155,6 @@ def test_all_success_tr_failure(self): done=2, flag_upstream_failed=False, session="Fake Session")) - self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) @@ -190,11 +162,9 @@ def test_all_failed_tr_success(self): """ All-failed trigger rule success """ - task = FakeTask( - trigger_rule=TriggerRule.ALL_FAILED, - upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]) - ti = FakeTI(task=task) - + ti = self._get_task_instance(TriggerRule.ALL_FAILED, + upstream_task_ids=["FakeTaskID", + "OtherFakeTaskID"]) dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( ti=ti, successes=0, @@ -204,18 +174,15 @@ def test_all_failed_tr_success(self): done=2, flag_upstream_failed=False, session="Fake Session")) - self.assertEqual(len(dep_statuses), 0) def test_all_failed_tr_failure(self): """ All-failed trigger rule failure """ - task = FakeTask( - trigger_rule=TriggerRule.ALL_FAILED, - upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]) - ti = FakeTI(task=task) - + ti = self._get_task_instance(TriggerRule.ALL_FAILED, + upstream_task_ids=["FakeTaskID", + "OtherFakeTaskID"]) dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( ti=ti, successes=2, @@ -225,7 +192,6 @@ def test_all_failed_tr_failure(self): done=2, flag_upstream_failed=False, session="Fake Session")) - self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) @@ -233,11 +199,9 @@ def test_all_done_tr_success(self): """ All-done trigger rule success """ - task = FakeTask( - trigger_rule=TriggerRule.ALL_DONE, - upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]) - ti = FakeTI(task=task) - + ti = self._get_task_instance(TriggerRule.ALL_DONE, + upstream_task_ids=["FakeTaskID", + "OtherFakeTaskID"]) dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( ti=ti, successes=2, @@ -247,18 +211,15 @@ def test_all_done_tr_success(self): done=2, flag_upstream_failed=False, session="Fake Session")) - self.assertEqual(len(dep_statuses), 0) def test_all_done_tr_failure(self): """ All-done trigger rule failure """ - task = FakeTask( - trigger_rule=TriggerRule.ALL_DONE, - upstream_task_ids=["FakeTaskID", "OtherFakeTaskID"]) - ti = FakeTI(task=task) - + ti = self._get_task_instance(TriggerRule.ALL_DONE, + upstream_task_ids=["FakeTaskID", + "OtherFakeTaskID"]) dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( ti=ti, successes=1, @@ -268,7 +229,6 @@ def test_all_done_tr_failure(self): done=1, flag_upstream_failed=False, session="Fake Session")) - self.assertEqual(len(dep_statuses), 1) self.assertFalse(dep_statuses[0].passed) @@ -276,11 +236,8 @@ def test_unknown_tr(self): """ Unknown trigger rules should cause this dep to fail """ - task = FakeTask( - trigger_rule="Unknown Trigger Rule", - upstream_task_ids=[]) - ti = FakeTI(task=task) - + ti = self._get_task_instance() + ti.task.trigger_rule = "Unknown Trigger Rule" dep_statuses = tuple(TriggerRuleDep()._evaluate_trigger_rule( ti=ti, successes=1, diff --git a/tests/ti_deps/deps/valid_state_dep.py b/tests/ti_deps/test_valid_state_dep.py similarity index 74% rename from tests/ti_deps/deps/valid_state_dep.py rename to tests/ti_deps/test_valid_state_dep.py index 6bc08357fa..2ece718e84 100644 --- a/tests/ti_deps/deps/valid_state_dep.py +++ b/tests/ti_deps/test_valid_state_dep.py @@ -14,11 +14,11 @@ import unittest from datetime import datetime +from mock import Mock from airflow import AirflowException from airflow.ti_deps.deps.valid_state_dep import ValidStateDep from airflow.utils.state import State -from fake_models import FakeTI class ValidStateDepTest(unittest.TestCase): @@ -27,23 +27,20 @@ def test_valid_state(self): """ Valid state should pass this dep """ - ti = FakeTI(state=State.QUEUED, end_date=datetime(2016, 1, 1)) - - self.assertTrue(ValidStateDep({State.QUEUED}).is_met(ti=ti, dep_context=None)) + ti = Mock(state=State.QUEUED, end_date=datetime(2016, 1, 1)) + self.assertTrue(ValidStateDep({State.QUEUED}).is_met(ti=ti)) def test_invalid_state(self): """ Invalid state should fail this dep """ - ti = FakeTI(state=State.SUCCESS, end_date=datetime(2016, 1, 1)) - - self.assertFalse(ValidStateDep({State.FAILURE}).is_met(ti=ti, dep_context=None)) + ti = Mock(state=State.SUCCESS, end_date=datetime(2016, 1, 1)) + self.assertFalse(ValidStateDep({State.FAILED}).is_met(ti=ti)) def test_no_valid_states(self): """ If there are no valid states the dependency should throw """ - ti = FakeTI(state=State.SUCCESS, end_date=datetime(2016, 1, 1)) - + ti = Mock(state=State.SUCCESS, end_date=datetime(2016, 1, 1)) with self.assertRaises(AirflowException): - ValidStateDep({}).is_met(ti=ti, dep_context=None) + ValidStateDep({}).is_met(ti=ti) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index 6b15998714..9d7677a99b 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -11,6 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from .compression import * -from .dates import * diff --git a/tests/utils/compression.py b/tests/utils/test_compression.py similarity index 98% rename from tests/utils/compression.py rename to tests/utils/test_compression.py index f8e0ebbb2a..13af92e8a7 100644 --- a/tests/utils/compression.py +++ b/tests/utils/test_compression.py @@ -91,7 +91,3 @@ def test_uncompress_file(self): txt_bz2 = compression.uncompress_file(fn_bz2, '.bz2', self.tmp_dir) self.assertTrue(filecmp.cmp(txt_bz2, fn_txt, shallow=False), msg="Uncompressed file doest match original") - - -if __name__ == '__main__': - unittest.main() diff --git a/tests/utils/dates.py b/tests/utils/test_dates.py similarity index 100% rename from tests/utils/dates.py rename to tests/utils/test_dates.py ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: users@infra.apache.org > Tons of unit tests are ignored > ------------------------------ > > Key: AIRFLOW-867 > URL: https://issues.apache.org/jira/browse/AIRFLOW-867 > Project: Apache Airflow > Issue Type: Bug > Components: tests > Reporter: George Sakkis > Assignee: Muhammad Ahmmad > Priority: Major > > I was poking around in tests and found out that lots of tests are not discovered by nosetests: > {noformat} > $ nosetests -q --collect-only > ---------------------------------------------------------------------- > Ran 254 tests in 0.948s > $ grep -R 'def test' tests/ | wc -l > 360 > {noformat} > Initially I thought it might be related to not having installed all extra dependencies but it turns out it's because apparently nosetests expects explicit import of the related modules instead of discovering them automatically (like py.test). For example, when adding an {{from .ti_deps.deps.runnable_exec_date_dep import *}} in {{tests/__init__.py}} it finds 260 tests, while when commenting out all imports in this module it finds only 15! > h4. Possible options > * Quick fix: Add the necessary missing "import *" to discover all current tests. > * Better fix: Rename all test modules to start with "test_" > -Move from nosetests to py.test and get rid of the ugly error-prone 'import *' hack.- -- This message was sent by Atlassian JIRA (v7.6.3#76005)