spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-16446][SPARKR][ML] Gaussian Mixture Model wrapper in SparkR
Date Wed, 17 Aug 2016 18:18:35 GMT
Repository: spark
Updated Branches:
  refs/heads/master e3fec51fa -> 4d92af310


[SPARK-16446][SPARKR][ML] Gaussian Mixture Model wrapper in SparkR

## What changes were proposed in this pull request?
Gaussian Mixture Model wrapper in SparkR, similarly to R's ```mvnormalmixEM```.

## How was this patch tested?
Unit test.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #14392 from yanboliang/spark-16446.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4d92af31
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4d92af31
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4d92af31

Branch: refs/heads/master
Commit: 4d92af310ad29ade039e4130f91f2a3d9180deef
Parents: e3fec51
Author: Yanbo Liang <ybliang8@gmail.com>
Authored: Wed Aug 17 11:18:33 2016 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Wed Aug 17 11:18:33 2016 -0700

----------------------------------------------------------------------
 R/pkg/NAMESPACE                                 |   3 +-
 R/pkg/R/generics.R                              |   7 +
 R/pkg/R/mllib.R                                 | 139 ++++++++++++++++++-
 R/pkg/inst/tests/testthat/test_mllib.R          |  62 +++++++++
 .../spark/ml/r/GaussianMixtureWrapper.scala     | 128 +++++++++++++++++
 .../scala/org/apache/spark/ml/r/RWrappers.scala |   2 +
 6 files changed, 338 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4d92af31/R/pkg/NAMESPACE
----------------------------------------------------------------------
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index 1e23b23..c71eec5 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -25,7 +25,8 @@ exportMethods("glm",
               "fitted",
               "spark.naiveBayes",
               "spark.survreg",
-              "spark.isoreg")
+              "spark.isoreg",
+              "spark.gaussianMixture")
 
 # Job group lifecycle management methods
 export("setJobGroup",

http://git-wip-us.apache.org/repos/asf/spark/blob/4d92af31/R/pkg/R/generics.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index ebacc11..06bb25d 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1308,6 +1308,13 @@ setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spar
 #' @export
 setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg")
})
 
+#' @rdname spark.gaussianMixture
+#' @export
+setGeneric("spark.gaussianMixture",
+           function(data, formula, ...) {
+             standardGeneric("spark.gaussianMixture")
+           })
+
 #' @rdname write.ml
 #' @export
 setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") })

http://git-wip-us.apache.org/repos/asf/spark/blob/4d92af31/R/pkg/R/mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index 0dcc54d..db74046 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -60,6 +60,13 @@ setClass("KMeansModel", representation(jobj = "jobj"))
 #' @note IsotonicRegressionModel since 2.1.0
 setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
 
+#' S4 class that represents a GaussianMixtureModel
+#'
+#' @param jobj a Java object reference to the backing Scala GaussianMixtureModel
+#' @export
+#' @note GaussianMixtureModel since 2.1.0
+setClass("GaussianMixtureModel", representation(jobj = "jobj"))
+
 #' Saves the MLlib model to the input path
 #'
 #' Saves the MLlib model to the input path. For more information, see the specific
@@ -67,7 +74,7 @@ setClass("IsotonicRegressionModel", representation(jobj = "jobj"))
 #' @rdname write.ml
 #' @name write.ml
 #' @export
-#' @seealso \link{spark.glm}, \link{glm}
+#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
 #' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
 #' @seealso \link{spark.isoreg}
 #' @seealso \link{read.ml}
@@ -80,7 +87,7 @@ NULL
 #' @rdname predict
 #' @name predict
 #' @export
-#' @seealso \link{spark.glm}, \link{glm}
+#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
 #' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
 #' @seealso \link{spark.isoreg}
 NULL
@@ -649,6 +656,25 @@ setMethod("write.ml", signature(object = "IsotonicRegressionModel", path
= "char
            invisible(callJMethod(writer, "save", path))
           })
 
+#  Save fitted MLlib model to the input path
+
+#' @param path the directory where the model is saved.
+#' @param overwrite overwrites or not if the output path already exists. Default is FALSE
+#'                  which means throw exception if the output path exists.
+#'
+#' @aliases write.ml,GaussianMixtureModel,character-method
+#' @rdname spark.gaussianMixture
+#' @export
+#' @note write.ml(GaussianMixtureModel, character) since 2.1.0
+setMethod("write.ml", signature(object = "GaussianMixtureModel", path = "character"),
+          function(object, path, overwrite = FALSE) {
+            writer <- callJMethod(object@jobj, "write")
+            if (overwrite) {
+              writer <- callJMethod(writer, "overwrite")
+            }
+            invisible(callJMethod(writer, "save", path))
+          })
+
 #' Load a fitted MLlib model from the input path.
 #'
 #' @param path Path of the model to read.
