spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From mln...@apache.org
Subject spark git commit: [SPARK-15668][ML] ml.feature: update check schema to avoid confusion when user use MLlib.vector as input type
Date Thu, 02 Jun 2016 23:37:05 GMT
Repository: spark
Updated Branches:
  refs/heads/master ccd298eb6 -> 5855e0057


[SPARK-15668][ML] ml.feature: update check schema to avoid confusion when user use MLlib.vector
as input type

## What changes were proposed in this pull request?

ml.feature: update check schema to avoid confusion when user use MLlib.vector as input type

## How was this patch tested?
existing ut

Author: Yuhao Yang <yuhao.yang@intel.com>

Closes #13411 from hhbyyh/schemaCheck.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5855e005
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5855e005
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5855e005

Branch: refs/heads/master
Commit: 5855e0057defeab8006ca4f7b0196003bbc9e899
Parents: ccd298e
Author: Yuhao Yang <yuhao.yang@intel.com>
Authored: Thu Jun 2 16:37:01 2016 -0700
Committer: Nick Pentreath <nickp@za.ibm.com>
Committed: Thu Jun 2 16:37:01 2016 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/feature/MaxAbsScaler.scala  |  4 +--
 .../apache/spark/ml/feature/MinMaxScaler.scala  |  4 +--
 .../scala/org/apache/spark/ml/feature/PCA.scala | 28 +++++++++-----------
 .../spark/ml/feature/StandardScaler.scala       | 25 ++++++++---------
 4 files changed, 25 insertions(+), 36 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5855e005/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
index 1b51599..7298a18 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
@@ -39,9 +39,7 @@ private[feature] trait MaxAbsScalerParams extends Params with HasInputCol
with H
 
    /** Validates and transforms the input schema. */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
-    val inputType = schema($(inputCol)).dataType
-    require(inputType.isInstanceOf[VectorUDT],
-      s"Input column ${$(inputCol)} must be a vector column")
+    SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
     require(!schema.fieldNames.contains($(outputCol)),
       s"Output column ${$(outputCol)} already exists.")
     val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)

http://git-wip-us.apache.org/repos/asf/spark/blob/5855e005/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index d15f1b8..a27bed5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -63,9 +63,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol
with H
   /** Validates and transforms the input schema. */
   protected def validateAndTransformSchema(schema: StructType): StructType = {
     require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})")
-    val inputType = schema($(inputCol)).dataType
-    require(inputType.isInstanceOf[VectorUDT],
-      s"Input column ${$(inputCol)} must be a vector column")
+    SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
     require(!schema.fieldNames.contains($(outputCol)),
       s"Output column ${$(outputCol)} already exists.")
     val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)

http://git-wip-us.apache.org/repos/asf/spark/blob/5855e005/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index dbbaa5a..2f667af 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -49,8 +49,16 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC
   /** @group getParam */
   def getK: Int = $(k)
 
-}
+  /** Validates and transforms the input schema. */
+  protected def validateAndTransformSchema(schema: StructType): StructType = {
+    SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
+    require(!schema.fieldNames.contains($(outputCol)),
+      s"Output column ${$(outputCol)} already exists.")
+    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
+    StructType(outputFields)
+  }
 
+}
 /**
  * :: Experimental ::
  * PCA trains a model to project vectors to a lower dimensional space of the top [[PCA!.k]]
@@ -86,13 +94,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    val inputType = schema($(inputCol)).dataType
-    require(inputType.isInstanceOf[VectorUDT],
-      s"Input column ${$(inputCol)} must be a vector column")
-    require(!schema.fieldNames.contains($(outputCol)),
-      s"Output column ${$(outputCol)} already exists.")
-    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
-    StructType(outputFields)
+    validateAndTransformSchema(schema)
   }
 
   override def copy(extra: ParamMap): PCA = defaultCopy(extra)
@@ -148,13 +150,7 @@ class PCAModel private[ml] (
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    val inputType = schema($(inputCol)).dataType
-    require(inputType.isInstanceOf[VectorUDT],
-      s"Input column ${$(inputCol)} must be a vector column")
-    require(!schema.fieldNames.contains($(outputCol)),
-      s"Output column ${$(outputCol)} already exists.")
-    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
-    StructType(outputFields)
+    validateAndTransformSchema(schema)
   }
 
   override def copy(extra: ParamMap): PCAModel = {
@@ -201,7 +197,7 @@ object PCAModel extends MLReadable[PCAModel] {
       val versionRegex = "([0-9]+)\\.([0-9]+).*".r
       val hasExplainedVariance = metadata.sparkVersion match {
         case versionRegex(major, minor) =>
-          (major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6))
+          major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6)
         case _ => false
       }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/5855e005/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 9d084b5..7cec369 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -62,6 +62,15 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol
with
   /** @group getParam */
   def getWithStd: Boolean = $(withStd)
 
+  /** Validates and transforms the input schema. */
+  protected def validateAndTransformSchema(schema: StructType): StructType = {
+    SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
+    require(!schema.fieldNames.contains($(outputCol)),
+      s"Output column ${$(outputCol)} already exists.")
+    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
+    StructType(outputFields)
+  }
+
   setDefault(withMean -> false, withStd -> true)
 }
 
@@ -105,13 +114,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    val inputType = schema($(inputCol)).dataType
-    require(inputType.isInstanceOf[VectorUDT],
-      s"Input column ${$(inputCol)} must be a vector column")
-    require(!schema.fieldNames.contains($(outputCol)),
-      s"Output column ${$(outputCol)} already exists.")
-    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
-    StructType(outputFields)
+    validateAndTransformSchema(schema)
   }
 
   override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra)
@@ -159,13 +162,7 @@ class StandardScalerModel private[ml] (
   }
 
   override def transformSchema(schema: StructType): StructType = {
-    val inputType = schema($(inputCol)).dataType
-    require(inputType.isInstanceOf[VectorUDT],
-      s"Input column ${$(inputCol)} must be a vector column")
-    require(!schema.fieldNames.contains($(outputCol)),
-      s"Output column ${$(outputCol)} already exists.")
-    val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
-    StructType(outputFields)
+    validateAndTransformSchema(schema)
   }
 
   override def copy(extra: ParamMap): StandardScalerModel = {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message