flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From trohrm...@apache.org
Subject [1/2] flink git commit: [ml] Adds syntactic sugar for map with single broadcast element. Rewrites the optimization framework to to consolidate the loss function.
Date Tue, 02 Jun 2015 14:45:37 GMT
Repository: flink
Updated Branches:
  refs/heads/master d163a817f -> 950b79c59


[ml] Adds syntactic sugar for map with single broadcast element. Rewrites the optimization framework to to consolidate the loss function.

Adds closure cleaner to convenience functions on RichDataSet

Removing regularization from LossFunction and making it part of the optimizer.

This closes #758.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/44dae0c3
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/44dae0c3
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/44dae0c3

Branch: refs/heads/master
Commit: 44dae0c361c6af050d275cc58bcb81041f58db24
Parents: d163a81
Author: Till Rohrmann <trohrmann@apache.org>
Authored: Thu May 28 03:03:24 2015 +0200
Committer: Till Rohrmann <trohrmann@apache.org>
Committed: Tue Jun 2 14:36:00 2015 +0200

----------------------------------------------------------------------
 docs/libs/ml/optimization.md                    | 137 +++---
 .../apache/flink/ml/classification/SVM.scala    |   8 +-
 .../apache/flink/ml/common/WeightVector.scala   |   2 +-
 .../flink/ml/optimization/GradientDescent.scala | 462 ++++++++++---------
 .../flink/ml/optimization/LossFunction.scala    | 143 ++----
 .../ml/optimization/PartialLossFunction.scala   |  67 +++
 .../ml/optimization/PredictionFunction.scala    |   8 +-
 .../flink/ml/optimization/Regularization.scala  | 228 ---------
 .../apache/flink/ml/optimization/Solver.scala   |  85 +---
 .../scala/org/apache/flink/ml/package.scala     |  75 ++-
 .../optimization/GradientDescentITSuite.scala   |  65 ++-
 .../ml/optimization/LossFunctionITSuite.scala   |  25 +-
 .../PredictionFunctionITSuite.scala             |   6 +-
 .../ml/optimization/RegularizationITSuite.scala | 119 -----
 .../flink/ml/pipeline/PipelineITSuite.scala     |   6 +-
 15 files changed, 559 insertions(+), 877 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/44dae0c3/docs/libs/ml/optimization.md
----------------------------------------------------------------------
diff --git a/docs/libs/ml/optimization.md b/docs/libs/ml/optimization.md
index 26207e9..110383d 100644
--- a/docs/libs/ml/optimization.md
+++ b/docs/libs/ml/optimization.md
@@ -75,13 +75,10 @@ The $L_2$ penalty penalizes large weights, favoring solutions with more small we
 few large ones.
 The $L_1$ penalty can be used to drive a number of the solution coefficients to 0, thereby
 producing sparse solutions.
