spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From marmb...@apache.org
Subject git commit: [SPARK-2569][SQL] Fix shipping of TEMPORARY hive UDFs.
Date Wed, 23 Jul 2014 23:26:59 GMT
Repository: spark
Updated Branches:
  refs/heads/master e060d3ee2 -> 1871574a2


[SPARK-2569][SQL] Fix shipping of TEMPORARY hive UDFs.

Instead of shipping just the name and then looking up the info on the workers, we now ship
the whole classname.  Also, I refactored the file as it was getting pretty large to move out
the type conversion code to its own file.

Author: Michael Armbrust <michael@databricks.com>

Closes #1552 from marmbrus/fixTempUdfs and squashes the following commits:

b695904 [Michael Armbrust] Make add jar execute with Hive.  Ship the whole function class
name since sometimes we cannot lookup temporary functions on the workers.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1871574a
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1871574a
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1871574a

Branch: refs/heads/master
Commit: 1871574a240e6f28adeb6bc8accc98c851cafae5
Parents: e060d3e
Author: Michael Armbrust <michael@databricks.com>
Authored: Wed Jul 23 16:26:55 2014 -0700
Committer: Michael Armbrust <michael@databricks.com>
Committed: Wed Jul 23 16:26:55 2014 -0700

----------------------------------------------------------------------
 .../apache/spark/sql/hive/HiveInspectors.scala  | 230 ++++++++++++++++
 .../org/apache/spark/sql/hive/HiveQl.scala      |   4 +-
 .../org/apache/spark/sql/hive/hiveUdfs.scala    | 262 +++----------------
 3 files changed, 261 insertions(+), 235 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/1871574a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
