spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From joshro...@apache.org
Subject spark git commit: [SPARK-8797] [SPARK-9146] [SPARK-9145] [SPARK-9147] Support NaN ordering and equality comparisons in Spark SQL
Date Tue, 21 Jul 2015 05:38:29 GMT
Repository: spark
Updated Branches:
  refs/heads/master 4d97be953 -> c032b0bf9


[SPARK-8797] [SPARK-9146] [SPARK-9145] [SPARK-9147] Support NaN ordering and equality comparisons
in Spark SQL

This patch addresses an issue where queries that sorted float or double columns containing
NaN values could fail with "Comparison method violates its general contract!" errors from
TimSort.  The root of this problem is that `NaN > anything`, `NaN == anything`, and `NaN
< anything` all return `false`.

Per the design specified in SPARK-9079, we have decided that `NaN = NaN` should return true
and that NaN should appear last when sorting in ascending order (i.e. it is larger than any
other numeric value).

In addition to implementing these semantics, this patch also adds canonicalization of NaN
values in UnsafeRow, which is necessary in order to be able to do binary equality comparisons
on equal NaNs that might have different bit representations (see SPARK-9147).

Author: Josh Rosen <joshrosen@databricks.com>

Closes #7194 from JoshRosen/nan and squashes the following commits:

983d4fc [Josh Rosen] Merge remote-tracking branch 'origin/master' into nan
88bd73c [Josh Rosen] Fix Row.equals()
a702e2e [Josh Rosen] normalization -> canonicalization
a7267cf [Josh Rosen] Normalize NaNs in UnsafeRow
fe629ae [Josh Rosen] Merge remote-tracking branch 'origin/master' into nan
fbb2a29 [Josh Rosen] Fix NaN comparisons in BinaryComparison expressions
c1fd4fe [Josh Rosen] Fold NaN test into existing test framework
b31eb19 [Josh Rosen] Uncomment failing tests
7fe67af [Josh Rosen] Support NaN == NaN (SPARK-9145)
58bad2c [Josh Rosen] Revert "Compare rows' string representations to work around NaN incomparability."
fc6b4d2 [Josh Rosen] Update CodeGenerator
3998ef2 [Josh Rosen] Remove unused code
a2ba2e7 [Josh Rosen] Fix prefix comparision for NaNs
a30d371 [Josh Rosen] Compare rows' string representations to work around NaN incomparability.
6f03f85 [Josh Rosen] Fix bug in Double / Float ordering
42a1ad5 [Josh Rosen] Stop filtering NaNs in UnsafeExternalSortSuite
bfca524 [Josh Rosen] Change ordering so that NaN is maximum value.
8d7be61 [Josh Rosen] Update randomized test to use ScalaTest's assume()
b20837b [Josh Rosen] Add failing test for new NaN comparision ordering
5b88b2b [Josh Rosen] Fix compilation of CodeGenerationSuite
d907b5b [Josh Rosen] Merge remote-tracking branch 'origin/master' into nan
630ebc5 [Josh Rosen] Specify an ordering for NaN values.
9bf195a [Josh Rosen] Re-enable NaNs in CodeGenerationSuite to produce more regression tests
13fc06a [Josh Rosen] Add regression test for NaN sorting issue
f9efbb5 [Josh Rosen] Fix ORDER BY NULL
e7dc4fb [Josh Rosen] Add very generic test for ordering
7d5c13e [Josh Rosen] Add regression test for SPARK-8782 (ORDER BY NULL)
b55875a [Josh Rosen] Generate doubles and floats over entire possible range.
5acdd5c [Josh Rosen] Infinity and NaN are interesting.
ab76cbd [Josh Rosen] Move code to Catalyst package.
d2b4a4a [Josh Rosen] Add random data generator test utilities to Spark SQL.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c032b0bf
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c032b0bf
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c032b0bf

Branch: refs/heads/master
Commit: c032b0bf92130dc4facb003f0deaeb1228aefded
Parents: 4d97be9
Author: Josh Rosen <joshrosen@databricks.com>
Authored: Mon Jul 20 22:38:05 2015 -0700
Committer: Josh Rosen <joshrosen@databricks.com>
Committed: Mon Jul 20 22:38:05 2015 -0700

