spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-19899][ML] Replace featuresCol with itemsCol in ml.fpm.FPGrowth
Date Mon, 20 Mar 2017 17:58:33 GMT
Repository: spark
Updated Branches:
  refs/heads/master fc7554599 -> bec6b16c1


[SPARK-19899][ML] Replace featuresCol with itemsCol in ml.fpm.FPGrowth

## What changes were proposed in this pull request?

Replaces `featuresCol` `Param` with `itemsCol`. See [SPARK-19899](https://issues.apache.org/jira/browse/SPARK-19899).

## How was this patch tested?

Manual tests. Existing unit tests.

Author: zero323 <zero323@users.noreply.github.com>

Closes #17321 from zero323/SPARK-19899.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/bec6b16c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/bec6b16c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/bec6b16c

Branch: refs/heads/master
Commit: bec6b16c1900fe93def89cc5eb51cbef498196cb
Parents: fc75545
Author: zero323 <zero323@users.noreply.github.com>
Authored: Mon Mar 20 10:58:30 2017 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Mon Mar 20 10:58:30 2017 -0700

----------------------------------------------------------------------
 .../org/apache/spark/ml/fpm/FPGrowth.scala      | 35 ++++++++++++++------
 .../org/apache/spark/ml/fpm/FPGrowthSuite.scala | 14 ++++----
 2 files changed, 31 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/bec6b16c/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index fa39dd9..e2bc270 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -25,7 +25,7 @@ import org.apache.hadoop.fs.Path
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.ml.{Estimator, Model}
 import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasPredictionCol}
+import org.apache.spark.ml.param.shared.HasPredictionCol
 import org.apache.spark.ml.util._
 import org.apache.spark.mllib.fpm.{AssociationRules => MLlibAssociationRules,
   FPGrowth => MLlibFPGrowth}
@@ -37,7 +37,20 @@ import org.apache.spark.sql.types._
 /**
  * Common params for FPGrowth and FPGrowthModel
  */