new file mode 100644
index 0000000..ad7dc0e
--- /dev/null
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import org.apache.hadoop.hive.common.`type`.HiveDecimal
+import org.apache.hadoop.hive.serde2.objectinspector._
+import org.apache.hadoop.hive.serde2.objectinspector.primitive._
+import org.apache.hadoop.hive.serde2.{io => hiveIo}
+import org.apache.hadoop.{io => hadoopIo}
+
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.types
+import org.apache.spark.sql.catalyst.types._
+
+/* Implicit conversions */
+import scala.collection.JavaConversions._
+
+private[hive] trait HiveInspectors {
+
+  def javaClassToDataType(clz: Class[_]): DataType = clz match {
+    // writable
+    case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType
+    case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType
+    case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType
+    case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType
+    case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType
+    case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType
+    case c: Class[_] if c == classOf[hadoopIo.Text] => StringType
+    case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType
+    case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType
+    case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType
+    case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType
+    case c: Class[_] if c == classOf[hadoopIo.BytesWritable] => BinaryType
+
+    // java class
+    case c: Class[_] if c == classOf[java.lang.String] => StringType
+    case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType
+    case c: Class[_] if c == classOf[HiveDecimal] => DecimalType
+    case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType
+    case c: Class[_] if c == classOf[Array[Byte]] => BinaryType
+    case c: Class[_] if c == classOf[java.lang.Short] => ShortType
+    case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
+    case c: Class[_] if c == classOf[java.lang.Long] => LongType
+    case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
+    case c: Class[_] if c == classOf[java.lang.Byte] => ByteType
+    case c: Class[_] if c == classOf[java.lang.Float] => FloatType
+    case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
+
+    // primitive type
+    case c: Class[_] if c == java.lang.Short.TYPE => ShortType
+    case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType
+    case c: Class[_] if c == java.lang.Long.TYPE => LongType
+    case c: Class[_] if c == java.lang.Double.TYPE => DoubleType
+    case c: Class[_] if c == java.lang.Byte.TYPE => ByteType
+    case c: Class[_] if c == java.lang.Float.TYPE => FloatType
+    case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType
+
+    case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType))
+  }
+
+  /** Converts hive types to native catalyst types. */
+  def unwrap(a: Any): Any = a match {
+    case null => null
+    case i: hadoopIo.IntWritable => i.get
+    case t: hadoopIo.Text => t.toString
+    case l: hadoopIo.LongWritable => l.get
+    case d: hadoopIo.DoubleWritable => d.get
+    case d: hiveIo.DoubleWritable => d.get
+    case s: hiveIo.ShortWritable => s.get
+    case b: hadoopIo.BooleanWritable => b.get
+    case b: hiveIo.ByteWritable => b.get
+    case b: hadoopIo.FloatWritable => b.get
+    case b: hadoopIo.BytesWritable => {
+      val bytes = new Array[Byte](b.getLength)
+      System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength)
+      bytes
+    }
+    case t: hiveIo.TimestampWritable => t.getTimestamp
+    case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue())
+    case list: java.util.List[_] => list.map(unwrap)
+    case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v))
}.toMap
+    case array: Array[_] => array.map(unwrap).toSeq
+    case p: java.lang.Short => p
+    case p: java.lang.Long => p
+    case p: java.lang.Float => p
+    case p: java.lang.Integer => p
+    case p: java.lang.Double => p
+    case p: java.lang.Byte => p
+    case p: java.lang.Boolean => p
+    case str: String => str
+    case p: java.math.BigDecimal => p
+    case p: Array[Byte] => p
+    case p: java.sql.Timestamp => p
+  }
+
+  def unwrapData(data: Any, oi: ObjectInspector): Any = oi match {
+    case hvoi: HiveVarcharObjectInspector =>
+      if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue
+    case hdoi: HiveDecimalObjectInspector =>
+      if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue())
+    case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data)
+    case li: ListObjectInspector =>
+      Option(li.getList(data))
+        .map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq)
+        .orNull
+    case mi: MapObjectInspector =>
+      Option(mi.getMap(data)).map(
+        _.map {
+          case (k,v) =>
+            (unwrapData(k, mi.getMapKeyObjectInspector),
+              unwrapData(v, mi.getMapValueObjectInspector))
+        }.toMap).orNull
+    case si: StructObjectInspector =>
+      val allRefs = si.getAllStructFieldRefs
+      new GenericRow(
+        allRefs.map(r =>
+          unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray)
+  }
+
+  /** Converts native catalyst types to the types expected by Hive */
+  def wrap(a: Any): AnyRef = a match {
+    case s: String => new hadoopIo.Text(s) // TODO why should be Text?
+    case i: Int => i: java.lang.Integer
+    case b: Boolean => b: java.lang.Boolean
+    case f: Float => f: java.lang.Float
+    case d: Double => d: java.lang.Double
+    case l: Long => l: java.lang.Long
+    case l: Short => l: java.lang.Short
+    case l: Byte => l: java.lang.Byte
+    case b: BigDecimal => b.bigDecimal
+    case b: Array[Byte] => b
+    case t: java.sql.Timestamp => t
+    case s: Seq[_] => seqAsJavaList(s.map(wrap))
+    case m: Map[_,_] =>
+      mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) })
+    case null => null
+  }
+
+  def toInspector(dataType: DataType): ObjectInspector = dataType match {
+    case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe))
+    case MapType(keyType, valueType) =>
+      ObjectInspectorFactory.getStandardMapObjectInspector(
+        toInspector(keyType), toInspector(valueType))
+    case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector
+    case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector
+    case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector
+    case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector
+    case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector
+    case FloatType => PrimitiveObjectInspectorFactory.javaFloatObjectInspector
+    case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector
+    case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector
+    case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector
+    case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector
+    case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector
+    case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector
+    case StructType(fields) =>
+      ObjectInspectorFactory.getStandardStructObjectInspector(
+        fields.map(f => f.name), fields.map(f => toInspector(f.dataType)))
+  }
+
+  def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match {
+    case s: StructObjectInspector =>
+      StructType(s.getAllStructFieldRefs.map(f => {
+        types.StructField(
+          f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true)
+      }))
+    case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector))
+    case m: MapObjectInspector =>
+      MapType(
+        inspectorToDataType(m.getMapKeyObjectInspector),
+        inspectorToDataType(m.getMapValueObjectInspector))
+    case _: WritableStringObjectInspector => StringType
+    case _: JavaStringObjectInspector => StringType
+    case _: WritableIntObjectInspector => IntegerType
+    case _: JavaIntObjectInspector => IntegerType
+    case _: WritableDoubleObjectInspector => DoubleType
+    case _: JavaDoubleObjectInspector => DoubleType
+    case _: WritableBooleanObjectInspector => BooleanType
+    case _: JavaBooleanObjectInspector => BooleanType
+    case _: WritableLongObjectInspector => LongType
+    case _: JavaLongObjectInspector => LongType
+    case _: WritableShortObjectInspector => ShortType
+    case _: JavaShortObjectInspector => ShortType
+    case _: WritableByteObjectInspector => ByteType
+    case _: JavaByteObjectInspector => ByteType
+    case _: WritableFloatObjectInspector => FloatType
+    case _: JavaFloatObjectInspector => FloatType
+    case _: WritableBinaryObjectInspector => BinaryType
+    case _: JavaBinaryObjectInspector => BinaryType
+    case _: WritableHiveDecimalObjectInspector => DecimalType
+    case _: JavaHiveDecimalObjectInspector => DecimalType
+    case _: WritableTimestampObjectInspector => TimestampType
+    case _: JavaTimestampObjectInspector => TimestampType
+  }
+
+  implicit class typeInfoConversions(dt: DataType) {
+    import org.apache.hadoop.hive.serde2.typeinfo._
+    import TypeInfoFactory._
+
+    def toTypeInfo: TypeInfo = dt match {
+      case BinaryType => binaryTypeInfo
+      case BooleanType => booleanTypeInfo
+      case ByteType => byteTypeInfo
+      case DoubleType => doubleTypeInfo
+      case FloatType => floatTypeInfo
+      case IntegerType => intTypeInfo
+      case LongType => longTypeInfo
+      case ShortType => shortTypeInfo
+      case StringType => stringTypeInfo
+      case DecimalType => decimalTypeInfo
+      case TimestampType => timestampTypeInfo
+      case NullType => voidTypeInfo
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/1871574a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index 53480a5..c4ca9f3 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -42,8 +42,6 @@ private[hive] case class ShellCommand(cmd: String) extends Command
 
 private[hive] case class SourceCommand(filePath: String) extends Command
 
-private[hive] case class AddJar(jarPath: String) extends Command
-
 private[hive] case class AddFile(filePath: String) extends Command
 
 /** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees.
*/
@@ -229,7 +227,7 @@ private[hive] object HiveQl {
       } else if (sql.trim.toLowerCase.startsWith("uncache table")) {
         CacheCommand(sql.trim.drop(14).trim, false)
       } else if (sql.trim.toLowerCase.startsWith("add jar")) {
-        AddJar(sql.trim.drop(8))
+        NativeCommand(sql)
       } else if (sql.trim.toLowerCase.startsWith("add file")) {
         AddFile(sql.trim.drop(9))
       } else if (sql.trim.toLowerCase.startsWith("dfs")) {

http://git-wip-us.apache.org/repos/asf/spark/blob/1871574a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
----------------------------------------------------------------------
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index fc33c5b..057eb60 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -24,22 +24,19 @@ import org.apache.hadoop.hive.ql.exec.UDF
 import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry}
 import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType}
 import org.apache.hadoop.hive.ql.udf.generic._
-import org.apache.hadoop.hive.serde2.objectinspector._
-import org.apache.hadoop.hive.serde2.objectinspector.primitive._
-import org.apache.hadoop.hive.serde2.{io => hiveIo}
-import org.apache.hadoop.{io => hadoopIo}
 
 import org.apache.spark.sql.Logging
 import org.apache.spark.sql.catalyst.analysis
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.types
 import org.apache.spark.sql.catalyst.types._
+import org.apache.spark.util.Utils.getContextOrSparkClassLoader
 
 /* Implicit conversions */
 import scala.collection.JavaConversions._
 
-private[hive] object HiveFunctionRegistry
-  extends analysis.FunctionRegistry with HiveFunctionFactory with HiveInspectors {
+private[hive] object HiveFunctionRegistry extends analysis.FunctionRegistry with HiveInspectors
{
+
+  def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name)
 
   def lookupFunction(name: String, children: Seq[Expression]): Expression = {
     // We only look it up to see if it exists, but do not include it in the HiveUDF since
it is
@@ -47,111 +44,37 @@ private[hive] object HiveFunctionRegistry
     val functionInfo: FunctionInfo = Option(FunctionRegistry.getFunctionInfo(name)).getOrElse(
       sys.error(s"Couldn't find function $name"))
 
+    val functionClassName = functionInfo.getFunctionClass.getName()
+
     if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
-      val function = createFunction[UDF](name)
+      val function = functionInfo.getFunctionClass.newInstance().asInstanceOf[UDF]
       val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo))
 
       lazy val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType)
 
       HiveSimpleUdf(
-        name,
+        functionClassName,
         children.zip(expectedDataTypes).map { case (e, t) => Cast(e, t) }
       )
     } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