@@ -676,6 +702,8 @@ read.ml <- function(path) {
       return(new("KMeansModel", jobj = jobj))
   } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) {
       return(new("IsotonicRegressionModel", jobj = jobj))
+  } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) {
+      return(new("GaussianMixtureModel", jobj = jobj))
   } else {
     stop(paste("Unsupported model: ", jobj))
   }
@@ -757,3 +785,110 @@ setMethod("predict", signature(object = "AFTSurvivalRegressionModel"),
           function(object, newData) {
             return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
           })
+
+#' Multivariate Gaussian Mixture Model (GMM)
+#'
+#' Fits multivariate gaussian mixture model against a Spark DataFrame, similarly to R's
+#' mvnormalmixEM(). Users can call \code{summary} to print a summary of the fitted model,
+#' \code{predict} to make predictions on new data, and \code{write.ml}/\code{read.ml}
+#' to save/load fitted models.
+#'
+#' @param data a SparkDataFrame for training.
+#' @param formula a symbolic description of the model to be fitted. Currently only a few
formula
+#'                operators are supported, including '~', '.', ':', '+', and '-'.
+#'                Note that the response variable of formula is empty in spark.gaussianMixture.
+#' @param k number of independent Gaussians in the mixture model.
+#' @param maxIter maximum iteration number.
+#' @param tol the convergence tolerance.
+#' @aliases spark.gaussianMixture,SparkDataFrame,formula-method
+#' @return \code{spark.gaussianMixture} returns a fitted multivariate gaussian mixture model.
+#' @rdname spark.gaussianMixture
+#' @name spark.gaussianMixture
+#' @seealso mixtools: \url{https://cran.r-project.org/web/packages/mixtools/}
+#' @export
+#' @examples
+#' \dontrun{
+#' sparkR.session()
+#' library(mvtnorm)
+#' set.seed(100)
+#' a <- rmvnorm(4, c(0, 0))
+#' b <- rmvnorm(6, c(3, 4))
+#' data <- rbind(a, b)
+#' df <- createDataFrame(as.data.frame(data))
+#' model <- spark.gaussianMixture(df, ~ V1 + V2, k = 2)
+#' summary(model)
+#'
+#' # fitted values on training data
+#' fitted <- predict(model, df)
+#' head(select(fitted, "V1", "prediction"))
+#'
+#' # save fitted model to input path
+#' path <- "path/to/model"
+#' write.ml(model, path)
+#'
+#' # can also read back the saved model and print
+#' savedModel <- read.ml(path)
+#' summary(savedModel)
+#' }
+#' @note spark.gaussianMixture since 2.1.0
+#' @seealso \link{predict}, \link{read.ml}, \link{write.ml}
+setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula = "formula"),
+          function(data, formula, k = 2, maxIter = 100, tol = 0.01) {
+            formula <- paste(deparse(formula), collapse = "")
+            jobj <- callJStatic("org.apache.spark.ml.r.GaussianMixtureWrapper", "fit",
data@sdf,
+                                formula, as.integer(k), as.integer(maxIter), as.numeric(tol))
+            return(new("GaussianMixtureModel", jobj = jobj))
+          })
+
+#  Get the summary of a multivariate gaussian mixture model
+
+#' @param object a fitted gaussian mixture model.
+#' @param ... currently not used argument(s) passed to the method.
+#' @return \code{summary} returns the model's lambda, mu, sigma and posterior.
+#' @aliases spark.gaussianMixture,SparkDataFrame,formula-method
+#' @rdname spark.gaussianMixture
+#' @export
+#' @note summary(GaussianMixtureModel) since 2.1.0
+setMethod("summary", signature(object = "GaussianMixtureModel"),
+          function(object, ...) {
+            jobj <- object@jobj
+            is.loaded <- callJMethod(jobj, "isLoaded")
+            lambda <- unlist(callJMethod(jobj, "lambda"))
+            muList <- callJMethod(jobj, "mu")
+            sigmaList <- callJMethod(jobj, "sigma")
+            k <- callJMethod(jobj, "k")
+            dim <- callJMethod(jobj, "dim")
+            mu <- c()
+            for (i in 1 : k) {
+              start <- (i - 1) * dim + 1
+              end <- i * dim
+              mu[[i]] <- unlist(muList[start : end])
+            }
+            sigma <- c()
+            for (i in 1 : k) {
+              start <- (i - 1) * dim * dim + 1
+              end <- i * dim * dim
+              sigma[[i]] <- t(matrix(sigmaList[start : end], ncol = dim))
+            }
+            posterior <- if (is.loaded) {
+              NULL
+            } else {
+              dataFrame(callJMethod(jobj, "posterior"))
+            }
+            return(list(lambda = lambda, mu = mu, sigma = sigma,
+                   posterior = posterior, is.loaded = is.loaded))
+          })
+
+#  Predicted values based on a gaussian mixture model
+
+#' @param newData a SparkDataFrame for testing.
+#' @return \code{predict} returns a SparkDataFrame containing predicted labels in a column
named
+#'         "prediction".
+#' @aliases predict,GaussianMixtureModel,SparkDataFrame-method
+#' @rdname spark.gaussianMixture
+#' @export
+#' @note predict(GaussianMixtureModel) since 2.1.0
+setMethod("predict", signature(object = "GaussianMixtureModel"),
+          function(object, newData) {
+            return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+          })

