flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From trohrm...@apache.org
Subject flink git commit: [FLINK-2102] [ml] Add predict function for labeled data for SVM and MLR.
Date Tue, 02 Jun 2015 11:25:03 GMT
Repository: flink
Updated Branches:
  refs/heads/master 7571959a1 -> d163a817f


[FLINK-2102] [ml] Add predict function for labeled data for SVM and MLR.

These functions return for each example in the input DataSet[LabeledVector] a pair (truth,
prediction)

Added documentation for new predict functions

This closes #744.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/d163a817
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/d163a817
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/d163a817

Branch: refs/heads/master
Commit: d163a817fa2e330e86384d0bbcd104f051a6fb48
Parents: 7571959
Author: Theodore Vasiloudis <tvas@sics.se>
Authored: Thu May 28 18:51:17 2015 +0200
Committer: Till Rohrmann <trohrmann@apache.org>
Committed: Tue Jun 2 13:24:05 2015 +0200

----------------------------------------------------------------------
 docs/libs/ml/multiple_linear_regression.md      |  8 +++
 docs/libs/ml/svm.md                             |  8 +++
 .../apache/flink/ml/classification/SVM.scala    | 53 +++++++++++++++++-
 .../regression/MultipleLinearRegression.scala   | 58 +++++++++++++++++++-
 .../flink/ml/classification/SVMITSuite.scala    | 31 +++++++++++
 .../MultipleLinearRegressionITSuite.scala       | 24 ++++++++
 6 files changed, 178 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/d163a817/docs/libs/ml/multiple_linear_regression.md
----------------------------------------------------------------------
diff --git a/docs/libs/ml/multiple_linear_regression.md b/docs/libs/ml/multiple_linear_regression.md
index d9bc951..aaf1fbf 100644
--- a/docs/libs/ml/multiple_linear_regression.md
+++ b/docs/libs/ml/multiple_linear_regression.md
@@ -77,6 +77,14 @@ MultipleLinearRegression predicts for all subtypes of `Vector` the corresponding
 
 * `predict[T <: Vector]: DataSet[T] => DataSet[LabeledVector]`
 
+If we call predict with a `DataSet[LabeledVector]`, we make a prediction on the regression
value
+for each example, and return a `DataSet[(Double, Double)]`. In each tuple the first element
+is the true value, as was provided from the input `DataSet[LabeledVector]` and the second
element
+is the predicted value. You can then use these `(truth, prediction)` tuples to evaluate
+the algorithm's performance.
+
+* `predict: DataSet[LabeledVector] => DataSet[(Double, Double)]`
+
 ## Parameters
 
   The multiple linear regression implementation can be controlled by the following parameters:

http://git-wip-us.apache.org/repos/asf/flink/blob/d163a817/docs/libs/ml/svm.md
----------------------------------------------------------------------
diff --git a/docs/libs/ml/svm.md b/docs/libs/ml/svm.md
index a9c94ec..e649949 100644
--- a/docs/libs/ml/svm.md
+++ b/docs/libs/ml/svm.md
@@ -74,6 +74,14 @@ SVM predicts for all subtypes of `Vector` the corresponding class label:
 
 * `predict[T <: Vector]: DataSet[T] => DataSet[LabeledVector]`
 
+If we call predict with a `DataSet[LabeledVector]`, we make a prediction on the class label
+for each example, and return a `DataSet[(Double, Double)]`. In each tuple the first element
+is the true value, as was provided from the input `DataSet[LabeledVector]` and the second
element
+is the predicted value. You can then use these `(truth, prediction)` tuples to evaluate
+the algorithm's performance.
+
+* `predict: DataSet[LabeledVector] => DataSet[(Double, Double)]`
+
 ## Parameters
 
 The SVM implementation can be controlled by the following parameters:

