spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yli...@apache.org
Subject spark git commit: [SPARK-19291][SPARKR][ML] spark.gaussianMixture supports output log-likelihood.
Date Sun, 22 Jan 2017 05:26:19 GMT
Repository: spark
Updated Branches:
  refs/heads/master 3dcad9fab -> 0c589e371


[SPARK-19291][SPARKR][ML] spark.gaussianMixture supports output log-likelihood.

## What changes were proposed in this pull request?
```spark.gaussianMixture``` supports output total log-likelihood for the model like R ```mvnormalmixEM```.

## How was this patch tested?
R unit test.

Author: Yanbo Liang <ybliang8@gmail.com>

Closes #16646 from yanboliang/spark-19291.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0c589e37
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0c589e37
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0c589e37

Branch: refs/heads/master
Commit: 0c589e3713655f25547d6945a40786da900ec2fc
Parents: 3dcad9f
Author: Yanbo Liang <ybliang8@gmail.com>
Authored: Sat Jan 21 21:26:14 2017 -0800
Committer: Yanbo Liang <ybliang8@gmail.com>
Committed: Sat Jan 21 21:26:14 2017 -0800

----------------------------------------------------------------------
 R/pkg/R/mllib_clustering.R                              |  5 +++--
 R/pkg/inst/tests/testthat/test_mllib_clustering.R       |  7 +++++++
 .../org/apache/spark/ml/r/GaussianMixtureWrapper.scala  | 12 +++++++++---
 3 files changed, 19 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0c589e37/R/pkg/R/mllib_clustering.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/mllib_clustering.R b/R/pkg/R/mllib_clustering.R
index fb8d9e7..fa40f9d 100644
--- a/R/pkg/R/mllib_clustering.R
+++ b/R/pkg/R/mllib_clustering.R
@@ -98,7 +98,7 @@ setMethod("spark.gaussianMixture", signature(data = "SparkDataFrame", formula
=
 #' @param object a fitted gaussian mixture model.
 #' @return \code{summary} returns summary of the fitted model, which is a list.
 #'         The list includes the model's \code{lambda} (lambda), \code{mu} (mu),
-#'         \code{sigma} (sigma), and \code{posterior} (posterior).
+#'         \code{sigma} (sigma), \code{loglik} (loglik), and \code{posterior} (posterior).
 #' @aliases spark.gaussianMixture,SparkDataFrame,formula-method
 #' @rdname spark.gaussianMixture
 #' @export
@@ -112,6 +112,7 @@ setMethod("summary", signature(object = "GaussianMixtureModel"),
             sigmaList <- callJMethod(jobj, "sigma")
             k <- callJMethod(jobj, "k")
             dim <- callJMethod(jobj, "dim")
+            loglik <- callJMethod(jobj, "logLikelihood")
             mu <- c()
             for (i in 1 : k) {
               start <- (i - 1) * dim + 1
@@ -129,7 +130,7 @@ setMethod("summary", signature(object = "GaussianMixtureModel"),
             } else {
               dataFrame(callJMethod(jobj, "posterior"))
             }
-            list(lambda = lambda, mu = mu, sigma = sigma,
+            list(lambda = lambda, mu = mu, sigma = sigma, loglik = loglik,
                  posterior = posterior, is.loaded = is.loaded)
           })
 

http://git-wip-us.apache.org/repos/asf/spark/blob/0c589e37/R/pkg/inst/tests/testthat/test_mllib_clustering.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_mllib_clustering.R b/R/pkg/inst/tests/testthat/test_mllib_clustering.R
index cfbdea5..9de8362 100644
--- a/R/pkg/inst/tests/testthat/test_mllib_clustering.R
+++ b/R/pkg/inst/tests/testthat/test_mllib_clustering.R
@@ -56,6 +56,10 @@ test_that("spark.gaussianMixture", {
   #            [,1]     [,2]
   #  [1,] 0.2961543 0.160783
   #  [2,] 0.1607830 1.008878
+  #
+  #' model$loglik
+  #
+  #  [1] -46.89499
   # nolint end
   data <- list(list(-0.6264538, 0.1836433), list(-0.8356286, 1.5952808),
                list(0.3295078, -0.8204684), list(0.4874291, 0.7383247),
@@ -72,9 +76,11 @@ test_that("spark.gaussianMixture", {
   rMu <- c(0.11731091, -0.06192351, 10.363673, 9.897081)
   rSigma <- c(0.62049934, 0.06880802, 0.06880802, 1.27431874,
               0.2961543, 0.160783, 0.1607830, 1.008878)
+  rLoglik <- -46.89499
   expect_equal(stats$lambda, rLambda, tolerance = 1e-3)
   expect_equal(unlist(stats$mu), rMu, tolerance = 1e-3)
   expect_equal(unlist(stats$sigma), rSigma, tolerance = 1e-3)
+  expect_equal(unlist(stats$loglik), rLoglik, tolerance = 1e-3)
   p <- collect(select(predict(model, df), "prediction"))
   expect_equal(p$prediction, c(0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1))
 
@@ -88,6 +94,7 @@ test_that("spark.gaussianMixture", {
   expect_equal(stats$lambda, stats2$lambda)
   expect_equal(unlist(stats$mu), unlist(stats2$mu))
   expect_equal(unlist(stats$sigma), unlist(stats2$sigma))
+  expect_equal(unlist(stats$loglik), unlist(stats2$loglik))
 
   unlink(modelPath)
 })

http://git-wip-us.apache.org/repos/asf/spark/blob/0c589e37/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
index b708702..9a98a8b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
@@ -34,6 +34,7 @@ import org.apache.spark.sql.functions._
 private[r] class GaussianMixtureWrapper private (
     val pipeline: PipelineModel,
     val dim: Int,
+    val logLikelihood: Double,
     val isLoaded: Boolean = false) extends MLWritable {
 
   private val gmm: GaussianMixtureModel = pipeline.stages(1).asInstanceOf[GaussianMixtureModel]
@@ -91,7 +92,10 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp
       .setStages(Array(rFormulaModel, gm))
       .fit(data)
 
-    new GaussianMixtureWrapper(pipeline, dim)
+    val gmm: GaussianMixtureModel = pipeline.stages(1).asInstanceOf[GaussianMixtureModel]
+    val logLikelihood: Double = gmm.summary.logLikelihood
+
+    new GaussianMixtureWrapper(pipeline, dim, logLikelihood)
   }
 
   override def read: MLReader[GaussianMixtureWrapper] = new GaussianMixtureWrapperReader
@@ -105,7 +109,8 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp
       val pipelinePath = new Path(path, "pipeline").toString
 
       val rMetadata = ("class" -> instance.getClass.getName) ~
-        ("dim" -> instance.dim)
+        ("dim" -> instance.dim) ~
+        ("logLikelihood" -> instance.logLikelihood)
       val rMetadataJson: String = compact(render(rMetadata))
 
       sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
@@ -124,7 +129,8 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp
       val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
       val rMetadata = parse(rMetadataStr)
       val dim = (rMetadata \ "dim").extract[Int]
-      new GaussianMixtureWrapper(pipeline, dim, isLoaded = true)
+      val logLikelihood = (rMetadata \ "logLikelihood").extract[Double]
+      new GaussianMixtureWrapper(pipeline, dim, logLikelihood, isLoaded = true)
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message