spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-22328][CORE] ClosureCleaner should not miss referenced superclass fields
Date Thu, 26 Oct 2017 20:41:48 GMT
Repository: spark
Updated Branches:
  refs/heads/master 0e9a750a8 -> 4f8dc6b01


[SPARK-22328][CORE] ClosureCleaner should not miss referenced superclass fields

## What changes were proposed in this pull request?

When the given closure uses some fields defined in super class, `ClosureCleaner` can't figure
them and don't set it properly. Those fields will be in null values.

## How was this patch tested?

Added test.

Author: Liang-Chi Hsieh <viirya@gmail.com>

Closes #19556 from viirya/SPARK-22328.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4f8dc6b0
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4f8dc6b0
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4f8dc6b0

Branch: refs/heads/master
Commit: 4f8dc6b01ea787243a38678ea8199fbb0814cffc
Parents: 0e9a750
Author: Liang-Chi Hsieh <viirya@gmail.com>
Authored: Thu Oct 26 21:41:45 2017 +0100
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Thu Oct 26 21:41:45 2017 +0100

----------------------------------------------------------------------
 .../org/apache/spark/util/ClosureCleaner.scala  | 73 ++++++++++++++++----
 .../apache/spark/util/ClosureCleanerSuite.scala | 72 +++++++++++++++++++
 2 files changed, 133 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4f8dc6b0/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
index 48a1d7b..dfece5d 100644
--- a/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/ClosureCleaner.scala
@@ -91,6 +91,54 @@ private[spark] object ClosureCleaner extends Logging {
     (seen - obj.getClass).toList
   }
 
