spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-5902] [ml] Made PipelineStage.transformSchema public instead of private to ml
Date Thu, 19 Feb 2015 20:46:40 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.3 55d91d92b -> 0c494cf9a


[SPARK-5902] [ml] Made PipelineStage.transformSchema public instead of private to ml

For users to implement their own PipelineStages, we need to make PipelineStage.transformSchema
be public instead of private to ml.  This would be nice to include in Spark 1.3

CC: mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #4682 from jkbradley/SPARK-5902 and squashes the following commits:

6f02357 [Joseph K. Bradley] Made transformSchema public
0e6d0a0 [Joseph K. Bradley] made implementations of transformSchema protected as well
fdaf26a [Joseph K. Bradley] Made PipelineStage.transformSchema protected instead of private[ml]

(cherry picked from commit a5fed34355b403188ad50b567ab62ee54597b493)
Signed-off-by: Xiangrui Meng <meng@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0c494cf9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0c494cf9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0c494cf9

Branch: refs/heads/branch-1.3
Commit: 0c494cf9a3d2b717d86f53445b35f725afa89ac8
Parents: 55d91d9
Author: Joseph K. Bradley <joseph@databricks.com>
Authored: Thu Feb 19 12:46:27 2015 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Thu Feb 19 12:46:37 2015 -0800

----------------------------------------------------------------------
 .../main/scala/org/apache/spark/ml/Pipeline.scala   | 16 ++++++++++++----
 .../apache/spark/ml/feature/StandardScaler.scala    |  4 ++--
 .../apache/spark/ml/impl/estimator/Predictor.scala  |  4 ++--
 .../org/apache/spark/ml/recommendation/ALS.scala    |  4 ++--
 .../org/apache/spark/ml/tuning/CrossValidator.scala |  4 ++--
 5 files changed, 20 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0c494cf9/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 5607ed2..5bbcd2e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml
 import scala.collection.mutable.ListBuffer
 
 import org.apache.spark.Logging
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
 import org.apache.spark.ml.param.{Param, ParamMap}
 import org.apache.spark.sql.DataFrame
 import org.apache.spark.sql.types.StructType
@@ -33,9 +33,17 @@ import org.apache.spark.sql.types.StructType
 abstract class PipelineStage extends Serializable with Logging {
 
   /**
+   * :: DeveloperAPI ::
+   *
    * Derives the output schema from the input schema and parameters.
+   * The schema describes the columns and types of the data.
+   *
+   * @param schema  Input schema to this stage
+   * @param paramMap  Parameters passed to this stage
+   * @return  Output schema from this stage
    */
-  private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType
+  @DeveloperApi
+  def transformSchema(schema: StructType, paramMap: ParamMap): StructType
 
   /**
    * Derives the output schema from the input schema and parameters, optionally with logging.
@@ -126,7 +134,7 @@ class Pipeline extends Estimator[PipelineModel] {
     new PipelineModel(this, map, transformers.toArray)
   }
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType
= {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     val map = this.paramMap ++ paramMap
     val theStages = map(stages)
     require(theStages.toSet.size == theStages.size,
@@ -171,7 +179,7 @@ class PipelineModel private[ml] (
     stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
   }
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType
= {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
     val map = (fittingParamMap ++ this.paramMap) ++ paramMap
     stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))

http://git-wip-us.apache.org/repos/asf/spark/blob/0c494cf9/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index ddbd648..1142aa4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -55,7 +55,7 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
     model
   }
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType
= {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     val map = this.paramMap ++ paramMap
     val inputType = schema(map(inputCol)).dataType
     require(inputType.isInstanceOf[VectorUDT],
@@ -91,7 +91,7 @@ class StandardScalerModel private[ml] (
     dataset.withColumn(map(outputCol), scale(col(map(inputCol))))
   }
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType
= {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     val map = this.paramMap ++ paramMap
     val inputType = schema(map(inputCol)).dataType
     require(inputType.isInstanceOf[VectorUDT],

http://git-wip-us.apache.org/repos/asf/spark/blob/0c494cf9/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
index 7daeff9..dfb89cc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
@@ -132,7 +132,7 @@ private[spark] abstract class Predictor[
   @DeveloperApi
   protected def featuresDataType: DataType = new VectorUDT
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType
= {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType)
   }
 
@@ -184,7 +184,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel
   @DeveloperApi
   protected def featuresDataType: DataType = new VectorUDT
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType
= {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0c494cf9/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 8d70e43..c2ec716 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -188,7 +188,7 @@ class ALSModel private[ml] (
       .select(dataset("*"), predict(users("features"), items("features")).as(map(predictionCol)))
   }
 
-  override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType
= {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     validateAndTransformSchema(schema, paramMap)
   }
 }
@@ -292,7 +292,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
     model
   }
 
-  override private[ml] def transformSchema(schema: StructType, paramMap: ParamMap): StructType
= {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     validateAndTransformSchema(schema, paramMap)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/0c494cf9/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index b07a682..2eb1dac 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -129,7 +129,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
     cvModel
   }
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType
= {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     val map = this.paramMap ++ paramMap
     map(estimator).transformSchema(schema, paramMap)
   }
@@ -150,7 +150,7 @@ class CrossValidatorModel private[ml] (
     bestModel.transform(dataset, paramMap)
   }
 
-  private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType
= {
+  override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
     bestModel.transformSchema(schema, paramMap)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message