airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From bo...@apache.org
Subject [2/3] incubator-airflow git commit: [AIRFLOW-1184] SparkSubmitHook does not split args
Date Sat, 13 May 2017 10:51:07 GMT
[AIRFLOW-1184] SparkSubmitHook does not split args


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/b113432e
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/b113432e
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/b113432e

Branch: refs/heads/master
Commit: b113432e7253811f55b373ab73c98079c724ead9
Parents: f29dc7c
Author: Vianney Foucault <vianney.foucault@gmail.com>
Authored: Fri May 12 14:44:45 2017 +0200
Committer: Vianney Foucault <vianney.foucault@gmail.com>
Committed: Fri May 12 16:32:42 2017 +0200

----------------------------------------------------------------------
 airflow/contrib/hooks/spark_submit_hook.py    |  13 +-
 tests/contrib/hooks/test_spark_submit_hook.py | 222 ++++++++++++---------
 2 files changed, 133 insertions(+), 102 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b113432e/airflow/contrib/hooks/spark_submit_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/spark_submit_hook.py b/airflow/contrib/hooks/spark_submit_hook.py
index c34538e..208b74f 100644
--- a/airflow/contrib/hooks/spark_submit_hook.py
+++ b/airflow/contrib/hooks/spark_submit_hook.py
@@ -196,8 +196,11 @@ class SparkSubmitHook(BaseHook):
         # Append any application arguments
         if self._application_args:
             for arg in self._application_args:
-                connection_cmd += [arg]
-
+                if len(arg.split()) > 1:
+                    for splitted_option in arg.split():
+                        connection_cmd += [splitted_option]
+                else:
+                    connection_cmd += [arg]
         logging.debug("Spark-Submit cmd: {}".format(connection_cmd))
 
         return connection_cmd
@@ -257,7 +260,7 @@ class SparkSubmitHook(BaseHook):
 
             if self._yarn_application_id:
                 logging.info('Killing application on YARN')
-                yarn_kill = Popen("yarn application -kill {0}".format(self._yarn_application_id),
-                                  stdout=subprocess.PIPE,
-                                  stderr=subprocess.PIPE)
+                yarn_kill = subprocess.Popen("yarn application -kill {0}".format(self._yarn_application_id),
+                                             stdout=subprocess.PIPE,
+                                             stderr=subprocess.PIPE)
                 logging.info("YARN killed with return code: {0}".format(yarn_kill.wait()))

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/b113432e/tests/contrib/hooks/test_spark_submit_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_spark_submit_hook.py b/tests/contrib/hooks/test_spark_submit_hook.py
index e06d44c..ee5b9e0 100644
--- a/tests/contrib/hooks/test_spark_submit_hook.py
+++ b/tests/contrib/hooks/test_spark_submit_hook.py
@@ -16,10 +16,10 @@ import sys
 import unittest
 from io import StringIO
 
-import mock
-
 from airflow import configuration, models
 from airflow.utils import db
+from mock import patch, call
+
 from airflow.contrib.hooks.spark_submit_hook import SparkSubmitHook
 
 
@@ -51,6 +51,15 @@ class TestSparkSubmitHook(unittest.TestCase):
         ]
     }
 
+    @staticmethod
+    def cmd_args_to_dict(list_cmd):
+        return_dict = {}
+        for arg in list_cmd:
+            if arg.startswith("--"):
+                pos = list_cmd.index(arg)
+                return_dict[arg] = list_cmd[pos+1]
+        return return_dict
+
     def setUp(self):
 
         if sys.version_info[0] == 3:
@@ -83,119 +92,136 @@ class TestSparkSubmitHook(unittest.TestCase):
         )
 
     def test_build_command(self):
+        # Given
         hook = SparkSubmitHook(**self._config)
 
