spark-reviews mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From yinxusen <...@git.apache.org>
Subject [GitHub] spark pull request: [SPARK-1543][MLlib] Add ADMM for solving Lasso...
Date Tue, 22 Apr 2014 00:12:03 GMT
Github user yinxusen commented on a diff in the pull request:

    https://github.com/apache/spark/pull/458#discussion_r11833024
  
    --- Diff: mllib/src/main/scala/org/apache/spark/mllib/optimization/ADMMLasso.scala ---
    @@ -0,0 +1,217 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.mllib.optimization
    +
    +import scala.collection.mutable.ArrayBuffer
    +
    +import breeze.linalg.{Vector => BV, DenseVector => BDV, DenseMatrix => BDM,
cholesky, norm}
    +
    +import org.apache.spark.mllib.linalg.{Vectors, Vector}
    +import org.apache.spark.{Partitioner, HashPartitioner, Logging}
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.SparkContext._
    +import org.apache.spark.annotation.DeveloperApi
    +
    +
    +/**
    + * :: DeveloperApi ::
    + * Class used to solve the optimization problem in ADMMLasso
    + */
    +@DeveloperApi
    +class ADMMLasso
    +  extends Optimizer with Logging
    +{
    +  private var numPartitions: Int = 10
    +  private var numIterations: Int = 100
    +  private var l1RegParam: Double = 1.0
    +  private var l2RegParam: Double = .0
    +  private var penalty: Double = 10.0
    +
    +
    +  /**
    +   * Set the number of partitions for ADMM. Default 10
    +   */
    +  def setNumPartitions(parts: Int): this.type = {
    +    this.numPartitions = parts
    +    this
    +  }
    +
    +  /**
    +   * Set the number of iterations for ADMM. Default 100.
    +   */
    +  def setNumIterations(iters: Int): this.type = {
    +    this.numIterations = iters
    +    this
    +  }
    +
    +  /**
    +   * Set the l1-regularization parameter. Default 1.0.
    +   */
    +  def setL1RegParam(regParam: Double): this.type = {
    +    this.l1RegParam = regParam
    +    this
    +  }
    +
    +  /**
    +   * Set the l2-regularization parameter. Default .0
    +   */
    +  def setL2RegParam(regParam: Double): this.type = {
    +    this.l2RegParam = regParam
    +    this
    +  }
    +
    +  /**
    +   * Set the penalty parameter. Default 10.0
    +   */
    +  def setPenalty(penalty: Double): this.type = {
    +    this.penalty = penalty
    +    this
    +  }
    +
    +  def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = {
    +    val (weights, _) = ADMMLasso.runADMM(data, numPartitions, numIterations, l1RegParam,
    +      l2RegParam, penalty, initialWeights)
    +    weights
    +  }
    +
    +}
    +
    +/**
    + * :: DeveloperApi ::
    + * Top-level method to run ADMMLasso.
    + */
    +@DeveloperApi
    +object ADMMLasso extends Logging {
    +
    +  /**
    +   * @param data  Input data for ADMMLasso. RDD of the set of data examples, each of
    +   *               the form (label, [feature values]).
    +   * @param numPartitions  number of data blocks to partition the RDD into
    +   * @param numIterations  number of iterations that ADMM should be run.
    +   * @param l1RegParam  l1-regularization parameter
    +   * @param l2RegParam  l2-regularization parameter
    +   * @param penalty  The penalty parameter in ADMM
    +   * @param initialWeights  Initial set of weights to be used. Array should be equal
in size to
    +   *        the number of features in the data.
    +   * @return A tuple containing two elements. The first element is a column matrix containing
    +   *         weights for every feature, and the second element is an array containing
the loss
    +   *         computed for every iteration.
    +   */
    +  def runADMM(
    +      data: RDD[(Double, Vector)],
    +      numPartitions: Int,
    +      numIterations: Int,
    +      l1RegParam: Double,
    +      l2RegParam: Double,
    +      penalty: Double,
    +      initialWeights: Vector): (Vector, Array[Double]) = {
    +
    +    val lossHistory = new ArrayBuffer[Double](numIterations)
    +
    +    // Initialize weights as a column vector
    +    val p = initialWeights.size
    +
    +    // Consensus variable
    +    var z =  BDV.zeros[Double](p)
    +
    +    // Transform the input data into ADMM format
    +    def collectBlock(it: Iterator[(Vector, Double)]):
    +        Iterator[((BDV[Double], BDM[Double], BDM[Double]), (BDV[Double], BDV[Double]))]
= {
    +      val lab = new ArrayBuffer[Double]()
    +      val features = new ArrayBuffer[Double]()
    +      var row = 0
    +      it.foreach {case (f, l) =>
    +        lab += l
    +        features ++= f.toArray
    +        row += 1
    +      }
    +      val col = features.length/row
    +
    +      val designMatrix = new BDM(col, features.toArray).t
    +
    +      // Precompute the cholesky decomposition for solving linear system inside each
partition
    +      val chol = if (row >= col) {
    +        cholesky((designMatrix.t * designMatrix) + (BDM.eye[Double](col) :* penalty))
    +      }
    +      else cholesky(((designMatrix * designMatrix.t) :/ penalty) + BDM.eye[Double](row))
    +
    +      Iterator(((BDV(lab.toArray), designMatrix, chol),
    +        (BDV(initialWeights.toArray), BDV.zeros[Double](col))))
    +    }
    +
    +    val partitionedData = data.map{case (x, y) => (y, x)}
    +      .partitionBy(new HashPartitioner(numPartitions)).cache()
    +
    +    // ((lab, design, chol), (x, u))
    +    var dividedData = partitionedData.mapPartitions(collectBlock, true)
    +
    +    var iter = 1
    +    var minorChange: Boolean = false
    +    while (iter <= numIterations && !minorChange) {
    +      val zBroadcast = z
    +      def localUpdate(
    +           it: Iterator[((BDV[Double], BDM[Double], BDM[Double]), (BDV[Double], BDV[Double]))]):
    +           Iterator[((BDV[Double], BDM[Double], BDM[Double]), (BDV[Double], BDV[Double]))]
= {
    +        if (it.hasNext) {
    +          val localData = it.next()
    +          val (x, u) = localData._2
    +          val updatedU = u + ((x - zBroadcast) :* penalty)
    +          // Update local x by solving linear system Ax = q
    +          val (lab, design, chol) = localData._1
    +          val (row, col) = (design.rows, design.cols)
    +          val q = (design.t * lab) + (zBroadcast :* penalty) - u
    +
    +          val updatedX = if (row >= col) {
    +            chol.t \ (chol \ q)
    +          }else {
    --- End diff --
    
    add a space after `}`


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

Mime
View raw message