spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sro...@apache.org
Subject spark git commit: [SPARK-22450][WIP][CORE][MLLIB][FOLLOWUP] Safely register MultivariateGaussian
Date Thu, 15 Nov 2018 15:22:34 GMT
Repository: spark
Updated Branches:
  refs/heads/master 9610efc25 -> 91405b3b6


[SPARK-22450][WIP][CORE][MLLIB][FOLLOWUP] Safely register MultivariateGaussian

## What changes were proposed in this pull request?
register following classes in Kryo:
"org.apache.spark.ml.stat.distribution.MultivariateGaussian",
"org.apache.spark.mllib.stat.distribution.MultivariateGaussian"

## How was this patch tested?
added tests

Due to existing module dependency, I can not import spark-core in mllib-local's testsuits,
so I do not add testsuite in `org.apache.spark.ml.stat.distribution.MultivariateGaussianSuite`.
And I notice that class `ClusterStats` in `ClusteringEvaluator` is registered in a different
way, should it be modified to keep in line with others in ML? srowen

Closes #22974 from zhengruifeng/kryo_MultivariateGaussian.

Authored-by: zhengruifeng <ruifengz@foxmail.com>
Signed-off-by: Sean Owen <sean.owen@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/91405b3b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/91405b3b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/91405b3b

Branch: refs/heads/master
Commit: 91405b3b6eb4fa8047123d951859b6e2a1e46b6a
Parents: 9610efc
Author: zhengruifeng <ruifengz@foxmail.com>
Authored: Thu Nov 15 09:22:31 2018 -0600
Committer: Sean Owen <sean.owen@databricks.com>
Committed: Thu Nov 15 09:22:31 2018 -0600

