spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From cutl...@apache.org
Subject spark git commit: [SPARK-25274][PYTHON][SQL] In toPandas with Arrow send un-ordered record batches to improve performance
Date Thu, 06 Dec 2018 18:08:15 GMT
Repository: spark
Updated Branches:
  refs/heads/master ab76900fe -> ecaa495b1


[SPARK-25274][PYTHON][SQL] In toPandas with Arrow send un-ordered record batches to improve
performance

## What changes were proposed in this pull request?

When executing `toPandas` with Arrow enabled, partitions that arrive in the JVM out-of-order
must be buffered before they can be send to Python. This causes an excess of memory to be
used in the driver JVM and increases the time it takes to complete because data must sit in
the JVM waiting for preceding partitions to come in.

This change sends un-ordered partitions to Python as soon as they arrive in the JVM, followed
by a list of partition indices so that Python can assemble the data in the correct order.
This way, data is not buffered at the JVM and there is no waiting on particular partitions
so performance will be increased.

Followup to #21546

## How was this patch tested?

Added new test with a large number of batches per partition, and test that forces a small
delay in the first partition. These test that partitions are collected out-of-order and then
are are put in the correct order in Python.

## Performance Tests - toPandas

Tests run on a 4 node standalone cluster with 32 cores total, 14.04.1-Ubuntu and OpenJDK 8
measured wall clock time to execute `toPandas()` and took the average best time of 5 runs/5
loops each.

Test code
```python
df = spark.range(1 << 25, numPartitions=32).toDF("id").withColumn("x1", rand()).withColumn("x2",
rand()).withColumn("x3", rand()).withColumn("x4", rand())
for i in range(5):
	start = time.time()
	_ = df.toPandas()
	elapsed = time.time() - start
```

Spark config
```
spark.driver.memory 5g
spark.executor.memory 5g
spark.driver.maxResultSize 2g
spark.sql.execution.arrow.enabled true
```

Current Master w/ Arrow stream | This PR
---------------------|------------
5.16207 | 4.342533
5.133671 | 4.399408
5.147513 | 4.468471
5.105243 | 4.36524
5.018685 | 4.373791

Avg Master | Avg This PR
------------------|--------------
5.1134364 | 4.3898886

Speedup of **1.164821449**

Closes #22275 from BryanCutler/arrow-toPandas-oo-batches-SPARK-25274.

Authored-by: Bryan Cutler <cutlerb@gmail.com>
Signed-off-by: Bryan Cutler <cutlerb@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ecaa495b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ecaa495b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ecaa495b

Branch: refs/heads/master
Commit: ecaa495b1fe532c36e952ccac42f4715809476af
Parents: ab76900
Author: Bryan Cutler <cutlerb@gmail.com>
Authored: Thu Dec 6 10:07:28 2018 -0800
Committer: Bryan Cutler <cutlerb@gmail.com>
Committed: Thu Dec 6 10:07:28 2018 -0800

----------------------------------------------------------------------
 python/pyspark/serializers.py                   | 33 ++++++++++++++
 python/pyspark/sql/dataframe.py                 | 11 ++++-
 python/pyspark/sql/tests/test_arrow.py          | 28 ++++++++++++
 .../scala/org/apache/spark/sql/Dataset.scala    | 45 +++++++++++---------
 4 files changed, 95 insertions(+), 22 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ecaa495b/python/pyspark/serializers.py
----------------------------------------------------------------------
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index ff9a612..f3ebd37 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -185,6 +185,39 @@ class FramedSerializer(Serializer):
         raise NotImplementedError
 
 
+class ArrowCollectSerializer(Serializer):
+    """
+    Deserialize a stream of batches followed by batch order information. Used in
+    DataFrame._collectAsArrow() after invoking Dataset.collectAsArrowToPython() in the JVM.
+    """
+
+    def __init__(self):
+        self.serializer = ArrowStreamSerializer()
+
+    def dump_stream(self, iterator, stream):
+        return self.serializer.dump_stream(iterator, stream)
+
+    def load_stream(self, stream):
+        """
+        Load a stream of un-ordered Arrow RecordBatches, where the last iteration yields
+        a list of indices that can be used to put the RecordBatches in the correct order.
+        """
+        # load the batches
+        for batch in self.serializer.load_stream(stream):
+            yield batch
+
+        # load the batch order indices
+        num = read_int(stream)
+        batch_order = []
+        for i in xrange(num):
+            index = read_int(stream)
+            batch_order.append(index)
+        yield batch_order
+
+    def __repr__(self):
+        return "ArrowCollectSerializer(%s)" % self.serializer
+
+
 class ArrowStreamSerializer(Serializer):
     """
     Serializes Arrow record batches as a stream.

http://git-wip-us.apache.org/repos/asf/spark/blob/ecaa495b/python/pyspark/sql/dataframe.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 1b1092c..a1056d0 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -29,7 +29,7 @@ import warnings
 
 from pyspark import copy_func, since, _NoValue
 from pyspark.rdd import RDD, _load_from_socket, ignore_unicode_prefix
-from pyspark.serializers import ArrowStreamSerializer, BatchedSerializer, PickleSerializer,
\
+from pyspark.serializers import ArrowCollectSerializer, BatchedSerializer, PickleSerializer,
\
     UTF8Deserializer
 from pyspark.storagelevel import StorageLevel
 from pyspark.traceback_utils import SCCallSiteSync
@@ -2168,7 +2168,14 @@ class DataFrame(object):
         """
         with SCCallSiteSync(self._sc) as css:
             sock_info = self._jdf.collectAsArrowToPython()
