airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From bo...@apache.org
Subject [2/3] incubator-airflow git commit: [AIRFLOW-1191] : SparkSubmitHook custom cmd
Date Tue, 30 May 2017 19:55:53 GMT
[AIRFLOW-1191] : SparkSubmitHook custom cmd

Add the capability to set the spark-submit binary to call.
The default behaviour set the spark-submit command to
'spark-submit', or to set it via a Spark env var.
the spark binary can now be set in the spark connection.

Test coverage extended for the new settings.


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/d06ab68f
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/d06ab68f
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/d06ab68f

Branch: refs/heads/master
Commit: d06ab68f2c83ad5dce3cae1c5aa9a9a9f32cf934
Parents: d165377
Author: vfoucault <vianney.foucault@gmail.com>
Authored: Sun May 14 23:23:11 2017 +0200
Committer: vfoucault <vianney.foucault@gmail.com>
Committed: Mon May 22 23:43:46 2017 +0200

----------------------------------------------------------------------
 airflow/contrib/hooks/spark_submit_hook.py    | 49 ++++++------
 tests/contrib/hooks/test_spark_submit_hook.py | 88 ++++++++++++++++++++--
 2 files changed, 104 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/d06ab68f/airflow/contrib/hooks/spark_submit_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/spark_submit_hook.py b/airflow/contrib/hooks/spark_submit_hook.py
index 208b74f..ae51959 100644
--- a/airflow/contrib/hooks/spark_submit_hook.py
+++ b/airflow/contrib/hooks/spark_submit_hook.py
@@ -100,40 +100,39 @@ class SparkSubmitHook(BaseHook):
         self._sp = None
         self._yarn_application_id = None
 
-        (self._master, self._queue, self._deploy_mode, self._spark_home) = self._resolve_connection()
-        self._is_yarn = 'yarn' in self._master
+        self._connection = self._resolve_connection()
+        self._is_yarn = 'yarn' in self._connection['master']
 
     def _resolve_connection(self):
         # Build from connection master or default to yarn if not available
-        master = 'yarn'
-        queue = None
-        deploy_mode = None
-        spark_home = None
+        conn_data = {'master': 'yarn',
+                     'queue': None,
+                     'deploy_mode': None,
+                     'spark_home': None,
+                     'spark_binary': 'spark-submit'}
 
         try:
             # Master can be local, yarn, spark://HOST:PORT or mesos://HOST:PORT
             conn = self.get_connection(self._conn_id)
             if conn.port:
-                master = "{}:{}".format(conn.host, conn.port)
+                conn_data['master'] = "{}:{}".format(conn.host, conn.port)
             else:
-                master = conn.host
+                conn_data['master'] = conn.host
 
             # Determine optional yarn queue from the extra field
             extra = conn.extra_dejson
-            if 'queue' in extra:
-                queue = extra['queue']
-            if 'deploy-mode' in extra:
-                deploy_mode = extra['deploy-mode']
-            if 'spark-home' in extra:
-                spark_home = extra['spark-home']
+            conn_data['queue'] = extra.get('queue', None)
+            conn_data['deploy_mode'] = extra.get('deploy-mode', None)
+            conn_data['spark_home'] = extra.get('spark-home', None)
+            conn_data['spark_binary'] = extra.get('spark-binary', 'spark-submit')
         except AirflowException:
             logging.debug(
                 "Could not load connection string {}, defaulting to {}".format(
-                    self._conn_id, master
+                    self._conn_id, conn_data['master']
                 )
             )
 
-        return master, queue, deploy_mode, spark_home
+        return conn_data
 
     def get_conn(self):
         pass
@@ -148,13 +147,13 @@ class SparkSubmitHook(BaseHook):
         # If the spark_home is passed then build the spark-submit executable path using
         # the spark_home; otherwise assume that spark-submit is present in the path to
         # the executing user
-        if self._spark_home:
-            connection_cmd = [os.path.join(self._spark_home, 'bin', 'spark-submit')]
+        if self._connection['spark_home']:
+            connection_cmd = [os.path.join(self._connection['spark_home'], 'bin', self._connection['spark_binary'])]
         else:
-            connection_cmd = ['spark-submit']
+            connection_cmd = [self._connection['spark_binary']]
 
         # The url ot the spark master