-        # The subprocess requires an array but we build the cmd by joining on a space
-        cmd = ' '.join(hook._build_command(self._spark_job_file))
-
-        # Check if the URL gets build properly and everything exists.
-        assert self._spark_job_file in cmd
-
-        # Check all the parameters
-        assert "--files {}".format(self._config['files']) in cmd
-        assert "--py-files {}".format(self._config['py_files']) in cmd
-        assert "--jars {}".format(self._config['jars']) in cmd
-        assert "--total-executor-cores {}".format(self._config['total_executor_cores']) in
cmd
-        assert "--executor-cores {}".format(self._config['executor_cores']) in cmd
-        assert "--executor-memory {}".format(self._config['executor_memory']) in cmd
-        assert "--keytab {}".format(self._config['keytab']) in cmd
-        assert "--principal {}".format(self._config['principal']) in cmd
-        assert "--name {}".format(self._config['name']) in cmd
-        assert "--num-executors {}".format(self._config['num_executors']) in cmd
-        assert "--class {}".format(self._config['java_class']) in cmd
-        assert "--driver-memory {}".format(self._config['driver_memory']) in cmd
-
-        # Check if all config settings are there
-        for k in self._config['conf']:
-            assert "--conf {0}={1}".format(k, self._config['conf'][k]) in cmd
-
-        # Check the application arguments are there
-        for a in self._config['application_args']:
-            assert a in cmd
-
-        # Check if application arguments are after the application
-        application_idx = cmd.find(self._spark_job_file)
-        for a in self._config['application_args']:
-            assert cmd.find(a) > application_idx
-
-        if self._config['verbose']:
-            assert "--verbose" in cmd
-
-    @mock.patch('airflow.contrib.hooks.spark_submit_hook.subprocess')
-    def test_submit(self, mock_process):
-        # We don't have spark-submit available, and this is hard to mock, so let's
-        # just use this simple mock.
-        mock_Popen = mock_process.Popen.return_value
-        mock_Popen.stdout = StringIO(u'stdout')
-        mock_Popen.stderr = StringIO(u'stderr')
-        mock_Popen.returncode = None
-        mock_Popen.communicate.return_value = ['extra stdout', 'extra stderr']
-        hook = SparkSubmitHook()
-        hook.submit(self._spark_job_file)
-
-    def test_resolve_connection(self):
-
-        # Default to the standard yarn connection because conn_id does not exists
+        # When
+        cmd = hook._build_command(self._spark_job_file)
+
+        # Then
+        expected_build_cmd = [
+            'spark-submit',
+            '--master', 'yarn',
+            '--conf', 'parquet.compression=SNAPPY',
+            '--files', 'hive-site.xml',
+            '--py-files', 'sample_library.py',
+            '--jars', 'parquet.jar',
+            '--num-executors', '10',
+            '--total-executor-cores', '4',
+            '--executor-cores', '4',
+            '--executor-memory', '22g',
+            '--driver-memory', '3g',
+            '--keytab', 'privileged_user.keytab',
+            '--principal', 'user/spark@airflow.org',
+            '--name', 'spark-job',
+            '--class', 'com.foo.bar.AppMain',
+            '--verbose',
+            'test_application.py',
+            '-f', 'foo',
+            '--bar', 'bar',
+            'baz'
+        ]
+        self.assertEquals(expected_build_cmd, cmd)
+
+
+
+    @patch('subprocess.Popen')
+    def test_SparkProcess_runcmd(self, mock_popen):
+        # Given
+        mock_popen.return_value.stdout = StringIO(u'stdout')
+        mock_popen.return_value.stderr = StringIO(u'stderr')
+        mock_popen.return_value.returncode = 0
+        mock_popen.return_value.communicate.return_value = [StringIO(u'stdout\nstdout'),
StringIO(u'stderr\nstderr')]
+
+        # When
         hook = SparkSubmitHook(conn_id='')
-        self.assertEqual(hook._resolve_connection(), ('yarn', None, None, None))
-        assert "--master yarn" in ' '.join(hook._build_command(self._spark_job_file))
+        hook.submit()
+
+        # Then
+        self.assertEqual(mock_popen.mock_calls[0], call(['spark-submit', '--master', 'yarn',
'--name', 'default-name', ''], stderr=-1, stdout=-1))
+
+    def test_resolve_connection_yarn_default(self):
+        # Given
+        hook = SparkSubmitHook(conn_id='')
+
+        # When
+        connection = hook._resolve_connection()
+        cmd = hook._build_command(self._spark_job_file)
+
+        # Then
+        dict_cmd = self.cmd_args_to_dict(cmd)
+        self.assertSequenceEqual(connection, ('yarn', None, None, None))
+        self.assertEqual(dict_cmd["--master"], "yarn")
 
-        # Default to the standard yarn connection
+    def test_resolve_connection_yarn_default_connection(self):
+        # Given
         hook = SparkSubmitHook(conn_id='spark_default')
-        self.assertEqual(
-            hook._resolve_connection(),
-            ('yarn', 'root.default', None, None)
-        )
-        cmd = ' '.join(hook._build_command(self._spark_job_file))
-        assert "--master yarn" in cmd
-        assert "--queue root.default" in cmd
 
