airflow-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From bo...@apache.org
Subject [26/45] incubator-airflow git commit: [AIRFLOW-802][AIRFLOW-1] Add spark-submit operator/hook
Date Mon, 13 Mar 2017 04:45:24 GMT
[AIRFLOW-802][AIRFLOW-1] Add spark-submit operator/hook

Add a operator for spark-submit to kick off Apache
Spark jobs by
using Airflow. This allows the user to maintain
the configuration
of the master and yarn queue within Airflow by
using connections.
Add default connection_id to the initdb routine to
set spark
to yarn by default. Add unit tests to verify the
behaviour of
the spark-submit operator and hook.

Closes #2042 from Fokko/airflow-802


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/01494fd4
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/01494fd4
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/01494fd4

Branch: refs/heads/v1-8-stable
Commit: 01494fd4c0633dbb57f231ee17e015f42a5ecf24
Parents: c29af46
Author: Fokko Driesprong <fokkodriesprong@godatadriven.com>
Authored: Mon Feb 27 13:45:24 2017 +0100
Committer: Bolke de Bruin <bolke@Bolkes-MacBook-Pro.local>
Committed: Sun Mar 12 08:19:37 2017 -0700

----------------------------------------------------------------------
 airflow/contrib/hooks/__init__.py               |   1 +
 airflow/contrib/hooks/spark_submit_hook.py      | 226 +++++++++++++++++++
 airflow/contrib/operators/__init__.py           |   1 +
 .../contrib/operators/spark_submit_operator.py  | 112 +++++++++
 airflow/utils/db.py                             |   4 +
 tests/contrib/hooks/spark_submit_hook.py        | 148 ++++++++++++
 .../contrib/operators/spark_submit_operator.py  |  75 ++++++
 7 files changed, 567 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/01494fd4/airflow/contrib/hooks/__init__.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/__init__.py b/airflow/contrib/hooks/__init__.py