-      HiveGenericUdf(name, children)
+      HiveGenericUdf(functionClassName, children)
     } else if (
          classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass))
{
-      HiveGenericUdaf(name, children)
+      HiveGenericUdaf(functionClassName, children)
 
     } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
-      HiveGenericUdtf(name, Nil, children)
+      HiveGenericUdtf(functionClassName, Nil, children)
     } else {
       sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
     }
   }
-
-  def javaClassToDataType(clz: Class[_]): DataType = clz match {
-    // writable
-    case c: Class[_] if c == classOf[hadoopIo.DoubleWritable] => DoubleType
-    case c: Class[_] if c == classOf[hiveIo.DoubleWritable] => DoubleType
-    case c: Class[_] if c == classOf[hiveIo.HiveDecimalWritable] => DecimalType
-    case c: Class[_] if c == classOf[hiveIo.ByteWritable] => ByteType
-    case c: Class[_] if c == classOf[hiveIo.ShortWritable] => ShortType
-    case c: Class[_] if c == classOf[hiveIo.TimestampWritable] => TimestampType
-    case c: Class[_] if c == classOf[hadoopIo.Text] => StringType
-    case c: Class[_] if c == classOf[hadoopIo.IntWritable] => IntegerType
-    case c: Class[_] if c == classOf[hadoopIo.LongWritable] => LongType
-    case c: Class[_] if c == classOf[hadoopIo.FloatWritable] => FloatType
-    case c: Class[_] if c == classOf[hadoopIo.BooleanWritable] => BooleanType
-    case c: Class[_] if c == classOf[hadoopIo.BytesWritable] => BinaryType
-
-    // java class
-    case c: Class[_] if c == classOf[java.lang.String] => StringType
-    case c: Class[_] if c == classOf[java.sql.Timestamp] => TimestampType
-    case c: Class[_] if c == classOf[HiveDecimal] => DecimalType
-    case c: Class[_] if c == classOf[java.math.BigDecimal] => DecimalType
-    case c: Class[_] if c == classOf[Array[Byte]] => BinaryType
-    case c: Class[_] if c == classOf[java.lang.Short] => ShortType
-    case c: Class[_] if c == classOf[java.lang.Integer] => IntegerType
-    case c: Class[_] if c == classOf[java.lang.Long] => LongType
-    case c: Class[_] if c == classOf[java.lang.Double] => DoubleType
-    case c: Class[_] if c == classOf[java.lang.Byte] => ByteType
-    case c: Class[_] if c == classOf[java.lang.Float] => FloatType
-    case c: Class[_] if c == classOf[java.lang.Boolean] => BooleanType
-
-    // primitive type
-    case c: Class[_] if c == java.lang.Short.TYPE => ShortType
-    case c: Class[_] if c == java.lang.Integer.TYPE => IntegerType
-    case c: Class[_] if c == java.lang.Long.TYPE => LongType
-    case c: Class[_] if c == java.lang.Double.TYPE => DoubleType
-    case c: Class[_] if c == java.lang.Byte.TYPE => ByteType
-    case c: Class[_] if c == java.lang.Float.TYPE => FloatType
-    case c: Class[_] if c == java.lang.Boolean.TYPE => BooleanType
-
-    case c: Class[_] if c.isArray => ArrayType(javaClassToDataType(c.getComponentType))
-  }
 }
 
 private[hive] trait HiveFunctionFactory {
-  def getFunctionInfo(name: String) = FunctionRegistry.getFunctionInfo(name)
-  def getFunctionClass(name: String) = getFunctionInfo(name).getFunctionClass
-  def createFunction[UDFType](name: String) =
-    getFunctionClass(name).newInstance.asInstanceOf[UDFType]
-
-  /** Converts hive types to native catalyst types. */
-  def unwrap(a: Any): Any = a match {
-    case null => null
-    case i: hadoopIo.IntWritable => i.get
-    case t: hadoopIo.Text => t.toString
-    case l: hadoopIo.LongWritable => l.get
-    case d: hadoopIo.DoubleWritable => d.get
-    case d: hiveIo.DoubleWritable => d.get
-    case s: hiveIo.ShortWritable => s.get
-    case b: hadoopIo.BooleanWritable => b.get
-    case b: hiveIo.ByteWritable => b.get
-    case b: hadoopIo.FloatWritable => b.get
-    case b: hadoopIo.BytesWritable => {
-      val bytes = new Array[Byte](b.getLength)
-      System.arraycopy(b.getBytes(), 0, bytes, 0, b.getLength)
-      bytes
-    }
-    case t: hiveIo.TimestampWritable => t.getTimestamp
-    case b: hiveIo.HiveDecimalWritable => BigDecimal(b.getHiveDecimal().bigDecimalValue())
-    case list: java.util.List[_] => list.map(unwrap)
-    case map: java.util.Map[_,_] => map.map { case (k, v) => (unwrap(k), unwrap(v))
}.toMap
-    case array: Array[_] => array.map(unwrap).toSeq
-    case p: java.lang.Short => p
-    case p: java.lang.Long => p
-    case p: java.lang.Float => p
-    case p: java.lang.Integer => p
-    case p: java.lang.Double => p
-    case p: java.lang.Byte => p
-    case p: java.lang.Boolean => p
-    case str: String => str
-    case p: java.math.BigDecimal => p
-    case p: Array[Byte] => p
-    case p: java.sql.Timestamp => p
-  }
+  val functionClassName: String
+
+  def createFunction[UDFType]() =
+    getContextOrSparkClassLoader.loadClass(functionClassName).newInstance.asInstanceOf[UDFType]
 }
 
 private[hive] abstract class HiveUdf extends Expression with Logging with HiveFunctionFactory
{
@@ -160,19 +83,17 @@ private[hive] abstract class HiveUdf extends Expression with Logging
with HiveFu
   type UDFType
   type EvaluatedType = Any
 
-  val name: String
-
   def nullable = true
   def references = children.flatMap(_.references).toSet
 
-  // FunctionInfo is not serializable so we must look it up here again.
-  lazy val functionInfo = getFunctionInfo(name)
-  lazy val function = createFunction[UDFType](name)
+  lazy val function = createFunction[UDFType]()
 
-  override def toString = s"$nodeName#${functionInfo.getDisplayName}(${children.mkString(",")})"
+  override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
 }
 