http://git-wip-us.apache.org/repos/asf/spark/blob/4d92af31/R/pkg/inst/tests/testthat/test_mllib.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index b759b28..9617986 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -508,4 +508,66 @@ test_that("spark.isotonicRegression", {
   unlink(modelPath)
 })
 
+test_that("spark.gaussianMixture", {
+  # R code to reproduce the result.
+  # nolint start
+  #' library(mvtnorm)
+  #' set.seed(100)
+  #' a <- rmvnorm(4, c(0, 0))
+  #' b <- rmvnorm(6, c(3, 4))
+  #' data <- rbind(a, b)
+  #' model <- mvnormalmixEM(data, k = 2)
+  #' model$lambda
+  #
+  #  [1] 0.4 0.6
+  #
+  #' model$mu
+  #
+  #  [1] -0.2614822  0.5128697
+  #  [1] 2.647284 4.544682
+  #
+  #' model$sigma
+  #
+  #  [[1]]
+  #  [,1]       [,2]
+  #  [1,] 0.08427399 0.00548772
+  #  [2,] 0.00548772 0.09090715
+  #
+  #  [[2]]
+  #  [,1]       [,2]
+  #  [1,]  0.1641373 -0.1673806
+  #  [2,] -0.1673806  0.7508951
+  # nolint end
+  data <- list(list(-0.50219235, 0.1315312), list(-0.07891709, 0.8867848),
+               list(0.11697127, 0.3186301), list(-0.58179068, 0.7145327),
+               list(2.17474057, 3.6401379), list(3.08988614, 4.0962745),
+               list(2.79836605, 4.7398405), list(3.12337950, 3.9706833),
+               list(2.61114575, 4.5108563), list(2.08618581, 6.3102968))
+  df <- createDataFrame(data, c("x1", "x2"))
+  model <- spark.gaussianMixture(df, ~ x1 + x2, k = 2)
+  stats <- summary(model)
+  rLambda <- c(0.4, 0.6)
+  rMu <- c(-0.2614822, 0.5128697, 2.647284, 4.544682)
+  rSigma <- c(0.08427399, 0.00548772, 0.00548772, 0.09090715,
+              0.1641373, -0.1673806, -0.1673806, 0.7508951)
+  expect_equal(stats$lambda, rLambda)
+  expect_equal(unlist(stats$mu), rMu, tolerance = 1e-3)
+  expect_equal(unlist(stats$sigma), rSigma, tolerance = 1e-3)
+  p <- collect(select(predict(model, df), "prediction"))
+  expect_equal(p$prediction, c(0, 0, 0, 0, 1, 1, 1, 1, 1, 1))
+
+  # Test model save/load
+  modelPath <- tempfile(pattern = "spark-gaussianMixture", fileext = ".tmp")
+  write.ml(model, modelPath)
+  expect_error(write.ml(model, modelPath))
+  write.ml(model, modelPath, overwrite = TRUE)
+  model2 <- read.ml(modelPath)
+  stats2 <- summary(model2)
+  expect_equal(stats$lambda, stats2$lambda)
+  expect_equal(unlist(stats$mu), unlist(stats2$mu))
+  expect_equal(unlist(stats$sigma), unlist(stats2$sigma))
+
+  unlink(modelPath)
+})
+
 sparkR.session.stop()

