flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tzuli...@apache.org
Subject [2/2] flink git commit: [FLINK-4821] [kinesis] General improvements to rescalable FlinkKinesisConsumer
Date Sun, 07 May 2017 09:35:29 GMT
[FLINK-4821] [kinesis] General improvements to rescalable FlinkKinesisConsumer

This commit adds some general improvements to the rescalable
implementation of FlinkKinesisConsumer, including:
- Refactor setup procedures in KinesisDataFetcher so that duplicate work
  isn't done on a restored run
- Strengthen corner cases where fetcher was not fully seeded with
  initial state when snapshot is taken

This closes #3001.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/e5b65a7f
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/e5b65a7f
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/e5b65a7f

Branch: refs/heads/master
Commit: e5b65a7fc2b4a7532ca40748f81bcbf8ace46815
Parents: a05b574
Author: Tzu-Li (Gordon) Tai <tzulitai@apache.org>
Authored: Sun May 7 16:29:32 2017 +0800
Committer: Tzu-Li (Gordon) Tai <tzulitai@apache.org>
Committed: Sun May 7 17:33:04 2017 +0800

----------------------------------------------------------------------
 .../kinesis/FlinkKinesisConsumer.java           | 150 ++++++++---------
 .../kinesis/internals/KinesisDataFetcher.java   |  52 +-----
 .../FlinkKinesisConsumerMigrationTest.java      |   5 +-
 .../kinesis/FlinkKinesisConsumerTest.java       | 159 +++++++++++--------
 .../internals/KinesisDataFetcherTest.java       |  65 ++++++--
 .../testutils/TestableKinesisDataFetcher.java   |  14 ++
 6 files changed, 233 insertions(+), 212 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
index dfcd552..4982f7f 100644
--- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
+++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumer.java
@@ -25,13 +25,14 @@ import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
 import org.apache.flink.api.java.typeutils.TupleTypeInfo;
-import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.state.FunctionInitializationContext;
 import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.checkpoint.CheckpointedRestoring;
 import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
+import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants.InitialPosition;
 import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
@@ -67,9 +68,9 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
  * @param <T> the type of data emitted
  */
 public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> implements
