flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From mbala...@apache.org
Subject [11/36] flink git commit: [scala] [streaming] Modified aggregations to work on scala tuples
Date Wed, 07 Jan 2015 14:12:50 GMT
[scala] [streaming] Modified aggregations to work on scala tuples


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/f7291ea1
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/f7291ea1
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/f7291ea1

Branch: refs/heads/release-0.8
Commit: f7291ea1c9fa4a0484ab6bc13e4a594ff6d7c2d2
Parents: 40a3b6b
Author: Gyula Fora <gyfora@apache.org>
Authored: Sat Dec 20 23:46:35 2014 +0100
Committer: mbalassi <mbalassi@apache.org>
Committed: Mon Jan 5 17:57:44 2015 +0100

----------------------------------------------------------------------
 .../aggregation/AggregationFunction.java        |   2 +-
 .../aggregation/ComparableAggregator.java       |   8 +-
 .../api/function/aggregation/SumFunction.java   |  12 +-
 .../operator/StreamReduceInvokable.java         |   1 +
 .../streaming/ScalaStreamingAggregator.java     | 111 +++++++++++++++++++
 .../flink/api/scala/streaming/DataStream.scala  |  52 +++++----
 .../scala/streaming/WindowedDataStream.scala    |  64 ++++-------
 7 files changed, 171 insertions(+), 79 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java