http://git-wip-us.apache.org/repos/asf/flink/blob/d163a817/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
index a186c5d..95f2b23 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
@@ -33,7 +33,7 @@ import org.apache.flink.ml.math.Breeze._
 
 import breeze.linalg.{Vector => BreezeVector, DenseVector => BreezeDenseVector}
 
-/** Implements a soft-maring SVM using the communication-efficient distributed dual coordinate
+/** Implements a soft-margin SVM using the communication-efficient distributed dual coordinate
   * ascent algorithm (CoCoA) with hinge-loss function.
   *
   * The algorithm solves the following minimization problem:
@@ -276,6 +276,57 @@ object SVM{
     }
   }
 
+  /** [[org.apache.flink.ml.pipeline.PredictOperation]] for [[LabeledVector ]]types. The
result type
+    * is a [[(Double, Double)]] tuple, corresponding to (truth, prediction)
+    *
+    * @return A DataSet[(Double, Double)] where each tuple is a (truth, prediction) pair.
+    */
+  implicit def predictLabeledValues = {
+    new PredictOperation[SVM, LabeledVector, (Double, Double)]{
+      override def predict(
+                            instance: SVM,
+                            predictParameters: ParameterMap,
+                            input: DataSet[LabeledVector])
+      : DataSet[(Double, Double)] = {
+
+        instance.weightsOption match {
+          case Some(weights) => {
+            input.map(new LabeledPredictionMapper).withBroadcastSet(weights, WEIGHT_VECTOR)
+          }
+
+          case None => {
+            throw new RuntimeException("The SVM model has not been trained. Call first fit"
+
+              "before calling the predict operation.")
+          }
+        }
+      }
+    }
+  }
+
+  /** Mapper to calculate the value of the prediction function. This is a RichMapFunction,
because
+    * we broadcast the weight vector to all mappers.
+    */
+  class LabeledPredictionMapper extends RichMapFunction[LabeledVector, (Double, Double)]
{
+
+    var weights: BreezeDenseVector[Double] = _
+
+    @throws(classOf[Exception])
+    override def open(configuration: Configuration): Unit = {
+      // get current weights
+      weights = getRuntimeContext.
+        getBroadcastVariable[BreezeDenseVector[Double]](WEIGHT_VECTOR).get(0)
+    }
+
+    override def map(labeledVector: LabeledVector): (Double, Double) = {
+      // calculate the prediction value (scaled distance from the separating hyperplane)
+      val prediction = weights dot labeledVector.vector.asBreeze
+      val truth = labeledVector.label
+
+      (truth, prediction)
+    }
+  }
+
+
   /** [[FitOperation]] which trains a SVM with soft-margin based on the given training data
set.
     *
     */

http://git-wip-us.apache.org/repos/asf/flink/blob/d163a817/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
index 64b24dc..32746a1 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/regression/MultipleLinearRegression.scala
@@ -21,11 +21,9 @@ package org.apache.flink.ml.regression
 import org.apache.flink.api.common.functions.RichMapFunction
 import org.apache.flink.api.scala.DataSet
 import org.apache.flink.configuration.Configuration
-import org.apache.flink.ml.math.Vector
+import org.apache.flink.ml.math.{DenseVector, BLAS, Vector, vector2Array}
 import org.apache.flink.ml.common._
 
-import org.apache.flink.ml.math.vector2Array
-
 import org.apache.flink.api.scala._
 
 import com.github.fommil.netlib.BLAS.{ getInstance => blas }
@@ -348,6 +346,60 @@ object MultipleLinearRegression {
       LabeledVector(prediction, value)
     }
   }