-	ResultTypeQueryable<T>,
-	CheckpointedFunction,
-	CheckpointedRestoring<HashMap<KinesisStreamShard, SequenceNumber>> {
+		ResultTypeQueryable<T>,
+		CheckpointedFunction,
+		CheckpointedRestoring<HashMap<KinesisStreamShard, SequenceNumber>> {
 
 	private static final long serialVersionUID = 4724006128720664870L;
 
@@ -86,7 +87,7 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 	 * shard list retrieval behaviours, etc */
 	private final Properties configProps;
 
-	/** User supplied deseriliazation schema to convert Kinesis byte messages to Flink objects */
+	/** User supplied deserialization schema to convert Kinesis byte messages to Flink objects */
 	private final KinesisDeserializationSchema<T> deserializer;
 
 	// ------------------------------------------------------------------------
@@ -96,9 +97,6 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 	/** Per-task fetcher for Kinesis data records, where each fetcher pulls data from one or more Kinesis shards */
 	private transient KinesisDataFetcher<T> fetcher;
 
-	/** The sequence numbers in the last state snapshot of this subtask */
-	private transient HashMap<KinesisStreamShard, SequenceNumber> lastStateSnapshot;
-
 	/** The sequence numbers to restore to upon restore from failure */
 	private transient HashMap<KinesisStreamShard, SequenceNumber> sequenceNumsToRestore;
 
@@ -108,7 +106,7 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 	//  State for Checkpoint
 	// ------------------------------------------------------------------------
 
-	/** The name is the key for sequence numbers state, and cannot be changed. */
+	/** State name to access shard sequence number states; cannot be changed */
 	private static final String sequenceNumsStateStoreName = "Kinesis-Stream-Shard-State";
 
 	private transient ListState<Tuple2<KinesisStreamShard, SequenceNumber>> sequenceNumsStateForCheckpoint;
@@ -191,57 +189,33 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 	// ------------------------------------------------------------------------
 
 	@Override
-	public void open(Configuration parameters) throws Exception {
-		super.open(parameters);
-
-		// restore to the last known sequence numbers from the latest complete snapshot
-		if (sequenceNumsToRestore != null) {
-			if (LOG.isInfoEnabled()) {
-				LOG.info("Subtask {} is restoring sequence numbers {} from previous checkpointed state",
-					getRuntimeContext().getIndexOfThisSubtask(), sequenceNumsToRestore.toString());
-			}
-
-			// initialize sequence numbers with restored state
-			lastStateSnapshot = sequenceNumsToRestore;
-		} else {
-			// start fresh with empty sequence numbers if there are no snapshots to restore from.
-			lastStateSnapshot = new HashMap<>();
-		}
-	}
-
-	@Override
 	public void run(SourceContext<T> sourceContext) throws Exception {
 
 		// all subtasks will run a fetcher, regardless of whether or not the subtask will initially have
 		// shards to subscribe to; fetchers will continuously poll for changes in the shard list, so all subtasks
 		// can potentially have new shards to subscribe to later on
-		fetcher = createFetcher(streams, sourceContext, getRuntimeContext(), configProps, deserializer);
-
-		boolean isRestoringFromFailure = (sequenceNumsToRestore != null);
-		fetcher.setIsRestoringFromFailure(isRestoringFromFailure);
-
-		// if we are restoring from a checkpoint, we iterate over the restored
-		// state and accordingly seed the fetcher with subscribed shards states
-		if (isRestoringFromFailure) {
-			// Since there may have a situation that some subtasks did not finish discovering before rescale,
-			// and KinesisDataFetcher will always discover the shard from the largest shard id. To prevent from
-			// missing some shards which didn't be discovered and whose id is not the largest one, we force the
-			// consumer to discover once from the smallest id and make sure each shard have its initial sequence
-			// number from restored state or SENTINEL_EARLIEST_SEQUENCE_NUM.
-			List<KinesisStreamShard> newShardsCreatedWhileNotRunning = fetcher.discoverNewShardsToSubscribe();
-			for (KinesisStreamShard shard : newShardsCreatedWhileNotRunning) {
-				SequenceNumber startingStateForNewShard;
-
-				if (lastStateSnapshot.containsKey(shard)) {
-					startingStateForNewShard = lastStateSnapshot.get(shard);
+		KinesisDataFetcher<T> fetcher = createFetcher(streams, sourceContext, getRuntimeContext(), configProps, deserializer);
+
+		// initial discovery
+		List<KinesisStreamShard> allShards = fetcher.discoverNewShardsToSubscribe();
+
+		for (KinesisStreamShard shard : allShards) {
+			if (sequenceNumsToRestore != null) {
+				if (sequenceNumsToRestore.containsKey(shard)) {
+					// if the shard was already seen and is contained in the state,
+					// just use the sequence number stored in the state
+					fetcher.registerNewSubscribedShardState(
+						new KinesisStreamShardState(shard, sequenceNumsToRestore.get(shard)));
 
 					if (LOG.isInfoEnabled()) {
 						LOG.info("Subtask {} is seeding the fetcher with restored shard {}," +
 								" starting state set to the restored sequence number {}",
-							getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), startingStateForNewShard);
+							getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), sequenceNumsToRestore.get(shard));
 					}
 				} else {
-					startingStateForNewShard = SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get();
+					// the shard wasn't discovered in the previous run, therefore should be consumed from the beginning
+					fetcher.registerNewSubscribedShardState(
+						new KinesisStreamShardState(shard, SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get()));
 
 					if (LOG.isInfoEnabled()) {
 						LOG.info("Subtask {} is seeding the fetcher with new discovered shard {}," +
@@ -249,9 +223,20 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 							getRuntimeContext().getIndexOfThisSubtask(), shard.toString());
 					}
 				}
+			} else {
+				// we're starting fresh; use the configured start position as initial state
+				SentinelSequenceNumber startingSeqNum =
+					InitialPosition.valueOf(configProps.getProperty(
+						ConsumerConfigConstants.STREAM_INITIAL_POSITION,
+						ConsumerConfigConstants.DEFAULT_STREAM_INITIAL_POSITION)).toSentinelSequenceNumber();
 
 				fetcher.registerNewSubscribedShardState(
-					new KinesisStreamShardState(shard, startingStateForNewShard));
+					new KinesisStreamShardState(shard, startingSeqNum.get()));
+
+				if (LOG.isInfoEnabled()) {
+					LOG.info("Subtask {} will be seeded with initial shard {}, starting state set as sequence number {}",
+						getRuntimeContext().getIndexOfThisSubtask(), shard.toString(), startingSeqNum.get());
+				}
 			}
 		}
 
@@ -260,6 +245,10 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 			return;
 		}
 
+		// expose the fetcher from this point, so that state
+		// snapshots can be taken from the fetcher's state holders
+		this.fetcher = fetcher;
+
 		// start the fetcher loop. The fetcher will stop running only when cancel() or
 		// close() is called, or an error is thrown by threads created by the fetcher
 		fetcher.runFetcher();
@@ -306,13 +295,12 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 
 	@Override
 	public void initializeState(FunctionInitializationContext context) throws Exception {
-		TypeInformation<Tuple2<KinesisStreamShard, SequenceNumber>> tuple = new TupleTypeInfo<>(
+		TypeInformation<Tuple2<KinesisStreamShard, SequenceNumber>> shardsStateTypeInfo = new TupleTypeInfo<>(
 			TypeInformation.of(KinesisStreamShard.class),
-			TypeInformation.of(SequenceNumber.class)
-		);
+			TypeInformation.of(SequenceNumber.class));
 
 		sequenceNumsStateForCheckpoint = context.getOperatorStateStore().getUnionListState(
-			new ListStateDescriptor<>(sequenceNumsStateStoreName, tuple));
+			new ListStateDescriptor<>(sequenceNumsStateStoreName, shardsStateTypeInfo));
 
 		if (context.isRestored()) {
 			if (sequenceNumsToRestore == null) {
@@ -323,8 +311,6 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 
 				LOG.info("Setting restore state in the FlinkKinesisConsumer. Using the following offsets: {}",
 					sequenceNumsToRestore);
-			} else if (sequenceNumsToRestore.isEmpty()) {
-				sequenceNumsToRestore = null;
 			}
 		} else {
 			LOG.info("No restore state for FlinkKinesisConsumer.");
@@ -333,11 +319,7 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 
 	@Override
 	public void snapshotState(FunctionSnapshotContext context) throws Exception {
-		if (lastStateSnapshot == null) {
-			LOG.debug("snapshotState() requested on not yet opened source; returning null.");
-		} else if (fetcher == null) {
-			LOG.debug("snapshotState() requested on not yet running source; returning null.");
-		} else if (!running) {
+		if (!running) {
 			LOG.debug("snapshotState() called on closed source; returning null.");
 		} else {
 			if (LOG.isDebugEnabled()) {
@@ -345,15 +327,33 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 			}
 
 			sequenceNumsStateForCheckpoint.clear();
-			lastStateSnapshot = fetcher.snapshotState();
 
-			if (LOG.isDebugEnabled()) {
-				LOG.debug("Snapshotted state, last processed sequence numbers: {}, checkpoint id: {}, timestamp: {}",
-					lastStateSnapshot.toString(), context.getCheckpointId(), context.getCheckpointTimestamp());
-			}
+			if (fetcher == null) {
+				if (sequenceNumsToRestore != null) {
+					for (Map.Entry<KinesisStreamShard, SequenceNumber> entry : sequenceNumsToRestore.entrySet()) {
+						// sequenceNumsToRestore is the restored global union state;
+						// should only snapshot shards that actually belong to us
+
+						if (KinesisDataFetcher.isThisSubtaskShouldSubscribeTo(
+								entry.getKey(),
+								getRuntimeContext().getNumberOfParallelSubtasks(),
+								getRuntimeContext().getIndexOfThisSubtask())) {
+
+							sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue()));
+						}
+					}
+				}
+			} else {
+				HashMap<KinesisStreamShard, SequenceNumber> lastStateSnapshot = fetcher.snapshotState();
+
+				if (LOG.isDebugEnabled()) {
+					LOG.debug("Snapshotted state, last processed sequence numbers: {}, checkpoint id: {}, timestamp: {}",
+						lastStateSnapshot.toString(), context.getCheckpointId(), context.getCheckpointTimestamp());
+				}
 
-			for (Map.Entry<KinesisStreamShard, SequenceNumber> entry : lastStateSnapshot.entrySet()) {
-				sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue()));
+				for (Map.Entry<KinesisStreamShard, SequenceNumber> entry : lastStateSnapshot.entrySet()) {
+					sequenceNumsStateForCheckpoint.add(Tuple2.of(entry.getKey(), entry.getValue()));
+				}
 			}
 		}
 	}
