flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From mbala...@apache.org
Subject [3/7] git commit: [FLINK-1188] [streaming] Updated aggregations to work also on arrays by default
Date Mon, 27 Oct 2014 12:03:27 GMT
[FLINK-1188] [streaming] Updated aggregations to work also on arrays by default


Project: http://git-wip-us.apache.org/repos/asf/incubator-flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-flink/commit/7709a3a6
Tree: http://git-wip-us.apache.org/repos/asf/incubator-flink/tree/7709a3a6
Diff: http://git-wip-us.apache.org/repos/asf/incubator-flink/diff/7709a3a6

Branch: refs/heads/master
Commit: 7709a3a6906d015e82c8d3a3fce30a5a90da5167
Parents: a221796
Author: Gyula Fora <gyfora@apache.org>
Authored: Tue Oct 21 22:30:12 2014 +0200
Committer: mbalassi <balassi.marton@gmail.com>
Committed: Mon Oct 27 12:23:58 2014 +0100

----------------------------------------------------------------------
 .../api/datastream/BatchedDataStream.java       | 12 ++--
 .../streaming/api/datastream/DataStream.java    | 37 +++++++++--
 .../aggregation/AggregationFunction.java        | 12 +++-
 .../ComparableAggregationFunction.java          | 29 ++++++--
 .../aggregation/MaxAggregationFunction.java     |  6 +-
 .../aggregation/MaxByAggregationFunction.java   |  6 +-
 .../aggregation/MinAggregationFunction.java     |  6 +-
 .../aggregation/MinByAggregationFunction.java   | 19 +++++-
 .../aggregation/SumAggregationFunction.java     | 69 +++++++++++---------
 .../streaming/api/AggregationFunctionTest.java  | 34 ++++++----
 10 files changed, 160 insertions(+), 70 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7709a3a6/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/BatchedDataStream.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/BatchedDataStream.java
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/BatchedDataStream.java
index dbf436d..75eadcf 100755
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/BatchedDataStream.java
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/BatchedDataStream.java
@@ -137,7 +137,7 @@ public class BatchedDataStream<OUT> {
 	public SingleOutputStreamOperator<OUT, ?> sum(int positionToSum) {
 		dataStream.checkFieldRange(positionToSum);
 		return aggregate((AggregationFunction<OUT>) SumAggregationFunction.getSumFunction(
-				positionToSum, dataStream.getClassAtPos(positionToSum)));
+				positionToSum, dataStream.getClassAtPos(positionToSum), dataStream.getOutputType()));
 	}
 
 	/**
@@ -159,7 +159,7 @@ public class BatchedDataStream<OUT> {
 	 */
 	public SingleOutputStreamOperator<OUT, ?> min(int positionToMin) {
 		dataStream.checkFieldRange(positionToMin);
-		return aggregate(new MinAggregationFunction<OUT>(positionToMin));
+		return aggregate(new MinAggregationFunction<OUT>(positionToMin, dataStream.getOutputType()));
 	}
 
 	/**
@@ -191,7 +191,8 @@ public class BatchedDataStream<OUT> {
 	 */
 	public SingleOutputStreamOperator<OUT, ?> minBy(int positionToMinBy, boolean first)
{
 		dataStream.checkFieldRange(positionToMinBy);
-		return aggregate(new MinByAggregationFunction<OUT>(positionToMinBy, first));
+		return aggregate(new MinByAggregationFunction<OUT>(positionToMinBy, first,
+				dataStream.getOutputType()));
 	}
 
 	/**
@@ -213,7 +214,7 @@ public class BatchedDataStream<OUT> {
 	 */
 	public SingleOutputStreamOperator<OUT, ?> max(int positionToMax) {
 		dataStream.checkFieldRange(positionToMax);
-		return aggregate(new MaxAggregationFunction<OUT>(positionToMax));
+		return aggregate(new MaxAggregationFunction<OUT>(positionToMax, dataStream.getOutputType()));
 	}
 
 	/**
@@ -244,7 +245,8 @@ public class BatchedDataStream<OUT> {
 	 */
 	public SingleOutputStreamOperator<OUT, ?> maxBy(int positionToMaxBy, boolean first)
{
 		dataStream.checkFieldRange(positionToMaxBy);
-		return aggregate(new MaxByAggregationFunction<OUT>(positionToMaxBy, first));
+		return aggregate(new MaxByAggregationFunction<OUT>(positionToMaxBy, first,
+				dataStream.getOutputType()));
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7709a3a6/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
index 36649cc..98058df 100644
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/datastream/DataStream.java
@@ -32,6 +32,8 @@ import org.apache.flink.api.common.functions.RichFilterFunction;
 import org.apache.flink.api.common.functions.RichFlatMapFunction;
 import org.apache.flink.api.common.functions.RichMapFunction;
 import org.apache.flink.api.common.functions.RichReduceFunction;
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.api.java.tuple.Tuple;
@@ -201,6 +203,31 @@ public class DataStream<OUT> {
 		TypeInformation<OUT> outTypeInfo = outTypeWrapper.getTypeInfo();
 		if (outTypeInfo.isTupleType()) {
 			type = ((TupleTypeInfo) outTypeInfo).getTypeAt(pos).getTypeClass();
+
+		} else if (outTypeInfo instanceof BasicArrayTypeInfo) {
+
+			type = ((BasicArrayTypeInfo) outTypeInfo).getComponentTypeClass();
+
+		} else if (outTypeInfo instanceof PrimitiveArrayTypeInfo) {
+			Class<?> clazz = outTypeInfo.getTypeClass();
+			if (clazz == boolean[].class) {
+				type = Boolean.class;
+			} else if (clazz == short[].class) {
+				type = Short.class;
+			} else if (clazz == int[].class) {
+				type = Integer.class;
+			} else if (clazz == long[].class) {
+				type = Long.class;
+			} else if (clazz == float[].class) {
+				type = Float.class;
+			} else if (clazz == double[].class) {
+				type = Double.class;
+			} else if (clazz == char[].class) {
+				type = Character.class;
+			} else {
+				throw new IndexOutOfBoundsException("Type could not be determined for array");
+			}
+
 		} else if (pos == 0) {
 			type = outTypeInfo.getTypeClass();
 		} else {
@@ -594,7 +621,7 @@ public class DataStream<OUT> {
 	public SingleOutputStreamOperator<OUT, ?> sum(int positionToSum) {
 		checkFieldRange(positionToSum);
 		return aggregate((AggregationFunction<OUT>) SumAggregationFunction.getSumFunction(
-				positionToSum, getClassAtPos(positionToSum)));
+				positionToSum, getClassAtPos(positionToSum), getOutputType()));
 	}
 
 	/**
@@ -616,7 +643,7 @@ public class DataStream<OUT> {
 	 */
 	public SingleOutputStreamOperator<OUT, ?> min(int positionToMin) {
 		checkFieldRange(positionToMin);
-		return aggregate(new MinAggregationFunction<OUT>(positionToMin));
+		return aggregate(new MinAggregationFunction<OUT>(positionToMin, getOutputType()));
 	}
 
 	/**
@@ -648,7 +675,7 @@ public class DataStream<OUT> {
 	 */
 	public SingleOutputStreamOperator<OUT, ?> minBy(int positionToMinBy, boolean first)
{
 		checkFieldRange(positionToMinBy);
-		return aggregate(new MinByAggregationFunction<OUT>(positionToMinBy, first));
+		return aggregate(new MinByAggregationFunction<OUT>(positionToMinBy, first, getOutputType()));
 	}
 
 	/**
@@ -670,7 +697,7 @@ public class DataStream<OUT> {
 	 */
 	public SingleOutputStreamOperator<OUT, ?> max(int positionToMax) {
 		checkFieldRange(positionToMax);
-		return aggregate(new MaxAggregationFunction<OUT>(positionToMax));
+		return aggregate(new MaxAggregationFunction<OUT>(positionToMax, getOutputType()));
 	}
 
 	/**
@@ -702,7 +729,7 @@ public class DataStream<OUT> {
 	 */
 	public SingleOutputStreamOperator<OUT, ?> maxBy(int positionToMaxBy, boolean first)
{
 		checkFieldRange(positionToMaxBy);
-		return aggregate(new MaxByAggregationFunction<OUT>(positionToMaxBy, first));
+		return aggregate(new MaxByAggregationFunction<OUT>(positionToMaxBy, first, getOutputType()));
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7709a3a6/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java
index 512853a..825b4db 100644
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java
@@ -18,17 +18,23 @@
 package org.apache.flink.streaming.api.function.aggregation;
 
 import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.typeinfo.BasicArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.PrimitiveArrayTypeInfo;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.tuple.Tuple;
 
 public abstract class AggregationFunction<T> implements ReduceFunction<T> {
 	private static final long serialVersionUID = 1L;
-	
+
 	public int position;
 	protected Tuple returnTuple;
+	protected boolean isTuple;
+	protected boolean isArray;
 
-	public AggregationFunction(int pos) {
+	public AggregationFunction(int pos, TypeInformation<?> type) {
 		this.position = pos;
+		this.isTuple = type.isTupleType();
+		this.isArray = type instanceof BasicArrayTypeInfo || type instanceof PrimitiveArrayTypeInfo;
 	}
 
-
 }

http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7709a3a6/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregationFunction.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregationFunction.java
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregationFunction.java
index 93444df..383c39c 100644
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregationFunction.java
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregationFunction.java
@@ -17,37 +17,56 @@
 
 package org.apache.flink.streaming.api.function.aggregation;
 
+import java.lang.reflect.Array;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.tuple.Tuple;
 
 public abstract class ComparableAggregationFunction<T> extends AggregationFunction<T>
{
 
 	private static final long serialVersionUID = 1L;
 
-	public ComparableAggregationFunction(int positionToAggregate) {
-		super(positionToAggregate);
+	public ComparableAggregationFunction(int positionToAggregate, TypeInformation<?> type)
{
+		super(positionToAggregate, type);
 	}
 
 	@SuppressWarnings("unchecked")
 	@Override
 	public T reduce(T value1, T value2) throws Exception {
-		if (value1 instanceof Tuple) {
+		if (isTuple) {
 			Tuple t1 = (Tuple) value1;
 			Tuple t2 = (Tuple) value2;
 
 			compare(t1, t2);
 
 			return (T) returnTuple;
+		} else if (isArray) {
+			return compareArray(value1, value2);
 		} else if (value1 instanceof Comparable) {
 			if (isExtremal((Comparable<Object>) value1, value2)) {
 				return value1;
-			}else{
+			} else {
 				return value2;
 			}
 		} else {
-			throw new RuntimeException("The values " + value1 +  " and "+ value2 + " cannot be compared.");
+			throw new RuntimeException("The values " + value1 + " and " + value2
+					+ " cannot be compared.");
 		}
 	}
 
+	@SuppressWarnings("unchecked")
+	public T compareArray(T array1, T array2) {
+		Object v1 = Array.get(array1, position);
+		Object v2 = Array.get(array2, position);
+		if (isExtremal((Comparable<Object>) v1, v2)) {
+			Array.set(array2, position, v1);
+		} else {
+			Array.set(array2, position, v2);
+		}
+
+		return array2;
+	}
+
 	public <R> void compare(Tuple tuple1, Tuple tuple2) throws InstantiationException,
 			IllegalAccessException {
 

http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7709a3a6/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MaxAggregationFunction.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MaxAggregationFunction.java
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MaxAggregationFunction.java
index dd63b2d..d013162 100644
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MaxAggregationFunction.java
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MaxAggregationFunction.java
@@ -17,12 +17,14 @@
 
 package org.apache.flink.streaming.api.function.aggregation;
 
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+
 public class MaxAggregationFunction<T> extends ComparableAggregationFunction<T>
{
 
 	private static final long serialVersionUID = 1L;
 
-	public MaxAggregationFunction(int pos) {
-		super(pos);
+	public MaxAggregationFunction(int pos, TypeInformation<?> type) {
+		super(pos, type);
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7709a3a6/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MaxByAggregationFunction.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MaxByAggregationFunction.java
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MaxByAggregationFunction.java
index 274c8b6..4679028 100644
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MaxByAggregationFunction.java
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MaxByAggregationFunction.java
@@ -17,12 +17,14 @@
 
 package org.apache.flink.streaming.api.function.aggregation;
 
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+
 public class MaxByAggregationFunction<T> extends MinByAggregationFunction<T>
{
 
 	private static final long serialVersionUID = 1L;
 
-	public MaxByAggregationFunction(int pos, boolean first) {
-		super(pos, first);
+	public MaxByAggregationFunction(int pos, boolean first, TypeInformation<?> type) {
+		super(pos, first, type);
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7709a3a6/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MinAggregationFunction.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MinAggregationFunction.java
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MinAggregationFunction.java
index ad903a8..83c20c7 100644
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MinAggregationFunction.java
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MinAggregationFunction.java
@@ -17,12 +17,14 @@
 
 package org.apache.flink.streaming.api.function.aggregation;
 
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+
 public class MinAggregationFunction<T> extends ComparableAggregationFunction<T>
{
 
 	private static final long serialVersionUID = 1L;
 
-	public MinAggregationFunction(int pos) {
-		super(pos);
+	public MinAggregationFunction(int pos, TypeInformation<?> type) {
+		super(pos, type);
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7709a3a6/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MinByAggregationFunction.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MinByAggregationFunction.java
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MinByAggregationFunction.java
index a4a328c..31d6b37 100644
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MinByAggregationFunction.java
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/MinByAggregationFunction.java
@@ -17,6 +17,9 @@
 
 package org.apache.flink.streaming.api.function.aggregation;
 
+import java.lang.reflect.Array;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.tuple.Tuple;
 
 public class MinByAggregationFunction<T> extends ComparableAggregationFunction<T>
{
@@ -24,8 +27,8 @@ public class MinByAggregationFunction<T> extends ComparableAggregationFunction<T
 	private static final long serialVersionUID = 1L;
 	protected boolean first;
 
-	public MinByAggregationFunction(int pos, boolean first) {
-		super(pos);
+	public MinByAggregationFunction(int pos, boolean first, TypeInformation<?> type) {
+		super(pos, type);
 		this.first = first;
 	}
 
@@ -44,6 +47,18 @@ public class MinByAggregationFunction<T> extends ComparableAggregationFunction<T
 	}
 
 	@Override
+	@SuppressWarnings("unchecked")
+	public T compareArray(T array1, T array2) {
+		Object v1 = Array.get(array1, position);
+		Object v2 = Array.get(array2, position);
+		if (isExtremal((Comparable<Object>) v1, v2)) {
+			return array1;
+		} else {
+			return array2;
+		}
+	}
+
+	@Override
 	public <R> boolean isExtremal(Comparable<R> o1, R o2) {
 		if (first) {
 			return o1.compareTo(o2) <= 0;

http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7709a3a6/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumAggregationFunction.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumAggregationFunction.java
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumAggregationFunction.java
index 0429cdb..cd50072 100644
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumAggregationFunction.java
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumAggregationFunction.java
@@ -17,20 +17,23 @@
 
 package org.apache.flink.streaming.api.function.aggregation;
 
+import java.lang.reflect.Array;
+
+import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.tuple.Tuple;
 
 public abstract class SumAggregationFunction<T> extends AggregationFunction<T>
{
 
 	private static final long serialVersionUID = 1L;
 
-	public SumAggregationFunction(int pos) {
-		super(pos);
+	public SumAggregationFunction(int pos, TypeInformation<?> type) {
+		super(pos, type);
 	}
 
 	@SuppressWarnings("unchecked")
 	@Override
 	public T reduce(T value1, T value2) throws Exception {
-		if (value1 instanceof Tuple) {
+		if (isTuple) {
 			Tuple tuple1 = (Tuple) value1;
 			Tuple tuple2 = (Tuple) value2;
 
@@ -39,6 +42,11 @@ public abstract class SumAggregationFunction<T> extends AggregationFunction<T>
{
 					position);
 
 			return (T) returnTuple;
+		} else if (isArray) {
+			Object v1 = Array.get(value1, position);
+			Object v2 = Array.get(value2, position);
+			Array.set(value2, position, add(v1, v2));
+			return value2;
 		} else {
 			return (T) add(value1, value2);
 		}
@@ -47,23 +55,24 @@ public abstract class SumAggregationFunction<T> extends AggregationFunction<T>
{
 	protected abstract Object add(Object value1, Object value2);
 
 	@SuppressWarnings("rawtypes")
-	public static <T> SumAggregationFunction getSumFunction(int pos, Class<T> type)
{
-
-		if (type == Integer.class) {
-			return new IntSum<T>(pos);
-		} else if (type == Long.class) {
-			return new LongSum<T>(pos);
-		} else if (type == Short.class) {
-			return new ShortSum<T>(pos);
-		} else if (type == Double.class) {
-			return new DoubleSum<T>(pos);
-		} else if (type == Float.class) {
-			return new FloatSum<T>(pos);
-		} else if (type == Byte.class) {
-			return new ByteSum<T>(pos);
+	public static <T> SumAggregationFunction getSumFunction(int pos, Class<T> classAtPos,
+			TypeInformation<?> typeInfo) {
+
+		if (classAtPos == Integer.class) {
+			return new IntSum<T>(pos, typeInfo);
+		} else if (classAtPos == Long.class) {
+			return new LongSum<T>(pos, typeInfo);
+		} else if (classAtPos == Short.class) {
+			return new ShortSum<T>(pos, typeInfo);
+		} else if (classAtPos == Double.class) {
+			return new DoubleSum<T>(pos, typeInfo);
+		} else if (classAtPos == Float.class) {
+			return new FloatSum<T>(pos, typeInfo);
+		} else if (classAtPos == Byte.class) {
+			return new ByteSum<T>(pos, typeInfo);
 		} else {
 			throw new RuntimeException("DataStream cannot be summed because the class "
-					+ type.getSimpleName() + " does not support the + operator.");
+					+ classAtPos.getSimpleName() + " does not support the + operator.");
 		}
 
 	}
@@ -71,8 +80,8 @@ public abstract class SumAggregationFunction<T> extends AggregationFunction<T>
{
 	private static class IntSum<T> extends SumAggregationFunction<T> {
 		private static final long serialVersionUID = 1L;
 
-		public IntSum(int pos) {
-			super(pos);
+		public IntSum(int pos, TypeInformation<?> type) {
+			super(pos, type);
 		}
 
 		@Override
@@ -84,8 +93,8 @@ public abstract class SumAggregationFunction<T> extends AggregationFunction<T>
{
 	private static class LongSum<T> extends SumAggregationFunction<T> {
 		private static final long serialVersionUID = 1L;
 
-		public LongSum(int pos) {
-			super(pos);
+		public LongSum(int pos, TypeInformation<?> type) {
+			super(pos, type);
 		}
 
 		@Override
@@ -98,8 +107,8 @@ public abstract class SumAggregationFunction<T> extends AggregationFunction<T>
{
 
 		private static final long serialVersionUID = 1L;
 
-		public DoubleSum(int pos) {
-			super(pos);
+		public DoubleSum(int pos, TypeInformation<?> type) {
+			super(pos, type);
 		}
 
 		@Override
@@ -111,8 +120,8 @@ public abstract class SumAggregationFunction<T> extends AggregationFunction<T>
{
 	private static class ShortSum<T> extends SumAggregationFunction<T> {
 		private static final long serialVersionUID = 1L;
 
-		public ShortSum(int pos) {
-			super(pos);
+		public ShortSum(int pos, TypeInformation<?> type) {
+			super(pos, type);
 		}
 
 		@Override
@@ -124,8 +133,8 @@ public abstract class SumAggregationFunction<T> extends AggregationFunction<T>
{
 	private static class FloatSum<T> extends SumAggregationFunction<T> {
 		private static final long serialVersionUID = 1L;
 
-		public FloatSum(int pos) {
-			super(pos);
+		public FloatSum(int pos, TypeInformation<?> type) {
+			super(pos, type);
 		}
 
 		@Override
@@ -137,8 +146,8 @@ public abstract class SumAggregationFunction<T> extends AggregationFunction<T>
{
 	private static class ByteSum<T> extends SumAggregationFunction<T> {
 		private static final long serialVersionUID = 1L;
 
-		public ByteSum(int pos) {
-			super(pos);
+		public ByteSum(int pos, TypeInformation<?> type) {
+			super(pos, type);
 		}
 
 		@Override

http://git-wip-us.apache.org/repos/asf/incubator-flink/blob/7709a3a6/flink-addons/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/AggregationFunctionTest.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/AggregationFunctionTest.java
b/flink-addons/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/AggregationFunctionTest.java
index 07f1185..70e6118 100644
--- a/flink-addons/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/AggregationFunctionTest.java
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/test/java/org/apache/flink/streaming/api/AggregationFunctionTest.java
@@ -23,7 +23,9 @@ import static org.junit.Assert.fail;
 import java.util.ArrayList;
 import java.util.List;
 
+import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.api.java.typeutils.TypeExtractor;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.function.aggregation.MaxAggregationFunction;
 import org.apache.flink.streaming.api.function.aggregation.MaxByAggregationFunction;
@@ -85,18 +87,21 @@ public class AggregationFunctionTest {
 			expectedGroupMaxList.add(new Tuple2<Integer, Integer>(i % 3, i));
 		}
 
+		TypeInformation<?> type1 = TypeExtractor.getForObject(new Tuple2<Integer, Integer>(0,
0));
+		TypeInformation<?> type2 = TypeExtractor.getForObject(2);
+
 		@SuppressWarnings("unchecked")
 		SumAggregationFunction<Tuple2<Integer, Integer>> sumFunction = SumAggregationFunction
-				.getSumFunction(1, Integer.class);
+				.getSumFunction(1, Integer.class, type1);
 		@SuppressWarnings("unchecked")
 		SumAggregationFunction<Integer> sumFunction0 = SumAggregationFunction.getSumFunction(0,
-				Integer.class);
+				Integer.class, type2);
 		MinAggregationFunction<Tuple2<Integer, Integer>> minFunction = new MinAggregationFunction<Tuple2<Integer,
Integer>>(
-				1);
-		MinAggregationFunction<Integer> minFunction0 = new MinAggregationFunction<Integer>(0);
+				1, type1);
+		MinAggregationFunction<Integer> minFunction0 = new MinAggregationFunction<Integer>(0,
type2);
 		MaxAggregationFunction<Tuple2<Integer, Integer>> maxFunction = new MaxAggregationFunction<Tuple2<Integer,
Integer>>(
-				1);
-		MaxAggregationFunction<Integer> maxFunction0 = new MaxAggregationFunction<Integer>(0);
+				1, type1);
+		MaxAggregationFunction<Integer> maxFunction0 = new MaxAggregationFunction<Integer>(0,
type2);
 
 		List<Tuple2<Integer, Integer>> sumList = MockInvokable.createAndExecute(
 				new StreamReduceInvokable<Tuple2<Integer, Integer>>(sumFunction), getInputList());
@@ -156,14 +161,14 @@ public class AggregationFunctionTest {
 		}
 
 		MaxByAggregationFunction<Tuple2<Integer, Integer>> maxByFunctionFirst = new
MaxByAggregationFunction<Tuple2<Integer, Integer>>(
-				0, true);
+				0, true, type1);
 		MaxByAggregationFunction<Tuple2<Integer, Integer>> maxByFunctionLast = new
MaxByAggregationFunction<Tuple2<Integer, Integer>>(
-				0, false);
+				0, false, type1);
 
 		MinByAggregationFunction<Tuple2<Integer, Integer>> minByFunctionFirst = new
MinByAggregationFunction<Tuple2<Integer, Integer>>(
-				0, true);
+				0, true, type1);
 		MinByAggregationFunction<Tuple2<Integer, Integer>> minByFunctionLast = new
MinByAggregationFunction<Tuple2<Integer, Integer>>(
-				0, false);
+				0, false, type1);
 
 		List<Tuple2<Integer, Integer>> maxByFirstExpected = new ArrayList<Tuple2<Integer,
Integer>>();
 		maxByFirstExpected.add(new Tuple2<Integer, Integer>(0, 0));
@@ -226,16 +231,17 @@ public class AggregationFunctionTest {
 
 	@Test
 	public void minMaxByTest() {
+		TypeInformation<?> type1 = TypeExtractor.getForObject(new Tuple2<Integer, Integer>(0,
0));
 
 		MaxByAggregationFunction<Tuple2<Integer, Integer>> maxByFunctionFirst = new
MaxByAggregationFunction<Tuple2<Integer, Integer>>(
-				0, true);
+				0, true, type1);
 		MaxByAggregationFunction<Tuple2<Integer, Integer>> maxByFunctionLast = new
MaxByAggregationFunction<Tuple2<Integer, Integer>>(
-				0, false);
+				0, false, type1);
 
 		MinByAggregationFunction<Tuple2<Integer, Integer>> minByFunctionFirst = new
MinByAggregationFunction<Tuple2<Integer, Integer>>(
-				0, true);
+				0, true, type1);
 		MinByAggregationFunction<Tuple2<Integer, Integer>> minByFunctionLast = new
MinByAggregationFunction<Tuple2<Integer, Integer>>(
-				0, false);
+				0, false, type1);
 
 		List<Tuple2<Integer, Integer>> maxByFirstExpected = new ArrayList<Tuple2<Integer,
Integer>>();
 		maxByFirstExpected.add(new Tuple2<Integer, Integer>(0, 0));


Mime
View raw message