flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From trohrm...@apache.org
Subject flink git commit: [FLINK-2053] [ml] Adds automatic type registration of flink-ml types. Adds de-duplication of registered types at ExecutionConfig. Fixes bug in Breeze SparseVector to Flink SparseVector conversion.
Date Tue, 26 May 2015 09:45:07 GMT
Repository: flink
Updated Branches:
  refs/heads/master a1d2df614 -> ae446388b


[FLINK-2053] [ml] Adds automatic type registration of flink-ml types. Adds de-duplication of registered types at ExecutionConfig. Fixes bug in Breeze SparseVector to Flink SparseVector conversion.

This closes #723.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/ae446388
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/ae446388
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/ae446388

Branch: refs/heads/master
Commit: ae446388b91ecc0f08887da19400395b96b32f6c
Parents: a1d2df6
Author: Till Rohrmann <trohrmann@apache.org>
Authored: Tue May 26 00:35:05 2015 +0200
Committer: Till Rohrmann <trohrmann@apache.org>
Committed: Tue May 26 11:44:46 2015 +0200

----------------------------------------------------------------------
 .../flink/api/common/ExecutionConfig.java       |  11 +-
 .../flink/api/common/ExecutionConfigTest.java   |  47 +++
 .../java/typeutils/runtime/PojoSerializer.java  |   3 +-
 .../typeutils/runtime/kryo/KryoSerializer.java  |   5 +-
 .../apache/flink/ml/classification/CoCoA.scala  |   9 +-
 .../apache/flink/ml/common/FlinkMLTools.scala   | 423 +++++++++++++++++++
 .../org/apache/flink/ml/common/FlinkTools.scala | 392 -----------------
 .../scala/org/apache/flink/ml/math/Breeze.scala |   4 +-
 .../flink/ml/math/BreezeVectorConverter.scala   |  10 +-
 .../apache/flink/ml/pipeline/Estimator.scala    |   3 +-
 .../apache/flink/ml/pipeline/Predictor.scala    |   3 +-
 .../apache/flink/ml/pipeline/Transformer.scala  |   3 +-
 .../flink/ml/preprocessing/StandardScaler.scala |   2 -
 .../apache/flink/ml/recommendation/ALS.scala    |   6 +-
 .../flink/ml/common/FlinkMLToolsSuite.scala     |  60 +++
 .../ml/feature/PolynomialBaseITSuite.scala      | 127 ------
 .../ml/feature/PolynomialFeaturesITSuite.scala  | 127 ++++++
 .../flink/ml/pipeline/PipelineITSuite.scala     |   5 +-
 18 files changed, 698 insertions(+), 542 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/ae446388/flink-core/src/main/java/org/apache/flink/api/common/ExecutionConfig.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/ExecutionConfig.java b/flink-core/src/main/java/org/apache/flink/api/common/ExecutionConfig.java
index 3af153a..8baedb4 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/ExecutionConfig.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/ExecutionConfig.java
@@ -22,6 +22,7 @@ import com.esotericsoftware.kryo.Serializer;
 
 import java.io.Serializable;
 import java.util.ArrayList;
+import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
 
