spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbrad...@apache.org
Subject spark git commit: [SPARK-7752] [MLLIB] Use lowercase letters for NaiveBayes.modelType
Date Thu, 21 May 2015 17:30:17 GMT
Repository: spark
Updated Branches:
  refs/heads/master a25c1ab8f -> 13348e21b


[SPARK-7752] [MLLIB] Use lowercase letters for NaiveBayes.modelType

to be consistent with other string names in MLlib. This PR also updates the implementation
to use vals instead of hardcoded strings. jkbradley leahmcguire

Author: Xiangrui Meng <meng@databricks.com>

Closes #6277 from mengxr/SPARK-7752 and squashes the following commits:

f38b662 [Xiangrui Meng] add another case _ back in test
ae5c66a [Xiangrui Meng] model type -> modelType
711d1c6 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7752
40ae53e [Xiangrui Meng] fix Java test suite
264a814 [Xiangrui Meng] add case _ back
3c456a8 [Xiangrui Meng] update NB user guide
17bba53 [Xiangrui Meng] update naive Bayes to use lowercase model type strings


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/13348e21
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/13348e21
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/13348e21

Branch: refs/heads/master
Commit: 13348e21b6b1c0df42c18b82b86c613291228863
Parents: a25c1ab
Author: Xiangrui Meng <meng@databricks.com>
Authored: Thu May 21 10:30:08 2015 -0700
Committer: Joseph K. Bradley <joseph@databricks.com>
Committed: Thu May 21 10:30:08 2015 -0700

----------------------------------------------------------------------
 docs/mllib-naive-bayes.md                       |  9 +--
 .../spark/mllib/classification/NaiveBayes.scala | 75 +++++++++++---------
 .../classification/JavaNaiveBayesSuite.java     |  4 +-
 .../mllib/classification/NaiveBayesSuite.scala  | 46 ++++++------
 4 files changed, 75 insertions(+), 59 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/13348e21/docs/mllib-naive-bayes.md
----------------------------------------------------------------------
diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md
index 9780ea5..56a2e9c 100644
--- a/docs/mllib-naive-bayes.md
+++ b/docs/mllib-naive-bayes.md
@@ -21,7 +21,7 @@ Within that context, each observation is a document and each
 feature represents a term whose value is the frequency of the term (in multinomial naive
Bayes) or
 a zero or one indicating whether the term was found in the document (in Bernoulli naive Bayes).
 Feature values must be nonnegative. The model type is selected with an optional parameter
