Return-Path: X-Original-To: apmail-spark-commits-archive@minotaur.apache.org Delivered-To: apmail-spark-commits-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id 13BA1186B1 for ; Thu, 14 May 2015 04:27:22 +0000 (UTC) Received: (qmail 93478 invoked by uid 500); 14 May 2015 04:27:22 -0000 Delivered-To: apmail-spark-commits-archive@spark.apache.org Received: (qmail 93450 invoked by uid 500); 14 May 2015 04:27:22 -0000 Mailing-List: contact commits-help@spark.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Delivered-To: mailing list commits@spark.apache.org Received: (qmail 93441 invoked by uid 99); 14 May 2015 04:27:21 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Thu, 14 May 2015 04:27:21 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id D4995DFF6A; Thu, 14 May 2015 04:27:21 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: meng@apache.org To: commits@spark.apache.org Message-Id: X-Mailer: ASF-Git Admin Mailer Subject: spark git commit: [SPARK-7612] [MLLIB] update NB training to use mllib's BLAS Date: Thu, 14 May 2015 04:27:21 +0000 (UTC) Repository: spark Updated Branches: refs/heads/master 3113da9c7 -> d5f18de16 [SPARK-7612] [MLLIB] update NB training to use mllib's BLAS This is similar to the changes to k-means, which gives us better control on the performance. dbtsai Author: Xiangrui Meng Closes #6128 from mengxr/SPARK-7612 and squashes the following commits: b5c24c5 [Xiangrui Meng] merge master a90e3ec [Xiangrui Meng] update NB training to use mllib's BLAS Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d5f18de1 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d5f18de1 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d5f18de1 Branch: refs/heads/master Commit: d5f18de1657bfabf5493011e0b2c7ec29c02c64c Parents: 3113da9 Author: Xiangrui Meng Authored: Wed May 13 21:27:17 2015 -0700 Committer: Xiangrui Meng Committed: Wed May 13 21:27:17 2015 -0700 ---------------------------------------------------------------------- .../spark/mllib/classification/NaiveBayes.scala | 43 +++++++++----------- 1 file changed, 20 insertions(+), 23 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d5f18de1/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index b381dc2..af24ab6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -21,15 +21,13 @@ import java.lang.{Iterable => JIterable} import scala.collection.JavaConverters._ -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum, Axis} +import breeze.linalg.{Axis, DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum} import breeze.numerics.{exp => brzExp, log => brzLog} - import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ -import org.json4s.{DefaultFormats, JValue} import org.apache.spark.{Logging, SparkContext, SparkException} -import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector} +import org.apache.spark.mllib.linalg.{BLAS, DenseVector, SparseVector, Vector} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD @@ -90,13 +88,13 @@ class NaiveBayesModel private[mllib] ( val brzData = testData.toBreeze modelType match { case "Multinomial" => - labels (brzArgmax (brzPi + brzTheta * brzData) ) + labels(brzArgmax(brzPi + brzTheta * brzData)) case "Bernoulli" => if (!brzData.forall(v => v == 0.0 || v == 1.0)) { throw new SparkException( s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.") } - labels (brzArgmax (brzPi + + labels(brzArgmax(brzPi + (brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get)) case _ => // This should never happen. @@ -152,7 +150,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { // Check schema explicitly since erasure makes it hard to use match-case for checking. checkSchema[Data](dataRDD.schema) val dataArray = dataRDD.select("labels", "pi", "theta", "modelType").take(1) - assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") + assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") val data = dataArray(0) val labels = data.getAs[Seq[Double]](0).toArray val pi = data.getAs[Seq[Double]](1).toArray @@ -198,7 +196,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] { // Check schema explicitly since erasure makes it hard to use match-case for checking. checkSchema[Data](dataRDD.schema) val dataArray = dataRDD.select("labels", "pi", "theta").take(1) - assert(dataArray.size == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") + assert(dataArray.length == 1, s"Unable to load NaiveBayesModel data from: ${dataPath(path)}") val data = dataArray(0) val labels = data.getAs[Seq[Double]](0).toArray val pi = data.getAs[Seq[Double]](1).toArray @@ -288,10 +286,8 @@ class NaiveBayes private ( def run(data: RDD[LabeledPoint]): NaiveBayesModel = { val requireNonnegativeValues: Vector => Unit = (v: Vector) => { val values = v match { - case SparseVector(size, indices, values) => - values - case DenseVector(values) => - values + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values } if (!values.forall(_ >= 0.0)) { throw new SparkException(s"Naive Bayes requires nonnegative feature values but found $v.") @@ -300,10 +296,8 @@ class NaiveBayes private ( val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { val values = v match { - case SparseVector(size, indices, values) => - values - case DenseVector(values) => - values + case sv: SparseVector => sv.values + case dv: DenseVector => dv.values } if (!values.forall(v => v == 0.0 || v == 1.0)) { throw new SparkException( @@ -314,21 +308,24 @@ class NaiveBayes private ( // Aggregates term frequencies per label. // TODO: Calling combineByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. - val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])]( + val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, DenseVector)]( createCombiner = (v: Vector) => { if (modelType == "Bernoulli") { requireZeroOneBernoulliValues(v) } else { requireNonnegativeValues(v) } - (1L, v.toBreeze.toDenseVector) + (1L, v.copy.toDense) }, - mergeValue = (c: (Long, BDV[Double]), v: Vector) => { + mergeValue = (c: (Long, DenseVector), v: Vector) => { requireNonnegativeValues(v) - (c._1 + 1L, c._2 += v.toBreeze) + BLAS.axpy(1.0, v, c._2) + (c._1 + 1L, c._2) }, - mergeCombiners = (c1: (Long, BDV[Double]), c2: (Long, BDV[Double])) => - (c1._1 + c2._1, c1._2 += c2._2) + mergeCombiners = (c1: (Long, DenseVector), c2: (Long, DenseVector)) => { + BLAS.axpy(1.0, c2._2, c1._2) + (c1._1 + c2._1, c1._2) + } ).collect() val numLabels = aggregated.length @@ -348,7 +345,7 @@ class NaiveBayes private ( labels(i) = label pi(i) = math.log(n + lambda) - piLogDenom val thetaLogDenom = modelType match { - case "Multinomial" => math.log(brzSum(sumTermFreqs) + numFeatures * lambda) + case "Multinomial" => math.log(sumTermFreqs.values.sum + numFeatures * lambda) case "Bernoulli" => math.log(n + 2.0 * lambda) case _ => // This should never happen. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org For additional commands, e-mail: commits-help@spark.apache.org