flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From fhue...@apache.org
Subject [1/4] flink git commit: [FLINK-3234] [dataSet] Add KeySelector support to sortPartition operation.
Date Wed, 10 Feb 2016 10:52:07 GMT
Repository: flink
Updated Branches:
  refs/heads/master 59b237b5d -> 0a63797a6


[FLINK-3234] [dataSet] Add KeySelector support to sortPartition operation.

This closes #1585


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/0a63797a
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/0a63797a
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/0a63797a

Branch: refs/heads/master
Commit: 0a63797a6a5418b2363bca25bd77c33c217ff257
Parents: 572855d
Author: Chiwan Park <chiwanpark@apache.org>
Authored: Thu Feb 4 20:46:10 2016 +0900
Committer: Fabian Hueske <fhueske@apache.org>
Committed: Wed Feb 10 11:51:26 2016 +0100

----------------------------------------------------------------------
 .../java/org/apache/flink/api/java/DataSet.java |  18 ++
 .../java/operators/SortPartitionOperator.java   | 174 +++++++++++++------
 .../api/java/operator/SortPartitionTest.java    |  82 +++++++++
 .../org/apache/flink/api/scala/DataSet.scala    |  25 +++
 .../api/scala/PartitionSortedDataSet.scala      |  22 ++-
 .../javaApiOperators/SortPartitionITCase.java   |  61 +++++++
 .../scala/operators/SortPartitionITCase.scala   |  59 +++++++
 7 files changed, 385 insertions(+), 56 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java b/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java
index bfb97f4..c315920 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java
@@ -1381,6 +1381,24 @@ public abstract class DataSet<T> {
 		return new SortPartitionOperator<>(this, field, order, Utils.getCallLocationName());
 	}
 
+	/**
+	 * Locally sorts the partitions of the DataSet on the extracted key in the specified order.
+	 * The DataSet can be sorted on multiple values by returning a tuple from the KeySelector.
+	 *
+	 * Note that no additional sort keys can be appended to a KeySelector sort keys. To sort
+	 * the partitions by multiple values using KeySelector, the KeySelector must return a tuple
+	 * consisting of the values.
+	 *
+	 * @param keyExtractor The KeySelector function which extracts the key values from the DataSet
+	 *                     on which the DataSet is sorted.
+	 * @param order The order in which the DataSet is sorted.
+	 * @return The DataSet with sorted local partitions.
+	 */
+	public <K> SortPartitionOperator<T> sortPartition(KeySelector<T, K> keyExtractor,
Order order) {
+		final TypeInformation<K> keyType = TypeExtractor.getKeySelectorTypes(keyExtractor,
getType());
+		return new SortPartitionOperator<>(this, new Keys.SelectorFunctionKeys<>(clean(keyExtractor),
getType(), keyType), order, Utils.getCallLocationName());
+	}
+
 	// --------------------------------------------------------------------------------------------
 	//  Top-K
 	// --------------------------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-java/src/main/java/org/apache/flink/api/java/operators/SortPartitionOperator.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/SortPartitionOperator.java
b/flink-java/src/main/java/org/apache/flink/api/java/operators/SortPartitionOperator.java
index 354a0cd..7f30a30 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/SortPartitionOperator.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/SortPartitionOperator.java
@@ -26,9 +26,13 @@ import org.apache.flink.api.common.operators.Order;
 import org.apache.flink.api.common.operators.Ordering;
 import org.apache.flink.api.common.operators.UnaryOperatorInformation;
 import org.apache.flink.api.common.operators.base.SortPartitionOperatorBase;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
 
