spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From r...@apache.org
Subject spark git commit: [SPARK-17495][SQL] Add more tests for hive hash
Date Fri, 24 Feb 2017 17:46:46 GMT
Repository: spark
Updated Branches:
  refs/heads/master a920a4369 -> 3e40f6c3d


[SPARK-17495][SQL] Add more tests for hive hash

## What changes were proposed in this pull request?

This PR adds tests hive-hash by comparing the outputs generated against Hive 1.2.1. Following
datatypes are covered by this PR:
- null
- boolean
- byte
- short
- int
- long
- float
- double
- string
- array
- map
- struct

Datatypes that I have _NOT_ covered but I will work on separately are:
- Decimal (handled separately in https://github.com/apache/spark/pull/17056)
- TimestampType
- DateType
- CalendarIntervalType

## How was this patch tested?

NA

Author: Tejas Patil <tejasp@fb.com>

Closes #17049 from tejasapatil/SPARK-17495_remaining_types.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3e40f6c3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3e40f6c3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3e40f6c3

Branch: refs/heads/master
Commit: 3e40f6c3d6fc0bcd828d09031fa3994925394889
Parents: a920a43
Author: Tejas Patil <tejasp@fb.com>
Authored: Fri Feb 24 09:46:42 2017 -0800
Committer: Reynold Xin <rxin@databricks.com>
Committed: Fri Feb 24 09:46:42 2017 -0800

----------------------------------------------------------------------
 .../sql/catalyst/expressions/HiveHasher.java    |   2 +-
 .../spark/sql/catalyst/expressions/hash.scala   |  11 +-
 .../expressions/HashExpressionsSuite.scala      | 247 ++++++++++++++++++-
 3 files changed, 252 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3e40f6c3/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java
----------------------------------------------------------------------
diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java
b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java
index c7ea908..7357743 100644
--- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java
+++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions;
 import org.apache.spark.unsafe.Platform;
 
 /**
- * Simulates Hive's hashing function at
+ * Simulates Hive's hashing function from Hive v1.2.1
  * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode()
  */
 public class HiveHasher {

http://git-wip-us.apache.org/repos/asf/spark/blob/3e40f6c3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index e14f054..2d9c2e4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -573,10 +573,9 @@ object XxHash64Function extends InterpretedHashFunction {
   }
 }
 
