spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-6723] [MLLIB] Model import/export for ChiSqSelector
Date Fri, 23 Oct 2015 15:45:19 GMT
Repository: spark
Updated Branches:
  refs/heads/master 282a15f78 -> 4e38defae


[SPARK-6723] [MLLIB] Model import/export for ChiSqSelector

This is a PR for Parquet-based model import/export.

* Added save/load for ChiSqSelectorModel
* Updated the test suite ChiSqSelectorSuite

Author: Jayant Shekar <jayant@user-MBPMBA-3.local>

Closes #6785 from jayantshekhar/SPARK-6723.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4e38defa
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4e38defa
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4e38defa

Branch: refs/heads/master
Commit: 4e38defae13b2b13e196b4d172722ef5e6266c66
Parents: 282a15f
Author: Jayant Shekar <jayant@user-MBPMBA-3.local>
Authored: Fri Oct 23 08:45:13 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Fri Oct 23 08:45:13 2015 -0700

----------------------------------------------------------------------
 .../spark/mllib/feature/ChiSqSelector.scala     | 70 +++++++++++++++++++-
 .../mllib/feature/ChiSqSelectorSuite.scala      | 26 ++++++++
 2 files changed, 95 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4e38defa/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
index b1524cf..5246faf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/ChiSqSelector.scala
@@ -19,11 +19,18 @@ package org.apache.spark.mllib.feature
 
 import scala.collection.mutable.ArrayBuilder
 
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
 import org.apache.spark.annotation.{Experimental, Since}
 import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.stat.Statistics
+import org.apache.spark.mllib.util.{Loader, Saveable}
 import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.{SQLContext, Row}
 
 /**
  * :: Experimental ::
@@ -34,7 +41,7 @@ import org.apache.spark.rdd.RDD
 @Since("1.3.0")
 @Experimental
 class ChiSqSelectorModel @Since("1.3.0") (
-  @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer {
+  @Since("1.3.0") val selectedFeatures: Array[Int]) extends VectorTransformer with Saveable
{
 
   require(isSorted(selectedFeatures), "Array has to be sorted asc")
 
@@ -102,6 +109,67 @@ class ChiSqSelectorModel @Since("1.3.0") (
           s"Only sparse and dense vectors are supported but got ${other.getClass}.")
     }
   }
+
+  @Since("1.6.0")
+  override def save(sc: SparkContext, path: String): Unit = {
+    ChiSqSelectorModel.SaveLoadV1_0.save(sc, this, path)
+  }
+
+  override protected def formatVersion: String = "1.0"
+}
+
+object ChiSqSelectorModel extends Loader[ChiSqSelectorModel] {
+  @Since("1.6.0")
+  override def load(sc: SparkContext, path: String): ChiSqSelectorModel = {
+    ChiSqSelectorModel.SaveLoadV1_0.load(sc, path)
+  }
+
+  private[feature]
+  object SaveLoadV1_0 {
+
+    private val thisFormatVersion = "1.0"
+
+    /** Model data for import/export */
+    case class Data(feature: Int)
+
+    private[feature]
+    val thisClassName = "org.apache.spark.mllib.feature.ChiSqSelectorModel"
+
+    def save(sc: SparkContext, model: ChiSqSelectorModel, path: String): Unit = {
+      val sqlContext = new SQLContext(sc)
+      import sqlContext.implicits._
+      val metadata = compact(render(
+        ("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
+      sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+      // Create Parquet data.
+      val dataArray = Array.tabulate(model.selectedFeatures.length) { i =>
+        Data(model.selectedFeatures(i))
+      }
+      sc.parallelize(dataArray, 1).toDF().write.parquet(Loader.dataPath(path))
+
+    }
+
+    def load(sc: SparkContext, path: String): ChiSqSelectorModel = {
+      implicit val formats = DefaultFormats
+      val sqlContext = new SQLContext(sc)
+      val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+      assert(className == thisClassName)
+      assert(formatVersion == thisFormatVersion)
+
+      val dataFrame = sqlContext.read.parquet(Loader.dataPath(path))
+      val dataArray = dataFrame.select("feature")
+
+      // Check schema explicitly since erasure makes it hard to use match-case for checking.
+      Loader.checkSchema[Data](dataFrame.schema)
+
+      val features = dataArray.map {
+        case Row(feature: Int) => (feature)
+      }.collect()
+
+      return new ChiSqSelectorModel(features)
+    }
+  }
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/4e38defa/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
index 889727f..734800a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/ChiSqSelectorSuite.scala
@@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
 
 class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext {
 
@@ -63,4 +64,29 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext
{
     }.collect().toSet
     assert(filteredData == preFilteredData)
   }
+
+  test("model load / save") {
+    val model = ChiSqSelectorSuite.createModel()
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+    try {
+      model.save(sc, path)
+      val sameModel = ChiSqSelectorModel.load(sc, path)
+      ChiSqSelectorSuite.checkEqual(model, sameModel)
+    } finally {
+      Utils.deleteRecursively(tempDir)
+    }
+  }
+}
+
+object ChiSqSelectorSuite extends SparkFunSuite {
+
+  def createModel(): ChiSqSelectorModel = {
+    val arr = Array(1, 2, 3, 4)
+    new ChiSqSelectorModel(arr)
+  }
+
+  def checkEqual(a: ChiSqSelectorModel, b: ChiSqSelectorModel): Unit = {
+    assert(a.selectedFeatures.deep == b.selectedFeatures.deep)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message