+
+  /** Calculates the predictions for labeled data with respect to the learned linear model.
+    *
+    * @return A DataSet[(Double, Double)] where each tuple is a (truth, prediction) pair.
+    */
+  implicit def predictLabeledVectors = {
+    new PredictOperation[MultipleLinearRegression, LabeledVector, (Double, Double)] {
+      override def predict(
+                            instance: MultipleLinearRegression,
+                            predictParameters: ParameterMap,
+                            input: DataSet[LabeledVector])
+      : DataSet[(Double, Double)] = {
+        instance.weightsOption match {
+          case Some(weights) => {
+            input.map(new LinearRegressionLabeledPrediction)
+              .withBroadcastSet(weights, WEIGHTVECTOR_BROADCAST)
+          }
+
+          case None => {
+            throw new RuntimeException("The MultipleLinearRegression has not been fitted
to the " +
+              "data. This is necessary to learn the weight vector of the linear function.")
+          }
+        }
+      }
+    }
+  }
+
+  private class LinearRegressionLabeledPrediction
+    extends RichMapFunction[LabeledVector, (Double, Double)] {
+    private var weights: Array[Double] = null
+    private var weight0: Double = 0
+
+
+    @throws(classOf[Exception])
+    override def open(configuration: Configuration): Unit = {
+      val t = getRuntimeContext
+        .getBroadcastVariable[(Array[Double], Double)](WEIGHTVECTOR_BROADCAST)
+
+      val weightsPair = t.get(0)
+
+      weights = weightsPair._1
+      weight0 = weightsPair._2
+    }
+
+    override def map(labeledVector: LabeledVector ): (Double, Double) = {
+
+      val truth = labeledVector.label
+      val dotProduct = BLAS.dot(DenseVector(weights), labeledVector.vector)
+
+      val prediction = dotProduct + weight0
+
+      (truth, prediction)
+    }
+  }
 }
 
 //--------------------------------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/d163a817/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala
b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala
index 55ef056..25c2afb 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/classification/SVMITSuite.scala
@@ -49,4 +49,35 @@ class SVMITSuite extends FlatSpec with Matchers with FlinkTestBase {
         weight should be(expectedWeight +- 0.1)
     }
   }
+
+  it should "make (mostly) correct predictions" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    val svm = SVM().
+      setBlocks(env.getParallelism).
+      setIterations(100).
+      setLocalIterations(100).
+      setRegularization(0.002).
+      setStepsize(0.1).
+      setSeed(0)
+
+    val trainingDS = env.fromCollection(Classification.trainingData)
+
+    svm.fit(trainingDS)
+
+    val threshold = 0.0
+
+    val predictionPairs = svm.predict(trainingDS).map {
+      truthPrediction =>
+        val truth = truthPrediction._1
+        val prediction = truthPrediction._2
+        val thresholdedPrediction = if (prediction > threshold) 1.0 else -1.0
+        (truth, thresholdedPrediction)
+    }
+
+    val absoluteErrorSum = predictionPairs.collect().map{
+      case (truth, prediction) => Math.abs(truth - prediction)}.sum
+
+    absoluteErrorSum should be < 15.0
+  }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/d163a817/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionITSuite.scala
b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionITSuite.scala
index 8be239a..30338e5 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionITSuite.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/regression/MultipleLinearRegressionITSuite.scala
@@ -106,4 +106,28 @@ class MultipleLinearRegressionITSuite
 
     srs should be(RegressionData.expectedPolynomialSquaredResidualSum +- 5)
   }
+
+  it should "make (mostly) correct predictions" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    val mlr = MultipleLinearRegression()
+
+    import RegressionData._
+
+    val parameters = ParameterMap()
+
+    parameters.add(MultipleLinearRegression.Stepsize, 1.0)
+    parameters.add(MultipleLinearRegression.Iterations, 10)
+    parameters.add(MultipleLinearRegression.ConvergenceThreshold, 0.001)
+
+    val inputDS = env.fromCollection(data)
+    mlr.fit(inputDS, parameters)
+
+    val predictionPairs = mlr.predict(inputDS)
+
+    val absoluteErrorSum = predictionPairs.collect().map{
+      case (truth, prediction) => Math.abs(truth - prediction)}.sum
+
+    absoluteErrorSum should be < 50.0
+  }
 }


Mime
View raw message