----------------------------------------------------------------------
 .../unsafe/sort/PrefixComparators.java          |  5 ++-
 .../scala/org/apache/spark/util/Utils.scala     | 28 ++++++++++++++
 .../org/apache/spark/util/UtilsSuite.scala      | 31 ++++++++++++++++
 .../unsafe/sort/PrefixComparatorsSuite.scala    | 25 +++++++++++++
 .../sql/catalyst/expressions/UnsafeRow.java     |  6 +++
 .../main/scala/org/apache/spark/sql/Row.scala   | 24 ++++++++----
 .../expressions/codegen/CodeGenerator.scala     |  4 ++
 .../sql/catalyst/expressions/predicates.scala   | 22 +++++++++--
 .../org/apache/spark/sql/types/DoubleType.scala |  5 ++-
 .../org/apache/spark/sql/types/FloatType.scala  |  5 ++-
 .../expressions/CodeGenerationSuite.scala       | 39 ++++++++++++++++++++
 .../catalyst/expressions/PredicateSuite.scala   | 13 ++++---
 .../expressions/UnsafeRowConverterSuite.scala   | 22 +++++++++++
 .../org/apache/spark/sql/DataFrameSuite.scala   | 22 +++++++++++
 .../scala/org/apache/spark/sql/RowSuite.scala   | 12 ++++++
 .../sql/execution/UnsafeExternalSortSuite.scala |  6 +--
 16 files changed, 243 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/c032b0bf/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
