flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tzuli...@apache.org
Subject [09/28] flink git commit: [FLINK-8398] [kinesis, tests] Cleanup confusing implementations in KinesisDataFetcherTest and related classes
Date Tue, 06 Feb 2018 19:03:09 GMT
[FLINK-8398] [kinesis, tests] Cleanup confusing implementations in KinesisDataFetcherTest and
related classes

The previous implementation of the TestableKinesisDataFetcher was
confusing in various ways, causing it hard to be re-used for other
tests. This commit contains the following various cleaups:

- Remove confusing mocks of source context and checkpoint lock. We now
  allow users of the TestableKinesisDataFetcher to provide a source
  context, which should provide the checkpoint lock.
- Remove override of emitRecordAndUpdateState(). Strictly speaking, that
  method should be final. It was previously overriden to allow
  verifying how many records were output by the fetcher. That
  verification would be better implemented within a mock source context.
- Properly parameterize the output type for the
  TestableKinesisDataFetcher.
- Remove use of PowerMockito in KinesisDataFetcherTest.
- Use CheckedThreads to properly capture any exceptions in fetcher /
  consumer threads in unit tests.
- Use assertEquals / assertNull instead of assertTrue where-ever
  appropriate.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/6922e4c6
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/6922e4c6
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/6922e4c6

Branch: refs/heads/release-1.4
Commit: 6922e4c6b6d6ffe455d550a9f9a7c9a38fe8b8d1
Parents: 0ad82b3
Author: Tzu-Li (Gordon) Tai <tzulitai@apache.org>
Authored: Wed Jan 10 10:11:31 2018 +0800
Committer: Tzu-Li (Gordon) Tai <tzulitai@apache.org>
Committed: Tue Feb 6 17:31:32 2018 +0100

----------------------------------------------------------------------
 .../kinesis/internals/KinesisDataFetcher.java   |  10 +-
 .../internals/KinesisDataFetcherTest.java       | 239 +++++++++----------
 .../kinesis/internals/ShardConsumerTest.java    |  39 ++-
 .../kinesis/testutils/TestSourceContext.java    |  64 +++++
 .../testutils/TestableKinesisDataFetcher.java   | 104 +++-----
 5 files changed, 252 insertions(+), 204 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/6922e4c6/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
index bbfbb20..a8f37a5 100644
--- a/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
+++ b/flink-connectors/flink-connector-kinesis/src/main/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcher.java
@@ -17,6 +17,7 @@
 
 package org.apache.flink.streaming.connectors.kinesis.internals;
 
+import org.apache.flink.annotation.VisibleForTesting;
 import org.apache.flink.api.common.functions.RuntimeContext;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
@@ -183,7 +184,7 @@ public class KinesisDataFetcher<T> {
 			KinesisProxy.create(configProps));
 	}
 
