spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wenc...@apache.org
Subject spark git commit: [SPARK-22984] Fix incorrect bitmap copying and offset adjustment in GenerateUnsafeRowJoiner
Date Tue, 09 Jan 2018 03:59:49 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-2.3 850b9f391 -> fd46a276c


[SPARK-22984] Fix incorrect bitmap copying and offset adjustment in GenerateUnsafeRowJoiner

## What changes were proposed in this pull request?

This PR fixes a longstanding correctness bug in `GenerateUnsafeRowJoiner`. This class was
introduced in https://github.com/apache/spark/pull/7821 (July 2015 / Spark 1.5.0+) and is
used to combine pairs of UnsafeRows in TungstenAggregationIterator, CartesianProductExec,
and AppendColumns.

### Bugs fixed by this patch

1. **Incorrect combining of null-tracking bitmaps**: when concatenating two UnsafeRows, the
implementation "Concatenate the two bitsets together into a single one, taking padding into
account". If one row has no columns then it has a bitset size of 0, but the code was incorrectly
assuming that if the left row had a non-zero number of fields then the right row would also
have at least one field, so it was copying invalid bytes and and treating them as part of
the bitset. I'm not sure whether this bug was also present in the original implementation
or whether it was introduced in https://github.com/apache/spark/pull/7892 (which fixed another
bug in this code).
2. **Incorrect updating of data offsets for null variable-length fields**: after updating
the bitsets and copying fixed-length and variable-length data, we need to perform adjustments
to the offsets pointing the start of variable length fields's data. The existing code was
_conditionally_ adding a fixed offset to correct for the new length of the combined row, but
it is unsafe to do this if the variable-length field has a null value: we always represent
nulls by storing `0` in the fixed-length slot, but this code was incorrectly incrementing
those values. This bug was present since the original version of `GenerateUnsafeRowJoiner`.

### Why this bug remained latent for so long

The PR which introduced `GenerateUnsafeRowJoiner` features several randomized tests, including
tests of the cases where one side of the join has no fields and where string-valued fields
are null. However, the existing assertions were too weak to uncover this bug:

- If a null field has a non-zero value in its fixed-length data slot then this will not cause
problems for field accesses because the null-tracking bitmap should still be correct and we
will not try to use the incorrect offset for anything.
- If the null tracking bitmap is corrupted by joining against a row with no fields then the
corruption occurs in field numbers past the actual field numbers contained in the row. Thus
valid `isNullAt()` calls will not read the incorrectly-set bits.

The existing `GenerateUnsafeRowJoinerSuite` tests only exercised `.get()` and `isNullAt()`,
but didn't actually check the UnsafeRows for bit-for-bit equality, preventing these bugs from
failing assertions. It turns out that there was even a [GenerateUnsafeRowJoinerBitsetSuite](https://github.com/apache/spark/blob/03377d2522776267a07b7d6ae9bddf79a4e0f516/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerBitsetSuite.scala)
but it looks like it also didn't catch this problem because it only tested the bitsets in
an end-to-end fashion by accessing them through the `UnsafeRow` interface instead of actually
comparing the bitsets' bytes.

### Impact of these bugs

- This bug will cause `equals()` and `hashCode()` to be incorrect for these rows, which will
be problematic in case`GenerateUnsafeRowJoiner`'s results are used as join or grouping keys.
- Chained / repeated invocations of `GenerateUnsafeRowJoiner` may result in reads from invalid
null bitmap positions causing fields to incorrectly become NULL (see the end-to-end example
below).
  - It looks like this generally only happens in `CartesianProductExec`, which our query optimizer
often avoids executing (usually we try to plan a `BroadcastNestedLoopJoin` instead).

### End-to-end test case demonstrating the problem

The following query demonstrates how this bug may result in incorrect query results:

```sql
set spark.sql.autoBroadcastJoinThreshold=-1; -- Needed to trigger CartesianProductExec

create table a as select * from values 1;
create table b as select * from values 2;

SELECT
  t3.col1,
  t1.col1
FROM a t1
CROSS JOIN b t2
CROSS JOIN b t3
```

This should return `(2, 1)` but instead was returning `(null, 1)`.

Column pruning ends up trimming off all columns from `t2`, so when `t2` joins with another
table this triggers the bitmap-copying bug. This incorrect bitmap is subsequently copied again
when performing the final join, causing the final output to have an incorrectly-set null bit
for the first field.

## How was this patch tested?

Strengthened the assertions in existing tests in GenerateUnsafeRowJoinerSuite. Also verified
that the end-to-end test case which uncovered this now passes.

Author: Josh Rosen <joshrosen@databricks.com>

Closes #20181 from JoshRosen/SPARK-22984-fix-generate-unsaferow-joiner-bitmap-bugs.

(cherry picked from commit f20131dd35939734fe16b0005a086aa72400893b)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fd46a276
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fd46a276
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fd46a276

Branch: refs/heads/branch-2.3
Commit: fd46a276c04d2a6617c56abe84be80a1be5d7f99
Parents: 850b9f3
Author: Josh Rosen <joshrosen@databricks.com>
Authored: Tue Jan 9 11:49:10 2018 +0800
Committer: Wenchen Fan <wenchen@databricks.com>
Committed: Tue Jan 9 11:59:39 2018 +0800

----------------------------------------------------------------------
 .../codegen/GenerateUnsafeRowJoiner.scala       | 52 ++++++++++-
 .../codegen/GenerateUnsafeRowJoinerSuite.scala  | 95 ++++++++++++++++++--
 2 files changed, 138 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fd46a276/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
index be5f5a7..febf7b0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoiner.scala
@@ -70,7 +70,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType),
U
 
     // --------------------- copy bitset from row 1 and row 2 ---------------------------