index 4387425..bf1bc5d 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/PrefixComparators.java
@@ -23,6 +23,7 @@ import com.google.common.primitives.UnsignedBytes;
 
 import org.apache.spark.annotation.Private;
 import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.util.Utils;
 
 @Private
 public class PrefixComparators {
@@ -82,7 +83,7 @@ public class PrefixComparators {
     public int compare(long aPrefix, long bPrefix) {
       float a = Float.intBitsToFloat((int) aPrefix);
       float b = Float.intBitsToFloat((int) bPrefix);
-      return (a < b) ? -1 : (a > b) ? 1 : 0;
+      return Utils.nanSafeCompareFloats(a, b);
     }
 
     public long computePrefix(float value) {
@@ -97,7 +98,7 @@ public class PrefixComparators {
     public int compare(long aPrefix, long bPrefix) {
       double a = Double.longBitsToDouble(aPrefix);
       double b = Double.longBitsToDouble(bPrefix);
-      return (a < b) ? -1 : (a > b) ? 1 : 0;
+      return Utils.nanSafeCompareDoubles(a, b);
     }
 
     public long computePrefix(double value) {

http://git-wip-us.apache.org/repos/asf/spark/blob/c032b0bf/core/src/main/scala/org/apache/spark/util/Utils.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index e6374f1..c581694 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -1586,6 +1586,34 @@ private[spark] object Utils extends Logging {
     hashAbs
   }
 
+  /**
+   * NaN-safe version of [[java.lang.Double.compare()]] which allows NaN values to be compared
+   * according to semantics where NaN == NaN and NaN > any non-NaN double.
+   */
+  def nanSafeCompareDoubles(x: Double, y: Double): Int = {
+    val xIsNan: Boolean = java.lang.Double.isNaN(x)
+    val yIsNan: Boolean = java.lang.Double.isNaN(y)
+    if ((xIsNan && yIsNan) || (x == y)) 0
+    else if (xIsNan) 1
+    else if (yIsNan) -1
+    else if (x > y) 1
+    else -1
+  }
+
+  /**
+   * NaN-safe version of [[java.lang.Float.compare()]] which allows NaN values to be compared
+   * according to semantics where NaN == NaN and NaN > any non-NaN float.
+   */
+  def nanSafeCompareFloats(x: Float, y: Float): Int = {
+    val xIsNan: Boolean = java.lang.Float.isNaN(x)
+    val yIsNan: Boolean = java.lang.Float.isNaN(y)
+    if ((xIsNan && yIsNan) || (x == y)) 0
+    else if (xIsNan) 1
+    else if (yIsNan) -1
+    else if (x > y) 1
+    else -1
+  }
+
   /** Returns the system properties map that is thread-safe to iterator over. It gets the
     * properties which have been set explicitly, as well as those for which only a default
value
     * has been defined. */

http://git-wip-us.apache.org/repos/asf/spark/blob/c032b0bf/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
index c763850..8f7e402 100644
--- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.util
 
 import java.io.{File, ByteArrayOutputStream, ByteArrayInputStream, FileOutputStream}
+import java.lang.{Double => JDouble, Float => JFloat}
 import java.net.{BindException, ServerSocket, URI}
 import java.nio.{ByteBuffer, ByteOrder}
 import java.text.DecimalFormatSymbols
@@ -689,4 +690,34 @@ class UtilsSuite extends SparkFunSuite with ResetSystemProperties with
Logging {
     // scalastyle:on println
     assert(buffer.toString === "t circular test circular\n")
   }
+
+  test("nanSafeCompareDoubles") {
+    def shouldMatchDefaultOrder(a: Double, b: Double): Unit = {
+      assert(Utils.nanSafeCompareDoubles(a, b) === JDouble.compare(a, b))
+      assert(Utils.nanSafeCompareDoubles(b, a) === JDouble.compare(b, a))
+    }
+    shouldMatchDefaultOrder(0d, 0d)
+    shouldMatchDefaultOrder(0d, 1d)
+    shouldMatchDefaultOrder(Double.MinValue, Double.MaxValue)
+    assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NaN) === 0)
+    assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.PositiveInfinity) === 1)
+    assert(Utils.nanSafeCompareDoubles(Double.NaN, Double.NegativeInfinity) === 1)
+    assert(Utils.nanSafeCompareDoubles(Double.PositiveInfinity, Double.NaN) === -1)
+    assert(Utils.nanSafeCompareDoubles(Double.NegativeInfinity, Double.NaN) === -1)
+  }
+
+  test("nanSafeCompareFloats") {
+    def shouldMatchDefaultOrder(a: Float, b: Float): Unit = {
+      assert(Utils.nanSafeCompareFloats(a, b) === JFloat.compare(a, b))
+      assert(Utils.nanSafeCompareFloats(b, a) === JFloat.compare(b, a))
+    }
+    shouldMatchDefaultOrder(0f, 0f)
+    shouldMatchDefaultOrder(1f, 1f)
+    shouldMatchDefaultOrder(Float.MinValue, Float.MaxValue)
+    assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NaN) === 0)
+    assert(Utils.nanSafeCompareFloats(Float.NaN, Float.PositiveInfinity) === 1)
+    assert(Utils.nanSafeCompareFloats(Float.NaN, Float.NegativeInfinity) === 1)
+    assert(Utils.nanSafeCompareFloats(Float.PositiveInfinity, Float.NaN) === -1)
+    assert(Utils.nanSafeCompareFloats(Float.NegativeInfinity, Float.NaN) === -1)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c032b0bf/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
index dd505df..dc03e37 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/unsafe/sort/PrefixComparatorsSuite.scala
@@ -47,4 +47,29 @@ class PrefixComparatorsSuite extends SparkFunSuite with PropertyChecks
{
     forAll (regressionTests) { (s1: String, s2: String) => testPrefixComparison(s1, s2)
}
     forAll { (s1: String, s2: String) => testPrefixComparison(s1, s2) }
   }
