spark-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From m...@apache.org
Subject spark git commit: [SPARK-13295][ ML, MLLIB ] AFTSurvivalRegression.AFTAggregator improvements - avoid creating new instances of arrays/vectors for each record
Date Tue, 23 Feb 2016 01:26:34 GMT
Repository: spark
Updated Branches:
  refs/heads/master 02b1fefff -> 33ef3aa7e


[SPARK-13295][ ML, MLLIB ] AFTSurvivalRegression.AFTAggregator improvements - avoid creating
new instances of arrays/vectors for each record

As also mentioned/marked by TODO in AFTAggregator.AFTAggregator.add(data: AFTPoint) method
a new array is being created for intercept value and it is being concatenated
with another array which contains the betas, the resulted Array is being converted into a
Dense vector which in its turn is being converted into breeze vector.
This is expensive and not necessarily beautiful.

I've tried to solve above mentioned problem by simple algebraic decompositions - keeping and
treating intercept independently.

Please let me know what do you think and if you have any questions.

Thanks,
Narine

Author: Narine Kokhlikyan <narine.kokhlikyan@gmail.com>

Closes #11179 from NarineK/survivaloptim.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/33ef3aa7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/33ef3aa7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/33ef3aa7

Branch: refs/heads/master
Commit: 33ef3aa7eabbe323620eb77fa94a53996ed0251d
Parents: 02b1fef
Author: Narine Kokhlikyan <narine.kokhlikyan@gmail.com>
Authored: Mon Feb 22 17:26:32 2016 -0800
Committer: Xiangrui Meng <meng@databricks.com>
Committed: Mon Feb 22 17:26:32 2016 -0800

----------------------------------------------------------------------
 .../ml/regression/AFTSurvivalRegression.scala   | 32 +++++++++++---------
 1 file changed, 17 insertions(+), 15 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/33ef3aa7/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index e8a1ff2..1e5b4cb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -437,23 +437,25 @@ object AFTSurvivalRegressionModel extends MLReadable[AFTSurvivalRegressionModel]
 private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
   extends Serializable {
 
-  // beta is the intercept and regression coefficients to the covariates
-  private val beta = parameters.slice(1, parameters.length)
+  // the regression coefficients to the covariates
+  private val coefficients = parameters.slice(2, parameters.length)
+  private val intercept = parameters.valueAt(1)
   // sigma is the scale parameter of the AFT model
   private val sigma = math.exp(parameters(0))
 
   private var totalCnt: Long = 0L
   private var lossSum = 0.0
-  private var gradientBetaSum = BDV.zeros[Double](beta.length)
+  private var gradientCoefficientSum = BDV.zeros[Double](coefficients.length)
+  private var gradientInterceptSum = 0.0
   private var gradientLogSigmaSum = 0.0
 
   def count: Long = totalCnt
 
   def loss: Double = if (totalCnt == 0) 1.0 else lossSum / totalCnt
 
-  // Here we optimize loss function over beta and log(sigma)
+  // Here we optimize loss function over coefficients, intercept and log(sigma)
   def gradient: BDV[Double] = BDV.vertcat(BDV(Array(gradientLogSigmaSum / totalCnt.toDouble)),
-    gradientBetaSum/totalCnt.toDouble)
+    BDV(Array(gradientInterceptSum/totalCnt.toDouble)), gradientCoefficientSum/totalCnt.toDouble)
 
   /**
    * Add a new training data to this AFTAggregator, and update the loss and gradient
@@ -464,15 +466,12 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
    */
   def add(data: AFTPoint): this.type = {
 
-    // TODO: Don't create a new xi vector each time.
-    val xi = if (fitIntercept) {
-      Vectors.dense(Array(1.0) ++ data.features.toArray).toBreeze
-    } else {
-      Vectors.dense(Array(0.0) ++ data.features.toArray).toBreeze
-    }
+    val interceptFlag = if (fitIntercept) 1.0 else 0.0
+
+    val xi = data.features.toBreeze
     val ti = data.label
     val delta = data.censor
-    val epsilon = (math.log(ti) - beta.dot(xi)) / sigma
+    val epsilon = (math.log(ti) - coefficients.dot(xi) - intercept * interceptFlag ) / sigma
 
     lossSum += math.log(sigma) * delta
     lossSum += (math.exp(epsilon) - delta * epsilon)
@@ -481,8 +480,10 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
     assert(!lossSum.isInfinity,
       s"AFTAggregator loss sum is infinity. Error for unknown reason.")
 
-    gradientBetaSum += xi * (delta - math.exp(epsilon)) / sigma
-    gradientLogSigmaSum += delta + (delta - math.exp(epsilon)) * epsilon
+    val deltaMinusExpEps = delta - math.exp(epsilon)
+    gradientCoefficientSum += xi * deltaMinusExpEps / sigma
+    gradientInterceptSum += interceptFlag * deltaMinusExpEps / sigma
+    gradientLogSigmaSum += delta + deltaMinusExpEps * epsilon
 
     totalCnt += 1
     this
@@ -501,7 +502,8 @@ private class AFTAggregator(parameters: BDV[Double], fitIntercept: Boolean)
       totalCnt += other.totalCnt
       lossSum += other.lossSum
 
-      gradientBetaSum += other.gradientBetaSum
+      gradientCoefficientSum += other.gradientCoefficientSum
+      gradientInterceptSum += other.gradientInterceptSum
       gradientLogSigmaSum += other.gradientLogSigmaSum
     }
     this


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org


Mime
View raw message