spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From cutl...@apache.org
Subject spark git commit: [SPARK-23691][PYTHON] Use sql_conf util in PySpark tests where possible
Date Tue, 20 Mar 2018 04:25:49 GMT
Repository: spark
Updated Branches:
  refs/heads/master 5f4deff19 -> 566321852


[SPARK-23691][PYTHON] Use sql_conf util in PySpark tests where possible

## What changes were proposed in this pull request?

https://github.com/apache/spark/commit/d6632d185e147fcbe6724545488ad80dce20277e added an useful
util

```python
contextmanager
def sql_conf(self, pairs):
    ...
```

to allow configuration set/unset within a block:

```python
with self.sql_conf({"spark.blah.blah.blah", "blah"})
    # test codes
```

This PR proposes to use this util where possible in PySpark tests.

Note that there look already few places affecting tests without restoring the original value
back in unittest classes.

## How was this patch tested?

Manually tested via:

```
./run-tests --modules=pyspark-sql --python-executables=python2
./run-tests --modules=pyspark-sql --python-executables=python3
```

Author: hyukjinkwon <gurwls223@gmail.com>

Closes #20830 from HyukjinKwon/cleanup-sql-conf.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/56632185
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/56632185
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/56632185

Branch: refs/heads/master
Commit: 566321852b2d60641fe86acbc8914b4a7063b58e
Parents: 5f4deff
Author: hyukjinkwon <gurwls223@gmail.com>
Authored: Mon Mar 19 21:25:37 2018 -0700
Committer: Bryan Cutler <cutlerb@gmail.com>
Committed: Mon Mar 19 21:25:37 2018 -0700

----------------------------------------------------------------------
 python/pyspark/sql/tests.py | 130 +++++++++++++++------------------------
 1 file changed, 50 insertions(+), 80 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/56632185/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index a0d547a..39d6c52 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -2461,17 +2461,13 @@ class SQLTests(ReusedSQLTestCase):
         df1 = self.spark.range(1).toDF("a")
         df2 = self.spark.range(1).toDF("b")
 
-        try:
-            self.spark.conf.set("spark.sql.crossJoin.enabled", "false")
+        with self.sql_conf({"spark.sql.crossJoin.enabled": False}):
             self.assertRaises(AnalysisException, lambda: df1.join(df2, how="inner").collect())
 
-            self.spark.conf.set("spark.sql.crossJoin.enabled", "true")
+        with self.sql_conf({"spark.sql.crossJoin.enabled": True}):
             actual = df1.join(df2, how="inner").collect()
             expected = [Row(a=0, b=0)]
             self.assertEqual(actual, expected)
-        finally:
-            # We should unset this. Otherwise, other tests are affected.
-            self.spark.conf.unset("spark.sql.crossJoin.enabled")
 
     # Regression test for invalid join methods when on is None, Spark-14761
     def test_invalid_join_method(self):
@@ -2943,21 +2939,18 @@ class SQLTests(ReusedSQLTestCase):
         self.assertPandasEqual(pdf, df.toPandas())
 
         orig_env_tz = os.environ.get('TZ', None)
-        orig_session_tz = self.spark.conf.get('spark.sql.session.timeZone')
         try:
             tz = 'America/Los_Angeles'
             os.environ['TZ'] = tz
             time.tzset()
-            self.spark.conf.set('spark.sql.session.timeZone', tz)
-
-            df = self.spark.createDataFrame(pdf)
-            self.assertPandasEqual(pdf, df.toPandas())
+            with self.sql_conf({'spark.sql.session.timeZone': tz}):
+                df = self.spark.createDataFrame(pdf)
+                self.assertPandasEqual(pdf, df.toPandas())
         finally:
             del os.environ['TZ']
             if orig_env_tz is not None:
                 os.environ['TZ'] = orig_env_tz
             time.tzset()
-            self.spark.conf.set('spark.sql.session.timeZone', orig_session_tz)
 
 
 class HiveSparkSubmitTests(SparkSubmitTests):
@@ -3562,12 +3555,11 @@ class ArrowTests(ReusedSQLTestCase):
         self.assertTrue(all([c == 1 for c in null_counts]))
 
     def _toPandas_arrow_toggle(self, df):
-        self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
-        try:
+        with self.sql_conf({"spark.sql.execution.arrow.enabled": False}):
             pdf = df.toPandas()