-"Multinomial" or "Bernoulli" with "Multinomial" as the default.
+"multinomial" or "bernoulli" with "multinomial" as the default.
 [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by
 setting the parameter $\lambda$ (default to $1.0$). For document classification, the input
feature
 vectors are usually sparse, and sparse vectors should be supplied as input to take advantage
of
@@ -35,7 +35,7 @@ sparsity. Since the training data is only used once, it is not necessary
to cach
 [NaiveBayes](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayes$) implements
 multinomial naive Bayes. It takes an RDD of
 [LabeledPoint](api/scala/index.html#org.apache.spark.mllib.regression.LabeledPoint) and an
optional
-smoothing parameter `lambda` as input, an optional model type parameter (default is Multinomial),
and outputs a
+smoothing parameter `lambda` as input, an optional model type parameter (default is "multinomial"),
and outputs a
 [NaiveBayesModel](api/scala/index.html#org.apache.spark.mllib.classification.NaiveBayesModel),
which
 can be used for evaluation and prediction.
 
@@ -54,7 +54,7 @@ val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L)
 val training = splits(0)
 val test = splits(1)
 
-val model = NaiveBayes.train(training, lambda = 1.0, model = "Multinomial")
+val model = NaiveBayes.train(training, lambda = 1.0, model = "multinomial")
 
 val predictionAndLabel = test.map(p => (model.predict(p.features), p.label))
 val accuracy = 1.0 * predictionAndLabel.filter(x => x._1 == x._2).count() / test.count()
@@ -75,6 +75,8 @@ optionally smoothing parameter `lambda` as input, and output a
 can be used for evaluation and prediction.
 
 {% highlight java %}
+import scala.Tuple2;
+
 import org.apache.spark.api.java.JavaPairRDD;
 import org.apache.spark.api.java.JavaRDD;
 import org.apache.spark.api.java.function.Function;
@@ -82,7 +84,6 @@ import org.apache.spark.api.java.function.PairFunction;
 import org.apache.spark.mllib.classification.NaiveBayes;
 import org.apache.spark.mllib.classification.NaiveBayesModel;
 import org.apache.spark.mllib.regression.LabeledPoint;
-import scala.Tuple2;
 
 JavaRDD<LabeledPoint> training = ... // training set
 JavaRDD<LabeledPoint> test = ... // test set

http://git-wip-us.apache.org/repos/asf/spark/blob/13348e21/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index cffe9ef..f51ee36 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -25,13 +25,12 @@ import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods._
 
 import org.apache.spark.{Logging, SparkContext, SparkException}
-import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector,
Vectors}
+import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.util.{Loader, Saveable}
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.{DataFrame, SQLContext}
 
-
 /**
  * Model for Naive Bayes Classifiers.
  *
@@ -39,7 +38,7 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
  * @param pi log of class priors, whose dimension is C, number of labels
  * @param theta log of class conditional probabilities, whose dimension is C-by-D,
  *              where D is number of features
- * @param modelType The type of NB model to fit  can be "Multinomial" or "Bernoulli"
+ * @param modelType The type of NB model to fit  can be "multinomial" or "bernoulli"
  */
 class NaiveBayesModel private[mllib] (
     val labels: Array[Double],
@@ -48,11 +47,13 @@ class NaiveBayesModel private[mllib] (
     val modelType: String)
   extends ClassificationModel with Serializable with Saveable {
 
+  import NaiveBayes.{Bernoulli, Multinomial, supportedModelTypes}
+
   private val piVector = new DenseVector(pi)
-  private val thetaMatrix = new DenseMatrix(labels.size, theta(0).size, theta.flatten, true)
+  private val thetaMatrix = new DenseMatrix(labels.length, theta(0).length, theta.flatten,
true)
 
   private[mllib] def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]])
