spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From gurwls...@apache.org
Subject spark git commit: [SPARK-22530][PYTHON][SQL] Adding Arrow support for ArrayType
Date Mon, 01 Jan 2018 22:13:27 GMT
Repository: spark
Updated Branches:
  refs/heads/master c284c4e1f -> 1c9f95cb7


[SPARK-22530][PYTHON][SQL] Adding Arrow support for ArrayType

## What changes were proposed in this pull request?

This change adds `ArrayType` support for working with Arrow in pyspark when creating a DataFrame,
calling `toPandas()`, and using vectorized `pandas_udf`.

## How was this patch tested?

Added new Python unit tests using Array data.

Author: Bryan Cutler <cutlerb@gmail.com>

Closes #20114 from BryanCutler/arrow-ArrayType-support-SPARK-22530.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1c9f95cb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1c9f95cb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1c9f95cb

Branch: refs/heads/master
Commit: 1c9f95cb771ac78775a77edd1abfeb2d8ae2a124
Parents: c284c4e
Author: Bryan Cutler <cutlerb@gmail.com>
Authored: Tue Jan 2 07:13:27 2018 +0900
Committer: hyukjinkwon <gurwls223@gmail.com>
Committed: Tue Jan 2 07:13:27 2018 +0900

----------------------------------------------------------------------
 python/pyspark/sql/tests.py                     | 47 +++++++++++++++++++-
 python/pyspark/sql/types.py                     |  4 ++
 .../execution/vectorized/ArrowColumnVector.java | 13 +++++-
 3 files changed, 61 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1c9f95cb/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1c34c89..67bdb3d 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -3372,6 +3372,31 @@ class ArrowTests(ReusedSQLTestCase):
         schema_rt = from_arrow_schema(arrow_schema)
         self.assertEquals(self.schema, schema_rt)
 
+    def test_createDataFrame_with_array_type(self):
+        import pandas as pd
+        pdf = pd.DataFrame({"a": [[1, 2], [3, 4]], "b": [[u"x", u"y"], [u"y", u"z"]]})
+        df, df_arrow = self._createDataFrame_toggle(pdf)
+        result = df.collect()
+        result_arrow = df_arrow.collect()
+        expected = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)]
+        for r in range(len(expected)):
+            for e in range(len(expected[r])):
+                self.assertTrue(expected[r][e] == result_arrow[r][e] and
+                                result[r][e] == result_arrow[r][e])
+
+    def test_toPandas_with_array_type(self):
+        expected = [([1, 2], [u"x", u"y"]), ([3, 4], [u"y", u"z"])]
+        array_schema = StructType([StructField("a", ArrayType(IntegerType())),
+                                   StructField("b", ArrayType(StringType()))])
+        df = self.spark.createDataFrame(expected, schema=array_schema)
+        pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
+        result = [tuple(list(e) for e in rec) for rec in pdf.to_records(index=False)]
+        result_arrow = [tuple(list(e) for e in rec) for rec in pdf_arrow.to_records(index=False)]
+        for r in range(len(expected)):
+            for e in range(len(expected[r])):
+                self.assertTrue(expected[r][e] == result_arrow[r][e] and
+                                result[r][e] == result_arrow[r][e])
+
 
 @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed")
 class PandasUDFTests(ReusedSQLTestCase):
@@ -3651,6 +3676,24 @@ class VectorizedUDFTests(ReusedSQLTestCase):
                         bool_f(col('bool')))
         self.assertEquals(df.collect(), res.collect())
 
+    def test_vectorized_udf_array_type(self):
+        from pyspark.sql.functions import pandas_udf, col
+        data = [([1, 2],), ([3, 4],)]
+        array_schema = StructType([StructField("array", ArrayType(IntegerType()))])
+        df = self.spark.createDataFrame(data, schema=array_schema)
+        array_f = pandas_udf(lambda x: x, ArrayType(IntegerType()))
+        result = df.select(array_f(col('array')))
+        self.assertEquals(df.collect(), result.collect())
+
+    def test_vectorized_udf_null_array(self):
+        from pyspark.sql.functions import pandas_udf, col
+        data = [([1, 2],), (None,), (None,), ([3, 4],), (None,)]
+        array_schema = StructType([StructField("array", ArrayType(IntegerType()))])
+        df = self.spark.createDataFrame(data, schema=array_schema)
+        array_f = pandas_udf(lambda x: x, ArrayType(IntegerType()))
+        result = df.select(array_f(col('array')))
+        self.assertEquals(df.collect(), result.collect())
+
     def test_vectorized_udf_complex(self):
         from pyspark.sql.functions import pandas_udf, col, expr
         df = self.spark.range(10).select(
@@ -3705,7 +3748,7 @@ class VectorizedUDFTests(ReusedSQLTestCase):
     def test_vectorized_udf_wrong_return_type(self):
         from pyspark.sql.functions import pandas_udf, col
         df = self.spark.range(10)
-        f = pandas_udf(lambda x: x * 1.0, ArrayType(LongType()))
+        f = pandas_udf(lambda x: x * 1.0, MapType(LongType(), LongType()))
         with QuietTest(self.sc):
             with self.assertRaisesRegexp(Exception, 'Unsupported.*type.*conversion'):
                 df.select(f(col('id'))).collect()
@@ -4009,7 +4052,7 @@ class GroupbyApplyTests(ReusedSQLTestCase):
 
         foo = pandas_udf(
             lambda pdf: pdf,
-            'id long, v array<int>',
+            'id long, v map<int, int>',
             PandasUDFType.GROUP_MAP
         )
 

http://git-wip-us.apache.org/repos/asf/spark/blob/1c9f95cb/python/pyspark/sql/types.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 02b2457..146e673 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1625,6 +1625,8 @@ def to_arrow_type(dt):
     elif type(dt) == TimestampType:
         # Timestamps should be in UTC, JVM Arrow timestamps require a timezone to be read
         arrow_type = pa.timestamp('us', tz='UTC')
+    elif type(dt) == ArrayType:
+        arrow_type = pa.list_(to_arrow_type(dt.elementType))
     else:
         raise TypeError("Unsupported type in conversion to Arrow: " + str(dt))
     return arrow_type
@@ -1665,6 +1667,8 @@ def from_arrow_type(at):
         spark_type = DateType()
     elif types.is_timestamp(at):
         spark_type = TimestampType()
+    elif types.is_list(at):
+        spark_type = ArrayType(from_arrow_type(at.value_type))
     else:
         raise TypeError("Unsupported type in conversion from Arrow: " + str(at))
     return spark_type

http://git-wip-us.apache.org/repos/asf/spark/blob/1c9f95cb/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java
index 528f66f..af5673e 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java
@@ -326,7 +326,8 @@ public final class ArrowColumnVector extends ColumnVector {
       this.vector = vector;
     }
 
-    final boolean isNullAt(int rowId) {
+    // TODO: should be final after removing ArrayAccessor workaround
+    boolean isNullAt(int rowId) {
       return vector.isNull(rowId);
     }
 
@@ -590,6 +591,16 @@ public final class ArrowColumnVector extends ColumnVector {
     }
 
     @Override
+    final boolean isNullAt(int rowId) {
+      // TODO: Workaround if vector has all non-null values, see ARROW-1948
+      if (accessor.getValueCount() > 0 && accessor.getValidityBuffer().capacity()
== 0) {
+        return false;
+      } else {
+        return super.isNullAt(rowId);
+      }
+    }
+
+    @Override
     final int getArrayLength(int rowId) {
       return accessor.getInnerValueCountAt(rowId);
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message