//
     val copyBitset = Seq.tabulate(outputBitsetWords) { i =>
-      val bits = if (bitset1Remainder > 0) {
+      val bits = if (bitset1Remainder > 0 && bitset2Words != 0) {
         if (i < bitset1Words - 1) {
           s"$getLong(obj1, offset1 + ${i * 8})"
         } else if (i == bitset1Words - 1) {
@@ -152,7 +152,9 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType),
U
       } else {
         // Number of bytes to increase for the offset. Note that since in UnsafeRow we store
the
         // offset in the upper 32 bit of the words, we can just shift the offset to the left
by
-        // 32 and increment that amount in place.
+        // 32 and increment that amount in place. However, we need to handle the important
special
+        // case of a null field, in which case the offset should be zero and should not have
a
+        // shift added to it.
         val shift =
           if (i < schema1.size) {
             s"${(outputBitsetWords - bitset1Words + schema2.size) * 8}L"
@@ -160,14 +162,55 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType),
U
             s"(${(outputBitsetWords - bitset2Words + schema1.size) * 8}L + numBytesVariableRow1)"
           }
         val cursor = offset + outputBitsetWords * 8 + i * 8
-        s"$putLong(buf, $cursor, $getLong(buf, $cursor) + ($shift << 32));\n"
+        // UnsafeRow is a little underspecified, so in what follows we'll treat UnsafeRowWriter's
+        // output as a de-facto specification for the internal layout of data.
+        //
+        // Null-valued fields will always have a data offset of 0 because
+        // UnsafeRowWriter.setNullAt(ordinal) sets the null bit and stores 0 to in field's
+        // position in the fixed-length section of the row. As a result, we must NOT add
+        // `shift` to the offset for null fields.
+        //
+        // We could perform a null-check here by inspecting the null-tracking bitmap, but
doing
+        // so could be expensive and will add significant bloat to the generated code. Instead,
+        // we'll rely on the invariant "stored offset == 0 for variable-length data type
implies
+        // that the field's value is null."
+        //
+        // To establish that this invariant holds, we'll prove that a non-null field can
never
+        // have a stored offset of 0. There are two cases to consider:
+        //
+        //   1. The non-null field's data is of non-zero length: reading this field's value
+        //      must read data from the variable-length section of the row, so the stored
offset
+        //      will actually be used in address calculation and must be correct. The offsets
+        //      count bytes from the start of the UnsafeRow so these offsets will always
be
+        //      non-zero because the storage of the offsets themselves takes up space at
the
+        //      start of the row.
+        //   2. The non-null field's data is of zero length (i.e. its data is empty). In
this
+        //      case, we have to worry about the possibility that an arbitrary offset value
was
+        //      stored because we never actually read any bytes using this offset and therefore
+        //      would not crash if it was incorrect. The variable-sized data writing paths
in
+        //      UnsafeRowWriter unconditionally calls setOffsetAndSize(ordinal, numBytes)
with
+        //      no special handling for the case where `numBytes == 0`. Internally,
+        //      setOffsetAndSize computes the offset without taking the size into account.
Thus
+        //      the stored offset is the same non-zero offset that would be used if the field's
+        //      dataSize was non-zero (and in (1) above we've shown that case behaves as
we
+        //      expect).
+        //
+        // Thus it is safe to perform `existingOffset != 0` checks here in the place of
+        // more expensive null-bit checks.
+        s"""
+           |existingOffset = $getLong(buf, $cursor);
+           |if (existingOffset != 0) {
+           |    $putLong(buf, $cursor, existingOffset + ($shift << 32));
+           |}
+         """.stripMargin
       }
     }
 
     val updateOffsets = ctx.splitExpressions(
       expressions = updateOffset,
       funcName = "copyBitsetFunc",
-      arguments = ("long", "numBytesVariableRow1") :: Nil)
+      arguments = ("long", "numBytesVariableRow1") :: Nil,
+      makeSplitFunction = (s: String) => "long existingOffset;\n" + s)
 
     // ------------------------ Finally, put everything together  ---------------------------
