spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From joshro...@apache.org
Subject spark git commit: [SPARK-6194] [SPARK-677] [PySpark] fix memory leak in collect()
Date Tue, 10 Mar 2015 01:20:14 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.2 e753f9c9b -> d7c359b49


[SPARK-6194] [SPARK-677] [PySpark] fix memory leak in collect()

Because circular reference between JavaObject and JavaMember, an Java object can not be released
until Python GC kick in, then it will cause memory leak in collect(), which may consume lots
of memory in JVM.

This PR change the way we sending collected data back into Python from local file to socket,
which could avoid any disk IO during collect, also avoid any referrers of Java object in Python.

cc JoshRosen

Author: Davies Liu <davies@databricks.com>

Closes #4923 from davies/fix_collect and squashes the following commits:

d730286 [Davies Liu] address comments
24c92a4 [Davies Liu] fix style
ba54614 [Davies Liu] use socket to transfer data from JVM
9517c8f [Davies Liu] fix memory leak in collect()

(cherry picked from commit 8767565cef01d847f57b7293d8b63b2422009b90)
Signed-off-by: Josh Rosen <joshrosen@databricks.com>

Conflicts:
	core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
	python/pyspark/rdd.py
	python/pyspark/sql/dataframe.py


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d7c359b4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d7c359b4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d7c359b4

Branch: refs/heads/branch-1.2
Commit: d7c359b495a10484e7240eae491d00e67e2dee2d
Parents: e753f9c
Author: Davies Liu <davies@databricks.com>
Authored: Mon Mar 9 16:24:06 2015 -0700
Committer: Josh Rosen <joshrosen@databricks.com>
Committed: Mon Mar 9 18:19:46 2015 -0700

----------------------------------------------------------------------
 .../org/apache/spark/api/python/PythonRDD.scala | 74 +++++++++++++++-----
 python/pyspark/context.py                       | 13 ++--
 python/pyspark/rdd.py                           | 43 +++++++-----
 python/pyspark/sql.py                           |  8 ++-
 .../scala/org/apache/spark/sql/SchemaRDD.scala  | 12 ----
 5 files changed, 96 insertions(+), 54 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d7c359b4/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index bfd36c7..2715722 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -19,27 +19,28 @@ package org.apache.spark.api.python
 
 import java.io._
 import java.net._
-import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections}
-
-import org.apache.spark.input.PortableDataStream
+import java.util.{Collections, ArrayList => JArrayList, List => JList, Map => JMap}
 
 import scala.collection.JavaConversions._
 import scala.collection.mutable
 import scala.language.existentials
 
 import com.google.common.base.Charsets.UTF_8
-
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.io.compress.CompressionCodec
-import org.apache.hadoop.mapred.{InputFormat, OutputFormat, JobConf}
+import org.apache.hadoop.mapred.{InputFormat, JobConf, OutputFormat}
 import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, OutputFormat =>
NewOutputFormat}
+
 import org.apache.spark._
 import org.apache.spark.SparkContext._
 import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
 import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.input.PortableDataStream
 import org.apache.spark.rdd.RDD
 import org.apache.spark.util.Utils
 