-        finally:
-            self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
+
         pdf_arrow = df.toPandas()
+
         return pdf, pdf_arrow
 
     def test_toPandas_arrow_toggle(self):
@@ -3579,16 +3571,17 @@ class ArrowTests(ReusedSQLTestCase):
 
     def test_toPandas_respect_session_timezone(self):
         df = self.spark.createDataFrame(self.data, schema=self.schema)
-        orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
-        try:
-            timezone = "America/New_York"
-            self.spark.conf.set("spark.sql.session.timeZone", timezone)
-            self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
-            try:
-                pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
-                self.assertPandasEqual(pdf_arrow_la, pdf_la)
-            finally:
-                self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone",
"true")
+
+        timezone = "America/New_York"
+        with self.sql_conf({
+                "spark.sql.execution.pandas.respectSessionTimeZone": False,
+                "spark.sql.session.timeZone": timezone}):
+            pdf_la, pdf_arrow_la = self._toPandas_arrow_toggle(df)
+            self.assertPandasEqual(pdf_arrow_la, pdf_la)
+
+        with self.sql_conf({
+                "spark.sql.execution.pandas.respectSessionTimeZone": True,
+                "spark.sql.session.timeZone": timezone}):
             pdf_ny, pdf_arrow_ny = self._toPandas_arrow_toggle(df)
             self.assertPandasEqual(pdf_arrow_ny, pdf_ny)
 
@@ -3601,8 +3594,6 @@ class ArrowTests(ReusedSQLTestCase):
                     pdf_la_corrected[field.name] = _check_series_convert_timestamps_local_tz(
                         pdf_la_corrected[field.name], timezone)
             self.assertPandasEqual(pdf_ny, pdf_la_corrected)
-        finally:
-            self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
 
     def test_pandas_round_trip(self):
         pdf = self.create_pandas_data_frame()
@@ -3618,12 +3609,11 @@ class ArrowTests(ReusedSQLTestCase):
         self.assertTrue(pdf.empty)
 
     def _createDataFrame_toggle(self, pdf, schema=None):
-        self.spark.conf.set("spark.sql.execution.arrow.enabled", "false")
-        try:
+        with self.sql_conf({"spark.sql.execution.arrow.enabled": False}):
             df_no_arrow = self.spark.createDataFrame(pdf, schema=schema)
-        finally:
-            self.spark.conf.set("spark.sql.execution.arrow.enabled", "true")
+
         df_arrow = self.spark.createDataFrame(pdf, schema=schema)
+
         return df_no_arrow, df_arrow
 
     def test_createDataFrame_toggle(self):
@@ -3634,18 +3624,18 @@ class ArrowTests(ReusedSQLTestCase):
     def test_createDataFrame_respect_session_timezone(self):
         from datetime import timedelta
         pdf = self.create_pandas_data_frame()
-        orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
-        try:
-            timezone = "America/New_York"
-            self.spark.conf.set("spark.sql.session.timeZone", timezone)
-            self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
-            try:
-                df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema)
-                result_la = df_no_arrow_la.collect()
-                result_arrow_la = df_arrow_la.collect()
-                self.assertEqual(result_la, result_arrow_la)
-            finally:
-                self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone",
"true")
+        timezone = "America/New_York"
+        with self.sql_conf({
+                "spark.sql.execution.pandas.respectSessionTimeZone": False,
+                "spark.sql.session.timeZone": timezone}):
+            df_no_arrow_la, df_arrow_la = self._createDataFrame_toggle(pdf, schema=self.schema)
+            result_la = df_no_arrow_la.collect()
+            result_arrow_la = df_arrow_la.collect()
+            self.assertEqual(result_la, result_arrow_la)
+
+        with self.sql_conf({
+                "spark.sql.execution.pandas.respectSessionTimeZone": True,
+                "spark.sql.session.timeZone": timezone}):
             df_no_arrow_ny, df_arrow_ny = self._createDataFrame_toggle(pdf, schema=self.schema)
             result_ny = df_no_arrow_ny.collect()
             result_arrow_ny = df_arrow_ny.collect()
@@ -3658,8 +3648,6 @@ class ArrowTests(ReusedSQLTestCase):
                                           for k, v in row.asDict().items()})
                                    for row in result_la]
             self.assertEqual(result_ny, result_la_corrected)
-        finally:
-            self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
 
     def test_createDataFrame_with_schema(self):
         pdf = self.create_pandas_data_frame()
