spark-reviews mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From jkbradley <...@git.apache.org>
Subject [GitHub] spark pull request #15211: [SPARK-14709][ML] spark.ml API for linear SVM
Date Wed, 28 Dec 2016 00:10:46 GMT
Github user jkbradley commented on a diff in the pull request:

    https://github.com/apache/spark/pull/15211#discussion_r93972268
  
    --- Diff: mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala ---
    @@ -0,0 +1,558 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.ml.classification
    +
    +import scala.collection.mutable
    +
    +import breeze.linalg.{DenseVector => BDV}
    +import breeze.optimize.{CachedDiffFunction, DiffFunction, OWLQN => BreezeOWLQN}
    +import org.apache.hadoop.fs.Path
    +
    +import org.apache.spark.SparkException
    +import org.apache.spark.annotation.{Experimental, Since}
    +import org.apache.spark.broadcast.Broadcast
    +import org.apache.spark.internal.Logging
    +import org.apache.spark.ml.feature.Instance
    +import org.apache.spark.ml.linalg._
    +import org.apache.spark.ml.linalg.BLAS._
    +import org.apache.spark.ml.param._
    +import org.apache.spark.ml.param.shared._
    +import org.apache.spark.ml.util._
    +import org.apache.spark.mllib.linalg.VectorImplicits._
    +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
    +import org.apache.spark.rdd.RDD
    +import org.apache.spark.sql.{Dataset, Row}
    +import org.apache.spark.sql.functions.{col, lit}
    +
    +/** Params for linear SVM Classifier. */
    +private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam
    +  with HasMaxIter with HasFitIntercept with HasTol with HasStandardization with HasWeightCol
    +  with HasThreshold with HasAggregationDepth {
    +
    +}
    +
    +/**
    + * :: Experimental ::
    + * Linear SVM Classifier with Hinge Loss and OWLQN optimizer
    + */
    +@Since("2.2.0")
    +@Experimental
    +class LinearSVC @Since("2.2.0") (
    +    @Since("2.2.0") override val uid: String)
    +  extends Classifier[Vector, LinearSVC, LinearSVCModel]
    +  with LinearSVCParams with DefaultParamsWritable {
    +
    +  @Since("2.2.0")
    +  def this() = this(Identifiable.randomUID("linearsvc"))
    +
    +  /**
    +   * Set the regularization parameter.
    +   * Default is 0.0.
    +   *
    +   * @group setParam
    +   */
    +  @Since("2.2.0")
    +  def setRegParam(value: Double): this.type = set(regParam, value)
    +  setDefault(regParam -> 0.0)
    +
    +  /**
    +   * Set the maximum number of iterations.
    +   * Default is 100.
    +   *
    +   * @group setParam
    +   */
    +  @Since("2.2.0")
    +  def setMaxIter(value: Int): this.type = set(maxIter, value)
    +  setDefault(maxIter -> 100)
    +
    +  /**
    +   * Whether to fit an intercept term.
    +   * Default is true.
    +   *
    +   * @group setParam
    +   */
    +  @Since("2.2.0")
    +  def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value)
    +  setDefault(fitIntercept -> true)
    +
    +  /**
    +   * Set the convergence tolerance of iterations.
    +   * Smaller value will lead to higher accuracy at the cost of more iterations.
    +   * Default is 1E-6.
    +   *
    +   * @group setParam
    +   */
    +  @Since("2.2.0")
    +  def setTol(value: Double): this.type = set(tol, value)
    +  setDefault(tol -> 1E-6)
    +
    +  /**
    +   * whether to standardize the training features before fitting the model.
    +   * Default is true.
    +   *
    +   * @group setParam
    +   */
    +  @Since("2.2.0")
    +  def setStandardization(value: Boolean): this.type = set(standardization, value)
    +  setDefault(standardization -> true)
    +
    +  /**
    +   * Sets the value of param [[weightCol]].
    +   * If this is not set or empty, we treat all instance weights as 1.0.
    +   * Default is not set, so all instances have weight one.
    +   *
    +   * @group setParam
    +   */
    +  @Since("2.2.0")
    +  def setWeightCol(value: String): this.type = set(weightCol, value)
    +
    +  /**
    +   * Set threshold in binary classification, in range [0, 1].
    +   *
    +   * @group setParam
    +   */
    +  @Since("2.2.0")
    +  def setThreshold(value: Double): this.type = set(threshold, value)
    +  setDefault(threshold -> 0.0)
    +
    +  /**
    +   * Suggested depth for treeAggregate (greater than or equal to 2).
    +   * If the dimensions of features or the number of partitions are large,
    +   * this param could be adjusted to a larger size.
    +   * Default is 2.
    +   *
    +   * @group expertSetParam
    +   */
    +  @Since("2.2.0")
    +  def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value)
    +  setDefault(aggregationDepth -> 2)
    +
    +  @Since("2.2.0")
    +  override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra)
    +
    +  /**
    +   *
    +   * Linear SVM Classifier (https://en.wikipedia.org/wiki/Support_vector_machine#Linear_SVM)
    +   *
    +   * This binary classifier optimizes the Hinge Loss using the OWLQN optimizer.
    +   *
    +   * @param dataset Training dataset
    +   * @return Fitted model
    +   */
    +  override protected def train(dataset: Dataset[_]): LinearSVCModel = {
    +    val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol))
    +    val instances: RDD[Instance] =
    +      dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
    +        case Row(label: Double, weight: Double, features: Vector) =>
    +          Instance(label, weight, features)
    +      }
    +
    +    val instr = Instrumentation.create(this, instances)
    +    instr.logParams(params: _*)
    +
    +    val (summarizer, labelSummarizer) = {
    +      val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer),
    +        instance: Instance) =>
    +          (c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight))
    +
    +      val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer),
    +        c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) =>
    +          (c1._1.merge(c2._1), c1._2.merge(c2._2))
    +
    +      instances.treeAggregate(
    +        new MultivariateOnlineSummarizer, new MultiClassSummarizer
    +      )(seqOp, combOp, $(aggregationDepth))
    +    }
    +
    +    val histogram = labelSummarizer.histogram
    +    val numInvalid = labelSummarizer.countInvalid
    +    val numFeatures = summarizer.mean.size
    +    val numFeaturesPlusIntercept = if (getFitIntercept) numFeatures + 1 else numFeatures
    +
    +    val numClasses = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match {
    +      case Some(n: Int) =>
    +        require(n >= histogram.length, s"Specified number of classes $n was " +
    +          s"less than the number of unique labels ${histogram.length}.")
    +        n
    +      case None => histogram.length
    +    }
    +    require(numClasses == 2, "LinearSVC only support binary classification.")
    +    instr.logNumClasses(numClasses)
    +    instr.logNumFeatures(numFeatures)
    +
    +    val (coefficientMatrix, interceptVector, objectiveHistory) = {
    +      if (numInvalid != 0) {
    +        val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " +
    +          s"Found $numInvalid invalid labels."
    +        logError(msg)
    +        throw new SparkException(msg)
    +      }
    +
    +      val featuresStd = summarizer.variance.toArray.map(math.sqrt)
    +      val regParamL2 = $(regParam)
    +      val bcFeaturesStd = instances.context.broadcast(featuresStd)
    +      val costFun = new LinearSVCCostFun(instances, numClasses, $(fitIntercept),
    +        $(standardization), bcFeaturesStd, regParamL2, $(aggregationDepth))
    +
    +      def regParamL1Fun = (index: Int) => 0D
    +      val optimizer = new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun,
$(tol))
    +      val initialCoefficientsWithIntercept = Vectors.zeros(numFeaturesPlusIntercept)
    +      if ($(fitIntercept)) {
    +        initialCoefficientsWithIntercept.toArray(numFeatures) = math.log(
    +          histogram(1) / histogram(0))
    +      }
    +
    +      val states = optimizer.iterations(new CachedDiffFunction(costFun),
    +        initialCoefficientsWithIntercept.asBreeze.toDenseVector)
    +
    +      val adjustedValues = mutable.ArrayBuilder.make[Double]
    +      var state: optimizer.State = null
    +      while (states.hasNext) {
    +        state = states.next()
    +        adjustedValues += state.adjustedValue
    +      }
    +
    +      bcFeaturesStd.destroy(blocking = false)
    +      if (state == null) {
    +        val msg = s"${optimizer.getClass.getName} failed."
    +        logError(msg)
    +        throw new SparkException(msg)
    +      }
    +
    +      /*
    +         The coefficients are trained in the scaled space; we're converting them back
to
    +         the original space.
    +         Note that the intercept in scaled space and original space is the same;
    +         as a result, no scaling is needed.
    +       */
    +      val rawCoefficients = state.x.toArray.clone()
    +      val coefficientArray = Array.tabulate(numFeatures) { i =>
    +        // flatIndex will loop though rawCoefficients, and skip the intercept terms.
    +        val flatIndex = if ($(fitIntercept)) i + i / numFeatures else i
    +        val featureIndex = i % numFeatures
    +        if (featuresStd(featureIndex) != 0.0) {
    +          rawCoefficients(flatIndex) / featuresStd(featureIndex)
    +        } else {
    +          0.0
    +        }
    +      }
    +
    +      val intercept = if ($(fitIntercept)) {
    +        rawCoefficients(numFeaturesPlusIntercept - 1)
    +      } else {
    +        0.0
    +      }
    +      (Vectors.dense(coefficientArray), intercept, adjustedValues.result())
    +    }
    +
    +    val model = copyValues(new LinearSVCModel(uid, coefficientMatrix, interceptVector))
    +    instr.logSuccess(model)
    +    model
    +  }
    +}
    +
    +@Since("2.2.0")
    +object LinearSVC extends DefaultParamsReadable[LinearSVC] {
    +
    +  @Since("2.2.0")
    +  override def load(path: String): LinearSVC = super.load(path)
    +}
    +
    +/**
    + * :: Experimental ::
    + * SVM Model trained by [[LinearSVC]]
    + */
    +@Since("2.2.0")
    +@Experimental
    +class LinearSVCModel private[classification] (
    +    @Since("2.2.0") override val uid: String,
    +    @Since("2.2.0") val coefficients: Vector,
    +    @Since("2.2.0") val intercept: Double)
    +  extends ClassificationModel[Vector, LinearSVCModel]
    +  with LinearSVCParams with MLWritable {
    +
    +  @Since("2.2.0")
    +  override val numClasses: Int = 2
    +
    +  @Since("2.2.0")
    +  override val numFeatures: Int = coefficients.size
    +
    +    @Since("2.2.0")
    +  def setThreshold(value: Double): this.type = set(threshold, value)
    +
    +  @Since("2.2.0")
    +  def setWeightCol(value: Double): this.type = set(threshold, value)
    +
    +  private val margin: Vector => Double = (features) => {
    +    BLAS.dot(features, coefficients) + intercept
    +  }
    +
    +  override protected def predict(features: Vector): Double = {
    +    if (margin(features) > $(threshold)) 1.0 else 0.0
    +  }
    +
    +  override protected def predictRaw(features: Vector): Vector = {
    +    val m = margin(features)
    +    Vectors.dense(-m, m)
    +  }
    +
    +  @Since("2.2.0")
    +  override def copy(extra: ParamMap): LinearSVCModel = {
    +    copyValues(new LinearSVCModel(uid, coefficients, intercept), extra).setParent(parent)
    +  }
    +
    +  @Since("2.2.0")
    +  override def write: MLWriter = new LinearSVCModel.LinearSVCWriter(this)
    +
    +}
    +
    +
    +@Since("2.2.0")
    +object LinearSVCModel extends MLReadable[LinearSVCModel] {
    +
    +  @Since("2.2.0")
    +  override def read: MLReader[LinearSVCModel] = new LinearSVCReader
    +
    +  @Since("2.2.0")
    +  override def load(path: String): LinearSVCModel = super.load(path)
    +
    +  /** [[MLWriter]] instance for [[LinearSVCModel]] */
    +  private[LinearSVCModel]
    +  class LinearSVCWriter(instance: LinearSVCModel)
    +    extends MLWriter with Logging {
    +
    +    private case class Data(coefficients: Vector, intercept: Double)
    +
    +    override protected def saveImpl(path: String): Unit = {
    +      // Save metadata and Params
    +      DefaultParamsWriter.saveMetadata(instance, path, sc)
    +      val data = Data(instance.coefficients, instance.intercept)
    +      val dataPath = new Path(path, "data").toString
    +      sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
    +    }
    +  }
    +
    +  private class LinearSVCReader extends MLReader[LinearSVCModel] {
    +
    +    /** Checked against metadata when loading model */
    +    private val className = classOf[LinearSVCModel].getName
    +
    +    override def load(path: String): LinearSVCModel = {
    +      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
    +      val dataPath = new Path(path, "data").toString
    +      val data = sparkSession.read.format("parquet").load(dataPath)
    +      val Row(coefficients: Vector, intercept: Double) =
    +        data.select("coefficients", "intercept").head()
    +      val model = new LinearSVCModel(metadata.uid, coefficients, intercept)
    +      DefaultParamsReader.getAndSetParams(model, metadata)
    +      model
    +    }
    +  }
    +}
    +
    +/**
    + * LinearSVCCostFun implements Breeze's DiffFunction[T] for hinge loss function
    + */
    +private class LinearSVCCostFun(
    +    instances: RDD[Instance],
    +    numClasses: Int,
    +    fitIntercept: Boolean,
    +    standardization: Boolean,
    +    bcFeaturesStd: Broadcast[Array[Double]],
    +    regParamL2: Double,
    +    aggregationDepth: Int) extends DiffFunction[BDV[Double]] {
    +
    +  override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
    +    val coeffs = Vectors.fromBreeze(coefficients)
    +    val bcCoeffs = instances.context.broadcast(coeffs)
    +    val featuresStd = bcFeaturesStd.value
    +    val numFeatures = featuresStd.length
    +
    +    val svmAggregator = {
    +      val seqOp = (c: LinearSVCAggregator, instance: Instance) => c.add(instance)
    +      val combOp = (c1: LinearSVCAggregator, c2: LinearSVCAggregator) => c1.merge(c2)
    +
    +      instances.treeAggregate(
    +        new LinearSVCAggregator(bcCoeffs, bcFeaturesStd, numClasses, fitIntercept)
    +      )(seqOp, combOp, aggregationDepth)
    +    }
    +
    +    val totalGradientArray = svmAggregator.gradient.toArray
    +    // regVal is the sum of coefficients squares excluding intercept for L2 regularization.
    +    val regVal = if (regParamL2 == 0.0) {
    +      0.0
    +    } else {
    +      var sum = 0.0
    +      coeffs.foreachActive { case (index, value) =>
    +        // We do not apply regularization to the intercepts
    +        val isIntercept = fitIntercept && ((index + 1) % (numFeatures + 1) ==
0)
    +        if (!isIntercept) {
    +          // The following code will compute the loss of the regularization; also
    +          // the gradient of the regularization, and add back to totalGradientArray.
    +          sum += {
    +            if (standardization) {
    +              totalGradientArray(index) += regParamL2 * value
    +              value * value
    +            } else {
    +              val featureIndex = if (fitIntercept) {
    +                index % (numFeatures + 1)
    +              } else {
    +                index % numFeatures
    +              }
    +              if (featuresStd(featureIndex) != 0.0) {
    +                // If `standardization` is false, we still standardize the data
    +                // to improve the rate of convergence; as a result, we have to
    +                // perform this reverse standardization by penalizing each component
    +                // differently to get effectively the same objective function when
    +                // the training dataset is not standardized.
    +                val temp = value / (featuresStd(featureIndex) * featuresStd(featureIndex))
    +                totalGradientArray(index) += regParamL2 * temp
    +                value * temp
    +              } else {
    +                0.0
    +              }
    +            }
    +          }
    +        }
    +      }
    +      0.5 * regParamL2 * sum
    +    }
    +    bcCoeffs.destroy(blocking = false)
    +
    +    (svmAggregator.loss + regVal, new BDV(totalGradientArray))
    +  }
    +}
    +
    +/**
    + * LinearSVCAggregator computes the gradient and loss for hinge loss function, as used
    + * in binary classification for instances in sparse or dense vector in a online fashion.
    + *
    + * Two LinearSVCAggregator can be merged together to have a summary of loss and gradient
of
    + * the corresponding joint dataset.
    + *
    + * @param bcCoefficients The coefficients corresponding to the features.
    + * @param fitIntercept Whether to fit an intercept term.
    + * @param bcFeaturesStd The standard deviation values of the features.
    + */
    +private class LinearSVCAggregator(
    +    bcCoefficients: Broadcast[Vector],
    +    bcFeaturesStd: Broadcast[Array[Double]],
    +    numClasses: Int,
    +    fitIntercept: Boolean) extends Serializable {
    +
    +  private val numFeatures = bcFeaturesStd.value.length
    +  private val numFeaturesPlusIntercept = if (fitIntercept) numFeatures + 1 else numFeatures
    +  private val coefficients = bcCoefficients.value
    +  private var weightSum = 0.0
    --- End diff --
    
    I know it's not done elsewhere, but I'd like us to put explicit types for class fields
to avoid type errors in the future.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org


Mime
View raw message