-import java.util.Arrays;
+import java.util.ArrayList;
+import java.util.List;
 
 /**
  * This operator represents a DataSet with locally sorted partitions.
@@ -38,27 +42,58 @@ import java.util.Arrays;
 @Public
 public class SortPartitionOperator<T> extends SingleInputOperator<T, T, SortPartitionOperator<T>>
{
 
-	private int[] sortKeyPositions;
+	private List<Keys<T>> keys;
 
-	private Order[] sortOrders;
+	private List<Order> orders;
 
 	private final String sortLocationName;
 
+	private boolean useKeySelector;
 
-	public SortPartitionOperator(DataSet<T> dataSet, int sortField, Order sortOrder, String
sortLocationName) {
+	private SortPartitionOperator(DataSet<T> dataSet, String sortLocationName) {
 		super(dataSet, dataSet.getType());
+
+		keys = new ArrayList<>();
+		orders = new ArrayList<>();
 		this.sortLocationName = sortLocationName;
+	}
+
+
+	public SortPartitionOperator(DataSet<T> dataSet, int sortField, Order sortOrder, String
sortLocationName) {
+		this(dataSet, sortLocationName);
+		this.useKeySelector = false;
+
+		ensureSortableKey(sortField);
 
-		int[] flatOrderKeys = getFlatFields(sortField);
-		this.appendSorting(flatOrderKeys, sortOrder);
+		keys.add(new Keys.ExpressionKeys<>(sortField, getType()));
+		orders.add(sortOrder);
 	}
 
 	public SortPartitionOperator(DataSet<T> dataSet, String sortField, Order sortOrder,
String sortLocationName) {
-		super(dataSet, dataSet.getType());
-		this.sortLocationName = sortLocationName;
+		this(dataSet, sortLocationName);
+		this.useKeySelector = false;
+
+		ensureSortableKey(sortField);
+
+		keys.add(new Keys.ExpressionKeys<>(sortField, getType()));
+		orders.add(sortOrder);
+	}
+
+	public <K> SortPartitionOperator(DataSet<T> dataSet, Keys.SelectorFunctionKeys<T,
K> sortKey, Order sortOrder, String sortLocationName) {
+		this(dataSet, sortLocationName);
+		this.useKeySelector = true;
+
+		ensureSortableKey(sortKey);
 
-		int[] flatOrderKeys = getFlatFields(sortField);
-		this.appendSorting(flatOrderKeys, sortOrder);
+		keys.add(sortKey);
+		orders.add(sortOrder);
+	}
+
+	/**
+	 * Returns whether using key selector or not.
+     */
+	public boolean useKeySelector() {
+		return useKeySelector;
 	}
 
 	/**
@@ -70,9 +105,14 @@ public class SortPartitionOperator<T> extends SingleInputOperator<T,
T, SortPart
 	 * @return The DataSet with sorted local partitions.
 	 */
 	public SortPartitionOperator<T> sortPartition(int field, Order order) {
+		if (useKeySelector) {
+			throw new InvalidProgramException("Expression keys cannot be appended after a KeySelector");
+		}
+
+		ensureSortableKey(field);
+		keys.add(new Keys.ExpressionKeys<>(field, getType()));
+		orders.add(order);
 
-		int[] flatOrderKeys = getFlatFields(field);
-		this.appendSorting(flatOrderKeys, order);
 		return this;
 	}
 
@@ -81,58 +121,41 @@ public class SortPartitionOperator<T> extends SingleInputOperator<T,
T, SortPart
 	 * local partition sorting of the DataSet.
 	 *
 	 * @param field The field expression referring to the field of the additional sort order
of
-	 *                 the local partition sorting.
-	 * @param order The order  of the additional sort order of the local partition sorting.
+	 *              the local partition sorting.
+	 * @param order The order of the additional sort order of the local partition sorting.
 	 * @return The DataSet with sorted local partitions.
 	 */
 	public SortPartitionOperator<T> sortPartition(String field, Order order) {
-		int[] flatOrderKeys = getFlatFields(field);
-		this.appendSorting(flatOrderKeys, order);
+		if (useKeySelector) {
+			throw new InvalidProgramException("Expression keys cannot be appended after a KeySelector");
+		}
+
+		ensureSortableKey(field);
+		keys.add(new Keys.ExpressionKeys<>(field, getType()));
+		orders.add(order);
+
 		return this;
 	}
 