+  /** Initializes the accessed fields for outer classes and their super classes. */
+  private def initAccessedFields(
+      accessedFields: Map[Class[_], Set[String]],
+      outerClasses: Seq[Class[_]]): Unit = {
+    for (cls <- outerClasses) {
+      var currentClass = cls
+      assert(currentClass != null, "The outer class can't be null.")
+
+      while (currentClass != null) {
+        accessedFields(currentClass) = Set.empty[String]
+        currentClass = currentClass.getSuperclass()
+      }
+    }
+  }
+
+  /** Sets accessed fields for given class in clone object based on given object. */
+  private def setAccessedFields(
+      outerClass: Class[_],
+      clone: AnyRef,
+      obj: AnyRef,
+      accessedFields: Map[Class[_], Set[String]]): Unit = {
+    for (fieldName <- accessedFields(outerClass)) {
+      val field = outerClass.getDeclaredField(fieldName)
+      field.setAccessible(true)
+      val value = field.get(obj)
+      field.set(clone, value)
+    }
+  }
+
+  /** Clones a given object and sets accessed fields in cloned object. */
+  private def cloneAndSetFields(
+      parent: AnyRef,
+      obj: AnyRef,
+      outerClass: Class[_],
+      accessedFields: Map[Class[_], Set[String]]): AnyRef = {
+    val clone = instantiateClass(outerClass, parent)
+
+    var currentClass = outerClass
+    assert(currentClass != null, "The outer class can't be null.")
+
+    while (currentClass != null) {
+      setAccessedFields(currentClass, clone, obj, accessedFields)
+      currentClass = currentClass.getSuperclass()
+    }
+
+    clone
+  }
+
   /**
    * Clean the given closure in place.
    *
@@ -202,9 +250,8 @@ private[spark] object ClosureCleaner extends Logging {
       logDebug(s" + populating accessed fields because this is the starting closure")
       // Initialize accessed fields with the outer classes first
       // This step is needed to associate the fields to the correct classes later
-      for (cls <- outerClasses) {
-        accessedFields(cls) = Set.empty[String]
-      }
+      initAccessedFields(accessedFields, outerClasses)
+
       // Populate accessed fields by visiting all fields and methods accessed by this and
       // all of its inner closures. If transitive cleaning is enabled, this may recursively
       // visits methods that belong to other classes in search of transitively referenced
fields.
@@ -250,13 +297,8 @@ private[spark] object ClosureCleaner extends Logging {
       // required fields from the original object. We need the parent here because the Java
       // language specification requires the first constructor parameter of any closure to
be
       // its enclosing object.
-      val clone = instantiateClass(cls, parent)
-      for (fieldName <- accessedFields(cls)) {
-        val field = cls.getDeclaredField(fieldName)
-        field.setAccessible(true)
-        val value = field.get(obj)
-        field.set(clone, value)
-      }
+      val clone = cloneAndSetFields(parent, obj, cls, accessedFields)
+
       // If transitive cleaning is enabled, we recursively clean any enclosing closure using
       // the already populated accessed fields map of the starting closure
       if (cleanTransitively && isClosure(clone.getClass)) {
@@ -395,8 +437,15 @@ private[util] class FieldAccessFinder(
             if (!visitedMethods.contains(m)) {
               // Keep track of visited methods to avoid potential infinite cycles
               visitedMethods += m
-              ClosureCleaner.getClassReader(cl).accept(
-                new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods),
0)
+
+              var currentClass = cl
+              assert(currentClass != null, "The outer class can't be null.")
+
+              while (currentClass != null) {
+                ClosureCleaner.getClassReader(currentClass).accept(
+                  new FieldAccessFinder(fields, findTransitively, Some(m), visitedMethods),
0)
+                currentClass = currentClass.getSuperclass()
+              }
             }
           }
         }

http://git-wip-us.apache.org/repos/asf/spark/blob/4f8dc6b0/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
index 4920b7e..9a19bae 100644
--- a/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/ClosureCleanerSuite.scala
@@ -119,6 +119,63 @@ class ClosureCleanerSuite extends SparkFunSuite {
   test("createNullValue") {
     new TestCreateNullValue().run()
   }
+
+  test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 1") {
+    val concreteObject = new TestAbstractClass {
+      val n2 = 222
+      val s2 = "bbb"
+      val d2 = 2.0d
+
+      def run(): Seq[(Int, Int, String, String, Double, Double)] = {
+        withSpark(new SparkContext("local", "test")) { sc =>
+          val rdd = sc.parallelize(1 to 1)
+          body(rdd)
+        }
+      }
+
+      def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)] = rdd.map
{ _ =>
+        (n1, n2, s1, s2, d1, d2)
+      }.collect()
+    }
+    assert(concreteObject.run() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d)))
+  }
+
+  test("SPARK-22328: ClosureCleaner misses referenced superclass fields: case 2") {
+    val concreteObject = new TestAbstractClass2 {
+      val n2 = 222
+      val s2 = "bbb"
+      val d2 = 2.0d
+      def getData: Int => (Int, Int, String, String, Double, Double) = _ => (n1, n2,
s1, s2, d1, d2)
+    }
+    withSpark(new SparkContext("local", "test")) { sc =>
+      val rdd = sc.parallelize(1 to 1).map(concreteObject.getData)
+      assert(rdd.collect() === Seq((111, 222, "aaa", "bbb", 1.0d, 2.0d)))
+    }
+  }
+
+  test("SPARK-22328: multiple outer classes have the same parent class") {
+    val concreteObject = new TestAbstractClass2 {
+
+      val innerObject = new TestAbstractClass2 {
+        override val n1 = 222
+        override val s1 = "bbb"
+      }
+
+      val innerObject2 = new TestAbstractClass2 {
+        override val n1 = 444
+        val n3 = 333
+        val s3 = "ccc"
+        val d3 = 3.0d
+
+        def getData: Int => (Int, Int, String, String, Double, Double, Int, String) =
+          _ => (n1, n3, s1, s3, d1, d3, innerObject.n1, innerObject.s1)
+      }
+    }
+    withSpark(new SparkContext("local", "test")) { sc =>
+      val rdd = sc.parallelize(1 to 1).map(concreteObject.innerObject2.getData)
+      assert(rdd.collect() === Seq((444, 333, "aaa", "ccc", 1.0d, 3.0d, 222, "bbb")))
+    }
+  }
 }
 
 // A non-serializable class we create in closures to make sure that we aren't
@@ -377,3 +434,18 @@ class TestCreateNullValue {
     nestedClosure()
   }
 }
+
+abstract class TestAbstractClass extends Serializable {
+  val n1 = 111
+  val s1 = "aaa"
+  protected val d1 = 1.0d
+
+  def run(): Seq[(Int, Int, String, String, Double, Double)]
+  def body(rdd: RDD[Int]): Seq[(Int, Int, String, String, Double, Double)]
+}
+
+abstract class TestAbstractClass2 extends Serializable {
+  val n1 = 111
+  val s1 = "aaa"
+  protected val d1 = 1.0d
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message