+
+  test("float prefix comparator handles NaN properly") {
+    val nan1: Float = java.lang.Float.intBitsToFloat(0x7f800001)
+    val nan2: Float = java.lang.Float.intBitsToFloat(0x7fffffff)
+    assert(nan1.isNaN)
+    assert(nan2.isNaN)
+    val nan1Prefix = PrefixComparators.FLOAT.computePrefix(nan1)
+    val nan2Prefix = PrefixComparators.FLOAT.computePrefix(nan2)
+    assert(nan1Prefix === nan2Prefix)
+    val floatMaxPrefix = PrefixComparators.FLOAT.computePrefix(Float.MaxValue)
+    assert(PrefixComparators.FLOAT.compare(nan1Prefix, floatMaxPrefix) === 1)
+  }
+
+  test("double prefix comparator handles NaNs properly") {
+    val nan1: Double = java.lang.Double.longBitsToDouble(0x7ff0000000000001L)
+    val nan2: Double = java.lang.Double.longBitsToDouble(0x7fffffffffffffffL)
+    assert(nan1.isNaN)
+    assert(nan2.isNaN)
+    val nan1Prefix = PrefixComparators.DOUBLE.computePrefix(nan1)
+    val nan2Prefix = PrefixComparators.DOUBLE.computePrefix(nan2)
+    assert(nan1Prefix === nan2Prefix)
+    val doubleMaxPrefix = PrefixComparators.DOUBLE.computePrefix(Double.MaxValue)
+    assert(PrefixComparators.DOUBLE.compare(nan1Prefix, doubleMaxPrefix) === 1)
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c032b0bf/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 87294a0..8cd9e7b 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -215,6 +215,9 @@ public final class UnsafeRow extends MutableRow {
   public void setDouble(int ordinal, double value) {
     assertIndexIsValid(ordinal);
     setNotNullAt(ordinal);
+    if (Double.isNaN(value)) {
+      value = Double.NaN;
+    }
     PlatformDependent.UNSAFE.putDouble(baseObject, getFieldOffset(ordinal), value);
   }
 
@@ -243,6 +246,9 @@ public final class UnsafeRow extends MutableRow {
   public void setFloat(int ordinal, float value) {
     assertIndexIsValid(ordinal);
     setNotNullAt(ordinal);
+    if (Float.isNaN(value)) {
+      value = Float.NaN;
+    }
     PlatformDependent.UNSAFE.putFloat(baseObject, getFieldOffset(ordinal), value);
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/c032b0bf/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
index 2cb64d0..9144947 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
@@ -403,20 +403,28 @@ trait Row extends Serializable {
       if (!isNullAt(i)) {
         val o1 = get(i)
         val o2 = other.get(i)
-        if (o1.isInstanceOf[Array[Byte]]) {
-          // handle equality of Array[Byte]
-          val b1 = o1.asInstanceOf[Array[Byte]]
-          if (!o2.isInstanceOf[Array[Byte]] ||
-            !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
+        o1 match {
+          case b1: Array[Byte] =>
+            if (!o2.isInstanceOf[Array[Byte]] ||
+                !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) {
+              return false
+            }
+          case f1: Float if java.lang.Float.isNaN(f1) =>
+            if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float]))
{
+              return false
+            }
+          case d1: Double if java.lang.Double.isNaN(d1) =>
+            if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double]))
{
+              return false
+            }
+          case _ => if (o1 != o2) {
             return false
           }
-        } else if (o1 != o2) {
-          return false
         }
       }
       i += 1
     }
-    return true
+    true
   }
 
   /* ---------------------- utility methods for Scala ---------------------- */

http://git-wip-us.apache.org/repos/asf/spark/blob/c032b0bf/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 10f411f..606f770 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -194,6 +194,8 @@ class CodeGenContext {
    */
   def genEqual(dataType: DataType, c1: String, c2: String): String = dataType match {
     case BinaryType => s"java.util.Arrays.equals($c1, $c2)"
+    case FloatType => s"(java.lang.Float.isNaN($c1) && java.lang.Float.isNaN($c2))
|| $c1 == $c2"
+    case DoubleType => s"(java.lang.Double.isNaN($c1) && java.lang.Double.isNaN($c2))
|| $c1 == $c2"
     case dt: DataType if isPrimitiveType(dt) => s"$c1 == $c2"
     case other => s"$c1.equals($c2)"
   }
@@ -204,6 +206,8 @@ class CodeGenContext {
   def genComp(dataType: DataType, c1: String, c2: String): String = dataType match {
     // java boolean doesn't support > or < operator
     case BooleanType => s"($c1 == $c2 ? 0 : ($c1 ? 1 : -1))"
+    case DoubleType => s"org.apache.spark.util.Utils.nanSafeCompareDoubles($c1, $c2)"
+    case FloatType => s"org.apache.spark.util.Utils.nanSafeCompareFloats($c1, $c2)"
     // use c1 - c2 may overflow
     case dt: DataType if isPrimitiveType(dt) => s"($c1 > $c2 ? 1 : $c1 < $c2 ? -1
: 0)"
     case BinaryType => s"org.apache.spark.sql.catalyst.util.TypeUtils.compareBinary($c1,
$c2)"