+import scala.util.control.NonFatal
+
 private[spark] class PythonRDD(
     @transient parent: RDD[_],
     command: Array[Byte],
@@ -331,21 +332,33 @@ private[spark] object PythonRDD extends Logging {
   /**
    * Adapter for calling SparkContext#runJob from Python.
    *
-   * This method will return an iterator of an array that contains all elements in the RDD
+   * This method will serve an iterator of an array that contains all elements in the RDD
    * (effectively a collect()), but allows you to run on a certain subset of partitions,
    * or to enable local execution.
+   *
+   * @return the port number of a local socket which serves the data collected from this
job.
    */
   def runJob(
       sc: SparkContext,
       rdd: JavaRDD[Array[Byte]],
       partitions: JArrayList[Int],
-      allowLocal: Boolean): Iterator[Array[Byte]] = {
+      allowLocal: Boolean): Int = {
     type ByteArray = Array[Byte]
     type UnrolledPartition = Array[ByteArray]
     val allPartitions: Array[UnrolledPartition] =
       sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
     val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
-    flattenedPartition.iterator
+    serveIterator(flattenedPartition.iterator,
+      s"serve RDD ${rdd.id} with partitions ${partitions.mkString(",")}")
+  }
+
+  /**
+   * A helper function to collect an RDD as an iterator, then serve it via socket.
+   *
+   * @return the port number of a local socket which serves the data collected from this
job.
+   */
+  def collectAndServe[T](rdd: RDD[T]): Int = {
+    serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}")
   }
 
   def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
@@ -594,15 +607,44 @@ private[spark] object PythonRDD extends Logging {
     dataOut.write(bytes)
   }
 
-  def writeToFile[T](items: java.util.Iterator[T], filename: String) {
-    import scala.collection.JavaConverters._
-    writeToFile(items.asScala, filename)
-  }
+  /**
+   * Create a socket server and a background thread to serve the data in `items`,
+   *
+   * The socket server can only accept one connection, or close if no connection
+   * in 3 seconds.
+   *
+   * Once a connection comes in, it tries to serialize all the data in `items`
+   * and send them into this connection.
+   *
+   * The thread will terminate after all the data are sent or any exceptions happen.
+   */
+  private def serveIterator[T](items: Iterator[T], threadName: String): Int = {
+    val serverSocket = new ServerSocket(0, 1)
+    serverSocket.setReuseAddress(true)
+    // Close the socket if no connection in 3 seconds
+    serverSocket.setSoTimeout(3000)
+
+    new Thread(threadName) {
+      setDaemon(true)
+      override def run() {
+        try {
+          val sock = serverSocket.accept()
+          val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
+          try {
+            writeIteratorToStream(items, out)
+          } finally {
+            out.close()
+          }
+        } catch {
+          case NonFatal(e) =>
+            logError(s"Error while sending iterator", e)
+        } finally {
+          serverSocket.close()
+        }
+      }
+    }.start()
 
-  def writeToFile[T](items: Iterator[T], filename: String) {
-    val file = new DataOutputStream(new FileOutputStream(filename))
-    writeIteratorToStream(items, file)
-    file.close()
+    serverSocket.getLocalPort
   }
 
   private def getMergedConf(confAsMap: java.util.HashMap[String, String],

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c359b4/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 23ff8cc..50d06d3 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -22,6 +22,8 @@ from threading import Lock
 from tempfile import NamedTemporaryFile
 import atexit
 
+from py4j.java_collections import ListConverter
+
 from pyspark import accumulators
 from pyspark.accumulators import Accumulator
 from pyspark.broadcast import Broadcast
@@ -31,11 +33,9 @@ from pyspark.java_gateway import launch_gateway
 from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
     PairDeserializer, AutoBatchedSerializer, NoOpSerializer
 from pyspark.storagelevel import StorageLevel
-from pyspark.rdd import RDD
+from pyspark.rdd import RDD, _load_from_socket
 from pyspark.traceback_utils import CallSite, first_spark_call
 
-from py4j.java_collections import ListConverter
-
 
 __all__ = ['SparkContext']
 
@@ -58,7 +58,6 @@ class SparkContext(object):
 
     _gateway = None
     _jvm = None
-    _writeToFile = None
     _next_accum_id = 0
     _active_spark_context = None
     _lock = Lock()
@@ -211,7 +210,6 @@ class SparkContext(object):
             if not SparkContext._gateway:
                 SparkContext._gateway = gateway or launch_gateway()
                 SparkContext._jvm = SparkContext._gateway.jvm
-                SparkContext._writeToFile = SparkContext._jvm.PythonRDD.writeToFile
 
             if instance:
                 if (SparkContext._active_spark_context and
@@ -824,8 +822,9 @@ class SparkContext(object):
         # by runJob() in order to avoid having to pass a Python lambda into
         # SparkContext#runJob.
         mappedRDD = rdd.mapPartitions(partitionFunc)
-        it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions,
allowLocal)
-        return list(mappedRDD._collect_iterator_through_file(it))
+        port = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions,
+                                          allowLocal)
+        return list(_load_from_socket(port, mappedRDD._jrdd_deserializer))
 
     def _add_profile(self, id, profileAcc):
         if not self._profile_stats:

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c359b4/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 9d676d7..f1037e0 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -19,7 +19,6 @@ import copy
 from collections import defaultdict
 from itertools import chain, ifilter, imap
 import operator
-import os
 import sys
 import shlex
 from subprocess import Popen, PIPE
@@ -29,6 +28,7 @@ import warnings
 import heapq
 import bisect
 import random
+import socket
 from math import sqrt, log, isinf, isnan
 
 from pyspark.accumulators import PStatsParam
@@ -112,6 +112,30 @@ def _parse_memory(s):
     return int(float(s[:-1]) * units[s[-1].lower()])
 
 
+def _load_from_socket(port, serializer):
+    sock = socket.socket()
+    try:
+        sock.connect(("localhost", port))
+        rf = sock.makefile("rb", 65536)
+        for item in serializer.load_stream(rf):
+            yield item
+    finally:
+        sock.close()
+
+
+class Partitioner(object):
+    def __init__(self, numPartitions, partitionFunc):
+        self.numPartitions = numPartitions
+        self.partitionFunc = partitionFunc
+
+    def __eq__(self, other):
+        return (isinstance(other, Partitioner) and self.numPartitions == other.numPartitions
+                and self.partitionFunc == other.partitionFunc)
+
+    def __call__(self, k):
+        return self.partitionFunc(k) % self.numPartitions
+
+
 class RDD(object):
 
     """
@@ -683,21 +707,8 @@ class RDD(object):
         Return a list that contains all of the elements in this RDD.
         """
         with SCCallSiteSync(self.context) as css:
-            bytesInJava = self._jrdd.collect().iterator()
-        return list(self._collect_iterator_through_file(bytesInJava))
-
-    def _collect_iterator_through_file(self, iterator):
-        # Transferring lots of data through Py4J can be slow because
-        # socket.readline() is inefficient.  Instead, we'll dump the data to a
-        # file and read it back.
-        tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
-        tempFile.close()
-        self.ctx._writeToFile(iterator, tempFile.name)
-        # Read the data into Python and deserialize it:
-        with open(tempFile.name, 'rb') as tempFile:
-            for item in self._jrdd_deserializer.load_stream(tempFile):
-                yield item
-        os.unlink(tempFile.name)
+            port = self.ctx._jvm.PythonRDD.collectAndServe(self._jrdd.rdd())
+        return list(_load_from_socket(port, self._jrdd_deserializer))
 
     def reduce(self, f):
         """

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c359b4/python/pyspark/sql.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 4410925..8c68801 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -44,7 +44,7 @@ from itertools import imap
 from py4j.protocol import Py4JError
 from py4j.java_collections import ListConverter, MapConverter
 
-from pyspark.rdd import RDD
+from pyspark.rdd import RDD, _load_from_socket
 from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer,
\
     CloudPickleSerializer, UTF8Deserializer
 from pyspark.storagelevel import StorageLevel
@@ -1996,9 +1996,11 @@ class SchemaRDD(RDD):
         [Row(field1=1, field2=u'row1'), ..., Row(field1=3, field2=u'row3')]
         """
         with SCCallSiteSync(self.context) as css:
-            bytesInJava = self._jschema_rdd.baseSchemaRDD().collectToPython().iterator()
+            rdd = self._jschema_rdd.baseSchemaRDD().javaToPython().rdd()
+            port = self._sc._jvm.PythonRDD.collectAndServe(rdd)
+        rs = list(_load_from_socket(port, BatchedSerializer(PickleSerializer())))
         cls = _create_cls(self.schema())
-        return map(cls, self._collect_iterator_through_file(bytesInJava))
+        return [cls(r) for r in rs]
 
     def take(self, num):
         """Take the first num rows of the RDD.

http://git-wip-us.apache.org/repos/asf/spark/blob/d7c359b4/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index c6d4dab..03689d3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -413,18 +413,6 @@ class SchemaRDD(
   }
 
   /**
-   * Serializes the Array[Row] returned by SchemaRDD's optimized collect(), using the same
-   * format as javaToPython. It is used by pyspark.
-   */
-  private[sql] def collectToPython: JList[Array[Byte]] = {
-    val fieldTypes = schema.fields.map(_.dataType)
-    val pickle = new Pickler
-    new java.util.ArrayList(collect().map { row =>
-      EvaluatePython.rowToArray(row, fieldTypes)
-    }.grouped(100).map(batched => pickle.dumps(batched.toArray)).toIterable)
-  }
-
-  /**
    * Creates SchemaRDD by applying own schema to derived RDD. Typically used to wrap return
value
    * of base RDD functions that do not change schema.
    *


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message