@@ -366,12 +366,14 @@ public class FlinkKinesisConsumer<T> extends RichParallelSourceFunction<T> imple
 		sequenceNumsToRestore = restoredState.isEmpty() ? null : restoredState;
 	}
 
-	/** This method is created for tests that can mock the KinesisDataFetcher in the consumer. */
-	protected KinesisDataFetcher<T> createFetcher(List<String> streams,
-													SourceFunction.SourceContext<T> sourceContext,
-													RuntimeContext runtimeContext,
-													Properties configProps,
-													KinesisDeserializationSchema<T> deserializationSchema) {
+	/** This method is exposed for tests that need to mock the KinesisDataFetcher in the consumer. */
+	protected KinesisDataFetcher<T> createFetcher(
+			List<String> streams,
+			SourceFunction.SourceContext<T> sourceContext,
+			RuntimeContext runtimeContext,
+			Properties configProps,
+			KinesisDeserializationSchema<T> deserializationSchema) {
+
 		return new KinesisDataFetcher<>(streams, sourceContext, runtimeContext, configProps, deserializationSchema);
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
index c5b4b04..99305cb 100644
--- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
+++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
@@ -19,9 +19,7 @@ package org.apache.flink.streaming.connectors.kinesis.internals;
 
 import org.apache.flink.api.common.functions.RuntimeContext;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
-import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisConsumer;
 import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
-import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants.InitialPosition;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
 import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumber;
@@ -99,12 +97,6 @@ public class KinesisDataFetcher<T> {
 
 	private final int indexOfThisConsumerSubtask;
 
-	/**
-	 * This flag should be set by {@link FlinkKinesisConsumer} using
-	 * {@link KinesisDataFetcher#setIsRestoringFromFailure(boolean)}
-	 */
-	private boolean isRestoredFromFailure;
-
 	// ------------------------------------------------------------------------
 	//  Executor services to run created threads
 	// ------------------------------------------------------------------------
@@ -235,41 +227,7 @@ public class KinesisDataFetcher<T> {
 		//  Procedures before starting the infinite while loop:
 		// ------------------------------------------------------------------------
 
-		//  1. query for any new shards that may have been created while the Kinesis consumer was not running,
-		//     and register them to the subscribedShardState list.
-		if (LOG.isDebugEnabled()) {
-			String logFormat = (!isRestoredFromFailure)
-				? "Subtask {} is trying to discover initial shards ..."
-				: "Subtask {} is trying to discover any new shards that were created while the consumer wasn't " +
-				"running due to failure ...";
-
-			LOG.debug(logFormat, indexOfThisConsumerSubtask);
-		}
-		List<KinesisStreamShard> newShardsCreatedWhileNotRunning = discoverNewShardsToSubscribe();
-		for (KinesisStreamShard shard : newShardsCreatedWhileNotRunning) {
-			// the starting state for new shards created while the consumer wasn't running depends on whether or not
-			// we are starting fresh (not restoring from a checkpoint); when we are starting fresh, this simply means
-			// all existing shards of streams we are subscribing to are new shards; when we are restoring from checkpoint,
-			// any new shards due to Kinesis resharding from the time of the checkpoint will be considered new shards.
-			InitialPosition initialPosition = InitialPosition.valueOf(configProps.getProperty(
-				ConsumerConfigConstants.STREAM_INITIAL_POSITION, ConsumerConfigConstants.DEFAULT_STREAM_INITIAL_POSITION));
-
-			SentinelSequenceNumber startingStateForNewShard = (isRestoredFromFailure)
-				? SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM
-				: initialPosition.toSentinelSequenceNumber();
-
-			if (LOG.isInfoEnabled()) {
-				String logFormat = (!isRestoredFromFailure)
-					? "Subtask {} will be seeded with initial shard {}, starting state set as sequence number {}"
-					: "Subtask {} will be seeded with new shard {} that was created while the consumer wasn't " +
-					"running due to failure, starting state set as sequence number {}";
-
-				LOG.info(logFormat, indexOfThisConsumerSubtask, shard.toString(), startingStateForNewShard.get());
-			}
-			registerNewSubscribedShardState(new KinesisStreamShardState(shard, startingStateForNewShard.get()));
-		}
-
-		//  2. check that there is at least one shard in the subscribed streams to consume from (can be done by
+		//  1. check that there is at least one shard in the subscribed streams to consume from (can be done by
 		//     checking if at least one value in subscribedStreamsToLastDiscoveredShardIds is not null)
 		boolean hasShards = false;
 		StringBuilder streamsWithNoShardsFound = new StringBuilder();
@@ -290,7 +248,7 @@ public class KinesisDataFetcher<T> {
 			throw new RuntimeException("No shards can be found for all subscribed streams: " + streams);
 		}
 
-		//  3. start consuming any shard state we already have in the subscribedShardState up to this point; the
+		//  2. start consuming any shard state we already have in the subscribedShardState up to this point; the
 		//     subscribedShardState may already be seeded with values due to step 1., or explicitly added by the
 		//     consumer using a restored state checkpoint
 		for (int seededStateIndex = 0; seededStateIndex < subscribedShardsState.size(); seededStateIndex++) {
@@ -489,10 +447,6 @@ public class KinesisDataFetcher<T> {
 	//  Functions to get / set information about the consumer
 	// ------------------------------------------------------------------------
 
-	public void setIsRestoringFromFailure(boolean bool) {
-		this.isRestoredFromFailure = bool;
-	}
-
 	protected Properties getConsumerConfiguration() {
 		return configProps;
 	}
@@ -595,7 +549,7 @@ public class KinesisDataFetcher<T> {
 	 * @param totalNumberOfConsumerSubtasks total number of consumer subtasks
 	 * @param indexOfThisConsumerSubtask index of this consumer subtask
 	 */
-	private static boolean isThisSubtaskShouldSubscribeTo(KinesisStreamShard shard,
+	public static boolean isThisSubtaskShouldSubscribeTo(KinesisStreamShard shard,
 														int totalNumberOfConsumerSubtasks,
 														int indexOfThisConsumerSubtask) {
 		return (Math.abs(shard.hashCode() % totalNumberOfConsumerSubtasks)) == indexOfThisConsumerSubtask;

http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
index 2f46e09..7629f9d 100644
--- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerMigrationTest.java
@@ -42,10 +42,7 @@ import static org.mockito.Mockito.mock;
 
 /**
  * Tests for checking whether {@link FlinkKinesisConsumer} can restore from snapshots that were
- * done using the Flink 1.1 {@link FlinkKinesisConsumer}.
- *
- * <p>For regenerating the binary snapshot file you have to run the commented out portion
- * of each test on a checkout of the Flink 1.1 branch.
+ * done using the Flink 1.1 {@code FlinkKinesisConsumer}.
  */
 public class FlinkKinesisConsumerMigrationTest {
 

http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
index bf8e44f..4b178c7 100644
--- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/FlinkKinesisConsumerTest.java
@@ -40,6 +40,7 @@ import org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGen
 import org.apache.flink.streaming.connectors.kinesis.testutils.TestableFlinkKinesisConsumer;
 import org.apache.flink.streaming.connectors.kinesis.util.KinesisConfigUtil;
 import org.apache.flink.streaming.util.serialization.SimpleStringSchema;
+import org.junit.Assert;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.ExpectedException;
@@ -57,10 +58,8 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.Properties;
 import java.util.UUID;
-import java.io.Serializable;
 
 import static org.junit.Assert.fail;
-import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotEquals;
 import static org.mockito.Mockito.mock;
@@ -530,7 +529,7 @@ public class FlinkKinesisConsumerTest {
 	// ----------------------------------------------------------------------
 
 	@Test
-	public void testSnapshotStateShouldNotClearListStateIfSourceNotOpened() throws Exception {
+	public void testUseRestoredStateForSnapshotIfFetcherNotInitialized() throws Exception {
 		Properties config = new Properties();
 		config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
 		config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId");
@@ -538,57 +537,63 @@ public class FlinkKinesisConsumerTest {
 
 		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
 
-		TestingListState<Serializable> listState = new TestingListState<>();
-
-		FlinkKinesisConsumer<String> consumer = new FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
-
-		when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
-
-		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
-
-		when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
-		when(initializationContext.isRestored()).thenReturn(false);
-
-		consumer.initializeState(initializationContext);
-
-		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(123, 123)); //arbitrary checkpoint id and timestamp
-
-		assertFalse(listState.isClearCalled());
-	}
-
-	@Test
-	public void testSnapshotStateShouldNotClearListStateIfSourceNotRun() throws Exception {
-		Properties config = new Properties();
-		config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
-		config.setProperty(AWSConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId");
-		config.setProperty(AWSConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey");
-
-		OperatorStateStore operatorStateStore = mock(OperatorStateStore.class);
+		List<Tuple2<KinesisStreamShard, SequenceNumber>> globalUnionState = new ArrayList<>(4);
+		globalUnionState.add(Tuple2.of(
+			new KinesisStreamShard("fakeStream",
+				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
+			new SequenceNumber("1")));
+		globalUnionState.add(Tuple2.of(
+			new KinesisStreamShard("fakeStream",
+				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
+			new SequenceNumber("1")));
+		globalUnionState.add(Tuple2.of(
+			new KinesisStreamShard("fakeStream",
+				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
+			new SequenceNumber("1")));
+		globalUnionState.add(Tuple2.of(
+			new KinesisStreamShard("fakeStream",
+				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(3))),
+			new SequenceNumber("1")));
 
-		TestingListState<Serializable> listState = new TestingListState<>();
+		TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> listState = new TestingListState<>();
+		for (Tuple2<KinesisStreamShard, SequenceNumber> state : globalUnionState) {
+			listState.add(state);
+		}
 
 		FlinkKinesisConsumer<String> consumer = new FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
+		RuntimeContext context = mock(RuntimeContext.class);
+		when(context.getIndexOfThisSubtask()).thenReturn(0);
+		when(context.getNumberOfParallelSubtasks()).thenReturn(2);
+		consumer.setRuntimeContext(context);
 
 		when(operatorStateStore.getUnionListState(Matchers.any(ListStateDescriptor.class))).thenReturn(listState);
 
 		StateInitializationContext initializationContext = mock(StateInitializationContext.class);
 
 		when(initializationContext.getOperatorStateStore()).thenReturn(operatorStateStore);
-		when(initializationContext.isRestored()).thenReturn(false);
+		when(initializationContext.isRestored()).thenReturn(true);
 
 		consumer.initializeState(initializationContext);
 
-		consumer.open(new Configuration()); // only opened, not run
+		// only opened, not run
+		consumer.open(new Configuration());
+
+		// arbitrary checkpoint id and timestamp
+		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(123, 123));
 
-		consumer.snapshotState(new StateSnapshotContextSynchronousImpl(123, 123)); //arbitrary checkpoint id and timestamp
+		Assert.assertTrue(listState.isClearCalled());
 
-		assertFalse(listState.isClearCalled());
+		// the checkpointed list state should contain only the shards that it should subscribe to
+		Assert.assertEquals(globalUnionState.size() / 2, listState.getList().size());
+		Assert.assertTrue(listState.getList().contains(globalUnionState.get(0)));
+		Assert.assertTrue(listState.getList().contains(globalUnionState.get(2)));
 	}
 
 	@Test
 	public void testListStateChangedAfterSnapshotState() throws Exception {
+
 		// ----------------------------------------------------------------------
-		// setting config, initial state and state after snapshot
+		// setup config, initial state and expected state snapshot
 		// ----------------------------------------------------------------------
 		Properties config = new Properties();
 		config.setProperty(AWSConfigConstants.AWS_REGION, "us-east-1");
@@ -601,16 +606,16 @@ public class FlinkKinesisConsumerTest {
 				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
 			new SequenceNumber("1")));
 
-		ArrayList<Tuple2<KinesisStreamShard, SequenceNumber>> snapShotState = new ArrayList<>(3);
-		snapShotState.add(Tuple2.of(
+		ArrayList<Tuple2<KinesisStreamShard, SequenceNumber>> expectedStateSnapshot = new ArrayList<>(3);
+		expectedStateSnapshot.add(Tuple2.of(
 			new KinesisStreamShard("fakeStream1",
 				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(0))),
 			new SequenceNumber("12")));
-		snapShotState.add(Tuple2.of(
+		expectedStateSnapshot.add(Tuple2.of(
 			new KinesisStreamShard("fakeStream1",
 				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(1))),
 			new SequenceNumber("11")));
-		snapShotState.add(Tuple2.of(
+		expectedStateSnapshot.add(Tuple2.of(
 			new KinesisStreamShard("fakeStream1",
 				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
 			new SequenceNumber("31")));
@@ -618,8 +623,9 @@ public class FlinkKinesisConsumerTest {
 		// ----------------------------------------------------------------------
 		// mock operator state backend and initial state for initializeState()
 		// ----------------------------------------------------------------------
-		TestingListState<Serializable> listState = new TestingListState<>();
-		for (Serializable state: initialState) {
+
+		TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> listState = new TestingListState<>();
+		for (Tuple2<KinesisStreamShard, SequenceNumber> state: initialState) {
 			listState.add(state);
 		}
 
@@ -633,8 +639,9 @@ public class FlinkKinesisConsumerTest {
 		// ----------------------------------------------------------------------
 		// mock a running fetcher and its state for snapshot
 		// ----------------------------------------------------------------------
+
 		HashMap<KinesisStreamShard, SequenceNumber> stateSnapshot = new HashMap<>();
-		for (Tuple2<KinesisStreamShard, SequenceNumber> tuple: snapShotState) {
+		for (Tuple2<KinesisStreamShard, SequenceNumber> tuple: expectedStateSnapshot) {
 			stateSnapshot.put(tuple.f0, tuple.f1);
 		}
 
@@ -644,6 +651,7 @@ public class FlinkKinesisConsumerTest {
 		// ----------------------------------------------------------------------
 		// create a consumer and test the snapshotState()
 		// ----------------------------------------------------------------------
+
 		FlinkKinesisConsumer<String> consumer = new FlinkKinesisConsumer<>("fakeStream", new SimpleStringSchema(), config);
 		FlinkKinesisConsumer<?> mockedConsumer = spy(consumer);
 
@@ -653,22 +661,22 @@ public class FlinkKinesisConsumerTest {
 		mockedConsumer.setRuntimeContext(context);
 		mockedConsumer.initializeState(initializationContext);
 		mockedConsumer.open(new Configuration());
-		Whitebox.setInternalState(mockedConsumer, "fetcher", mockedFetcher); // mock as consumer is running.
+		Whitebox.setInternalState(mockedConsumer, "fetcher", mockedFetcher); // mock consumer as running.
 
 		mockedConsumer.snapshotState(mock(FunctionSnapshotContext.class));
 
 		assertEquals(true, listState.clearCalled);
 		assertEquals(3, listState.getList().size());
 
-		for (Serializable state: initialState) {
-			for (Serializable currentState: listState.getList()) {
+		for (Tuple2<KinesisStreamShard, SequenceNumber> state: initialState) {
+			for (Tuple2<KinesisStreamShard, SequenceNumber> currentState: listState.getList()) {
 				assertNotEquals(state, currentState);
 			}
 		}
 
-		for (Serializable state: snapShotState) {
+		for (Tuple2<KinesisStreamShard, SequenceNumber> state: expectedStateSnapshot) {
 			boolean hasOneIsSame = false;
-			for (Serializable currentState: listState.getList()) {
+			for (Tuple2<KinesisStreamShard, SequenceNumber> currentState: listState.getList()) {
 				hasOneIsSame = hasOneIsSame || state.equals(currentState);
 			}
 			assertEquals(true, hasOneIsSame);
@@ -693,8 +701,6 @@ public class FlinkKinesisConsumerTest {
 			"fakeStream", new Properties(), 10, 2);
 		consumer.open(new Configuration());
 		consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
-
-		Mockito.verify(mockedFetcher).setIsRestoringFromFailure(false);
 	}
 
 	@Test
@@ -718,7 +724,6 @@ public class FlinkKinesisConsumerTest {
 		consumer.open(new Configuration());
 		consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
 
-		Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
 		for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) {
 			Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
 				new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
@@ -728,15 +733,18 @@ public class FlinkKinesisConsumerTest {
 	@Test
 	@SuppressWarnings("unchecked")
 	public void testFetcherShouldBeCorrectlySeededIfRestoringFromCheckpoint() throws Exception {
+
 		// ----------------------------------------------------------------------
-		// setting initial state
+		// setup initial state
 		// ----------------------------------------------------------------------
+
 		HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("all");
 
 		// ----------------------------------------------------------------------
 		// mock operator state backend and initial state for initializeState()
 		// ----------------------------------------------------------------------
-		TestingListState<Serializable> listState = new TestingListState<>();
+
+		TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> listState = new TestingListState<>();
 		for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredState.entrySet()) {
 			listState.add(Tuple2.of(state.getKey(), state.getValue()));
 		}
@@ -751,6 +759,7 @@ public class FlinkKinesisConsumerTest {
 		// ----------------------------------------------------------------------
 		// mock fetcher
 		// ----------------------------------------------------------------------
+
 		KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class);
 		List<KinesisStreamShard> shards = new ArrayList<>();
 		shards.addAll(fakeRestoredState.keySet());
@@ -762,15 +771,15 @@ public class FlinkKinesisConsumerTest {
 		PowerMockito.doNothing().when(KinesisConfigUtil.class);
 
 		// ----------------------------------------------------------------------
-		// start to test seed initial state to fetcher
+		// start to test fetcher's initial state seeding
 		// ----------------------------------------------------------------------
+
 		TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer(
 			"fakeStream", new Properties(), 10, 2);
 		consumer.initializeState(initializationContext);
 		consumer.open(new Configuration());
 		consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
 
-		Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
 		for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) {
 			Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
 				new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));
@@ -780,9 +789,11 @@ public class FlinkKinesisConsumerTest {
 	@Test
 	@SuppressWarnings("unchecked")
 	public void testFetcherShouldBeCorrectlySeededOnlyItsOwnStates() throws Exception {
+
 		// ----------------------------------------------------------------------
-		// setting initial state
+		// setup initial state
 		// ----------------------------------------------------------------------
+
 		HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("fakeStream1");
 
 		HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredStateForOthers = getFakeRestoredStore("fakeStream2");
@@ -790,7 +801,8 @@ public class FlinkKinesisConsumerTest {
 		// ----------------------------------------------------------------------
 		// mock operator state backend and initial state for initializeState()
 		// ----------------------------------------------------------------------
-		TestingListState<Serializable> listState = new TestingListState<>();
+
+		TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> listState = new TestingListState<>();
 		for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredState.entrySet()) {
 			listState.add(Tuple2.of(state.getKey(), state.getValue()));
 		}
@@ -808,6 +820,7 @@ public class FlinkKinesisConsumerTest {
 		// ----------------------------------------------------------------------
 		// mock fetcher
 		// ----------------------------------------------------------------------
+
 		KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class);
 		List<KinesisStreamShard> shards = new ArrayList<>();
 		shards.addAll(fakeRestoredState.keySet());
@@ -819,15 +832,15 @@ public class FlinkKinesisConsumerTest {
 		PowerMockito.doNothing().when(KinesisConfigUtil.class);
 
 		// ----------------------------------------------------------------------
-		// start to test seed initial state to fetcher
+		// start to test fetcher's initial state seeding
 		// ----------------------------------------------------------------------
+
 		TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer(
 			"fakeStream", new Properties(), 10, 2);
 		consumer.initializeState(initializationContext);
 		consumer.open(new Configuration());
 		consumer.run(Mockito.mock(SourceFunction.SourceContext.class));
 
-		Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
 		for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredStateForOthers.entrySet()) {
 			// should never get restored state not belonging to itself
 			Mockito.verify(mockedFetcher, never()).registerNewSubscribedShardState(
@@ -841,42 +854,49 @@ public class FlinkKinesisConsumerTest {
 	}
 
 	/*
-	 * If the original parallelism is 2 and states is:
+	 * This tests that the consumer correctly picks up shards that were not discovered on the previous run.
+	 *
+	 * Case under test:
+	 *
+	 * If the original parallelism is 2 and states are:
 	 *   Consumer subtask 1:
 	 *     stream1, shard1, SequentialNumber(xxx)
 	 *   Consumer subtask 2:
 	 *     stream1, shard2, SequentialNumber(yyy)
-	 * After discoverNewShardsToSubscribe() if there are two shards (shard3, shard4) been created:
+	 *
+	 * After discoverNewShardsToSubscribe() if there were two shards (shard3, shard4) created:
 	 *   Consumer subtask 1 (late for discoverNewShardsToSubscribe()):
 	 *     stream1, shard1, SequentialNumber(xxx)
 	 *   Consumer subtask 2:
 	 *     stream1, shard2, SequentialNumber(yyy)
 	 *     stream1, shard4, SequentialNumber(zzz)
-	 *  If snapshotState() occur and parallelism is changed to 1:
-	 *    Union state will be:
+	 *
+	 * If snapshotState() occurs and parallelism is changed to 1:
+	 *   Union state will be:
 	 *     stream1, shard1, SequentialNumber(xxx)
 	 *     stream1, shard2, SequentialNumber(yyy)
 	 *     stream1, shard4, SequentialNumber(zzz)
-	 *    Fetcher should be seeded with:
+	 *   Fetcher should be seeded with:
 	 *     stream1, shard1, SequentialNumber(xxx)
 	 *     stream1, shard2, SequentialNumber(yyy)
 	 *     stream1, share3, SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM
 	 *     stream1, shard4, SequentialNumber(zzz)
-	 *
-	 *  This test is to guarantee the fetcher will be seeded correctly for such situation.
 	 */
 	@Test
 	@SuppressWarnings("unchecked")
 	public void testFetcherShouldBeCorrectlySeededWithNewDiscoveredKinesisStreamShard() throws Exception {
+
 		// ----------------------------------------------------------------------
-		// setting initial state
+		// setup initial state
 		// ----------------------------------------------------------------------
+
 		HashMap<KinesisStreamShard, SequenceNumber> fakeRestoredState = getFakeRestoredStore("all");
 
 		// ----------------------------------------------------------------------
 		// mock operator state backend and initial state for initializeState()
 		// ----------------------------------------------------------------------
-		TestingListState<Serializable> listState = new TestingListState<>();
+
+		TestingListState<Tuple2<KinesisStreamShard, SequenceNumber>> listState = new TestingListState<>();
 		for (Map.Entry<KinesisStreamShard, SequenceNumber> state: fakeRestoredState.entrySet()) {
 			listState.add(Tuple2.of(state.getKey(), state.getValue()));
 		}
@@ -891,6 +911,7 @@ public class FlinkKinesisConsumerTest {
 		// ----------------------------------------------------------------------
 		// mock fetcher
 		// ----------------------------------------------------------------------
+
 		KinesisDataFetcher mockedFetcher = Mockito.mock(KinesisDataFetcher.class);
 		List<KinesisStreamShard> shards = new ArrayList<>();
 		shards.addAll(fakeRestoredState.keySet());
@@ -904,8 +925,9 @@ public class FlinkKinesisConsumerTest {
 		PowerMockito.doNothing().when(KinesisConfigUtil.class);
 
 		// ----------------------------------------------------------------------
-		// start to test seed initial state to fetcher
+		// start to test fetcher's initial state seeding
 		// ----------------------------------------------------------------------
+
 		TestableFlinkKinesisConsumer consumer = new TestableFlinkKinesisConsumer(
 			"fakeStream", new Properties(), 10, 2);
 		consumer.initializeState(initializationContext);
@@ -915,7 +937,6 @@ public class FlinkKinesisConsumerTest {
 		fakeRestoredState.put(new KinesisStreamShard("fakeStream2",
 				new Shard().withShardId(KinesisShardIdGenerator.generateFromShardOrder(2))),
 			SentinelSequenceNumber.SENTINEL_EARLIEST_SEQUENCE_NUM.get());
-		Mockito.verify(mockedFetcher).setIsRestoringFromFailure(true);
 		for (Map.Entry<KinesisStreamShard, SequenceNumber> restoredShard : fakeRestoredState.entrySet()) {
 			Mockito.verify(mockedFetcher).registerNewSubscribedShardState(
 				new KinesisStreamShardState(restoredShard.getKey(), restoredShard.getValue()));

http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
index e79f9b1..800fde5 100644
--- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
@@ -18,9 +18,14 @@
 package org.apache.flink.streaming.connectors.kinesis.internals;
 
 import com.amazonaws.services.kinesis.model.Shard;
+import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisConsumer;
+import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShard;
 import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
+import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema;
 import org.apache.flink.streaming.connectors.kinesis.testutils.FakeKinesisBehavioursFactory;
 import org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGenerator;
 import org.apache.flink.streaming.connectors.kinesis.testutils.TestableKinesisDataFetcher;
@@ -42,6 +47,8 @@ import java.util.UUID;
 import java.util.concurrent.atomic.AtomicReference;
 
 import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
 
 @RunWith(PowerMockRunner.class)
 @PrepareForTest(TestableKinesisDataFetcher.class)
@@ -67,8 +74,6 @@ public class KinesisDataFetcherTest {
 				subscribedStreamsToLastSeenShardIdsUnderTest,
 				FakeKinesisBehavioursFactory.noShardsFoundForRequestedStreamsBehaviour());
 
-		fetcher.setIsRestoringFromFailure(false); // not restoring
-
 		fetcher.runFetcher(); // this should throw RuntimeException
 	}
 
@@ -100,23 +105,30 @@ public class KinesisDataFetcherTest {
 				subscribedStreamsToLastSeenShardIdsUnderTest,
 				FakeKinesisBehavioursFactory.nonReshardedStreamsBehaviour(streamToShardCount));
 
-		fetcher.setIsRestoringFromFailure(false);
+		Properties testConfig = new Properties();
+		testConfig.setProperty(ConsumerConfigConstants.AWS_REGION, "us-east-1");
+		testConfig.setProperty(ConsumerConfigConstants.AWS_CREDENTIALS_PROVIDER, "BASIC");
+		testConfig.setProperty(ConsumerConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId");
+		testConfig.setProperty(ConsumerConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey");
+
+		final DummyFlinkKafkaConsumer<String> consumer = new DummyFlinkKafkaConsumer<>(testConfig, fetcher);
 
 		PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
-		Thread runFetcherThread = new Thread(new Runnable() {
+		Thread consumerThread = new Thread(new Runnable() {
 			@Override
 			public void run() {
 				try {
-					fetcher.runFetcher();
+					consumer.run(mock(SourceFunction.SourceContext.class));
 				} catch (Exception e) {
 					//
 				}
 			}
 		});
-		runFetcherThread.start();
-		Thread.sleep(1000); // sleep a while before closing
-		fetcher.shutdownFetcher();
+		consumerThread.start();
 
+		fetcher.waitUntilRun();
+		consumer.cancel();
+		consumerThread.join();
 
 		// assert that the streams tracked in the state are identical to the subscribed streams
 		Set<String> streamsInState = subscribedStreamsToLastSeenShardIdsUnderTest.keySet();
@@ -192,8 +204,6 @@ public class KinesisDataFetcherTest {
 				new KinesisStreamShardState(restoredState.getKey(), new SequenceNumber(restoredState.getValue())));
 		}
 
-		fetcher.setIsRestoringFromFailure(true);
-
 		PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
 		Thread runFetcherThread = new Thread(new Runnable() {
 			@Override
@@ -284,8 +294,6 @@ public class KinesisDataFetcherTest {
 				new KinesisStreamShardState(restoredState.getKey(), new SequenceNumber(restoredState.getValue())));
 		}
 
-		fetcher.setIsRestoringFromFailure(true);
-
 		PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
 		Thread runFetcherThread = new Thread(new Runnable() {
 			@Override
@@ -380,8 +388,6 @@ public class KinesisDataFetcherTest {
 				new KinesisStreamShardState(restoredState.getKey(), new SequenceNumber(restoredState.getValue())));
 		}
 
-		fetcher.setIsRestoringFromFailure(true);
-
 		PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
 		Thread runFetcherThread = new Thread(new Runnable() {
 			@Override
@@ -477,8 +483,6 @@ public class KinesisDataFetcherTest {
 				new KinesisStreamShardState(restoredState.getKey(), new SequenceNumber(restoredState.getValue())));
 		}
 
-		fetcher.setIsRestoringFromFailure(true);
-
 		PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
 		Thread runFetcherThread = new Thread(new Runnable() {
 			@Override
@@ -507,4 +511,33 @@ public class KinesisDataFetcherTest {
 		assertTrue(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream3") == null);
 		assertTrue(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream4") == null);
 	}
+
+	private static class DummyFlinkKafkaConsumer<T> extends FlinkKinesisConsumer<T> {
+		private static final long serialVersionUID = 1L;
+
+		private KinesisDataFetcher<T> fetcher;
+
+		@SuppressWarnings("unchecked")
+		DummyFlinkKafkaConsumer(Properties properties, KinesisDataFetcher<T> fetcher) {
+			super("test", mock(KinesisDeserializationSchema.class), properties);
+			this.fetcher = fetcher;
+		}
+
+		@Override
+		protected KinesisDataFetcher<T> createFetcher(List<String> streams,
+													  SourceFunction.SourceContext<T> sourceContext,
+													  RuntimeContext runtimeContext,
+													  Properties configProps,
+													  KinesisDeserializationSchema<T> deserializationSchema) {
+			return fetcher;
+		}
+
+		@Override
+		public RuntimeContext getRuntimeContext() {
+			RuntimeContext context = mock(RuntimeContext.class);
+			when(context.getIndexOfThisSubtask()).thenReturn(0);
+			when(context.getNumberOfParallelSubtasks()).thenReturn(1);
+			return context;
+		}
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/e5b65a7f/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
index 57886fe..bb644ba 100644
--- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
@@ -18,6 +18,7 @@
 package org.apache.flink.streaming.connectors.kinesis.testutils;
 
 import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
@@ -42,6 +43,8 @@ public class TestableKinesisDataFetcher extends KinesisDataFetcher<String> {
 
 	private long numElementsCollected;
 
+	private OneShotLatch runWaiter;
+
 	public TestableKinesisDataFetcher(List<String> fakeStreams,
 									  Properties fakeConfiguration,
 									  int fakeTotalCountOfSubtasks,
@@ -62,6 +65,7 @@ public class TestableKinesisDataFetcher extends KinesisDataFetcher<String> {
 			fakeKinesis);
 
 		this.numElementsCollected = 0;
+		this.runWaiter = new OneShotLatch();
 	}
 
 	public long getNumOfElementsCollected() {
@@ -81,6 +85,16 @@ public class TestableKinesisDataFetcher extends KinesisDataFetcher<String> {
 		}
 	}
 
+	@Override
+	public void runFetcher() throws Exception {
+		runWaiter.trigger();
+		super.runFetcher();
+	}
+
+	public void waitUntilRun() throws Exception {
+		runWaiter.await();
+	}
+
 	@SuppressWarnings("unchecked")
 	private static SourceFunction.SourceContext<String> getMockedSourceContext() {
 		return Mockito.mock(SourceFunction.SourceContext.class);


Mime
View raw message