-private[hive] case class HiveSimpleUdf(name: String, children: Seq[Expression]) extends HiveUdf
{
+private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[Expression])
+  extends HiveUdf {
+
   import org.apache.spark.sql.hive.HiveFunctionRegistry._
   type UDFType = UDF
 
@@ -226,7 +147,7 @@ private[hive] case class HiveSimpleUdf(name: String, children: Seq[Expression])
   }
 }
 
-private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression])
+private[hive] case class HiveGenericUdf(functionClassName: String, children: Seq[Expression])
   extends HiveUdf with HiveInspectors {
 
   import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._
@@ -277,131 +198,8 @@ private[hive] case class HiveGenericUdf(name: String, children: Seq[Expression])
   }
 }
 
-private[hive] trait HiveInspectors {
-
-  def unwrapData(data: Any, oi: ObjectInspector): Any = oi match {
-    case hvoi: HiveVarcharObjectInspector => 
-      if (data == null) null else hvoi.getPrimitiveJavaObject(data).getValue
-    case hdoi: HiveDecimalObjectInspector => 
-      if (data == null) null else BigDecimal(hdoi.getPrimitiveJavaObject(data).bigDecimalValue())
-    case pi: PrimitiveObjectInspector => pi.getPrimitiveJavaObject(data)
-    case li: ListObjectInspector =>
-      Option(li.getList(data))
-        .map(_.map(unwrapData(_, li.getListElementObjectInspector)).toSeq)
-        .orNull
-    case mi: MapObjectInspector =>
-      Option(mi.getMap(data)).map(
-        _.map {
-          case (k,v) =>
-            (unwrapData(k, mi.getMapKeyObjectInspector),
-              unwrapData(v, mi.getMapValueObjectInspector))
-        }.toMap).orNull
-    case si: StructObjectInspector =>
-      val allRefs = si.getAllStructFieldRefs
-      new GenericRow(
-        allRefs.map(r =>
-          unwrapData(si.getStructFieldData(data,r), r.getFieldObjectInspector)).toArray)
-  }
-
-  /** Converts native catalyst types to the types expected by Hive */
-  def wrap(a: Any): AnyRef = a match {
-    case s: String => new hadoopIo.Text(s) // TODO why should be Text?
-    case i: Int => i: java.lang.Integer
-    case b: Boolean => b: java.lang.Boolean
-    case f: Float => f: java.lang.Float
-    case d: Double => d: java.lang.Double
-    case l: Long => l: java.lang.Long
-    case l: Short => l: java.lang.Short
-    case l: Byte => l: java.lang.Byte
-    case b: BigDecimal => b.bigDecimal
-    case b: Array[Byte] => b
-    case t: java.sql.Timestamp => t
-    case s: Seq[_] => seqAsJavaList(s.map(wrap))
-    case m: Map[_,_] =>
-      mapAsJavaMap(m.map { case (k, v) => wrap(k) -> wrap(v) })
-    case null => null
-  }
-
-  def toInspector(dataType: DataType): ObjectInspector = dataType match {
-    case ArrayType(tpe) => ObjectInspectorFactory.getStandardListObjectInspector(toInspector(tpe))
-    case MapType(keyType, valueType) =>
-      ObjectInspectorFactory.getStandardMapObjectInspector(
-        toInspector(keyType), toInspector(valueType))
-    case StringType => PrimitiveObjectInspectorFactory.javaStringObjectInspector
-    case IntegerType => PrimitiveObjectInspectorFactory.javaIntObjectInspector
-    case DoubleType => PrimitiveObjectInspectorFactory.javaDoubleObjectInspector
-    case BooleanType => PrimitiveObjectInspectorFactory.javaBooleanObjectInspector
-    case LongType => PrimitiveObjectInspectorFactory.javaLongObjectInspector
-    case FloatType => PrimitiveObjectInspectorFactory.javaFloatObjectInspector
-    case ShortType => PrimitiveObjectInspectorFactory.javaShortObjectInspector
-    case ByteType => PrimitiveObjectInspectorFactory.javaByteObjectInspector
-    case NullType => PrimitiveObjectInspectorFactory.javaVoidObjectInspector
-    case BinaryType => PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector
-    case TimestampType => PrimitiveObjectInspectorFactory.javaTimestampObjectInspector
-    case DecimalType => PrimitiveObjectInspectorFactory.javaHiveDecimalObjectInspector
-    case StructType(fields) =>
-      ObjectInspectorFactory.getStandardStructObjectInspector(
-        fields.map(f => f.name), fields.map(f => toInspector(f.dataType)))
-  }
-
-  def inspectorToDataType(inspector: ObjectInspector): DataType = inspector match {
-    case s: StructObjectInspector =>
-      StructType(s.getAllStructFieldRefs.map(f => {
-        types.StructField(
-          f.getFieldName, inspectorToDataType(f.getFieldObjectInspector), nullable = true)
-      }))
-    case l: ListObjectInspector => ArrayType(inspectorToDataType(l.getListElementObjectInspector))
-    case m: MapObjectInspector =>
-      MapType(
-        inspectorToDataType(m.getMapKeyObjectInspector),
-        inspectorToDataType(m.getMapValueObjectInspector))
-    case _: WritableStringObjectInspector => StringType
-    case _: JavaStringObjectInspector => StringType
-    case _: WritableIntObjectInspector => IntegerType
-    case _: JavaIntObjectInspector => IntegerType
-    case _: WritableDoubleObjectInspector => DoubleType
-    case _: JavaDoubleObjectInspector => DoubleType
-    case _: WritableBooleanObjectInspector => BooleanType
-    case _: JavaBooleanObjectInspector => BooleanType
-    case _: WritableLongObjectInspector => LongType
-    case _: JavaLongObjectInspector => LongType
-    case _: WritableShortObjectInspector => ShortType
-    case _: JavaShortObjectInspector => ShortType
-    case _: WritableByteObjectInspector => ByteType
-    case _: JavaByteObjectInspector => ByteType
-    case _: WritableFloatObjectInspector => FloatType
-    case _: JavaFloatObjectInspector => FloatType
-    case _: WritableBinaryObjectInspector => BinaryType
-    case _: JavaBinaryObjectInspector => BinaryType
-    case _: WritableHiveDecimalObjectInspector => DecimalType
-    case _: JavaHiveDecimalObjectInspector => DecimalType
-    case _: WritableTimestampObjectInspector => TimestampType
-    case _: JavaTimestampObjectInspector => TimestampType
-  }
-
-  implicit class typeInfoConversions(dt: DataType) {
-    import org.apache.hadoop.hive.serde2.typeinfo._
-    import TypeInfoFactory._
-
-    def toTypeInfo: TypeInfo = dt match {
-      case BinaryType => binaryTypeInfo
-      case BooleanType => booleanTypeInfo
-      case ByteType => byteTypeInfo
-      case DoubleType => doubleTypeInfo
-      case FloatType => floatTypeInfo
-      case IntegerType => intTypeInfo
-      case LongType => longTypeInfo
-      case ShortType => shortTypeInfo
-      case StringType => stringTypeInfo
-      case DecimalType => decimalTypeInfo
-      case TimestampType => timestampTypeInfo
-      case NullType => voidTypeInfo
-    }
-  }
-}
-
 private[hive] case class HiveGenericUdaf(
-    name: String,
+    functionClassName: String,
     children: Seq[Expression]) extends AggregateExpression
   with HiveInspectors
   with HiveFunctionFactory {
@@ -409,7 +207,7 @@ private[hive] case class HiveGenericUdaf(
   type UDFType = AbstractGenericUDAFResolver
 
   @transient
-  protected lazy val resolver: AbstractGenericUDAFResolver = createFunction(name)
+  protected lazy val resolver: AbstractGenericUDAFResolver = createFunction()
 
   @transient
   protected lazy val objectInspector  = {
@@ -426,9 +224,9 @@ private[hive] case class HiveGenericUdaf(
 
   def references: Set[Attribute] = children.map(_.references).flatten.toSet
 
-  override def toString = s"$nodeName#$name(${children.mkString(",")})"
+  override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
 
-  def newInstance() = new HiveUdafFunction(name, children, this)
+  def newInstance() = new HiveUdafFunction(functionClassName, children, this)
 }
 
 /**
@@ -443,7 +241,7 @@ private[hive] case class HiveGenericUdaf(
  * user defined aggregations, which have clean semantics even in a partitioned execution.
  */
 private[hive] case class HiveGenericUdtf(
-    name: String,
+    functionClassName: String,
     aliasNames: Seq[String],
     children: Seq[Expression])
   extends Generator with HiveInspectors with HiveFunctionFactory {
@@ -451,7 +249,7 @@ private[hive] case class HiveGenericUdtf(
   override def references = children.flatMap(_.references).toSet
 
   @transient
-  protected lazy val function: GenericUDTF = createFunction(name)
+  protected lazy val function: GenericUDTF = createFunction()
 
   protected lazy val inputInspectors = children.map(_.dataType).map(toInspector)
 
@@ -506,11 +304,11 @@ private[hive] case class HiveGenericUdtf(
     }
   }
 
-  override def toString = s"$nodeName#$name(${children.mkString(",")})"
+  override def toString = s"$nodeName#$functionClassName(${children.mkString(",")})"
 }
 
 private[hive] case class HiveUdafFunction(
-    functionName: String,
+    functionClassName: String,
     exprs: Seq[Expression],
     base: AggregateExpression)
   extends AggregateFunction
@@ -519,7 +317,7 @@ private[hive] case class HiveUdafFunction(
 
   def this() = this(null, null, null)
 
-  private val resolver = createFunction[AbstractGenericUDAFResolver](functionName)
+  private val resolver = createFunction[AbstractGenericUDAFResolver]()
 
   private val inspectors = exprs.map(_.dataType).map(toInspector).toArray
 


Mime
View raw message