-
 /**
- * Simulates Hive's hashing function at
- * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() in Hive
+ * Simulates Hive's hashing function from Hive v1.2.1 at
+ * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode()
  *
  * We should use this hash function for both shuffle and bucket of Hive tables, so that
  * we can guarantee shuffle and bucketing have same data distribution
@@ -595,7 +594,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int]
{
   override protected def hasherClassName: String = classOf[HiveHasher].getName
 
   override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = {
-    HiveHashFunction.hash(value, dataType, seed).toInt
+    HiveHashFunction.hash(value, dataType, this.seed).toInt
   }
 
   override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -781,12 +780,12 @@ object HiveHashFunction extends InterpretedHashFunction {
         var i = 0
         val length = struct.numFields
         while (i < length) {
-          result = (31 * result) + hash(struct.get(i, types(i)), types(i), seed + 1).toInt
+          result = (31 * result) + hash(struct.get(i, types(i)), types(i), 0).toInt
           i += 1
         }
         result
 
-      case _ => super.hash(value, dataType, seed)
+      case _ => super.hash(value, dataType, 0)
     }
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/3e40f6c3/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
index 0326292..0cb3a79 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
@@ -19,16 +19,20 @@ package org.apache.spark.sql.catalyst.expressions
 
 import java.nio.charset.StandardCharsets
 
+import scala.collection.mutable.ArrayBuffer
+
 import org.apache.commons.codec.digest.DigestUtils
 
 import org.apache.spark.SparkFunSuite
 import org.apache.spark.sql.{RandomDataGenerator, Row}
 import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder}
 import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
+import org.apache.spark.sql.types.{ArrayType, StructType, _}
 import org.apache.spark.unsafe.types.UTF8String
 
 class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+  val random = new scala.util.Random
 
   test("md5") {
     checkEvaluation(Md5(Literal("ABC".getBytes(StandardCharsets.UTF_8))),
@@ -71,6 +75,247 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
{
     checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
   }
 
+
+  def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = {
+    // Note : All expected hashes need to be computed using Hive 1.2.1
+    val actual = HiveHashFunction.hash(input, dataType, seed = 0)
+
+    withClue(s"hash mismatch for input = `$input` of type `$dataType`.") {
+      assert(actual == expected)
+    }
+  }
+
+  def checkHiveHashForIntegralType(dataType: DataType): Unit = {
+    // corner cases
+    checkHiveHash(null, dataType, 0)
+    checkHiveHash(1, dataType, 1)
+    checkHiveHash(0, dataType, 0)
+    checkHiveHash(-1, dataType, -1)
+    checkHiveHash(Int.MaxValue, dataType, Int.MaxValue)
+    checkHiveHash(Int.MinValue, dataType, Int.MinValue)
+
+    // random values
+    for (_ <- 0 until 10) {
+      val input = random.nextInt()
+      checkHiveHash(input, dataType, input)
+    }
+  }
+
+  test("hive-hash for null") {
+    checkHiveHash(null, NullType, 0)
+  }
+
+  test("hive-hash for boolean") {
+    checkHiveHash(true, BooleanType, 1)
+    checkHiveHash(false, BooleanType, 0)
+  }
+
+  test("hive-hash for byte") {
+    checkHiveHashForIntegralType(ByteType)
+  }
+
+  test("hive-hash for short") {
+    checkHiveHashForIntegralType(ShortType)
+  }
+
+  test("hive-hash for int") {
+    checkHiveHashForIntegralType(IntegerType)
+  }
+
+  test("hive-hash for long") {
+    checkHiveHash(1L, LongType, 1L)
+    checkHiveHash(0L, LongType, 0L)
+    checkHiveHash(-1L, LongType, 0L)
+    checkHiveHash(Long.MaxValue, LongType, -2147483648)
+    // Hive's fails to parse this.. but the hashing function itself can handle this input
+    checkHiveHash(Long.MinValue, LongType, -2147483648)
+
+    for (_ <- 0 until 10) {
+      val input = random.nextLong()
+      checkHiveHash(input, LongType, ((input >>> 32) ^ input).toInt)
+    }
+  }
+
+  test("hive-hash for float") {
+    checkHiveHash(0F, FloatType, 0)
+    checkHiveHash(0.0F, FloatType, 0)
+    checkHiveHash(1.1F, FloatType, 1066192077L)
+    checkHiveHash(-1.1F, FloatType, -1081291571)
+    checkHiveHash(99999999.99999999999F, FloatType, 1287568416L)
+    checkHiveHash(Float.MaxValue, FloatType, 2139095039)
+    checkHiveHash(Float.MinValue, FloatType, -8388609)
+  }
+
+  test("hive-hash for double") {
+    checkHiveHash(0, DoubleType, 0)
+    checkHiveHash(0.0, DoubleType, 0)
+    checkHiveHash(1.1, DoubleType, -1503133693)
+    checkHiveHash(-1.1, DoubleType, 644349955)
+    checkHiveHash(1000000000.000001, DoubleType, 1104006509)
+    checkHiveHash(1000000000.0000000000000000000000001, DoubleType, 1104006501)
+    checkHiveHash(9999999999999999999.9999999999999999999, DoubleType, 594568676)
+    checkHiveHash(Double.MaxValue, DoubleType, -2146435072)
+    checkHiveHash(Double.MinValue, DoubleType, 1048576)
+  }
+
+  test("hive-hash for string") {
+    checkHiveHash(UTF8String.fromString("apache spark"), StringType, 1142704523L)
+    checkHiveHash(UTF8String.fromString("!@#$%^&*()_+=-"), StringType, -613724358L)
+    checkHiveHash(UTF8String.fromString("abcdefghijklmnopqrstuvwxyz"), StringType, 958031277L)
+    checkHiveHash(UTF8String.fromString("AbCdEfGhIjKlMnOpQrStUvWxYz012"), StringType, -648013852L)
+    // scalastyle:off nonascii
+    checkHiveHash(UTF8String.fromString("数据砖头"), StringType, -898686242L)
+    checkHiveHash(UTF8String.fromString("नमस्ते"), StringType, 2006045948L)
+    // scalastyle:on nonascii
+  }
+
+  test("hive-hash for array") {
+    // empty array
+    checkHiveHash(
+      input = new GenericArrayData(Array[Int]()),
+      dataType = ArrayType(IntegerType, containsNull = false),
+      expected = 0)
+
+    // basic case
+    checkHiveHash(
+      input = new GenericArrayData(Array(1, 10000, Int.MaxValue)),
+      dataType = ArrayType(IntegerType, containsNull = false),
+      expected = -2147172688L)
+
+    // with negative values
+    checkHiveHash(
+      input = new GenericArrayData(Array(-1L, 0L, 999L, Int.MinValue.toLong)),
+      dataType = ArrayType(LongType, containsNull = false),
+      expected = -2147452680L)
+
+    // with nulls only
+    val arrayTypeWithNull = ArrayType(IntegerType, containsNull = true)
+    checkHiveHash(
+      input = new GenericArrayData(Array(null, null)),
+      dataType = arrayTypeWithNull,
+      expected = 0)
+
+    // mix with null
+    checkHiveHash(
+      input = new GenericArrayData(Array(-12221, 89, null, 767)),
+      dataType = arrayTypeWithNull,
+      expected = -363989515)
+
+    // nested with array
+    checkHiveHash(
+      input = new GenericArrayData(
+        Array(
+          new GenericArrayData(Array(1234L, -9L, 67L)),
+          new GenericArrayData(Array(null, null)),
+          new GenericArrayData(Array(55L, -100L, -2147452680L))
+        )),
+      dataType = ArrayType(ArrayType(LongType)),
+      expected = -1007531064)
+
+    // nested with map
+    checkHiveHash(
+      input = new GenericArrayData(
+        Array(
+          new ArrayBasedMapData(
+            new GenericArrayData(Array(-99, 1234)),
+            new GenericArrayData(Array(UTF8String.fromString("sql"), null))),
+          new ArrayBasedMapData(
+            new GenericArrayData(Array(67)),
+            new GenericArrayData(Array(UTF8String.fromString("apache spark"))))
+        )),
+      dataType = ArrayType(MapType(IntegerType, StringType)),
+      expected = 1139205955)
+  }
+
+  test("hive-hash for map") {
+    val mapType = MapType(IntegerType, StringType)
+
+    // empty map
+    checkHiveHash(
+      input = new ArrayBasedMapData(new GenericArrayData(Array()), new GenericArrayData(Array())),
+      dataType = mapType,
+      expected = 0)
+
+    // basic case
+    checkHiveHash(
+      input = new ArrayBasedMapData(
+        new GenericArrayData(Array(1, 2)),
+        new GenericArrayData(Array(UTF8String.fromString("foo"), UTF8String.fromString("bar")))),
+      dataType = mapType,
+      expected = 198872)
+
+    // with null value
+    checkHiveHash(
+      input = new ArrayBasedMapData(
+        new GenericArrayData(Array(55, -99)),
+        new GenericArrayData(Array(UTF8String.fromString("apache spark"), null))),
+      dataType = mapType,
+      expected = 1142704473)
+
+    // nesting (only values can be nested as keys have to be primitive datatype)
+    val nestedMapType = MapType(IntegerType, MapType(IntegerType, StringType))
+    checkHiveHash(
+      input = new ArrayBasedMapData(
+        new GenericArrayData(Array(1, -100)),
+        new GenericArrayData(
+          Array(
+            new ArrayBasedMapData(
+              new GenericArrayData(Array(-99, 1234)),
+              new GenericArrayData(Array(UTF8String.fromString("sql"), null))),
+            new ArrayBasedMapData(
+              new GenericArrayData(Array(67)),
+              new GenericArrayData(Array(UTF8String.fromString("apache spark"))))
+          ))),
+      dataType = nestedMapType,
+      expected = -1142817416)
+  }
+
+  test("hive-hash for struct") {
+    // basic
+    val row = new GenericInternalRow(Array[Any](1, 2, 3))
+    checkHiveHash(
+      input = row,
+      dataType =
+        new StructType()
+          .add("col1", IntegerType)
+          .add("col2", IntegerType)
+          .add("col3", IntegerType),
+      expected = 1026)
+
+    // mix of several datatypes
+    val structType = new StructType()
+      .add("null", NullType)
+      .add("boolean", BooleanType)
+      .add("byte", ByteType)
+      .add("short", ShortType)
+      .add("int", IntegerType)
+      .add("long", LongType)
+      .add("arrayOfString", arrayOfString)
+      .add("mapOfString", mapOfString)
+
+    val rowValues = new ArrayBuffer[Any]()
+    rowValues += null
+    rowValues += true
+    rowValues += 1
+    rowValues += 2
+    rowValues += Int.MaxValue
+    rowValues += Long.MinValue
+    rowValues += new GenericArrayData(Array(
+      UTF8String.fromString("apache spark"),
+      UTF8String.fromString("hello world")
+    ))
+    rowValues += new ArrayBasedMapData(
+      new GenericArrayData(Array(UTF8String.fromString("project"), UTF8String.fromString("meta"))),
+      new GenericArrayData(Array(UTF8String.fromString("apache spark"), null))
+    )
+
+    val row2 = new GenericInternalRow(rowValues.toArray)
+    checkHiveHash(
+      input = row2,
+      dataType = structType,
+      expected = -2119012447)
+  }
+
   private val structOfString = new StructType().add("str", StringType)
   private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false)
   private val arrayOfString = ArrayType(StringType)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message