spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From marmb...@apache.org
Subject spark git commit: [SPARK-11553][SQL] Primitive Row accessors should not convert null to default value
Date Mon, 16 Nov 2015 23:14:43 GMT
Repository: spark
Updated Branches:
  refs/heads/master bcea0bfda -> 31296628a


[SPARK-11553][SQL] Primitive Row accessors should not convert null to default value

Invocation of getters for type extending AnyVal returns default value (if field value is null)
instead of throwing NPE. Please check comments for SPARK-11553 issue for more details.

Author: Bartlomiej Alberski <bartlomiej.alberski@allegrogroup.com>

Closes #9642 from alberskib/bugfix/SPARK-11553.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/31296628
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/31296628
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/31296628

Branch: refs/heads/master
Commit: 31296628ac7cd7be71e0edca335dc8604f62bb47
Parents: bcea0bf
Author: Bartlomiej Alberski <bartlomiej.alberski@allegrogroup.com>
Authored: Mon Nov 16 15:14:38 2015 -0800
Committer: Michael Armbrust <michael@databricks.com>
Committed: Mon Nov 16 15:14:38 2015 -0800

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/sql/Row.scala   | 32 ++++++++++++-----
 .../scala/org/apache/spark/sql/RowTest.scala    | 20 +++++++++++
 .../local/NestedLoopJoinNodeSuite.scala         | 36 ++++++++++++--------
 3 files changed, 65 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/31296628/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index 0f0f200..b14c66c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -191,7 +191,7 @@ trait Row extends Serializable {
    * @throws ClassCastException when data type does not match.
    * @throws NullPointerException when value is null.
    */
-  def getBoolean(i: Int): Boolean = getAs[Boolean](i)
+  def getBoolean(i: Int): Boolean = getAnyValAs[Boolean](i)
 
   /**
    * Returns the value at position i as a primitive byte.
@@ -199,7 +199,7 @@ trait Row extends Serializable {
    * @throws ClassCastException when data type does not match.
    * @throws NullPointerException when value is null.
    */
-  def getByte(i: Int): Byte = getAs[Byte](i)
+  def getByte(i: Int): Byte = getAnyValAs[Byte](i)
 
   /**
    * Returns the value at position i as a primitive short.
@@ -207,7 +207,7 @@ trait Row extends Serializable {
    * @throws ClassCastException when data type does not match.
    * @throws NullPointerException when value is null.
    */
-  def getShort(i: Int): Short = getAs[Short](i)
+  def getShort(i: Int): Short = getAnyValAs[Short](i)
 
   /**
    * Returns the value at position i as a primitive int.
@@ -215,7 +215,7 @@ trait Row extends Serializable {
    * @throws ClassCastException when data type does not match.
    * @throws NullPointerException when value is null.
    */
-  def getInt(i: Int): Int = getAs[Int](i)
+  def getInt(i: Int): Int = getAnyValAs[Int](i)
 
   /**
    * Returns the value at position i as a primitive long.
@@ -223,7 +223,7 @@ trait Row extends Serializable {
    * @throws ClassCastException when data type does not match.
    * @throws NullPointerException when value is null.
    */
-  def getLong(i: Int): Long = getAs[Long](i)
+  def getLong(i: Int): Long = getAnyValAs[Long](i)
 
   /**
    * Returns the value at position i as a primitive float.
@@ -232,7 +232,7 @@ trait Row extends Serializable {
    * @throws ClassCastException when data type does not match.
    * @throws NullPointerException when value is null.
    */
-  def getFloat(i: Int): Float = getAs[Float](i)
+  def getFloat(i: Int): Float = getAnyValAs[Float](i)
 
   /**
    * Returns the value at position i as a primitive double.
@@ -240,13 +240,12 @@ trait Row extends Serializable {
    * @throws ClassCastException when data type does not match.
    * @throws NullPointerException when value is null.
    */
-  def getDouble(i: Int): Double = getAs[Double](i)
+  def getDouble(i: Int): Double = getAnyValAs[Double](i)
 
   /**
    * Returns the value at position i as a String object.
    *
    * @throws ClassCastException when data type does not match.
-   * @throws NullPointerException when value is null.
    */
   def getString(i: Int): String = getAs[String](i)
 
@@ -318,6 +317,8 @@ trait Row extends Serializable {
 
   /**
    * Returns the value at position i.
+   * For primitive types if value is null it returns 'zero value' specific for primitive
+   * ie. 0 for Int - use isNullAt to ensure that value is not null
    *
    * @throws ClassCastException when data type does not match.
    */
@@ -325,6 +326,8 @@ trait Row extends Serializable {
 
   /**
    * Returns the value of a given fieldName.
+   * For primitive types if value is null it returns 'zero value' specific for primitive
+   * ie. 0 for Int - use isNullAt to ensure that value is not null
    *
    * @throws UnsupportedOperationException when schema is not defined.
    * @throws IllegalArgumentException when fieldName do not exist.
@@ -344,6 +347,8 @@ trait Row extends Serializable {
 
   /**
    * Returns a Map(name -> value) for the requested fieldNames
+   * For primitive types if value is null it returns 'zero value' specific for primitive
+   * ie. 0 for Int - use isNullAt to ensure that value is not null
    *
    * @throws UnsupportedOperationException when schema is not defined.
    * @throws IllegalArgumentException when fieldName do not exist.
@@ -458,4 +463,15 @@ trait Row extends Serializable {
    * start, end, and separator strings.
    */
   def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep,
end)
+
+  /**
+   * Returns the value of a given fieldName.
+   *
+   * @throws UnsupportedOperationException when schema is not defined.
+   * @throws ClassCastException when data type does not match.
+   * @throws NullPointerException when value is null.
+   */
+  private def getAnyValAs[T <: AnyVal](i: Int): T =
+    if (isNullAt(i)) throw new NullPointerException(s"Value at index $i in null")
+    else getAs[T](i)
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/31296628/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
index 01ff84c..5c22a72 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala
@@ -29,8 +29,10 @@ class RowTest extends FunSpec with Matchers {
     StructField("col2", StringType) ::
     StructField("col3", IntegerType) :: Nil)
   val values = Array("value1", "value2", 1)
+  val valuesWithoutCol3 = Array[Any](null, "value2", null)
 
   val sampleRow: Row = new GenericRowWithSchema(values, schema)
+  val sampleRowWithoutCol3: Row = new GenericRowWithSchema(valuesWithoutCol3, schema)
   val noSchemaRow: Row = new GenericRow(values)
 
   describe("Row (without schema)") {
@@ -68,6 +70,24 @@ class RowTest extends FunSpec with Matchers {
       )
       sampleRow.getValuesMap(List("col1", "col2")) shouldBe expected
     }
+
+    it("getValuesMap() retrieves null value on non AnyVal Type") {
+      val expected = Map(
+        "col1" -> null,
+        "col2" -> "value2"
+      )
+      sampleRowWithoutCol3.getValuesMap[String](List("col1", "col2")) shouldBe expected
+    }
+
+    it("getAs() on type extending AnyVal throws an exception when accessing field that is
null") {
+      intercept[NullPointerException] {
+        sampleRowWithoutCol3.getInt(sampleRowWithoutCol3.fieldIndex("col3"))
+      }
+    }
+
+    it("getAs() on type extending AnyVal does not throw exception when value is null"){
+      sampleRowWithoutCol3.getAs[String](sampleRowWithoutCol3.fieldIndex("col1")) shouldBe
null
+    }
   }
 
   describe("row equals") {

http://git-wip-us.apache.org/repos/asf/spark/blob/31296628/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala
index 252f7cc..45df2ea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala
@@ -58,8 +58,14 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
       val hashJoinNode = makeUnsafeNode(leftNode, rightNode)
       val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType)
       val actualOutput = hashJoinNode.collect().map { row =>
-        // (id, name, id, nickname)
-        (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3))
+        // (
+        //   id, name,
+        //   id, nickname
+        // )
+        (
+          Option(row.get(0)).map(_.asInstanceOf[Int]), Option(row.getString(1)),
+          Option(row.get(2)).map(_.asInstanceOf[Int]), Option(row.getString(3))
+        )
       }
       assert(actualOutput.toSet === expectedOutput.toSet)
     }
