spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject git commit: [SPARK-3130][MLLIB] detect negative values in naive Bayes
Date Wed, 20 Aug 2014 04:01:36 GMT
Repository: spark
Updated Branches:
  refs/heads/master 0e3ab94d4 -> 068b6fe6a


[SPARK-3130][MLLIB] detect negative values in naive Bayes

because NB treats feature values as term frequencies. jkbradley

Author: Xiangrui Meng <meng@databricks.com>

Closes #2038 from mengxr/nb-neg and squashes the following commits:

52c37c3 [Xiangrui Meng] address comments
65f892d [Xiangrui Meng] detect negative values in nb


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/068b6fe6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/068b6fe6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/068b6fe6

Branch: refs/heads/master
Commit: 068b6fe6a10eb1c6b2102d88832203267f030e85
Parents: 0e3ab94
Author: Xiangrui Meng <meng@databricks.com>
Authored: Tue Aug 19 21:01:23 2014 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Tue Aug 19 21:01:23 2014 -0700

----------------------------------------------------------------------
 docs/mllib-naive-bayes.md                       |  3 ++-
 .../spark/mllib/classification/NaiveBayes.scala | 28 ++++++++++++++++----
 .../mllib/classification/NaiveBayesSuite.scala  | 28 ++++++++++++++++++++
 3 files changed, 53 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/068b6fe6/docs/mllib-naive-bayes.md
----------------------------------------------------------------------
diff --git a/docs/mllib-naive-bayes.md b/docs/mllib-naive-bayes.md
index 86d94ae..7f9d4c6 100644
--- a/docs/mllib-naive-bayes.md
+++ b/docs/mllib-naive-bayes.md
@@ -17,7 +17,8 @@ Bayes](http://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bay
 which is typically used for [document
 classification](http://nlp.stanford.edu/IR-book/html/htmledition/naive-bayes-text-classification-1.html).
 Within that context, each observation is a document and each
-feature represents a term whose value is the frequency of the term. 
+feature represents a term whose value is the frequency of the term.
+Feature values must be nonnegative to represent term frequencies.
 [Additive smoothing](http://en.wikipedia.org/wiki/Lidstone_smoothing) can be used by
 setting the parameter $\lambda$ (default to $1.0$). For document classification, the input
feature
 vectors are usually sparse, and sparse vectors should be supplied as input to take advantage
of

http://git-wip-us.apache.org/repos/asf/spark/blob/068b6fe6/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 6c7be0a..8c8e4a1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -19,9 +19,9 @@ package org.apache.spark.mllib.classification
 
 import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax,
sum => brzSum}
 
-import org.apache.spark.Logging
+import org.apache.spark.{SparkException, Logging}
 import org.apache.spark.SparkContext._
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.rdd.RDD
 
@@ -73,7 +73,7 @@ class NaiveBayesModel private[mllib] (
  * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds
of
  * discrete data.  For example, by converting documents into TF-IDF vectors, it can be used
for
  * document classification.  By making every vector a 0-1 vector, it can also be used as
- * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
+ * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). The input feature values must be nonnegative.
  */
 class NaiveBayes private (private var lambda: Double) extends Serializable with Logging {
 
@@ -91,12 +91,30 @@ class NaiveBayes private (private var lambda: Double) extends Serializable
with
    * @param data RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
    */
   def run(data: RDD[LabeledPoint]) = {
+    val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
+      val values = v match {
+        case sv: SparseVector =>
+          sv.values
+        case dv: DenseVector =>
+          dv.values
+      }
+      if (!values.forall(_ >= 0.0)) {
+        throw new SparkException(s"Naive Bayes requires nonnegative feature values but found
$v.")
+      }
+    }
+
     // Aggregates term frequencies per label.
     // TODO: Calling combineByKey and collect creates two stages, we can implement something
     // TODO: similar to reduceByKeyLocally to save one stage.
     val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])](
-      createCombiner = (v: Vector) => (1L, v.toBreeze.toDenseVector),
-      mergeValue = (c: (Long, BDV[Double]), v: Vector) => (c._1 + 1L, c._2 += v.toBreeze),
+      createCombiner = (v: Vector) => {
+        requireNonnegativeValues(v)
+        (1L, v.toBreeze.toDenseVector)
+      },
+      mergeValue = (c: (Long, BDV[Double]), v: Vector) => {
+        requireNonnegativeValues(v)
+        (c._1 + 1L, c._2 += v.toBreeze)
+      },
       mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) =>
         (c1._1 + c2._1, c1._2 += c2._2)
     ).collect()

http://git-wip-us.apache.org/repos/asf/spark/blob/068b6fe6/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index 06cdd04..80989bc 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -21,6 +21,7 @@ import scala.util.Random
 
 import org.scalatest.FunSuite
 
+import org.apache.spark.SparkException
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
@@ -95,6 +96,33 @@ class NaiveBayesSuite extends FunSuite with LocalSparkContext {
     // Test prediction on Array.
     validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
   }
+
+  test("detect negative values") {
+    val dense = Seq(
+      LabeledPoint(1.0, Vectors.dense(1.0)),
+      LabeledPoint(0.0, Vectors.dense(-1.0)),
+      LabeledPoint(1.0, Vectors.dense(1.0)),
+      LabeledPoint(1.0, Vectors.dense(0.0)))
+    intercept[SparkException] {
+      NaiveBayes.train(sc.makeRDD(dense, 2))
+    }
+    val sparse = Seq(
+      LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
+      LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(-1.0))),
+      LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
+      LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty)))
+    intercept[SparkException] {
+      NaiveBayes.train(sc.makeRDD(sparse, 2))
+    }
+    val nan = Seq(
+      LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
+      LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(Double.NaN))),
+      LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
+      LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty)))
+    intercept[SparkException] {
+      NaiveBayes.train(sc.makeRDD(nan, 2))
+    }
+  }
 }
 
 class NaiveBayesClusterSuite extends FunSuite with LocalClusterSparkContext {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message