[FLINK-3234] [dataSet] Add KeySelector support to sortPartition operation.
This closes #1585
Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/0a63797a
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/0a63797a
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/0a63797a
Branch: refs/heads/tableOnCalcite
Commit: 0a63797a6a5418b2363bca25bd77c33c217ff257
Parents: 572855d
Author: Chiwan Park <chiwanpark@apache.org>
Authored: Thu Feb 4 20:46:10 2016 +0900
Committer: Fabian Hueske <fhueske@apache.org>
Committed: Wed Feb 10 11:51:26 2016 +0100
----------------------------------------------------------------------
.../java/org/apache/flink/api/java/DataSet.java | 18 ++
.../java/operators/SortPartitionOperator.java | 174 +++++++++++++------
.../api/java/operator/SortPartitionTest.java | 82 +++++++++
.../org/apache/flink/api/scala/DataSet.scala | 25 +++
.../api/scala/PartitionSortedDataSet.scala | 22 ++-
.../javaApiOperators/SortPartitionITCase.java | 61 +++++++
.../scala/operators/SortPartitionITCase.scala | 59 +++++++
7 files changed, 385 insertions(+), 56 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java b/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java
index bfb97f4..c315920 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/DataSet.java
@@ -1381,6 +1381,24 @@ public abstract class DataSet<T> {
return new SortPartitionOperator<>(this, field, order, Utils.getCallLocationName());
}
+ /**
+ * Locally sorts the partitions of the DataSet on the extracted key in the specified order.
+ * The DataSet can be sorted on multiple values by returning a tuple from the KeySelector.
+ *
+ * Note that no additional sort keys can be appended to a KeySelector sort keys. To sort
+ * the partitions by multiple values using KeySelector, the KeySelector must return a tuple
+ * consisting of the values.
+ *
+ * @param keyExtractor The KeySelector function which extracts the key values from the DataSet
+ * on which the DataSet is sorted.
+ * @param order The order in which the DataSet is sorted.
+ * @return The DataSet with sorted local partitions.
+ */
+ public <K> SortPartitionOperator<T> sortPartition(KeySelector<T, K> keyExtractor,
Order order) {
+ final TypeInformation<K> keyType = TypeExtractor.getKeySelectorTypes(keyExtractor,
getType());
+ return new SortPartitionOperator<>(this, new Keys.SelectorFunctionKeys<>(clean(keyExtractor),
getType(), keyType), order, Utils.getCallLocationName());
+ }
+
// --------------------------------------------------------------------------------------------
// Top-K
// --------------------------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-java/src/main/java/org/apache/flink/api/java/operators/SortPartitionOperator.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/SortPartitionOperator.java
b/flink-java/src/main/java/org/apache/flink/api/java/operators/SortPartitionOperator.java
index 354a0cd..7f30a30 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/SortPartitionOperator.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/SortPartitionOperator.java
@@ -26,9 +26,13 @@ import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.common.operators.Ordering;
import org.apache.flink.api.common.operators.UnaryOperatorInformation;
import org.apache.flink.api.common.operators.base.SortPartitionOperatorBase;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
-import java.util.Arrays;
+import java.util.ArrayList;
+import java.util.List;
/**
* This operator represents a DataSet with locally sorted partitions.
@@ -38,27 +42,58 @@ import java.util.Arrays;
@Public
public class SortPartitionOperator<T> extends SingleInputOperator<T, T, SortPartitionOperator<T>>
{
- private int[] sortKeyPositions;
+ private List<Keys<T>> keys;
- private Order[] sortOrders;
+ private List<Order> orders;
private final String sortLocationName;
+ private boolean useKeySelector;
- public SortPartitionOperator(DataSet<T> dataSet, int sortField, Order sortOrder, String
sortLocationName) {
+ private SortPartitionOperator(DataSet<T> dataSet, String sortLocationName) {
super(dataSet, dataSet.getType());
+
+ keys = new ArrayList<>();
+ orders = new ArrayList<>();
this.sortLocationName = sortLocationName;
+ }
+
+
+ public SortPartitionOperator(DataSet<T> dataSet, int sortField, Order sortOrder, String
sortLocationName) {
+ this(dataSet, sortLocationName);
+ this.useKeySelector = false;
+
+ ensureSortableKey(sortField);
- int[] flatOrderKeys = getFlatFields(sortField);
- this.appendSorting(flatOrderKeys, sortOrder);
+ keys.add(new Keys.ExpressionKeys<>(sortField, getType()));
+ orders.add(sortOrder);
}
public SortPartitionOperator(DataSet<T> dataSet, String sortField, Order sortOrder,
String sortLocationName) {
- super(dataSet, dataSet.getType());
- this.sortLocationName = sortLocationName;
+ this(dataSet, sortLocationName);
+ this.useKeySelector = false;
+
+ ensureSortableKey(sortField);
+
+ keys.add(new Keys.ExpressionKeys<>(sortField, getType()));
+ orders.add(sortOrder);
+ }
+
+ public <K> SortPartitionOperator(DataSet<T> dataSet, Keys.SelectorFunctionKeys<T,
K> sortKey, Order sortOrder, String sortLocationName) {
+ this(dataSet, sortLocationName);
+ this.useKeySelector = true;
+
+ ensureSortableKey(sortKey);
- int[] flatOrderKeys = getFlatFields(sortField);
- this.appendSorting(flatOrderKeys, sortOrder);
+ keys.add(sortKey);
+ orders.add(sortOrder);
+ }
+
+ /**
+ * Returns whether using key selector or not.
+ */
+ public boolean useKeySelector() {
+ return useKeySelector;
}
/**
@@ -70,9 +105,14 @@ public class SortPartitionOperator<T> extends SingleInputOperator<T,
T, SortPart
* @return The DataSet with sorted local partitions.
*/
public SortPartitionOperator<T> sortPartition(int field, Order order) {
+ if (useKeySelector) {
+ throw new InvalidProgramException("Expression keys cannot be appended after a KeySelector");
+ }
+
+ ensureSortableKey(field);
+ keys.add(new Keys.ExpressionKeys<>(field, getType()));
+ orders.add(order);
- int[] flatOrderKeys = getFlatFields(field);
- this.appendSorting(flatOrderKeys, order);
return this;
}
@@ -81,58 +121,41 @@ public class SortPartitionOperator<T> extends SingleInputOperator<T,
T, SortPart
* local partition sorting of the DataSet.
*
* @param field The field expression referring to the field of the additional sort order
of
- * the local partition sorting.
- * @param order The order of the additional sort order of the local partition sorting.
+ * the local partition sorting.
+ * @param order The order of the additional sort order of the local partition sorting.
* @return The DataSet with sorted local partitions.
*/
public SortPartitionOperator<T> sortPartition(String field, Order order) {
- int[] flatOrderKeys = getFlatFields(field);
- this.appendSorting(flatOrderKeys, order);
+ if (useKeySelector) {
+ throw new InvalidProgramException("Expression keys cannot be appended after a KeySelector");
+ }
+
+ ensureSortableKey(field);
+ keys.add(new Keys.ExpressionKeys<>(field, getType()));
+ orders.add(order);
+
return this;
}
- // --------------------------------------------------------------------------------------------
- // Key Extraction
- // --------------------------------------------------------------------------------------------
-
- private int[] getFlatFields(int field) {
+ public <K> SortPartitionOperator<T> sortPartition(KeySelector<T, K> keyExtractor,
Order order) {
+ throw new InvalidProgramException("KeySelector cannot be chained.");
+ }
- if (!Keys.ExpressionKeys.isSortKey(field, super.getType())) {
+ private void ensureSortableKey(int field) throws InvalidProgramException {
+ if (!Keys.ExpressionKeys.isSortKey(field, getType())) {
throw new InvalidProgramException("Selected sort key is not a sortable type");
}
-
- Keys.ExpressionKeys<T> ek = new Keys.ExpressionKeys<>(field, super.getType());
- return ek.computeLogicalKeyPositions();
}
- private int[] getFlatFields(String fields) {
-
- if (!Keys.ExpressionKeys.isSortKey(fields, super.getType())) {
+ private void ensureSortableKey(String field) throws InvalidProgramException {
+ if (!Keys.ExpressionKeys.isSortKey(field, getType())) {
throw new InvalidProgramException("Selected sort key is not a sortable type");
}
-
- Keys.ExpressionKeys<T> ek = new Keys.ExpressionKeys<>(fields, super.getType());
- return ek.computeLogicalKeyPositions();
}
- private void appendSorting(int[] flatOrderFields, Order order) {
-
- if(this.sortKeyPositions == null) {
- // set sorting info
- this.sortKeyPositions = flatOrderFields;
- this.sortOrders = new Order[flatOrderFields.length];
- Arrays.fill(this.sortOrders, order);
- } else {
- // append sorting info to exising info
- int oldLength = this.sortKeyPositions.length;
- int newLength = oldLength + flatOrderFields.length;
- this.sortKeyPositions = Arrays.copyOf(this.sortKeyPositions, newLength);
- this.sortOrders = Arrays.copyOf(this.sortOrders, newLength);
-
- for(int i=0; i<flatOrderFields.length; i++) {
- this.sortKeyPositions[oldLength+i] = flatOrderFields[i];
- this.sortOrders[oldLength+i] = order;
- }
+ private <K> void ensureSortableKey(Keys.SelectorFunctionKeys<T, K> sortKey)
{
+ if (!sortKey.getKeyType().isSortKeyType()) {
+ throw new InvalidProgramException("Selected sort key is not a sortable type");
}
}
@@ -144,16 +167,33 @@ public class SortPartitionOperator<T> extends SingleInputOperator<T,
T, SortPart
String name = "Sort at " + sortLocationName;
+ if (useKeySelector) {
+ return translateToDataFlowWithKeyExtractor(input, (Keys.SelectorFunctionKeys<T, ?>)
keys.get(0), orders.get(0), name);
+ }
+
+ // flatten sort key positions
+ List<Integer> allKeyPositions = new ArrayList<>();
+ List<Order> allOrders = new ArrayList<>();
+ for (int i = 0, length = keys.size(); i < length; i++) {
+ int[] sortKeyPositions = keys.get(i).computeLogicalKeyPositions();
+ Order order = orders.get(i);
+
+ for (int sortKeyPosition : sortKeyPositions) {
+ allKeyPositions.add(sortKeyPosition);
+ allOrders.add(order);
+ }
+ }
+
Ordering partitionOrdering = new Ordering();
- for (int i = 0; i < this.sortKeyPositions.length; i++) {
- partitionOrdering.appendOrdering(this.sortKeyPositions[i], null, this.sortOrders[i]);
+ for (int i = 0, length = allKeyPositions.size(); i < length; i++) {
+ partitionOrdering.appendOrdering(allKeyPositions.get(i), null, allOrders.get(i));
}
// distinguish between partition types
UnaryOperatorInformation<T, T> operatorInfo = new UnaryOperatorInformation<>(getType(),
getType());
- SortPartitionOperatorBase<T> noop = new SortPartitionOperatorBase<>(operatorInfo,
partitionOrdering, name);
+ SortPartitionOperatorBase<T> noop = new SortPartitionOperatorBase<>(operatorInfo,
partitionOrdering, name);
noop.setInput(input);
- if(this.getParallelism() < 0) {
+ if (this.getParallelism() < 0) {
// use parallelism of input if not explicitly specified
noop.setParallelism(input.getParallelism());
} else {
@@ -165,4 +205,32 @@ public class SortPartitionOperator<T> extends SingleInputOperator<T,
T, SortPart
}
+ private <K> org.apache.flink.api.common.operators.SingleInputOperator<?, T, ?>
translateToDataFlowWithKeyExtractor(
+ Operator<T> input, Keys.SelectorFunctionKeys<T, K> keys, Order order, String
name) {
+ TypeInformation<Tuple2<K, T>> typeInfoWithKey = KeyFunctions.createTypeWithKey(keys);
+ Keys.ExpressionKeys<Tuple2<K, T>> newKey = new Keys.ExpressionKeys<>(0,
typeInfoWithKey);
+
+ Operator<Tuple2<K, T>> keyedInput = KeyFunctions.appendKeyExtractor(input,
keys);
+
+ int[] sortKeyPositions = newKey.computeLogicalKeyPositions();
+ Ordering partitionOrdering = new Ordering();
+ for (int keyPosition : sortKeyPositions) {
+ partitionOrdering.appendOrdering(keyPosition, null, order);
+ }
+
+ // distinguish between partition types
+ UnaryOperatorInformation<Tuple2<K, T>, Tuple2<K, T>> operatorInfo = new
UnaryOperatorInformation<>(typeInfoWithKey, typeInfoWithKey);
+ SortPartitionOperatorBase<Tuple2<K, T>> noop = new SortPartitionOperatorBase<>(operatorInfo,
partitionOrdering, name);
+ noop.setInput(keyedInput);
+ if (this.getParallelism() < 0) {
+ // use parallelism of input if not explicitly specified
+ noop.setParallelism(input.getParallelism());
+ } else {
+ // use explicitly specified parallelism
+ noop.setParallelism(this.getParallelism());
+ }
+
+ return KeyFunctions.appendKeyRemover(noop, keys);
+ }
+
}
http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-java/src/test/java/org/apache/flink/api/java/operator/SortPartitionTest.java
----------------------------------------------------------------------
diff --git a/flink-java/src/test/java/org/apache/flink/api/java/operator/SortPartitionTest.java
b/flink-java/src/test/java/org/apache/flink/api/java/operator/SortPartitionTest.java
index a4e2bbc..3540e6a 100644
--- a/flink-java/src/test/java/org/apache/flink/api/java/operator/SortPartitionTest.java
+++ b/flink-java/src/test/java/org/apache/flink/api/java/operator/SortPartitionTest.java
@@ -169,6 +169,88 @@ public class SortPartitionTest {
tupleDs.sortPartition("f3", Order.ASCENDING);
}
+ @Test
+ public void testSortPartitionWithKeySelector1() {
+ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData,
tupleWithCustomInfo);
+
+ // should work
+ try {
+ tupleDs.sortPartition(new KeySelector<Tuple4<Integer, Long, CustomType, Long[]>,
Integer>() {
+ @Override
+ public Integer getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws Exception
{
+ return value.f0;
+ }
+ }, Order.ASCENDING);
+ } catch (Exception e) {
+ Assert.fail();
+ }
+ }
+
+ @Test(expected = InvalidProgramException.class)
+ public void testSortPartitionWithKeySelector2() {
+ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData,
tupleWithCustomInfo);
+
+ // must not work
+ tupleDs.sortPartition(new KeySelector<Tuple4<Integer, Long, CustomType, Long[]>,
Long[]>() {
+ @Override
+ public Long[] getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws Exception
{
+ return value.f3;
+ }
+ }, Order.ASCENDING);
+ }
+
+ @Test(expected = InvalidProgramException.class)
+ public void testSortPartitionWithKeySelector3() {
+ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData,
tupleWithCustomInfo);
+
+ // must not work
+ tupleDs
+ .sortPartition("f1", Order.ASCENDING)
+ .sortPartition(new KeySelector<Tuple4<Integer, Long, CustomType, Long[]>, CustomType>()
{
+ @Override
+ public CustomType getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws
Exception {
+ return value.f2;
+ }
+ }, Order.ASCENDING);
+ }
+
+ @Test
+ public void testSortPartitionWithKeySelector4() {
+ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData,
tupleWithCustomInfo);
+
+ // should work
+ try {
+ tupleDs.sortPartition(new KeySelector<Tuple4<Integer,Long,CustomType,Long[]>,
Tuple2<Integer, Long>>() {
+ @Override
+ public Tuple2<Integer, Long> getKey(Tuple4<Integer, Long, CustomType, Long[]>
value) throws Exception {
+ return new Tuple2<>(value.f0, value.f1);
+ }
+ }, Order.ASCENDING);
+ } catch (Exception e) {
+ Assert.fail();
+ }
+ }
+
+ @Test(expected = InvalidProgramException.class)
+ public void testSortPartitionWithKeySelector5() {
+ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ DataSet<Tuple4<Integer, Long, CustomType, Long[]>> tupleDs = env.fromCollection(tupleWithCustomData,
tupleWithCustomInfo);
+
+ // must not work
+ tupleDs
+ .sortPartition(new KeySelector<Tuple4<Integer, Long, CustomType, Long[]>, CustomType>()
{
+ @Override
+ public CustomType getKey(Tuple4<Integer, Long, CustomType, Long[]> value) throws
Exception {
+ return value.f2;
+ }
+ }, Order.ASCENDING)
+ .sortPartition("f1", Order.ASCENDING);
+ }
+
public static class CustomType implements Serializable {
public static class Nest {
http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
index e47bc42..5735b32 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/DataSet.scala
@@ -1511,6 +1511,31 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
new SortPartitionOperator[T](javaSet, field, order, getCallLocationName()))
}
+ /**
+ * Locally sorts the partitions of the DataSet on the extracted key in the specified order.
+ * The DataSet can be sorted on multiple values by returning a tuple from the KeySelector.
+ *
+ * Note that no additional sort keys can be appended to a KeySelector sort keys. To sort
+ * the partitions by multiple values using KeySelector, the KeySelector must return a
tuple
+ * consisting of the values.
+ */
+ def sortPartition[K: TypeInformation](fun: T => K, order: Order): DataSet[T] ={
+ val keyExtractor = new KeySelector[T, K] {
+ val cleanFun = clean(fun)
+ def getKey(in: T) = cleanFun(in)
+ }
+
+ val keyType = implicitly[TypeInformation[K]]
+ new PartitionSortedDataSet[T](
+ new SortPartitionOperator[T](javaSet,
+ new Keys.SelectorFunctionKeys[T, K](
+ keyExtractor,
+ javaSet.getType,
+ keyType),
+ order,
+ getCallLocationName()))
+ }
+
// --------------------------------------------------------------------------------------------
// Result writing
// --------------------------------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-scala/src/main/scala/org/apache/flink/api/scala/PartitionSortedDataSet.scala
----------------------------------------------------------------------
diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/PartitionSortedDataSet.scala
b/flink-scala/src/main/scala/org/apache/flink/api/scala/PartitionSortedDataSet.scala
index c924a76..a402dd9 100644
--- a/flink-scala/src/main/scala/org/apache/flink/api/scala/PartitionSortedDataSet.scala
+++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/PartitionSortedDataSet.scala
@@ -18,7 +18,9 @@
package org.apache.flink.api.scala
import org.apache.flink.annotation.Public
+import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.common.operators.Order
+import org.apache.flink.api.common.typeinfo.TypeInformation
import org.apache.flink.api.java.operators.SortPartitionOperator
import scala.reflect.ClassTag
@@ -37,16 +39,30 @@ class PartitionSortedDataSet[T: ClassTag](set: SortPartitionOperator[T])
* Appends the given field and order to the sort-partition operator.
*/
override def sortPartition(field: Int, order: Order): DataSet[T] = {
+ if (set.useKeySelector()) {
+ throw new InvalidProgramException("Expression keys cannot be appended after selector
" +
+ "function keys")
+ }
+
this.set.sortPartition(field, order)
this
}
-/**
- * Appends the given field and order to the sort-partition operator.
- */
+ /**
+ * Appends the given field and order to the sort-partition operator.
+ */
override def sortPartition(field: String, order: Order): DataSet[T] = {
+ if (set.useKeySelector()) {
+ throw new InvalidProgramException("Expression keys cannot be appended after selector
" +
+ "function keys")
+ }
+
this.set.sortPartition(field, order)
this
}
+ override def sortPartition[K: TypeInformation](fun: T => K, order: Order): DataSet[T]
= {
+ throw new InvalidProgramException("KeySelector cannot be chained.")
+ }
+
}
http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SortPartitionITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SortPartitionITCase.java
b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SortPartitionITCase.java
index 2423420..c7f07f6 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SortPartitionITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/javaApiOperators/SortPartitionITCase.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple1;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
@@ -197,6 +198,58 @@ public class SortPartitionITCase extends MultipleProgramsTestBase {
compareResultAsText(result, expected);
}
+ @Test
+ public void testSortPartitionWithKeySelector1() throws Exception {
+ /*
+ * Test sort partition on an extracted key
+ */
+
+ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setParallelism(4);
+
+ DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
+ List<Tuple1<Boolean>> result = ds
+ .map(new IdMapper<Tuple3<Integer, Long, String>>()).setParallelism(4) // parallelize
input
+ .sortPartition(new KeySelector<Tuple3<Integer, Long, String>, Long>() {
+ @Override
+ public Long getKey(Tuple3<Integer, Long, String> value) throws Exception {
+ return value.f1;
+ }
+ }, Order.ASCENDING)
+ .mapPartition(new OrderCheckMapper<>(new Tuple3AscendingChecker()))
+ .distinct().collect();
+
+ String expected = "(true)\n";
+
+ compareResultAsText(result, expected);
+ }
+
+ @Test
+ public void testSortPartitionWithKeySelector2() throws Exception {
+ /*
+ * Test sort partition on an extracted key
+ */
+
+ final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+ env.setParallelism(4);
+
+ DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
+ List<Tuple1<Boolean>> result = ds
+ .map(new IdMapper<Tuple3<Integer, Long, String>>()).setParallelism(4) // parallelize
input
+ .sortPartition(new KeySelector<Tuple3<Integer, Long, String>, Tuple2<Integer,
Long>>() {
+ @Override
+ public Tuple2<Integer, Long> getKey(Tuple3<Integer, Long, String> value)
throws Exception {
+ return new Tuple2<>(value.f0, value.f1);
+ }
+ }, Order.DESCENDING)
+ .mapPartition(new OrderCheckMapper<>(new Tuple3Checker()))
+ .distinct().collect();
+
+ String expected = "(true)\n";
+
+ compareResultAsText(result, expected);
+ }
+
public interface OrderChecker<T> extends Serializable {
boolean inOrder(T t1, T t2);
}
@@ -210,6 +263,14 @@ public class SortPartitionITCase extends MultipleProgramsTestBase {
}
@SuppressWarnings("serial")
+ public static class Tuple3AscendingChecker implements OrderChecker<Tuple3<Integer,
Long, String>> {
+ @Override
+ public boolean inOrder(Tuple3<Integer, Long, String> t1, Tuple3<Integer, Long,
String> t2) {
+ return t1.f1 <= t2.f1;
+ }
+ }
+
+ @SuppressWarnings("serial")
public static class Tuple5Checker implements OrderChecker<Tuple5<Integer, Long, Integer,
String, Long>> {
@Override
public boolean inOrder(Tuple5<Integer, Long, Integer, String, Long> t1,
http://git-wip-us.apache.org/repos/asf/flink/blob/0a63797a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SortPartitionITCase.scala
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SortPartitionITCase.scala
b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SortPartitionITCase.scala
index 3f67063..cda8f4f 100644
--- a/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SortPartitionITCase.scala
+++ b/flink-tests/src/test/scala/org/apache/flink/api/scala/operators/SortPartitionITCase.scala
@@ -24,6 +24,7 @@ import scala.collection.JavaConverters._
import org.apache.flink.api.common.functions.MapPartitionFunction
import org.apache.flink.api.common.operators.Order
+import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.util.CollectionDataSets
import org.apache.flink.test.util.MultipleProgramsTestBase.TestExecutionMode
@@ -166,6 +167,58 @@ class SortPartitionITCase(mode: TestExecutionMode) extends MultipleProgramsTestB
TestBaseUtils.compareResultAsText(result.asJava, expected)
}
+ @Test
+ def testSortPartitionWithKeySelector1(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ env.setParallelism(4)
+ val ds = CollectionDataSets.get3TupleDataSet(env)
+
+ val result = ds
+ .map { x => x }.setParallelism(4)
+ .sortPartition(_._2, Order.ASCENDING)
+ .mapPartition(new OrderCheckMapper(new Tuple3AscendingChecker))
+ .distinct()
+ .collect()
+
+ val expected: String = "(true)\n"
+ TestBaseUtils.compareResultAsText(result.asJava, expected)
+ }
+
+ @Test
+ def testSortPartitionWithKeySelector2(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ env.setParallelism(4)
+ val ds = CollectionDataSets.get3TupleDataSet(env)
+
+ val result = ds
+ .map { x => x }.setParallelism(4)
+ .sortPartition(x => (x._2, x._1), Order.DESCENDING)
+ .mapPartition(new OrderCheckMapper(new Tuple3Checker))
+ .distinct()
+ .collect()
+
+ val expected: String = "(true)\n"
+ TestBaseUtils.compareResultAsText(result.asJava, expected)
+ }
+
+ @Test(expected = classOf[InvalidProgramException])
+ def testSortPartitionWithKeySelector3(): Unit = {
+ val env = ExecutionEnvironment.getExecutionEnvironment
+ env.setParallelism(4)
+ val ds = CollectionDataSets.get3TupleDataSet(env)
+
+ val result = ds
+ .map { x => x }.setParallelism(4)
+ .sortPartition(x => (x._2, x._1), Order.DESCENDING)
+ .sortPartition(0, Order.DESCENDING)
+ .mapPartition(new OrderCheckMapper(new Tuple3Checker))
+ .distinct()
+ .collect()
+
+ val expected: String = "(true)\n"
+ TestBaseUtils.compareResultAsText(result.asJava, expected)
+ }
+
}
trait OrderChecker[T] extends Serializable {
@@ -178,6 +231,12 @@ class Tuple3Checker extends OrderChecker[(Int, Long, String)] {
}
}
+class Tuple3AscendingChecker extends OrderChecker[(Int, Long, String)] {
+ def inOrder(t1: (Int, Long, String), t2: (Int, Long, String)): Boolean = {
+ t1._2 <= t2._2
+ }
+}
+
class Tuple5Checker extends OrderChecker[(Int, Long, Int, String, Long)] {
def inOrder(t1: (Int, Long, Int, String, Long), t2: (Int, Long, Int, String, Long)): Boolean
= {
t1._5 < t2._5 || t1._5 == t2._5 && t1._3 >= t2._3
|