spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-7612] [MLLIB] update NB training to use mllib's BLAS
Date Thu, 14 May 2015 04:27:21 GMT
Repository: spark
Updated Branches:
  refs/heads/master 3113da9c7 -> d5f18de16


[SPARK-7612] [MLLIB] update NB training to use mllib's BLAS

This is similar to the changes to k-means, which gives us better control on the performance.
dbtsai

Author: Xiangrui Meng <meng@databricks.com>

Closes #6128 from mengxr/SPARK-7612 and squashes the following commits:

b5c24c5 [Xiangrui Meng] merge master
a90e3ec [Xiangrui Meng] update NB training to use mllib's BLAS


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d5f18de1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d5f18de1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d5f18de1

Branch: refs/heads/master
Commit: d5f18de1657bfabf5493011e0b2c7ec29c02c64c
Parents: 3113da9
Author: Xiangrui Meng <meng@databricks.com>
Authored: Wed May 13 21:27:17 2015 -0700
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Wed May 13 21:27:17 2015 -0700

----------------------------------------------------------------------
 .../spark/mllib/classification/NaiveBayes.scala | 43 +++++++++-----------
 1 file changed, 20 insertions(+), 23 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d5f18de1/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index b381dc2..af24ab6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -21,15 +21,13 @@ import java.lang.{Iterable => JIterable}
 
 import scala.collection.JavaConverters._
 
-import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax,
sum => brzSum, Axis}
+import breeze.linalg.{Axis, DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax,
sum => brzSum}
 import breeze.numerics.{exp => brzExp, log => brzLog}
-
 import org.json4s.JsonDSL._
 import org.json4s.jackson.JsonMethods._
-import org.json4s.{DefaultFormats, JValue}
 
 import org.apache.spark.{Logging, SparkContext, SparkException}
-import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector}
+import org.apache.spark.mllib.linalg.{BLAS, DenseVector, SparseVector, Vector}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.util.{Loader, Saveable}
 import org.apache.spark.rdd.RDD
@@ -90,13 +88,13 @@ class NaiveBayesModel private[mllib] (
     val brzData = testData.toBreeze
     modelType match {
       case "Multinomial" =>
-        labels (brzArgmax (brzPi + brzTheta * brzData) )
+        labels(brzArgmax(brzPi + brzTheta * brzData))
       case "Bernoulli" =>
         if (!brzData.forall(v => v == 0.0 || v == 1.0)) {
           throw new SparkException(
             s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
         }
-        labels (brzArgmax (brzPi +
+        labels(brzArgmax(brzPi +
           (brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get))
       case _ =>
         // This should never happen.
@@ -152,7 +150,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
       // Check schema explicitly since erasure makes it hard to use match-case for checking.
       checkSchema[Data](dataRDD.schema)
       val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1)
-      assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
+      assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
       val data = dataArray(0)
       val labels = data.getAs[Seq[Double]](0).toArray
       val pi = data.getAs[Seq[Double]](1).toArray
@@ -198,7 +196,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
       // Check schema explicitly since erasure makes it hard to use match-case for checking.
       checkSchema[Data](dataRDD.schema)
       val dataArray = dataRDD.select("labels", "pi", "theta").take(1)
-      assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
+      assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}")
       val data = dataArray(0)
       val labels = data.getAs[Seq[Double]](0).toArray
       val pi = data.getAs[Seq[Double]](1).toArray
@@ -288,10 +286,8 @@ class NaiveBayes private (
   def run(data: RDD[LabeledPoint]): NaiveBayesModel = {
     val requireNonnegativeValues: Vector => Unit = (v: Vector) => {
       val values = v match {
-        case SparseVector(size, indices, values) =>
-          values
-        case DenseVector(values) =>
-          values
+        case sv: SparseVector => sv.values
+        case dv: DenseVector => dv.values
       }
       if (!values.forall(_ >= 0.0)) {
         throw new SparkException(s"Naive Bayes requires nonnegative feature values but found
$v.")
@@ -300,10 +296,8 @@ class NaiveBayes private (
 
     val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => {
       val values = v match {
-        case SparseVector(size, indices, values) =>
-          values
-        case DenseVector(values) =>
-          values
+        case sv: SparseVector => sv.values
+        case dv: DenseVector => dv.values
       }
       if (!values.forall(v => v == 0.0 || v == 1.0)) {
         throw new SparkException(
@@ -314,21 +308,24 @@ class NaiveBayes private (
     // Aggregates term frequencies per label.
     // TODO: Calling combineByKey and collect creates two stages, we can implement something
     // TODO: similar to reduceByKeyLocally to save one stage.
-    val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])](
+    val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)](
       createCombiner = (v: Vector) => {
         if (modelType == "Bernoulli") {
           requireZeroOneBernoulliValues(v)
         } else {
           requireNonnegativeValues(v)
         }
-        (1L, v.toBreeze.toDenseVector)
+        (1L, v.copy.toDense)
       },
-      mergeValue = (c: (Long, BDV[Double]), v: Vector) => {
+      mergeValue = (c: (Long, DenseVector), v: Vector) => {
         requireNonnegativeValues(v)
-        (c._1 + 1L, c._2 += v.toBreeze)
+        BLAS.axpy(1.0, v, c._2)
+        (c._1 + 1L, c._2)
       },
-      mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) =>
-        (c1._1 + c2._1, c1._2 += c2._2)
+      mergeCombiners = (c1: (Long, DenseVector), c2: (Long, DenseVector)) => {
+        BLAS.axpy(1.0, c2._2, c1._2)
+        (c1._1 + c2._1, c1._2)
+      }
     ).collect()
 
     val numLabels = aggregated.length
@@ -348,7 +345,7 @@ class NaiveBayes private (
       labels(i) = label
       pi(i) = math.log(n + lambda) - piLogDenom
       val thetaLogDenom = modelType match {
-        case "Multinomial" => math.log(brzSum(sumTermFreqs) + numFeatures * lambda)
+        case "Multinomial" => math.log(sumTermFreqs.values.sum + numFeatures * lambda)
         case "Bernoulli" => math.log(n + 2.0 * lambda)
         case _ =>
           // This should never happen.


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message