=
-    this(labels, pi, theta, "Multinomial")
+    this(labels, pi, theta, NaiveBayes.Multinomial)
 
   /** A Java-friendly constructor that takes three Iterable parameters. */
   private[mllib] def this(
@@ -61,12 +62,15 @@ class NaiveBayesModel private[mllib] (
       theta: JIterable[JIterable[Double]]) =
     this(labels.asScala.toArray, pi.asScala.toArray, theta.asScala.toArray.map(_.asScala.toArray))
 
+  require(supportedModelTypes.contains(modelType),
+    s"Invalid modelType $modelType. Supported modelTypes are $supportedModelTypes.")
+
   // Bernoulli scoring requires log(condprob) if 1, log(1-condprob) if 0.
   // This precomputes log(1.0 - exp(theta)) and its sum which are used for the linear algebra
   // application of this condition (in predict function).
   private val (thetaMinusNegTheta, negThetaSum) = modelType match {
-    case "Multinomial" => (None, None)
-    case "Bernoulli" =>
+    case Multinomial => (None, None)
+    case Bernoulli =>
       val negTheta = thetaMatrix.map(value => math.log(1.0 - math.exp(value)))
       val ones = new DenseVector(Array.fill(thetaMatrix.numCols){1.0})
       val thetaMinusNegTheta = thetaMatrix.map { value =>
@@ -75,7 +79,7 @@ class NaiveBayesModel private[mllib] (
       (Option(thetaMinusNegTheta), Option(negTheta.multiply(ones)))
     case _ =>
       // This should never happen.
-      throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
+      throw new UnknownError(s"Invalid modelType: $modelType.")
   }
 
   override def predict(testData: RDD[Vector]): RDD[Double] = {
@@ -88,15 +92,15 @@ class NaiveBayesModel private[mllib] (
 
   override def predict(testData: Vector): Double = {
     modelType match {
-      case "Multinomial" =>
+      case Multinomial =>
         val prob = thetaMatrix.multiply(testData)
         BLAS.axpy(1.0, piVector, prob)
         labels(prob.argmax)
-      case "Bernoulli" =>
+      case Bernoulli =>
         testData.foreachActive { (index, value) =>
           if (value != 0.0 && value != 1.0) {
             throw new SparkException(
-              s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
+              s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.")
           }
         }
         val prob = thetaMinusNegTheta.get.multiply(testData)
@@ -105,7 +109,7 @@ class NaiveBayesModel private[mllib] (
         labels(prob.argmax)
       case _ =>
         // This should never happen.
-        throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
+        throw new UnknownError(s"Invalid modelType: $modelType.")
     }
   }
 
@@ -230,16 +234,16 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
         s"($loadedClassName, $version).  Supported:\n" +
         s"  ($classNameV1_0, 1.0)")
     }
-    assert(model.pi.size == numClasses,
+    assert(model.pi.length == numClasses,
       s"NaiveBayesModel.load expected $numClasses classes," +
-        s" but class priors vector pi had ${model.pi.size} elements")
-    assert(model.theta.size == numClasses,
+        s" but class priors vector pi had ${model.pi.length} elements")
+    assert(model.theta.length == numClasses,
       s"NaiveBayesModel.load expected $numClasses classes," +
-        s" but class conditionals array theta had ${model.theta.size} elements")
-    assert(model.theta.forall(_.size == numFeatures),
+        s" but class conditionals array theta had ${model.theta.length} elements")
+    assert(model.theta.forall(_.length == numFeatures),
       s"NaiveBayesModel.load expected $numFeatures features," +
         s" but class conditionals array theta had elements of size:" +
-        s" ${model.theta.map(_.size).mkString(",")}")
+        s" ${model.theta.map(_.length).mkString(",")}")
     model
   }
 }
@@ -257,9 +261,11 @@ class NaiveBayes private (
     private var lambda: Double,
     private var modelType: String) extends Serializable with Logging {
 
-  def this(lambda: Double) = this(lambda, "Multinomial")
+  import NaiveBayes.{Bernoulli, Multinomial}
 
-  def this() = this(1.0, "Multinomial")
+  def this(lambda: Double) = this(lambda, NaiveBayes.Multinomial)
+
+  def this() = this(1.0, NaiveBayes.Multinomial)
 
   /** Set the smoothing parameter. Default: 1.0. */
   def setLambda(lambda: Double): NaiveBayes = {
@@ -272,12 +278,11 @@ class NaiveBayes private (
 
   /**
    * Set the model type using a string (case-sensitive).
-   * Supported options: "Multinomial" and "Bernoulli".
-   * (default: Multinomial)
+   * Supported options: "multinomial" (default) and "bernoulli".
    */
-  def setModelType(modelType:String): NaiveBayes = {
+  def setModelType(modelType: String): NaiveBayes = {
     require(NaiveBayes.supportedModelTypes.contains(modelType),
-      s"NaiveBayes was created with an unknown ModelType: $modelType")
+      s"NaiveBayes was created with an unknown modelType: $modelType.")
     this.modelType = modelType
     this
   }
@@ -308,7 +313,7 @@ class NaiveBayes private (
       }
       if (!values.forall(v => v == 0.0 || v == 1.0)) {
         throw new SparkException(
-          s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.")
+          s"Bernoulli naive Bayes requires 0 or 1 feature values but found $v.")
       }
     }
 
@@ -317,7 +322,7 @@ class NaiveBayes private (
     // TODO: similar to reduceByKeyLocally to save one stage.
     val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)](
       createCombiner = (v: Vector) => {
-        if (modelType == "Bernoulli") {
+        if (modelType == Bernoulli) {
           requireZeroOneBernoulliValues(v)
         } else {
           requireNonnegativeValues(v)
@@ -352,11 +357,11 @@ class NaiveBayes private (
       labels(i) = label
       pi(i) = math.log(n + lambda) - piLogDenom
       val thetaLogDenom = modelType match {
-        case "Multinomial" => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
-        case "Bernoulli" => math.log(n + 2.0 * lambda)
+        case Multinomial => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
+        case Bernoulli => math.log(n + 2.0 * lambda)
         case _ =>
           // This should never happen.
-          throw new UnknownError(s"NaiveBayes was created with an unknown ModelType: $modelType")
+          throw new UnknownError(s"Invalid modelType: $modelType.")
       }
       var j = 0
       while (j < numFeatures) {
@@ -375,8 +380,14 @@ class NaiveBayes private (
  */
 object NaiveBayes {
 
+  /** String name for multinomial model type. */
+  private[classification] val Multinomial: String = "multinomial"
+
+  /** String name for Bernoulli model type. */
+  private[classification] val Bernoulli: String = "bernoulli"
+
   /* Set of modelTypes that NaiveBayes supports */
-  private[mllib] val supportedModelTypes = Set("Multinomial", "Bernoulli")
+  private[classification] val supportedModelTypes = Set(Multinomial, Bernoulli)
 
   /**
    * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
@@ -406,7 +417,7 @@ object NaiveBayes {
    * @param lambda The smoothing parameter
    */
   def train(input: RDD[LabeledPoint], lambda: Double): NaiveBayesModel = {
-    new NaiveBayes(lambda, "Multinomial").run(input)
+    new NaiveBayes(lambda, Multinomial).run(input)
   }
 
   /**
@@ -429,7 +440,7 @@ object NaiveBayes {
    */
   def train(input: RDD[LabeledPoint], lambda: Double, modelType: String): NaiveBayesModel
= {
     require(supportedModelTypes.contains(modelType),
-      s"NaiveBayes was created with an unknown ModelType: $modelType")
+      s"NaiveBayes was created with an unknown modelType: $modelType.")
     new NaiveBayes(lambda, modelType).run(input)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/13348e21/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
----------------------------------------------------------------------
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
index 71fb7f1..3771c0e 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
@@ -108,7 +108,7 @@ public class JavaNaiveBayesSuite implements Serializable {
   @Test
   public void testModelTypeSetters() {
     NaiveBayes nb = new NaiveBayes()
-        .setModelType("Bernoulli")
-        .setModelType("Multinomial");
+      .setModelType("bernoulli")
+      .setModelType("multinomial");
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/13348e21/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index 40a79a1..c111a78 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -19,9 +19,8 @@ package org.apache.spark.mllib.classification
 
 import scala.util.Random
 
-import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax,
sum => brzSum, Axis}
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax,
sum => brzSum}
 import breeze.stats.distributions.{Multinomial => BrzMultinomial}
-
 import org.scalatest.FunSuite
 
 import org.apache.spark.SparkException
@@ -30,9 +29,10 @@ import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
 import org.apache.spark.util.Utils
 
-
 object NaiveBayesSuite {
 
+  import NaiveBayes.{Multinomial, Bernoulli}
+
   private def calcLabel(p: Double, pi: Array[Double]): Int = {
     var sum = 0.0
     for (j <- 0 until pi.length) {
@@ -48,7 +48,7 @@ object NaiveBayesSuite {
     theta: Array[Array[Double]],  // CXD
     nPoints: Int,
     seed: Int,
-    modelType: String = "Multinomial",
+    modelType: String = Multinomial,
     sample: Int = 10): Seq[LabeledPoint] = {
     val D = theta(0).length
     val rnd = new Random(seed)
@@ -58,10 +58,10 @@ object NaiveBayesSuite {
     for (i <- 0 until nPoints) yield {
       val y = calcLabel(rnd.nextDouble(), _pi)
       val xi = modelType match {
-        case "Bernoulli" => Array.tabulate[Double] (D) { j =>
+        case Bernoulli => Array.tabulate[Double] (D) { j =>
             if (rnd.nextDouble () < _theta(y)(j) ) 1 else 0
         }
-        case "Multinomial" =>
+        case Multinomial =>
           val mult = BrzMultinomial(BDV(_theta(y)))
           val emptyMap = (0 until D).map(x => (x, 0.0)).toMap
           val counts = emptyMap ++ mult.sample(sample).groupBy(x => x).map {
@@ -70,7 +70,7 @@ object NaiveBayesSuite {
           counts.toArray.sortBy(_._1).map(_._2)
         case _ =>
           // This should never happen.
-          throw new UnknownError(s"NaiveBayesSuite found unknown ModelType: $modelType")
+          throw new UnknownError(s"Invalid modelType: $modelType.")
       }
 
       LabeledPoint(y, Vectors.dense(xi))
@@ -79,17 +79,17 @@ object NaiveBayesSuite {
 
   /** Bernoulli NaiveBayes with binary labels, 3 features */
   private val binaryBernoulliModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
-    pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)),
-    "Bernoulli")
+    pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Bernoulli)
 
   /** Multinomial NaiveBayes with binary labels, 3 features */
   private val binaryMultinomialModel = new NaiveBayesModel(labels = Array(0.0, 1.0),
-    pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)),
-    "Multinomial")
+    pi = Array(0.2, 0.8), theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), Multinomial)
 }
 
 class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
 
+  import NaiveBayes.{Multinomial, Bernoulli}
+
   def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
     val numOfPredictions = predictions.zip(input).count {
       case (prediction, expected) =>
@@ -117,6 +117,11 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
     }
   }
 
+  test("model types") {
+    assert(Multinomial === "multinomial")
+    assert(Bernoulli === "bernoulli")
+  }
+
   test("get, set params") {
     val nb = new NaiveBayes()
     nb.setLambda(2.0)
@@ -134,16 +139,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
       Array(0.10, 0.10, 0.70, 0.10)  // label 2
     ).map(_.map(math.log))
 
-    val testData = NaiveBayesSuite.generateNaiveBayesInput(
-      pi, theta, nPoints, 42, "Multinomial")
+    val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42, Multinomial)
     val testRDD = sc.parallelize(testData, 2)
     testRDD.cache()
 
-    val model = NaiveBayes.train(testRDD, 1.0, "Multinomial")
+    val model = NaiveBayes.train(testRDD, 1.0, Multinomial)
     validateModelFit(pi, theta, model)
 
     val validationData = NaiveBayesSuite.generateNaiveBayesInput(
-      pi, theta, nPoints, 17, "Multinomial")
+      pi, theta, nPoints, 17, Multinomial)
     val validationRDD = sc.parallelize(validationData, 2)
 
     // Test prediction on RDD.
@@ -163,15 +167,15 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
     ).map(_.map(math.log))
 
     val testData = NaiveBayesSuite.generateNaiveBayesInput(
-      pi, theta, nPoints, 45, "Bernoulli")
+      pi, theta, nPoints, 45, Bernoulli)
     val testRDD = sc.parallelize(testData, 2)
     testRDD.cache()
 
-    val model = NaiveBayes.train(testRDD, 1.0, "Bernoulli")
+    val model = NaiveBayes.train(testRDD, 1.0, Bernoulli)
     validateModelFit(pi, theta, model)
 
     val validationData = NaiveBayesSuite.generateNaiveBayesInput(
-      pi, theta, nPoints, 20, "Bernoulli")
+      pi, theta, nPoints, 20, Bernoulli)
     val validationRDD = sc.parallelize(validationData, 2)
 
     // Test prediction on RDD.
@@ -216,7 +220,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
       LabeledPoint(1.0, Vectors.dense(0.0)))
 
     intercept[SparkException] {
-      NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, "Bernoulli")
+      NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, Bernoulli)
     }
 
     val okTrain = Seq(
@@ -235,7 +239,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
       Vectors.dense(1.0),
       Vectors.dense(0.0))
 
-    val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, "Bernoulli")
+    val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, Bernoulli)
     intercept[SparkException] {
       model.predict(sc.makeRDD(badPredict, 2)).collect()
     }
@@ -275,7 +279,7 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
       assert(model.labels === sameModel.labels)
       assert(model.pi === sameModel.pi)
       assert(model.theta === sameModel.theta)
-      assert(model.modelType === "Multinomial")
+      assert(model.modelType === Multinomial)
     } finally {
       Utils.deleteRecursively(tempDir)
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message