http://git-wip-us.apache.org/repos/asf/spark/blob/4d92af31/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
new file mode 100644
index 0000000..1e8b3bb
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.clustering.{GaussianMixture, GaussianMixtureModel}
+import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.linalg.Vector
+import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter}
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.functions._
+
+private[r] class GaussianMixtureWrapper private (
+    val pipeline: PipelineModel,
+    val dim: Int,
+    val isLoaded: Boolean = false) extends MLWritable {
+
+  private val gmm: GaussianMixtureModel = pipeline.stages(1).asInstanceOf[GaussianMixtureModel]
+
+  lazy val k: Int = gmm.getK
+
+  lazy val lambda: Array[Double] = gmm.weights
+
+  lazy val mu: Array[Double] = gmm.gaussians.flatMap(_.mean.toArray)
+
+  lazy val sigma: Array[Double] = gmm.gaussians.flatMap(_.cov.toArray)
+
+  lazy val vectorToArray = udf { probability: Vector => probability.toArray }
+  lazy val posterior: DataFrame = gmm.summary.probability
+    .withColumn("posterior", vectorToArray(col(gmm.summary.probabilityCol)))
+    .drop(gmm.summary.probabilityCol)
+
+  def transform(dataset: Dataset[_]): DataFrame = {
+    pipeline.transform(dataset).drop(gmm.getFeaturesCol)
+  }
+
+  override def write: MLWriter = new GaussianMixtureWrapper.GaussianMixtureWrapperWriter(this)
+
+}
+
+private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapper] {
+
+  def fit(
+      data: DataFrame,
+      formula: String,
+      k: Int,
+      maxIter: Int,
+      tol: Double): GaussianMixtureWrapper = {
+
+    val rFormulaModel = new RFormula()
+      .setFormula(formula)
+      .setFeaturesCol("features")
+      .fit(data)
+
+    // get feature names from output schema
+    val schema = rFormulaModel.transform(data).schema
+    val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
+      .attributes.get
+    val features = featureAttrs.map(_.name.get)
+    val dim = features.length
+
+    val gm = new GaussianMixture()
+      .setK(k)
+      .setMaxIter(maxIter)
+      .setTol(tol)
+
+    val pipeline = new Pipeline()
+      .setStages(Array(rFormulaModel, gm))
+      .fit(data)
+
+    new GaussianMixtureWrapper(pipeline, dim)
+  }
+
+  override def read: MLReader[GaussianMixtureWrapper] = new GaussianMixtureWrapperReader
+
+  override def load(path: String): GaussianMixtureWrapper = super.load(path)
+
+  class GaussianMixtureWrapperWriter(instance: GaussianMixtureWrapper) extends MLWriter {
+
+    override protected def saveImpl(path: String): Unit = {
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+
+      val rMetadata = ("class" -> instance.getClass.getName) ~
+        ("dim" -> instance.dim)
+      val rMetadataJson: String = compact(render(rMetadata))
+
+      sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+      instance.pipeline.save(pipelinePath)
+    }
+  }
+
+  class GaussianMixtureWrapperReader extends MLReader[GaussianMixtureWrapper] {
+
+    override def load(path: String): GaussianMixtureWrapper = {
+      implicit val format = DefaultFormats
+      val rMetadataPath = new Path(path, "rMetadata").toString
+      val pipelinePath = new Path(path, "pipeline").toString
+      val pipeline = PipelineModel.load(pipelinePath)
+
+      val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+      val rMetadata = parse(rMetadataStr)
+      val dim = (rMetadata \ "dim").extract[Int]
+      new GaussianMixtureWrapper(pipeline, dim, isLoaded = true)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/4d92af31/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index f9a44d6..88ac26b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -46,6 +46,8 @@ private[r] object RWrappers extends MLReader[Object] {
         KMeansWrapper.load(path)
       case "org.apache.spark.ml.r.IsotonicRegressionWrapper" =>
         IsotonicRegressionWrapper.load(path)
+      case "org.apache.spark.ml.r.GaussianMixtureWrapper" =>
+        GaussianMixtureWrapper.load(path)
       case _ =>
         throw new SparkException(s"SparkR read.ml does not support load $className")
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message