-The optimization framework in Flink supports the $L_1$ and $L_2$ penalties, as well as no
-regularization. The
-regularization parameter $\lambda$ in $\eqref{eq:objectiveFunc}$ determines the amount of
-regularization applied to the model,
-and is usually determined through model cross-validation. A good comparison of regularization
-types can
-be found in [this](http://www.robotics.stanford.edu/~ang/papers/icml04-l1l2.pdf) paper by Andrew Ng.
+The regularization constant $\lambda$ in $\eqref{eq:objectiveFunc}$ determines the amount of regularization applied to the model,
+and is usually determined through model cross-validation. 
+A good comparison of regularization types can be found in [this](http://www.robotics.stanford.edu/~ang/papers/icml04-l1l2.pdf) paper by Andrew Ng.
+Which regularization type is supported depends on the actually used optimization algorithm.
 
 ## Stochastic Gradient Descent
 
@@ -107,6 +104,33 @@ The current implementation of SGD  uses the whole partition, making it
 effectively a batch gradient descent. Once a sampling operator has been introduced in Flink, true
 mini-batch SGD will be performed.
 
+### Regularization
+
+FlinkML supports Stochastic Gradient Descent with L1, L2 and no regularization.
+The following list contains a mapping between the implementing classes and the regularization function.
+
+<table class="table table-bordered">
+  <thead>
+    <tr>
+      <th class="text-left" style="width: 20%">Class Name</th>
+      <th class="text-center">Regularization function $R(\wv)$</th>
+    </tr>
+  </thead>
+  <tbody>
+    <tr>
+      <td><code>SimpleGradient</code></td>
+      <td>$R(\wv) = 0$</td>
+    </tr>
+    <tr>
+      <td><code>GradientDescentL1</code></td>
+      <td>$R(\wv) = \norm{\wv}_1$</td>
+    </tr>
+    <tr>
+      <td><code>GradientDescentL2</code></td>
+      <td>$R(\wv) = \frac{1}{2}\norm{\wv}_2^2$</td>
+    </tr>
+  </tbody>
+</table>
 
 ### Parameters
 
@@ -124,22 +148,12 @@ mini-batch SGD will be performed.
         <td><strong>LossFunction</strong></td>
         <td>
           <p>
-            The class of the loss function to be used. See <a href="#loss-function-values">loss function values</a> for a list of supported values. (Default value:
-            <strong>SquaredLoss</strong>, used for regression tasks)
+            The loss function to be optimized. (Default value: <strong>None</strong>)
           </p>
         </td>
       </tr>
       <tr>
-        <td><strong>RegularizationType</strong></td>
-        <td>
-          <p>
-            The type of regularization penalty to apply. See <a href="#regularization-function-values">regularization function values</a> for a list of supported values. (Default value:
-            <strong>NoRegularization</strong>)
-          </p>
-        </td>
-      </tr>
-      <tr>
-        <td><strong>RegularizationParameter</strong></td>
+        <td><strong>RegularizationConstant</strong></td>
         <td>
           <p>
             The amount of regularization to apply. (Default value: <strong>0.0</strong>)
@@ -147,17 +161,6 @@ mini-batch SGD will be performed.
         </td>
       </tr>
       <tr>
-        <td><strong>PredictionFunction</strong></td>
-        <td>
-          <p>
-            Class that provides the prediction function, used to calculate $\hat{y}$ and the
-            prediction gradient based on the weights $\wv$ and the example features $\x$. See
-            <a href="#prediction-function-values">prediction function values</a> for a list of supported values.
-            (Default value: <strong>LinearPrediction</strong>)
-          </p>
-        </td>
-      </tr>
-      <tr>
         <td><strong>Iterations</strong></td>
         <td>
           <p>
@@ -166,10 +169,10 @@ mini-batch SGD will be performed.
         </td>
       </tr>
       <tr>
-        <td><strong>Stepsize</strong></td>
+        <td><strong>LearningRate</strong></td>
         <td>
           <p>
-            Initial step size for the gradient descent method.
+            Initial learning rate for the gradient descent method.
             This value controls how far the gradient descent method moves in the opposite direction
             of the gradient.
             (Default value: <strong>0.1</strong>)
@@ -189,7 +192,20 @@ mini-batch SGD will be performed.
     </tbody>
   </table>
   
-#### Loss Function Values ##
+### Loss Function
+
+The loss function which is minimized has to implement the `LossFunction` interface, which defines methods to compute the loss and the gradient of it.
+Either one defines ones own `LossFunction` or one uses the `GenericLossFunction` class which constructs the loss function from an outer loss function and a prediction function.
+An example can be seen here
+
+```Scala
+val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction) 
+```
+
+The full list of supported outer loss functions can be found [here](#partial-loss-function-values).
+The full list of supported prediction functions can be found [here](#prediction-function-values).
+  
+#### Partial Loss Function Values ##
 
   <table class="table table-bordered">
     <thead>
@@ -240,53 +256,10 @@ mini-batch SGD will be performed.
       </tbody>
     </table>
 
-#### Regularization Function Values ##
-
-  <table class="table table-bordered">
-    <thead>
-      <tr>
-        <th class="text-left" style="width: 20%">Regularization Name</th>
-        <th class="text-center">Description</th>
-        <th class="text-center">$R(\wv)$</th>
-      </tr>
-    </thead>
-    <tbody>
-      <tr>
-        <td><strong>L1Regularization</strong></td>
-        <td>
-          <p>
-            This type of regularization will drive small weights to 0, potentially providing sparse
-            solutions.
-          </p>
-        </td>
-        <td class="text-center">$\norm{\wv}_1$</td>
-      </tr>
-      <tr>
-        <td><strong>L2Regularization</strong></td>
-        <td>
-          <p>
-            This type of regularization will keep weights from growing too large, favoring solutions
-            with more small weights, rather than few large ones.
-          </p>
-        </td>
-        <td class="text-center">$\frac{1}{2}\norm{\wv}_2^2$</td>
-      </tr>
-      <tr>
-        <td><strong>NoRegularization</strong></td>
-        <td>
-          <p>
-            No regularization is applied to the weights when this regularization type is used.
-          </p>
-        </td>
-        <td class="text-center">$0$</td>
-      </tr>
-    </tbody>
-  </table>
-
 ### Examples
 
 In the Flink implementation of SGD, given a set of examples in a `DataSet[LabeledVector]` and
-optionally some initial weights, we can use `GradientDescent.optimize()` in order to optimize
+optionally some initial weights, we can use `GradientDescentL1.optimize()` in order to optimize
 the weights for the given data.
 
 The user can provide an initial `DataSet[WeightVector]`,
@@ -298,12 +271,11 @@ weight vector. This allows us to avoid applying regularization to the intercept.
 
 {% highlight scala %}
 // Create stochastic gradient descent solver
-val sgd = GradientDescent()
+val sgd = GradientDescentL1()
   .setLossFunction(SquaredLoss())
-  .setRegularizationType(L1Regularization())
-  .setRegularizationParameter(0.2)
+  .setRegularizationConstant(0.2)
   .setIterations(100)
-  .setStepsize(0.01)
+  .setLearningRate(0.01)
 
 
 // Obtain data
@@ -311,7 +283,4 @@ val trainingDS: DataSet[LabeledVector] = ...
 
 // Optimize the weights, according to the provided data
 val weightDS = sgd.optimize(trainingDS)
-
-// We can now use weightDS to make predictions
-
 {% endhighlight %}

http://git-wip-us.apache.org/repos/asf/flink/blob/44dae0c3/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
index 95f2b23..e01735f 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/SVM.scala
@@ -540,17 +540,17 @@ object SVM{
 
     // compute projected gradient
     var proj_grad = if(alpha  <= 0.0){
-      Math.min(grad, 0)
+      math.min(grad, 0)
     } else if(alpha >= 1.0) {
-      Math.max(grad, 0)
+      math.max(grad, 0)
     } else {
       grad
     }
 
-    if(Math.abs(grad) != 0.0){
+    if(math.abs(grad) != 0.0){
       val qii = x dot x
       val newAlpha = if(qii != 0.0){
-        Math.min(Math.max((alpha - (grad / qii)), 0.0), 1.0)
+        math.min(math.max((alpha - (grad / qii)), 0.0), 1.0)
       } else {
         1.0
       }

http://git-wip-us.apache.org/repos/asf/flink/blob/44dae0c3/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/WeightVector.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/WeightVector.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/WeightVector.scala
index 247d92e..4628c71 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/WeightVector.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/WeightVector.scala
@@ -29,4 +29,4 @@ import org.apache.flink.ml.math.Vector
   * @param weights The vector of weights
   * @param intercept The intercept (bias) weight
   */
-case class WeightVector(weights: Vector, var intercept: Double) extends Serializable {}
+case class WeightVector(weights: Vector, intercept: Double) extends Serializable {}

http://git-wip-us.apache.org/repos/asf/flink/blob/44dae0c3/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala
index ef171f5..78bad70 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/GradientDescent.scala
@@ -19,15 +19,14 @@
 
 package org.apache.flink.ml.optimization
 
-import org.apache.flink.api.common.functions.RichMapFunction
 import org.apache.flink.api.scala._
-import org.apache.flink.configuration.Configuration
 import org.apache.flink.ml.common._
 import org.apache.flink.ml.math._
-import org.apache.flink.ml.optimization.IterativeSolver.{ConvergenceThreshold, Iterations, Stepsize}
+import org.apache.flink.ml.optimization.IterativeSolver.{ConvergenceThreshold, Iterations, LearningRate}
 import org.apache.flink.ml.optimization.Solver._
+import org.apache.flink.ml._
 
-/** This [[Solver]] performs Stochastic Gradient Descent optimization using mini batches
+/** Base class which performs Stochastic Gradient Descent optimization using mini batches.
   *
   * For each labeled vector in a mini batch the gradient is computed and added to a partial
   * gradient. The partial gradients are then summed and divided by the size of the batches. The
@@ -38,46 +37,14 @@ import org.apache.flink.ml.optimization.Solver._
   *
   *  The parameters to tune the algorithm are:
   *                      [[Solver.LossFunction]] for the loss function to be used,
-  *                      [[Solver.RegularizationType]] for the type of regularization,
-  *                      [[Solver.RegularizationParameter]] for the regularization parameter,
+  *                      [[Solver.RegularizationConstant]] for the regularization parameter,
   *                      [[IterativeSolver.Iterations]] for the maximum number of iteration,
-  *                      [[IterativeSolver.Stepsize]] for the learning rate used.
+  *                      [[IterativeSolver.LearningRate]] for the learning rate used.
   *                      [[IterativeSolver.ConvergenceThreshold]] when provided the algorithm will
   *                      stop the iterations if the relative change in the value of the objective
   *                      function between successive iterations is is smaller than this value.
   */
-class GradientDescent() extends IterativeSolver {
-
-  import Solver.WEIGHTVECTOR_BROADCAST
-
-  /** Performs one iteration of Stochastic Gradient Descent using mini batches
-    *
-    * @param data A Dataset of LabeledVector (label, features) pairs
-    * @param currentWeights A Dataset with the current weights to be optimized as its only element
-    * @return A Dataset containing the weights after one stochastic gradient descent step
-    */
-  private def SGDStep(data: DataSet[(LabeledVector)], currentWeights: DataSet[WeightVector]):
-  DataSet[WeightVector] = {
-
-    // TODO: Sample from input to realize proper SGD
-    data.map {
-      new GradientCalculation
-    }.withBroadcastSet(currentWeights, WEIGHTVECTOR_BROADCAST).reduce {
-      (left, right) =>
-        val (leftGradVector, leftLoss, leftCount) = left
-        val (rightGradVector, rightLoss, rightCount) = right
-        // Add the left gradient to the right one
-        BLAS.axpy(1.0, leftGradVector.weights, rightGradVector.weights)
-        val gradients = WeightVector(
-          rightGradVector.weights, leftGradVector.intercept + rightGradVector.intercept)
-
-        (gradients , leftLoss + rightLoss, leftCount + rightCount)
-    }.map {
-      new WeightsUpdate
-    }.withBroadcastSet(currentWeights, WEIGHTVECTOR_BROADCAST)
-  }
-
-
+abstract class GradientDescent extends IterativeSolver {
 
   /** Provides a solution for the given optimization problem
     *
@@ -88,217 +55,296 @@ class GradientDescent() extends IterativeSolver {
   override def optimize(
     data: DataSet[LabeledVector],
     initialWeights: Option[DataSet[WeightVector]]): DataSet[WeightVector] = {
+
     val numberOfIterations: Int = parameters(Iterations)
     val convergenceThresholdOption: Option[Double] = parameters.get(ConvergenceThreshold)
+    val lossFunction = parameters(LossFunction)
+    val learningRate = parameters(LearningRate)
+    val regularizationConstant = parameters(RegularizationConstant)
 
     // Initialize weights
     val initialWeightsDS: DataSet[WeightVector] = createInitialWeightsDS(initialWeights, data)
 
     // Perform the iterations
-    val optimizedWeights = convergenceThresholdOption match {
+    convergenceThresholdOption match {
       // No convergence criterion
       case None =>
-        initialWeightsDS.iterate(numberOfIterations) {
-          weightVectorDS => {
-            SGDStep(data, weightVectorDS)
-          }
-        }
+        optimizeWithoutConvergenceCriterion(
+          data,
+          initialWeightsDS,
+          numberOfIterations,
+          regularizationConstant,
+          learningRate,
+          lossFunction)
       case Some(convergence) =>
-        // Calculates the regularized loss, from the data and given weights
-        def lossCalculation(data: DataSet[LabeledVector], weightDS: DataSet[WeightVector]):
-        DataSet[Double] = {
-          data
-            .map {new LossCalculation}.withBroadcastSet(weightDS, WEIGHTVECTOR_BROADCAST)
-            .reduce {
-              (left, right) =>
-                val (leftLoss, leftCount) = left
-                val (rightLoss, rightCount) = right
-                (leftLoss + rightLoss, rightCount + leftCount)
+        optimizeWithConvergenceCriterion(
+          data,
+          initialWeightsDS,
+          numberOfIterations,
+          regularizationConstant,
+          learningRate,
+          convergence,
+          lossFunction
+        )
+    }
+  }
+
+  def optimizeWithConvergenceCriterion(
+      dataPoints: DataSet[LabeledVector],
+      initialWeightsDS: DataSet[WeightVector],
+      numberOfIterations: Int,
+      regularizationConstant: Double,
+      learningRate: Double,
+      convergenceThreshold: Double,
+      lossFunction: LossFunction)
+    : DataSet[WeightVector] = {
+    // We have to calculate for each weight vector the sum of squared residuals,
+    // and then sum them and apply regularization
+    val initialLossSumDS = calculateLoss(dataPoints, initialWeightsDS, lossFunction)
+
+    // Combine weight vector with the current loss
+    val initialWeightsWithLossSum = initialWeightsDS.mapWithBcVariable(initialLossSumDS){
+      (weights, loss) => (weights, loss)
+    }
+
+    val resultWithLoss = initialWeightsWithLossSum.iterateWithTermination(numberOfIterations) {
+      weightsWithPreviousLossSum =>
+
+        // Extract weight vector and loss
+        val previousWeightsDS = weightsWithPreviousLossSum.map{_._1}
+        val previousLossSumDS = weightsWithPreviousLossSum.map{_._2}
+
+        val currentWeightsDS = SGDStep(
+          dataPoints,
+          previousWeightsDS,
+          lossFunction,
+          regularizationConstant,
+          learningRate)
+
+        val currentLossSumDS = calculateLoss(dataPoints, currentWeightsDS, lossFunction)
+
+        // Check if the relative change in the loss is smaller than the
+        // convergence threshold. If yes, then terminate i.e. return empty termination data set
+        val termination = previousLossSumDS.filterWithBcVariable(currentLossSumDS){
+          (previousLoss, currentLoss) => {
+            if (previousLoss <= 0) {
+              false
+            } else {
+              scala.math.abs((previousLoss - currentLoss)/previousLoss) >= convergenceThreshold
             }
-            .map{new RegularizedLossCalculation}.withBroadcastSet(weightDS, WEIGHTVECTOR_BROADCAST)
+          }
         }
-        // We have to calculate for each weight vector the sum of squared residuals,
-        // and then sum them and apply regularization
-        val initialLossSumDS = lossCalculation(data, initialWeightsDS)
-
-        // Combine weight vector with the current loss
-        val initialWeightsWithLossSum = initialWeightsDS.
-          crossWithTiny(initialLossSumDS).setParallelism(1)
-
-        val resultWithLoss = initialWeightsWithLossSum.
-          iterateWithTermination(numberOfIterations) {
-          weightsWithLossSum =>
-
-            // Extract weight vector and loss
-            val previousWeightsDS = weightsWithLossSum.map{_._1}
-            val previousLossSumDS = weightsWithLossSum.map{_._2}
-
-            val currentWeightsDS = SGDStep(data, previousWeightsDS)
-
-            val currentLossSumDS = lossCalculation(data, currentWeightsDS)
-
-            // Check if the relative change in the loss is smaller than the
-            // convergence threshold. If yes, then terminate i.e. return empty termination data set
-            val termination = previousLossSumDS.crossWithTiny(currentLossSumDS).setParallelism(1).
-              filter{
-              pair => {
-                val (previousLoss, currentLoss) = pair
-
-                if (previousLoss <= 0) {
-                  false
-                } else {
-                  math.abs((previousLoss - currentLoss)/previousLoss) >= convergence
-                }
-              }
-            }
 
-            // Result for new iteration
-            (currentWeightsDS cross currentLossSumDS, termination)
-        }
-        // Return just the weights
-        resultWithLoss.map{_._1}
+        // Result for new iteration
+        (currentWeightsDS.mapWithBcVariable(currentLossSumDS)((w, l) => (w, l)), termination)
+    }
+    // Return just the weights
+    resultWithLoss.map{_._1}
+  }
+
+  def optimizeWithoutConvergenceCriterion(
+      data: DataSet[LabeledVector],
+      initialWeightsDS: DataSet[WeightVector],
+      numberOfIterations: Int,
+      regularizationConstant: Double,
+      learningRate: Double,
+      lossFunction: LossFunction)
+    : DataSet[WeightVector] = {
+    initialWeightsDS.iterate(numberOfIterations) {
+      weightVectorDS => {
+        SGDStep(data, weightVectorDS, lossFunction, regularizationConstant, learningRate)
+      }
     }
-    optimizedWeights
   }
 
-  /** Calculates the loss value, given a labeled vector and the current weight vector
+  /** Performs one iteration of Stochastic Gradient Descent using mini batches
     *
-    * The weight vector is received as a broadcast variable.
+    * @param data A Dataset of LabeledVector (label, features) pairs
+    * @param currentWeights A Dataset with the current weights to be optimized as its only element
+    * @return A Dataset containing the weights after one stochastic gradient descent step
     */
-  private class LossCalculation extends RichMapFunction[LabeledVector, (Double, Int)] {
+  private def SGDStep(
+    data: DataSet[(LabeledVector)],
+    currentWeights: DataSet[WeightVector],
+    lossFunction: LossFunction,
+    regularizationConstant: Double,
+    learningRate: Double)
+  : DataSet[WeightVector] = {
+
+    data.mapWithBcVariable(currentWeights){
+      (data, weightVector) => (lossFunction.gradient(data, weightVector), 1)
+    }.reduce{
+      (left, right) =>
+        val (leftGradVector, leftCount) = left
+        val (rightGradVector, rightCount) = right
+        // Add the left gradient to the right one
+        BLAS.axpy(1.0, leftGradVector.weights, rightGradVector.weights)
+        val gradients = WeightVector(
+          rightGradVector.weights, leftGradVector.intercept + rightGradVector.intercept)
 
-    var weightVector: WeightVector = null
+        (gradients , leftCount + rightCount)
+    }.mapWithBcVariableIteration(currentWeights){
+      (gradientCount, weightVector, iteration) => {
+        val (WeightVector(weights, intercept), count) = gradientCount
 
-    @throws(classOf[Exception])
-    override def open(configuration: Configuration): Unit = {
-      val list = this.getRuntimeContext.
-        getBroadcastVariable[WeightVector](WEIGHTVECTOR_BROADCAST)
+        BLAS.scal(1.0/count, weights)
 
-      weightVector = list.get(0)
-    }
+        val gradient = WeightVector(weights, intercept/count)
 
-    override def map(example: LabeledVector): (Double, Int) = {
-      val lossFunction = parameters(LossFunction)
-      val predictionFunction = parameters(PredictionFunction)
+        val effectiveLearningRate = learningRate/Math.sqrt(iteration)
 
-      val loss = lossFunction.lossValue(
-        example,
-        weightVector,
-        predictionFunction)
+        val newWeights = takeStep(
+          weightVector.weights,
+          gradient.weights,
+          regularizationConstant,
+          effectiveLearningRate)
 
-      (loss, 1)
+        WeightVector(
+          newWeights,
+          weightVector.intercept - effectiveLearningRate * gradient.intercept)
+      }
     }
   }
 
-/** Calculates the regularized loss value, given the loss and the current weight vector
-  *
-  * The weight vector is received as a broadcast variable.
-  */
-private class RegularizedLossCalculation extends RichMapFunction[(Double, Int), Double] {
-
-  var weightVector: WeightVector = null
-
-  @throws(classOf[Exception])
-  override def open(configuration: Configuration): Unit = {
-    val list = this.getRuntimeContext.
-      getBroadcastVariable[WeightVector](WEIGHTVECTOR_BROADCAST)
-
-    weightVector = list.get(0)
-  }
-
-  override def map(lossAndCount: (Double, Int)): Double = {
-    val (lossSum, count) = lossAndCount
-    val regType = parameters(RegularizationType)
-    val regParameter = parameters(RegularizationParameter)
-
-    val regularizedLoss = {
-      regType.regLoss(
-        lossSum/count,
-        weightVector.weights,
-        regParameter)
+  /** Calculates the new weights based on the gradient
+    *
+    * @param weightVector
+    * @param gradient
+    * @param regularizationConstant
+    * @param learningRate
+    * @return
+    */
+  def takeStep(
+    weightVector: Vector,
+    gradient: Vector,
+    regularizationConstant: Double,
+    learningRate: Double
+    ): Vector
+
+  /** Calculates the regularized loss, from the data and given weights.
+    *
+    * @param data
+    * @param weightDS
+    * @param lossFunction
+    * @return
+    */
+  private def calculateLoss(
+      data: DataSet[LabeledVector],
+      weightDS: DataSet[WeightVector],
+      lossFunction: LossFunction)
+    : DataSet[Double] = {
+    data.mapWithBcVariable(weightDS){
+      (data, weightVector) => (lossFunction.loss(data, weightVector), 1)
+    }.reduce{
+      (left, right) => (left._1 + right._1, left._2 + right._2)
+    }.map {
+      lossCount => lossCount._1 / lossCount._2
     }
-    regularizedLoss
   }
 }
 
-  /** Performs the update of the weights, according to the given gradients and regularization type.
+/** Implementation of a SGD solver with L2 regularization.
+  *
+  * The regularization function is `1/2 ||w||_2^2` with `w` being the weight vector.
+  */
+class GradientDescentL2 extends GradientDescent {
+
+  /** Calculates the new weights based on the gradient
     *
+    * @param weightVector
+    * @param gradient
+    * @param regularizationConstant
+    * @param learningRate
+    * @return
     */
-  private class WeightsUpdate() extends
-  RichMapFunction[(WeightVector, Double, Int), WeightVector] {
+  override def takeStep(
+      weightVector: Vector,
+      gradient: Vector,
+      regularizationConstant: Double,
+      learningRate: Double)
+    : Vector = {
+    // add the gradient of the L2 regularization
+    BLAS.axpy(regularizationConstant, weightVector, gradient)
+
+    // update the weights according to the learning rate
+    BLAS.axpy(-learningRate, gradient, weightVector)
+
+    weightVector
+  }
+}
 
-    var weightVector: WeightVector = null
+object GradientDescentL2 {
+  def apply() = new GradientDescentL2
+}
 
-    @throws(classOf[Exception])
-    override def open(configuration: Configuration): Unit = {
-      val list = this.getRuntimeContext.
-        getBroadcastVariable[WeightVector](WEIGHTVECTOR_BROADCAST)
+/** Implementation of a SGD solver with L1 regularization.
+  *
+  * The regularization function is `||w||_1` with `w` being the weight vector.
+  */
+class GradientDescentL1 extends GradientDescent {
 
-      weightVector = list.get(0)
+  /** Calculates the new weights based on the gradient.
+    *
+    * @param weightVector
+    * @param gradient
+    * @param regularizationConstant
+    * @param learningRate
+    * @return
+    */
+  override def takeStep(
+      weightVector: Vector,
+      gradient: Vector,
+      regularizationConstant: Double,
+      learningRate: Double)
+    : Vector = {
+    // Update weight vector with gradient. L1 regularization has no gradient, the proximal operator
+    // does the job.
+    BLAS.axpy(-learningRate, gradient, weightVector)
+
+    // Apply proximal operator (soft thresholding)
+    val shrinkageVal = regularizationConstant * learningRate
+    var i = 0
+    while (i < weightVector.size) {
+      val wi = weightVector(i)
+      weightVector(i) = scala.math.signum(wi) *
+        scala.math.max(0.0, scala.math.abs(wi) - shrinkageVal)
+      i += 1
     }
 
-    override def map(gradientLossAndCount: (WeightVector, Double, Int)): WeightVector = {
-      val regType = parameters(RegularizationType)
-      val regParameter = parameters(RegularizationParameter)
-      val stepsize = parameters(Stepsize)
-      val weightGradients = gradientLossAndCount._1
-      val lossSum = gradientLossAndCount._2
-      val count = gradientLossAndCount._3
-
-      // Scale the gradients according to batch size
-      BLAS.scal(1.0/count, weightGradients.weights)
-
-      // Calculate the regularized loss and, if the regularization is differentiable, add the
-      // regularization term to the gradient as well, in-place
-      // Note(tvas): adjustedLoss is never used currently, but I'd like to leave it here for now.
-      // We can probably maintain a loss history as the optimization package grows towards a
-      // Breeze-like interface (see breeze.optimize.FirstOrderMinimizer)
-      val adjustedLoss = {
-        regType match {
-          case x: DiffRegularization => {
-            x.regularizedLossAndGradient(
-              lossSum / count,
-              weightVector.weights,
-              weightGradients.weights,
-              regParameter)
-          }
-          case x: Regularization => {
-            x.regLoss(
-              lossSum / count,
-              weightVector.weights,
-              regParameter)
-          }
-        }
-      }
-
-      val weight0Gradient = weightGradients.intercept / count
-
-      val iteration = getIterationRuntimeContext.getSuperstepNumber
-
-      // Scale initial stepsize by the inverse square root of the iteration number
-      // TODO(tvas): There are more ways to determine the stepsize, possible low-effort extensions
-      // here
-      val effectiveStepsize = stepsize/math.sqrt(iteration)
-
-      // Take the gradient step for the intercept
-      weightVector.intercept -= effectiveStepsize * weight0Gradient
-
-      // Take the gradient step for the weight vector, possibly applying regularization
-      // TODO(tvas): This should be moved to a takeStep() function that takes regType plus all these
-      // arguments, this would decouple the update step from the regularization classes
-      regType.takeStep(weightVector.weights, weightGradients.weights,
-        effectiveStepsize, regParameter)
-
-      weightVector
-    }
+    weightVector
   }
 }
 
-object GradientDescent {
-  def apply(): GradientDescent = {
-    new GradientDescent()
-  }
+object GradientDescentL1 {
+  def apply() = new GradientDescentL1
 }
 
+/** Implementation of a SGD solver without regularization.
+  *
+  * No regularization is applied.
+  */
+class SimpleGradientDescent extends GradientDescent {
 
+  /** Calculates the new weights based on the gradient.
+    *
+    * @param weightVector
+    * @param gradient
+    * @param regularizationConstant
+    * @param learningRate
+    * @return
+    */
+  override def takeStep(
+      weightVector: Vector,
+      gradient: Vector,
+      regularizationConstant: Double,
+      learningRate: Double)
+    : Vector = {
+    // Update the weight vector
+    BLAS.axpy(-learningRate, gradient, weightVector)
+    weightVector
+  }
+}
 
+object SimpleGradientDescent{
+  def apply() = new SimpleGradientDescent
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/44dae0c3/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/LossFunction.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/LossFunction.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/LossFunction.scala
index d612b90..1ff5d97 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/LossFunction.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/LossFunction.scala
@@ -19,129 +19,78 @@
 package org.apache.flink.ml.optimization
 
 import org.apache.flink.ml.common.{WeightVector, LabeledVector}
-import org.apache.flink.ml.math.{Vector => FlinkVector, BLAS}
+import org.apache.flink.ml.math.BLAS
 
 /** Abstract class that implements some of the functionality for common loss functions
   *
   * A loss function determines the loss term $L(w) of the objective function  $f(w) = L(w) +
   * \lambda R(w)$ for prediction tasks, the other being regularization, $R(w)$.
   *
+  * The regularization is specific to the used optimization algorithm and, thus, implemented there.
+  *
   * We currently only support differentiable loss functions, in the future this class
-  * could be changed to DiffLossFunction in order to support other types, such absolute loss.
+  * could be changed to DiffLossFunction in order to support other types, such as absolute loss.
   */
-abstract class LossFunction extends Serializable{
+trait LossFunction extends Serializable {
 
-  /** Calculates the loss for a given prediction/truth pair
+  /** Calculates the loss given the prediction and label value
     *
-    * @param prediction The predicted value
-    * @param truth The true value
+    * @param dataPoint
+    * @param weightVector
+    * @return
     */
-  protected def loss(prediction: Double, truth: Double): Double
-
-  /** Calculates the derivative of the loss function with respect to the prediction
-    *
-    * @param prediction The predicted value
-    * @param truth The true value
-    */
-  protected def lossDerivative(prediction: Double, truth: Double): Double
+  def loss(dataPoint: LabeledVector, weightVector: WeightVector): Double = {
+    lossGradient(dataPoint, weightVector)._1
+  }
 
-  /** Compute the gradient and the loss for the given data.
-    * The provided cumGradient is updated in place.
+  /** Calculates the gradient of the loss function given a data point and weight vector
     *
-    * @param example The features and the label associated with the example
-    * @param weights The current weight vector
-    * @param cumGradient The vector to which the gradient will be added to, in place.
-    * @param predictionFunction A [[PredictionFunction]] object which provides a way to calculate
-    *                           a prediction and its gradient from the features and weights
-    * @return A tuple containing the computed loss as its first element and a the loss derivative as
-    *         its second element. The gradient is updated in-place.
+    * @param dataPoint
+    * @param weightVector
+    * @return
     */
-  def lossAndGradient(
-      example: LabeledVector,
-      weights: WeightVector,
-      cumGradient: FlinkVector,
-      predictionFunction: PredictionFunction):
-  (Double, Double) = {
-    val features = example.vector
-    val label = example.label
-    // TODO(tvas): We could also provide for the case where we don't want an intercept value
-    // i.e. data already centered
-    val prediction = predictionFunction.predict(features, weights)
-    val predictionGradient = predictionFunction.gradient(features, weights)
-    val lossValue: Double = loss(prediction, label)
-    // The loss derivative is used to update the intercept
-    val lossDeriv = lossDerivative(prediction, label)
-    // Restrict the value of the loss derivative to avoid numerical instabilities
-    val restrictedLossDeriv: Double = {
-      if (lossDeriv < -IterativeSolver.MAX_DLOSS) {
-        -IterativeSolver.MAX_DLOSS
-      }
-      else if (lossDeriv > IterativeSolver.MAX_DLOSS) {
-        IterativeSolver.MAX_DLOSS
-      }
-      else {
-        lossDeriv
-      }
-    }
-    // Update the gradient
-    BLAS.axpy(restrictedLossDeriv, predictionGradient, cumGradient)
-    (lossValue, lossDeriv)
+  def gradient(dataPoint: LabeledVector, weightVector: WeightVector): WeightVector = {
+    lossGradient(dataPoint, weightVector)._2
   }
 
-  /** Compute the loss for the given data.
+  /** Calculates the gradient as well as the loss given a data point and the weight vector
     *
-    * @param example The features and the label associated with the example
-    * @param weights The current weight vector
-    * @param predictionFunction A [[PredictionFunction]] object which provides a way to calculate
-    *                           a prediction and its gradient from the features and weights
-    * @return The calculated loss value
+    * @param dataPoint
+    * @param weightVector
+    * @return
     */
-  def lossValue(
-      example: LabeledVector,
-      weights: WeightVector,
-      predictionFunction: PredictionFunction): Double = {
-    val features = example.vector
-    val label = example.label
-    // TODO(tvas): We could also provide for the case where we don't want an intercept value
-    // i.e. data already centered
-    val prediction = predictionFunction.predict(features, weights)
-    val lossValue: Double = loss(prediction, label)
-    lossValue
-  }
-
+  def lossGradient(dataPoint: LabeledVector, weightVector: WeightVector): (Double, WeightVector)
 }
 
-trait ClassificationLoss extends LossFunction
-trait RegressionLoss extends LossFunction
-
-// TODO(tvas): Implement LogisticLoss, HingeLoss.
-
-/** Squared loss function where $L(w) = \frac{1}{2} (w^{T} x - y)^2$
+/** Generic loss function which lets you build a loss function out of the [[PartialLossFunction]]
+  * and the [[PredictionFunction]].
   *
+  * @param partialLossFunction
+  * @param predictionFunction
   */
-class SquaredLoss extends RegressionLoss {
-  /** Calculates the loss for a given prediction/truth pair
-    *
-    * @param prediction The predicted value
-    * @param truth The true value
-    */
-  protected override def loss(prediction: Double, truth: Double): Double = {
-    0.5 * (prediction - truth) * (prediction - truth)
-  }
+case class GenericLossFunction(
+    partialLossFunction: PartialLossFunction,
+    predictionFunction: PredictionFunction)
+  extends LossFunction {
 
-  /** Calculates the derivative of the loss function with respect to the prediction
+  /** Calculates the gradient as well as the loss given a data point and the weight vector
     *
-    * @param prediction The predicted value
-    * @param truth The true value
+    * @param dataPoint
+    * @param weightVector
+    * @return
     */
-  protected override def lossDerivative(prediction: Double, truth: Double): Double = {
-    prediction - truth
-  }
+  def lossGradient(dataPoint: LabeledVector, weightVector: WeightVector): (Double, WeightVector) = {
+    val prediction = predictionFunction.predict(dataPoint.vector, weightVector)
 
-}
+    val loss = partialLossFunction.loss(prediction, dataPoint.label)
+
+    val lossDerivative = partialLossFunction.derivative(prediction, dataPoint.label)
+
+    val WeightVector(weightGradient, interceptGradient) =
+      predictionFunction.gradient(dataPoint.vector, weightVector)
+
+    BLAS.scal(lossDerivative, weightGradient)
 
-object SquaredLoss {
-  def apply(): SquaredLoss = {
-    new SquaredLoss
+    (loss, WeightVector(weightGradient, lossDerivative * interceptGradient))
   }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/44dae0c3/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PartialLossFunction.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PartialLossFunction.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PartialLossFunction.scala
new file mode 100644
index 0000000..5cf69b6
--- /dev/null
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PartialLossFunction.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.optimization
+
+/** Represents loss functions which can be used with the [[GenericLossFunction]].
+  *
+  */
+trait PartialLossFunction extends Serializable {
+  /** Calculates the loss depending on the label and the prediction
+    *
+    * @param prediction
+    * @param label
+    * @return
+    */
+  def loss(prediction: Double, label: Double): Double
+
+  /** Calculates the derivative of the [[PartialLossFunction]]
+    * 
+    * @param prediction
+    * @param label
+    * @return
+    */
+  def derivative(prediction: Double, label: Double): Double
+}
+
+/** Squared loss function which can be used with the [[GenericLossFunction]]
+  *
+  * The [[SquaredLoss]] function implements `1/2 (prediction - label)^2`
+  */
+object SquaredLoss extends PartialLossFunction {
+
+  /** Calculates the loss depending on the label and the prediction
+    *
+    * @param prediction
+    * @param label
+    * @return
+    */
+  override def loss(prediction: Double, label: Double): Double = {
+    0.5 * (prediction - label) * (prediction - label)
+  }
+
+  /** Calculates the derivative of the [[PartialLossFunction]]
+    *
+    * @param prediction
+    * @param label
+    * @return
+    */
+  override def derivative(prediction: Double, label: Double): Double = {
+    (prediction - label)
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/44dae0c3/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PredictionFunction.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PredictionFunction.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PredictionFunction.scala
index 91b0f39..38f340a 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PredictionFunction.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/PredictionFunction.scala
@@ -25,14 +25,16 @@ import org.apache.flink.ml.math.{Vector => FlinkVector, BLAS}
 abstract class PredictionFunction extends Serializable {
   def predict(features: FlinkVector, weights: WeightVector): Double
 
-  def gradient(features: FlinkVector, weights: WeightVector): FlinkVector
+  def gradient(features: FlinkVector, weights: WeightVector): WeightVector
 }
 
 /** A linear prediction function **/
-class LinearPrediction extends PredictionFunction {
+object LinearPrediction extends PredictionFunction {
   override def predict(features: FlinkVector, weightVector: WeightVector): Double = {
     BLAS.dot(features, weightVector.weights) + weightVector.intercept
   }
 
-  override def gradient(features: FlinkVector, weights: WeightVector): FlinkVector = {features}
+  override def gradient(features: FlinkVector, weights: WeightVector): WeightVector = {
+    WeightVector(features.copy, 1)
+  }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/44dae0c3/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Regularization.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Regularization.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Regularization.scala
deleted file mode 100644
index 9e6df4a..0000000
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Regularization.scala
+++ /dev/null
@@ -1,228 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.ml.optimization
-
-import org.apache.flink.ml.math.{Vector => FlinkVector, BLAS}
-
-/** Represents a type of regularization penalty
-  *
-  * Regularization penalties are used to restrict the optimization problem to solutions with
-  * certain desirable characteristics, such as sparsity for the L1 penalty, or penalizing large
-  * weights for the L2 penalty.
-  *
-  * The regularization term, $R(w)$ is added to the objective function, $f(w) = L(w) + \lambda R(w)$
-  * where $\lambda$ is the regularization parameter used to tune the amount of regularization
-  * applied.
-  */
-abstract class Regularization extends Serializable {
-
-  /** Updates the weights by taking a step according to the gradient and regularization applied
-    *
-    * @param oldWeights The weights to be updated
-    * @param gradient The gradient according to which we will update the weights
-    * @param effectiveStepSize The effective step size for this iteration
-    * @param regParameter The regularization parameter, $\lambda$.
-    */
-  def takeStep(
-      oldWeights: FlinkVector,
-      gradient: FlinkVector,
-      effectiveStepSize: Double,
-      regParameter: Double) {
-    BLAS.axpy(-effectiveStepSize, gradient, oldWeights)
-  }
-
-  /** Adds the regularization term to the loss value
-    *
-    * @param loss The loss value, before applying regularization.
-    * @param weightVector The current vector of weights.
-    * @param regularizationParameter The regularization parameter, $\lambda$.
-    * @return The loss value with regularization applied.
-    */
-  def regLoss(loss: Double, weightVector: FlinkVector, regularizationParameter: Double): Double
-
-}
-
-/** Abstract class for regularization penalties that are differentiable
-  *
-  */
-abstract class DiffRegularization extends Regularization {
-
-  /** Compute the regularized gradient loss for the given data.
-    * The provided cumGradient is updated in place.
-    *
-    * @param loss The loss value without regularization.
-    * @param weightVector The current vector of weights.
-    * @param lossGradient The loss gradient, without regularization. Updated in-place.
-    * @param regParameter The regularization parameter, $\lambda$.
-    * @return The loss value with regularization applied.
-    */
-  def regularizedLossAndGradient(
-      loss: Double,
-      weightVector: FlinkVector,
-      lossGradient: FlinkVector,
-      regParameter: Double) : Double ={
-    val adjustedLoss = regLoss(loss, weightVector, regParameter)
-    regGradient(weightVector, lossGradient, regParameter)
-
-    adjustedLoss
-  }
-
-  /** Adds the regularization gradient term to the loss gradient. The gradient is updated in place.
-    *
-    * @param weightVector The current vector of weights
-    * @param lossGradient The loss gradient, without regularization. Updated in-place.
-    * @param regParameter The regularization parameter, $\lambda$.
-    */
-  def regGradient(
-      weightVector: FlinkVector,
-      lossGradient: FlinkVector,
-      regParameter: Double)
-}
-
-/** Performs no regularization, equivalent to $R(w) = 0$ **/
-class NoRegularization extends DiffRegularization {
-  /** Adds the regularization term to the loss value
-    *
-    * @param loss The loss value, before applying regularization
-    * @param weightVector The current vector of weights
-    * @param regParameter The regularization parameter, $\lambda$
-    * @return The loss value with regularization applied.
-    */
-  override def regLoss(
-    loss: Double,
-    weightVector: FlinkVector,
-    regParameter: Double):  Double = {loss}
-
-  /** Adds the regularization gradient term to the loss gradient. The gradient is updated in place.
-    *
-    * Since we don't apply any regularization, the gradient will stay the same.
-    * @param weightVector The current vector of weights
-    * @param lossGradient The loss gradient, without regularization. Updated in-place.
-    * @param regParameter The regularization parameter, $\lambda$.
-    */
-  override def regGradient(
-      weightVector: FlinkVector,
-      lossGradient: FlinkVector,
-      regParameter: Double) = {}
-}
-
-object NoRegularization {
-  def apply(): NoRegularization = {
-    new NoRegularization
-  }
-}
-
-/** $L_2$ regularization penalty.
-  *
-  * Penalizes large weights, favoring solutions with more small weights rather than few large ones.
-  *
-  */
-class L2Regularization extends DiffRegularization {
-
-  /** Adds the regularization term to the loss value
-    *
-    * @param loss The loss value, before applying regularization
-    * @param weightVector The current vector of weights
-    * @param regParameter The regularization parameter, $\lambda$
-    * @return The loss value with regularization applied.
-    */
-  override def regLoss(loss: Double, weightVector: FlinkVector, regParameter: Double)
-    : Double = {
-    loss + regParameter * BLAS.dot(weightVector, weightVector) / 2
-  }
-
-  /** Adds the regularization gradient term to the loss gradient. The gradient is updated in place.
-    *
-    * @param weightVector The current vector of weights.
-    * @param lossGradient The loss gradient, without regularization. Updated in-place.
-    * @param regParameter The regularization parameter, $\lambda$.
-    */
-  override def regGradient(
-      weightVector: FlinkVector,
-      lossGradient: FlinkVector,
-      regParameter: Double): Unit = {
-    BLAS.axpy(regParameter, weightVector, lossGradient)
-  }
-}
-
-object L2Regularization {
-  def apply(): L2Regularization = {
-    new L2Regularization
-  }
-}
-
-/** $L_1$ regularization penalty.
-  *
-  * The $L_1$ penalty can be used to drive a number of the solution coefficients to 0, thereby
-  * producing sparse solutions.
-  *
-  */
-class L1Regularization extends Regularization {
-  /** Calculates and applies the regularization amount and the regularization parameter
-    *
-    * Implementation was taken from the Apache Spark Mllib library:
-    * http://git.io/vfZIT
-    *
-    * @param oldWeights The weights to be updated
-    * @param gradient The gradient according to which we will update the weights
-    * @param effectiveStepSize The effective step size for this iteration
-    * @param regParameter The regularization parameter to be applied in the case of L1
-    *                     regularization
-    */
-  override def takeStep(
-      oldWeights: FlinkVector,
-      gradient: FlinkVector,
-      effectiveStepSize: Double,
-      regParameter: Double) {
-    BLAS.axpy(-effectiveStepSize, gradient, oldWeights)
-
-    // Apply proximal operator (soft thresholding)
-    val shrinkageVal = regParameter * effectiveStepSize
-    var i = 0
-    while (i < oldWeights.size) {
-      val wi = oldWeights(i)
-      oldWeights(i) = math.signum(wi) * math.max(0.0, math.abs(wi) - shrinkageVal)
-      i += 1
-    }
-  }
-
-  /** Adds the regularization term to the loss value
-    *
-    * @param loss The loss value, before applying regularization.
-    * @param weightVector The current vector of weights.
-    * @param regularizationParameter The regularization parameter, $\lambda$.
-    * @return The loss value with regularization applied.
-    */
-  override def regLoss(loss: Double, weightVector: FlinkVector, regularizationParameter: Double):
-  Double = {
-    loss + l1Norm(weightVector) * regularizationParameter
-  }
-
-  // TODO(tvas): Replace once we decide on how we deal with vector ops (roll our own or use Breeze)
-  /** $L_1$ norm of a Vector **/
-  private def l1Norm(vector: FlinkVector) : Double = {
-    vector.valueIterator.fold(0.0){(a,b) => math.abs(a) + math.abs(b)}
-  }
-}
-
-object L1Regularization {
-  def apply(): L1Regularization = {
-    new L1Regularization
-  }
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/44dae0c3/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala
index f2cbce3..39a031f 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/optimization/Solver.scala
@@ -18,21 +18,17 @@
 
 package org.apache.flink.ml.optimization
 
-import org.apache.flink.api.common.functions.RichMapFunction
 import org.apache.flink.api.scala.DataSet
-import org.apache.flink.configuration.Configuration
 import org.apache.flink.ml.common._
 import org.apache.flink.ml.math.{SparseVector, DenseVector}
 import org.apache.flink.api.scala._
 import org.apache.flink.ml.optimization.IterativeSolver._
-// TODO(tvas): Kind of ugly that we have to do this. Why not define the parameters inside the class?
-import org.apache.flink.ml.optimization.Solver._
 
 /** Base class for optimization algorithms
  *
  */
-abstract class Solver() extends Serializable with WithParameters {
-
+abstract class Solver extends Serializable with WithParameters {
+  import Solver._
 
   /** Provides a solution for the given optimization problem
     *
@@ -41,8 +37,9 @@ abstract class Solver() extends Serializable with WithParameters {
     * @return A Vector of weights optimized to the given problem
     */
   def optimize(
-    data: DataSet[LabeledVector],
-    initialWeights: Option[DataSet[WeightVector]]): DataSet[WeightVector]
+      data: DataSet[LabeledVector],
+      initialWeights: Option[DataSet[WeightVector]])
+    : DataSet[WeightVector]
 
   /** Creates initial weights vector, creating a DataSet with a WeightVector element
     *
@@ -51,7 +48,7 @@ abstract class Solver() extends Serializable with WithParameters {
     * @return A DataSet containing a single WeightVector element
     */
   def createInitialWeightsDS(initialWeights: Option[DataSet[WeightVector]],
-                             data: DataSet[LabeledVector]):  DataSet[WeightVector] = {
+                             data: DataSet[LabeledVector]): DataSet[WeightVector] = {
     // TODO: Faster way to do this?
     val dimensionsDS = data.map(_.vector.size).reduce((a, b) => b)
 
@@ -78,7 +75,7 @@ abstract class Solver() extends Serializable with WithParameters {
     *                    vector
     * @return DataSet of a zero vector of dimension d
     */
-  def createInitialWeightVector(dimensionDS: DataSet[Int]):  DataSet[WeightVector] = {
+  def createInitialWeightVector(dimensionDS: DataSet[Int]): DataSet[WeightVector] = {
     dimensionDS.map {
       dimension =>
         val values = Array.fill(dimension)(0.0)
@@ -93,46 +90,23 @@ abstract class Solver() extends Serializable with WithParameters {
     this
   }
 
-  // TODO(tvas): Sanitize the input, i.e. depending on Solver type allow only certain types of
-  // regularization to be set.
-  def setRegularizationType(regularization: Regularization): this.type = {
-    parameters.add(RegularizationType, regularization)
-    this
-  }
-
-  def setRegularizationParameter(regularizationParameter: Double): this.type = {
-    parameters.add(RegularizationParameter, regularizationParameter)
-    this
-  }
-
-  def setPredictionFunction(predictionFunction: PredictionFunction): this.type = {
-    parameters.add(PredictionFunction, predictionFunction)
+  def setRegularizationConstant(regularizationConstant: Double): this.type = {
+    parameters.add(RegularizationConstant, regularizationConstant)
     this
   }
 }
 
 object Solver {
-  // TODO(tvas): Does this belong in IterativeSolver instead?
-  val WEIGHTVECTOR_BROADCAST = "weights_broadcast"
-
   // Define parameters for Solver
   case object LossFunction extends Parameter[LossFunction] {
     // TODO(tvas): Should depend on problem, here is where differentiating between classification
     // and regression could become useful
-    val defaultValue = Some(new SquaredLoss)
-  }
-
-  case object RegularizationType extends Parameter[Regularization] {
-    val defaultValue = Some(new NoRegularization)
+    val defaultValue = None
   }
 
-  case object RegularizationParameter extends Parameter[Double] {
+  case object RegularizationConstant extends Parameter[Double] {
     val defaultValue = Some(0.0) // TODO(tvas): Properly initialize this, ensure Parameter > 0!
   }
-
-  case object PredictionFunction extends Parameter[PredictionFunction] {
-    val defaultValue = Some(new LinearPrediction)
-  }
 }
 
 /** An abstract class for iterative optimization algorithms
@@ -149,7 +123,7 @@ abstract class IterativeSolver() extends Solver {
   }
 
   def setStepsize(stepsize: Double): this.type = {
-    parameters.add(Stepsize, stepsize)
+    parameters.add(LearningRate, stepsize)
     this
   }
 
@@ -157,39 +131,6 @@ abstract class IterativeSolver() extends Solver {
     parameters.add(ConvergenceThreshold, convergenceThreshold)
     this
   }
-
-  /** Mapping function that calculates the weight gradients from the data.
-    *
-    */
-  protected class GradientCalculation
-    extends RichMapFunction[LabeledVector, (WeightVector, Double, Int)] {
-
-    var weightVector: WeightVector = null
-
-    @throws(classOf[Exception])
-    override def open(configuration: Configuration): Unit = {
-      val list = this.getRuntimeContext.
-        getBroadcastVariable[WeightVector](WEIGHTVECTOR_BROADCAST)
-
-      weightVector = list.get(0)
-    }
-
-    override def map(example: LabeledVector): (WeightVector, Double, Int) = {
-
-      val lossFunction = parameters(LossFunction)
-      val predictionFunction = parameters(PredictionFunction)
-      val dimensions = example.vector.size
-      val weightGradient = new DenseVector(new Array[Double](dimensions))
-
-      val (loss, lossDeriv) = lossFunction.lossAndGradient(
-        example,
-        weightVector,
-        weightGradient,
-        predictionFunction)
-
-      (new WeightVector(weightGradient, lossDeriv), loss, 1)
-    }
-  }
 }
 
 object IterativeSolver {
@@ -197,7 +138,7 @@ object IterativeSolver {
   val MAX_DLOSS: Double = 1e12
 
   // Define parameters for IterativeSolver
-  case object Stepsize extends Parameter[Double] {
+  case object LearningRate extends Parameter[Double] {
     val defaultValue = Some(0.1)
   }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/44dae0c3/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/package.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/package.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/package.scala
index 250c8cb..554e155 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/package.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/package.scala
@@ -18,10 +18,15 @@
 
 package org.apache.flink
 
+import org.apache.flink.api.common.functions.{RichFilterFunction, RichMapFunction}
+import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.operators.DataSink
 import org.apache.flink.api.scala.{DataSet, ExecutionEnvironment}
+import org.apache.flink.configuration.Configuration
 import org.apache.flink.ml.common.LabeledVector
 
+import scala.reflect.ClassTag
+
 package object ml {
 
   /** Pimp my [[ExecutionEnvironment]] to directly support `readLibSVM`
@@ -38,9 +43,77 @@ package object ml {
     *
     * @param dataSet
     */
-  implicit class RichDataSet(dataSet: DataSet[LabeledVector]) {
+  implicit class RichLabeledDataSet(dataSet: DataSet[LabeledVector]) {
     def writeAsLibSVM(path: String): DataSink[String] = {
       MLUtils.writeLibSVM(path, dataSet)
     }
   }
+
+  implicit class RichDataSet[T](dataSet: DataSet[T]) {
+    def mapWithBcVariable[B, O: TypeInformation: ClassTag](
+        broadcastVariable: DataSet[B])(
+        fun: (T, B) => O)
+      : DataSet[O] = {
+      dataSet.map(new BroadcastSingleElementMapper[T, B, O](dataSet.clean(fun)))
+        .withBroadcastSet(broadcastVariable, "broadcastVariable")
+    }
+
+    def filterWithBcVariable[B, O](broadcastVariable: DataSet[B])(fun: (T, B) => Boolean)
+      : DataSet[T] = {
+      dataSet.filter(new BroadcastSingleElementFilter[T, B](dataSet.clean(fun)))
+        .withBroadcastSet(broadcastVariable, "broadcastVariable")
+    }
+
+    def mapWithBcVariableIteration[B, O: TypeInformation: ClassTag](
+        broadcastVariable: DataSet[B])(fun: (T, B, Int) => O)
+      : DataSet[O] = {
+      dataSet.map(new BroadcastSingleElementMapperWithIteration[T, B, O](dataSet.clean(fun)))
+        .withBroadcastSet(broadcastVariable, "broadcastVariable")
+    }
+  }
+
+  private class BroadcastSingleElementMapper[T, B, O](
+      fun: (T, B) => O)
+    extends RichMapFunction[T, O] {
+    var broadcastVariable: B = _
+
+    @throws(classOf[Exception])
+    override def open(configuration: Configuration): Unit = {
+      broadcastVariable = getRuntimeContext.getBroadcastVariable[B]("broadcastVariable").get(0)
+    }
+
+    override def map(value: T): O = {
+      fun(value, broadcastVariable)
+    }
+  }
+
+  private class BroadcastSingleElementMapperWithIteration[T, B, O](
+      fun: (T, B, Int) => O)
+    extends RichMapFunction[T, O] {
+    var broadcastVariable: B = _
+
+    @throws(classOf[Exception])
+    override def open(configuration: Configuration): Unit = {
+      broadcastVariable = getRuntimeContext.getBroadcastVariable[B]("broadcastVariable").get(0)
+    }
+
+    override def map(value: T): O = {
+      fun(value, broadcastVariable, getIterationRuntimeContext.getSuperstepNumber)
+    }
+  }
+
+  private class BroadcastSingleElementFilter[T, B](
+      fun: (T, B) => Boolean)
+    extends RichFilterFunction[T] {
+    var broadcastVariable: B = _
+
+    @throws(classOf[Exception])
+    override def open(configuration: Configuration): Unit = {
+      broadcastVariable = getRuntimeContext.getBroadcastVariable[B]("broadcastVariable").get(0)
+    }
+
+    override def filter(value: T): Boolean = {
+      fun(value, broadcastVariable)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/44dae0c3/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala
index bae0288..d84d017 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/GradientDescentITSuite.scala
@@ -18,7 +18,7 @@
 
 package org.apache.flink.ml.optimization
 
-import org.apache.flink.ml.common.{LabeledVector, WeightVector, ParameterMap}
+import org.apache.flink.ml.common.{LabeledVector, WeightVector}
 import org.apache.flink.ml.math.DenseVector
 import org.apache.flink.ml.regression.RegressionData._
 import org.scalatest.{Matchers, FlatSpec}
@@ -38,12 +38,13 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
 
     env.setParallelism(2)
 
-    val sgd = GradientDescent()
+    val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
+
+    val sgd = GradientDescentL1()
       .setStepsize(0.01)
       .setIterations(2000)
-      .setLossFunction(SquaredLoss())
-      .setRegularizationType(L1Regularization())
-      .setRegularizationParameter(0.3)
+      .setLossFunction(lossFunction)
+      .setRegularizationConstant(0.3)
 
     val inputDS: DataSet[LabeledVector] = env.fromCollection(regularizationData)
 
@@ -69,12 +70,13 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
 
     env.setParallelism(2)
 
-    val sgd = GradientDescent()
+    val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
+
+    val sgd = GradientDescentL2()
       .setStepsize(0.1)
       .setIterations(1)
-      .setLossFunction(SquaredLoss())
-      .setRegularizationType(L2Regularization())
-      .setRegularizationParameter(1.0)
+      .setLossFunction(lossFunction)
+      .setRegularizationConstant(1.0)
 
     val inputDS: DataSet[LabeledVector] = env.fromElements(LabeledVector(1.0, DenseVector(2.0)))
     val currentWeights = new WeightVector(DenseVector(1.0), 1.0)
@@ -86,12 +88,9 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
 
     weightList.size should equal(1)
 
-    val weightVector: WeightVector = weightList.head
-
-    val updatedIntercept = weightVector.intercept
-    val updatedWeight = weightVector.weights(0)
+    val WeightVector(updatedWeights, updatedIntercept) = weightList.head
 
-    updatedWeight should be (0.5 +- 0.001)
+    updatedWeights(0) should be (0.5 +- 0.001)
     updatedIntercept should be (0.8 +- 0.01)
   }
 
@@ -100,12 +99,12 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
 
     env.setParallelism(2)
 
-    val sgd = GradientDescent()
+    val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
+
+    val sgd = SimpleGradientDescent()
       .setStepsize(1.0)
       .setIterations(800)
-      .setLossFunction(SquaredLoss())
-      .setRegularizationType(NoRegularization())
-      .setRegularizationParameter(0.0)
+      .setLossFunction(lossFunction)
 
     val inputDS = env.fromCollection(data)
     val weightDS = sgd.optimize(inputDS, None)
@@ -131,12 +130,12 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
 
     env.setParallelism(2)
 
-    val sgd = GradientDescent()
+    val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
+
+    val sgd = SimpleGradientDescent()
       .setStepsize(0.0001)
       .setIterations(100)
-      .setLossFunction(SquaredLoss())
-      .setRegularizationType(NoRegularization())
-      .setRegularizationParameter(0.0)
+      .setLossFunction(lossFunction)
 
     val inputDS = env.fromCollection(noInterceptData)
     val weightDS = sgd.optimize(inputDS, None)
@@ -162,12 +161,12 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
 
     env.setParallelism(2)
 
-    val sgd = GradientDescent()
+    val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
+
+    val sgd = SimpleGradientDescent()
       .setStepsize(0.1)
       .setIterations(1)
-      .setLossFunction(SquaredLoss())
-      .setRegularizationType(NoRegularization())
-      .setRegularizationParameter(0.0)
+      .setLossFunction(lossFunction)
 
     val inputDS: DataSet[LabeledVector] = env.fromElements(LabeledVector(1.0, DenseVector(2.0)))
     val currentWeights = new WeightVector(DenseVector(1.0), 1.0)
@@ -198,13 +197,13 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
 
     env.setParallelism(2)
 
-    val sgdEarlyTerminate = GradientDescent()
+    val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
+
+    val sgdEarlyTerminate = SimpleGradientDescent()
       .setConvergenceThreshold(1e2)
       .setStepsize(1.0)
       .setIterations(800)
-      .setLossFunction(SquaredLoss())
-      .setRegularizationType(NoRegularization())
-      .setRegularizationParameter(0.0)
+      .setLossFunction(lossFunction)
 
     val inputDS = env.fromCollection(data)
 
@@ -218,12 +217,10 @@ class GradientDescentITSuite extends FlatSpec with Matchers with FlinkTestBase {
     val weightsEarly = weightVectorEarly.weights.asInstanceOf[DenseVector].data
     val weight0Early = weightVectorEarly.intercept
 
-    val sgdNoConvergence = GradientDescent()
+    val sgdNoConvergence = SimpleGradientDescent()
       .setStepsize(1.0)
       .setIterations(800)
-      .setLossFunction(SquaredLoss())
-      .setRegularizationType(NoRegularization())
-      .setRegularizationParameter(0.0)
+      .setLossFunction(lossFunction)
 
     val weightDSNoConvergence = sgdNoConvergence.optimize(inputDS, None)
 

http://git-wip-us.apache.org/repos/asf/flink/blob/44dae0c3/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionITSuite.scala
index a0921e5..4152188 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionITSuite.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/LossFunctionITSuite.scala
@@ -18,8 +18,8 @@
 
 package org.apache.flink.ml.optimization
 
-import org.apache.flink.ml.common.{LabeledVector, WeightVector, ParameterMap}
-import org.apache.flink.ml.math.{BLAS, Vector => FlinkVector, DenseVector}
+import org.apache.flink.ml.common.{LabeledVector, WeightVector}
+import org.apache.flink.ml.math.DenseVector
 import org.scalatest.{Matchers, FlatSpec}
 
 import org.apache.flink.api.scala._
@@ -35,28 +35,17 @@ class LossFunctionITSuite extends FlatSpec with Matchers with FlinkTestBase {
 
     env.setParallelism(2)
 
-    val squaredLoss = new SquaredLoss
+    val lossFunction = GenericLossFunction(SquaredLoss, LinearPrediction)
+
 
     val example = LabeledVector(1.0, DenseVector(2))
     val weightVector = new WeightVector(DenseVector(1.0), 1.0)
-    val gradient = DenseVector(0.0)
-
-    val (loss, lossDerivative) = squaredLoss.lossAndGradient(
-      example,
-      weightVector,
-      gradient,
-      new LinearPrediction)
 
-    val onlyLoss = squaredLoss.lossValue(example, weightVector, new LinearPrediction)
+    val gradient = lossFunction.gradient(example, weightVector)
+    val loss = lossFunction.loss(example, weightVector)
 
     loss should be (2.0 +- 0.001)
 
-    onlyLoss should be (2.0 +- 0.001)
-
-    lossDerivative should be (2.0 +- 0.001)
-
-    gradient.data(0) should be (4.0 +- 0.001)
-
+    gradient.weights(0) should be (4.0 +- 0.001)
   }
-
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/44dae0c3/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/PredictionFunctionITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/PredictionFunctionITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/PredictionFunctionITSuite.scala
index 69e67e9..6d2a239 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/PredictionFunctionITSuite.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/PredictionFunctionITSuite.scala
@@ -34,7 +34,7 @@ class PredictionFunctionITSuite extends FlatSpec with Matchers with FlinkTestBas
 
     env.setParallelism(2)
 
-    val predFunction = new LinearPrediction
+    val predFunction = LinearPrediction
 
     val weightVector = new WeightVector(DenseVector(-1.0, 1.0, 0.4, -0.4, 0.0), 1.0)
     val features = DenseVector(1.0, 1.0, 1.0, 1.0, 1.0)
@@ -49,14 +49,14 @@ class PredictionFunctionITSuite extends FlatSpec with Matchers with FlinkTestBas
 
     env.setParallelism(2)
 
-    val predFunction = new LinearPrediction
+    val predFunction = LinearPrediction
 
     val weightVector = new WeightVector(DenseVector(-1.0, 1.0, 0.4, -0.4, 0.0), 1.0)
     val features = DenseVector(1.0, 1.0, 1.0, 1.0, 1.0)
 
     val gradient = predFunction.gradient(features, weightVector)
 
-    gradient shouldEqual DenseVector(1.0, 1.0, 1.0, 1.0, 1.0)
+    gradient shouldEqual WeightVector(DenseVector(1.0, 1.0, 1.0, 1.0, 1.0), 1.0)
   }
 
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/44dae0c3/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/RegularizationITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/RegularizationITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/RegularizationITSuite.scala
deleted file mode 100644
index 89c77f2..0000000
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/optimization/RegularizationITSuite.scala
+++ /dev/null
@@ -1,119 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.ml.optimization
-
-import org.apache.flink.ml.common.WeightVector
-import org.apache.flink.ml.math.DenseVector
-import org.apache.flink.api.scala._
-import org.apache.flink.test.util.FlinkTestBase
-
-import org.scalatest.{Matchers, FlatSpec}
-
-
-
-
-class RegularizationITSuite extends FlatSpec with Matchers with FlinkTestBase {
-
-  behavior of "The regularization type implementations"
-
-  it should "not change the loss or gradient when no regularization is used" in {
-
-    val env = ExecutionEnvironment.getExecutionEnvironment
-
-    env.setParallelism(2)
-
-    val regularization = new NoRegularization
-
-    val weightVector = new WeightVector(DenseVector(1.0), 1.0)
-    val regParameter = 1.0
-    val gradient = DenseVector(0.0)
-    val originalLoss = 1.0
-
-    val adjustedLoss = regularization.regularizedLossAndGradient(
-      originalLoss,
-      weightVector.weights,
-      gradient,
-      regParameter)
-
-    adjustedLoss should be (originalLoss +- 0.0001)
-    gradient shouldEqual DenseVector(0.0)
-  }
-
-  it should "correctly apply L1 regularization" in {
-    val env = ExecutionEnvironment.getExecutionEnvironment
-
-    env.setParallelism(2)
-
-    val regularization = new L1Regularization
-
-    val weightVector = new WeightVector(DenseVector(-1.0, 1.0, 0.4, -0.4, 0.0), 1.0)
-    val effectiveStepsize = 1.0
-    val regParameter = 0.5
-    val gradient = DenseVector(0.0, 0.0, 0.0, 0.0, 0.0)
-
-    regularization.takeStep(weightVector.weights,  gradient, effectiveStepsize, regParameter)
-
-    val expectedWeights = DenseVector(-0.5, 0.5, 0.0, 0.0, 0.0)
-
-    weightVector.weights shouldEqual expectedWeights
-    weightVector.intercept should be (1.0 +- 0.0001)
-  }
-
-  it should "correctly calculate L1 loss"  in {
-    val env = ExecutionEnvironment.getExecutionEnvironment
-
-    env.setParallelism(2)
-
-    val regularization = new L1Regularization
-
-    val weightVector = new WeightVector(DenseVector(-1.0, 1.0, 0.4, -0.4, 0.0), 1.0)
-    val regParameter = 0.5
-    val originalLoss = 1.0
-
-    val adjustedLoss = regularization.regLoss(originalLoss, weightVector.weights, regParameter)
-
-    weightVector shouldEqual WeightVector(DenseVector(-1.0, 1.0, 0.4, -0.4, 0.0), 1.0)
-    adjustedLoss should be (2.4 +- 0.1)
-  }
-
-  it should "correctly adjust the gradient and loss for L2 regularization" in {
-    val env = ExecutionEnvironment.getExecutionEnvironment
-
-    env.setParallelism(2)
-
-    val regularization = new L2Regularization
-
-    val weightVector = new WeightVector(DenseVector(-1.0, 1.0, 0.4, -0.4, 0.0), 1.0)
-    val regParameter = 0.5
-    val lossGradient = DenseVector(0.0, 0.0, 0.0, 0.0, 0.0)
-    val originalLoss = 1.0
-
-    val adjustedLoss = regularization.regularizedLossAndGradient(
-      originalLoss,
-      weightVector.weights,
-      lossGradient,
-      regParameter)
-
-    val expectedGradient = DenseVector(-0.5, 0.5, 0.2, -0.2, 0.0)
-
-    weightVector shouldEqual WeightVector(DenseVector(-1.0, 1.0, 0.4, -0.4, 0.0), 1.0)
-    adjustedLoss should be (1.58 +- 0.1)
-    lossGradient shouldEqual expectedGradient
-  }
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/44dae0c3/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/pipeline/PipelineITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/pipeline/PipelineITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/pipeline/PipelineITSuite.scala
index d29da0c..a36a0d1 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/pipeline/PipelineITSuite.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/pipeline/PipelineITSuite.scala
@@ -18,9 +18,6 @@
 
 package org.apache.flink.ml.pipeline
 
-import breeze.linalg
-import org.apache.flink.api.common.ExecutionConfig
-import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer
 import org.apache.flink.api.scala._
 import org.apache.flink.ml.classification.SVM
 import org.apache.flink.ml.common.{ParameterMap, LabeledVector}
@@ -166,8 +163,7 @@ class PipelineITSuite extends FlatSpec with Matchers with FlinkTestBase {
     val chainedScalers5 = chainedScalers4.chainTransformer(StandardScaler())
 
     val predictor = MultipleLinearRegression()
-
-
+    
     val pipeline = chainedScalers5.chainPredictor(predictor)
 
     pipeline.fit(trainingData)


Mime
View raw message