@@ -97,9 +98,9 @@ public class ExecutionConfig implements Serializable {
 	private final List<Entry<Class<?>, Class<? extends Serializer<?>>>> defaultKryoSerializerClasses =
 			new ArrayList<Entry<Class<?>, Class<? extends Serializer<?>>>>();
 
-	private final List<Class<?>> registeredKryoTypes = new ArrayList<Class<?>>();
+	private final LinkedHashSet<Class<?>> registeredKryoTypes = new LinkedHashSet<Class<?>>();
 
-	private final List<Class<?>> registeredPojoTypes = new ArrayList<Class<?>>();
+	private final LinkedHashSet<Class<?>> registeredPojoTypes = new LinkedHashSet<Class<?>>();
 
 	// --------------------------------------------------------------------------------------------
 
@@ -505,11 +506,11 @@ public class ExecutionConfig implements Serializable {
 	/**
 	 * Returns the registered Kryo types.
 	 */
-	public List<Class<?>> getRegisteredKryoTypes() {
+	public LinkedHashSet<Class<?>> getRegisteredKryoTypes() {
 		if (isForceKryoEnabled()) {
 			// if we force kryo, we must also return all the types that
 			// were previously only registered as POJO
-			List<Class<?>> result = new ArrayList<Class<?>>();
+			LinkedHashSet<Class<?>> result = new LinkedHashSet<Class<?>>();
 			result.addAll(registeredKryoTypes);
 			for(Class<?> t : registeredPojoTypes) {
 				if (!result.contains(t)) {
@@ -525,7 +526,7 @@ public class ExecutionConfig implements Serializable {
 	/**
 	 * Returns the registered POJO types.
 	 */
-	public List<Class<?>> getRegisteredPojoTypes() {
+	public LinkedHashSet<Class<?>> getRegisteredPojoTypes() {
 		return registeredPojoTypes;
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/ae446388/flink-core/src/test/java/org/apache/flink/api/common/ExecutionConfigTest.java
----------------------------------------------------------------------
diff --git a/flink-core/src/test/java/org/apache/flink/api/common/ExecutionConfigTest.java b/flink-core/src/test/java/org/apache/flink/api/common/ExecutionConfigTest.java
new file mode 100644
index 0000000..ad3ad91
--- /dev/null
+++ b/flink-core/src/test/java/org/apache/flink/api/common/ExecutionConfigTest.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.common;
+
+import org.junit.Test;
+import static org.junit.Assert.*;
+
+import java.util.Arrays;
+import java.util.List;
+
+public class ExecutionConfigTest {
+
+	@Test
+	public void testDoubleTypeRegistration() {
+		ExecutionConfig config = new ExecutionConfig();
+		List<Class<?>> types = Arrays.asList((Class<?>)Double.class, Integer.class, Double.class);
+		List<Class<?>> expectedTypes = Arrays.asList((Class<?>)Double.class, Integer.class);
+
+		for(Class<?> tpe: types) {
+			config.registerKryoType(tpe);
+		}
+
+		int counter = 0;
+
+		for(Class<?> tpe: config.getRegisteredKryoTypes()){
+			assertEquals(tpe, expectedTypes.get(counter++));
+		}
+
+		assertTrue(counter == expectedTypes.size());
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/ae446388/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoSerializer.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoSerializer.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoSerializer.java
index c61ad8d..5d4553d 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoSerializer.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoSerializer.java
@@ -27,6 +27,7 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.LinkedHashMap;
+import java.util.LinkedHashSet;
 import java.util.List;
 import java.util.Map;
 
@@ -78,7 +79,7 @@ public final class PojoSerializer<T> extends TypeSerializer<T> {
 		this.numFields = fieldSerializers.length;
 		this.executionConfig = executionConfig;
 
-		List<Class<?>> registeredPojoTypes = executionConfig.getRegisteredPojoTypes();
+		LinkedHashSet<Class<?>> registeredPojoTypes = executionConfig.getRegisteredPojoTypes();
 
 		for (int i = 0; i < numFields; i++) {
 			this.fields[i].setAccessible(true);

http://git-wip-us.apache.org/repos/asf/flink/blob/ae446388/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/kryo/KryoSerializer.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/kryo/KryoSerializer.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/kryo/KryoSerializer.java
index e14546e..8ae3562 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/kryo/KryoSerializer.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/kryo/KryoSerializer.java
@@ -42,6 +42,7 @@ import java.io.ByteArrayOutputStream;
 import java.io.EOFException;
 import java.io.IOException;
 import java.lang.reflect.Modifier;
+import java.util.LinkedHashSet;
 import java.util.List;
 
 /**
@@ -63,7 +64,7 @@ public class KryoSerializer<T> extends TypeSerializer<T> {
 	private final List<ExecutionConfig.Entry<Class<?>, Class<? extends Serializer<?>>>> registeredTypesWithSerializerClasses;
 	private final List<ExecutionConfig.Entry<Class<?>, Serializer<?>>> defaultSerializers;
 	private final List<ExecutionConfig.Entry<Class<?>, Class<? extends Serializer<?>>>> defaultSerializerClasses;
-	private final List<Class<?>> registeredTypes;
+	private final LinkedHashSet<Class<?>> registeredTypes;
 
 	private final Class<T> type;
 	
@@ -305,7 +306,7 @@ public class KryoSerializer<T> extends TypeSerializer<T> {
 	// For testing
 	// --------------------------------------------------------------------------------------------
 	
-	Kryo getKryo() {
+	public Kryo getKryo() {
 		checkKryoInitialized();
 		return this.kryo;
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/ae446388/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/CoCoA.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/CoCoA.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/CoCoA.scala
index 4ba9299..fea6be5 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/CoCoA.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/classification/CoCoA.scala
@@ -26,7 +26,7 @@ import scala.util.Random
 import org.apache.flink.api.common.functions.RichMapFunction
 import org.apache.flink.api.scala._
 import org.apache.flink.configuration.Configuration
-import org.apache.flink.ml.common.FlinkTools.ModuloKeyPartitioner
+import org.apache.flink.ml.common.FlinkMLTools.ModuloKeyPartitioner
 import org.apache.flink.ml.common._
 import org.apache.flink.ml.math.Vector
 import org.apache.flink.ml.math.Breeze._
@@ -244,6 +244,11 @@ object CoCoA{
           case Some(weights) => {
             input.map(new PredictionMapper[T]).withBroadcastSet(weights, WEIGHT_VECTOR)
           }
+
+          case None => {
+            throw new RuntimeException("The CoCoA model has not been trained. Call first fit" +
+              "before calling the predict operation.")
+          }
         }
       }
     }
@@ -310,7 +315,7 @@ object CoCoA{
         val numberVectors = input map { x => 1 } reduce { _ + _ }
 
         // Group the input data into blocks in round robin fashion
-        val blockedInputNumberElements = FlinkTools.block(
+        val blockedInputNumberElements = FlinkMLTools.block(
           input,
           blocks,
           Some(ModuloKeyPartitioner)).

http://git-wip-us.apache.org/repos/asf/flink/blob/ae446388/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkMLTools.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkMLTools.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkMLTools.scala
new file mode 100644
index 0000000..553ec00
--- /dev/null
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkMLTools.scala
@@ -0,0 +1,423 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common
+
+import org.apache.flink.api.common.functions.Partitioner
+import org.apache.flink.api.common.io.FileOutputFormat.OutputDirectoryMode
+import org.apache.flink.api.common.typeinfo.TypeInformation
+import org.apache.flink.api.java.io.{TypeSerializerInputFormat, TypeSerializerOutputFormat}
+import org.apache.flink.api.scala._
+import org.apache.flink.core.fs.FileSystem.WriteMode
+import org.apache.flink.core.fs.Path
+
+import scala.reflect.ClassTag
+
+/** FlinkTools contains a set of convenience functions for Flink's machine learning library:
+  *
+  *  - persist:
+  *  Takes up to 5 [[DataSet]]s and file paths. Each [[DataSet]] is written to the specified
+  *  path and subsequently re-read from disk. This method can be used to effectively split the
+  *  execution graph at the given [[DataSet]]. Writing it to disk triggers its materialization
+  *  and specifying it as a source will prevent the re-execution of it.
+  *
+  *  - block:
+  *  Takes a DataSet of elements T and groups them in n blocks.
+  *
+  */
+object FlinkMLTools {
+
+  /** Registers the different FlinkML related types for Kryo serialization
+    *
+    * @param env
+    */
+  def registerFlinkMLTypes(env: ExecutionEnvironment): Unit = {
+
+    // Vector types
+    env.registerType(classOf[org.apache.flink.ml.math.DenseVector])
+    env.registerType(classOf[org.apache.flink.ml.math.SparseVector])
+
+    // Matrix types
+    env.registerType(classOf[org.apache.flink.ml.math.DenseMatrix])
+    env.registerType(classOf[org.apache.flink.ml.math.SparseMatrix])
+
+    // Breeze Vector types
+    env.registerType(classOf[breeze.linalg.DenseVector[_]])
+    env.registerType(classOf[breeze.linalg.SparseVector[_]])
+
+    // Breeze specialized types
+    env.registerType(breeze.linalg.DenseVector.zeros[Double](0).getClass)
+    env.registerType(breeze.linalg.SparseVector.zeros[Double](0).getClass)
+
+    // Breeze Matrix types
+    env.registerType(classOf[breeze.linalg.DenseMatrix[Double]])
+    env.registerType(classOf[breeze.linalg.CSCMatrix[Double]])
+
+    // Breeze specialized types
+    env.registerType(breeze.linalg.DenseMatrix.zeros[Double](0, 0).getClass)
+    env.registerType(breeze.linalg.CSCMatrix.zeros[Double](0, 0).getClass)
+  }
+
+  /** Writes a [[DataSet]] to the specified path and returns it as a DataSource for subsequent
+    * operations.
+    *
+    * @param dataset [[DataSet]] to write to disk
+    * @param path File path to write dataset to
+    * @tparam T Type of the [[DataSet]] elements
+    * @return [[DataSet]] reading the just written file
+    */
+  def persist[T: ClassTag: TypeInformation](dataset: DataSet[T], path: String): DataSet[T] = {
+    val env = dataset.getExecutionEnvironment
+    val outputFormat = new TypeSerializerOutputFormat[T]
+
+    val filePath = new Path(path)
+
+    outputFormat.setOutputFilePath(filePath)
+    outputFormat.setWriteMode(WriteMode.OVERWRITE)
+
+    dataset.output(outputFormat)
+    env.execute("FlinkTools persist")
+
+    val inputFormat = new TypeSerializerInputFormat[T](dataset.getType)
+    inputFormat.setFilePath(filePath)
+
+    env.createInput(inputFormat)
+  }
+
+  /** Writes multiple [[DataSet]]s to the specified paths and returns them as DataSources for
+    * subsequent operations.
+    *
+    * @param ds1 First [[DataSet]] to write to disk
+    * @param ds2 Second [[DataSet]] to write to disk
+    * @param path1 Path for ds1
+    * @param path2 Path for ds2
+    * @tparam A Type of the first [[DataSet]]'s elements
+    * @tparam B Type of the second [[DataSet]]'s elements
+    * @return Tuple of [[DataSet]]s reading the just written files
+    */
+  def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation](ds1: DataSet[A], ds2:
+  DataSet[B], path1: String, path2: String):(DataSet[A], DataSet[B])  = {
+    val env = ds1.getExecutionEnvironment
+
+    val f1 = new Path(path1)
+
+    val of1 = new TypeSerializerOutputFormat[A]
+    of1.setOutputFilePath(f1)
+    of1.setWriteMode(WriteMode.OVERWRITE)
+
+    ds1.output(of1)
+
+    val f2 = new Path(path2)
+
+    val of2 = new TypeSerializerOutputFormat[B]
+    of2.setOutputFilePath(f2)
+    of2.setWriteMode(WriteMode.OVERWRITE)
+
+    ds2.output(of2)
+
+    env.execute("FlinkTools persist")
+
+    val if1 = new TypeSerializerInputFormat[A](ds1.getType)
+    if1.setFilePath(f1)
+
+    val if2 = new TypeSerializerInputFormat[B](ds2.getType)
+    if2.setFilePath(f2)
+
+    (env.createInput(if1), env.createInput(if2))
+  }
+
+  /** Writes multiple [[DataSet]]s to the specified paths and returns them as DataSources for
+    * subsequent operations.
+    *
+    * @param ds1 First [[DataSet]] to write to disk
+    * @param ds2 Second [[DataSet]] to write to disk
+    * @param ds3 Third [[DataSet]] to write to disk
+    * @param path1 Path for ds1
+    * @param path2 Path for ds2
+    * @param path3 Path for ds3
+    * @tparam A Type of first [[DataSet]]'s elements
+    * @tparam B Type of second [[DataSet]]'s elements
+    * @tparam C Type of third [[DataSet]]'s elements
+    * @return Tuple of [[DataSet]]s reading the just written files
+    */
+  def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation,
+  C: ClassTag: TypeInformation](ds1: DataSet[A], ds2:  DataSet[B], ds3: DataSet[C], path1:
+  String, path2: String, path3: String): (DataSet[A], DataSet[B], DataSet[C])  = {
+    val env = ds1.getExecutionEnvironment
+
+    val f1 = new Path(path1)
+
+    val of1 = new TypeSerializerOutputFormat[A]
+    of1.setOutputFilePath(f1)
+    of1.setWriteMode(WriteMode.OVERWRITE)
+
+    ds1.output(of1)
+
+    val f2 = new Path(path2)
+
+    val of2 = new TypeSerializerOutputFormat[B]
+    of2.setOutputFilePath(f2)
+    of2.setWriteMode(WriteMode.OVERWRITE)
+
+    ds2.output(of2)
+
+    val f3 = new Path(path3)
+
+    val of3 = new TypeSerializerOutputFormat[C]
+    of3.setOutputFilePath(f3)
+    of3.setWriteMode(WriteMode.OVERWRITE)
+
+    ds3.output(of3)
+
+    env.execute("FlinkTools persist")
+
+    val if1 = new TypeSerializerInputFormat[A](ds1.getType)
+    if1.setFilePath(f1)
+
+    val if2 = new TypeSerializerInputFormat[B](ds2.getType)
+    if2.setFilePath(f2)
+
+    val if3 = new TypeSerializerInputFormat[C](ds3.getType)
+    if3.setFilePath(f3)
+
+    (env.createInput(if1), env.createInput(if2), env.createInput(if3))
+  }
+
+  /** Writes multiple [[DataSet]]s to the specified paths and returns them as DataSources for
+    * subsequent operations.
+    *
+    * @param ds1 First [[DataSet]] to write to disk
+    * @param ds2 Second [[DataSet]] to write to disk
+    * @param ds3 Third [[DataSet]] to write to disk
+    * @param ds4 Fourth [[DataSet]] to write to disk
+    * @param path1 Path for ds1
+    * @param path2 Path for ds2
+    * @param path3 Path for ds3
+    * @param path4 Path for ds4
+    * @tparam A Type of first [[DataSet]]'s elements
+    * @tparam B Type of second [[DataSet]]'s elements
+    * @tparam C Type of third [[DataSet]]'s elements
+    * @tparam D Type of fourth [[DataSet]]'s elements
+    * @return Tuple of [[DataSet]]s reading the just written files
+    */
+  def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation,
+  C: ClassTag: TypeInformation, D: ClassTag: TypeInformation](ds1: DataSet[A], ds2:  DataSet[B],
+                                                              ds3: DataSet[C], ds4: DataSet[D],
+                                                              path1: String, path2: String, path3:
+                                                              String, path4: String):
+  (DataSet[A], DataSet[B], DataSet[C], DataSet[D])  = {
+    val env = ds1.getExecutionEnvironment
+
+    val f1 = new Path(path1)
+
+    val of1 = new TypeSerializerOutputFormat[A]
+    of1.setOutputFilePath(f1)
+    of1.setWriteMode(WriteMode.OVERWRITE)
+
+    ds1.output(of1)
+
+    val f2 = new Path(path2)
+
+    val of2 = new TypeSerializerOutputFormat[B]
+    of2.setOutputFilePath(f2)
+    of2.setWriteMode(WriteMode.OVERWRITE)
+
+    ds2.output(of2)
+
+    val f3 = new Path(path3)
+
+    val of3 = new TypeSerializerOutputFormat[C]
+    of3.setOutputFilePath(f3)
+    of3.setWriteMode(WriteMode.OVERWRITE)
+
+    ds3.output(of3)
+
+    val f4 = new Path(path4)
+
+    val of4 = new TypeSerializerOutputFormat[D]
+    of4.setOutputFilePath(f4)
+    of4.setWriteMode(WriteMode.OVERWRITE)
+
+    ds4.output(of4)
+
+    env.execute("FlinkTools persist")
+
+    val if1 = new TypeSerializerInputFormat[A](ds1.getType)
+    if1.setFilePath(f1)
+
+    val if2 = new TypeSerializerInputFormat[B](ds2.getType)
+    if2.setFilePath(f2)
+
+    val if3 = new TypeSerializerInputFormat[C](ds3.getType)
+    if3.setFilePath(f3)
+
+    val if4 = new TypeSerializerInputFormat[D](ds4.getType)
+    if4.setFilePath(f4)
+
+    (env.createInput(if1), env.createInput(if2), env.createInput(if3), env.createInput(if4))
+  }
+
+  /** Writes multiple [[DataSet]]s to the specified paths and returns them as DataSources for
+    * subsequent operations.
+    *
+    * @param ds1 First [[DataSet]] to write to disk
+    * @param ds2 Second [[DataSet]] to write to disk
+    * @param ds3 Third [[DataSet]] to write to disk
+    * @param ds4 Fourth [[DataSet]] to write to disk
+    * @param ds5 Fifth [[DataSet]] to write to disk
+    * @param path1 Path for ds1
+    * @param path2 Path for ds2
+    * @param path3 Path for ds3
+    * @param path4 Path for ds4
+    * @param path5 Path for ds5
+    * @tparam A Type of first [[DataSet]]'s elements
+    * @tparam B Type of second [[DataSet]]'s elements
+    * @tparam C Type of third [[DataSet]]'s elements
+    * @tparam D Type of fourth [[DataSet]]'s elements
+    * @tparam E Type of fifth [[DataSet]]'s elements
+    * @return Tuple of [[DataSet]]s reading the just written files
+    */
+  def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation,
+  C: ClassTag: TypeInformation, D: ClassTag: TypeInformation, E: ClassTag: TypeInformation]
+  (ds1: DataSet[A], ds2:  DataSet[B], ds3: DataSet[C], ds4: DataSet[D], ds5: DataSet[E], path1:
+  String, path2: String, path3: String, path4: String, path5: String): (DataSet[A], DataSet[B],
+    DataSet[C], DataSet[D], DataSet[E])  = {
+    val env = ds1.getExecutionEnvironment
+
+    val f1 = new Path(path1)
+
+    val of1 = new TypeSerializerOutputFormat[A]
+    of1.setOutputFilePath(f1)
+    of1.setWriteMode(WriteMode.OVERWRITE)
+
+    ds1.output(of1)
+
+    val f2 = new Path(path2)
+
+    val of2 = new TypeSerializerOutputFormat[B]
+    of2.setOutputFilePath(f2)
+    of2.setOutputDirectoryMode(OutputDirectoryMode.ALWAYS)
+    of2.setWriteMode(WriteMode.OVERWRITE)
+
+    ds2.output(of2)
+
+    val f3 = new Path(path3)
+
+    val of3 = new TypeSerializerOutputFormat[C]
+    of3.setOutputFilePath(f3)
+    of3.setWriteMode(WriteMode.OVERWRITE)
+
+    ds3.output(of3)
+
+    val f4 = new Path(path4)
+
+    val of4 = new TypeSerializerOutputFormat[D]
+    of4.setOutputFilePath(f4)
+    of4.setWriteMode(WriteMode.OVERWRITE)
+
+    ds4.output(of4)
+
+    val f5 = new Path(path5)
+
+    val of5 = new TypeSerializerOutputFormat[E]
+    of5.setOutputFilePath(f5)
+    of5.setWriteMode(WriteMode.OVERWRITE)
+
+    ds5.output(of5)
+
+    env.execute("FlinkTools persist")
+
+    val if1 = new TypeSerializerInputFormat[A](ds1.getType)
+    if1.setFilePath(f1)
+
+    val if2 = new TypeSerializerInputFormat[B](ds2.getType)
+    if2.setFilePath(f2)
+
+    val if3 = new TypeSerializerInputFormat[C](ds3.getType)
+    if3.setFilePath(f3)
+
+    val if4 = new TypeSerializerInputFormat[D](ds4.getType)
+    if4.setFilePath(f4)
+
+    val if5 = new TypeSerializerInputFormat[E](ds5.getType)
+    if5.setFilePath(f5)
+
+    (env.createInput(if1), env.createInput(if2), env.createInput(if3), env.createInput(if4), env
+      .createInput(if5))
+  }
+
+  /** Groups the DataSet input into numBlocks blocks.
+    * 
+    * @param input
+    * @param numBlocks Number of Blocks
+    * @param partitionerOption Optional partitioner to control the partitioning
+    * @tparam T
+    * @return
+    */
+  def block[T: TypeInformation: ClassTag](
+    input: DataSet[T],
+    numBlocks: Int,
+    partitionerOption: Option[Partitioner[Int]] = None)
+  : DataSet[Block[T]] = {
+    val blockIDInput = input map {
+      element =>
+        val blockID = element.hashCode() % numBlocks
+
+        val blockIDResult = if(blockID < 0){
+          blockID + numBlocks
+        } else {
+          blockID
+        }
+
+        (blockIDResult, element)
+    }
+
+    val preGroupBlockIDInput = partitionerOption match {
+      case Some(partitioner) =>
+        blockIDInput partitionCustom(partitioner, 0)
+
+      case None => blockIDInput
+    }
+
+    preGroupBlockIDInput.groupBy(0).reduceGroup {
+      iter => {
+        val array = iter.toVector
+
+        val blockID = array(0)._1
+        val elements = array.map(_._2)
+
+        Block[T](blockID, elements)
+      }
+    }.withForwardedFields("0 -> index")
+  }
+
+  /** Distributes the elements by taking the modulo of their keys and assigning it to this channel
+    *
+    */
+  object ModuloKeyPartitioner extends Partitioner[Int] {
+    override def partition(key: Int, numPartitions: Int): Int = {
+      val result = key % numPartitions
+
+      if(result < 0) {
+        result + numPartitions
+      } else {
+        result
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/ae446388/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala
deleted file mode 100644
index 57bf98e..0000000
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/common/FlinkTools.scala
+++ /dev/null
@@ -1,392 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.ml.common
-
-import org.apache.flink.api.common.functions.Partitioner
-import org.apache.flink.api.common.io.FileOutputFormat.OutputDirectoryMode
-import org.apache.flink.api.common.typeinfo.TypeInformation
-import org.apache.flink.api.java.io.{TypeSerializerInputFormat, TypeSerializerOutputFormat}
-import org.apache.flink.api.scala._
-import org.apache.flink.core.fs.FileSystem.WriteMode
-import org.apache.flink.core.fs.Path
-
-import scala.reflect.ClassTag
-
-/** FlinkTools contains a set of convenience functions for Flink's machine learning library:
-  *
-  *  - persist:
-  *  Takes up to 5 [[DataSet]]s and file paths. Each [[DataSet]] is written to the specified
-  *  path and subsequently re-read from disk. This method can be used to effectively split the
-  *  execution graph at the given [[DataSet]]. Writing it to disk triggers its materialization
-  *  and specifying it as a source will prevent the re-execution of it.
-  *
-  *  - block:
-  *  Takes a DataSet of elements T and groups them in n blocks.
-  *
-  */
-object FlinkTools {
-
-  /** Writes a [[DataSet]] to the specified path and returns it as a DataSource for subsequent
-    * operations.
-    *
-    * @param dataset [[DataSet]] to write to disk
-    * @param path File path to write dataset to
-    * @tparam T Type of the [[DataSet]] elements
-    * @return [[DataSet]] reading the just written file
-    */
-  def persist[T: ClassTag: TypeInformation](dataset: DataSet[T], path: String): DataSet[T] = {
-    val env = dataset.getExecutionEnvironment
-    val outputFormat = new TypeSerializerOutputFormat[T]
-
-    val filePath = new Path(path)
-
-    outputFormat.setOutputFilePath(filePath)
-    outputFormat.setWriteMode(WriteMode.OVERWRITE)
-
-    dataset.output(outputFormat)
-    env.execute("FlinkTools persist")
-
-    val inputFormat = new TypeSerializerInputFormat[T](dataset.getType)
-    inputFormat.setFilePath(filePath)
-
-    env.createInput(inputFormat)
-  }
-
-  /** Writes multiple [[DataSet]]s to the specified paths and returns them as DataSources for
-    * subsequent operations.
-    *
-    * @param ds1 First [[DataSet]] to write to disk
-    * @param ds2 Second [[DataSet]] to write to disk
-    * @param path1 Path for ds1
-    * @param path2 Path for ds2
-    * @tparam A Type of the first [[DataSet]]'s elements
-    * @tparam B Type of the second [[DataSet]]'s elements
-    * @return Tuple of [[DataSet]]s reading the just written files
-    */
-  def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation](ds1: DataSet[A], ds2:
-  DataSet[B], path1: String, path2: String):(DataSet[A], DataSet[B])  = {
-    val env = ds1.getExecutionEnvironment
-
-    val f1 = new Path(path1)
-
-    val of1 = new TypeSerializerOutputFormat[A]
-    of1.setOutputFilePath(f1)
-    of1.setWriteMode(WriteMode.OVERWRITE)
-
-    ds1.output(of1)
-
-    val f2 = new Path(path2)
-
-    val of2 = new TypeSerializerOutputFormat[B]
-    of2.setOutputFilePath(f2)
-    of2.setWriteMode(WriteMode.OVERWRITE)
-
-    ds2.output(of2)
-
-    env.execute("FlinkTools persist")
-
-    val if1 = new TypeSerializerInputFormat[A](ds1.getType)
-    if1.setFilePath(f1)
-
-    val if2 = new TypeSerializerInputFormat[B](ds2.getType)
-    if2.setFilePath(f2)
-
-    (env.createInput(if1), env.createInput(if2))
-  }
-
-  /** Writes multiple [[DataSet]]s to the specified paths and returns them as DataSources for
-    * subsequent operations.
-    *
-    * @param ds1 First [[DataSet]] to write to disk
-    * @param ds2 Second [[DataSet]] to write to disk
-    * @param ds3 Third [[DataSet]] to write to disk
-    * @param path1 Path for ds1
-    * @param path2 Path for ds2
-    * @param path3 Path for ds3
-    * @tparam A Type of first [[DataSet]]'s elements
-    * @tparam B Type of second [[DataSet]]'s elements
-    * @tparam C Type of third [[DataSet]]'s elements
-    * @return Tuple of [[DataSet]]s reading the just written files
-    */
-  def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation,
-  C: ClassTag: TypeInformation](ds1: DataSet[A], ds2:  DataSet[B], ds3: DataSet[C], path1:
-  String, path2: String, path3: String): (DataSet[A], DataSet[B], DataSet[C])  = {
-    val env = ds1.getExecutionEnvironment
-
-    val f1 = new Path(path1)
-
-    val of1 = new TypeSerializerOutputFormat[A]
-    of1.setOutputFilePath(f1)
-    of1.setWriteMode(WriteMode.OVERWRITE)
-
-    ds1.output(of1)
-
-    val f2 = new Path(path2)
-
-    val of2 = new TypeSerializerOutputFormat[B]
-    of2.setOutputFilePath(f2)
-    of2.setWriteMode(WriteMode.OVERWRITE)
-
-    ds2.output(of2)
-
-    val f3 = new Path(path3)
-
-    val of3 = new TypeSerializerOutputFormat[C]
-    of3.setOutputFilePath(f3)
-    of3.setWriteMode(WriteMode.OVERWRITE)
-
-    ds3.output(of3)
-
-    env.execute("FlinkTools persist")
-
-    val if1 = new TypeSerializerInputFormat[A](ds1.getType)
-    if1.setFilePath(f1)
-
-    val if2 = new TypeSerializerInputFormat[B](ds2.getType)
-    if2.setFilePath(f2)
-
-    val if3 = new TypeSerializerInputFormat[C](ds3.getType)
-    if3.setFilePath(f3)
-
-    (env.createInput(if1), env.createInput(if2), env.createInput(if3))
-  }
-
-  /** Writes multiple [[DataSet]]s to the specified paths and returns them as DataSources for
-    * subsequent operations.
-    *
-    * @param ds1 First [[DataSet]] to write to disk
-    * @param ds2 Second [[DataSet]] to write to disk
-    * @param ds3 Third [[DataSet]] to write to disk
-    * @param ds4 Fourth [[DataSet]] to write to disk
-    * @param path1 Path for ds1
-    * @param path2 Path for ds2
-    * @param path3 Path for ds3
-    * @param path4 Path for ds4
-    * @tparam A Type of first [[DataSet]]'s elements
-    * @tparam B Type of second [[DataSet]]'s elements
-    * @tparam C Type of third [[DataSet]]'s elements
-    * @tparam D Type of fourth [[DataSet]]'s elements
-    * @return Tuple of [[DataSet]]s reading the just written files
-    */
-  def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation,
-  C: ClassTag: TypeInformation, D: ClassTag: TypeInformation](ds1: DataSet[A], ds2:  DataSet[B],
-                                                              ds3: DataSet[C], ds4: DataSet[D],
-                                                              path1: String, path2: String, path3:
-                                                              String, path4: String):
-  (DataSet[A], DataSet[B], DataSet[C], DataSet[D])  = {
-    val env = ds1.getExecutionEnvironment
-
-    val f1 = new Path(path1)
-
-    val of1 = new TypeSerializerOutputFormat[A]
-    of1.setOutputFilePath(f1)
-    of1.setWriteMode(WriteMode.OVERWRITE)
-
-    ds1.output(of1)
-
-    val f2 = new Path(path2)
-
-    val of2 = new TypeSerializerOutputFormat[B]
-    of2.setOutputFilePath(f2)
-    of2.setWriteMode(WriteMode.OVERWRITE)
-
-    ds2.output(of2)
-
-    val f3 = new Path(path3)
-
-    val of3 = new TypeSerializerOutputFormat[C]
-    of3.setOutputFilePath(f3)
-    of3.setWriteMode(WriteMode.OVERWRITE)
-
-    ds3.output(of3)
-
-    val f4 = new Path(path4)
-
-    val of4 = new TypeSerializerOutputFormat[D]
-    of4.setOutputFilePath(f4)
-    of4.setWriteMode(WriteMode.OVERWRITE)
-
-    ds4.output(of4)
-
-    env.execute("FlinkTools persist")
-
-    val if1 = new TypeSerializerInputFormat[A](ds1.getType)
-    if1.setFilePath(f1)
-
-    val if2 = new TypeSerializerInputFormat[B](ds2.getType)
-    if2.setFilePath(f2)
-
-    val if3 = new TypeSerializerInputFormat[C](ds3.getType)
-    if3.setFilePath(f3)
-
-    val if4 = new TypeSerializerInputFormat[D](ds4.getType)
-    if4.setFilePath(f4)
-
-    (env.createInput(if1), env.createInput(if2), env.createInput(if3), env.createInput(if4))
-  }
-
-  /** Writes multiple [[DataSet]]s to the specified paths and returns them as DataSources for
-    * subsequent operations.
-    *
-    * @param ds1 First [[DataSet]] to write to disk
-    * @param ds2 Second [[DataSet]] to write to disk
-    * @param ds3 Third [[DataSet]] to write to disk
-    * @param ds4 Fourth [[DataSet]] to write to disk
-    * @param ds5 Fifth [[DataSet]] to write to disk
-    * @param path1 Path for ds1
-    * @param path2 Path for ds2
-    * @param path3 Path for ds3
-    * @param path4 Path for ds4
-    * @param path5 Path for ds5
-    * @tparam A Type of first [[DataSet]]'s elements
-    * @tparam B Type of second [[DataSet]]'s elements
-    * @tparam C Type of third [[DataSet]]'s elements
-    * @tparam D Type of fourth [[DataSet]]'s elements
-    * @tparam E Type of fifth [[DataSet]]'s elements
-    * @return Tuple of [[DataSet]]s reading the just written files
-    */
-  def persist[A: ClassTag: TypeInformation ,B: ClassTag: TypeInformation,
-  C: ClassTag: TypeInformation, D: ClassTag: TypeInformation, E: ClassTag: TypeInformation]
-  (ds1: DataSet[A], ds2:  DataSet[B], ds3: DataSet[C], ds4: DataSet[D], ds5: DataSet[E], path1:
-  String, path2: String, path3: String, path4: String, path5: String): (DataSet[A], DataSet[B],
-    DataSet[C], DataSet[D], DataSet[E])  = {
-    val env = ds1.getExecutionEnvironment
-
-    val f1 = new Path(path1)
-
-    val of1 = new TypeSerializerOutputFormat[A]
-    of1.setOutputFilePath(f1)
-    of1.setWriteMode(WriteMode.OVERWRITE)
-
-    ds1.output(of1)
-
-    val f2 = new Path(path2)
-
-    val of2 = new TypeSerializerOutputFormat[B]
-    of2.setOutputFilePath(f2)
-    of2.setOutputDirectoryMode(OutputDirectoryMode.ALWAYS)
-    of2.setWriteMode(WriteMode.OVERWRITE)
-
-    ds2.output(of2)
-
-    val f3 = new Path(path3)
-
-    val of3 = new TypeSerializerOutputFormat[C]
-    of3.setOutputFilePath(f3)
-    of3.setWriteMode(WriteMode.OVERWRITE)
-
-    ds3.output(of3)
-
-    val f4 = new Path(path4)
-
-    val of4 = new TypeSerializerOutputFormat[D]
-    of4.setOutputFilePath(f4)
-    of4.setWriteMode(WriteMode.OVERWRITE)
-
-    ds4.output(of4)
-
-    val f5 = new Path(path5)
-
-    val of5 = new TypeSerializerOutputFormat[E]
-    of5.setOutputFilePath(f5)
-    of5.setWriteMode(WriteMode.OVERWRITE)
-
-    ds5.output(of5)
-
-    env.execute("FlinkTools persist")
-
-    val if1 = new TypeSerializerInputFormat[A](ds1.getType)
-    if1.setFilePath(f1)
-
-    val if2 = new TypeSerializerInputFormat[B](ds2.getType)
-    if2.setFilePath(f2)
-
-    val if3 = new TypeSerializerInputFormat[C](ds3.getType)
-    if3.setFilePath(f3)
-
-    val if4 = new TypeSerializerInputFormat[D](ds4.getType)
-    if4.setFilePath(f4)
-
-    val if5 = new TypeSerializerInputFormat[E](ds5.getType)
-    if5.setFilePath(f5)
-
-    (env.createInput(if1), env.createInput(if2), env.createInput(if3), env.createInput(if4), env
-      .createInput(if5))
-  }
-
-  /** Groups the DataSet input into numBlocks blocks.
-    * 
-    * @param input
-    * @param numBlocks Number of Blocks
-    * @param partitionerOption Optional partitioner to control the partitioning
-    * @tparam T
-    * @return
-    */
-  def block[T: TypeInformation: ClassTag](
-    input: DataSet[T],
-    numBlocks: Int,
-    partitionerOption: Option[Partitioner[Int]] = None)
-  : DataSet[Block[T]] = {
-    val blockIDInput = input map {
-      element =>
-        val blockID = element.hashCode() % numBlocks
-
-        val blockIDResult = if(blockID < 0){
-          blockID + numBlocks
-        } else {
-          blockID
-        }
-
-        (blockIDResult, element)
-    }
-
-    val preGroupBlockIDInput = partitionerOption match {
-      case Some(partitioner) =>
-        blockIDInput partitionCustom(partitioner, 0)
-
-      case None => blockIDInput
-    }
-
-    preGroupBlockIDInput.groupBy(0).reduceGroup {
-      iter => {
-        val array = iter.toVector
-
-        val blockID = array(0)._1
-        val elements = array.map(_._2)
-
-        Block[T](blockID, elements)
-      }
-    }.withForwardedFields("0 -> index")
-  }
-
-  /** Distributes the elements by taking the modulo of their keys and assigning it to this channel
-    *
-    */
-  object ModuloKeyPartitioner extends Partitioner[Int] {
-    override def partition(key: Int, numPartitions: Int): Int = {
-      val result = key % numPartitions
-
-      if(result < 0) {
-        result + numPartitions
-      } else {
-        result
-      }
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/ae446388/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala
index fbe35d4..74d4d8f 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/Breeze.scala
@@ -78,10 +78,10 @@ object Breeze {
     def asBreeze: BreezeVector[Double] = {
       vector match {
         case dense: DenseVector =>
-          new BreezeDenseVector[Double](dense.data)
+          new breeze.linalg.DenseVector(dense.data)
 
         case sparse: SparseVector =>
-          new BreezeSparseVector[Double](sparse.indices, sparse.data, sparse.size)
+          new BreezeSparseVector(sparse.indices, sparse.data, sparse.size)
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/ae446388/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/BreezeVectorConverter.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/BreezeVectorConverter.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/BreezeVectorConverter.scala
index 687772e..f5f7469 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/BreezeVectorConverter.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/math/BreezeVectorConverter.scala
@@ -56,7 +56,10 @@ object BreezeVectorConverter{
             dense.length,
             dense.iterator.toIterable)
         case sparse: BreezeSparseVector[Double] =>
-          new SparseVector(sparse.length, sparse.index, sparse.data)
+          new SparseVector(
+            sparse.used,
+            sparse.index.take(sparse.used),
+            sparse.data.take(sparse.used))
       }
     }
   }
@@ -68,7 +71,10 @@ object BreezeVectorConverter{
         case dense: BreezeDenseVector[Double] => new DenseVector(dense.data)
 
         case sparse: BreezeSparseVector[Double] =>
-          new SparseVector(sparse.length, sparse.index, sparse.data)
+          new SparseVector(
+            sparse.used,
+            sparse.index.take(sparse.used),
+            sparse.data.take(sparse.used))
       }
     }
   }

http://git-wip-us.apache.org/repos/asf/flink/blob/ae446388/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Estimator.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Estimator.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Estimator.scala
index 6acac8f..088b184 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Estimator.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Estimator.scala
@@ -21,7 +21,7 @@ package org.apache.flink.ml.pipeline
 import scala.reflect.ClassTag
 
 import org.apache.flink.api.scala.DataSet
-import org.apache.flink.ml.common.{ParameterMap, WithParameters}
+import org.apache.flink.ml.common.{FlinkMLTools, ParameterMap, WithParameters}
 
 /** Base trait for Flink's pipeline operators.
   *
@@ -50,6 +50,7 @@ trait Estimator[Self] extends WithParameters with Serializable {
       training: DataSet[Training],
       fitParameters: ParameterMap = ParameterMap.Empty)(implicit
       fitOperation: FitOperation[Self, Training]): Unit = {
+    FlinkMLTools.registerFlinkMLTypes(training.getExecutionEnvironment)
     fitOperation.fit(this, fitParameters, training)
   }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/ae446388/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala
index ebfa787..c0e66a0 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Predictor.scala
@@ -21,7 +21,7 @@ package org.apache.flink.ml.pipeline
 import scala.reflect.ClassTag
 
 import org.apache.flink.api.scala.DataSet
-import org.apache.flink.ml.common.{ParameterMap, WithParameters}
+import org.apache.flink.ml.common.{FlinkMLTools, ParameterMap, WithParameters}
 
 /** Predictor trait for Flink's pipeline operators.
   *
@@ -53,6 +53,7 @@ trait Predictor[Self] extends Estimator[Self] with WithParameters with Serializa
       predictParameters: ParameterMap = ParameterMap.Empty)(implicit
       predictor: PredictOperation[Self, Testing, Prediction])
     : DataSet[Prediction] = {
+    FlinkMLTools.registerFlinkMLTypes(testing.getExecutionEnvironment)
     predictor.predict(this, predictParameters, testing)
   }
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/ae446388/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Transformer.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Transformer.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Transformer.scala
index 52e3f7f..02360bc 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Transformer.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/pipeline/Transformer.scala
@@ -21,7 +21,7 @@ package org.apache.flink.ml.pipeline
 import scala.reflect.ClassTag
 
 import org.apache.flink.api.scala.DataSet
-import org.apache.flink.ml.common.{ParameterMap, WithParameters}
+import org.apache.flink.ml.common.{FlinkMLTools, ParameterMap, WithParameters}
 
 /** Transformer trait for Flink's pipeline operators.
   *
@@ -60,6 +60,7 @@ trait Transformer[Self <: Transformer[Self]]
       transformParameters: ParameterMap = ParameterMap.Empty)
       (implicit transformOperation: TransformOperation[Self, Input, Output])
     : DataSet[Output] = {
+    FlinkMLTools.registerFlinkMLTypes(input.getExecutionEnvironment)
     transformOperation.transform(that, transformParameters, input)
   }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/ae446388/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/preprocessing/StandardScaler.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/preprocessing/StandardScaler.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/preprocessing/StandardScaler.scala
index bd952c3..2e3ed95 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/preprocessing/StandardScaler.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/preprocessing/StandardScaler.scala
@@ -61,7 +61,6 @@ import scala.reflect.ClassTag
   */
 class StandardScaler extends Transformer[StandardScaler] {
 
-
   var metricsOption: Option[DataSet[(linalg.Vector[Double], linalg.Vector[Double])]] = None
 
   /** Sets the target mean of the transformed data
@@ -183,7 +182,6 @@ object StandardScaler {
             varianceVector.update(i, 1.0)
           }
         }
-
         (metric._2 / metric._1, varianceVector)
       }
     }

http://git-wip-us.apache.org/repos/asf/flink/blob/ae446388/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala
index d8efdaf..c5db6e4 100644
--- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala
+++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/recommendation/ALS.scala
@@ -471,12 +471,12 @@ object ALS {
         blockIDPartitioner)
 
       val (userIn, userOut) = persistencePath match {
-        case Some(path) => FlinkTools.persist(uIn, uOut, path + "userIn", path + "userOut")
+        case Some(path) => FlinkMLTools.persist(uIn, uOut, path + "userIn", path + "userOut")
         case None => (uIn, uOut)
       }
 
       val (itemIn, itemOut) = persistencePath match {
-        case Some(path) => FlinkTools.persist(iIn, iOut, path + "itemIn", path + "itemOut")
+        case Some(path) => FlinkMLTools.persist(iIn, iOut, path + "itemIn", path + "itemOut")
         case None => (iIn, iOut)
       }
 
@@ -502,7 +502,7 @@ object ALS {
       }
 
       val pItems = persistencePath match {
-        case Some(path) => FlinkTools.persist(items, path + "items")
+        case Some(path) => FlinkMLTools.persist(items, path + "items")
         case None => items
       }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/ae446388/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/common/FlinkMLToolsSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/common/FlinkMLToolsSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/common/FlinkMLToolsSuite.scala
new file mode 100644
index 0000000..525ba4d
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/common/FlinkMLToolsSuite.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common
+
+import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer
+import org.apache.flink.api.scala.ExecutionEnvironment
+import org.apache.flink.test.util.FlinkTestBase
+import org.scalatest.{FlatSpec, Matchers}
+
+class FlinkMLToolsSuite extends FlatSpec with Matchers with FlinkTestBase {
+  behavior of "FlinkMLTools"
+
+  it should "register the required types" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    FlinkMLTools.registerFlinkMLTypes(env)
+
+    val executionConfig = env.getConfig
+
+    val serializer = new KryoSerializer[Nothing](classOf[Nothing], executionConfig)
+
+    val kryo = serializer.getKryo()
+
+    kryo.getRegistration(classOf[org.apache.flink.ml.math.DenseVector]).getId > 0 should be(true)
+    kryo.getRegistration(classOf[org.apache.flink.ml.math.SparseVector]).getId > 0 should be(true)
+    kryo.getRegistration(classOf[org.apache.flink.ml.math.DenseMatrix]).getId > 0 should be(true)
+    kryo.getRegistration(classOf[org.apache.flink.ml.math.SparseMatrix]).getId > 0 should be(true)
+
+    kryo.getRegistration(classOf[breeze.linalg.DenseMatrix[_]]).getId > 0 should be(true)
+    kryo.getRegistration(classOf[breeze.linalg.CSCMatrix[_]]).getId > 0 should be(true)
+    kryo.getRegistration(classOf[breeze.linalg.DenseVector[_]]).getId > 0 should be(true)
+    kryo.getRegistration(classOf[breeze.linalg.SparseVector[_]]).getId > 0 should be(true)
+
+    kryo.getRegistration(breeze.linalg.DenseVector.zeros[Double](0).getClass).getId > 0 should
+      be(true)
+    kryo.getRegistration(breeze.linalg.SparseVector.zeros[Double](0).getClass).getId > 0 should
+      be(true)
+    kryo.getRegistration(breeze.linalg.DenseMatrix.zeros[Double](0, 0).getClass).getId > 0 should
+      be(true)
+    kryo.getRegistration(breeze.linalg.CSCMatrix.zeros[Double](0, 0).getClass).getId > 0 should
+      be(true)
+  }
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/ae446388/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/feature/PolynomialBaseITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/feature/PolynomialBaseITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/feature/PolynomialBaseITSuite.scala
deleted file mode 100644
index 0f045ab..0000000
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/feature/PolynomialBaseITSuite.scala
+++ /dev/null
@@ -1,127 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *     http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.ml.feature
-
-import org.apache.flink.api.scala.ExecutionEnvironment
-import org.apache.flink.ml.common.LabeledVector
-import org.apache.flink.ml.math.DenseVector
-import org.apache.flink.ml.preprocessing.PolynomialFeatures
-import org.scalatest.{Matchers, FlatSpec}
-
-import org.apache.flink.api.scala._
-import org.apache.flink.test.util.FlinkTestBase
-
-class PolynomialBaseITSuite
-  extends FlatSpec
-  with Matchers
-  with FlinkTestBase {
-
-  behavior of "The polynomial base implementation"
-
-  it should "map single element vectors to the polynomial vector space" in {
-    val env = ExecutionEnvironment.getExecutionEnvironment
-
-    env.setParallelism (2)
-
-    val input = Seq (
-    LabeledVector (1.0, DenseVector (1)),
-    LabeledVector (2.0, DenseVector (2))
-    )
-
-    val inputDS = env.fromCollection (input)
-
-    val transformer = PolynomialFeatures()
-    .setDegree (3)
-
-    val transformedDS = transformer.transform(inputDS)
-
-    val expectedMap = List (
-    (1.0 -> DenseVector (1.0, 1.0, 1.0) ),
-    (2.0 -> DenseVector (8.0, 4.0, 2.0) )
-    ) toMap
-
-    val result = transformedDS.collect()
-
-    for (entry <- result) {
-    expectedMap.contains (entry.label) should be (true)
-    entry.vector should equal (expectedMap (entry.label) )
-    }
-  }
-
-  it should "map vectors to the polynomial vector space" in {
-    val env = ExecutionEnvironment.getExecutionEnvironment
-
-    env.setParallelism(2)
-
-    val input = Seq(
-      LabeledVector(1.0, DenseVector(2, 3)),
-      LabeledVector(2.0, DenseVector(2, 3, 4))
-    )
-
-    val expectedMap = List(
-      (1.0 -> DenseVector(8.0, 12.0, 18.0, 27.0, 4.0, 6.0, 9.0, 2.0, 3.0)),
-      (2.0 -> DenseVector(8.0, 12.0, 16.0, 18.0, 24.0, 32.0, 27.0, 36.0, 48.0, 64.0, 4.0, 6.0, 8.0,
-        9.0, 12.0, 16.0, 2.0, 3.0, 4.0))
-    ) toMap
-
-    val inputDS = env.fromCollection(input)
-
-    val transformer = PolynomialFeatures()
-      .setDegree(3)
-
-    val transformedDS = transformer.transform(inputDS)
-
-    val result = transformedDS.collect()
-
-    for(entry <- result) {
-      expectedMap.contains(entry.label) should be(true)
-      entry.vector should equal(expectedMap(entry.label))
-    }
-  }
-
-  it should "return an empty vector if the max degree is zero" in {
-    val env = ExecutionEnvironment.getExecutionEnvironment
-
-    env.setParallelism(2)
-
-    val input = Seq(
-      LabeledVector(1.0, DenseVector(2, 3)),
-      LabeledVector(2.0, DenseVector(2, 3, 4))
-    )
-
-    val inputDS = env.fromCollection(input)
-
-    val transformer = PolynomialFeatures()
-      .setDegree(0)
-
-    val transformedDS = transformer.transform(inputDS)
-
-    val result = transformedDS.collect()
-
-    val expectedMap = List(
-      (1.0 -> DenseVector()),
-      (2.0 -> DenseVector())
-    ) toMap
-
-    for(entry <- result) {
-      expectedMap.contains(entry.label) should be(true)
-      entry.vector should equal(expectedMap(entry.label))
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/ae446388/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/feature/PolynomialFeaturesITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/feature/PolynomialFeaturesITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/feature/PolynomialFeaturesITSuite.scala
new file mode 100644
index 0000000..674c1c4
--- /dev/null
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/feature/PolynomialFeaturesITSuite.scala
@@ -0,0 +1,127 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature
+
+import org.apache.flink.api.scala.ExecutionEnvironment
+import org.apache.flink.ml.common.LabeledVector
+import org.apache.flink.ml.math.DenseVector
+import org.apache.flink.ml.preprocessing.PolynomialFeatures
+import org.scalatest.{Matchers, FlatSpec}
+
+import org.apache.flink.api.scala._
+import org.apache.flink.test.util.FlinkTestBase
+
+class PolynomialFeaturesITSuite
+  extends FlatSpec
+  with Matchers
+  with FlinkTestBase {
+
+  behavior of "The polynomial base implementation"
+
+  it should "map single element vectors to the polynomial vector space" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    env.setParallelism (2)
+
+    val input = Seq (
+    LabeledVector (1.0, DenseVector (1)),
+    LabeledVector (2.0, DenseVector (2))
+    )
+
+    val inputDS = env.fromCollection (input)
+
+    val transformer = PolynomialFeatures()
+    .setDegree (3)
+
+    val transformedDS = transformer.transform(inputDS)
+
+    val expectedMap = List (
+    (1.0 -> DenseVector (1.0, 1.0, 1.0) ),
+    (2.0 -> DenseVector (8.0, 4.0, 2.0) )
+    ) toMap
+
+    val result = transformedDS.collect()
+
+    for (entry <- result) {
+    expectedMap.contains (entry.label) should be (true)
+    entry.vector should equal (expectedMap (entry.label) )
+    }
+  }
+
+  it should "map vectors to the polynomial vector space" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    env.setParallelism(2)
+
+    val input = Seq(
+      LabeledVector(1.0, DenseVector(2, 3)),
+      LabeledVector(2.0, DenseVector(2, 3, 4))
+    )
+
+    val expectedMap = List(
+      (1.0 -> DenseVector(8.0, 12.0, 18.0, 27.0, 4.0, 6.0, 9.0, 2.0, 3.0)),
+      (2.0 -> DenseVector(8.0, 12.0, 16.0, 18.0, 24.0, 32.0, 27.0, 36.0, 48.0, 64.0, 4.0, 6.0, 8.0,
+        9.0, 12.0, 16.0, 2.0, 3.0, 4.0))
+    ) toMap
+
+    val inputDS = env.fromCollection(input)
+
+    val transformer = PolynomialFeatures()
+      .setDegree(3)
+
+    val transformedDS = transformer.transform(inputDS)
+
+    val result = transformedDS.collect()
+
+    for(entry <- result) {
+      expectedMap.contains(entry.label) should be(true)
+      entry.vector should equal(expectedMap(entry.label))
+    }
+  }
+
+  it should "return an empty vector if the max degree is zero" in {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+
+    env.setParallelism(2)
+
+    val input = Seq(
+      LabeledVector(1.0, DenseVector(2, 3)),
+      LabeledVector(2.0, DenseVector(2, 3, 4))
+    )
+
+    val inputDS = env.fromCollection(input)
+
+    val transformer = PolynomialFeatures()
+      .setDegree(0)
+
+    val transformedDS = transformer.transform(inputDS)
+
+    val result = transformedDS.collect()
+
+    val expectedMap = List(
+      (1.0 -> DenseVector()),
+      (2.0 -> DenseVector())
+    ) toMap
+
+    for(entry <- result) {
+      expectedMap.contains(entry.label) should be(true)
+      entry.vector should equal(expectedMap(entry.label))
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/ae446388/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/pipeline/PipelineITSuite.scala
----------------------------------------------------------------------
diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/pipeline/PipelineITSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/pipeline/PipelineITSuite.scala
index 8803195..9909a18 100644
--- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/pipeline/PipelineITSuite.scala
+++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/pipeline/PipelineITSuite.scala
@@ -18,9 +18,12 @@
 
 package org.apache.flink.ml.pipeline
 
+import breeze.linalg
+import org.apache.flink.api.common.ExecutionConfig
+import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer
 import org.apache.flink.api.scala._
 import org.apache.flink.ml.common.LabeledVector
-import org.apache.flink.ml.math.DenseVector
+import org.apache.flink.ml.math._
 import org.apache.flink.ml.preprocessing.{PolynomialFeatures, StandardScaler}
 import org.apache.flink.ml.regression.MultipleLinearRegression
 import org.apache.flink.test.util.FlinkTestBase


Mime
View raw message