----------------------------------------------------------------------
 .../spark/serializer/KryoSerializer.scala       | 10 ++++++++-
 .../distribution/MultivariateGaussian.scala     |  4 ++--
 .../distribution/MultivariateGaussian.scala     |  4 ++--
 .../spark/ml/attribute/AttributeSuite.scala     | 19 ++++++++++++++++-
 .../MultivariateGaussianSuite.scala             | 22 +++++++++++++++++++-
 5 files changed, 52 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/91405b3b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 3795d5c..66812a5 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -215,6 +215,12 @@ class KryoSerializer(conf: SparkConf)
     // We can't load those class directly in order to avoid unnecessary jar dependencies.
     // We load them safely, ignore it if the class not found.
     Seq(
+      "org.apache.spark.ml.attribute.Attribute",
+      "org.apache.spark.ml.attribute.AttributeGroup",
+      "org.apache.spark.ml.attribute.BinaryAttribute",
+      "org.apache.spark.ml.attribute.NominalAttribute",
+      "org.apache.spark.ml.attribute.NumericAttribute",
+
       "org.apache.spark.ml.feature.Instance",
       "org.apache.spark.ml.feature.LabeledPoint",
       "org.apache.spark.ml.feature.OffsetInstance",
@@ -224,6 +230,7 @@ class KryoSerializer(conf: SparkConf)
       "org.apache.spark.ml.linalg.SparseMatrix",
       "org.apache.spark.ml.linalg.SparseVector",
       "org.apache.spark.ml.linalg.Vector",
+      "org.apache.spark.ml.stat.distribution.MultivariateGaussian",
       "org.apache.spark.ml.tree.impl.TreePoint",
       "org.apache.spark.mllib.clustering.VectorWithNorm",
       "org.apache.spark.mllib.linalg.DenseMatrix",
@@ -232,7 +239,8 @@ class KryoSerializer(conf: SparkConf)
       "org.apache.spark.mllib.linalg.SparseMatrix",
       "org.apache.spark.mllib.linalg.SparseVector",
       "org.apache.spark.mllib.linalg.Vector",
-      "org.apache.spark.mllib.regression.LabeledPoint"
+      "org.apache.spark.mllib.regression.LabeledPoint",
+      "org.apache.spark.mllib.stat.distribution.MultivariateGaussian"
     ).foreach { name =>
       try {
         val clazz = Utils.classForName(name)

http://git-wip-us.apache.org/repos/asf/spark/blob/91405b3b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala
----------------------------------------------------------------------
diff --git a/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala
b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala
index 3167e0c..e7f7a8e 100644
--- a/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala
+++ b/mllib-local/src/main/scala/org/apache/spark/ml/stat/distribution/MultivariateGaussian.scala
@@ -48,14 +48,14 @@ class MultivariateGaussian @Since("2.0.0") (
     this(Vectors.fromBreeze(mean), Matrices.fromBreeze(cov))
   }
 
-  private val breezeMu = mean.asBreeze.toDenseVector
+  @transient private lazy val breezeMu = mean.asBreeze.toDenseVector
 
   /**
    * Compute distribution dependent constants:
    *    rootSigmaInv = D^(-1/2)^ * U.t, where sigma = U * D * U.t
    *    u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
    */
-  private val (rootSigmaInv: BDM[Double], u: Double) = calculateCovarianceConstants
+  @transient private lazy val (rootSigmaInv: BDM[Double], u: Double) = calculateCovarianceConstants
 
   /**
    * Returns density of this multivariate Gaussian at given point, x

http://git-wip-us.apache.org/repos/asf/spark/blob/91405b3b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
index 4cf662e..9a746dc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala
@@ -43,7 +43,7 @@ class MultivariateGaussian @Since("1.3.0") (
   require(sigma.numCols == sigma.numRows, "Covariance matrix must be square")
   require(mu.size == sigma.numCols, "Mean vector length must match covariance matrix size")
 
-  private val breezeMu = mu.asBreeze.toDenseVector
+  @transient private lazy val breezeMu = mu.asBreeze.toDenseVector
 
   /**
    * private[mllib] constructor
@@ -60,7 +60,7 @@ class MultivariateGaussian @Since("1.3.0") (
    *    rootSigmaInv = D^(-1/2)^ * U.t, where sigma = U * D * U.t
    *    u = log((2*pi)^(-k/2)^ * det(sigma)^(-1/2)^)
    */
-  private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants
+  @transient private lazy val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants
 
   /**
    * Returns density of this multivariate Gaussian at given point, x

http://git-wip-us.apache.org/repos/asf/spark/blob/91405b3b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
index 6355e0f..eb5f3ca 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala
@@ -17,7 +17,8 @@
 
 package org.apache.spark.ml.attribute
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkConf, SparkFunSuite}
+import org.apache.spark.serializer.KryoSerializer
 import org.apache.spark.sql.types._
 
 class AttributeSuite extends SparkFunSuite {
@@ -221,4 +222,20 @@ class AttributeSuite extends SparkFunSuite {
     val decimalFldWithMeta = new StructField("x", DecimalType(38, 18), false, metadata)
     assert(Attribute.fromStructField(decimalFldWithMeta).isNumeric)
   }
+
+  test("Kryo class register") {
+    val conf = new SparkConf(false)
+    conf.set("spark.kryo.registrationRequired", "true")
+
+    val ser = new KryoSerializer(conf).newInstance()
+
+    val numericAttr = new NumericAttribute(Some("numeric"), Some(1), Some(1.0), Some(2.0))
+    val nominalAttr = new NominalAttribute(Some("nominal"), Some(2), Some(false))
+    val binaryAttr = new BinaryAttribute(Some("binary"), Some(3), Some(Array("i", "j")))
+
+    Seq(numericAttr, nominalAttr, binaryAttr).foreach { i =>
+      val i2 = ser.deserialize[Attribute](ser.serialize(i))
+      assert(i === i2)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/91405b3b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
index 669d442..5b4a260 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussianSuite.scala
@@ -17,10 +17,11 @@
 
 package org.apache.spark.mllib.stat.distribution
 
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkConf, SparkFunSuite}
 import org.apache.spark.mllib.linalg.{Matrices, Vectors}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.serializer.KryoSerializer
 
 class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext {
   test("univariate") {
@@ -80,4 +81,23 @@ class MultivariateGaussianSuite extends SparkFunSuite with MLlibTestSparkContext
     assert(dist.pdf(x) ~== 7.154782224045512E-5 absTol 1E-9)
   }
 
+  test("Kryo class register") {
+    val conf = new SparkConf(false)
+    conf.set("spark.kryo.registrationRequired", "true")
+
+    val ser = new KryoSerializer(conf).newInstance()
+
+    val mu = Vectors.dense(0.0, 0.0)
+    val sigma1 = Matrices.dense(2, 2, Array(1.0, 0.0, 0.0, 1.0))
+    val dist1 = new MultivariateGaussian(mu, sigma1)
+
+    val sigma2 = Matrices.dense(2, 2, Array(4.0, -1.0, -1.0, 2.0))
+    val dist2 = new MultivariateGaussian(mu, sigma2)
+
+    Seq(dist1, dist2).foreach { i =>
+      val i2 = ser.deserialize[MultivariateGaussian](ser.serialize(i))
+      assert(i.sigma === i2.sigma)
+      assert(i.mu === i2.mu)
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message