spark-reviews mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From mateiz <...@git.apache.org>
Subject [GitHub] spark pull request: [WIP] [SPARK-1328] Add vector statistics
Date Thu, 10 Apr 2014 05:13:28 GMT
Github user mateiz commented on a diff in the pull request:

    https://github.com/apache/spark/pull/268#discussion_r11470050
  
    --- Diff: mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
---
    @@ -19,13 +19,144 @@ package org.apache.spark.mllib.linalg.distributed
     
     import java.util
     
    -import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, svd => brzSvd}
    +import breeze.linalg.{Vector => BV, DenseMatrix => BDM, DenseVector => BDV,
svd => brzSvd}
     import breeze.numerics.{sqrt => brzSqrt}
     import com.github.fommil.netlib.BLAS.{getInstance => blas}
     
     import org.apache.spark.mllib.linalg._
     import org.apache.spark.rdd.RDD
     import org.apache.spark.Logging
    +import org.apache.spark.mllib.stat.MultivariateStatisticalSummary
    +
    +/**
    + * Column statistics aggregator implementing
    + * [[org.apache.spark.mllib.stat.MultivariateStatisticalSummary]]
    + * together with add() and merge() function.
    + * A numerically stable algorithm is implemented to compute sample mean and variance:
    +  *[[http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance variance-wiki]].
    + * Zero elements (including explicit zero values) are skipped when calling add() and
merge(),
    + * to have time complexity O(nnz) instead of O(n) for each column.
    + */
    +private class ColumnStatisticsAggregator(private val n: Int)
    +    extends MultivariateStatisticalSummary with Serializable {
    +
    +  private val currMean: BDV[Double] = BDV.zeros[Double](n)
    +  private val currM2n: BDV[Double] = BDV.zeros[Double](n)
    +  private var totalCnt = 0.0
    +  private val nnz: BDV[Double] = BDV.zeros[Double](n)
    +  private val currMax: BDV[Double] = BDV.fill(n)(Double.MinValue)
    +  private val currMin: BDV[Double] = BDV.fill(n)(Double.MaxValue)
    +
    +  override def mean: Vector = {
    +    val realMean = BDV.zeros[Double](n)
    +    var i = 0
    +    while (i < n) {
    +      realMean(i) = currMean(i) * nnz(i) / totalCnt
    +      i += 1
    +    }
    +    Vectors.fromBreeze(realMean)
    +  }
    +
    +  override def variance: Vector = {
    +    val realVariance = BDV.zeros[Double](n)
    +
    +    val denominator = totalCnt - 1.0
    +
    +    // Sample variance is computed, if the denominator is 0, the variance is just 0.
    +    if (denominator != 0.0) {
    +      val deltaMean = currMean
    +      var i = 0
    +      while (i < currM2n.size) {
    +        realVariance(i) =
    +          currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * (totalCnt - nnz(i)) / totalCnt
    +        realVariance(i) /= denominator
    +        i += 1
    +      }
    +    }
    +
    +    Vectors.fromBreeze(realVariance)
    +  }
    +
    +  override def count: Long = totalCnt.toLong
    +
    +  override def numNonzeros: Vector = Vectors.fromBreeze(nnz)
    +
    +  override def max: Vector = {
    +    var i = 0
    +    while (i < n) {
    +      if ((nnz(i) < totalCnt) && (currMax(i) < 0.0)) currMax(i) = 0.0
    +      i += 1
    +    }
    +    Vectors.fromBreeze(currMax)
    +  }
    +
    +  override def min: Vector = {
    +    var i = 0
    +    while (i < n) {
    +      if ((nnz(i) < totalCnt) && (currMin(i) > 0.0)) currMin(i) = 0.0
    +      i += 1
    +    }
    +    Vectors.fromBreeze(currMin)
    +  }
    +
    +  /**
    +   * Aggregates a row.
    +   */
    +  def add(currData: BV[Double]): this.type = {
    +    currData.activeIterator.foreach {
    +      case (_, 0.0) => // Skip explicit zero elements.
    +      case (i, value) =>
    +        if (currMax(i) < value) currMax(i) = value
    +        if (currMin(i) > value) currMin(i) = value
    +
    +        val tmpPrevMean = currMean(i)
    +        currMean(i) = (currMean(i) * nnz(i) + value) / (nnz(i) + 1.0)
    +        currM2n(i) += (value - currMean(i)) * (value - tmpPrevMean)
    +
    +        nnz(i) += 1.0
    +    }
    +
    +    totalCnt += 1.0
    +    this
    +  }
    +
    +  /**
    +   * Merges another aggregator.
    +   */
    +  def merge(other: ColumnStatisticsAggregator): this.type = {
    +
    +    require(n == other.n, s"Dimensions mismatch. Expecting $n but got ${other.n}.")
    +
    +    totalCnt += other.totalCnt
    +
    +    val deltaMean = currMean - other.currMean
    +
    +    var i = 0
    +    while (i < n) {
    +      // merge mean together
    +      if (other.currMean(i) != 0.0) {
    +        currMean(i) = (currMean(i) * nnz(i) + other.currMean(i) * other.nnz(i)) /
    +          (nnz(i) + other.nnz(i))
    +      }
    +
    +      // merge m2n together
    +      if (nnz(i) + other.nnz(i) != 0.0) {
    +        currM2n(i) += other.currM2n(i) + deltaMean(i) * deltaMean(i) * nnz(i) * other.nnz(i)
/
    +          (nnz(i) + other.nnz(i))
    +      }
    +
    +      if (currMax(i) < other.currMax(i)) currMax(i) = other.currMax(i)
    --- End diff --
    
    Use curly braces around the if bodies in case we want to extend them later, like this:
    ```
    if (currMax(i) < other.currMax(i)) {
      currMax(i) = other.currMax(i)
    }
    ```


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

Mime
View raw message