index d95c37e..1c273d3 100644
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/AggregationFunction.java
@@ -22,7 +22,7 @@ import org.apache.flink.api.common.functions.ReduceFunction;
 public abstract class AggregationFunction<T> implements ReduceFunction<T> {
 	private static final long serialVersionUID = 1L;
 
-	int position;
+	public int position;
 
 	public AggregationFunction(int pos) {
 		this.position = pos;

http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java
index 6e2a400..5fb8f62 100644
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/ComparableAggregator.java
@@ -35,11 +35,11 @@ public abstract class ComparableAggregator<T> extends AggregationFunction<T>
{
 
 	private static final long serialVersionUID = 1L;
 
-	Comparator comparator;
-	boolean byAggregate;
-	boolean first;
+	public Comparator comparator;
+	public boolean byAggregate;
+	public boolean first;
 
-	private ComparableAggregator(int pos, AggregationType aggregationType, boolean first) {
+	public ComparableAggregator(int pos, AggregationType aggregationType, boolean first) {
 		super(pos);
 		this.comparator = Comparator.getForAggregation(aggregationType);
 		this.byAggregate = (aggregationType == AggregationType.MAXBY)

http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumFunction.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumFunction.java
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumFunction.java
index 1ac236d..2aef19c 100644
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumFunction.java
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/function/aggregation/SumFunction.java
@@ -45,7 +45,7 @@ public abstract class SumFunction implements Serializable{
 		}
 	}
 
-	private static class IntSum extends SumFunction {
+	public static class IntSum extends SumFunction {
 		private static final long serialVersionUID = 1L;
 
 		@Override
@@ -54,7 +54,7 @@ public abstract class SumFunction implements Serializable{
 		}
 	}
 
-	private static class LongSum extends SumFunction {
+	public static class LongSum extends SumFunction {
 		private static final long serialVersionUID = 1L;
 
 		@Override
@@ -63,7 +63,7 @@ public abstract class SumFunction implements Serializable{
 		}
 	}
 
-	private static class DoubleSum extends SumFunction {
+	public static class DoubleSum extends SumFunction {
 
 		private static final long serialVersionUID = 1L;
 
@@ -73,7 +73,7 @@ public abstract class SumFunction implements Serializable{
 		}
 	}
 
-	private static class ShortSum extends SumFunction {
+	public static class ShortSum extends SumFunction {
 		private static final long serialVersionUID = 1L;
 
 		@Override
@@ -82,7 +82,7 @@ public abstract class SumFunction implements Serializable{
 		}
 	}
 
-	private static class FloatSum extends SumFunction {
+	public static class FloatSum extends SumFunction {
 		private static final long serialVersionUID = 1L;
 
 		@Override
@@ -91,7 +91,7 @@ public abstract class SumFunction implements Serializable{
 		}
 	}
 
-	private static class ByteSum extends SumFunction {
+	public static class ByteSum extends SumFunction {
 		private static final long serialVersionUID = 1L;
 
 		@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/invokable/operator/StreamReduceInvokable.java
----------------------------------------------------------------------
diff --git a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/invokable/operator/StreamReduceInvokable.java
b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/invokable/operator/StreamReduceInvokable.java
index 4bb78b8..5f5cb12 100644
--- a/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/invokable/operator/StreamReduceInvokable.java
+++ b/flink-addons/flink-streaming/flink-streaming-core/src/main/java/org/apache/flink/streaming/api/invokable/operator/StreamReduceInvokable.java
@@ -52,6 +52,7 @@ public class StreamReduceInvokable<IN> extends StreamInvokable<IN,
IN> {
 			currentValue = reducer.reduce(currentValue, nextValue);
 		} else {
 			currentValue = nextValue;
+
 		}
 		collector.collect(currentValue);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-scala/src/main/java/org/apache/flink/api/scala/streaming/ScalaStreamingAggregator.java
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/java/org/apache/flink/api/scala/streaming/ScalaStreamingAggregator.java
b/flink-scala/src/main/java/org/apache/flink/api/scala/streaming/ScalaStreamingAggregator.java
new file mode 100644
index 0000000..2f587d7
--- /dev/null
+++ b/flink-scala/src/main/java/org/apache/flink/api/scala/streaming/ScalaStreamingAggregator.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.api.scala.streaming;
+
+import java.io.Serializable;
+
+import org.apache.flink.api.common.typeutils.TypeSerializer;
+import org.apache.flink.api.java.typeutils.runtime.TupleSerializerBase;
+import org.apache.flink.streaming.api.function.aggregation.AggregationFunction;
+import org.apache.flink.streaming.api.function.aggregation.ComparableAggregator;
+import org.apache.flink.streaming.api.function.aggregation.SumFunction;
+
+import scala.Product;
+
+public class ScalaStreamingAggregator<IN extends Product> implements Serializable {
+
+	private static final long serialVersionUID = 1L;
+
+	TupleSerializerBase<IN> serializer;
+	Object[] fields;
+	int length;
+	int position;
+
+	public ScalaStreamingAggregator(TypeSerializer<IN> serializer, int pos) {
+		this.serializer = (TupleSerializerBase<IN>) serializer;
+		this.length = this.serializer.getArity();
+		this.fields = new Object[this.length];
+		this.position = pos;
+	}
+
+	public class Sum extends AggregationFunction<IN> {
+		private static final long serialVersionUID = 1L;
+		SumFunction sumFunction;
+
+		public Sum(SumFunction func) {
+			super(ScalaStreamingAggregator.this.position);
+			this.sumFunction = func;
+		}
+
+		@Override
+		public IN reduce(IN value1, IN value2) throws Exception {
+			for (int i = 0; i < length; i++) {
+				fields[i] = value2.productElement(i);
+			}
+
+			fields[position] = sumFunction.add(fields[position], value1.productElement(position));
+
+			return serializer.createInstance(fields);
+		}
+	}
+
+	public class ProductComparableAggregator extends ComparableAggregator<IN> {
+
+		private static final long serialVersionUID = 1L;
+
+		public ProductComparableAggregator(AggregationFunction.AggregationType aggregationType,
+				boolean first) {
+			super(ScalaStreamingAggregator.this.position, aggregationType, first);
+		}
+
+		@SuppressWarnings("unchecked")
+		@Override
+		public IN reduce(IN value1, IN value2) throws Exception {
+			Object v1 = value1.productElement(position);
+			Object v2 = value2.productElement(position);
+
+			int c = comparator.isExtremal((Comparable<Object>) v1, v2);
+
+			if (byAggregate) {
+				if (c == 1) {
+					return value1;
+				}
+				if (first) {
+					if (c == 0) {
+						return value1;
+					}
+				}
+
+				return value2;
+			} else {
+				for (int i = 0; i < length; i++) {
+					fields[i] = value2.productElement(i);
+				}
+
+				if (c == 1) {
+					fields[position] = v1;
+				}
+
+				return serializer.createInstance(fields);
+			}
+		}
+
+	}
+
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/DataStream.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/DataStream.scala
b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/DataStream.scala
index 42ec709..ecf5615 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/DataStream.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/DataStream.scala
@@ -44,6 +44,11 @@ import org.apache.flink.streaming.api.windowing.policy.TriggerPolicy
 import org.apache.flink.streaming.api.collector.OutputSelector
 import scala.collection.JavaConversions._
 import java.util.HashMap
+import org.apache.flink.streaming.api.function.aggregation.SumFunction
+import org.apache.flink.api.java.typeutils.TupleTypeInfoBase
+import org.apache.flink.streaming.api.function.aggregation.AggregationFunction
+import org.apache.flink.streaming.api.function.aggregation.AggregationFunction.AggregationType
+import com.amazonaws.services.cloudfront_2012_03_15.model.InvalidArgumentException
 
 class DataStream[T](javaStream: JavaStream[T]) {
 
@@ -230,53 +235,52 @@ class DataStream[T](javaStream: JavaStream[T]) {
    * the given position.
    *
    */
-  def max(field: Any): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.max(field))
-    case field: String => return new DataStream[T](javaStream.max(field))
-    case _ => throw new IllegalArgumentException("Aggregations are only supported by field
position (Int) or field expression (String)")
-  }
+  def max(position: Int): DataStream[T] = aggregate(AggregationType.MAX, position)
 
   /**
    * Applies an aggregation that that gives the current minimum of the data stream at
    * the given position.
    *
    */
-  def min(field: Any): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.min(field))
-    case field: String => return new DataStream[T](javaStream.min(field))
-    case _ => throw new IllegalArgumentException("Aggregations are only supported by field
position (Int) or field expression (String)")
-  }
+  def min(position: Int): DataStream[T] = aggregate(AggregationType.MIN, position)
 
   /**
    * Applies an aggregation that sums the data stream at the given position.
    *
    */
-  def sum(field: Any): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.sum(field))
-    case field: String => return new DataStream[T](javaStream.sum(field))
-    case _ => throw new IllegalArgumentException("Aggregations are only supported by field
position (Int) or field expression (String)")
-  }
+  def sum(position: Int): DataStream[T] = aggregate(AggregationType.SUM, position)
 
   /**
    * Applies an aggregation that that gives the current minimum element of the data stream
by
    * the given position. When equality, the user can set to get the first or last element
with the minimal value.
    *
    */
-  def minBy(field: Any, first: Boolean = true): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.minBy(field, first))
-    case field: String => return new DataStream[T](javaStream.minBy(field, first))
-    case _ => throw new IllegalArgumentException("Aggregations are only supported by field
position (Int) or field expression (String)")
-  }
+  def minBy(position: Int, first: Boolean = true): DataStream[T] = aggregate(AggregationType.MINBY,
position, first)
 
   /**
    * Applies an aggregation that that gives the current maximum element of the data stream
by
    * the given position. When equality, the user can set to get the first or last element
with the maximal value.
    *
    */
-  def maxBy(field: Any, first: Boolean = true): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.maxBy(field, first))
-    case field: String => return new DataStream[T](javaStream.maxBy(field, first))
-    case _ => throw new IllegalArgumentException("Aggregations are only supported by field
position (Int) or field expression (String)")
+  def maxBy(position: Int, first: Boolean = true): DataStream[T] = aggregate(AggregationType.MAXBY,
position, first)
+
+  private def aggregate(aggregationType: AggregationType, position: Int, first: Boolean =
true): DataStream[T] = {
+
+    val jStream = javaStream.asInstanceOf[JavaStream[Product]]
+    val outType = jStream.getType().asInstanceOf[TupleTypeInfoBase[_]]
+
+    val agg = new ScalaStreamingAggregator[Product](jStream.getType().createSerializer(),
position)
+
+    val reducer = aggregationType match {
+      case AggregationType.SUM => new agg.Sum(SumFunction.getForClass(outType.getTypeAt(position).getTypeClass()));
+      case _ => new agg.ProductComparableAggregator(aggregationType, first)
+    }
+
+    val invokable = jStream match {
+      case groupedStream: GroupedDataStream[_] => new GroupedReduceInvokable(reducer,
groupedStream.getKeySelector())
+      case _ => new StreamReduceInvokable(reducer)
+    }
+    new DataStream[Product](jStream.transform("aggregation", jStream.getType(), invokable)).asInstanceOf[DataStream[T]]
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/flink/blob/f7291ea1/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/WindowedDataStream.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/WindowedDataStream.scala
b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/WindowedDataStream.scala
index c686497..c037305 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/WindowedDataStream.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/streaming/WindowedDataStream.scala
@@ -36,6 +36,9 @@ import org.apache.flink.streaming.api.windowing.helper.WindowingHelper
 import org.apache.flink.api.common.functions.GroupReduceFunction
 import org.apache.flink.streaming.api.invokable.StreamInvokable
 import scala.collection.JavaConversions._
+import org.apache.flink.streaming.api.function.aggregation.AggregationFunction.AggregationType
+import org.apache.flink.api.java.typeutils.TupleTypeInfoBase
+import org.apache.flink.streaming.api.function.aggregation.SumFunction
 
 class WindowedDataStream[T](javaStream: JavaWStream[T]) {
 
@@ -158,75 +161,48 @@ class WindowedDataStream[T](javaStream: JavaWStream[T]) {
    * the given position.
    *
    */
-  def max(field: Any): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.max(field))
-    case field: String => return new DataStream[T](javaStream.max(field))
-    case _ => throw new IllegalArgumentException("Aggregations are only supported by field
position (Int) or field expression (String)")
-  }
+  def max(position: Int): DataStream[T] = aggregate(AggregationType.MAX, position)
 
   /**
    * Applies an aggregation that that gives the minimum of the elements in the window at
    * the given position.
    *
    */
-  def min(field: Any): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.min(field))
-    case field: String => return new DataStream[T](javaStream.min(field))
-    case _ => throw new IllegalArgumentException("Aggregations are only supported by field
position (Int) or field expression (String)")
-  }
+  def min(position: Int): DataStream[T] = aggregate(AggregationType.MIN, position)
 
   /**
    * Applies an aggregation that sums the elements in the window at the given position.
    *
    */
-  def sum(field: Any): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.sum(field))
-    case field: String => return new DataStream[T](javaStream.sum(field))
-    case _ => throw new IllegalArgumentException("Aggregations are only supported by field
position (Int) or field expression (String)")
-  }
+  def sum(position: Int): DataStream[T] = aggregate(AggregationType.SUM, position)
 
   /**
    * Applies an aggregation that that gives the maximum element of the window by
    * the given position. When equality, returns the first.
    *
    */
-  def maxBy(field: Any): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.maxBy(field))
-    case field: String => return new DataStream[T](javaStream.maxBy(field))
-    case _ => throw new IllegalArgumentException("Aggregations are only supported by field
position (Int) or field expression (String)")
-  }
+  def maxBy(position: Int, first: Boolean = true): DataStream[T] = aggregate(AggregationType.MAXBY,
position, first)
 
   /**
    * Applies an aggregation that that gives the minimum element of the window by
    * the given position. When equality, returns the first.
    *
    */
-  def minBy(field: Any): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.minBy(field))
-    case field: String => return new DataStream[T](javaStream.minBy(field))
-    case _ => throw new IllegalArgumentException("Aggregations are only supported by field
position (Int) or field expression (String)")
-  }
+  def minBy(position: Int, first: Boolean = true): DataStream[T] = aggregate(AggregationType.MINBY,
position, first)
 
-  /**
-   * Applies an aggregation that that gives the minimum element of the window by
-   * the given position. When equality, the user can set to get the first or last element
with the minimal value.
-   *
-   */
-  def minBy(field: Any, first: Boolean): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.minBy(field, first))
-    case field: String => return new DataStream[T](javaStream.minBy(field, first))
-    case _ => throw new IllegalArgumentException("Aggregations are only supported by field
position (Int) or field expression (String)")
-  }
+  def aggregate(aggregationType: AggregationType, position: Int, first: Boolean = true):
DataStream[T] = {
 
-  /**
-   * Applies an aggregation that that gives the maximum element of the window by
-   * the given position. When equality, the user can set to get the first or last element
with the maximal value.
-   *
-   */
-  def maxBy(field: Any, first: Boolean): DataStream[T] = field match {
-    case field: Int => return new DataStream[T](javaStream.maxBy(field, first))
-    case field: String => return new DataStream[T](javaStream.maxBy(field, first))
-    case _ => throw new IllegalArgumentException("Aggregations are only supported by field
position (Int) or field expression (String)")
+    val jStream = javaStream.asInstanceOf[JavaWStream[Product]]
+    val outType = jStream.getType().asInstanceOf[TupleTypeInfoBase[_]]
+
+    val agg = new ScalaStreamingAggregator[Product](jStream.getType().createSerializer(),
position)
+
+    val reducer = aggregationType match {
+      case AggregationType.SUM => new agg.Sum(SumFunction.getForClass(outType.getTypeAt(position).getTypeClass()));
+      case _ => new agg.ProductComparableAggregator(aggregationType, first)
+    }
+
+    new DataStream[Product](jStream.reduce(reducer)).asInstanceOf[DataStream[T]]
   }
 
 }
\ No newline at end of file


Mime
View raw message