-	/** This constructor is exposed for testing purposes. */
+	@VisibleForTesting
 	protected KinesisDataFetcher(List<String> streams,
 								SourceFunction.SourceContext<T> sourceContext,
 								Object checkpointLock,
@@ -481,7 +482,7 @@ public class KinesisDataFetcher<T> {
 	 *                        when the shard state was registered.
 	 * @param lastSequenceNumber the last sequence number value to update
 	 */
-	protected void emitRecordAndUpdateState(T record, long recordTimestamp, int shardStateIndex,
SequenceNumber lastSequenceNumber) {
+	protected final void emitRecordAndUpdateState(T record, long recordTimestamp, int shardStateIndex,
SequenceNumber lastSequenceNumber) {
 		synchronized (checkpointLock) {
 			sourceContext.collectWithTimestamp(record, recordTimestamp);
 			updateState(shardStateIndex, lastSequenceNumber);
@@ -498,7 +499,7 @@ public class KinesisDataFetcher<T> {
 	 *                        when the shard state was registered.
 	 * @param lastSequenceNumber the last sequence number value to update
 	 */
-	protected void updateState(int shardStateIndex, SequenceNumber lastSequenceNumber) {
+	protected final void updateState(int shardStateIndex, SequenceNumber lastSequenceNumber)
{
 		synchronized (checkpointLock) {
 			subscribedShardsState.get(shardStateIndex).setLastProcessedSequenceNum(lastSequenceNumber);
 
@@ -559,7 +560,8 @@ public class KinesisDataFetcher<T> {
 		return (Math.abs(shard.hashCode() % totalNumberOfConsumerSubtasks)) == indexOfThisConsumerSubtask;
 	}
 
-	private static ExecutorService createShardConsumersThreadPool(final String subtaskName)
{
+	@VisibleForTesting
+	protected ExecutorService createShardConsumersThreadPool(final String subtaskName) {
 		return Executors.newCachedThreadPool(new ThreadFactory() {
 			@Override
 			public Thread newThread(Runnable runnable) {

http://git-wip-us.apache.org/repos/asf/flink/blob/6922e4c6/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
index 2e1adb6..56566c0 100644
--- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/KinesisDataFetcherTest.java
@@ -18,27 +18,26 @@
 package org.apache.flink.streaming.connectors.kinesis.internals;
 
 import org.apache.flink.api.common.functions.RuntimeContext;
+import org.apache.flink.api.common.serialization.SimpleStringSchema;
+import org.apache.flink.core.testutils.CheckedThread;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.connectors.kinesis.FlinkKinesisConsumer;
-import org.apache.flink.streaming.connectors.kinesis.config.ConsumerConfigConstants;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
 import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
 import org.apache.flink.streaming.connectors.kinesis.model.StreamShardHandle;
 import org.apache.flink.streaming.connectors.kinesis.model.StreamShardMetadata;
 import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema;
+import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchemaWrapper;
 import org.apache.flink.streaming.connectors.kinesis.testutils.FakeKinesisBehavioursFactory;
 import org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGenerator;
+import org.apache.flink.streaming.connectors.kinesis.testutils.TestSourceContext;
+import org.apache.flink.streaming.connectors.kinesis.testutils.TestUtils;
 import org.apache.flink.streaming.connectors.kinesis.testutils.TestableKinesisDataFetcher;
 
 import com.amazonaws.services.kinesis.model.HashKeyRange;
 import com.amazonaws.services.kinesis.model.SequenceNumberRange;
 import com.amazonaws.services.kinesis.model.Shard;
 import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.mockito.Mockito;
-import org.powermock.api.mockito.PowerMockito;
-import org.powermock.core.classloader.annotations.PrepareForTest;
-import org.powermock.modules.junit4.PowerMockRunner;
 
 import java.util.HashMap;
 import java.util.LinkedList;
@@ -51,6 +50,7 @@ import java.util.UUID;
 import java.util.concurrent.atomic.AtomicReference;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
 import static org.junit.Assert.assertTrue;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.when;
@@ -58,8 +58,6 @@ import static org.mockito.Mockito.when;
 /**
  * Tests for the {@link KinesisDataFetcher}.
  */
-@RunWith(PowerMockRunner.class)
-@PrepareForTest(TestableKinesisDataFetcher.class)
 public class KinesisDataFetcherTest {
 
 	@Test(expected = RuntimeException.class)
@@ -71,14 +69,16 @@ public class KinesisDataFetcherTest {
 		HashMap<String, String> subscribedStreamsToLastSeenShardIdsUnderTest =
 			KinesisDataFetcher.createInitialSubscribedStreamsToLastDiscoveredShardsState(fakeStreams);
 
-		TestableKinesisDataFetcher fetcher =
-			new TestableKinesisDataFetcher(
+		TestableKinesisDataFetcher<String> fetcher =
+			new TestableKinesisDataFetcher<>(
 				fakeStreams,
-				new Properties(),
+				new TestSourceContext<>(),
+				TestUtils.getStandardProperties(),
+				new KinesisDeserializationSchemaWrapper<>(new SimpleStringSchema()),
 				10,
 				2,
-				new AtomicReference<Throwable>(),
-				new LinkedList<KinesisStreamShardState>(),
+				new AtomicReference<>(),
+				new LinkedList<>(),
 				subscribedStreamsToLastSeenShardIdsUnderTest,
 				FakeKinesisBehavioursFactory.noShardsFoundForRequestedStreamsBehaviour());
 
@@ -102,52 +102,44 @@ public class KinesisDataFetcherTest {
 			streamToShardCount.put(fakeStream, rand.nextInt(5) + 1);
 		}
 
-		final TestableKinesisDataFetcher fetcher =
-			new TestableKinesisDataFetcher(
+		final TestableKinesisDataFetcher<String> fetcher =
+			new TestableKinesisDataFetcher<>(
 				fakeStreams,
-				new Properties(),
+				new TestSourceContext<>(),
+				TestUtils.getStandardProperties(),
+				new KinesisDeserializationSchemaWrapper<>(new SimpleStringSchema()),
 				10,
 				2,
-				new AtomicReference<Throwable>(),
-				new LinkedList<KinesisStreamShardState>(),
+				new AtomicReference<>(),
+				new LinkedList<>(),
 				subscribedStreamsToLastSeenShardIdsUnderTest,
 				FakeKinesisBehavioursFactory.nonReshardedStreamsBehaviour(streamToShardCount));
 
-		Properties testConfig = new Properties();
-		testConfig.setProperty(ConsumerConfigConstants.AWS_REGION, "us-east-1");
-		testConfig.setProperty(ConsumerConfigConstants.AWS_CREDENTIALS_PROVIDER, "BASIC");
-		testConfig.setProperty(ConsumerConfigConstants.AWS_ACCESS_KEY_ID, "accessKeyId");
-		testConfig.setProperty(ConsumerConfigConstants.AWS_SECRET_ACCESS_KEY, "secretKey");
+		final DummyFlinkKafkaConsumer<String> consumer = new DummyFlinkKafkaConsumer<>(
+				TestUtils.getStandardProperties(), fetcher, 1, 0);
 
-		final DummyFlinkKafkaConsumer<String> consumer = new DummyFlinkKafkaConsumer<>(testConfig,
fetcher);
-
-		PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
-		Thread consumerThread = new Thread(new Runnable() {
+		CheckedThread consumerThread = new CheckedThread() {
 			@Override
-			public void run() {
-				try {
-					consumer.run(mock(SourceFunction.SourceContext.class));
-				} catch (Exception e) {
-					//
-				}
+			public void go() throws Exception {
+				consumer.run(new TestSourceContext<>());
 			}
-		});
+		};
 		consumerThread.start();
 
 		fetcher.waitUntilRun();
 		consumer.cancel();
-		consumerThread.join();
+		consumerThread.sync();
 
 		// assert that the streams tracked in the state are identical to the subscribed streams
 		Set<String> streamsInState = subscribedStreamsToLastSeenShardIdsUnderTest.keySet();
-		assertTrue(streamsInState.size() == fakeStreams.size());
+		assertEquals(fakeStreams.size(), streamsInState.size());
 		assertTrue(streamsInState.containsAll(fakeStreams));
 
 		// assert that the last seen shards in state is correctly set
 		for (Map.Entry<String, String> streamToLastSeenShard : subscribedStreamsToLastSeenShardIdsUnderTest.entrySet())
{
-			assertTrue(
-				streamToLastSeenShard.getValue().equals(
-					KinesisShardIdGenerator.generateFromShardOrder(streamToShardCount.get(streamToLastSeenShard.getKey())
- 1)));
+			assertEquals(
+				KinesisShardIdGenerator.generateFromShardOrder(streamToShardCount.get(streamToLastSeenShard.getKey())
- 1),
+				streamToLastSeenShard.getValue());
 		}
 	}
 
@@ -195,14 +187,16 @@ public class KinesisDataFetcherTest {
 		HashMap<String, String> subscribedStreamsToLastSeenShardIdsUnderTest =
 			KinesisDataFetcher.createInitialSubscribedStreamsToLastDiscoveredShardsState(fakeStreams);
 
-		final TestableKinesisDataFetcher fetcher =
-			new TestableKinesisDataFetcher(
+		final TestableKinesisDataFetcher<String> fetcher =
+			new TestableKinesisDataFetcher<>(
 				fakeStreams,
-				new Properties(),
+				new TestSourceContext<>(),
+				TestUtils.getStandardProperties(),
+				new KinesisDeserializationSchemaWrapper<>(new SimpleStringSchema()),
 				10,
 				2,
-				new AtomicReference<Throwable>(),
-				new LinkedList<KinesisStreamShardState>(),
+				new AtomicReference<>(),
+				new LinkedList<>(),
 				subscribedStreamsToLastSeenShardIdsUnderTest,
 				FakeKinesisBehavioursFactory.nonReshardedStreamsBehaviour(streamToShardCount));
 
@@ -213,31 +207,27 @@ public class KinesisDataFetcherTest {
 					restoredState.getKey(), new SequenceNumber(restoredState.getValue())));
 		}
 
-		PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
-		Thread runFetcherThread = new Thread(new Runnable() {
+		CheckedThread runFetcherThread = new CheckedThread() {
 			@Override
-			public void run() {
-				try {
-					fetcher.runFetcher();
-				} catch (Exception e) {
-					//
-				}
+			public void go() throws Exception {
+				fetcher.runFetcher();
 			}
-		});
+		};
 		runFetcherThread.start();
 		Thread.sleep(1000); // sleep a while before closing
 		fetcher.shutdownFetcher();
+		runFetcherThread.sync();
 
 		// assert that the streams tracked in the state are identical to the subscribed streams
 		Set<String> streamsInState = subscribedStreamsToLastSeenShardIdsUnderTest.keySet();
-		assertTrue(streamsInState.size() == fakeStreams.size());
+		assertEquals(fakeStreams.size(), streamsInState.size());
 		assertTrue(streamsInState.containsAll(fakeStreams));
 
 		// assert that the last seen shards in state is correctly set
 		for (Map.Entry<String, String> streamToLastSeenShard : subscribedStreamsToLastSeenShardIdsUnderTest.entrySet())
{
-			assertTrue(
-				streamToLastSeenShard.getValue().equals(
-					KinesisShardIdGenerator.generateFromShardOrder(streamToShardCount.get(streamToLastSeenShard.getKey())
- 1)));
+			assertEquals(
+				KinesisShardIdGenerator.generateFromShardOrder(streamToShardCount.get(streamToLastSeenShard.getKey())
- 1),
+				streamToLastSeenShard.getValue());
 		}
 	}
 
@@ -286,14 +276,16 @@ public class KinesisDataFetcherTest {
 			KinesisDataFetcher.createInitialSubscribedStreamsToLastDiscoveredShardsState(fakeStreams);
 
 		// using a non-resharded streams kinesis behaviour to represent that Kinesis is not resharded
AFTER the restore
-		final TestableKinesisDataFetcher fetcher =
-			new TestableKinesisDataFetcher(
+		final TestableKinesisDataFetcher<String> fetcher =
+			new TestableKinesisDataFetcher<>(
 				fakeStreams,
-				new Properties(),
+				new TestSourceContext<>(),
+				TestUtils.getStandardProperties(),
+				new KinesisDeserializationSchemaWrapper<>(new SimpleStringSchema()),
 				10,
 				2,
-				new AtomicReference<Throwable>(),
-				new LinkedList<KinesisStreamShardState>(),
+				new AtomicReference<>(),
+				new LinkedList<>(),
 				subscribedStreamsToLastSeenShardIdsUnderTest,
 				FakeKinesisBehavioursFactory.nonReshardedStreamsBehaviour(streamToShardCount));
 
@@ -304,31 +296,27 @@ public class KinesisDataFetcherTest {
 					restoredState.getKey(), new SequenceNumber(restoredState.getValue())));
 		}
 
-		PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
-		Thread runFetcherThread = new Thread(new Runnable() {
+		CheckedThread runFetcherThread = new CheckedThread() {
 			@Override
-			public void run() {
-				try {
-					fetcher.runFetcher();
-				} catch (Exception e) {
-					//
-				}
+			public void go() throws Exception {
+				fetcher.runFetcher();
 			}
-		});
+		};
 		runFetcherThread.start();
 		Thread.sleep(1000); // sleep a while before closing
 		fetcher.shutdownFetcher();
+		runFetcherThread.sync();
 
 		// assert that the streams tracked in the state are identical to the subscribed streams
 		Set<String> streamsInState = subscribedStreamsToLastSeenShardIdsUnderTest.keySet();
-		assertTrue(streamsInState.size() == fakeStreams.size());
+		assertEquals(fakeStreams.size(), streamsInState.size());
 		assertTrue(streamsInState.containsAll(fakeStreams));
 
 		// assert that the last seen shards in state is correctly set
 		for (Map.Entry<String, String> streamToLastSeenShard : subscribedStreamsToLastSeenShardIdsUnderTest.entrySet())
{
-			assertTrue(
-				streamToLastSeenShard.getValue().equals(
-					KinesisShardIdGenerator.generateFromShardOrder(streamToShardCount.get(streamToLastSeenShard.getKey())
- 1)));
+			assertEquals(
+				KinesisShardIdGenerator.generateFromShardOrder(streamToShardCount.get(streamToLastSeenShard.getKey())
- 1),
+				streamToLastSeenShard.getValue());
 		}
 	}
 
@@ -381,14 +369,16 @@ public class KinesisDataFetcherTest {
 			KinesisDataFetcher.createInitialSubscribedStreamsToLastDiscoveredShardsState(fakeStreams);
 
 		// using a non-resharded streams kinesis behaviour to represent that Kinesis is not resharded
AFTER the restore
-		final TestableKinesisDataFetcher fetcher =
-			new TestableKinesisDataFetcher(
+		final TestableKinesisDataFetcher<String> fetcher =
+			new TestableKinesisDataFetcher<>(
 				fakeStreams,
-				new Properties(),
+				new TestSourceContext<>(),
+				TestUtils.getStandardProperties(),
+				new KinesisDeserializationSchemaWrapper<>(new SimpleStringSchema()),
 				10,
 				2,
-				new AtomicReference<Throwable>(),
-				new LinkedList<KinesisStreamShardState>(),
+				new AtomicReference<>(),
+				new LinkedList<>(),
 				subscribedStreamsToLastSeenShardIdsUnderTest,
 				FakeKinesisBehavioursFactory.nonReshardedStreamsBehaviour(streamToShardCount));
 
@@ -399,33 +389,31 @@ public class KinesisDataFetcherTest {
 					restoredState.getKey(), new SequenceNumber(restoredState.getValue())));
 		}
 
-		PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
-		Thread runFetcherThread = new Thread(new Runnable() {
+		CheckedThread runFetcherThread = new CheckedThread() {
 			@Override
-			public void run() {
-				try {
-					fetcher.runFetcher();
-				} catch (Exception e) {
-					//
-				}
+			public void go() throws Exception {
+				fetcher.runFetcher();
 			}
-		});
+		};
 		runFetcherThread.start();
 		Thread.sleep(1000); // sleep a while before closing
 		fetcher.shutdownFetcher();
+		runFetcherThread.sync();
 
 		// assert that the streams tracked in the state are identical to the subscribed streams
 		Set<String> streamsInState = subscribedStreamsToLastSeenShardIdsUnderTest.keySet();
-		assertTrue(streamsInState.size() == fakeStreams.size());
+		assertEquals(fakeStreams.size(), streamsInState.size());
 		assertTrue(streamsInState.containsAll(fakeStreams));
 
 		// assert that the last seen shards in state is correctly set
-		assertTrue(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream1").equals(
-			KinesisShardIdGenerator.generateFromShardOrder(2)));
-		assertTrue(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream2").equals(
-			KinesisShardIdGenerator.generateFromShardOrder(1)));
-		assertTrue(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream3") == null);
-		assertTrue(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream4") == null);
+		assertEquals(
+			KinesisShardIdGenerator.generateFromShardOrder(2),
+			subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream1"));
+		assertEquals(
+			KinesisShardIdGenerator.generateFromShardOrder(1),
+			subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream2"));
+		assertNull(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream3"));
+		assertNull(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream4"));
 	}
 
 	@Test
@@ -477,14 +465,16 @@ public class KinesisDataFetcherTest {
 			KinesisDataFetcher.createInitialSubscribedStreamsToLastDiscoveredShardsState(fakeStreams);
 
 		// using a non-resharded streams kinesis behaviour to represent that Kinesis is not resharded
AFTER the restore
-		final TestableKinesisDataFetcher fetcher =
-			new TestableKinesisDataFetcher(
+		final TestableKinesisDataFetcher<String> fetcher =
+			new TestableKinesisDataFetcher<>(
 				fakeStreams,
-				new Properties(),
+				new TestSourceContext<>(),
+				TestUtils.getStandardProperties(),
+				new KinesisDeserializationSchemaWrapper<>(new SimpleStringSchema()),
 				10,
 				2,
-				new AtomicReference<Throwable>(),
-				new LinkedList<KinesisStreamShardState>(),
+				new AtomicReference<>(),
+				new LinkedList<>(),
 				subscribedStreamsToLastSeenShardIdsUnderTest,
 				FakeKinesisBehavioursFactory.nonReshardedStreamsBehaviour(streamToShardCount));
 
@@ -495,33 +485,31 @@ public class KinesisDataFetcherTest {
 					restoredState.getKey(), new SequenceNumber(restoredState.getValue())));
 		}
 
-		PowerMockito.whenNew(ShardConsumer.class).withAnyArguments().thenReturn(Mockito.mock(ShardConsumer.class));
-		Thread runFetcherThread = new Thread(new Runnable() {
+		CheckedThread runFetcherThread = new CheckedThread() {
 			@Override
-			public void run() {
-				try {
-					fetcher.runFetcher();
-				} catch (Exception e) {
-					//
-				}
+			public void go() throws Exception {
+				fetcher.runFetcher();
 			}
-		});
+		};
 		runFetcherThread.start();
 		Thread.sleep(1000); // sleep a while before closing
 		fetcher.shutdownFetcher();
+		runFetcherThread.sync();
 
 		// assert that the streams tracked in the state are identical to the subscribed streams
 		Set<String> streamsInState = subscribedStreamsToLastSeenShardIdsUnderTest.keySet();
-		assertTrue(streamsInState.size() == fakeStreams.size());
+		assertEquals(fakeStreams.size(), streamsInState.size());
 		assertTrue(streamsInState.containsAll(fakeStreams));
 
 		// assert that the last seen shards in state is correctly set
-		assertTrue(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream1").equals(
-			KinesisShardIdGenerator.generateFromShardOrder(3)));
-		assertTrue(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream2").equals(
-			KinesisShardIdGenerator.generateFromShardOrder(4)));
-		assertTrue(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream3") == null);
-		assertTrue(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream4") == null);
+		assertEquals(
+			KinesisShardIdGenerator.generateFromShardOrder(3),
+			subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream1"));
+		assertEquals(
+			KinesisShardIdGenerator.generateFromShardOrder(4),
+			subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream2"));
+		assertNull(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream3"));
+		assertNull(subscribedStreamsToLastSeenShardIdsUnderTest.get("fakeStream4"));
 	}
 
 	@Test
@@ -564,12 +552,21 @@ public class KinesisDataFetcherTest {
 	private static class DummyFlinkKafkaConsumer<T> extends FlinkKinesisConsumer<T>
{
 		private static final long serialVersionUID = 1L;
 
-		private KinesisDataFetcher<T> fetcher;
+		private final KinesisDataFetcher<T> fetcher;
+
+		private final int numParallelSubtasks;
+		private final int subtaskIndex;
 
 		@SuppressWarnings("unchecked")
-		DummyFlinkKafkaConsumer(Properties properties, KinesisDataFetcher<T> fetcher) {
+		DummyFlinkKafkaConsumer(
+				Properties properties,
+				KinesisDataFetcher<T> fetcher,
+				int numParallelSubtasks,
+				int subtaskIndex) {
 			super("test", mock(KinesisDeserializationSchema.class), properties);
 			this.fetcher = fetcher;
+			this.numParallelSubtasks = numParallelSubtasks;
+			this.subtaskIndex = subtaskIndex;
 		}
 
 		@Override
@@ -585,8 +582,8 @@ public class KinesisDataFetcherTest {
 		@Override
 		public RuntimeContext getRuntimeContext() {
 			RuntimeContext context = mock(RuntimeContext.class);
-			when(context.getIndexOfThisSubtask()).thenReturn(0);
-			when(context.getNumberOfParallelSubtasks()).thenReturn(1);
+			when(context.getIndexOfThisSubtask()).thenReturn(subtaskIndex);
+			when(context.getNumberOfParallelSubtasks()).thenReturn(numParallelSubtasks);
 			return context;
 		}
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/6922e4c6/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumerTest.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumerTest.java
b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumerTest.java
index a194835..efc98a4 100644
--- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumerTest.java
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/internals/ShardConsumerTest.java
@@ -22,9 +22,12 @@ import org.apache.flink.streaming.connectors.kinesis.model.SentinelSequenceNumbe
 import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
 import org.apache.flink.streaming.connectors.kinesis.model.StreamShardHandle;
 import org.apache.flink.streaming.connectors.kinesis.proxy.KinesisProxyInterface;
+import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchemaWrapper;
 import org.apache.flink.streaming.connectors.kinesis.testutils.FakeKinesisBehavioursFactory;
 import org.apache.flink.streaming.connectors.kinesis.testutils.KinesisShardIdGenerator;
+import org.apache.flink.streaming.connectors.kinesis.testutils.TestSourceContext;
 import org.apache.flink.streaming.connectors.kinesis.testutils.TestableKinesisDataFetcher;
+import org.apache.flink.streaming.util.serialization.SimpleStringSchema;
 
 import com.amazonaws.services.kinesis.model.HashKeyRange;
 import com.amazonaws.services.kinesis.model.Shard;
@@ -38,7 +41,7 @@ import java.util.LinkedList;
 import java.util.Properties;
 import java.util.concurrent.atomic.AtomicReference;
 
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.assertEquals;
 
 /**
  * Tests for the {@link ShardConsumer}.
@@ -61,13 +64,17 @@ public class ShardConsumerTest {
 			new KinesisStreamShardState(KinesisDataFetcher.convertToStreamShardMetadata(fakeToBeConsumedShard),
 				fakeToBeConsumedShard, new SequenceNumber("fakeStartingState")));
 
-		TestableKinesisDataFetcher fetcher =
-			new TestableKinesisDataFetcher(
+		TestSourceContext<String> sourceContext = new TestSourceContext<>();
+
+		TestableKinesisDataFetcher<String> fetcher =
+			new TestableKinesisDataFetcher<>(
 				Collections.singletonList("fakeStream"),
+				sourceContext,
 				new Properties(),
+				new KinesisDeserializationSchemaWrapper<>(new SimpleStringSchema()),
 				10,
 				2,
-				new AtomicReference<Throwable>(),
+				new AtomicReference<>(),
 				subscribedShardsStateUnderTest,
 				KinesisDataFetcher.createInitialSubscribedStreamsToLastDiscoveredShardsState(Collections.singletonList("fakeStream")),
 				Mockito.mock(KinesisProxyInterface.class));
@@ -79,9 +86,10 @@ public class ShardConsumerTest {
 			subscribedShardsStateUnderTest.get(0).getLastProcessedSequenceNum(),
 			FakeKinesisBehavioursFactory.totalNumOfRecordsAfterNumOfGetRecordsCalls(1000, 9)).run();
 
-		assertTrue(fetcher.getNumOfElementsCollected() == 1000);
-		assertTrue(subscribedShardsStateUnderTest.get(0).getLastProcessedSequenceNum().equals(
-			SentinelSequenceNumber.SENTINEL_SHARD_ENDING_SEQUENCE_NUM.get()));
+		assertEquals(1000, sourceContext.getCollectedOutputs().size());
+		assertEquals(
+			SentinelSequenceNumber.SENTINEL_SHARD_ENDING_SEQUENCE_NUM.get(),
+			subscribedShardsStateUnderTest.get(0).getLastProcessedSequenceNum());
 	}
 
 	@Test
@@ -100,13 +108,17 @@ public class ShardConsumerTest {
 			new KinesisStreamShardState(KinesisDataFetcher.convertToStreamShardMetadata(fakeToBeConsumedShard),
 				fakeToBeConsumedShard, new SequenceNumber("fakeStartingState")));
 
-		TestableKinesisDataFetcher fetcher =
-			new TestableKinesisDataFetcher(
+		TestSourceContext<String> sourceContext = new TestSourceContext<>();
+
+		TestableKinesisDataFetcher<String> fetcher =
+			new TestableKinesisDataFetcher<>(
 				Collections.singletonList("fakeStream"),
+				sourceContext,
 				new Properties(),
+				new KinesisDeserializationSchemaWrapper<>(new SimpleStringSchema()),
 				10,
 				2,
-				new AtomicReference<Throwable>(),
+				new AtomicReference<>(),
 				subscribedShardsStateUnderTest,
 				KinesisDataFetcher.createInitialSubscribedStreamsToLastDiscoveredShardsState(Collections.singletonList("fakeStream")),
 				Mockito.mock(KinesisProxyInterface.class));
@@ -120,9 +132,10 @@ public class ShardConsumerTest {
 			// and the 7th getRecords() call will encounter an unexpected expired shard iterator
 			FakeKinesisBehavioursFactory.totalNumOfRecordsAfterNumOfGetRecordsCallsWithUnexpectedExpiredIterator(1000,
9, 7)).run();
 
-		assertTrue(fetcher.getNumOfElementsCollected() == 1000);
-		assertTrue(subscribedShardsStateUnderTest.get(0).getLastProcessedSequenceNum().equals(
-			SentinelSequenceNumber.SENTINEL_SHARD_ENDING_SEQUENCE_NUM.get()));
+		assertEquals(1000, sourceContext.getCollectedOutputs().size());
+		assertEquals(
+			SentinelSequenceNumber.SENTINEL_SHARD_ENDING_SEQUENCE_NUM.get(),
+			subscribedShardsStateUnderTest.get(0).getLastProcessedSequenceNum());
 	}
 
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/6922e4c6/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestSourceContext.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestSourceContext.java
b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestSourceContext.java
new file mode 100644
index 0000000..4fe5c54
--- /dev/null
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestSourceContext.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.streaming.connectors.kinesis.testutils;
+
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
+import org.apache.flink.streaming.api.watermark.Watermark;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+
+import java.util.concurrent.ConcurrentLinkedQueue;
+
+/**
+ * A testable {@link SourceFunction.SourceContext}.
+ */
+public class TestSourceContext<T> implements SourceFunction.SourceContext<T>
{
+
+	private final Object checkpointLock = new Object();
+
+	private ConcurrentLinkedQueue<StreamRecord<T>> collectedOutputs = new ConcurrentLinkedQueue<>();
+
+	@Override
+	public void collect(T element) {
+		this.collectedOutputs.add(new StreamRecord<>(element));
+	}
+
+	@Override
+	public void collectWithTimestamp(T element, long timestamp) {
+		this.collectedOutputs.add(new StreamRecord<>(element, timestamp));
+	}
+
+	@Override
+	public void emitWatermark(Watermark mark) {
+		throw new UnsupportedOperationException();
+	}
+
+	@Override
+	public void markAsTemporarilyIdle() {}
+
+	@Override
+	public Object getCheckpointLock() {
+		return checkpointLock;
+	}
+
+	@Override
+	public void close() {}
+
+	public ConcurrentLinkedQueue<StreamRecord<T>> getCollectedOutputs() {
+		return collectedOutputs;
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/6922e4c6/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
----------------------------------------------------------------------
diff --git a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
index b6f3cbc..65ae6cc 100644
--- a/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
+++ b/flink-connectors/flink-connector-kinesis/src/test/java/org/apache/flink/streaming/connectors/kinesis/testutils/TestableKinesisDataFetcher.java
@@ -22,72 +22,55 @@ import org.apache.flink.core.testutils.OneShotLatch;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.connectors.kinesis.internals.KinesisDataFetcher;
 import org.apache.flink.streaming.connectors.kinesis.model.KinesisStreamShardState;
-import org.apache.flink.streaming.connectors.kinesis.model.SequenceNumber;
 import org.apache.flink.streaming.connectors.kinesis.proxy.KinesisProxyInterface;
 import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchema;
-import org.apache.flink.streaming.connectors.kinesis.serialization.KinesisDeserializationSchemaWrapper;
-import org.apache.flink.streaming.util.serialization.SimpleStringSchema;
 
 import org.mockito.Mockito;
 import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
 
 import java.util.HashMap;
 import java.util.LinkedList;
 import java.util.List;
 import java.util.Properties;
+import java.util.concurrent.ExecutorService;
 import java.util.concurrent.atomic.AtomicReference;
 
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
 /**
  * Extension of the {@link KinesisDataFetcher} for testing.
  */
-public class TestableKinesisDataFetcher extends KinesisDataFetcher<String> {
-
-	private static final Object fakeCheckpointLock = new Object();
-
-	private long numElementsCollected;
+public class TestableKinesisDataFetcher<T> extends KinesisDataFetcher<T> {
 
 	private OneShotLatch runWaiter;
+	private volatile boolean running;
 
 	public TestableKinesisDataFetcher(
 			List<String> fakeStreams,
+			SourceFunction.SourceContext<T> sourceContext,
 			Properties fakeConfiguration,
+			KinesisDeserializationSchema<T> deserializationSchema,
 			int fakeTotalCountOfSubtasks,
-			int fakeTndexOfThisSubtask,
+			int fakeIndexOfThisSubtask,
 			AtomicReference<Throwable> thrownErrorUnderTest,
 			LinkedList<KinesisStreamShardState> subscribedShardsStateUnderTest,
 			HashMap<String, String> subscribedStreamsToLastDiscoveredShardIdsStateUnderTest,
 			KinesisProxyInterface fakeKinesis) {
-		super(fakeStreams,
-			getMockedSourceContext(),
-			fakeCheckpointLock,
-			getMockedRuntimeContext(fakeTotalCountOfSubtasks, fakeTndexOfThisSubtask),
+		super(
+			fakeStreams,
+			sourceContext,
+			sourceContext.getCheckpointLock(),
+			getMockedRuntimeContext(fakeTotalCountOfSubtasks, fakeIndexOfThisSubtask),
 			fakeConfiguration,
-			new KinesisDeserializationSchemaWrapper<>(new SimpleStringSchema()),
+			deserializationSchema,
 			thrownErrorUnderTest,
 			subscribedShardsStateUnderTest,
 			subscribedStreamsToLastDiscoveredShardIdsStateUnderTest,
 			fakeKinesis);
 
-		this.numElementsCollected = 0;
 		this.runWaiter = new OneShotLatch();
-	}
-
-	public long getNumOfElementsCollected() {
-		return numElementsCollected;
-	}
-
-	@Override
-	protected KinesisDeserializationSchema<String> getClonedDeserializationSchema() {
-		return new KinesisDeserializationSchemaWrapper<>(new SimpleStringSchema());
-	}
-
-	@Override
-	protected void emitRecordAndUpdateState(String record, long recordTimestamp, int shardStateIndex,
SequenceNumber lastSequenceNumber) {
-		synchronized (fakeCheckpointLock) {
-			this.numElementsCollected++;
-			updateState(shardStateIndex, lastSequenceNumber);
-		}
+		this.running = true;
 	}
 
 	@Override
@@ -100,41 +83,30 @@ public class TestableKinesisDataFetcher extends KinesisDataFetcher<String>
{
 		runWaiter.await();
 	}
 
-	@SuppressWarnings("unchecked")
-	private static SourceFunction.SourceContext<String> getMockedSourceContext() {
-		return Mockito.mock(SourceFunction.SourceContext.class);
+	@Override
+	protected ExecutorService createShardConsumersThreadPool(String subtaskName) {
+		// this is just a dummy fetcher, so no need to create a thread pool for shard consumers
+		ExecutorService mockExecutor = mock(ExecutorService.class);
+		when(mockExecutor.isTerminated()).thenAnswer((InvocationOnMock invocation) -> !running);
+		return mockExecutor;
+	}
+
+	@Override
+	public void awaitTermination() throws InterruptedException {
+		this.running = false;
+		super.awaitTermination();
 	}
 
-	private static RuntimeContext getMockedRuntimeContext(final int fakeTotalCountOfSubtasks,
final int fakeTndexOfThisSubtask) {
-		RuntimeContext mockedRuntimeContext = Mockito.mock(RuntimeContext.class);
-
-		Mockito.when(mockedRuntimeContext.getNumberOfParallelSubtasks()).thenAnswer(new Answer<Integer>()
{
-			@Override
-			public Integer answer(InvocationOnMock invocationOnMock) throws Throwable {
-				return fakeTotalCountOfSubtasks;
-			}
-		});
-
-		Mockito.when(mockedRuntimeContext.getIndexOfThisSubtask()).thenAnswer(new Answer<Integer>()
{
-			@Override
-			public Integer answer(InvocationOnMock invocationOnMock) throws Throwable {
-				return fakeTndexOfThisSubtask;
-			}
-		});
-
-		Mockito.when(mockedRuntimeContext.getTaskName()).thenAnswer(new Answer<String>()
{
-			@Override
-			public String answer(InvocationOnMock invocationOnMock) throws Throwable {
-				return "Fake Task";
-			}
-		});
-
-		Mockito.when(mockedRuntimeContext.getTaskNameWithSubtasks()).thenAnswer(new Answer<String>()
{
-			@Override
-			public String answer(InvocationOnMock invocationOnMock) throws Throwable {
-				return "Fake Task (" + fakeTndexOfThisSubtask + "/" + fakeTotalCountOfSubtasks + ")";
-			}
-		});
+	private static RuntimeContext getMockedRuntimeContext(final int fakeTotalCountOfSubtasks,
final int fakeIndexOfThisSubtask) {
+		RuntimeContext mockedRuntimeContext = mock(RuntimeContext.class);
+
+		Mockito.when(mockedRuntimeContext.getNumberOfParallelSubtasks()).thenReturn(fakeTotalCountOfSubtasks);
+		Mockito.when(mockedRuntimeContext.getIndexOfThisSubtask()).thenReturn(fakeIndexOfThisSubtask);
+		Mockito.when(mockedRuntimeContext.getTaskName()).thenReturn("Fake Task");
+		Mockito.when(mockedRuntimeContext.getTaskNameWithSubtasks()).thenReturn(
+				"Fake Task (" + fakeIndexOfThisSubtask + "/" + fakeTotalCountOfSubtasks + ")");
+		Mockito.when(mockedRuntimeContext.getUserCodeClassLoader()).thenReturn(
+				Thread.currentThread().getContextClassLoader());
 
 		return mockedRuntimeContext;
 	}


Mime
View raw message