index a16a3f7..19fc2b4 100644
--- a/airflow/contrib/hooks/__init__.py
+++ b/airflow/contrib/hooks/__init__.py
@@ -42,6 +42,7 @@ _hooks = {
     'datastore_hook': ['DatastoreHook'],
     'gcp_dataproc_hook': ['DataProcHook'],
     'gcp_dataflow_hook': ['DataFlowHook'],
+    'spark_submit_operator': ['SparkSubmitOperator'],
     'cloudant_hook': ['CloudantHook'],
     'fs_hook': ['FSHook']
 }

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/01494fd4/airflow/contrib/hooks/spark_submit_hook.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/hooks/spark_submit_hook.py b/airflow/contrib/hooks/spark_submit_hook.py
new file mode 100644
index 0000000..619cc71
--- /dev/null
+++ b/airflow/contrib/hooks/spark_submit_hook.py
@@ -0,0 +1,226 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import logging
+import subprocess
+import re
+
+from airflow.hooks.base_hook import BaseHook
+from airflow.exceptions import AirflowException
+
+log = logging.getLogger(__name__)
+
+
+class SparkSubmitHook(BaseHook):
+    """
+    This hook is a wrapper around the spark-submit binary to kick off a spark-submit job.
+    It requires that the "spark-submit" binary is in the PATH.
+    :param conf: Arbitrary Spark configuration properties
+    :type conf: dict
+    :param conn_id: The connection id as configured in Airflow administration. When an
+                    invalid connection_id is supplied, it will default to yarn.
+    :type conn_id: str
+    :param files: Upload additional files to the container running the job, separated by
a
+                  comma. For example hive-site.xml.
+    :type files: str
+    :param py_files: Additional python files used by the job, can be .zip, .egg or .py.
+    :type py_files: str
+    :param jars: Submit additional jars to upload and place them in executor classpath.
+    :type jars: str
+    :param executor_cores: Number of cores per executor (Default: 2)
+    :type executor_cores: int
+    :param executor_memory: Memory per executor (e.g. 1000M, 2G) (Default: 1G)
+    :type executor_memory: str
+    :param keytab: Full path to the file that contains the keytab
+    :type keytab: str
+    :param principal: The name of the kerberos principal used for keytab
+    :type principal: str
+    :param name: Name of the job (default airflow-spark)
+    :type name: str
+    :param num_executors: Number of executors to launch
+    :type num_executors: int
+    :param verbose: Whether to pass the verbose flag to spark-submit process for debugging
+    :type verbose: bool
+    """
+
+    def __init__(self,
+                 conf=None,
+                 conn_id='spark_default',
+                 files=None,
+                 py_files=None,
+                 jars=None,
+                 executor_cores=None,
+                 executor_memory=None,
+                 keytab=None,
+                 principal=None,
+                 name='default-name',
+                 num_executors=None,
+                 verbose=False):
+        self._conf = conf
+        self._conn_id = conn_id
+        self._files = files
+        self._py_files = py_files
+        self._jars = jars
+        self._executor_cores = executor_cores
+        self._executor_memory = executor_memory
+        self._keytab = keytab
+        self._principal = principal
+        self._name = name
+        self._num_executors = num_executors
+        self._verbose = verbose
+        self._sp = None
+        self._yarn_application_id = None
+
+        (self._master, self._queue, self._deploy_mode) = self._resolve_connection()
+        self._is_yarn = 'yarn' in self._master
+
+    def _resolve_connection(self):
+        # Build from connection master or default to yarn if not available
+        master = 'yarn'
+        queue = None
+        deploy_mode = None
+
+        try:
+            # Master can be local, yarn, spark://HOST:PORT or mesos://HOST:PORT
+            conn = self.get_connection(self._conn_id)
+            if conn.port:
+                master = "{}:{}".format(conn.host, conn.port)
+            else:
+                master = conn.host
+
+            # Determine optional yarn queue from the extra field
+            extra = conn.extra_dejson
+            if 'queue' in extra:
+                queue = extra['queue']
+            if 'deploy-mode' in extra:
+                deploy_mode = extra['deploy-mode']
+        except AirflowException:
+            logging.debug(
+                "Could not load connection string {}, defaulting to {}".format(
+                    self._conn_id, master
+                )
+            )
+
+        return master, queue, deploy_mode
+
+    def get_conn(self):
+        pass
+
+    def _build_command(self, application):
+        """
+        Construct the spark-submit command to execute.
+        :param application: command to append to the spark-submit command
+        :type application: str
+        :return: full command to be executed
+        """
+        # The spark-submit binary needs to be in the path
+        connection_cmd = ["spark-submit"]
+
+        # The url ot the spark master
+        connection_cmd += ["--master", self._master]
+
+        if self._conf:
+            for key in self._conf:
+                connection_cmd += ["--conf", "{}={}".format(key, str(self._conf[key]))]
+        if self._files:
+            connection_cmd += ["--files", self._files]
+        if self._py_files:
+            connection_cmd += ["--py-files", self._py_files]
+        if self._jars:
+            connection_cmd += ["--jars", self._jars]
+        if self._num_executors:
+            connection_cmd += ["--num-executors", str(self._num_executors)]
+        if self._executor_cores:
+            connection_cmd += ["--executor-cores", str(self._executor_cores)]
+        if self._executor_memory:
+            connection_cmd += ["--executor-memory", self._executor_memory]
+        if self._keytab:
+            connection_cmd += ["--keytab", self._keytab]
+        if self._principal:
+            connection_cmd += ["--principal", self._principal]
+        if self._name:
+            connection_cmd += ["--name", self._name]
+        if self._verbose:
+            connection_cmd += ["--verbose"]
+        if self._queue:
+            connection_cmd += ["--queue", self._queue]
+        if self._deploy_mode:
+            connection_cmd += ["--deploy-mode", self._deploy_mode]
+
+        # The actual script to execute
+        connection_cmd += [application]
+
+        logging.debug("Spark-Submit cmd: {}".format(connection_cmd))
+
+        return connection_cmd
+
+    def submit(self, application="", **kwargs):
+        """
+        Remote Popen to execute the spark-submit job
+
+        :param application: Submitted application, jar or py file
+        :type application: str
+        :param kwargs: extra arguments to Popen (see subprocess.Popen)
+        """
+        spark_submit_cmd = self._build_command(application)
+        self._sp = subprocess.Popen(spark_submit_cmd,
+                                    stdout=subprocess.PIPE,
+                                    stderr=subprocess.PIPE,
+                                    **kwargs)
+
+        # Using two iterators here to support 'real-time' logging
+        sources = [self._sp.stdout, self._sp.stderr]
+
+        for source in sources:
+            self._process_log(iter(source.readline, b''))
+
+        output, stderr = self._sp.communicate()
+
+        if self._sp.returncode:
+            raise AirflowException(
+                "Cannot execute: {}. Error code is: {}. Output: {}, Stderr: {}".format(
+                    spark_submit_cmd, self._sp.returncode, output, stderr
+                )
+            )
+
+    def _process_log(self, itr):
+        """
+        Processes the log files and extracts useful information out of it
+
+        :param itr: An iterator which iterates over the input of the subprocess
+        """
+        # Consume the iterator
+        for line in itr:
+            line = line.decode('utf-8').strip()
+            # If we run yarn cluster mode, we want to extract the application id from
+            # the logs so we can kill the application when we stop it unexpectedly
+            if self._is_yarn and self._deploy_mode == 'cluster':
+                match = re.search('(application[0-9_]+)', line)
+                if match:
+                    self._yarn_application_id = match.groups()[0]
+
+            # Pass to logging
+            logging.info(line)
+
+    def on_kill(self):
+        if self._sp and self._sp.poll() is None:
+            logging.info('Sending kill signal to spark-submit')
+            self.sp.kill()
+
+            if self._yarn_application_id:
+                logging.info('Killing application on YARN')
+                yarn_kill = Popen("yarn application -kill {0}".format(self._yarn_application_id),
+                                  stdout=subprocess.PIPE,
+                                  stderr=subprocess.PIPE)
+                logging.info("YARN killed with return code: {0}".format(yarn_kill.wait()))

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/01494fd4/airflow/contrib/operators/__init__.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/__init__.py b/airflow/contrib/operators/__init__.py
index ae481ea..bef3433 100644
--- a/airflow/contrib/operators/__init__.py
+++ b/airflow/contrib/operators/__init__.py
@@ -36,6 +36,7 @@ _operators = {
     'vertica_operator': ['VerticaOperator'],
     'vertica_to_hive': ['VerticaToHiveTransfer'],
     'qubole_operator': ['QuboleOperator'],
+    'spark_submit_operator': ['SparkSubmitOperator'],
     'fs_operator': ['FileSensor']
 }
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/01494fd4/airflow/contrib/operators/spark_submit_operator.py
----------------------------------------------------------------------
diff --git a/airflow/contrib/operators/spark_submit_operator.py b/airflow/contrib/operators/spark_submit_operator.py
new file mode 100644
index 0000000..a5e6145
--- /dev/null
+++ b/airflow/contrib/operators/spark_submit_operator.py
@@ -0,0 +1,112 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import logging
+
+from airflow.contrib.hooks.spark_submit_hook import SparkSubmitHook
+from airflow.models import BaseOperator
+from airflow.utils.decorators import apply_defaults
+
+log = logging.getLogger(__name__)
+
+
+class SparkSubmitOperator(BaseOperator):
+    """
+    This hook is a wrapper around the spark-submit binary to kick off a spark-submit job.
+    It requires that the "spark-submit" binary is in the PATH.
+    :param application: The application that submitted as a job, either jar or py file.
+    :type application: str
+    :param conf: Arbitrary Spark configuration properties
+    :type conf: dict
+    :param conn_id: The connection id as configured in Airflow administration. When an
+                    invalid connection_id is supplied, it will default to yarn.
+    :type conn_id: str
+    :param files: Upload additional files to the container running the job, separated by
a
+                  comma. For example hive-site.xml.
+    :type files: str
+    :param py_files: Additional python files used by the job, can be .zip, .egg or .py.
+    :type py_files: str
+    :param jars: Submit additional jars to upload and place them in executor classpath.
+    :type jars: str
+    :param executor_cores: Number of cores per executor (Default: 2)
+    :type executor_cores: int
+    :param executor_memory: Memory per executor (e.g. 1000M, 2G) (Default: 1G)
+    :type executor_memory: str
+    :param keytab: Full path to the file that contains the keytab
+    :type keytab: str
+    :param principal: The name of the kerberos principal used for keytab
+    :type principal: str
+    :param name: Name of the job (default airflow-spark)
+    :type name: str
+    :param num_executors: Number of executors to launch
+    :type num_executors: int
+    :param verbose: Whether to pass the verbose flag to spark-submit process for debugging
+    :type verbose: bool
+    """
+
+    @apply_defaults
+    def __init__(self,
+                 application='',
+                 conf=None,
+                 conn_id='spark_default',
+                 files=None,
+                 py_files=None,
+                 jars=None,
+                 executor_cores=None,
+                 executor_memory=None,
+                 keytab=None,
+                 principal=None,
+                 name='airflow-spark',
+                 num_executors=None,
+                 verbose=False,
+                 *args,
+                 **kwargs):
+        super(SparkSubmitOperator, self).__init__(*args, **kwargs)
+        self._application = application
+        self._conf = conf
+        self._files = files
+        self._py_files = py_files
+        self._jars = jars
+        self._executor_cores = executor_cores
+        self._executor_memory = executor_memory
+        self._keytab = keytab
+        self._principal = principal
+        self._name = name
+        self._num_executors = num_executors
+        self._verbose = verbose
+        self._hook = None
+        self._conn_id = conn_id
+
+    def execute(self, context):
+        """
+        Call the SparkSubmitHook to run the provided spark job
+        """
+        self._hook = SparkSubmitHook(
+            conf=self._conf,
+            conn_id=self._conn_id,
+            files=self._files,
+            py_files=self._py_files,
+            jars=self._jars,
+            executor_cores=self._executor_cores,
+            executor_memory=self._executor_memory,
+            keytab=self._keytab,
+            principal=self._principal,
+            name=self._name,
+            num_executors=self._num_executors,
+            verbose=self._verbose
+        )
+        self._hook.submit(self._application)
+
+    def on_kill(self):
+        self._hook.on_kill()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/01494fd4/airflow/utils/db.py
----------------------------------------------------------------------
diff --git a/airflow/utils/db.py b/airflow/utils/db.py
index 2502219..977a949 100644
--- a/airflow/utils/db.py
+++ b/airflow/utils/db.py
@@ -192,6 +192,10 @@ def initdb():
             extra='{"region_name": "us-east-1"}'))
     merge_conn(
         models.Connection(
+            conn_id='spark_default', conn_type='spark',
+            host='yarn', extra='{"queue": "root.default"}'))
+    merge_conn(
+        models.Connection(
             conn_id='emr_default', conn_type='emr',
             extra='''
                 {   "Name": "default_job_flow_name",

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/01494fd4/tests/contrib/hooks/spark_submit_hook.py
----------------------------------------------------------------------
diff --git a/tests/contrib/hooks/spark_submit_hook.py b/tests/contrib/hooks/spark_submit_hook.py
new file mode 100644
index 0000000..b18925a
--- /dev/null
+++ b/tests/contrib/hooks/spark_submit_hook.py
@@ -0,0 +1,148 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+
+from airflow import configuration, models
+from airflow.utils import db
+from airflow.exceptions import AirflowException
+from airflow.contrib.hooks.spark_submit_hook import SparkSubmitHook
+
+
+class TestSparkSubmitHook(unittest.TestCase):
+    _spark_job_file = 'test_application.py'
+    _config = {
+        'conf': {
+            'parquet.compression': 'SNAPPY'
+        },
+        'conn_id': 'default_spark',
+        'files': 'hive-site.xml',
+        'py_files': 'sample_library.py',
+        'jars': 'parquet.jar',
+        'executor_cores': 4,
+        'executor_memory': '22g',
+        'keytab': 'privileged_user.keytab',
+        'principal': 'user/spark@airflow.org',
+        'name': 'spark-job',
+        'num_executors': 10,
+        'verbose': True
+    }
+
+    def setUp(self):
+        configuration.load_test_config()
+        db.merge_conn(
+            models.Connection(
+                conn_id='spark_yarn_cluster', conn_type='spark',
+                host='yarn://yarn-mater', extra='{"queue": "root.etl", "deploy-mode": "cluster"}')
+        )
+        db.merge_conn(
+            models.Connection(
+                conn_id='spark_default_mesos', conn_type='spark',
+                host='mesos://host', port=5050)
+        )
+
+    def test_build_command(self):
+        hook = SparkSubmitHook(**self._config)
+
+        # The subprocess requires an array but we build the cmd by joining on a space
+        cmd = ' '.join(hook._build_command(self._spark_job_file))
+
+        # Check if the URL gets build properly and everything exists.
+        assert self._spark_job_file in cmd
+
+        # Check all the parameters
+        assert "--files {}".format(self._config['files']) in cmd
+        assert "--py-files {}".format(self._config['py_files']) in cmd
+        assert "--jars {}".format(self._config['jars']) in cmd
+        assert "--executor-cores {}".format(self._config['executor_cores']) in cmd
+        assert "--executor-memory {}".format(self._config['executor_memory']) in cmd
+        assert "--keytab {}".format(self._config['keytab']) in cmd
+        assert "--principal {}".format(self._config['principal']) in cmd
+        assert "--name {}".format(self._config['name']) in cmd
+        assert "--num-executors {}".format(self._config['num_executors']) in cmd
+
+        # Check if all config settings are there
+        for k in self._config['conf']:
+            assert "--conf {0}={1}".format(k, self._config['conf'][k]) in cmd
+
+        if self._config['verbose']:
+            assert "--verbose" in cmd
+
+    def test_submit(self):
+        hook = SparkSubmitHook()
+
+        # We don't have spark-submit available, and this is hard to mock, so just accept
+        # an exception for now.
+        with self.assertRaises(AirflowException):
+            hook.submit(self._spark_job_file)
+
+    def test_resolve_connection(self):
+
+        # Default to the standard yarn connection because conn_id does not exists
+        hook = SparkSubmitHook(conn_id='')
+        self.assertEqual(hook._resolve_connection(), ('yarn', None, None))
+        assert "--master yarn" in ' '.join(hook._build_command(self._spark_job_file))
+
+        # Default to the standard yarn connection
+        hook = SparkSubmitHook(conn_id='spark_default')
+        self.assertEqual(
+            hook._resolve_connection(),
+            ('yarn', 'root.default', None)
+        )
+        cmd = ' '.join(hook._build_command(self._spark_job_file))
+        assert "--master yarn" in cmd
+        assert "--queue root.default" in cmd
+
+        # Connect to a mesos master
+        hook = SparkSubmitHook(conn_id='spark_default_mesos')
+        self.assertEqual(
+            hook._resolve_connection(),
+            ('mesos://host:5050', None, None)
+        )
+
+        cmd = ' '.join(hook._build_command(self._spark_job_file))
+        assert "--master mesos://host:5050" in cmd
+
+        # Set specific queue and deploy mode
+        hook = SparkSubmitHook(conn_id='spark_yarn_cluster')
+        self.assertEqual(
+            hook._resolve_connection(),
+            ('yarn://yarn-master', 'root.etl', 'cluster')
+        )
+
+        cmd = ' '.join(hook._build_command(self._spark_job_file))
+        assert "--master yarn://yarn-master" in cmd
+        assert "--queue root.etl" in cmd
+        assert "--deploy-mode cluster" in cmd
+
+    def test_process_log(self):
+        # Must select yarn connection
+        hook = SparkSubmitHook(conn_id='spark_yarn_cluster')
+
+        log_lines = [
+            'SPARK_MAJOR_VERSION is set to 2, using Spark2',
+            'WARN NativeCodeLoader: Unable to load native-hadoop library for your platform...
using builtin-java classes where applicable',
+            'WARN DomainSocketFactory: The short-circuit local reads feature cannot be used
because libhadoop cannot be loaded.',
+            'INFO Client: Requesting a new application from cluster with 10 NodeManagers',
+            'INFO Client: Submitting application application_1486558679801_1820 to ResourceManager'
+        ]
+
+        hook._process_log(log_lines)
+
+        assert hook._yarn_application_id == 'application_1486558679801_1820'
+
+
+if __name__ == '__main__':
+    unittest.main()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/01494fd4/tests/contrib/operators/spark_submit_operator.py
----------------------------------------------------------------------
diff --git a/tests/contrib/operators/spark_submit_operator.py b/tests/contrib/operators/spark_submit_operator.py
new file mode 100644
index 0000000..c080f76
--- /dev/null
+++ b/tests/contrib/operators/spark_submit_operator.py
@@ -0,0 +1,75 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import unittest
+import datetime
+
+from airflow import DAG, configuration
+from airflow.contrib.operators.spark_submit_operator import SparkSubmitOperator
+
+DEFAULT_DATE = datetime.datetime(2017, 1, 1)
+
+
+class TestSparkSubmitOperator(unittest.TestCase):
+    _config = {
+        'conf': {
+            'parquet.compression': 'SNAPPY'
+        },
+        'files': 'hive-site.xml',
+        'py_files': 'sample_library.py',
+        'jars': 'parquet.jar',
+        'executor_cores': 4,
+        'executor_memory': '22g',
+        'keytab': 'privileged_user.keytab',
+        'principal': 'user/spark@airflow.org',
+        'name': 'spark-job',
+        'num_executors': 10,
+        'verbose': True,
+        'application': 'test_application.py'
+    }
+
+    def setUp(self):
+        configuration.load_test_config()
+        args = {
+            'owner': 'airflow',
+            'start_date': DEFAULT_DATE
+        }
+        self.dag = DAG('test_dag_id', default_args=args)
+
+    def test_execute(self, conn_id='spark_default'):
+        operator = SparkSubmitOperator(
+            task_id='spark_submit_job',
+            dag=self.dag,
+            **self._config
+        )
+
+        self.assertEqual(conn_id, operator._conn_id)
+
+        self.assertEqual(self._config['application'], operator._application)
+        self.assertEqual(self._config['conf'], operator._conf)
+        self.assertEqual(self._config['files'], operator._files)
+        self.assertEqual(self._config['py_files'], operator._py_files)
+        self.assertEqual(self._config['jars'], operator._jars)
+        self.assertEqual(self._config['executor_cores'], operator._executor_cores)
+        self.assertEqual(self._config['executor_memory'], operator._executor_memory)
+        self.assertEqual(self._config['keytab'], operator._keytab)
+        self.assertEqual(self._config['principal'], operator._principal)
+        self.assertEqual(self._config['name'], operator._name)
+        self.assertEqual(self._config['num_executors'], operator._num_executors)
+        self.assertEqual(self._config['verbose'], operator._verbose)
+
+
+if __name__ == '__main__':
+    unittest.main()


Mime
View raw message