-        connection_cmd += ["--master", self._master]
+        connection_cmd += ["--master", self._connection['master']]
 
         if self._conf:
             for key in self._conf:
@@ -185,10 +184,10 @@ class SparkSubmitHook(BaseHook):
             connection_cmd += ["--class", self._java_class]
         if self._verbose:
             connection_cmd += ["--verbose"]
-        if self._queue:
-            connection_cmd += ["--queue", self._queue]
-        if self._deploy_mode:
-            connection_cmd += ["--deploy-mode", self._deploy_mode]
+        if self._connection['queue']:
+            connection_cmd += ["--queue", self._connection['queue']]
+        if self._connection['deploy_mode']:
+            connection_cmd += ["--deploy-mode", self._connection['deploy_mode']]
 
         # The actual script to execute
         connection_cmd += [application]
@@ -245,7 +244,7 @@ class SparkSubmitHook(BaseHook):
             line = line.decode('utf-8').strip()
             # If we run yarn cluster mode, we want to extract the application id from
             # the logs so we can kill the application when we stop it unexpectedly
-            if self._is_yarn and self._deploy_mode == 'cluster':
+            if self._is_yarn and self._connection['deploy_mode'] == 'cluster':
                 match = re.search('(application[0-9_]+)', line)
                 if match:
                     self._yarn_application_id = match.groups()[0]

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/d06ab68f/tests/contrib/hooks/test_spark_submit_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_spark_submit_hook.py b/tests/contrib/hooks/test_spark_submit_hook.py
index ee5b9e0..80b5ce0 100644
--- a/tests/contrib/hooks/test_spark_submit_hook.py
+++ b/tests/contrib/hooks/test_spark_submit_hook.py
@@ -90,6 +90,16 @@ class TestSparkSubmitHook(unittest.TestCase):
                 conn_id='spark_home_not_set', conn_type='spark',
                 host='yarn://yarn-master')
         )
+        db.merge_conn(
+            models.Connection(
+                conn_id='spark_binary_set', conn_type='spark',
+                host='yarn', extra='{"spark-binary": "custom-spark-submit"}')
+        )
+        db.merge_conn(
+            models.Connection(
+                conn_id='spark_binary_and_home_set', conn_type='spark',
+                host='yarn', extra='{"spark-home": "/path/to/spark_home", "spark-binary":
"custom-spark-submit"}')
+        )
 
     def test_build_command(self):
         # Given
@@ -123,8 +133,6 @@ class TestSparkSubmitHook(unittest.TestCase):
         ]
         self.assertEquals(expected_build_cmd, cmd)
 
-
-
     @patch('subprocess.Popen')
     def test_SparkProcess_runcmd(self, mock_popen):
         # Given
@@ -150,7 +158,12 @@ class TestSparkSubmitHook(unittest.TestCase):
 
         # Then
         dict_cmd = self.cmd_args_to_dict(cmd)
-        self.assertSequenceEqual(connection, ('yarn', None, None, None))
+        expected_spark_connection = {"master": u"yarn",
+                                     "spark_binary": "spark-submit",
+                                     "deploy_mode": None,
+                                     "queue": None,
+                                     "spark_home": None}
+        self.assertEqual(connection, expected_spark_connection)
         self.assertEqual(dict_cmd["--master"], "yarn")
 
     def test_resolve_connection_yarn_default_connection(self):
@@ -163,7 +176,12 @@ class TestSparkSubmitHook(unittest.TestCase):
 
         # Then
         dict_cmd = self.cmd_args_to_dict(cmd)
-        self.assertSequenceEqual(connection, ('yarn', 'root.default', None, None))
+        expected_spark_connection = {"master": u"yarn",
+                                     "spark_binary": "spark-submit",
+                                     "deploy_mode": None,
+                                     "queue": u"root.default",
+                                     "spark_home": None}
+        self.assertEqual(connection, expected_spark_connection)
         self.assertEqual(dict_cmd["--master"], "yarn")
         self.assertEqual(dict_cmd["--queue"], "root.default")
 
@@ -177,7 +195,12 @@ class TestSparkSubmitHook(unittest.TestCase):
 
         # Then
         dict_cmd = self.cmd_args_to_dict(cmd)