-private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol with HasPredictionCol
{
+private[fpm] trait FPGrowthParams extends Params with HasPredictionCol {
+
+  /**
+   * Items column name.
+   * Default: "items"
+   * @group param
+   */
+  @Since("2.2.0")
+  val itemsCol: Param[String] = new Param[String](this, "itemsCol", "items column name")
+  setDefault(itemsCol -> "items")
+
+  /** @group getParam */
+  @Since("2.2.0")
+  def getItemsCol: String = $(itemsCol)
 
   /**
    * Minimal support level of the frequent pattern. [0.0, 1.0]. Any pattern that appears
@@ -91,10 +104,10 @@ private[fpm] trait FPGrowthParams extends Params with HasFeaturesCol
with HasPre
    */
   @Since("2.2.0")
   protected def validateAndTransformSchema(schema: StructType): StructType = {
-    val inputType = schema($(featuresCol)).dataType
+    val inputType = schema($(itemsCol)).dataType
     require(inputType.isInstanceOf[ArrayType],
       s"The input column must be ArrayType, but got $inputType.")
-    SchemaUtils.appendColumn(schema, $(predictionCol), schema($(featuresCol)).dataType)
+    SchemaUtils.appendColumn(schema, $(predictionCol), schema($(itemsCol)).dataType)
   }
 }
 
@@ -133,7 +146,7 @@ class FPGrowth @Since("2.2.0") (
 
   /** @group setParam */
   @Since("2.2.0")
-  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+  def setItemsCol(value: String): this.type = set(itemsCol, value)
 
   /** @group setParam */
   @Since("2.2.0")
@@ -146,8 +159,8 @@ class FPGrowth @Since("2.2.0") (
   }
 
   private def genericFit[T: ClassTag](dataset: Dataset[_]): FPGrowthModel = {
-    val data = dataset.select($(featuresCol))
-    val items = data.where(col($(featuresCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray)
+    val data = dataset.select($(itemsCol))
+    val items = data.where(col($(itemsCol)).isNotNull).rdd.map(r => r.getSeq[T](0).toArray)
     val mllibFP = new MLlibFPGrowth().setMinSupport($(minSupport))
     if (isSet(numPartitions)) {
       mllibFP.setNumPartitions($(numPartitions))
@@ -156,7 +169,7 @@ class FPGrowth @Since("2.2.0") (
     val rows = parentModel.freqItemsets.map(f => Row(f.items, f.freq))
 
     val schema = StructType(Seq(
-      StructField("items", dataset.schema($(featuresCol)).dataType, nullable = false),
+      StructField("items", dataset.schema($(itemsCol)).dataType, nullable = false),
       StructField("freq", LongType, nullable = false)))
     val frequentItems = dataset.sparkSession.createDataFrame(rows, schema)
     copyValues(new FPGrowthModel(uid, frequentItems)).setParent(this)
@@ -198,7 +211,7 @@ class FPGrowthModel private[ml] (
 
   /** @group setParam */
   @Since("2.2.0")
-  def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+  def setItemsCol(value: String): this.type = set(itemsCol, value)
 
   /** @group setParam */
   @Since("2.2.0")
@@ -235,7 +248,7 @@ class FPGrowthModel private[ml] (
       .collect().asInstanceOf[Array[(Seq[Any], Seq[Any])]]
     val brRules = dataset.sparkSession.sparkContext.broadcast(rules)
 
-    val dt = dataset.schema($(featuresCol)).dataType
+    val dt = dataset.schema($(itemsCol)).dataType
     // For each rule, examine the input items and summarize the consequents
     val predictUDF = udf((items: Seq[_]) => {
       if (items != null) {
@@ -249,7 +262,7 @@ class FPGrowthModel private[ml] (
       } else {
         Seq.empty
       }}, dt)
-    dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+    dataset.withColumn($(predictionCol), predictUDF(col($(itemsCol))))
   }
 
   @Since("2.2.0")

http://git-wip-us.apache.org/repos/asf/spark/blob/bec6b16c/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
index 910d4b0..4603a61 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
@@ -34,7 +34,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with
Defaul
 
   test("FPGrowth fit and transform with different data types") {
     Array(IntegerType, StringType, ShortType, LongType, ByteType).foreach { dt =>
-      val data = dataset.withColumn("features", col("features").cast(ArrayType(dt)))
+      val data = dataset.withColumn("items", col("items").cast(ArrayType(dt)))
       val model = new FPGrowth().setMinSupport(0.5).fit(data)
       val generatedRules = model.setMinConfidence(0.5).associationRules
       val expectedRules = spark.createDataFrame(Seq(
@@ -52,8 +52,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with
Defaul
         (0, Array("1", "2"), Array.emptyIntArray),
         (0, Array("1", "2"), Array.emptyIntArray),
         (0, Array("1", "3"), Array(2))
-      )).toDF("id", "features", "prediction")
-        .withColumn("features", col("features").cast(ArrayType(dt)))
+      )).toDF("id", "items", "prediction")
+        .withColumn("items", col("items").cast(ArrayType(dt)))
         .withColumn("prediction", col("prediction").cast(ArrayType(dt)))
       assert(expectedTransformed.collect().toSet.equals(
         transformed.collect().toSet))
@@ -79,7 +79,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with
Defaul
       (1, Array("1", "2", "3", "5")),
       (2, Array("1", "2", "3", "4")),
       (3, null.asInstanceOf[Array[String]])
-    )).toDF("id", "features")
+    )).toDF("id", "items")
     val model = new FPGrowth().setMinSupport(0.7).fit(dataset)
     val prediction = model.transform(df)
     assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty)
@@ -108,11 +108,11 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext
with Defaul
     val dataset = spark.createDataFrame(Seq(
       Array("1", "3"),
       Array("2", "3")
-    ).map(Tuple1(_))).toDF("features")
+    ).map(Tuple1(_))).toDF("items")
     val model = new FPGrowth().fit(dataset)
 
     val prediction = model.transform(
-      spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features")
+      spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items")
     ).first().getAs[Seq[String]]("prediction")
 
     assert(prediction === Seq("3"))
@@ -127,7 +127,7 @@ object FPGrowthSuite {
       (0, Array("1", "2")),
       (0, Array("1", "2")),
       (0, Array("1", "3"))
-    )).toDF("id", "features")
+    )).toDF("id", "items")
   }
 
   /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message