@@ -4336,9 +4324,7 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
     def test_vectorized_udf_check_config(self):
         from pyspark.sql.functions import pandas_udf, col
         import pandas as pd
-        orig_value = self.spark.conf.get("spark.sql.execution.arrow.maxRecordsPerBatch",
None)
-        self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 3)
-        try:
+        with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": 3}):
             df = self.spark.range(10, numPartitions=1)
 
             @pandas_udf(returnType=LongType())
@@ -4348,11 +4334,6 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
             result = df.select(check_records_per_batch(col("id"))).collect()
             for (r,) in result:
                 self.assertTrue(r <= 3)
-        finally:
-            if orig_value is None:
-                self.spark.conf.unset("spark.sql.execution.arrow.maxRecordsPerBatch")
-            else:
-                self.spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", orig_value)
 
     def test_vectorized_udf_timestamps_respect_session_timezone(self):
         from pyspark.sql.functions import pandas_udf, col
@@ -4371,30 +4352,27 @@ class ScalarPandasUDFTests(ReusedSQLTestCase):
         internal_value = pandas_udf(
             lambda ts: ts.apply(lambda ts: ts.value if ts is not pd.NaT else None), LongType())
 
-        orig_tz = self.spark.conf.get("spark.sql.session.timeZone")
-        try:
-            timezone = "America/New_York"
-            self.spark.conf.set("spark.sql.session.timeZone", timezone)
-            self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone", "false")
-            try:
-                df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
-                    .withColumn("internal_value", internal_value(col("timestamp")))
-                result_la = df_la.select(col("idx"), col("internal_value")).collect()
-                # Correct result_la by adjusting 3 hours difference between Los Angeles and
New York
-                diff = 3 * 60 * 60 * 1000 * 1000 * 1000
-                result_la_corrected = \
-                    df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()
-            finally:
-                self.spark.conf.set("spark.sql.execution.pandas.respectSessionTimeZone",
"true")
+        timezone = "America/New_York"
+        with self.sql_conf({
+                "spark.sql.execution.pandas.respectSessionTimeZone": False,
+                "spark.sql.session.timeZone": timezone}):
+            df_la = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
+                .withColumn("internal_value", internal_value(col("timestamp")))
+            result_la = df_la.select(col("idx"), col("internal_value")).collect()
+            # Correct result_la by adjusting 3 hours difference between Los Angeles and New
York
+            diff = 3 * 60 * 60 * 1000 * 1000 * 1000
+            result_la_corrected = \
+                df_la.select(col("idx"), col("tscopy"), col("internal_value") + diff).collect()
 
+        with self.sql_conf({
+                "spark.sql.execution.pandas.respectSessionTimeZone": True,
+                "spark.sql.session.timeZone": timezone}):
             df_ny = df.withColumn("tscopy", f_timestamp_copy(col("timestamp"))) \
                 .withColumn("internal_value", internal_value(col("timestamp")))
             result_ny = df_ny.select(col("idx"), col("tscopy"), col("internal_value")).collect()
 
             self.assertNotEqual(result_ny, result_la)
             self.assertEqual(result_ny, result_la_corrected)
-        finally:
-            self.spark.conf.set("spark.sql.session.timeZone", orig_tz)
 
     def test_nondeterministic_vectorized_udf(self):
         # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations
@@ -5170,9 +5148,7 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
 
     def test_retain_group_columns(self):
         from pyspark.sql.functions import sum, lit, col
-        orig_value = self.spark.conf.get("spark.sql.retainGroupColumns", None)
-        self.spark.conf.set("spark.sql.retainGroupColumns", False)
-        try:
+        with self.sql_conf({"spark.sql.retainGroupColumns": False}):
             df = self.data
             sum_udf = self.pandas_agg_sum_udf
 
@@ -5180,12 +5156,6 @@ class GroupedAggPandasUDFTests(ReusedSQLTestCase):
             expected1 = df.groupby(df.id).agg(sum(df.v))
             self.assertPandasEqual(expected1.toPandas(), result1.toPandas())
 
-        finally:
-            if orig_value is None:
-                self.spark.conf.unset("spark.sql.retainGroupColumns")
-            else:
-                self.spark.conf.set("spark.sql.retainGroupColumns", orig_value)
-
     def test_invalid_args(self):
         from pyspark.sql.functions import mean
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message