http://git-wip-us.apache.org/repos/asf/spark/blob/c032b0bf/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index 40ec3df..a53ec31 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.catalyst.expressions.codegen._
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.types._
+import org.apache.spark.util.Utils
 
 
 object InterpretedPredicate {
@@ -222,7 +223,9 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator
with P
 abstract class BinaryComparison extends BinaryOperator with Predicate {
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
-    if (ctx.isPrimitiveType(left.dataType)) {
+    if (ctx.isPrimitiveType(left.dataType)
+        && left.dataType != FloatType
+        && left.dataType != DoubleType) {
       // faster version
       defineCodeGen(ctx, ev, (c1, c2) => s"$c1 $symbol $c2")
     } else {
@@ -254,8 +257,15 @@ case class EqualTo(left: Expression, right: Expression) extends BinaryComparison
   override def symbol: String = "="
 
   protected override def nullSafeEval(input1: Any, input2: Any): Any = {
-    if (left.dataType != BinaryType) input1 == input2
-    else java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
+    if (left.dataType == FloatType) {
+      Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float])
== 0
+    } else if (left.dataType == DoubleType) {
+      Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double])
== 0
+    } else if (left.dataType != BinaryType) {
+      input1 == input2
+    } else {
+      java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])
+    }
   }
 
   override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