//
     val codeBody = s"""
@@ -200,6 +243,7 @@ object GenerateUnsafeRowJoiner extends CodeGenerator[(StructType, StructType),
U
        |    $copyFixedLengthRow2
        |    $copyVariableLengthRow1
        |    $copyVariableLengthRow2
+       |    long existingOffset;
        |    $updateOffsets
        |
        |    out.pointTo(buf, sizeInBytes);

http://git-wip-us.apache.org/repos/asf/spark/blob/fd46a276/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
index f203f25..75c6bee 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeRowJoinerSuite.scala
@@ -22,8 +22,10 @@ import scala.util.Random
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.RandomDataGenerator
 import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
+import org.apache.spark.sql.catalyst.expressions.{JoinedRow, UnsafeProjection, UnsafeRow}
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
 
 /**
  * Test suite for [[GenerateUnsafeRowJoiner]].
@@ -45,6 +47,32 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
     testConcat(64, 64, fixed)
   }
 
+  test("rows with all empty strings") {
+    val schema = StructType(Seq(
+      StructField("f1", StringType), StructField("f2", StringType)))
+    val row: UnsafeRow = UnsafeProjection.create(schema).apply(
+      InternalRow(UTF8String.EMPTY_UTF8, UTF8String.EMPTY_UTF8))
+     testConcat(schema, row, schema, row)
+  }
+
+  test("rows with all empty int arrays") {
+    val schema = StructType(Seq(
+      StructField("f1", ArrayType(IntegerType)), StructField("f2", ArrayType(IntegerType))))
+    val emptyIntArray =
+      ExpressionEncoder[Array[Int]]().resolveAndBind().toRow(Array.emptyIntArray).getArray(0)
+    val row: UnsafeRow = UnsafeProjection.create(schema).apply(
+      InternalRow(emptyIntArray, emptyIntArray))
+    testConcat(schema, row, schema, row)
+  }
+
+  test("alternating empty and non-empty strings") {
+    val schema = StructType(Seq(
+      StructField("f1", StringType), StructField("f2", StringType)))
+    val row: UnsafeRow = UnsafeProjection.create(schema).apply(
+      InternalRow(UTF8String.EMPTY_UTF8, UTF8String.fromString("foo")))
+    testConcat(schema, row, schema, row)
+  }
+
   test("randomized fix width types") {
     for (i <- 0 until 20) {
       testConcatOnce(Random.nextInt(100), Random.nextInt(100), fixed)
@@ -94,27 +122,84 @@ class GenerateUnsafeRowJoinerSuite extends SparkFunSuite {
     val extRow2 = RandomDataGenerator.forType(schema2, nullable = false).get.apply()
     val row1 = converter1.apply(internalConverter1.apply(extRow1).asInstanceOf[InternalRow])
     val row2 = converter2.apply(internalConverter2.apply(extRow2).asInstanceOf[InternalRow])
+    testConcat(schema1, row1, schema2, row2)
+  }
+
+  private def testConcat(
+      schema1: StructType,
+      row1: UnsafeRow,
+      schema2: StructType,
+      row2: UnsafeRow) {
 
     // Run the joiner.
     val mergedSchema = StructType(schema1 ++ schema2)
     val concater = GenerateUnsafeRowJoiner.create(schema1, schema2)
-    val output = concater.join(row1, row2)
+    val output: UnsafeRow = concater.join(row1, row2)
+
+    // We'll also compare to an UnsafeRow produced with JoinedRow + UnsafeProjection. This
ensures
+    // that unused space in the row (e.g. leftover bits in the null-tracking bitmap) is written
+    // correctly.
+    val expectedOutput: UnsafeRow = {
+      val joinedRowProjection = UnsafeProjection.create(mergedSchema)
+      val joined = new JoinedRow()
+      joinedRowProjection.apply(joined.apply(row1, row2))
+    }
 
     // Test everything equals ...
     for (i <- mergedSchema.indices) {
+      val dataType = mergedSchema(i).dataType
       if (i < schema1.size) {
         assert(output.isNullAt(i) === row1.isNullAt(i))
         if (!output.isNullAt(i)) {
-          assert(output.get(i, mergedSchema(i).dataType) === row1.get(i, mergedSchema(i).dataType))
+          assert(output.get(i, dataType) === row1.get(i, dataType))
+          assert(output.get(i, dataType) === expectedOutput.get(i, dataType))
         }
       } else {
         assert(output.isNullAt(i) === row2.isNullAt(i - schema1.size))
         if (!output.isNullAt(i)) {
-          assert(output.get(i, mergedSchema(i).dataType) ===
-            row2.get(i - schema1.size, mergedSchema(i).dataType))
+          assert(output.get(i, dataType) === row2.get(i - schema1.size, dataType))
+          assert(output.get(i, dataType) === expectedOutput.get(i, dataType))
         }
       }
     }
+
+
+    assert(
+      expectedOutput.getSizeInBytes == output.getSizeInBytes,
+      "output isn't same size in bytes as slow path")
+
+    // Compare the UnsafeRows byte-by-byte so that we can print more useful debug information
in
+    // case this assertion fails:
+    val actualBytes = output.getBaseObject.asInstanceOf[Array[Byte]]
+      .take(output.getSizeInBytes)
+    val expectedBytes = expectedOutput.getBaseObject.asInstanceOf[Array[Byte]]
+      .take(expectedOutput.getSizeInBytes)
+
+    val bitsetWidth = UnsafeRow.calculateBitSetWidthInBytes(expectedOutput.numFields())
+    val actualBitset = actualBytes.take(bitsetWidth)
+    val expectedBitset = expectedBytes.take(bitsetWidth)
+    assert(actualBitset === expectedBitset, "bitsets were not equal")
+
+    val fixedLengthSize = expectedOutput.numFields() * 8
+    val actualFixedLength = actualBytes.slice(bitsetWidth, bitsetWidth + fixedLengthSize)
+    val expectedFixedLength = expectedBytes.slice(bitsetWidth, bitsetWidth + fixedLengthSize)
+    if (actualFixedLength !== expectedFixedLength) {
+      actualFixedLength.grouped(8)
+        .zip(expectedFixedLength.grouped(8))
+        .zip(mergedSchema.fields.toIterator)
+        .foreach {
+          case ((actual, expected), field) =>
+            assert(actual === expected, s"Fixed length sections are not equal for field $field")
+      }
+      fail("Fixed length sections were not equal")
+    }
+
+    val variableLengthStart = bitsetWidth + fixedLengthSize
+    val actualVariableLength = actualBytes.drop(variableLengthStart)
+    val expectedVariableLength = expectedBytes.drop(variableLengthStart)
+    assert(actualVariableLength === expectedVariableLength, "fixed length sections were not
equal")
+
+    assert(output.hashCode() == expectedOutput.hashCode(), "hash codes were not equal")
   }
 
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message