airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From bo...@apache.org
Subject incubator-airflow git commit: [AIRFLOW-1854] Improve Spark Submit operator for standalone cluster mode
Date Tue, 12 Dec 2017 11:46:06 GMT
Repository: incubator-airflow
Updated Branches:
  refs/heads/master 22453d037 -> 3e6babe8e


[AIRFLOW-1854] Improve Spark Submit operator for standalone cluster mode

Closes #2852 from milanvdmria/svend/submit2


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/3e6babe8
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/3e6babe8
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/3e6babe8

Branch: refs/heads/master
Commit: 3e6babe8ed8f8f281b67aa3f4e03bf3cfc1bcbaa
Parents: 22453d0
Author: milanvdmria <milan.vandermeer@realimpactanalytics.com>
Authored: Tue Dec 12 12:45:41 2017 +0100
Committer: Bolke de Bruin <bolke@xs4all.nl>
Committed: Tue Dec 12 12:45:52 2017 +0100

----------------------------------------------------------------------
 airflow/contrib/hooks/spark_submit_hook.py      | 217 ++++++++++++++++---
 .../contrib/operators/spark_submit_operator.py  |  17 +-
 tests/contrib/hooks/test_spark_submit_hook.py   | 175 ++++++++++++---
 .../operators/test_spark_submit_operator.py     |  19 +-
 4 files changed, 354 insertions(+), 74 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3e6babe8/airflow/contrib/hooks/spark_submit_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/spark_submit_hook.py b/airflow/contrib/hooks/spark_submit_hook.py
index c0bc84f..16e14b4 100644
--- a/airflow/contrib/hooks/spark_submit_hook.py
+++ b/airflow/contrib/hooks/spark_submit_hook.py
@@ -15,6 +15,7 @@
 import os
 import subprocess
 import re
+import time
 
 from airflow.hooks.base_hook import BaseHook
 from airflow.exceptions import AirflowException
@@ -42,15 +43,20 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
     :type jars: str
     :param java_class: the main class of the Java application
     :type java_class: str
-    :param packages: Comma-separated list of maven coordinates of jars to include on the
driver and executor classpaths
+    :param packages: Comma-separated list of maven coordinates of jars to include on the
+    driver and executor classpaths
     :type packages: str
-    :param exclude_packages: Comma-separated list of maven coordinates of jars to exclude
while resolving the dependencies provided in 'packages'
+    :param exclude_packages: Comma-separated list of maven coordinates of jars to exclude
+    while resolving the dependencies provided in 'packages'
     :type exclude_packages: str
-    :param repositories: Comma-separated list of additional remote repositories to search
for the maven coordinates given with 'packages'
+    :param repositories: Comma-separated list of additional remote repositories to search
+    for the maven coordinates given with 'packages'
     :type repositories: str
-    :param total_executor_cores: (Standalone & Mesos only) Total cores for all executors
(Default: all the available cores on the worker)
+    :param total_executor_cores: (Standalone & Mesos only) Total cores for all executors
+    (Default: all the available cores on the worker)
     :type total_executor_cores: int
-    :param executor_cores: (Standalone & YARN only) Number of cores per executor (Default:
2)
+    :param executor_cores: (Standalone & YARN only) Number of cores per executor
+    (Default: 2)
     :type executor_cores: int
     :param executor_memory: Memory per executor (e.g. 1000M, 2G) (Default: 1G)
     :type executor_memory: str
@@ -110,12 +116,25 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         self._num_executors = num_executors
         self._application_args = application_args
         self._verbose = verbose
-        self._sp = None
+        self._submit_sp = None
         self._yarn_application_id = None
 
         self._connection = self._resolve_connection()
         self._is_yarn = 'yarn' in self._connection['master']
 