-	// --------------------------------------------------------------------------------------------
-	//  Key Extraction
-	// --------------------------------------------------------------------------------------------
-
-	private int[] getFlatFields(int field) {
+	public <K> SortPartitionOperator<T> sortPartition(KeySelector<T, K> keyExtractor,
Order order) {
+		throw new InvalidProgramException("KeySelector cannot be chained.");
+	}
 
-		if (!Keys.ExpressionKeys.isSortKey(field, super.getType())) {
+	private void ensureSortableKey(int field) throws InvalidProgramException {
+		if (!Keys.ExpressionKeys.isSortKey(field, getType())) {
 			throw new InvalidProgramException("Selected sort key is not a sortable type");
 		}
-
-		Keys.ExpressionKeys<T> ek = new Keys.ExpressionKeys<>(field, super.getType());
-		return ek.computeLogicalKeyPositions();
 	}
 
-	private int[] getFlatFields(String fields) {
-
-		if (!Keys.ExpressionKeys.isSortKey(fields, super.getType())) {
+	private void ensureSortableKey(String field) throws InvalidProgramException {
+		if (!Keys.ExpressionKeys.isSortKey(field, getType())) {
 			throw new InvalidProgramException("Selected sort key is not a sortable type");
 		}
-
-		Keys.ExpressionKeys<T> ek = new Keys.ExpressionKeys<>(fields, super.getType());
-		return ek.computeLogicalKeyPositions();
 	}
 
-	private void appendSorting(int[] flatOrderFields, Order order) {
-
-		if(this.sortKeyPositions == null) {
-			// set sorting info
-			this.sortKeyPositions = flatOrderFields;
-			this.sortOrders = new Order[flatOrderFields.length];
-			Arrays.fill(this.sortOrders, order);
-		} else {
-			// append sorting info to exising info
-			int oldLength = this.sortKeyPositions.length;
-			int newLength = oldLength + flatOrderFields.length;
-			this.sortKeyPositions = Arrays.copyOf(this.sortKeyPositions, newLength);
-			this.sortOrders = Arrays.copyOf(this.sortOrders, newLength);
-
-			for(int i=0; i<flatOrderFields.length; i++) {
-				this.sortKeyPositions[oldLength+i] = flatOrderFields[i];
-				this.sortOrders[oldLength+i] = order;
-			}
+	private <K> void ensureSortableKey(Keys.SelectorFunctionKeys<T, K> sortKey)
{
+		if (!sortKey.getKeyType().isSortKeyType()) {
+			throw new InvalidProgramException("Selected sort key is not a sortable type");
 		}
 	}
 
@@ -144,16 +167,33 @@ public class SortPartitionOperator<T> extends SingleInputOperator<T,
T, SortPart
 
 		String name = "Sort at " + sortLocationName;
 
+		if (useKeySelector) {
+			return translateToDataFlowWithKeyExtractor(input, (Keys.SelectorFunctionKeys<T, ?>)
keys.get(0), orders.get(0), name);
+		}
+
+		// flatten sort key positions
+		List<Integer> allKeyPositions = new ArrayList<>();
+		List<Order> allOrders = new ArrayList<>();
+		for (int i = 0, length = keys.size(); i < length; i++) {
+			int[] sortKeyPositions = keys.get(i).computeLogicalKeyPositions();
+			Order order = orders.get(i);
+
+			for (int sortKeyPosition : sortKeyPositions) {
+				allKeyPositions.add(sortKeyPosition);
+				allOrders.add(order);
+			}
+		}
+
 		Ordering partitionOrdering = new Ordering();
-		for (int i = 0; i < this.sortKeyPositions.length; i++) {
-			partitionOrdering.appendOrdering(this.sortKeyPositions[i], null, this.sortOrders[i]);
+		for (int i = 0, length = allKeyPositions.size(); i < length; i++) {
+			partitionOrdering.appendOrdering(allKeyPositions.get(i), null, allOrders.get(i));
 		}
 
 		// distinguish between partition types
 		UnaryOperatorInformation<T, T> operatorInfo = new UnaryOperatorInformation<>(getType(),
getType());
-		SortPartitionOperatorBase<T> noop = new  SortPartitionOperatorBase<>(operatorInfo,
partitionOrdering, name);
+		SortPartitionOperatorBase<T> noop = new SortPartitionOperatorBase<>(operatorInfo,
partitionOrdering, name);
 		noop.setInput(input);
-		if(this.getParallelism() < 0) {
+		if (this.getParallelism() < 0) {
 			// use parallelism of input if not explicitly specified
 			noop.setParallelism(input.getParallelism());
 		} else {
@@ -165,4 +205,32 @@ public class SortPartitionOperator<T> extends SingleInputOperator<T,
T, SortPart
 
 	}
 
+	private <K> org.apache.flink.api.common.operators.SingleInputOperator<?, T, ?>
translateToDataFlowWithKeyExtractor(
+		Operator<T> input, Keys.SelectorFunctionKeys<T, K> keys, Order order, String
name) {
+		TypeInformation<Tuple2<K, T>> typeInfoWithKey = KeyFunctions.createTypeWithKey(keys);
+		Keys.ExpressionKeys<Tuple2<K, T>> newKey = new Keys.ExpressionKeys<>(0,
typeInfoWithKey);
+
+		Operator<Tuple2<K, T>> keyedInput = KeyFunctions.appendKeyExtractor(input,
keys);
+
+		int[] sortKeyPositions = newKey.computeLogicalKeyPositions();
+		Ordering partitionOrdering = new Ordering();
+		for (int keyPosition : sortKeyPositions) {
+			partitionOrdering.appendOrdering(keyPosition, null, order);
+		}
+
+		// distinguish between partition types
+		UnaryOperatorInformation<Tuple2<K, T>, Tuple2<K, T>> operatorInfo = new
UnaryOperatorInformation<>(typeInfoWithKey, typeInfoWithKey);
+		SortPartitionOperatorBase<Tuple2<K, T>> noop = new SortPartitionOperatorBase<>(operatorInfo,
partitionOrdering, name);
+		noop.setInput(keyedInput);
+		if (this.getParallelism() < 0) {
+			// use parallelism of input if not explicitly specified
+			noop.setParallelism(input.getParallelism());
+		} else {
+			// use explicitly specified parallelism
+			noop.setParallelism(this.getParallelism());
+		}
+
+		return KeyFunctions.appendKeyRemover(noop, keys);
+	}
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-java/src/test/java/org/apache/flink/api/java/operator/SortPartitionTest.java
----------------------------------------------------------------------
diff --git a/flink-java/src/test/java/org/apache/flink/api/java/operator/SortPartitionTest.java
b/flink-java/src/test/java/org/apache/flink/api/java/operator/SortPartitionTest.java
index a4e2bbc..3540e6a 100644
--- a/flink-java/src/test/java/org/apache/flink/api/java/operator/SortPartitionTest.java
+++ b/flink-java/src/test/java/org/apache/flink/api/java/operator/SortPartitionTest.java
@@ -169,6 +169,88 @@ public class SortPartitionTest {
 		tupleDs.sortPartition("f3", Order.ASCENDING);
 	}
 
+	@Test
+	public void testSortPartitionWithKeySelector1() {
+		final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+		DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData,
tupleWithCustomInfo);
+
+		// should work
+		try {
+			tupleDs.sortPartition(new KeySelector<Tuple4<Integer, Long, CustomType, Long[]>,
Integer>() {
+				@Override
+				public Integer getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws Exception
{
+					return value.f0;
+				}
+			}, Order.ASCENDING);
+		} catch (Exception e) {
+			Assert.fail();
+		}
+	}
+
+	@Test(expected = InvalidProgramException.class)
+	public void testSortPartitionWithKeySelector2() {
+		final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+		DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData,
tupleWithCustomInfo);
+
+		// must not work
+		tupleDs.sortPartition(new KeySelector<Tuple4<Integer, Long, CustomType, Long[]>,
Long[]>() {
+			@Override
+			public Long[] getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws Exception
{
+				return value.f3;
+			}
+		}, Order.ASCENDING);
+	}
+
+	@Test(expected = InvalidProgramException.class)
+	public void testSortPartitionWithKeySelector3() {
+		final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+		DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData,
tupleWithCustomInfo);
+
+		// must not work
+		tupleDs
+			.sortPartition("f1", Order.ASCENDING)
+			.sortPartition(new KeySelector<Tuple4<Integer, Long, CustomType, Long[]>, CustomType>()
{
+				@Override
+				public CustomType getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws
Exception {
+					return value.f2;
+				}
+			}, Order.ASCENDING);
+	}
+
+	@Test
+	public void testSortPartitionWithKeySelector4() {
+		final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+		DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData,
tupleWithCustomInfo);
+
+		// should work
+		try {
+			tupleDs.sortPartition(new KeySelector<Tuple4<Integer,Long,CustomType,Long[]>,
Tuple2<Integer, Long>>() {
+				@Override
+				public Tuple2<Integer, Long> getKey(Tuple4<Integer, Long, CustomType, Long[]>
value) throws Exception {
+					return new Tuple2<>(value.f0, value.f1);
+				}
+			}, Order.ASCENDING);
+		} catch (Exception e) {
+			Assert.fail();
+		}
+	}
+
+	@Test(expected = InvalidProgramException.class)
+	public void testSortPartitionWithKeySelector5() {
+		final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+		DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData,
tupleWithCustomInfo);
+
+		// must not work
+		tupleDs
+			.sortPartition(new KeySelector<Tuple4<Integer, Long, CustomType, Long[]>, CustomType>()
{
+				@Override
+				public CustomType getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws
Exception {
+					return value.f2;
+				}
+			}, Order.ASCENDING)
+			.sortPartition("f1", Order.ASCENDING);
+	}
+
 	public static class CustomType implements Serializable {
 		
 		public static class Nest {

http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
index e47bc42..5735b32 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
@@ -1511,6 +1511,31 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
       new SortPartitionOperator[T](javaSet, field, order, getCallLocationName()))
   }
 
+  /**
+    * Locally sorts the partitions of the DataSet on the extracted key in the specified order.
+    * The DataSet can be sorted on multiple values by returning a tuple from the KeySelector.
+    *
+    * Note that no additional sort keys can be appended to a KeySelector sort keys. To sort
+    * the partitions by multiple values using KeySelector, the KeySelector must return a
tuple
+    * consisting of the values.
+    */
+  def sortPartition[K: TypeInformation](fun: T => K, order: Order): DataSet[T] ={
+    val keyExtractor = new KeySelector[T, K] {
+      val cleanFun = clean(fun)
+      def getKey(in: T) = cleanFun(in)
+    }
+
+    val keyType = implicitly[TypeInformation[K]]
+    new PartitionSortedDataSet[T](
+      new SortPartitionOperator[T](javaSet,
+        new Keys.SelectorFunctionKeys[T, K](
+          keyExtractor,
+          javaSet.getType,
+          keyType),
+        order,
+        getCallLocationName()))
+  }
+
   // --------------------------------------------------------------------------------------------
   //  Result writing
   // --------------------------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-scala/src/main/scala/org/apache/flink/api/scala/PartitionSortedDataSet.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/PartitionSortedDataSet.scala
b/flink-scala/src/main/scala/org/apache/flink/api/scala/PartitionSortedDataSet.scala
index c924a76..a402dd9 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/PartitionSortedDataSet.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/PartitionSortedDataSet.scala
@@ -18,7 +18,9 @@
 package org.apache.flink.api.scala
 
 import org.apache.flink.annotation.Public
+import org.apache.flink.api.common.InvalidProgramException
 import org.apache.flink.api.common.operators.Order
+import org.apache.flink.api.common.typeinfo.TypeInformation
 import org.apache.flink.api.java.operators.SortPartitionOperator
 
 import scala.reflect.ClassTag
@@ -37,16 +39,30 @@ class PartitionSortedDataSet[T: ClassTag](set: SortPartitionOperator[T])
    * Appends the given field and order to the sort-partition operator.
    */
   override def sortPartition(field: Int, order: Order): DataSet[T] = {
+    if (set.useKeySelector()) {
+      throw new InvalidProgramException("Expression keys cannot be appended after selector
" +
+        "function keys")
+    }
+
     this.set.sortPartition(field, order)
     this
   }
 
-/**
- * Appends the given field and order to the sort-partition operator.
- */
+  /**
+   * Appends the given field and order to the sort-partition operator.
+   */
   override def sortPartition(field: String, order: Order): DataSet[T] = {
+    if (set.useKeySelector()) {
+      throw new InvalidProgramException("Expression keys cannot be appended after selector
" +
+        "function keys")
+    }
+
     this.set.sortPartition(field, order)
     this
   }
 
+  override def sortPartition[K: TypeInformation](fun: T => K, order: Order): DataSet[T]
= {
+    throw new InvalidProgramException("KeySelector cannot be chained.")
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SortPartitionITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SortPartitionITCase.java
b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SortPartitionITCase.java
index 2423420..c7f07f6 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SortPartitionITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SortPartitionITCase.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.functions.MapPartitionFunction;
 import org.apache.flink.api.common.operators.Order;
 import org.apache.flink.api.java.DataSet;
 import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.api.java.tuple.Tuple1;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.tuple.Tuple3;
@@ -197,6 +198,58 @@ public class SortPartitionITCase extends MultipleProgramsTestBase {
 		compareResultAsText(result, expected);
 	}
 
+	@Test
+	public void testSortPartitionWithKeySelector1() throws Exception {
+		/*
+		 * Test sort partition on an extracted key
+		 */
+
+		final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+		env.setParallelism(4);
+
+		DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
+		List<Tuple1<Boolean>> result = ds
+			.map(new IdMapper<Tuple3<Integer, Long, String>>()).setParallelism(4) // parallelize
input
+			.sortPartition(new KeySelector<Tuple3<Integer, Long, String>, Long>() {
+				@Override
+				public Long getKey(Tuple3<Integer, Long, String> value) throws Exception {
+					return value.f1;
+				}
+			}, Order.ASCENDING)
+			.mapPartition(new OrderCheckMapper<>(new Tuple3AscendingChecker()))
+			.distinct().collect();
+
+		String expected = "(true)\n";
+
+		compareResultAsText(result, expected);
+	}
+
+	@Test
+	public void testSortPartitionWithKeySelector2() throws Exception {
+		/*
+		 * Test sort partition on an extracted key
+		 */
+
+		final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+		env.setParallelism(4);
+
+		DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
+		List<Tuple1<Boolean>> result = ds
+			.map(new IdMapper<Tuple3<Integer, Long, String>>()).setParallelism(4) // parallelize
input
+			.sortPartition(new KeySelector<Tuple3<Integer, Long, String>, Tuple2<Integer,
Long>>() {
+				@Override
+				public Tuple2<Integer, Long> getKey(Tuple3<Integer, Long, String> value)
throws Exception {
+					return new Tuple2<>(value.f0, value.f1);
+				}
+			}, Order.DESCENDING)
+			.mapPartition(new OrderCheckMapper<>(new Tuple3Checker()))
+			.distinct().collect();
+
+		String expected = "(true)\n";
+
+		compareResultAsText(result, expected);
+	}
+
 	public interface OrderChecker<T> extends Serializable {
 		boolean inOrder(T t1, T t2);
 	}
@@ -210,6 +263,14 @@ public class SortPartitionITCase extends MultipleProgramsTestBase {
 	}
 
 	@SuppressWarnings("serial")
+	public static class Tuple3AscendingChecker implements OrderChecker<Tuple3<Integer,
Long, String>> {
+		@Override
+		public boolean inOrder(Tuple3<Integer, Long, String> t1, Tuple3<Integer, Long,
String> t2) {
+			return t1.f1 <= t2.f1;
+		}
+	}
+
+	@SuppressWarnings("serial")
 	public static class Tuple5Checker implements OrderChecker<Tuple5<Integer, Long, Integer,
String, Long>> {
 		@Override
 		public boolean inOrder(Tuple5<Integer, Long, Integer, String, Long> t1,

http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SortPartitionITCase.scala
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SortPartitionITCase.scala
b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SortPartitionITCase.scala
index 3f67063..cda8f4f 100644
--- a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SortPartitionITCase.scala
+++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SortPartitionITCase.scala
@@ -24,6 +24,7 @@ import scala.collection.JavaConverters._
 
 import org.apache.flink.api.common.functions.MapPartitionFunction
 import org.apache.flink.api.common.operators.Order
+import org.apache.flink.api.common.InvalidProgramException
 import org.apache.flink.api.scala._
 import org.apache.flink.api.scala.util.CollectionDataSets
 import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
@@ -166,6 +167,58 @@ class SortPartitionITCase(mode: TestExecutionMode) extends MultipleProgramsTestB
     TestBaseUtils.compareResultAsText(result.asJava, expected)
   }
 
+  @Test
+  def testSortPartitionWithKeySelector1(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    env.setParallelism(4)
+    val ds = CollectionDataSets.get3TupleDataSet(env)
+
+    val result = ds
+      .map { x => x }.setParallelism(4)
+      .sortPartition(_._2, Order.ASCENDING)
+      .mapPartition(new OrderCheckMapper(new Tuple3AscendingChecker))
+      .distinct()
+      .collect()
+
+    val expected: String = "(true)\n"
+    TestBaseUtils.compareResultAsText(result.asJava, expected)
+  }
+
+  @Test
+  def testSortPartitionWithKeySelector2(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    env.setParallelism(4)
+    val ds = CollectionDataSets.get3TupleDataSet(env)
+
+    val result = ds
+      .map { x => x }.setParallelism(4)
+      .sortPartition(x => (x._2, x._1), Order.DESCENDING)
+      .mapPartition(new OrderCheckMapper(new Tuple3Checker))
+      .distinct()
+      .collect()
+
+    val expected: String = "(true)\n"
+    TestBaseUtils.compareResultAsText(result.asJava, expected)
+  }
+
+  @Test(expected = classOf[InvalidProgramException])
+  def testSortPartitionWithKeySelector3(): Unit = {
+    val env = ExecutionEnvironment.getExecutionEnvironment
+    env.setParallelism(4)
+    val ds = CollectionDataSets.get3TupleDataSet(env)
+
+    val result = ds
+      .map { x => x }.setParallelism(4)
+      .sortPartition(x => (x._2, x._1), Order.DESCENDING)
+      .sortPartition(0, Order.DESCENDING)
+      .mapPartition(new OrderCheckMapper(new Tuple3Checker))
+      .distinct()
+      .collect()
+
+    val expected: String = "(true)\n"
+    TestBaseUtils.compareResultAsText(result.asJava, expected)
+  }
+
 }
 
 trait OrderChecker[T] extends Serializable {
@@ -178,6 +231,12 @@ class Tuple3Checker extends OrderChecker[(Int, Long, String)] {
   }
 }
 
+class Tuple3AscendingChecker extends OrderChecker[(Int, Long, String)] {
+  def inOrder(t1: (Int, Long, String), t2: (Int, Long, String)): Boolean = {
+    t1._2 <= t2._2
+  }
+}
+
 class Tuple5Checker extends OrderChecker[(Int, Long, Int, String, Long)] {
   def inOrder(t1: (Int, Long, Int, String, Long), t2: (Int, Long, Int, String, Long)): Boolean
= {
     t1._5 < t2._5 || t1._5 == t2._5 && t1._3 >= t2._3


Mime
View raw message