spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-24257][SQL] LongToUnsafeRowMap calculate the new size may be wrong
Date Thu, 24 May 2018 03:24:30 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.1 b51ce9377 -> 1d8287671


[SPARK-24257][SQL] LongToUnsafeRowMap calculate the new size may be wrong

LongToUnsafeRowMap has a mistake when growing its page array: it blindly grows to `oldSize
* 2`, while the new record may be larger than `oldSize * 2`. Then we may have a malformed
UnsafeRow when querying this map, whose actual data is smaller than its declared size, and
the data is corrupted.

Author: sychen <sychen@ctrip.com>

Closes #21311 from cxzl25/fix_LongToUnsafeRowMap_page_size.

(cherry picked from commit 888340151f737bb68d0e419b1e949f11469881f9)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1d828767
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1d828767
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1d828767

Branch: refs/heads/branch-2.1
Commit: 1d828767114fb9d0ab81d7aefcf6ad9e12e10a71
Parents: b51ce93
Author: sychen <sychen@ctrip.com>
Authored: Thu May 24 11:02:09 2018 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Thu May 24 11:20:23 2018 +0800

----------------------------------------------------------------------
 .../sql/execution/joins/HashedRelation.scala    | 38 ++++++++++++--------
 .../execution/joins/HashedRelationSuite.scala   | 26 +++++++++++++-
 2 files changed, 48 insertions(+), 16 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1d828767/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index b9f6601..f7e8ea6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -533,7 +533,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager,
cap
   def append(key: Long, row: UnsafeRow): Unit = {
     val sizeInBytes = row.getSizeInBytes
     if (sizeInBytes >= (1 << SIZE_BITS)) {
-      sys.error("Does not support row that is larger than 256M")
+      throw new UnsupportedOperationException("Does not support row that is larger than 256M")
     }
 
     if (key < minKey) {
@@ -543,19 +543,7 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager,
cap
       maxKey = key
     }
 
-    // There is 8 bytes for the pointer to next value
-    if (cursor + 8 + row.getSizeInBytes > page.length * 8L + Platform.LONG_ARRAY_OFFSET)
{
-      val used = page.length
-      if (used >= (1 << 30)) {
-        sys.error("Can not build a HashedRelation that is larger than 8G")
-      }
-      ensureAcquireMemory(used * 8L * 2)
-      val newPage = new Array[Long](used * 2)
-      Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
-        cursor - Platform.LONG_ARRAY_OFFSET)
-      page = newPage
-      freeMemory(used * 8L)
-    }
+    grow(row.getSizeInBytes)
 
     // copy the bytes of UnsafeRow
     val offset = cursor
@@ -588,7 +576,8 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager,
cap
           growArray()
         } else if (numKeys > array.length / 2 * 0.75) {
           // The fill ratio should be less than 0.75
-          sys.error("Cannot build HashedRelation with more than 1/3 billions unique keys")
+          throw new UnsupportedOperationException(
+            "Cannot build HashedRelation with more than 1/3 billions unique keys")
         }
       }
     } else {
@@ -599,6 +588,25 @@ private[execution] final class LongToUnsafeRowMap(val mm: TaskMemoryManager,
cap
     }
   }
 
+  private def grow(inputRowSize: Int): Unit = {
+    // There is 8 bytes for the pointer to next value
+    val neededNumWords = (cursor - Platform.LONG_ARRAY_OFFSET + 8 + inputRowSize + 7) / 8
+    if (neededNumWords > page.length) {
+      if (neededNumWords > (1 << 30)) {
+        throw new UnsupportedOperationException(
+          "Can not build a HashedRelation that is larger than 8G")
+      }
+      val newNumWords = math.max(neededNumWords, math.min(page.length * 2, 1 << 30))
+      ensureAcquireMemory(newNumWords * 8L)
+      val newPage = new Array[Long](newNumWords.toInt)
+      Platform.copyMemory(page, Platform.LONG_ARRAY_OFFSET, newPage, Platform.LONG_ARRAY_OFFSET,
+        cursor - Platform.LONG_ARRAY_OFFSET)
+      val used = page.length
+      page = newPage
+      freeMemory(used * 8L)
+    }
+  }
+
   private def growArray(): Unit = {
     var old_array = array
     val n = array.length

http://git-wip-us.apache.org/repos/asf/spark/blob/1d828767/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index ede63fe..b575e55 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.serializer.KryoSerializer
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType}
+import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.map.BytesToBytesMap
 import org.apache.spark.unsafe.types.UTF8String
 import org.apache.spark.util.collection.CompactBuffer
@@ -253,6 +253,30 @@ class HashedRelationSuite extends SparkFunSuite with SharedSQLContext
{
     map.free()
   }
 
+  test("SPARK-24257: insert big values into LongToUnsafeRowMap") {
+    val taskMemoryManager = new TaskMemoryManager(
+      new StaticMemoryManager(
+        new SparkConf().set(MEMORY_OFFHEAP_ENABLED.key, "false"),
+        Long.MaxValue,
+        Long.MaxValue,
+        1),
+      0)
+    val unsafeProj = UnsafeProjection.create(Array[DataType](StringType))
+    val map = new LongToUnsafeRowMap(taskMemoryManager, 1)
+
+    val key = 0L
+    // the page array is initialized with length 1 << 17 (1M bytes),
+    // so here we need a value larger than 1 << 18 (2M bytes), to trigger the bug
+    val bigStr = UTF8String.fromString("x" * (1 << 19))
+
+    map.append(key, unsafeProj(InternalRow(bigStr)))
+    map.optimize()
+
+    val resultRow = new UnsafeRow(1)
+    assert(map.getValue(key, resultRow).getUTF8String(0) === bigStr)
+    map.free()
+  }
+
   test("Spark-14521") {
     val ser = new KryoSerializer(
       (new SparkConf).set("spark.kryo.referenceTracking", "false")).newInstance()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message