+        self._should_track_driver_status = self._resolve_should_track_driver_status()
+        self._driver_id = None
+        self._driver_status = None
+
+    def _resolve_should_track_driver_status(self):
+        """
+        Determines whether or not this hook should poll the spark driver status through
+        subsequent spark-submit status requests after the initial spark-submit request
+        :return: if the driver status should be tracked
+        """
+        return ('spark://' in self._connection['master'] and
+                self._connection['deploy_mode'] == 'cluster')
+
     def _resolve_connection(self):
         # Build from connection master or default to yarn if not available
         conn_data = {'master': 'yarn',
@@ -149,21 +168,27 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
     def get_conn(self):
         pass
 
-    def _build_command(self, application):
-        """
-        Construct the spark-submit command to execute.
-        :param application: command to append to the spark-submit command
-        :type application: str
-        :return: full command to be executed
-        """
+    def _get_spark_binary_path(self):
         # If the spark_home is passed then build the spark-submit executable path using
         # the spark_home; otherwise assume that spark-submit is present in the path to
         # the executing user
         if self._connection['spark_home']:
-            connection_cmd = [os.path.join(self._connection['spark_home'], 'bin', self._connection['spark_binary'])]
+            connection_cmd = [os.path.join(self._connection['spark_home'], 'bin',
+                                           self._connection['spark_binary'])]
         else:
             connection_cmd = [self._connection['spark_binary']]
 
+        return connection_cmd
+
+    def _build_spark_submit_command(self, application):
+        """
+        Construct the spark-submit command to execute.
+        :param application: command to append to the spark-submit command
+        :type application: str
+        :return: full command to be executed
+        """
+        connection_cmd = self._get_spark_binary_path()
+
         # The url ot the spark master
         connection_cmd += ["--master", self._connection['master']]
 
@@ -216,7 +241,30 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         if self._application_args:
             connection_cmd += self._application_args
 
-        self.log.debug("Spark-Submit cmd: %s", connection_cmd)
+        self.log.info("Spark-Submit cmd: %s", connection_cmd)
+
+        return connection_cmd
+
+    def _build_track_driver_status_command(self):
+        """
+        Construct the command to poll the driver status.
+
+        :return: full command to be executed
+        """
+        connection_cmd = self._get_spark_binary_path()
+
+        # The url ot the spark master
+        connection_cmd += ["--master", self._connection['master']]
+
+        # The driver id so we can poll for its status
+        if self._driver_id:
+            connection_cmd += ["--status", self._driver_id]
+        else:
+            raise AirflowException(
+                "Invalid status: attempted to poll driver " +
+                "status but no driver id is known. Giving up.")
+
+        self.log.debug("Poll driver status cmd: %s", connection_cmd)
 
         return connection_cmd
 
@@ -228,16 +276,16 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         :type application: str
         :param kwargs: extra arguments to Popen (see subprocess.Popen)
         """
-        spark_submit_cmd = self._build_command(application)
-        self._sp = subprocess.Popen(spark_submit_cmd,
-                                    stdout=subprocess.PIPE,
-                                    stderr=subprocess.STDOUT,
-                                    bufsize=-1,
-                                    universal_newlines=True,
-                                    **kwargs)
+        spark_submit_cmd = self._build_spark_submit_command(application)
+        self._submit_sp = subprocess.Popen(spark_submit_cmd,
+                                           stdout=subprocess.PIPE,
+                                           stderr=subprocess.STDOUT,
+                                           bufsize=-1,
+                                           universal_newlines=True,
+                                           **kwargs)
 
-        self._process_log(iter(self._sp.stdout.readline, ''))
-        returncode = self._sp.wait()
+        self._process_spark_submit_log(iter(self._submit_sp.stdout.readline, ''))
+        returncode = self._submit_sp.wait()
 
         if returncode:
             raise AirflowException(
@@ -246,9 +294,34 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                 )
             )
 
-    def _process_log(self, itr):
+        self.log.debug("Should track driver: {}".format(self._should_track_driver_status))
+
+        # We want the Airflow job to wait until the Spark driver is finished
+        if self._should_track_driver_status:
+            if self._driver_id is None:
+                raise AirflowException(
+                    "No driver id is known: something went wrong when executing " +
+                    "the spark submit command"
+                )
+
+            # We start with the SUBMITTED status as initial status
+            self._driver_status = "SUBMITTED"
+
+            # Start tracking the driver status (blocking function)
+            self._start_driver_status_tracking()
+
+            if self._driver_status != "FINISHED":
+                raise AirflowException(
+                    "ERROR : Driver {} badly exited with status {}"
+                    .format(self._driver_id, self._driver_status)
+                )
+
+    def _process_spark_submit_log(self, itr):
         """
-        Processes the log files and extracts useful information out of it
+        Processes the log files and extracts useful information out of it.
+
+        Remark: If the driver needs to be tracked for its status, the log-level of the
+        spark deploy needs to be at least INFO (log4j.logger.org.apache.spark.deploy=INFO)
 
         :param itr: An iterator which iterates over the input of the subprocess
         """
@@ -262,16 +335,94 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                 if match:
                     self._yarn_application_id = match.groups()[0]
 
-            # Pass to logging
-            self.log.info(line)
+            # if we run in standalone cluster mode and we want to track the driver status
+            # we need to extract the driver id from the logs. This allows us to poll for
+            # the status using the driver id. Also, we can kill the driver when needed.
+            if self._should_track_driver_status and not self._driver_id:
+                match_driver_id = re.search('(driver-[0-9\-]+)', line)
+                if match_driver_id:
+                    self._driver_id = match_driver_id.groups()[0]
+                    self.log.info("identified spark driver id: {}"
+                                  .format(self._driver_id))
+
+            self.log.debug("spark submit log: {}".format(line))
+
+    def _process_spark_status_log(self, itr):
+        """
+        parses the logs of the spark driver status query process
+
+        :param itr: An iterator which iterates over the input of the subprocess
+        """
+        # Consume the iterator
+        for line in itr:
+            line = line.strip()
+
+            # Check if the log line is about the driver status and extract the status.
+            if "driverState" in line:
+                self._driver_status = line.split(' : ')[1] \
+                    .replace(',', '').replace('\"', '').strip()
+
+            self.log.debug("spark driver status log: {}".format(line))
+
+    def _start_driver_status_tracking(self):
+        """
+        Polls the driver based on self._driver_id to get the status.
+        Finish successfully when the status is FINISHED.
+        Finish failed when the status is ERROR/UNKNOWN/KILLED/FAILED.
+
+        Possible status:
+            SUBMITTED: Submitted but not yet scheduled on a worker
+            RUNNING: Has been allocated to a worker to run
+            FINISHED: Previously ran and exited cleanly
+            RELAUNCHING: Exited non-zero or due to worker failure, but has not yet
+            started running again
+            UNKNOWN: The status of the driver is temporarily not known due to
+             master failure recovery
+            KILLED: A user manually killed this driver
+            FAILED: The driver exited non-zero and was not supervised
+            ERROR: Unable to run or restart due to an unrecoverable error
+            (e.g. missing jar file)
+        """
+        # Keep polling as long as the driver is processing
+        while self._driver_status not in ["FINISHED", "UNKNOWN",
+                                          "KILLED", "FAILED", "ERROR"]:
+
+            # Sleep for 1 second as we do not want to spam the cluster
+            time.sleep(1)
+
+            self.log.debug("polling status of spark driver with id {}"
+                           .format(self._driver_id))
+
+            poll_drive_status_cmd = self._build_track_driver_status_command()
+            status_process = subprocess.Popen(poll_drive_status_cmd,
+                                              stdout=subprocess.PIPE,
+                                              stderr=subprocess.STDOUT,
+                                              bufsize=-1,
+                                              universal_newlines=True)
+
+            self._process_spark_status_log(iter(status_process.stdout.readline, ''))
+            returncode = status_process.wait()
+
+            if returncode:
+                raise AirflowException(
+                    "Failed to poll for the driver status: returncode = {}"
+                    .format(returncode)
+                )
 
     def on_kill(self):
-        if self._sp and self._sp.poll() is None:
+
+        if self._submit_sp and self._submit_sp.poll() is None:
             self.log.info('Sending kill signal to %s', self._connection['spark_binary'])
-            self._sp.kill()
+            self._submit_sp.kill()
 
             if self._yarn_application_id:
-                self.log.info('Killing application on YARN')
-                kill_cmd = "yarn application -kill {0}".format(self._yarn_application_id).split()
-                yarn_kill = subprocess.Popen(kill_cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+                self.log.info('Killing application {} on YARN'
+                              .format(self._yarn_application_id))
+
+                kill_cmd = "yarn application -kill {}" \
+                    .format(self._yarn_application_id).split()
+                yarn_kill = subprocess.Popen(kill_cmd,
+                                             stdout=subprocess.PIPE,
+                                             stderr=subprocess.PIPE)
+
                 self.log.info("YARN killed with return code: %s", yarn_kill.wait())

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3e6babe8/airflow/contrib/operators/spark_submit_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/spark_submit_operator.py b/airflow/contrib/operators/spark_submit_operator.py
index ae821fa..d743393 100644
--- a/airflow/contrib/operators/spark_submit_operator.py
+++ b/airflow/contrib/operators/spark_submit_operator.py
@@ -42,15 +42,20 @@ class SparkSubmitOperator(BaseOperator):
     :type jars: str
     :param java_class: the main class of the Java application
     :type java_class: str
-    :param packages: Comma-separated list of maven coordinates of jars to include on the
driver and executor classpaths
+    :param packages: Comma-separated list of maven coordinates of jars to include on the
+    driver and executor classpaths
     :type packages: str
-    :param exclude_packages: Comma-separated list of maven coordinates of jars to exclude
while resolving the dependencies provided in 'packages'
+    :param exclude_packages: Comma-separated list of maven coordinates of jars to exclude
+    while resolving the dependencies provided in 'packages'
     :type exclude_packages: str
-    :param repositories: Comma-separated list of additional remote repositories to search
for the maven coordinates given with 'packages'
+    :param repositories: Comma-separated list of additional remote repositories to search
+    for the maven coordinates given with 'packages'
     :type repositories: str
-    :param total_executor_cores: (Standalone & Mesos only) Total cores for all executors
(Default: all the available cores on the worker)
+    :param total_executor_cores: (Standalone & Mesos only) Total cores for all executors
+    (Default: all the available cores on the worker)
     :type total_executor_cores: int
-    :param executor_cores: (Standalone & YARN only) Number of cores per executor (Default:
2)
+    :param executor_cores: (Standalone & YARN only) Number of cores per executor
+    (Default: 2)
     :type executor_cores: int
     :param executor_memory: Memory per executor (e.g. 1000M, 2G) (Default: 1G)
     :type executor_memory: str
@@ -69,7 +74,7 @@ class SparkSubmitOperator(BaseOperator):
     :param verbose: Whether to pass the verbose flag to spark-submit process for debugging
     :type verbose: bool
     """
-    template_fields = ('_name', '_application_args','_packages')
+    template_fields = ('_name', '_application_args', '_packages')
     ui_color = WEB_COLORS['LIGHTORANGE']
 
     @apply_defaults

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3e6babe8/tests/contrib/hooks/test_spark_submit_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/test_spark_submit_hook.py b/tests/contrib/hooks/test_spark_submit_hook.py
index 5cb7132..6c55ce2 100644
--- a/tests/contrib/hooks/test_spark_submit_hook.py
+++ b/tests/contrib/hooks/test_spark_submit_hook.py
@@ -13,7 +13,6 @@
 # limitations under the License.
 #
 import six
-import sys
 import unittest
 
 from airflow import configuration, models
@@ -61,7 +60,7 @@ class TestSparkSubmitHook(unittest.TestCase):
         for arg in list_cmd:
             if arg.startswith("--"):
                 pos = list_cmd.index(arg)
-                return_dict[arg] = list_cmd[pos+1]
+                return_dict[arg] = list_cmd[pos + 1]
         return return_dict
 
     def setUp(self):
@@ -70,7 +69,8 @@ class TestSparkSubmitHook(unittest.TestCase):
         db.merge_conn(
             models.Connection(
                 conn_id='spark_yarn_cluster', conn_type='spark',
-                host='yarn://yarn-master', extra='{"queue": "root.etl", "deploy-mode": "cluster"}')
+                host='yarn://yarn-master',
+                extra='{"queue": "root.etl", "deploy-mode": "cluster"}')
         )
         db.merge_conn(
             models.Connection(
@@ -98,15 +98,23 @@ class TestSparkSubmitHook(unittest.TestCase):
         db.merge_conn(
             models.Connection(
                 conn_id='spark_binary_and_home_set', conn_type='spark',
-                host='yarn', extra='{"spark-home": "/path/to/spark_home", "spark-binary":
"custom-spark-submit"}')
+                host='yarn',
+                extra='{"spark-home": "/path/to/spark_home", ' +
+                      '"spark-binary": "custom-spark-submit"}')
+        )
+        db.merge_conn(
+            models.Connection(
+                conn_id='spark_standalone_cluster', conn_type='spark',
+                host='spark://spark-standalone-master:6066',
+                extra='{"spark-home": "/path/to/spark_home", "deploy-mode": "cluster"}')
         )
 
-    def test_build_command(self):
+    def test_build_spark_submit_command(self):
         # Given
         hook = SparkSubmitHook(**self._config)
 
         # When
-        cmd = hook._build_command(self._spark_job_file)
+        cmd = hook._build_spark_submit_command(self._spark_job_file)
 
         # Then
         expected_build_cmd = [
@@ -149,7 +157,51 @@ class TestSparkSubmitHook(unittest.TestCase):
         hook.submit()
 
         # Then
-        self.assertEqual(mock_popen.mock_calls[0], call(['spark-submit', '--master', 'yarn',
'--name', 'default-name', ''], stderr=-2, stdout=-1, universal_newlines=True, bufsize=-1))
+        self.assertEqual(mock_popen.mock_calls[0],
+                         call(['spark-submit', '--master', 'yarn',
+                               '--name', 'default-name', ''],
+                              stderr=-2, stdout=-1, universal_newlines=True, bufsize=-1))
+
+    def test_resolve_should_track_driver_status(self):
+        # Given
+        hook_default = SparkSubmitHook(conn_id='')
+        hook_spark_yarn_cluster = SparkSubmitHook(conn_id='spark_yarn_cluster')
+        hook_spark_default_mesos = SparkSubmitHook(conn_id='spark_default_mesos')
+        hook_spark_home_set = SparkSubmitHook(conn_id='spark_home_set')
+        hook_spark_home_not_set = SparkSubmitHook(conn_id='spark_home_not_set')
+        hook_spark_binary_set = SparkSubmitHook(conn_id='spark_binary_set')
+        hook_spark_binary_and_home_set = SparkSubmitHook(
+            conn_id='spark_binary_and_home_set')
+        hook_spark_standalone_cluster = SparkSubmitHook(
+            conn_id='spark_standalone_cluster')
+
+        # When
+        should_track_driver_status_default = hook_default \
+            ._resolve_should_track_driver_status()
+        should_track_driver_status_spark_yarn_cluster = hook_spark_yarn_cluster \
+            ._resolve_should_track_driver_status()
+        should_track_driver_status_spark_default_mesos = hook_spark_default_mesos \
+            ._resolve_should_track_driver_status()
+        should_track_driver_status_spark_home_set = hook_spark_home_set \
+            ._resolve_should_track_driver_status()
+        should_track_driver_status_spark_home_not_set = hook_spark_home_not_set \
+            ._resolve_should_track_driver_status()
+        should_track_driver_status_spark_binary_set = hook_spark_binary_set \
+            ._resolve_should_track_driver_status()
+        should_track_driver_status_spark_binary_and_home_set = \
+            hook_spark_binary_and_home_set._resolve_should_track_driver_status()
+        should_track_driver_status_spark_standalone_cluster = \
+            hook_spark_standalone_cluster._resolve_should_track_driver_status()
+
+        # Then
+        self.assertEqual(should_track_driver_status_default, False)
+        self.assertEqual(should_track_driver_status_spark_yarn_cluster, False)
+        self.assertEqual(should_track_driver_status_spark_default_mesos, False)
+        self.assertEqual(should_track_driver_status_spark_home_set, False)
+        self.assertEqual(should_track_driver_status_spark_home_not_set, False)
+        self.assertEqual(should_track_driver_status_spark_binary_set, False)
+        self.assertEqual(should_track_driver_status_spark_binary_and_home_set, False)
+        self.assertEqual(should_track_driver_status_spark_standalone_cluster, True)
 
     def test_resolve_connection_yarn_default(self):
         # Given
@@ -157,7 +209,7 @@ class TestSparkSubmitHook(unittest.TestCase):
 
         # When
         connection = hook._resolve_connection()
-        cmd = hook._build_command(self._spark_job_file)
+        cmd = hook._build_spark_submit_command(self._spark_job_file)
 
         # Then
         dict_cmd = self.cmd_args_to_dict(cmd)
@@ -175,7 +227,7 @@ class TestSparkSubmitHook(unittest.TestCase):
 
         # When
         connection = hook._resolve_connection()
-        cmd = hook._build_command(self._spark_job_file)
+        cmd = hook._build_spark_submit_command(self._spark_job_file)
 
         # Then
         dict_cmd = self.cmd_args_to_dict(cmd)
@@ -194,7 +246,7 @@ class TestSparkSubmitHook(unittest.TestCase):
 
         # When
         connection = hook._resolve_connection()
-        cmd = hook._build_command(self._spark_job_file)
+        cmd = hook._build_spark_submit_command(self._spark_job_file)
 
         # Then
         dict_cmd = self.cmd_args_to_dict(cmd)
@@ -212,7 +264,7 @@ class TestSparkSubmitHook(unittest.TestCase):
 
         # When
         connection = hook._resolve_connection()
-        cmd = hook._build_command(self._spark_job_file)
+        cmd = hook._build_spark_submit_command(self._spark_job_file)
 
         # Then
         dict_cmd = self.cmd_args_to_dict(cmd)
@@ -232,7 +284,7 @@ class TestSparkSubmitHook(unittest.TestCase):
 
         # When
         connection = hook._resolve_connection()
-        cmd = hook._build_command(self._spark_job_file)
+        cmd = hook._build_spark_submit_command(self._spark_job_file)
 
         # Then
         expected_spark_connection = {"master": "yarn://yarn-master",
@@ -249,7 +301,7 @@ class TestSparkSubmitHook(unittest.TestCase):
 
         # When
         connection = hook._resolve_connection()
-        cmd = hook._build_command(self._spark_job_file)
+        cmd = hook._build_spark_submit_command(self._spark_job_file)
 
         # Then
         expected_spark_connection = {"master": "yarn://yarn-master",
@@ -266,7 +318,7 @@ class TestSparkSubmitHook(unittest.TestCase):
 
         # When
         connection = hook._resolve_connection()
-        cmd = hook._build_command(self._spark_job_file)
+        cmd = hook._build_spark_submit_command(self._spark_job_file)
 
         # Then
         expected_spark_connection = {"master": "yarn",
@@ -283,7 +335,7 @@ class TestSparkSubmitHook(unittest.TestCase):
 
         # When
         connection = hook._resolve_connection()
-        cmd = hook._build_command(self._spark_job_file)
+        cmd = hook._build_spark_submit_command(self._spark_job_file)
 
         # Then
         expected_spark_connection = {"master": "yarn",
@@ -294,25 +346,87 @@ class TestSparkSubmitHook(unittest.TestCase):
         self.assertEqual(connection, expected_spark_connection)
         self.assertEqual(cmd[0], '/path/to/spark_home/bin/custom-spark-submit')
 
-    def test_process_log(self):
+    def test_resolve_connection_spark_standalone_cluster_connection(self):
+        # Given
+        hook = SparkSubmitHook(conn_id='spark_standalone_cluster')
+
+        # When
+        connection = hook._resolve_connection()
+        cmd = hook._build_spark_submit_command(self._spark_job_file)
+
+        # Then
+        expected_spark_connection = {"master": "spark://spark-standalone-master:6066",
+                                     "spark_binary": "spark-submit",
+                                     "deploy_mode": "cluster",
+                                     "queue": None,
+                                     "spark_home": "/path/to/spark_home"}
+        self.assertEqual(connection, expected_spark_connection)
+        self.assertEqual(cmd[0], '/path/to/spark_home/bin/spark-submit')
+
+    def test_process_spark_submit_log_yarn(self):
         # Given
         hook = SparkSubmitHook(conn_id='spark_yarn_cluster')
         log_lines = [
             'SPARK_MAJOR_VERSION is set to 2, using Spark2',
-            'WARN NativeCodeLoader: Unable to load native-hadoop library for your platform...
using builtin-java classes where applicable',
-            'WARN DomainSocketFactory: The short-circuit local reads feature cannot be used
because libhadoop cannot be loaded.',
+            'WARN NativeCodeLoader: Unable to load native-hadoop library for your ' +
+            'platform... using builtin-java classes where applicable',
+            'WARN DomainSocketFactory: The short-circuit local reads feature cannot '
+            'be used because libhadoop cannot be loaded.',
             'INFO Client: Requesting a new application from cluster with 10 NodeManagers',
-            'INFO Client: Submitting application application_1486558679801_1820 to ResourceManager'
+            'INFO Client: Submitting application application_1486558679801_1820 ' +
+            'to ResourceManager'
         ]
         # When
-        hook._process_log(log_lines)
+        hook._process_spark_submit_log(log_lines)
 
         # Then
 
         self.assertEqual(hook._yarn_application_id, 'application_1486558679801_1820')
 
+    def test_process_spark_submit_log_standalone_cluster(self):
+        # Given
+        hook = SparkSubmitHook(conn_id='spark_standalone_cluster')
+        log_lines = [
+            'Running Spark using the REST application submission protocol.',
+            '17/11/28 11:14:15 INFO RestSubmissionClient: Submitting a request '
+            'to launch an application in spark://spark-standalone-master:6066',
+            '17/11/28 11:14:15 INFO RestSubmissionClient: Submission successfully ' +
+            'created as driver-20171128111415-0001. Polling submission state...'
+        ]
+        # When
+        hook._process_spark_submit_log(log_lines)
+
+        # Then
+
+        self.assertEqual(hook._driver_id, 'driver-20171128111415-0001')
+
+    def test_process_spark_driver_status_log(self):
+        # Given
+        hook = SparkSubmitHook(conn_id='spark_standalone_cluster')
+        log_lines = [
+            'Submitting a request for the status of submission ' +
+            'driver-20171128111415-0001 in spark://spark-standalone-master:6066',
+            '17/11/28 11:15:37 INFO RestSubmissionClient: Server responded with ' +
+            'SubmissionStatusResponse:',
+            '{',
+            '"action" : "SubmissionStatusResponse",',
+            '"driverState" : "RUNNING",',
+            '"serverSparkVersion" : "1.6.0",',
+            '"submissionId" : "driver-20171128111415-0001",',
+            '"success" : true,',
+            '"workerHostPort" : "172.18.0.7:38561",',
+            '"workerId" : "worker-20171128110741-172.18.0.7-38561"',
+            '}'
+        ]
+        # When
+        hook._process_spark_status_log(log_lines)
+
+        # Then
+
+        self.assertEqual(hook._driver_status, 'RUNNING')
+
     @patch('airflow.contrib.hooks.spark_submit_hook.subprocess.Popen')
-    def test_spark_process_on_kill(self, mock_popen):
+    def test_yarn_process_on_kill(self, mock_popen):
         # Given
         mock_popen.return_value.stdout = six.StringIO('stdout')
         mock_popen.return_value.stderr = six.StringIO('stderr')
@@ -320,20 +434,27 @@ class TestSparkSubmitHook(unittest.TestCase):
         mock_popen.return_value.wait.return_value = 0
         log_lines = [
             'SPARK_MAJOR_VERSION is set to 2, using Spark2',
-            'WARN NativeCodeLoader: Unable to load native-hadoop library for your platform...
using builtin-java classes where applicable',
-            'WARN DomainSocketFactory: The short-circuit local reads feature cannot be used
because libhadoop cannot be loaded.',
-            'INFO Client: Requesting a new application from cluster with 10 NodeManagerapplication_1486558679801_1820s',
-            'INFO Client: Submitting application application_1486558679801_1820 to ResourceManager'
+            'WARN NativeCodeLoader: Unable to load native-hadoop library for your ' +
+            'platform... using builtin-java classes where applicable',
+            'WARN DomainSocketFactory: The short-circuit local reads feature cannot ' +
+            'be used because libhadoop cannot be loaded.',
+            'INFO Client: Requesting a new application from cluster with 10 ' +
+            'NodeManagerapplication_1486558679801_1820s',
+            'INFO Client: Submitting application application_1486558679801_1820 ' +
+            'to ResourceManager'
         ]
         hook = SparkSubmitHook(conn_id='spark_yarn_cluster')
-        hook._process_log(log_lines)
+        hook._process_spark_submit_log(log_lines)
         hook.submit()
 
         # When
         hook.on_kill()
 
         # Then
-        self.assertIn(call(['yarn', 'application', '-kill', 'application_1486558679801_1820'],
stderr=-1, stdout=-1), mock_popen.mock_calls)
+        self.assertIn(call(['yarn', 'application', '-kill',
+                            'application_1486558679801_1820'],
+                           stderr=-1, stdout=-1),
+                      mock_popen.mock_calls)
 
 
 if __name__ == '__main__':

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3e6babe8/tests/contrib/operators/test_spark_submit_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/test_spark_submit_operator.py b/tests/contrib/operators/test_spark_submit_operator.py
index 05ddc32..5903fcd 100644
--- a/tests/contrib/operators/test_spark_submit_operator.py
+++ b/tests/contrib/operators/test_spark_submit_operator.py
@@ -14,7 +14,6 @@
 #
 
 import unittest
-import sys
 
 from airflow import DAG, configuration
 from airflow.models import TaskInstance
@@ -40,7 +39,7 @@ class TestSparkSubmitOperator(unittest.TestCase):
         'packages': 'com.databricks:spark-avro_2.11:3.2.0',
         'exclude_packages': 'org.bad.dependency:1.0.0',
         'repositories': 'http://myrepo.org',
-        'total_executor_cores':4,
+        'total_executor_cores': 4,
         'executor_cores': 4,
         'executor_memory': '22g',
         'keytab': 'privileged_user.keytab',
@@ -107,7 +106,6 @@ class TestSparkSubmitOperator(unittest.TestCase):
                 '--end', '{{ ds }}',
                 '--with-spaces', 'args should keep embdedded spaces',
             ]
-
         }
 
         self.assertEqual(conn_id, operator._conn_id)
@@ -120,7 +118,8 @@ class TestSparkSubmitOperator(unittest.TestCase):
         self.assertEqual(expected_dict['packages'], operator._packages)
         self.assertEqual(expected_dict['exclude_packages'], operator._exclude_packages)
         self.assertEqual(expected_dict['repositories'], operator._repositories)
-        self.assertEqual(expected_dict['total_executor_cores'], operator._total_executor_cores)
+        self.assertEqual(expected_dict['total_executor_cores'],
+                         operator._total_executor_cores)
         self.assertEqual(expected_dict['executor_cores'], operator._executor_cores)
         self.assertEqual(expected_dict['executor_memory'], operator._executor_memory)
         self.assertEqual(expected_dict['keytab'], operator._keytab)
@@ -134,7 +133,8 @@ class TestSparkSubmitOperator(unittest.TestCase):
 
     def test_render_template(self):
         # Given
-        operator = SparkSubmitOperator(task_id='spark_submit_job', dag=self.dag, **self._config)
+        operator = SparkSubmitOperator(task_id='spark_submit_job',
+                                       dag=self.dag, **self._config)
         ti = TaskInstance(operator, DEFAULT_DATE)
 
         # When
@@ -143,12 +143,15 @@ class TestSparkSubmitOperator(unittest.TestCase):
         # Then
         expected_application_args = [u'-f', 'foo',
                                      u'--bar', 'bar',
-                                     u'--start', (DEFAULT_DATE - timedelta(days=1)).strftime("%Y-%m-%d"),
+                                     u'--start', (DEFAULT_DATE - timedelta(days=1))
+                                     .strftime("%Y-%m-%d"),
                                      u'--end', DEFAULT_DATE.strftime("%Y-%m-%d"),
-                                     u'--with-spaces', u'args should keep embdedded spaces',
+                                     u'--with-spaces',
+                                     u'args should keep embdedded spaces',
                                      ]
         expected_name = "spark_submit_job"
-        self.assertListEqual(expected_application_args, getattr(operator, '_application_args'))
+        self.assertListEqual(expected_application_args,
+                             getattr(operator, '_application_args'))
         self.assertEqual(expected_name, getattr(operator, '_name'))
 
 


Mime
View raw message