spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sro...@apache.org
Subject spark git commit: [SPARK-19054][ML] Eliminate extra pass in NB
Date Wed, 04 Jan 2017 11:54:17 GMT
Repository: spark
Updated Branches:
  refs/heads/master 101556d0f -> 7a8250581


[SPARK-19054][ML] Eliminate extra pass in NB

## What changes were proposed in this pull request?
eliminate unnecessary extra pass in NB's train

## How was this patch tested?
existing tests

Author: Zheng RuiFeng <ruifengz@foxmail.com>

Closes #16453 from zhengruifeng/nb_getNC.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7a825058
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7a825058
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7a825058

Branch: refs/heads/master
Commit: 7a82505817d479007adff6424473063d2003fcc1
Parents: 101556d
Author: Zheng RuiFeng <ruifengz@foxmail.com>
Authored: Wed Jan 4 11:54:13 2017 +0000
Committer: Sean Owen <sowen@cloudera.com>
Committed: Wed Jan 4 11:54:13 2017 +0000

----------------------------------------------------------------------
 .../org/apache/spark/ml/classification/NaiveBayes.scala   | 10 ++++------
 1 file changed, 4 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7a825058/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 94ee2a2..e90040d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -127,13 +127,11 @@ class NaiveBayes @Since("1.5.0") (
   private[spark] def trainWithLabelCheck(
       dataset: Dataset[_],
       positiveLabel: Boolean): NaiveBayesModel = {
-    if (positiveLabel) {
+    if (positiveLabel && isDefined(thresholds)) {
       val numClasses = getNumClasses(dataset)
-      if (isDefined(thresholds)) {
-        require($(thresholds).length == numClasses, this.getClass.getSimpleName +
-          ".train() called with non-matching numClasses and thresholds.length." +
-          s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
-      }
+      require($(thresholds).length == numClasses, this.getClass.getSimpleName +
+        ".train() called with non-matching numClasses and thresholds.length." +
+        s" numClasses=$numClasses, but thresholds has length ${$(thresholds).length}")
     }
 
     val modelTypeValue = $(modelType)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message