@@ -280,7 +290,11 @@ case class EqualNullSafe(left: Expression, right: Expression) extends
BinaryComp
     } else if (input1 == null || input2 == null) {
       false
     } else {
-      if (left.dataType != BinaryType) {
+      if (left.dataType == FloatType) {
+        Utils.nanSafeCompareFloats(input1.asInstanceOf[Float], input2.asInstanceOf[Float])
== 0
+      } else if (left.dataType == DoubleType) {
+        Utils.nanSafeCompareDoubles(input1.asInstanceOf[Double], input2.asInstanceOf[Double])
== 0
+      } else if (left.dataType != BinaryType) {
         input1 == input2
       } else {
         java.util.Arrays.equals(input1.asInstanceOf[Array[Byte]], input2.asInstanceOf[Array[Byte]])

http://git-wip-us.apache.org/repos/asf/spark/blob/c032b0bf/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala
index 986c2ab..2a1bf09 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/DoubleType.scala
@@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.typeTag
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.sql.catalyst.ScalaReflectionLock
+import org.apache.spark.util.Utils
 
 /**
  * :: DeveloperApi ::
@@ -37,7 +38,9 @@ class DoubleType private() extends FractionalType {
   @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType]
}
   private[sql] val numeric = implicitly[Numeric[Double]]
   private[sql] val fractional = implicitly[Fractional[Double]]
-  private[sql] val ordering = implicitly[Ordering[InternalType]]
+  private[sql] val ordering = new Ordering[Double] {
+    override def compare(x: Double, y: Double): Int = Utils.nanSafeCompareDoubles(x, y)
+  }
   private[sql] val asIntegral = DoubleAsIfIntegral
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/c032b0bf/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala
index 9bd48ec..08e2225 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/FloatType.scala
@@ -23,6 +23,7 @@ import scala.reflect.runtime.universe.typeTag
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.sql.catalyst.ScalaReflectionLock
+import org.apache.spark.util.Utils
 
 /**
  * :: DeveloperApi ::
@@ -37,7 +38,9 @@ class FloatType private() extends FractionalType {
   @transient private[sql] lazy val tag = ScalaReflectionLock.synchronized { typeTag[InternalType]
}
   private[sql] val numeric = implicitly[Numeric[Float]]
   private[sql] val fractional = implicitly[Fractional[Float]]
-  private[sql] val ordering = implicitly[Ordering[InternalType]]
+  private[sql] val ordering = new Ordering[Float] {
+    override def compare(x: Float, y: Float): Int = Utils.nanSafeCompareFloats(x, y)
+  }
   private[sql] val asIntegral = FloatAsIfIntegral
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/c032b0bf/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index e05218a..f4fbc49 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -17,9 +17,14 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import scala.math._
+
 import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.types.{DataTypeTestUtils, NullType, StructField, StructType}
 
 /**
  * Additional tests for code generation.
@@ -43,6 +48,40 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper
{
     futures.foreach(Await.result(_, 10.seconds))
   }
 
+  // Test GenerateOrdering for all common types. For each type, we construct random input
rows that
+  // contain two columns of that type, then for pairs of randomly-generated rows we check
that
+  // GenerateOrdering agrees with RowOrdering.
+  (DataTypeTestUtils.atomicTypes ++ Set(NullType)).foreach { dataType =>
+    test(s"GenerateOrdering with $dataType") {
+      val rowOrdering = RowOrdering.forSchema(Seq(dataType, dataType))
+      val genOrdering = GenerateOrdering.generate(
+        BoundReference(0, dataType, nullable = true).asc ::
+          BoundReference(1, dataType, nullable = true).asc :: Nil)
+      val rowType = StructType(
+        StructField("a", dataType, nullable = true) ::
+          StructField("b", dataType, nullable = true) :: Nil)
+      val maybeDataGenerator = RandomDataGenerator.forType(rowType, nullable = false)
+      assume(maybeDataGenerator.isDefined)
+      val randGenerator = maybeDataGenerator.get
+      val toCatalyst = CatalystTypeConverters.createToCatalystConverter(rowType)
+      for (_ <- 1 to 50) {
+        val a = toCatalyst(randGenerator()).asInstanceOf[InternalRow]
+        val b = toCatalyst(randGenerator()).asInstanceOf[InternalRow]
+        withClue(s"a = $a, b = $b") {
+          assert(genOrdering.compare(a, a) === 0)
+          assert(genOrdering.compare(b, b) === 0)
+          assert(rowOrdering.compare(a, a) === 0)
+          assert(rowOrdering.compare(b, b) === 0)
+          assert(signum(genOrdering.compare(a, b)) === -1 * signum(genOrdering.compare(b,
a)))
+          assert(signum(rowOrdering.compare(a, b)) === -1 * signum(rowOrdering.compare(b,
a)))
+          assert(
+            signum(rowOrdering.compare(a, b)) === signum(genOrdering.compare(a, b)),
+            "Generated and non-generated orderings should agree")
+        }
+      }
+    }
+  }
+
   test("SPARK-8443: split wide projections into blocks due to JVM code size limit") {
     val length = 5000
     val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1)))

http://git-wip-us.apache.org/repos/asf/spark/blob/c032b0bf/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
index 2173a0c..0bc2812 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala
@@ -136,11 +136,14 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper
{
     checkEvaluation(And(InSet(one, hS), InSet(two, hS)), true)
   }
 
-  private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_))
-  private val largeValues = Seq(2, Decimal(2), Array(2.toByte), "b").map(Literal(_))
-
-  private val equalValues1 = smallValues
-  private val equalValues2 = Seq(1, Decimal(1), Array(1.toByte), "a").map(Literal(_))
+  private val smallValues = Seq(1, Decimal(1), Array(1.toByte), "a", 0f, 0d).map(Literal(_))
+  private val largeValues =
+    Seq(2, Decimal(2), Array(2.toByte), "b", Float.NaN, Double.NaN).map(Literal(_))
+
+  private val equalValues1 =
+    Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_))
+  private val equalValues2 =
+    Seq(1, Decimal(1), Array(1.toByte), "a", Float.NaN, Double.NaN).map(Literal(_))
 
   test("BinaryComparison: <") {
     for (i <- 0 until smallValues.length) {

http://git-wip-us.apache.org/repos/asf/spark/blob/c032b0bf/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
index d00aeb4d..dff5faf 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala
@@ -316,4 +316,26 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers {
     assert(setToNullAfterCreation.get(11) === rowWithNoNullColumns.get(11))
   }
 
+  test("NaN canonicalization") {
+    val fieldTypes: Array[DataType] = Array(FloatType, DoubleType)
+
+    val row1 = new SpecificMutableRow(fieldTypes)
+    row1.setFloat(0, java.lang.Float.intBitsToFloat(0x7f800001))
+    row1.setDouble(1, java.lang.Double.longBitsToDouble(0x7ff0000000000001L))
+
+    val row2 = new SpecificMutableRow(fieldTypes)
+    row2.setFloat(0, java.lang.Float.intBitsToFloat(0x7fffffff))
+    row2.setDouble(1, java.lang.Double.longBitsToDouble(0x7fffffffffffffffL))
+
+    val converter = new UnsafeRowConverter(fieldTypes)
+    val row1Buffer = new Array[Byte](converter.getSizeRequirement(row1))
+    val row2Buffer = new Array[Byte](converter.getSizeRequirement(row2))
+    converter.writeRow(
+      row1, row1Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row1Buffer.length, null)
+    converter.writeRow(
+      row2, row2Buffer, PlatformDependent.BYTE_ARRAY_OFFSET, row2Buffer.length, null)
+
+    assert(row1Buffer.toSeq === row2Buffer.toSeq)
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c032b0bf/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 192cc0a..f67f2c6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
 import java.io.File
 
 import scala.language.postfixOps
+import scala.util.Random
 
 import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
 import org.apache.spark.sql.functions._
@@ -742,6 +743,27 @@ class DataFrameSuite extends QueryTest with SQLTestUtils {
     df.col("t.``")
   }
 
+  test("SPARK-8797: sort by float column containing NaN should not crash") {
+    val inputData = Seq.fill(10)(Tuple1(Float.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toFloat))
+    val df = Random.shuffle(inputData).toDF("a")
+    df.orderBy("a").collect()
+  }
+
+  test("SPARK-8797: sort by double column containing NaN should not crash") {
+    val inputData = Seq.fill(10)(Tuple1(Double.NaN)) ++ (1 to 1000).map(x => Tuple1(x.toDouble))
+    val df = Random.shuffle(inputData).toDF("a")
+    df.orderBy("a").collect()
+  }
+
+  test("NaN is greater than all other non-NaN numeric values") {
+    val maxDouble = Seq(Double.NaN, Double.PositiveInfinity, Double.MaxValue)
+      .map(Tuple1.apply).toDF("a").selectExpr("max(a)").first()
+    assert(java.lang.Double.isNaN(maxDouble.getDouble(0)))
+    val maxFloat = Seq(Float.NaN, Float.PositiveInfinity, Float.MaxValue)
+      .map(Tuple1.apply).toDF("a").selectExpr("max(a)").first()
+    assert(java.lang.Float.isNaN(maxFloat.getFloat(0)))
+  }
+
   test("SPARK-8072: Better Exception for Duplicate Columns") {
     // only one duplicate column present
     val e = intercept[org.apache.spark.sql.AnalysisException] {

http://git-wip-us.apache.org/repos/asf/spark/blob/c032b0bf/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
index d84b57a..7cc6ffd 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
@@ -73,4 +73,16 @@ class RowSuite extends SparkFunSuite {
       row.getAs[Int]("c")
     }
   }
+
+  test("float NaN == NaN") {
+    val r1 = Row(Float.NaN)
+    val r2 = Row(Float.NaN)
+    assert(r1 === r2)
+  }
+
+  test("double NaN == NaN") {
+    val r1 = Row(Double.NaN)
+    val r2 = Row(Double.NaN)
+    assert(r1 === r2)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/c032b0bf/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
index 4f4c1f2..5fe73f7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeExternalSortSuite.scala
@@ -83,11 +83,7 @@ class UnsafeExternalSortSuite extends SparkPlanTest with BeforeAndAfterAll
{
     randomDataGenerator <- RandomDataGenerator.forType(dataType, nullable)
   ) {
     test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
-      val inputData = Seq.fill(1000)(randomDataGenerator()).filter {
-        case d: Double => !d.isNaN
-        case f: Float => !java.lang.Float.isNaN(f)
-        case x => true
-      }
+      val inputData = Seq.fill(1000)(randomDataGenerator())
       val inputDf = TestSQLContext.createDataFrame(
         TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
         StructType(StructField("a", dataType, nullable = true) :: Nil)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message