spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-9981] [ML] Made labels public for StringIndexerModel
Date Fri, 14 Aug 2015 21:11:30 GMT
Repository: spark
Updated Branches:
  refs/heads/branch-1.5 59cdcc079 -> 0f4ccdc4c


[SPARK-9981] [ML] Made labels public for StringIndexerModel

Also added unit test for integration between StringIndexerModel and IndexToString

CC: holdenk We realized we should have left in your unit test (to catch the issue with removing
the inverse() method), so this adds it back.  mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #8211 from jkbradley/stridx-labels.

(cherry picked from commit 2a6590e510aba3bfc6603d280023128b3f5ac702)
Signed-off-by: Xiangrui Meng <meng@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0f4ccdc4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0f4ccdc4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0f4ccdc4

Branch: refs/heads/branch-1.5
Commit: 0f4ccdc4cfa02ad78f2c4949ddb3822d07d65104
Parents: 59cdcc0
Author: Joseph K. Bradley <joseph@databricks.com>
Authored: Fri Aug 14 14:05:03 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Fri Aug 14 14:11:26 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/feature/StringIndexer.scala   |  5 ++++-
 .../spark/ml/feature/StringIndexerSuite.scala     | 18 ++++++++++++++++++
 2 files changed, 22 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0f4ccdc4/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index f5dfba1..76f017d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -93,14 +93,17 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
 /**
  * :: Experimental ::
  * Model fitted by [[StringIndexer]].
+ *
  * NOTE: During transformation, if the input column does not exist,
  * [[StringIndexerModel.transform]] would return the input dataset unmodified.
  * This is a temporary fix for the case when target labels do not exist during prediction.
+ *
+ * @param labels  Ordered list of labels, corresponding to indices to be assigned
  */
 @Experimental
 class StringIndexerModel (
     override val uid: String,
-    labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase {
+    val labels: Array[String]) extends Model[StringIndexerModel] with StringIndexerBase {
 
   def this(labels: Array[String]) = this(Identifiable.randomUID("strIdx"), labels)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0f4ccdc4/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index d960861..5fe66a3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -116,4 +116,22 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext
{
         assert(actual === expected)
     }
   }
+
+  test("StringIndexer, IndexToString are inverses") {
+    val data = sc.parallelize(Seq((0, "a"), (1, "b"), (2, "c"), (3, "a"), (4, "a"), (5, "c")),
2)
+    val df = sqlContext.createDataFrame(data).toDF("id", "label")
+    val indexer = new StringIndexer()
+      .setInputCol("label")
+      .setOutputCol("labelIndex")
+      .fit(df)
+    val transformed = indexer.transform(df)
+    val idx2str = new IndexToString()
+      .setInputCol("labelIndex")
+      .setOutputCol("sameLabel")
+      .setLabels(indexer.labels)
+    idx2str.transform(transformed).select("label", "sameLabel").collect().foreach {
+      case Row(a: String, b: String) =>
+        assert(a === b)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message