-        # Connect to a mesos master
+        # When
+        connection = hook._resolve_connection()
+        cmd = hook._build_command(self._spark_job_file)
+
+        # Then
+        dict_cmd = self.cmd_args_to_dict(cmd)
+        self.assertSequenceEqual(connection, ('yarn', 'root.default', None, None))
+        self.assertEqual(dict_cmd["--master"], "yarn")
+        self.assertEqual(dict_cmd["--queue"], "root.default")
+
+    def test_resolve_connection_mesos_default_connection(self):
+        # Given
         hook = SparkSubmitHook(conn_id='spark_default_mesos')
-        self.assertEqual(
-            hook._resolve_connection(),
-            ('mesos://host:5050', None, None, None)
-        )
 
-        cmd = ' '.join(hook._build_command(self._spark_job_file))
-        assert "--master mesos://host:5050" in cmd
+        # When
+        connection = hook._resolve_connection()
+        cmd = hook._build_command(self._spark_job_file)
 
-        # Set specific queue and deploy mode
+        # Then
+        dict_cmd = self.cmd_args_to_dict(cmd)
+        self.assertSequenceEqual(connection, ('mesos://host:5050', None, None, None))
+        self.assertEqual(dict_cmd["--master"], "mesos://host:5050")
+
+    def test_resolve_connection_spark_yarn_cluster_connection(self):
+        # Given
         hook = SparkSubmitHook(conn_id='spark_yarn_cluster')
-        self.assertEqual(
-            hook._resolve_connection(),
-            ('yarn://yarn-master', 'root.etl', 'cluster', None)
-        )
 
-        cmd = ' '.join(hook._build_command(self._spark_job_file))
-        assert "--master yarn://yarn-master" in cmd
-        assert "--queue root.etl" in cmd
-        assert "--deploy-mode cluster" in cmd
+        # When
+        connection = hook._resolve_connection()
+        cmd = hook._build_command(self._spark_job_file)
+
+        # Then
+        dict_cmd = self.cmd_args_to_dict(cmd)
+        self.assertSequenceEqual(connection, ('yarn://yarn-master', 'root.etl', 'cluster',
None))
+        self.assertEqual(dict_cmd["--master"], "yarn://yarn-master")
+        self.assertEqual(dict_cmd["--queue"], "root.etl")
+        self.assertEqual(dict_cmd["--deploy-mode"], "cluster")
 
-        # Set the spark home
+    def test_resolve_connection_spark_home_set_connection(self):
+        # Given
         hook = SparkSubmitHook(conn_id='spark_home_set')
-        self.assertEqual(
-            hook._resolve_connection(),
-            ('yarn://yarn-master', None, None, '/opt/myspark')
-        )
 
-        cmd = ' '.join(hook._build_command(self._spark_job_file))
-        assert cmd.startswith('/opt/myspark/bin/spark-submit')
+        # When
+        connection = hook._resolve_connection()
+        cmd = hook._build_command(self._spark_job_file)
+
+        # Then
+        self.assertSequenceEqual(connection, ('yarn://yarn-master', None, None, '/opt/myspark'))
+        self.assertEqual(cmd[0], '/opt/myspark/bin/spark-submit')
 
-        # Spark home not set
+    def test_resolve_connection_spark_home_not_set_connection(self):
+        # Given
         hook = SparkSubmitHook(conn_id='spark_home_not_set')
-        self.assertEqual(
-            hook._resolve_connection(),
-            ('yarn://yarn-master', None, None, None)
-        )
 
-        cmd = ' '.join(hook._build_command(self._spark_job_file))
-        assert cmd.startswith('spark-submit')
+        # When
+        connection = hook._resolve_connection()
+        cmd = hook._build_command(self._spark_job_file)
+
+        # Then
+        self.assertSequenceEqual(connection, ('yarn://yarn-master', None, None, None))
+        self.assertEqual(cmd[0], 'spark-submit')
 
     def test_process_log(self):
-        # Must select yarn connection
+        # Given
         hook = SparkSubmitHook(conn_id='spark_yarn_cluster')
-
         log_lines = [
             'SPARK_MAJOR_VERSION is set to 2, using Spark2',
             'WARN NativeCodeLoader: Unable to load native-hadoop library for your platform...
using builtin-java classes where applicable',
@@ -203,10 +229,12 @@ class TestSparkSubmitHook(unittest.TestCase):
             'INFO Client: Requesting a new application from cluster with 10 NodeManagers',
             'INFO Client: Submitting application application_1486558679801_1820 to ResourceManager'
         ]
-
+        # When
         hook._process_log(log_lines)
 
-        assert hook._yarn_application_id == 'application_1486558679801_1820'
+        # Then
+
+        self.assertEqual(hook._yarn_application_id, 'application_1486558679801_1820')
 
 
 if __name__ == '__main__':


Mime
View raw message