-        self.assertSequenceEqual(connection, ('mesos://host:5050', None, None, None))
+        expected_spark_connection = {"master": u"mesos://host:5050",
+                                     "spark_binary": "spark-submit",
+                                     "deploy_mode": None,
+                                     "queue": None,
+                                     "spark_home": None}
+        self.assertEqual(connection, expected_spark_connection)
         self.assertEqual(dict_cmd["--master"], "mesos://host:5050")
 
     def test_resolve_connection_spark_yarn_cluster_connection(self):
@@ -190,7 +213,12 @@ class TestSparkSubmitHook(unittest.TestCase):
 
         # Then
         dict_cmd = self.cmd_args_to_dict(cmd)
-        self.assertSequenceEqual(connection, ('yarn://yarn-master', 'root.etl', 'cluster',
None))
+        expected_spark_connection = {"master": u"yarn://yarn-master",
+                                     "spark_binary": "spark-submit",
+                                     "deploy_mode": u"cluster",
+                                     "queue": u"root.etl",
+                                     "spark_home": None}
+        self.assertEqual(connection, expected_spark_connection)
         self.assertEqual(dict_cmd["--master"], "yarn://yarn-master")
         self.assertEqual(dict_cmd["--queue"], "root.etl")
         self.assertEqual(dict_cmd["--deploy-mode"], "cluster")
@@ -204,7 +232,12 @@ class TestSparkSubmitHook(unittest.TestCase):
         cmd = hook._build_command(self._spark_job_file)
 
         # Then
-        self.assertSequenceEqual(connection, ('yarn://yarn-master', None, None, '/opt/myspark'))
+        expected_spark_connection = {"master": u"yarn://yarn-master",
+                                     "spark_binary": "spark-submit",
+                                     "deploy_mode": None,
+                                     "queue": None,
+                                     "spark_home": u"/opt/myspark"}
+        self.assertEqual(connection, expected_spark_connection)
         self.assertEqual(cmd[0], '/opt/myspark/bin/spark-submit')
 
     def test_resolve_connection_spark_home_not_set_connection(self):
@@ -216,9 +249,48 @@ class TestSparkSubmitHook(unittest.TestCase):
         cmd = hook._build_command(self._spark_job_file)
 
         # Then
-        self.assertSequenceEqual(connection, ('yarn://yarn-master', None, None, None))
+        expected_spark_connection = {"master": u"yarn://yarn-master",
+                                     "spark_binary": "spark-submit",
+                                     "deploy_mode": None,
+                                     "queue": None,
+                                     "spark_home": None}
+        self.assertEqual(connection, expected_spark_connection)
         self.assertEqual(cmd[0], 'spark-submit')
 
+    def test_resolve_connection_spark_binary_set_connection(self):
+        # Given
+        hook = SparkSubmitHook(conn_id='spark_binary_set')
+
+        # When
+        connection = hook._resolve_connection()
+        cmd = hook._build_command(self._spark_job_file)
+
+        # Then
+        expected_spark_connection = {"master": u"yarn",
+                                     "spark_binary": u"custom-spark-submit",
+                                     "deploy_mode": None,
+                                     "queue": None,
+                                     "spark_home": None}
+        self.assertEqual(connection, expected_spark_connection)
+        self.assertEqual(cmd[0], 'custom-spark-submit')
+
+    def test_resolve_connection_spark_binary_and_home_set_connection(self):
+        # Given
+        hook = SparkSubmitHook(conn_id='spark_binary_and_home_set')
+
+        # When
+        connection = hook._resolve_connection()
+        cmd = hook._build_command(self._spark_job_file)
+
+        # Then
+        expected_spark_connection = {"master": u"yarn",
+                                     "spark_binary": u"custom-spark-submit",
+                                     "deploy_mode": None,
+                                     "queue": None,
+                                     "spark_home": u"/path/to/spark_home"}
+        self.assertEqual(connection, expected_spark_connection)
+        self.assertEqual(cmd[0], '/path/to/spark_home/bin/custom-spark-submit')
+
     def test_process_log(self):
         # Given
         hook = SparkSubmitHook(conn_id='spark_yarn_cluster')


Mime
View raw message