-        return list(_load_from_socket(sock_info, ArrowStreamSerializer()))
+
+        # Collect list of un-ordered batches where last element is a list of correct order
indices
+        results = list(_load_from_socket(sock_info, ArrowCollectSerializer()))
+        batches = results[:-1]
+        batch_order = results[-1]
+
+        # Re-order the batch list using the correct order
+        return [batches[i] for i in batch_order]
 
     ##########################################################################################
     # Pandas compatibility

http://git-wip-us.apache.org/repos/asf/spark/blob/ecaa495b/python/pyspark/sql/tests/test_arrow.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py
index 6e75e82..21fe500 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -381,6 +381,34 @@ class ArrowTests(ReusedSQLTestCase):
         self.assertPandasEqual(pdf, df_from_python.toPandas())
         self.assertPandasEqual(pdf, df_from_pandas.toPandas())
 
+    def test_toPandas_batch_order(self):
+
+        def delay_first_part(partition_index, iterator):
+            if partition_index == 0:
+                time.sleep(0.1)
+            return iterator
+
+        # Collects Arrow RecordBatches out of order in driver JVM then re-orders in Python
+        def run_test(num_records, num_parts, max_records, use_delay=False):
+            df = self.spark.range(num_records, numPartitions=num_parts).toDF("a")
+            if use_delay:
+                df = df.rdd.mapPartitionsWithIndex(delay_first_part).toDF()
+            with self.sql_conf({"spark.sql.execution.arrow.maxRecordsPerBatch": max_records}):
+                pdf, pdf_arrow = self._toPandas_arrow_toggle(df)
+                self.assertPandasEqual(pdf, pdf_arrow)
+
+        cases = [
+            (1024, 512, 2),    # Use large num partitions for more likely collecting out
of order
+            (64, 8, 2, True),  # Use delay in first partition to force collecting out of
order
+            (64, 64, 1),       # Test single batch per partition
+            (64, 1, 64),       # Test single partition, single batch
+            (64, 1, 8),        # Test single partition, multiple batches
+            (30, 7, 2),        # Test different sized partitions
+        ]
+
+        for case in cases:
+            run_test(*case)
+
 
 class EncryptionArrowTests(ArrowTests):
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ecaa495b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index b10d66d..a664c73 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -17,9 +17,10 @@
 
 package org.apache.spark.sql
 
-import java.io.CharArrayWriter
+import java.io.{CharArrayWriter, DataOutputStream}
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
 import scala.language.implicitConversions
 import scala.util.control.NonFatal
 
@@ -3200,34 +3201,38 @@ class Dataset[T] private[sql](
     val timeZoneId = sparkSession.sessionState.conf.sessionLocalTimeZone
 
     withAction("collectAsArrowToPython", queryExecution) { plan =>
-      PythonRDD.serveToStream("serve-Arrow") { out =>
+      PythonRDD.serveToStream("serve-Arrow") { outputStream =>
+        val out = new DataOutputStream(outputStream)
         val batchWriter = new ArrowBatchStreamWriter(schema, out, timeZoneId)
         val arrowBatchRdd = toArrowBatchRdd(plan)
         val numPartitions = arrowBatchRdd.partitions.length
 
-        // Store collection results for worst case of 1 to N-1 partitions
-        val results = new Array[Array[Array[Byte]]](numPartitions - 1)
-        var lastIndex = -1  // index of last partition written
+        // Batches ordered by (index of partition, batch index in that partition) tuple
+        val batchOrder = new ArrayBuffer[(Int, Int)]()
+        var partitionCount = 0
 
-        // Handler to eagerly write partitions to Python in order
+        // Handler to eagerly write batches to Python as they arrive, un-ordered
         def handlePartitionBatches(index: Int, arrowBatches: Array[Array[Byte]]): Unit =
{
-          // If result is from next partition in order
-          if (index - 1 == lastIndex) {
+          if (arrowBatches.nonEmpty) {
+            // Write all batches (can be more than 1) in the partition, store the batch order
tuple
             batchWriter.writeBatches(arrowBatches.iterator)
-            lastIndex += 1
-            // Write stored partitions that come next in order
-            while (lastIndex < results.length && results(lastIndex) != null) {
-              batchWriter.writeBatches(results(lastIndex).iterator)
-              results(lastIndex) = null
-              lastIndex += 1
+            arrowBatches.indices.foreach {
+              partition_batch_index => batchOrder.append((index, partition_batch_index))
             }
-            // After last batch, end the stream
-            if (lastIndex == results.length) {
-              batchWriter.end()
+          }
+          partitionCount += 1
+
+          // After last batch, end the stream and write batch order indices
+          if (partitionCount == numPartitions) {
+            batchWriter.end()
+            out.writeInt(batchOrder.length)
+            // Sort by (index of partition, batch index in that partition) tuple to get the
+            // overall_batch_index from 0 to N-1 batches, which can be used to put the
+            // transferred batches in the correct order
+            batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overall_batch_index)
=>
+              out.writeInt(overall_batch_index)
             }
-          } else {
-            // Store partitions received out of order
-            results(index - 1) = arrowBatches
+            out.flush()
           }
         }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message