flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From aljos...@apache.org
Subject [02/27] flink git commit: [FLINK-4380] Introduce KeyGroupAssigner and Max-Parallelism Parameter
Date Wed, 31 Aug 2016 17:28:20 GMT
http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
index 5a86c5c..17bea68 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/DataStreamTest.java
@@ -52,7 +52,7 @@ import org.apache.flink.streaming.api.windowing.triggers.PurgingTrigger;
 import org.apache.flink.streaming.api.windowing.windows.GlobalWindow;
 import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
 import org.apache.flink.streaming.runtime.partitioner.CustomPartitionerWrapper;
-import org.apache.flink.streaming.runtime.partitioner.HashPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.KeyGroupStreamPartitioner;
 import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
 import org.apache.flink.streaming.runtime.partitioner.GlobalPartitioner;
 import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner;
@@ -672,7 +672,7 @@ public class DataStreamTest {
 		assertNotNull(env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getStateKeySerializer());
 		assertNotNull(env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getStateKeySerializer());
 		assertEquals(key1, env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getStatePartitioner1());
-		assertTrue(env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getInEdges().get(0).getPartitioner() instanceof HashPartitioner);
+		assertTrue(env.getStreamGraph().getStreamNode(sink2.getTransformation().getId()).getInEdges().get(0).getPartitioner() instanceof KeyGroupStreamPartitioner);
 
 		KeySelector<Long, Long> key2 = new KeySelector<Long, Long>() {
 
@@ -688,7 +688,7 @@ public class DataStreamTest {
 
 		assertTrue(env.getStreamGraph().getStreamNode(sink3.getTransformation().getId()).getStatePartitioner1() != null);
 		assertEquals(key2, env.getStreamGraph().getStreamNode(sink3.getTransformation().getId()).getStatePartitioner1());
-		assertTrue(env.getStreamGraph().getStreamNode(sink3.getTransformation().getId()).getInEdges().get(0).getPartitioner() instanceof HashPartitioner);
+		assertTrue(env.getStreamGraph().getStreamNode(sink3.getTransformation().getId()).getInEdges().get(0).getPartitioner() instanceof KeyGroupStreamPartitioner);
 	}
 
 	@Test
@@ -783,7 +783,7 @@ public class DataStreamTest {
 	private static boolean isPartitioned(List<StreamEdge> edges) {
 		boolean result = true;
 		for (StreamEdge edge: edges) {
-			if (!(edge.getPartitioner() instanceof HashPartitioner)) {
+			if (!(edge.getPartitioner() instanceof KeyGroupStreamPartitioner)) {
 				result = false;
 			}
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/RestartStrategyTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/RestartStrategyTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/RestartStrategyTest.java
index c57bea7..d6fcd61 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/RestartStrategyTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/RestartStrategyTest.java
@@ -22,11 +22,11 @@ import org.apache.flink.api.common.restartstrategy.RestartStrategies;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.graph.StreamGraph;
-
+import org.apache.flink.util.TestLogger;
 import org.junit.Assert;
 import org.junit.Test;
 
-public class RestartStrategyTest {
+public class RestartStrategyTest extends TestLogger {
 
 	/**
 	 * Tests that in a streaming use case where checkpointing is enabled, a

http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/SlotAllocationTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/SlotAllocationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/SlotAllocationTest.java
index bab43fa..d873771 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/SlotAllocationTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/SlotAllocationTest.java
@@ -28,6 +28,7 @@ import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 
 import org.apache.flink.streaming.api.functions.co.CoMapFunction;
+import org.apache.flink.util.TestLogger;
 import org.junit.Test;
 
 /**
@@ -37,7 +38,7 @@ import org.junit.Test;
  * resource groups/slot sharing groups.
  */
 @SuppressWarnings("serial")
-public class SlotAllocationTest {
+public class SlotAllocationTest extends TestLogger {
 	
 	@Test
 	public void testTwoPipelines() {

http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java
index a4ee18e..06d381f 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamGraphGeneratorTest.java
@@ -19,13 +19,17 @@
 package org.apache.flink.streaming.api.graph;
 
 import org.apache.flink.api.common.ExecutionConfig;
+import org.apache.flink.api.common.state.KeyGroupAssigner;
 import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.runtime.state.HashKeyGroupAssigner;
 import org.apache.flink.streaming.api.datastream.ConnectedStreams;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.datastream.SingleOutputStreamOperator;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
+import org.apache.flink.streaming.api.functions.co.CoMapFunction;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.Output;
@@ -34,8 +38,10 @@ import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.partitioner.BroadcastPartitioner;
 import org.apache.flink.streaming.runtime.partitioner.GlobalPartitioner;
+import org.apache.flink.streaming.runtime.partitioner.KeyGroupStreamPartitioner;
 import org.apache.flink.streaming.runtime.partitioner.RebalancePartitioner;
 import org.apache.flink.streaming.runtime.partitioner.ShufflePartitioner;
+import org.apache.flink.streaming.runtime.partitioner.StreamPartitioner;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.streaming.runtime.tasks.StreamTask;
 import org.apache.flink.streaming.util.EvenOddOutputSelector;
@@ -236,6 +242,207 @@ public class StreamGraphGeneratorTest {
 		assertEquals(BasicTypeInfo.INT_TYPE_INFO, outputTypeConfigurableOperation.getTypeInformation());
 	}
 
+	/**
+	 * Tests that the KeyGroupStreamPartitioner are properly set up with the correct value of
+	 * maximum parallelism.
+	 */
+	@Test
+	public void testSetupOfKeyGroupPartitioner() {
+		int maxParallelism = 42;
+		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+		env.getConfig().setMaxParallelism(maxParallelism);
+
+		DataStream<Integer> source = env.fromElements(1, 2, 3);
+
+		DataStream<Integer> keyedResult = source.keyBy(new KeySelector<Integer, Integer>() {
+			private static final long serialVersionUID = 9205556348021992189L;
+
+			@Override
+			public Integer getKey(Integer value) throws Exception {
+				return value;
+			}
+		}).map(new NoOpIntMap());
+
+		keyedResult.addSink(new DiscardingSink<Integer>());
+
+		StreamGraph graph = env.getStreamGraph();
+
+		StreamNode keyedResultNode = graph.getStreamNode(keyedResult.getId());
+
+		StreamPartitioner<?> streamPartitioner = keyedResultNode.getInEdges().get(0).getPartitioner();
+
+		HashKeyGroupAssigner<?> hashKeyGroupAssigner = extractHashKeyGroupAssigner(streamPartitioner);
+
+		assertEquals(maxParallelism, hashKeyGroupAssigner.getNumberKeyGroups());
+	}
+
+	/**
+	 * Tests that the global and operator-wide max parallelism setting is respected
+	 */
+	@Test
+	public void testMaxParallelismForwarding() {
+		int globalMaxParallelism = 42;
+		int keyedResult2MaxParallelism = 17;
+
+		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+		env.getConfig().setMaxParallelism(globalMaxParallelism);
+
+		DataStream<Integer> source = env.fromElements(1, 2, 3);
+
+		DataStream<Integer> keyedResult1 = source.keyBy(new KeySelector<Integer, Integer>() {
+			private static final long serialVersionUID = 9205556348021992189L;
+
+			@Override
+			public Integer getKey(Integer value) throws Exception {
+				return value;
+			}
+		}).map(new NoOpIntMap());
+
+		DataStream<Integer> keyedResult2 = keyedResult1.keyBy(new KeySelector<Integer, Integer>() {
+			private static final long serialVersionUID = 1250168178707154838L;
+
+			@Override
+			public Integer getKey(Integer value) throws Exception {
+				return value;
+			}
+		}).map(new NoOpIntMap()).setMaxParallelism(keyedResult2MaxParallelism);
+
+		keyedResult2.addSink(new DiscardingSink<Integer>());
+
+		StreamGraph graph = env.getStreamGraph();
+
+		StreamNode keyedResult1Node = graph.getStreamNode(keyedResult1.getId());
+		StreamNode keyedResult2Node = graph.getStreamNode(keyedResult2.getId());
+
+		assertEquals(globalMaxParallelism, keyedResult1Node.getMaxParallelism());
+		assertEquals(keyedResult2MaxParallelism, keyedResult2Node.getMaxParallelism());
+	}
+
+	/**
+	 * Tests that the max parallelism is automatically set to the parallelism if it has not been
+	 * specified.
+	 */
+	@Test
+	public void testAutoMaxParallelism() {
+		int globalParallelism = 42;
+		int mapParallelism = 17;
+		int maxParallelism = 21;
+		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+		env.setParallelism(globalParallelism);
+
+		DataStream<Integer> source = env.fromElements(1, 2, 3);
+
+		DataStream<Integer> keyedResult1 = source.keyBy(new KeySelector<Integer, Integer>() {
+			private static final long serialVersionUID = 9205556348021992189L;
+
+			@Override
+			public Integer getKey(Integer value) throws Exception {
+				return value;
+			}
+		}).map(new NoOpIntMap());
+
+		DataStream<Integer> keyedResult2 = keyedResult1.keyBy(new KeySelector<Integer, Integer>() {
+			private static final long serialVersionUID = 1250168178707154838L;
+
+			@Override
+			public Integer getKey(Integer value) throws Exception {
+				return value;
+			}
+		}).map(new NoOpIntMap()).setParallelism(mapParallelism);
+
+		DataStream<Integer> keyedResult3 = keyedResult2.keyBy(new KeySelector<Integer, Integer>() {
+			private static final long serialVersionUID = 1250168178707154838L;
+
+			@Override
+			public Integer getKey(Integer value) throws Exception {
+				return value;
+			}
+		}).map(new NoOpIntMap()).setMaxParallelism(maxParallelism);
+
+		DataStream<Integer> keyedResult4 = keyedResult3.keyBy(new KeySelector<Integer, Integer>() {
+			private static final long serialVersionUID = 1250168178707154838L;
+
+			@Override
+			public Integer getKey(Integer value) throws Exception {
+				return value;
+			}
+		}).map(new NoOpIntMap()).setMaxParallelism(maxParallelism).setParallelism(mapParallelism);
+
+		keyedResult4.addSink(new DiscardingSink<Integer>());
+
+		StreamGraph graph = env.getStreamGraph();
+
+		StreamNode keyedResult1Node = graph.getStreamNode(keyedResult1.getId());
+		StreamNode keyedResult2Node = graph.getStreamNode(keyedResult2.getId());
+		StreamNode keyedResult3Node = graph.getStreamNode(keyedResult3.getId());
+		StreamNode keyedResult4Node = graph.getStreamNode(keyedResult4.getId());
+
+		assertEquals(globalParallelism, keyedResult1Node.getMaxParallelism());
+		assertEquals(mapParallelism, keyedResult2Node.getMaxParallelism());
+		assertEquals(maxParallelism, keyedResult3Node.getMaxParallelism());
+		assertEquals(maxParallelism, keyedResult4Node.getMaxParallelism());
+	}
+
+	/**
+	 * Tests that the max parallelism and the key group partitioner is properly set for connected
+	 * streams.
+	 */
+	@Test
+	public void testMaxParallelismWithConnectedKeyedStream() {
+		int maxParallelism = 42;
+
+		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+		DataStream<Integer> input1 = env.fromElements(1, 2, 3, 4).setMaxParallelism(128);
+		DataStream<Integer> input2 = env.fromElements(1, 2, 3, 4).setMaxParallelism(129);
+
+		env.getConfig().setMaxParallelism(maxParallelism);
+
+		DataStream<Integer> keyedResult = input1.connect(input2).keyBy(
+			 new KeySelector<Integer, Integer>() {
+				 private static final long serialVersionUID = -6908614081449363419L;
+
+				 @Override
+				 public Integer getKey(Integer value) throws Exception {
+					 return value;
+				 }
+			},
+			new KeySelector<Integer, Integer>() {
+				private static final long serialVersionUID = 3195683453223164931L;
+
+				@Override
+				public Integer getKey(Integer value) throws Exception {
+					return value;
+				}
+			}).map(new NoOpIntCoMap());
+
+		keyedResult.addSink(new DiscardingSink<Integer>());
+
+		StreamGraph graph = env.getStreamGraph();
+
+		StreamNode keyedResultNode = graph.getStreamNode(keyedResult.getId());
+
+		StreamPartitioner<?> streamPartitioner1 = keyedResultNode.getInEdges().get(0).getPartitioner();
+		StreamPartitioner<?> streamPartitioner2 = keyedResultNode.getInEdges().get(1).getPartitioner();
+
+		HashKeyGroupAssigner<?> hashKeyGroupAssigner1 = extractHashKeyGroupAssigner(streamPartitioner1);
+		assertEquals(maxParallelism, hashKeyGroupAssigner1.getNumberKeyGroups());
+
+		HashKeyGroupAssigner<?> hashKeyGroupAssigner2 = extractHashKeyGroupAssigner(streamPartitioner2);
+		assertEquals(maxParallelism, hashKeyGroupAssigner2.getNumberKeyGroups());
+	}
+
+	private HashKeyGroupAssigner<?> extractHashKeyGroupAssigner(StreamPartitioner<?> streamPartitioner) {
+		assertTrue(streamPartitioner instanceof KeyGroupStreamPartitioner);
+
+		KeyGroupStreamPartitioner<?, ?> keyGroupStreamPartitioner = (KeyGroupStreamPartitioner<?, ?>) streamPartitioner;
+
+		KeyGroupAssigner<?> keyGroupAssigner = keyGroupStreamPartitioner.getKeyGroupAssigner();
+
+		assertTrue(keyGroupAssigner instanceof HashKeyGroupAssigner);
+
+		return (HashKeyGroupAssigner<?>) keyGroupAssigner;
+	}
+
 	private static class OutputTypeConfigurableOperationWithTwoInputs
 			extends AbstractStreamOperator<Integer>
 			implements TwoInputStreamOperator<Integer, Integer, Integer>, OutputTypeConfigurable<Integer> {
@@ -297,4 +504,17 @@ public class StreamGraphGeneratorTest {
 		}
 	}
 
+	static class NoOpIntCoMap implements CoMapFunction<Integer, Integer, Integer> {
+		private static final long serialVersionUID = 1886595528149124270L;
+
+		public Integer map1(Integer value) throws Exception {
+			return value;
+		}
+
+		public Integer map2(Integer value) throws Exception {
+			return value;
+		}
+
+	};
+
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
index 7f94aa0..277fab4 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/api/graph/StreamingJobGraphGeneratorTest.java
@@ -18,24 +18,33 @@
 package org.apache.flink.streaming.api.graph;
 
 import java.io.IOException;
+import java.util.List;
 import java.util.Random;
 
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.state.KeyGroupAssigner;
+import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.jobgraph.JobVertex;
+import org.apache.flink.runtime.state.HashKeyGroupAssigner;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
 import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import org.apache.flink.streaming.util.NoOpIntMap;
 import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.SerializedValue;
 
+import org.apache.flink.util.TestLogger;
 import org.junit.Test;
 
 import static org.junit.Assert.*;
 
 @SuppressWarnings("serial")
-public class StreamingJobGraphGeneratorTest {
+public class StreamingJobGraphGeneratorTest extends TestLogger {
 	
 	@Test
 	public void testExecutionConfigSerialization() throws IOException, ClassNotFoundException {
@@ -114,6 +123,8 @@ public class StreamingJobGraphGeneratorTest {
 		DataStream<Tuple2<String, String>> input = env
 				.fromElements("a", "b", "c", "d", "e", "f")
 				.map(new MapFunction<String, Tuple2<String, String>>() {
+					private static final long serialVersionUID = 471891682418382583L;
+
 					@Override
 					public Tuple2<String, String> map(String value) {
 						return new Tuple2<>(value, value);
@@ -124,6 +135,8 @@ public class StreamingJobGraphGeneratorTest {
 				.keyBy(0)
 				.map(new MapFunction<Tuple2<String, String>, Tuple2<String, String>>() {
 
+					private static final long serialVersionUID = 3583760206245136188L;
+
 					@Override
 					public Tuple2<String, String> map(Tuple2<String, String> value) {
 						return value;
@@ -131,6 +144,8 @@ public class StreamingJobGraphGeneratorTest {
 				});
 
 		result.addSink(new SinkFunction<Tuple2<String, String>>() {
+			private static final long serialVersionUID = -5614849094269539342L;
+
 			@Override
 			public void invoke(Tuple2<String, String> value) {}
 		});
@@ -145,4 +160,203 @@ public class StreamingJobGraphGeneratorTest {
 		assertEquals(1, jobGraph.getVerticesAsArray()[0].getParallelism());
 		assertEquals(1, jobGraph.getVerticesAsArray()[1].getParallelism());
 	}
+
+	/**
+	 * Tests that the KeyGroupAssigner is properly set in the {@link StreamConfig} if the max
+	 * parallelism is set for the whole job.
+	 */
+	@Test
+	public void testKeyGroupAssignerProperlySet() {
+		int maxParallelism = 42;
+
+		final StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+		env.getConfig().setMaxParallelism(maxParallelism);
+
+		DataStream<Integer> input = env.fromElements(1, 2, 3);
+
+		DataStream<Integer> keyedResult = input.keyBy(new KeySelector<Integer, Integer>() {
+			private static final long serialVersionUID = 350461576474507944L;
+
+			@Override
+			public Integer getKey(Integer value) throws Exception {
+				return value;
+			}
+		}).map(new NoOpIntMap());
+
+		keyedResult.addSink(new DiscardingSink<Integer>());
+
+		JobGraph jobGraph = env.getStreamGraph().getJobGraph();
+
+		List<JobVertex> jobVertices = jobGraph.getVerticesSortedTopologicallyFromSources();
+
+		assertEquals(maxParallelism, jobVertices.get(1).getMaxParallelism());
+
+		HashKeyGroupAssigner<Integer> hashKeyGroupAssigner = extractHashKeyGroupAssigner(jobVertices.get(1));
+
+		assertEquals(maxParallelism, hashKeyGroupAssigner.getNumberKeyGroups());
+	}
+
+	/**
+	 * Tests that the key group assigner for the keyed streams in the stream config is properly
+	 * initialized with the max parallelism value if there is no max parallelism defined for the
+	 * whole job.
+	 */
+	@Test
+	public void testKeyGroupAssignerProperlySetAutoMaxParallelism() {
+		int globalParallelism = 42;
+		int mapParallelism = 17;
+		int maxParallelism = 43;
+		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+		env.setParallelism(globalParallelism);
+
+		DataStream<Integer> source = env.fromElements(1, 2, 3);
+
+		DataStream<Integer> keyedResult1 = source.keyBy(new KeySelector<Integer, Integer>() {
+			private static final long serialVersionUID = 9205556348021992189L;
+
+			@Override
+			public Integer getKey(Integer value) throws Exception {
+				return value;
+			}
+		}).map(new NoOpIntMap());
+
+		DataStream<Integer> keyedResult2 = keyedResult1.keyBy(new KeySelector<Integer, Integer>() {
+			private static final long serialVersionUID = 1250168178707154838L;
+
+			@Override
+			public Integer getKey(Integer value) throws Exception {
+				return value;
+			}
+		}).map(new NoOpIntMap()).setParallelism(mapParallelism);
+
+		DataStream<Integer> keyedResult3 = keyedResult2.keyBy(new KeySelector<Integer, Integer>() {
+			private static final long serialVersionUID = 1250168178707154838L;
+
+			@Override
+			public Integer getKey(Integer value) throws Exception {
+				return value;
+			}
+		}).map(new NoOpIntMap()).setMaxParallelism(maxParallelism);
+
+		DataStream<Integer> keyedResult4 = keyedResult3.keyBy(new KeySelector<Integer, Integer>() {
+			private static final long serialVersionUID = 1250168178707154838L;
+
+			@Override
+			public Integer getKey(Integer value) throws Exception {
+				return value;
+			}
+		}).map(new NoOpIntMap()).setMaxParallelism(maxParallelism).setParallelism(mapParallelism);
+
+		keyedResult4.addSink(new DiscardingSink<Integer>());
+
+		JobGraph jobGraph = env.getStreamGraph().getJobGraph();
+
+		List<JobVertex> vertices = jobGraph.getVerticesSortedTopologicallyFromSources();
+
+		JobVertex keyedResultJV1 = vertices.get(1);
+		JobVertex keyedResultJV2 = vertices.get(2);
+		JobVertex keyedResultJV3 = vertices.get(3);
+		JobVertex keyedResultJV4 = vertices.get(4);
+
+		HashKeyGroupAssigner<Integer> hashKeyGroupAssigner1 = extractHashKeyGroupAssigner(keyedResultJV1);
+		HashKeyGroupAssigner<Integer> hashKeyGroupAssigner2 = extractHashKeyGroupAssigner(keyedResultJV2);
+		HashKeyGroupAssigner<Integer> hashKeyGroupAssigner3 = extractHashKeyGroupAssigner(keyedResultJV3);
+		HashKeyGroupAssigner<Integer> hashKeyGroupAssigner4 = extractHashKeyGroupAssigner(keyedResultJV4);
+
+		assertEquals(globalParallelism, hashKeyGroupAssigner1.getNumberKeyGroups());
+		assertEquals(mapParallelism, hashKeyGroupAssigner2.getNumberKeyGroups());
+		assertEquals(maxParallelism, hashKeyGroupAssigner3.getNumberKeyGroups());
+		assertEquals(maxParallelism, hashKeyGroupAssigner4.getNumberKeyGroups());
+	}
+
+	/**
+	 * Tests that the {@link KeyGroupAssigner} is properly set in the {@link StreamConfig} for
+	 * connected streams.
+	 */
+	@Test
+	public void testMaxParallelismWithConnectedKeyedStream() {
+		int maxParallelism = 42;
+
+		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+		DataStream<Integer> input1 = env.fromElements(1, 2, 3, 4).setMaxParallelism(128).name("input1");
+		DataStream<Integer> input2 = env.fromElements(1, 2, 3, 4).setMaxParallelism(129).name("input2");
+
+		env.getConfig().setMaxParallelism(maxParallelism);
+
+		DataStream<Integer> keyedResult = input1.connect(input2).keyBy(
+			new KeySelector<Integer, Integer>() {
+				private static final long serialVersionUID = -6908614081449363419L;
+
+				@Override
+				public Integer getKey(Integer value) throws Exception {
+					return value;
+				}
+			},
+			new KeySelector<Integer, Integer>() {
+				private static final long serialVersionUID = 3195683453223164931L;
+
+				@Override
+				public Integer getKey(Integer value) throws Exception {
+					return value;
+				}
+			}).map(new StreamGraphGeneratorTest.NoOpIntCoMap());
+
+		keyedResult.addSink(new DiscardingSink<Integer>());
+
+		JobGraph jobGraph = env.getStreamGraph().getJobGraph();
+
+		List<JobVertex> jobVertices = jobGraph.getVerticesSortedTopologicallyFromSources();
+
+		JobVertex input1JV = jobVertices.get(0);
+		JobVertex input2JV = jobVertices.get(1);
+		JobVertex connectedJV = jobVertices.get(2);
+
+		// disambiguate the partial order of the inputs
+		if (input1JV.getName().equals("Source: input1")) {
+			assertEquals(128, input1JV.getMaxParallelism());
+			assertEquals(129, input2JV.getMaxParallelism());
+		} else {
+			assertEquals(128, input2JV.getMaxParallelism());
+			assertEquals(129, input1JV.getMaxParallelism());
+		}
+
+		assertEquals(maxParallelism, connectedJV.getMaxParallelism());
+
+		HashKeyGroupAssigner<Integer> hashKeyGroupAssigner = extractHashKeyGroupAssigner(connectedJV);
+
+		assertEquals(maxParallelism, hashKeyGroupAssigner.getNumberKeyGroups());
+	}
+
+	/**
+	 * Tests that the {@link JobGraph} creation fails if the parallelism is greater than the max
+	 * parallelism.
+	 */
+	@Test(expected=IllegalStateException.class)
+	public void testFailureOfJobJobCreationIfParallelismGreaterThanMaxParallelism() {
+		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+		env.getConfig().setMaxParallelism(42);
+
+		DataStream<Integer> input = env.fromElements(1, 2, 3, 4);
+
+		DataStream<Integer> result = input.map(new NoOpIntMap()).setParallelism(43);
+
+		result.addSink(new DiscardingSink<Integer>());
+
+		env.getStreamGraph().getJobGraph();
+
+		fail("The JobGraph should not have been created because the parallelism is greater than " +
+			"the max parallelism.");
+	}
+
+	private HashKeyGroupAssigner<Integer> extractHashKeyGroupAssigner(JobVertex jobVertex) {
+		Configuration config = jobVertex.getConfiguration();
+
+		StreamConfig streamConfig = new StreamConfig(config);
+
+		KeyGroupAssigner<Integer> keyGroupAssigner = streamConfig.getKeyGroupAssigner(getClass().getClassLoader());
+
+		assertTrue(keyGroupAssigner instanceof HashKeyGroupAssigner);
+
+		return (HashKeyGroupAssigner<Integer>) keyGroupAssigner;
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/StreamingJobGraphGeneratorNodeHashTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/StreamingJobGraphGeneratorNodeHashTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/StreamingJobGraphGeneratorNodeHashTest.java
index bcf621a..340981b 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/StreamingJobGraphGeneratorNodeHashTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/graph/StreamingJobGraphGeneratorNodeHashTest.java
@@ -27,12 +27,11 @@ import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
-import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
 import org.apache.flink.streaming.api.functions.sink.SinkFunction;
 import org.apache.flink.streaming.api.functions.source.ParallelSourceFunction;
 import org.apache.flink.streaming.api.graph.StreamGraph;
 import org.apache.flink.streaming.api.graph.StreamNode;
-
+import org.apache.flink.util.TestLogger;
 import org.junit.Test;
 
 import java.util.HashMap;
@@ -52,7 +51,7 @@ import static org.junit.Assert.assertTrue;
  * {@link JobGraph} instances.
  */
 @SuppressWarnings("serial")
-public class StreamingJobGraphGeneratorNodeHashTest {
+public class StreamingJobGraphGeneratorNodeHashTest extends TestLogger {
 
 	// ------------------------------------------------------------------------
 	// Deterministic hash assignment
@@ -126,53 +125,6 @@ public class StreamingJobGraphGeneratorNodeHashTest {
 	}
 
 	/**
-	 * Verifies that parallelism affects the node hash.
-	 */
-	@Test
-	public void testNodeHashParallelism() throws Exception {
-		StreamExecutionEnvironment env = StreamExecutionEnvironment.createLocalEnvironment();
-		env.disableOperatorChaining();
-
-		env.addSource(new NoOpSourceFunction(), "src").setParallelism(4)
-				.addSink(new DiscardingSink<String>()).name("sink").setParallelism(4);
-
-		JobGraph jobGraph = env.getStreamGraph().getJobGraph();
-
-		Map<JobVertexID, String> ids = rememberIds(jobGraph);
-
-		// Change parallelism of source
-		env = StreamExecutionEnvironment.createLocalEnvironment();
-		env.disableOperatorChaining();
-
-		env.addSource(new NoOpSourceFunction(), "src").setParallelism(8)
-				.addSink(new DiscardingSink<String>()).name("sink").setParallelism(4);
-
-		jobGraph = env.getStreamGraph().getJobGraph();
-
-		verifyIdsNotEqual(jobGraph, ids);
-
-		// Change parallelism of sink
-		env = StreamExecutionEnvironment.createLocalEnvironment();
-		env.disableOperatorChaining();
-
-		env.addSource(new NoOpSourceFunction(), "src").setParallelism(4)
-				.addSink(new DiscardingSink<String>()).name("sink").setParallelism(8);
-
-		jobGraph = env.getStreamGraph().getJobGraph();
-
-		// The source hash will should be the same
-		JobVertex[] vertices = jobGraph.getVerticesAsArray();
-		if (vertices[0].isInputVertex()) {
-			assertTrue(ids.containsKey(vertices[0].getID()));
-			assertFalse(ids.containsKey(vertices[1].getID()));
-		}
-		else {
-			assertTrue(ids.containsKey(vertices[1].getID()));
-			assertFalse(ids.containsKey(vertices[0].getID()));
-		}
-	}
-
-	/**
 	 * Tests that there are no collisions with two identical sources.
 	 *
 	 * <pre>
@@ -516,6 +468,8 @@ public class StreamingJobGraphGeneratorNodeHashTest {
 
 	private static class NoOpSourceFunction implements ParallelSourceFunction<String> {
 
+		private static final long serialVersionUID = -5459224792698512636L;
+
 		@Override
 		public void run(SourceContext<String> ctx) throws Exception {
 		}
@@ -527,6 +481,8 @@ public class StreamingJobGraphGeneratorNodeHashTest {
 
 	private static class NoOpSinkFunction implements SinkFunction<String> {
 
+		private static final long serialVersionUID = -5654199886203297279L;
+
 		@Override
 		public void invoke(String value) throws Exception {
 		}
@@ -534,6 +490,8 @@ public class StreamingJobGraphGeneratorNodeHashTest {
 
 	private static class NoOpMapFunction implements MapFunction<String, String> {
 
+		private static final long serialVersionUID = 6584823409744624276L;
+
 		@Override
 		public String map(String value) throws Exception {
 			return value;
@@ -542,6 +500,8 @@ public class StreamingJobGraphGeneratorNodeHashTest {
 
 	private static class NoOpFilterFunction implements FilterFunction<String> {
 
+		private static final long serialVersionUID = 500005424900187476L;
+
 		@Override
 		public boolean filter(String value) throws Exception {
 			return true;
@@ -550,6 +510,8 @@ public class StreamingJobGraphGeneratorNodeHashTest {
 
 	private static class NoOpKeySelector implements KeySelector<String, String> {
 
+		private static final long serialVersionUID = -96127515593422991L;
+
 		@Override
 		public String getKey(String value) throws Exception {
 			return value;
@@ -557,6 +519,8 @@ public class StreamingJobGraphGeneratorNodeHashTest {
 	}
 
 	private static class NoOpReduceFunction implements ReduceFunction<String> {
+		private static final long serialVersionUID = -8775747640749256372L;
+
 		@Override
 		public String reduce(String value1, String value2) throws Exception {
 			return value1;

http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AllWindowTranslationTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AllWindowTranslationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AllWindowTranslationTest.java
index ebe6bea..7ac9e13 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AllWindowTranslationTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/AllWindowTranslationTest.java
@@ -253,6 +253,8 @@ public class AllWindowTranslationTest {
 
 		try {
 			windowedStream.fold("", new FoldFunction<String, String>() {
+				private static final long serialVersionUID = -8722899157560218917L;
+
 				@Override
 				public String fold(String accumulator, String value) throws Exception {
 					return accumulator;
@@ -278,6 +280,8 @@ public class AllWindowTranslationTest {
 
 		try {
 			windowedStream.trigger(new Trigger<String, TimeWindow>() {
+				private static final long serialVersionUID = 8360971631424870421L;
+
 				@Override
 				public TriggerResult onElement(String element,
 						long timestamp,

http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowTranslationTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowTranslationTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowTranslationTest.java
index 39d89cf..2707108 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowTranslationTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/operators/windowing/WindowTranslationTest.java
@@ -76,6 +76,8 @@ public class WindowTranslationTest {
 			.keyBy(0)
 			.window(SlidingEventTimeWindows.of(Time.of(1, TimeUnit.SECONDS), Time.of(100, TimeUnit.MILLISECONDS)))
 			.reduce(new RichReduceFunction<Tuple2<String, Integer>>() {
+				private static final long serialVersionUID = -6448847205314995812L;
+
 				@Override
 				public Tuple2<String, Integer> reduce(Tuple2<String, Integer> value1,
 					Tuple2<String, Integer> value2) throws Exception {
@@ -242,6 +244,8 @@ public class WindowTranslationTest {
 
 		WindowedStream<String, String, TimeWindow> windowedStream = env.fromElements("Hello", "Ciao")
 				.keyBy(new KeySelector<String, String>() {
+					private static final long serialVersionUID = -3298887124448443076L;
+
 					@Override
 					public String getKey(String value) throws Exception {
 						return value;
@@ -251,6 +255,8 @@ public class WindowTranslationTest {
 
 		try {
 			windowedStream.fold("", new FoldFunction<String, String>() {
+				private static final long serialVersionUID = -4567902917104921706L;
+
 				@Override
 				public String fold(String accumulator, String value) throws Exception {
 					return accumulator;
@@ -273,6 +279,8 @@ public class WindowTranslationTest {
 
 		WindowedStream<String, String, TimeWindow> windowedStream = env.fromElements("Hello", "Ciao")
 				.keyBy(new KeySelector<String, String>() {
+					private static final long serialVersionUID = 598309916882894293L;
+
 					@Override
 					public String getKey(String value) throws Exception {
 						return value;
@@ -282,6 +290,8 @@ public class WindowTranslationTest {
 
 		try {
 			windowedStream.trigger(new Trigger<String, TimeWindow>() {
+				private static final long serialVersionUID = 6558046711583024443L;
+
 				@Override
 				public TriggerResult onElement(String element,
 						long timestamp,

http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/HashPartitionerTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/HashPartitionerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/HashPartitionerTest.java
deleted file mode 100644
index 6dbf932..0000000
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/HashPartitionerTest.java
+++ /dev/null
@@ -1,71 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.streaming.runtime.partitioner;
-
-import static org.junit.Assert.assertArrayEquals;
-import static org.junit.Assert.assertEquals;
-
-import org.apache.flink.api.java.functions.KeySelector;
-import org.apache.flink.api.java.tuple.Tuple2;
-import org.apache.flink.runtime.plugable.SerializationDelegate;
-import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
-import org.junit.Before;
-import org.junit.Test;
-
-public class HashPartitionerTest {
-
-	private HashPartitioner<Tuple2<String, Integer>> hashPartitioner;
-	private StreamRecord<Tuple2<String, Integer>> streamRecord1 = new StreamRecord<Tuple2<String, Integer>>(new Tuple2<String, Integer>("test", 0));
-	private StreamRecord<Tuple2<String, Integer>> streamRecord2 = new StreamRecord<Tuple2<String, Integer>>(new Tuple2<String, Integer>("test", 42));
-	private SerializationDelegate<StreamRecord<Tuple2<String, Integer>>> sd1 = new SerializationDelegate<StreamRecord<Tuple2<String, Integer>>>(null);
-	private SerializationDelegate<StreamRecord<Tuple2<String, Integer>>> sd2 = new SerializationDelegate<StreamRecord<Tuple2<String, Integer>>>(null);
-
-	@Before
-	public void setPartitioner() {
-		hashPartitioner = new HashPartitioner<Tuple2<String, Integer>>(new KeySelector<Tuple2<String, Integer>, String>() {
-
-			private static final long serialVersionUID = 1L;
-
-			@Override
-			public String getKey(Tuple2<String, Integer> value) throws Exception {
-				return value.getField(0);
-			}
-		});
-	}
-
-	@Test
-	public void testSelectChannelsLength() {
-		sd1.setInstance(streamRecord1);
-		assertEquals(1, hashPartitioner.selectChannels(sd1, 1).length);
-		assertEquals(1, hashPartitioner.selectChannels(sd1, 2).length);
-		assertEquals(1, hashPartitioner.selectChannels(sd1, 1024).length);
-	}
-
-	@Test
-	public void testSelectChannelsGrouping() {
-		sd1.setInstance(streamRecord1);
-		sd2.setInstance(streamRecord2);
-
-		assertArrayEquals(hashPartitioner.selectChannels(sd1, 1),
-				hashPartitioner.selectChannels(sd2, 1));
-		assertArrayEquals(hashPartitioner.selectChannels(sd1, 2),
-				hashPartitioner.selectChannels(sd2, 2));
-		assertArrayEquals(hashPartitioner.selectChannels(sd1, 1024),
-				hashPartitioner.selectChannels(sd2, 1024));
-	}
-}

http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitionerTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitionerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitionerTest.java
new file mode 100644
index 0000000..6fbf35e
--- /dev/null
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/KeyGroupStreamPartitionerTest.java
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.runtime.partitioner;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.runtime.plugable.SerializationDelegate;
+import org.apache.flink.runtime.state.HashKeyGroupAssigner;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.util.TestLogger;
+import org.junit.Before;
+import org.junit.Test;
+
+public class KeyGroupStreamPartitionerTest extends TestLogger {
+
+	private KeyGroupStreamPartitioner<Tuple2<String, Integer>, String> keyGroupPartitioner;
+	private StreamRecord<Tuple2<String, Integer>> streamRecord1 = new StreamRecord<Tuple2<String, Integer>>(new Tuple2<String, Integer>("test", 0));
+	private StreamRecord<Tuple2<String, Integer>> streamRecord2 = new StreamRecord<Tuple2<String, Integer>>(new Tuple2<String, Integer>("test", 42));
+	private SerializationDelegate<StreamRecord<Tuple2<String, Integer>>> sd1 = new SerializationDelegate<StreamRecord<Tuple2<String, Integer>>>(null);
+	private SerializationDelegate<StreamRecord<Tuple2<String, Integer>>> sd2 = new SerializationDelegate<StreamRecord<Tuple2<String, Integer>>>(null);
+
+	@Before
+	public void setPartitioner() {
+		keyGroupPartitioner = new KeyGroupStreamPartitioner<Tuple2<String, Integer>, String>(new KeySelector<Tuple2<String, Integer>, String>() {
+
+			private static final long serialVersionUID = 1L;
+
+			@Override
+			public String getKey(Tuple2<String, Integer> value) throws Exception {
+				return value.getField(0);
+			}
+		},
+		new HashKeyGroupAssigner<String>(1024));
+	}
+
+	@Test
+	public void testSelectChannelsLength() {
+		sd1.setInstance(streamRecord1);
+		assertEquals(1, keyGroupPartitioner.selectChannels(sd1, 1).length);
+		assertEquals(1, keyGroupPartitioner.selectChannels(sd1, 2).length);
+		assertEquals(1, keyGroupPartitioner.selectChannels(sd1, 1024).length);
+	}
+
+	@Test
+	public void testSelectChannelsGrouping() {
+		sd1.setInstance(streamRecord1);
+		sd2.setInstance(streamRecord2);
+
+		assertArrayEquals(keyGroupPartitioner.selectChannels(sd1, 1),
+				keyGroupPartitioner.selectChannels(sd2, 1));
+		assertArrayEquals(keyGroupPartitioner.selectChannels(sd1, 2),
+				keyGroupPartitioner.selectChannels(sd2, 2));
+		assertArrayEquals(keyGroupPartitioner.selectChannels(sd1, 1024),
+				keyGroupPartitioner.selectChannels(sd2, 1024));
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitionerTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitionerTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitionerTest.java
index 8c7360a..37ea68a 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitionerTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/partitioner/RescalePartitionerTest.java
@@ -94,6 +94,8 @@ public class RescalePartitionerTest extends TestLogger {
 
 		// get input data
 		DataStream<String> text = env.addSource(new ParallelSourceFunction<String>() {
+			private static final long serialVersionUID = 7772338606389180774L;
+
 			@Override
 			public void run(SourceContext<String> ctx) throws Exception {
 
@@ -108,6 +110,8 @@ public class RescalePartitionerTest extends TestLogger {
 		DataStream<Tuple2<String, Integer>> counts = text
 			.rescale()
 			.flatMap(new FlatMapFunction<String, Tuple2<String, Integer>>() {
+				private static final long serialVersionUID = -5255930322161596829L;
+
 				@Override
 				public void flatMap(String value,
 					Collector<Tuple2<String, Integer>> out) throws Exception {

http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java
index 145edc2..5f73e25 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTestHarness.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.java.ClosureCleaner;
 import org.apache.flink.api.java.functions.KeySelector;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.io.network.partition.consumer.StreamTestSingleInputGate;
+import org.apache.flink.runtime.state.HashKeyGroupAssigner;
 
 import java.io.IOException;
 
@@ -105,6 +106,7 @@ public class OneInputStreamTaskTestHarness<IN, OUT> extends StreamTaskTestHarnes
 		ClosureCleaner.clean(keySelector, false);
 		streamConfig.setStatePartitioner(0, keySelector);
 		streamConfig.setStateKeySerializer(keyType.createSerializer(executionConfig));
+		streamConfig.setKeyGroupAssigner(new HashKeyGroupAssigner<K>(10));
 	}
 }
 

http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
index bcd8a5f..3d9d50f 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
@@ -19,8 +19,11 @@
 package org.apache.flink.streaming.runtime.tasks;
 
 import akka.actor.ActorRef;
+
+import akka.dispatch.Futures;
 import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.blob.BlobKey;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
@@ -40,9 +43,11 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
+import org.apache.flink.runtime.messages.TaskMessages;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.taskmanager.Task;
 import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
+import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.api.graph.StreamConfig;
 import org.apache.flink.streaming.api.operators.Output;
@@ -51,17 +56,27 @@ import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
 import org.apache.flink.util.ExceptionUtils;
 import org.apache.flink.util.SerializedValue;
 import org.junit.Test;
+
+import scala.concurrent.Await;
 import scala.concurrent.ExecutionContext;
 import scala.concurrent.Future;
+import scala.concurrent.duration.Deadline;
 import scala.concurrent.duration.FiniteDuration;
+import scala.concurrent.impl.Promise;
 
 import java.io.IOException;
 import java.io.ObjectInputStream;
 import java.net.URL;
 import java.util.Collections;
+import java.util.Comparator;
+import java.util.PriorityQueue;
 import java.util.UUID;
 import java.util.concurrent.TimeUnit;
 
+import static org.mockito.Matchers.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
@@ -72,55 +87,140 @@ import static org.mockito.Mockito.when;
 
 public class StreamTaskTest {
 
-	/**
+		/**
 	 * This test checks that cancel calls that are issued before the operator is
 	 * instantiated still lead to proper canceling.
 	 */
 	@Test
-	public void testEarlyCanceling() {
-		try {
-			StreamConfig cfg = new StreamConfig(new Configuration());
-			cfg.setStreamOperator(new SlowlyDeserializingOperator());
-			
-			Task task = createTask(SourceStreamTask.class, cfg);
-			task.startTaskThread();
-			
-			// wait until the task thread reached state RUNNING 
-			while (task.getExecutionState() == ExecutionState.CREATED ||
-					task.getExecutionState() == ExecutionState.DEPLOYING)
-			{
-				Thread.sleep(5);
-			}
-			
-			// make sure the task is really running
-			if (task.getExecutionState() != ExecutionState.RUNNING) {
-				fail("Task entered state " + task.getExecutionState() + " with error "
-						+ ExceptionUtils.stringifyException(task.getFailureCause()));
-			}
-			
-			// send a cancel. because the operator takes a long time to deserialize, this should
-			// hit the task before the operator is deserialized
-			task.cancelExecution();
-			
-			// the task should reach state canceled eventually
-			assertTrue(task.getExecutionState() == ExecutionState.CANCELING ||
-					task.getExecutionState() == ExecutionState.CANCELED);
-			
-			task.getExecutingThread().join(60000);
-			
-			assertFalse("Task did not cancel", task.getExecutingThread().isAlive());
-			assertEquals(ExecutionState.CANCELED, task.getExecutionState());
-		}
-		catch (Exception e) {
-			e.printStackTrace();
-			fail(e.getMessage());
+	public void testEarlyCanceling() throws Exception {
+		Deadline deadline = new FiniteDuration(2, TimeUnit.MINUTES).fromNow();
+		StreamConfig cfg = new StreamConfig(new Configuration());
+		cfg.setStreamOperator(new SlowlyDeserializingOperator());
+		cfg.setTimeCharacteristic(TimeCharacteristic.ProcessingTime);
+
+		Task task = createTask(SourceStreamTask.class, cfg);
+
+		ExecutionStateListener executionStateListener = new ExecutionStateListener();
+
+		task.registerExecutionListener(executionStateListener);
+		task.startTaskThread();
+
+		Future<ExecutionState> running = executionStateListener.notifyWhenExecutionState(ExecutionState.RUNNING);
+
+		// wait until the task thread reached state RUNNING
+		ExecutionState executionState = Await.result(running, deadline.timeLeft());
+
+		// make sure the task is really running
+		if (executionState != ExecutionState.RUNNING) {
+			fail("Task entered state " + task.getExecutionState() + " with error "
+					+ ExceptionUtils.stringifyException(task.getFailureCause()));
 		}
+
+		// send a cancel. because the operator takes a long time to deserialize, this should
+		// hit the task before the operator is deserialized
+		task.cancelExecution();
+
+		Future<ExecutionState> canceling = executionStateListener.notifyWhenExecutionState(ExecutionState.CANCELING);
+
+		executionState = Await.result(canceling, deadline.timeLeft());
+
+		// the task should reach state canceled eventually
+		assertTrue(executionState == ExecutionState.CANCELING ||
+				executionState == ExecutionState.CANCELED);
+
+		task.getExecutingThread().join(deadline.timeLeft().toMillis());
+
+		assertFalse("Task did not cancel", task.getExecutingThread().isAlive());
+		assertEquals(ExecutionState.CANCELED, task.getExecutionState());
 	}
 
+
 	// ------------------------------------------------------------------------
 	//  Test Utilities
 	// ------------------------------------------------------------------------
 
+	private static class ExecutionStateListener implements ActorGateway {
+
+		private static final long serialVersionUID = 8926442805035692182L;
+
+		ExecutionState executionState = null;
+
+		PriorityQueue<Tuple2<ExecutionState, Promise<ExecutionState>>> priorityQueue = new PriorityQueue<>(
+			1,
+			new Comparator<Tuple2<ExecutionState, Promise<ExecutionState>>>() {
+				@Override
+				public int compare(Tuple2<ExecutionState, Promise<ExecutionState>> o1, Tuple2<ExecutionState, Promise<ExecutionState>> o2) {
+					return o1.f0.ordinal() - o2.f0.ordinal();
+				}
+			});
+
+		public Future<ExecutionState> notifyWhenExecutionState(ExecutionState executionState) {
+			synchronized (priorityQueue) {
+				if (this.executionState != null && this.executionState.ordinal() >= executionState.ordinal()) {
+					return Futures.successful(executionState);
+				} else {
+					Promise<ExecutionState> promise = new Promise.DefaultPromise<ExecutionState>();
+
+					priorityQueue.offer(Tuple2.of(executionState, promise));
+
+					return promise.future();
+				}
+			}
+		}
+
+		@Override
+		public Future<Object> ask(Object message, FiniteDuration timeout) {
+			return null;
+		}
+
+		@Override
+		public void tell(Object message) {
+			this.tell(message, null);
+		}
+
+		@Override
+		public void tell(Object message, ActorGateway sender) {
+			if (message instanceof TaskMessages.UpdateTaskExecutionState) {
+				TaskMessages.UpdateTaskExecutionState updateTaskExecutionState = (TaskMessages.UpdateTaskExecutionState) message;
+
+				synchronized (priorityQueue) {
+					this.executionState = updateTaskExecutionState.taskExecutionState().getExecutionState();
+
+					while (!priorityQueue.isEmpty() && priorityQueue.peek().f0.ordinal() <= this.executionState.ordinal()) {
+						Promise<ExecutionState> promise = priorityQueue.poll().f1;
+
+						promise.success(this.executionState);
+					}
+				}
+			}
+		}
+
+		@Override
+		public void forward(Object message, ActorGateway sender) {
+
+		}
+
+		@Override
+		public Future<Object> retry(Object message, int numberRetries, FiniteDuration timeout, ExecutionContext executionContext) {
+			return null;
+		}
+
+		@Override
+		public String path() {
+			return null;
+		}
+
+		@Override
+		public ActorRef actor() {
+			return null;
+		}
+
+		@Override
+		public UUID leaderSessionID() {
+			return null;
+		}
+	}
+
 	private Task createTask(Class<? extends AbstractInvokable> invokable, StreamConfig taskConfig) throws Exception {
 		LibraryCacheManager libCache = mock(LibraryCacheManager.class);
 		when(libCache.getClassLoader(any(JobID.class))).thenReturn(getClass().getClassLoader());

http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
index 00e95b9..cb10c5c 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTestHarness.java
@@ -186,23 +186,58 @@ public class StreamTaskTestHarness<OUT> {
 		taskThread.start();
 	}
 
+	/**
+	 * Waits for the task completion.
+	 *
+	 * @throws Exception
+	 */
 	public void waitForTaskCompletion() throws Exception {
+		waitForTaskCompletion(Long.MAX_VALUE);
+	}
+
+	/**
+	 * Waits for the task completion. If this does not happen within the timeout, then a
+	 * TimeoutException is thrown.
+	 *
+	 * @param timeout Timeout for the task completion
+	 * @throws Exception
+	 */
+	public void waitForTaskCompletion(long timeout) throws Exception {
 		if (taskThread == null) {
 			throw new IllegalStateException("Task thread was not started.");
 		}
 
-		taskThread.join();
+		taskThread.join(timeout);
 		if (taskThread.getError() != null) {
 			throw new Exception("error in task", taskThread.getError());
 		}
 	}
 
+	/**
+	 * Waits for the task to be running.
+	 *
+	 * @throws Exception
+	 */
 	public void waitForTaskRunning() throws Exception {
+		waitForTaskRunning(Long.MAX_VALUE);
+	}
+
+	/**
+	 * Waits fro the task to be running. If this does not happen within the timeout, then a
+	 * TimeoutException is thrown.
+	 *
+	 * @param timeout Timeout for the task to be running.
+	 * @throws Exception
+	 */
+	public void waitForTaskRunning(long timeout) throws Exception {
 		if (taskThread == null) {
 			throw new IllegalStateException("Task thread was not started.");
 		}
 		else {
 			if (taskThread.task instanceof StreamTask) {
+				long base = System.currentTimeMillis();
+				long now = 0;
+
 				StreamTask<?, ?> streamTask = (StreamTask<?, ?>) taskThread.task;
 				while (!streamTask.isRunning()) {
 					Thread.sleep(100);

http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-java/src/test/resources/log4j-test.properties
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/resources/log4j-test.properties b/flink-streaming-java/src/test/resources/log4j-test.properties
index 0b686e5..881dc06 100644
--- a/flink-streaming-java/src/test/resources/log4j-test.properties
+++ b/flink-streaming-java/src/test/resources/log4j-test.properties
@@ -24,4 +24,4 @@ log4j.appender.A1=org.apache.log4j.ConsoleAppender
 
 # A1 uses PatternLayout.
 log4j.appender.A1.layout=org.apache.log4j.PatternLayout
-log4j.appender.A1.layout.ConversionPattern=%-4r [%t] %-5p %c %x - %m%n
\ No newline at end of file
+log4j.appender.A1.layout.ConversionPattern=%-4r [%t] %-5p %c %x - %m%n

http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala
----------------------------------------------------------------------
diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala
index 8693834..4fe73e9 100644
--- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala
+++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/DataStream.scala
@@ -133,6 +133,17 @@ class DataStream[T](stream: JavaStream[T]) {
     this
   }
 
+  def setMaxParallelism(maxParallelism: Int): DataStream[T] = {
+    stream match {
+      case ds: SingleOutputStreamOperator[T] => ds.setMaxParallelism(maxParallelism)
+      case _ =>
+        throw new UnsupportedOperationException("Operator " + stream + " cannot set the maximum" +
+                                                  "paralllelism")
+    }
+
+    this
+  }
+
   /**
    * Gets the name of the current data stream. This name is
    * used by the visualization and logging during runtime.

http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/StreamExecutionEnvironment.scala
----------------------------------------------------------------------
diff --git a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/StreamExecutionEnvironment.scala b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/StreamExecutionEnvironment.scala
index 9cb36a5..2e432ba 100644
--- a/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/StreamExecutionEnvironment.scala
+++ b/flink-streaming-scala/src/main/scala/org/apache/flink/streaming/api/scala/StreamExecutionEnvironment.scala
@@ -59,12 +59,30 @@ class StreamExecutionEnvironment(javaEnv: JavaEnv) {
   }
 
   /**
+    * Sets the maximum degree of parallelism defined for the program.
+    * The maximum degree of parallelism specifies the upper limit for dynamic scaling. It also
+    * defines the number of key groups used for partitioned state.
+    **/
+  def setMaxParallelism(maxParallelism: Int): Unit = {
+    javaEnv.setMaxParallelism(maxParallelism)
+  }
+
+  /**
    * Returns the default parallelism for this execution environment. Note that this
    * value can be overridden by individual operations using [[DataStream#setParallelism(int)]]
    */
   def getParallelism = javaEnv.getParallelism
 
   /**
+    * Returns the maximum degree of parallelism defined for the program.
+    *
+    * The maximum degree of parallelism specifies the upper limit for dynamic scaling. It also
+    * defines the number of key groups used for partitioned state.
+    *
+    */
+  def getMaxParallelism = javaEnv.getMaxParallelism
+
+  /**
    * Sets the maximum time frequency (milliseconds) for the flushing of the
    * output buffers. By default the output buffers flush frequently to provide
    * low latency and to aid smooth developer experience. Setting the parameter

http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/DataStreamTest.scala
----------------------------------------------------------------------
diff --git a/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/DataStreamTest.scala b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/DataStreamTest.scala
index 16fcfc3..b73eae8 100644
--- a/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/DataStreamTest.scala
+++ b/flink-streaming-scala/src/test/scala/org/apache/flink/streaming/api/scala/DataStreamTest.scala
@@ -512,7 +512,7 @@ class DataStreamTest extends StreamingMultipleProgramsTestBase {
 
   private def isPartitioned(edges: java.util.List[StreamEdge]): Boolean = {
     import scala.collection.JavaConverters._
-    edges.asScala.forall( _.getPartitioner.isInstanceOf[HashPartitioner[_]])
+    edges.asScala.forall( _.getPartitioner.isInstanceOf[KeyGroupStreamPartitioner[_, _]])
   }
 
   private def isCustomPartitioned(edges: java.util.List[StreamEdge]): Boolean = {

http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-tests/src/test/java/org/apache/flink/test/checkpointing/EventTimeAllWindowCheckpointingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/EventTimeAllWindowCheckpointingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/EventTimeAllWindowCheckpointingITCase.java
index 6faee45..163fb42 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/EventTimeAllWindowCheckpointingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/EventTimeAllWindowCheckpointingITCase.java
@@ -162,7 +162,7 @@ public class EventTimeAllWindowCheckpointingITCase extends TestLogger {
 			env.setParallelism(PARALLELISM);
 			env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime);
 			env.enableCheckpointing(100);
-			env.setRestartStrategy(RestartStrategies.fixedDelayRestart(3, 0));
+			env.setRestartStrategy(RestartStrategies.fixedDelayRestart(1, 0));
 			env.getConfig().disableSysoutLogging();
 
 			env

http://git-wip-us.apache.org/repos/asf/flink/blob/ec975aab/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
new file mode 100644
index 0000000..0de2a75
--- /dev/null
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
@@ -0,0 +1,683 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.test.checkpointing;
+
+import io.netty.util.internal.ConcurrentSet;
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.state.ValueState;
+import org.apache.flink.api.common.state.ValueStateDescriptor;
+import org.apache.flink.api.java.functions.KeySelector;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.configuration.ConfigConstants;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.runtime.checkpoint.savepoint.SavepointStoreFactory;
+import org.apache.flink.runtime.client.JobExecutionException;
+import org.apache.flink.runtime.execution.SuppressRestartsException;
+import org.apache.flink.runtime.instance.ActorGateway;
+import org.apache.flink.runtime.jobgraph.JobGraph;
+import org.apache.flink.runtime.messages.JobManagerMessages;
+import org.apache.flink.runtime.state.HashKeyGroupAssigner;
+import org.apache.flink.runtime.state.filesystem.FsStateBackendFactory;
+import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages;
+import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
+import org.apache.flink.streaming.api.functions.sink.SinkFunction;
+import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.test.util.ForkableFlinkMiniCluster;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.TestLogger;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.ClassRule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+import scala.concurrent.Await;
+import scala.concurrent.Future;
+import scala.concurrent.duration.Deadline;
+import scala.concurrent.duration.FiniteDuration;
+
+import java.io.File;
+import java.util.HashSet;
+import java.util.Set;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+public class RescalingITCase extends TestLogger {
+
+	private static int numTaskManagers = 2;
+	private static int slotsPerTaskManager = 2;
+	private static int numSlots = numTaskManagers * slotsPerTaskManager;
+
+	private static ForkableFlinkMiniCluster cluster;
+
+	@ClassRule
+	public static TemporaryFolder temporaryFolder = new TemporaryFolder();
+
+	@BeforeClass
+	public static void setup() throws Exception {
+		Configuration config = new Configuration();
+		config.setInteger(ConfigConstants.LOCAL_NUMBER_TASK_MANAGER, numTaskManagers);
+		config.setInteger(ConfigConstants.TASK_MANAGER_NUM_TASK_SLOTS, slotsPerTaskManager);
+
+		final File checkpointDir = temporaryFolder.newFolder();
+		final File savepointDir = temporaryFolder.newFolder();
+
+		config.setString(ConfigConstants.STATE_BACKEND, "filesystem");
+		config.setString(FsStateBackendFactory.CHECKPOINT_DIRECTORY_URI_CONF_KEY, checkpointDir.toURI().toString());
+		config.setString(SavepointStoreFactory.SAVEPOINT_BACKEND_KEY, "filesystem");
+		config.setString(SavepointStoreFactory.SAVEPOINT_DIRECTORY_KEY, savepointDir.toURI().toString());
+
+		cluster = new ForkableFlinkMiniCluster(config);
+		cluster.start();
+	}
+
+	@AfterClass
+	public static void teardown() {
+		if (cluster != null) {
+			cluster.shutdown();
+		}
+	}
+
+	/**
+	 * Tests that a a job with purely partitioned state can be restarted from a savepoint
+	 * with a different parallelism.
+	 */
+	@Test
+	public void testSavepointRescalingWithPartitionedState() throws Exception {
+		int numberKeys = 42;
+		int numberElements = 1000;
+		int numberElements2 = 500;
+		int parallelism = numSlots / 2;
+		int parallelism2 = numSlots;
+		int maxParallelism = 13;
+
+		FiniteDuration timeout = new FiniteDuration(3, TimeUnit.MINUTES);
+		Deadline deadline = timeout.fromNow();
+
+		ActorGateway jobManager = null;
+		JobID jobID = null;
+
+		try {
+			jobManager = cluster.getLeaderGateway(deadline.timeLeft());
+
+			JobGraph jobGraph = createPartitionedStateJobGraph(parallelism, maxParallelism, numberKeys, numberElements, false, 100);
+
+			jobID = jobGraph.getJobID();
+
+			cluster.submitJobDetached(jobGraph);
+
+			// wait til the sources have emitted numberElements for each key and completed a checkpoint
+			SubtaskIndexFlatMapper.workCompletedLatch.await(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
+
+			// verify the current state
+
+			Set<Tuple2<Integer, Integer>> actualResult = CollectionSink.getElementsSet();
+
+			Set<Tuple2<Integer, Integer>> expectedResult = new HashSet<>();
+
+			HashKeyGroupAssigner<Integer> keyGroupAssigner = new HashKeyGroupAssigner<>(maxParallelism);
+
+			for (int key = 0; key < numberKeys; key++) {
+				int keyGroupIndex = keyGroupAssigner.getKeyGroupIndex(key);
+
+				expectedResult.add(Tuple2.of(keyGroupIndex % parallelism, numberElements * key));
+			}
+
+			assertEquals(expectedResult, actualResult);
+
+			// clear the CollectionSink set for the restarted job
+			CollectionSink.clearElementsSet();
+
+			Future<Object> savepointPathFuture = jobManager.ask(new JobManagerMessages.TriggerSavepoint(jobID), deadline.timeLeft());
+
+			final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess)
+				Await.result(savepointPathFuture, deadline.timeLeft())).savepointPath();
+
+			Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft());
+
+			Future<Object> cancellationResponseFuture = jobManager.ask(new JobManagerMessages.CancelJob(jobID), deadline.timeLeft());
+
+			Object cancellationResponse = Await.result(cancellationResponseFuture, deadline.timeLeft());
+
+			assertTrue(cancellationResponse instanceof JobManagerMessages.CancellationSuccess);
+
+			Await.ready(jobRemovedFuture, deadline.timeLeft());
+
+			jobID = null;
+
+			JobGraph scaledJobGraph = createPartitionedStateJobGraph(parallelism2, maxParallelism, numberKeys, numberElements2, true, 100);
+
+			scaledJobGraph.setSavepointPath(savepointPath);
+
+			jobID = scaledJobGraph.getJobID();
+
+			cluster.submitJobAndWait(scaledJobGraph, false);
+
+			jobID = null;
+
+			Set<Tuple2<Integer, Integer>> actualResult2 = CollectionSink.getElementsSet();
+
+			Set<Tuple2<Integer, Integer>> expectedResult2 = new HashSet<>();
+
+			HashKeyGroupAssigner<Integer> keyGroupAssigner2 = new HashKeyGroupAssigner<>(maxParallelism);
+
+			for (int key = 0; key < numberKeys; key++) {
+				int keyGroupIndex = keyGroupAssigner2.getKeyGroupIndex(key);
+				expectedResult2.add(Tuple2.of(keyGroupIndex % parallelism2, key * (numberElements + numberElements2)));
+			}
+
+			assertEquals(expectedResult2, actualResult2);
+
+		} finally {
+			// clear the CollectionSink set for the restarted job
+			CollectionSink.clearElementsSet();
+
+			// clear any left overs from a possibly failed job
+			if (jobID != null && jobManager != null) {
+				Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), timeout);
+
+				try {
+					Await.ready(jobRemovedFuture, timeout);
+				} catch (TimeoutException | InterruptedException ie) {
+					fail("Failed while cleaning up the cluster.");
+				}
+			}
+		}
+	}
+
+	/**
+	 * Tests that a job cannot be restarted from a savepoint with a different parallelism if the
+	 * rescaled operator has non-partitioned state.
+	 *
+	 * @throws Exception
+	 */
+	@Test
+	public void testSavepointRescalingFailureWithNonPartitionedState() throws Exception {
+		int parallelism = numSlots / 2;
+		int parallelism2 = numSlots;
+		int maxParallelism = 13;
+
+		FiniteDuration timeout = new FiniteDuration(3, TimeUnit.MINUTES);
+		Deadline deadline = timeout.fromNow();
+
+		JobID jobID = null;
+		ActorGateway jobManager = null;
+
+		try {
+			jobManager = cluster.getLeaderGateway(deadline.timeLeft());
+
+			JobGraph jobGraph = createNonPartitionedStateJobGraph(parallelism, maxParallelism, 500);
+
+			jobID = jobGraph.getJobID();
+
+			cluster.submitJobDetached(jobGraph);
+
+			Future<Object> allTasksRunning = jobManager.ask(new TestingJobManagerMessages.WaitForAllVerticesToBeRunning(jobID), deadline.timeLeft());
+
+			Await.ready(allTasksRunning, deadline.timeLeft());
+
+			Future<Object> savepointPathFuture = jobManager.ask(new JobManagerMessages.TriggerSavepoint(jobID), deadline.timeLeft());
+
+			Object savepointResponse = Await.result(savepointPathFuture, deadline.timeLeft());
+
+			assertTrue(savepointResponse instanceof JobManagerMessages.TriggerSavepointSuccess);
+
+			final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess)savepointResponse).savepointPath();
+
+			Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft());
+
+			Future<Object> cancellationResponseFuture = jobManager.ask(new JobManagerMessages.CancelJob(jobID), deadline.timeLeft());
+
+			Object cancellationResponse = Await.result(cancellationResponseFuture, deadline.timeLeft());
+
+			assertTrue(cancellationResponse instanceof JobManagerMessages.CancellationSuccess);
+
+			Await.ready(jobRemovedFuture, deadline.timeLeft());
+
+			// job successfully removed
+			jobID = null;
+
+			JobGraph scaledJobGraph = createNonPartitionedStateJobGraph(parallelism2, maxParallelism, 500);
+
+			scaledJobGraph.setSavepointPath(savepointPath);
+
+			jobID = scaledJobGraph.getJobID();
+
+			cluster.submitJobAndWait(scaledJobGraph, false);
+
+			jobID = null;
+
+		} catch (JobExecutionException exception) {
+			if (exception.getCause() instanceof SuppressRestartsException) {
+				SuppressRestartsException suppressRestartsException = (SuppressRestartsException) exception.getCause();
+
+				if (suppressRestartsException.getCause() instanceof IllegalStateException) {
+					// we expect a IllegalStateException wrapped in a SuppressRestartsException wrapped
+					// in a JobExecutionException, because the job containing non-partitioned state
+					// is being rescaled
+				} else {
+					throw exception;
+				}
+			} else {
+				throw exception;
+			}
+		} finally {
+			// clear any left overs from a possibly failed job
+			if (jobID != null && jobManager != null) {
+				Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), timeout);
+
+				try {
+					Await.ready(jobRemovedFuture, timeout);
+				} catch (TimeoutException | InterruptedException ie) {
+					fail("Failed while cleaning up the cluster.");
+				}
+			}
+		}
+	}
+
+	/**
+	 * Tests that a job with non partitioned state can be restarted from a savepoint with a
+	 * different parallelism if the operator with non-partitioned state are not rescaled.
+	 *
+	 * @throws Exception
+	 */
+	@Test
+	public void testSavepointRescalingWithPartiallyNonPartitionedState() throws Exception {
+		int numberKeys = 42;
+		int numberElements = 1000;
+		int numberElements2 = 500;
+		int parallelism = numSlots / 2;
+		int parallelism2 = numSlots;
+		int maxParallelism = 13;
+
+		FiniteDuration timeout = new FiniteDuration(3, TimeUnit.MINUTES);
+		Deadline deadline = timeout.fromNow();
+
+		ActorGateway jobManager = null;
+		JobID jobID = null;
+
+		try {
+			 jobManager = cluster.getLeaderGateway(deadline.timeLeft());
+
+			JobGraph jobGraph = createPartitionedNonPartitionedStateJobGraph(
+				parallelism,
+				maxParallelism,
+				parallelism,
+				numberKeys,
+				numberElements,
+				false,
+				100);
+
+			jobID = jobGraph.getJobID();
+
+			cluster.submitJobDetached(jobGraph);
+
+			// wait til the sources have emitted numberElements for each key and completed a checkpoint
+			SubtaskIndexFlatMapper.workCompletedLatch.await(deadline.timeLeft().toMillis(), TimeUnit.MILLISECONDS);
+
+			// verify the current state
+
+			Set<Tuple2<Integer, Integer>> actualResult = CollectionSink.getElementsSet();
+
+			Set<Tuple2<Integer, Integer>> expectedResult = new HashSet<>();
+
+			HashKeyGroupAssigner<Integer> keyGroupAssigner = new HashKeyGroupAssigner<>(maxParallelism);
+
+			for (int key = 0; key < numberKeys; key++) {
+				int keyGroupIndex = keyGroupAssigner.getKeyGroupIndex(key);
+
+				expectedResult.add(Tuple2.of(keyGroupIndex % parallelism, numberElements * key));
+			}
+
+			assertEquals(expectedResult, actualResult);
+
+			// clear the CollectionSink set for the restarted job
+			CollectionSink.clearElementsSet();
+
+			Future<Object> savepointPathFuture = jobManager.ask(new JobManagerMessages.TriggerSavepoint(jobID), deadline.timeLeft());
+
+			final String savepointPath = ((JobManagerMessages.TriggerSavepointSuccess)
+				Await.result(savepointPathFuture, deadline.timeLeft())).savepointPath();
+
+			Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), deadline.timeLeft());
+
+			Future<Object> cancellationResponseFuture = jobManager.ask(new JobManagerMessages.CancelJob(jobID), deadline.timeLeft());
+
+			Object cancellationResponse = Await.result(cancellationResponseFuture, deadline.timeLeft());
+
+			assertTrue(cancellationResponse instanceof JobManagerMessages.CancellationSuccess);
+
+			Await.ready(jobRemovedFuture, deadline.timeLeft());
+
+			jobID = null;
+
+			JobGraph scaledJobGraph = createPartitionedNonPartitionedStateJobGraph(
+				parallelism2,
+				maxParallelism,
+				parallelism,
+				numberKeys,
+				numberElements + numberElements2,
+				true,
+				100);
+
+			scaledJobGraph.setSavepointPath(savepointPath);
+
+			jobID = scaledJobGraph.getJobID();
+
+			cluster.submitJobAndWait(scaledJobGraph, false);
+
+			jobID = null;
+
+			Set<Tuple2<Integer, Integer>> actualResult2 = CollectionSink.getElementsSet();
+
+			Set<Tuple2<Integer, Integer>> expectedResult2 = new HashSet<>();
+
+			HashKeyGroupAssigner<Integer> keyGroupAssigner2 = new HashKeyGroupAssigner<>(maxParallelism);
+
+			for (int key = 0; key < numberKeys; key++) {
+				int keyGroupIndex = keyGroupAssigner2.getKeyGroupIndex(key);
+				expectedResult2.add(Tuple2.of(keyGroupIndex % parallelism2, key * (numberElements + numberElements2)));
+			}
+
+			assertEquals(expectedResult2, actualResult2);
+
+		} finally {
+			// clear the CollectionSink set for the restarted job
+			CollectionSink.clearElementsSet();
+
+			// clear any left overs from a possibly failed job
+			if (jobID != null && jobManager != null) {
+				Future<Object> jobRemovedFuture = jobManager.ask(new TestingJobManagerMessages.NotifyWhenJobRemoved(jobID), timeout);
+
+				try {
+					Await.ready(jobRemovedFuture, timeout);
+				} catch (TimeoutException | InterruptedException ie) {
+					fail("Failed while cleaning up the cluster.");
+				}
+			}
+		}
+	}
+
+	private static JobGraph createNonPartitionedStateJobGraph(int parallelism, int maxParallelism, long checkpointInterval) {
+		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+		env.setParallelism(parallelism);
+		env.getConfig().setMaxParallelism(maxParallelism);
+		env.enableCheckpointing(checkpointInterval);
+		env.setRestartStrategy(RestartStrategies.noRestart());
+
+		DataStream<Integer> input = env.addSource(new NonPartitionedStateSource());
+
+		input.addSink(new DiscardingSink<Integer>());
+
+		return env.getStreamGraph().getJobGraph();
+	}
+
+	private static JobGraph createPartitionedStateJobGraph(
+		int parallelism,
+		int maxParallelism,
+		int numberKeys,
+		int numberElements,
+		boolean terminateAfterEmission,
+		int checkpointingInterval) {
+
+		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+		env.setParallelism(parallelism);
+		env.getConfig().setMaxParallelism(maxParallelism);
+		env.enableCheckpointing(checkpointingInterval);
+		env.setRestartStrategy(RestartStrategies.noRestart());
+
+		DataStream<Integer> input = env.addSource(new SubtaskIndexSource(
+			numberKeys,
+			numberElements,
+			terminateAfterEmission))
+			.keyBy(new KeySelector<Integer, Integer>() {
+				private static final long serialVersionUID = -7952298871120320940L;
+
+				@Override
+				public Integer getKey(Integer value) throws Exception {
+					return value;
+				}
+			});
+
+		SubtaskIndexFlatMapper.workCompletedLatch = new CountDownLatch(numberKeys);
+
+		DataStream<Tuple2<Integer, Integer>> result = input.flatMap(new SubtaskIndexFlatMapper(numberElements));
+
+		result.addSink(new CollectionSink());
+
+		return env.getStreamGraph().getJobGraph();
+	}
+
+	private static JobGraph createPartitionedNonPartitionedStateJobGraph(
+		int parallelism,
+		int maxParallelism,
+		int fixedParallelism,
+		int numberKeys,
+		int numberElements,
+		boolean terminateAfterEmission,
+		int checkpointingInterval) {
+
+		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+		env.setParallelism(parallelism);
+		env.getConfig().setMaxParallelism(maxParallelism);
+		env.enableCheckpointing(checkpointingInterval);
+		env.setRestartStrategy(RestartStrategies.noRestart());
+
+		DataStream<Integer> input = env.addSource(new SubtaskIndexNonPartitionedStateSource(
+			numberKeys,
+			numberElements,
+			terminateAfterEmission))
+			.setParallelism(fixedParallelism)
+			.keyBy(new KeySelector<Integer, Integer>() {
+				private static final long serialVersionUID = -7952298871120320940L;
+
+				@Override
+				public Integer getKey(Integer value) throws Exception {
+					return value;
+				}
+			});
+
+		SubtaskIndexFlatMapper.workCompletedLatch = new CountDownLatch(numberKeys);
+
+		DataStream<Tuple2<Integer, Integer>> result = input.flatMap(new SubtaskIndexFlatMapper(numberElements));
+
+		result.addSink(new CollectionSink());
+
+		return env.getStreamGraph().getJobGraph();
+	}
+
+	private static class SubtaskIndexSource
+		extends RichParallelSourceFunction<Integer> {
+
+		private static final long serialVersionUID = -400066323594122516L;
+
+		private final int numberKeys;
+		private final int numberElements;
+		private final boolean terminateAfterEmission;
+
+		protected int counter = 0;
+
+		private boolean running = true;
+
+		SubtaskIndexSource(
+			int numberKeys,
+			int numberElements,
+			boolean terminateAfterEmission) {
+
+			this.numberKeys = numberKeys;
+			this.numberElements = numberElements;
+			this.terminateAfterEmission = terminateAfterEmission;
+		}
+
+		@Override
+		public void run(SourceContext<Integer> ctx) throws Exception {
+			final Object lock = ctx.getCheckpointLock();
+			final int subtaskIndex = getRuntimeContext().getIndexOfThisSubtask();
+
+			while (running) {
+
+				if (counter < numberElements) {
+					synchronized (lock) {
+						for (int value = subtaskIndex;
+							 value < numberKeys;
+							 value += getRuntimeContext().getNumberOfParallelSubtasks()) {
+
+							ctx.collect(value);
+						}
+
+						counter++;
+					}
+				} else {
+					if (terminateAfterEmission) {
+						running = false;
+					} else {
+						Thread.sleep(100);
+					}
+				}
+			}
+		}
+
+		@Override
+		public void cancel() {
+			running = false;
+		}
+	}
+
+	private static class SubtaskIndexNonPartitionedStateSource extends SubtaskIndexSource implements Checkpointed<Integer> {
+
+		private static final long serialVersionUID = 8388073059042040203L;
+
+		SubtaskIndexNonPartitionedStateSource(int numberKeys, int numberElements, boolean terminateAfterEmission) {
+			super(numberKeys, numberElements, terminateAfterEmission);
+		}
+
+		@Override
+		public Integer snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
+			return counter;
+		}
+
+		@Override
+		public void restoreState(Integer state) throws Exception {
+			counter = state;
+		}
+	}
+
+	private static class SubtaskIndexFlatMapper extends RichFlatMapFunction<Integer, Tuple2<Integer, Integer>> {
+
+		private static final long serialVersionUID = 5273172591283191348L;
+
+		private static volatile CountDownLatch workCompletedLatch = new CountDownLatch(1);
+
+		private transient ValueState<Integer> counter;
+		private transient ValueState<Integer> sum;
+
+		private final int numberElements;
+
+		SubtaskIndexFlatMapper(int numberElements) {
+			this.numberElements = numberElements;
+		}
+
+		@Override
+		public void open(Configuration configuration) {
+			counter = getRuntimeContext().getState(new ValueStateDescriptor<Integer>("counter", Integer.class, 0));
+			sum = getRuntimeContext().getState(new ValueStateDescriptor<Integer>("sum", Integer.class, 0));
+		}
+
+		@Override
+		public void flatMap(Integer value, Collector<Tuple2<Integer, Integer>> out) throws Exception {
+			int count = counter.value() + 1;
+			counter.update(count);
+
+			int s = sum.value() + value;
+			sum.update(s);
+
+			if (count % numberElements == 0) {
+				out.collect(Tuple2.of(getRuntimeContext().getIndexOfThisSubtask(), s));
+				workCompletedLatch.countDown();
+			}
+		}
+	}
+
+	private static class CollectionSink<IN> implements SinkFunction<IN> {
+
+		private static ConcurrentSet<Object> elements = new ConcurrentSet<Object>();
+
+		private static final long serialVersionUID = -1652452958040267745L;
+
+		public static <IN> Set<IN> getElementsSet() {
+			return (Set<IN>) elements;
+		}
+
+		public static void clearElementsSet() {
+			elements.clear();
+		}
+
+		@Override
+		public void invoke(IN value) throws Exception {
+			elements.add(value);
+		}
+	}
+
+	private static class NonPartitionedStateSource extends RichParallelSourceFunction<Integer> implements Checkpointed<Integer> {
+
+		private static final long serialVersionUID = -8108185918123186841L;
+
+		private int counter = 0;
+		private boolean running = true;
+
+		@Override
+		public Integer snapshotState(long checkpointId, long checkpointTimestamp) throws Exception {
+			return counter;
+		}
+
+		@Override
+		public void restoreState(Integer state) throws Exception {
+			counter = state;
+		}
+
+		@Override
+		public void run(SourceContext<Integer> ctx) throws Exception {
+			final Object lock = ctx.getCheckpointLock();
+
+			while (running) {
+				synchronized (lock) {
+					counter++;
+
+					ctx.collect(counter * getRuntimeContext().getIndexOfThisSubtask());
+				}
+
+				Thread.sleep(100);
+			}
+		}
+
+		@Override
+		public void cancel() {
+			running = true;
+		}
+	}
+}


Mime
View raw message