@@ -95,36 +101,36 @@ class NestedLoopJoinNodeSuite extends LocalNodeTest {
   private def generateExpectedOutput(
       leftInput: Array[(Int, String)],
       rightInput: Array[(Int, String)],
-      joinType: JoinType): Array[(Int, String, Int, String)] = {
+      joinType: JoinType): Array[(Option[Int], Option[String], Option[Int], Option[String])]
= {
     joinType match {
       case LeftOuter =>
         val rightInputMap = rightInput.toMap
         leftInput.map { case (k, v) =>
-          val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0)
-          val rightValue = rightInputMap.getOrElse(k, null)
-          (k, v, rightKey, rightValue)
+          val rightKey = rightInputMap.get(k).map { _ => k }
+          val rightValue = rightInputMap.get(k)
+          (Some(k), Some(v), rightKey, rightValue)
         }
 
       case RightOuter =>
         val leftInputMap = leftInput.toMap
         rightInput.map { case (k, v) =>
-          val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0)
-          val leftValue = leftInputMap.getOrElse(k, null)
-          (leftKey, leftValue, k, v)
+          val leftKey = leftInputMap.get(k).map { _ => k }
+          val leftValue = leftInputMap.get(k)
+          (leftKey, leftValue, Some(k), Some(v))
         }
 
       case FullOuter =>
         val leftInputMap = leftInput.toMap
         val rightInputMap = rightInput.toMap
         val leftOutput = leftInput.map { case (k, v) =>
-          val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0)
-          val rightValue = rightInputMap.getOrElse(k, null)
-          (k, v, rightKey, rightValue)
+          val rightKey = rightInputMap.get(k).map { _ => k }
+          val rightValue = rightInputMap.get(k)
+          (Some(k), Some(v), rightKey, rightValue)
         }
         val rightOutput = rightInput.map { case (k, v) =>
-          val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0)
-          val leftValue = leftInputMap.getOrElse(k, null)
-          (leftKey, leftValue, k, v)
+          val leftKey = leftInputMap.get(k).map { _ => k }
+          val leftValue = leftInputMap.get(k)
+          (leftKey, leftValue, Some(k), Some(v))
         }
         (leftOutput ++ rightOutput).distinct
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message