flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From srich...@apache.org
Subject [2/6] flink git commit: [FLINK-7720] [checkpoints] Centralize creation of backends and state related resources
Date Mon, 22 Jan 2018 13:08:46 GMT
http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
index 6650118..6bd4fac 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/AbstractStreamOperatorTestHarness.java
@@ -30,6 +30,7 @@ import org.apache.flink.runtime.checkpoint.OperatorStateRepartitioner;
 import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
 import org.apache.flink.runtime.checkpoint.RoundRobinOperatorStateRepartitioner;
 import org.apache.flink.runtime.checkpoint.StateAssignmentOperation;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
@@ -37,9 +38,9 @@ import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
 import org.apache.flink.runtime.state.CheckpointStreamFactory;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateHandle;
-import org.apache.flink.runtime.state.OperatorStateBackend;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StateBackend;
+import org.apache.flink.runtime.state.TestTaskStateManager;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.streaming.api.TimeCharacteristic;
 import org.apache.flink.streaming.api.graph.StreamConfig;
@@ -48,6 +49,8 @@ import org.apache.flink.streaming.api.operators.AbstractStreamOperatorTest;
 import org.apache.flink.streaming.api.operators.OperatorSnapshotResult;
 import org.apache.flink.streaming.api.operators.Output;
 import org.apache.flink.streaming.api.operators.StreamOperator;
+import org.apache.flink.streaming.api.operators.StreamTaskStateInitializer;
+import org.apache.flink.streaming.api.operators.StreamTaskStateInitializerImpl;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.LatencyMarker;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
@@ -98,10 +101,14 @@ public class AbstractStreamOperatorTestHarness<OUT> implements
AutoCloseable {
 
 	protected final StreamTask<?, ?> mockTask;
 
-	final Environment environment;
+	protected final TestTaskStateManager taskStateManager;
+
+	final MockEnvironment environment;
 
 	private final Optional<MockEnvironment> internalEnvironment;
 
+	protected StreamTaskStateInitializer streamTaskStateInitializer;
+
 	CloseableRegistry closableRegistry;
 
 	// use this as default for tests
@@ -134,6 +141,7 @@ public class AbstractStreamOperatorTestHarness<OUT> implements AutoCloseable
{
 				1024,
 				new Configuration(),
 				new ExecutionConfig(),
+				new TestTaskStateManager(),
 				maxParallelism,
 				parallelism,
 				subtaskIndex),
@@ -142,33 +150,37 @@ public class AbstractStreamOperatorTestHarness<OUT> implements
AutoCloseable {
 
 	public AbstractStreamOperatorTestHarness(
 			StreamOperator<OUT> operator,
-			final Environment environment) throws Exception {
-		this(operator, environment, false);
+			MockEnvironment env) throws Exception {
+		this(operator, env, false);
 	}
 
 	private AbstractStreamOperatorTestHarness(
 			StreamOperator<OUT> operator,
-			final Environment environment,
+			MockEnvironment env,
 			boolean environmentIsInternal) throws Exception {
 		this.operator = operator;
 		this.outputList = new ConcurrentLinkedQueue<>();
 		this.sideOutputLists = new HashMap<>();
 
-		Configuration underlyingConfig = environment.getTaskConfiguration();
+		Configuration underlyingConfig = env.getTaskConfiguration();
 		this.config = new StreamConfig(underlyingConfig);
 		this.config.setCheckpointingEnabled(true);
 		this.config.setOperatorID(new OperatorID());
-		this.executionConfig = environment.getExecutionConfig();
+		this.executionConfig = env.getExecutionConfig();
 		this.closableRegistry = new CloseableRegistry();
 		this.checkpointLock = new Object();
 
-		this.environment = Preconditions.checkNotNull(environment);
-		this.internalEnvironment = environmentIsInternal ? Optional.of((MockEnvironment) environment)
: Optional.empty();
+		this.environment = Preconditions.checkNotNull(env);
+
+		this.taskStateManager = env.getTaskStateManager();
+		this.internalEnvironment = environmentIsInternal ? Optional.of(environment) : Optional.empty();
 
 		mockTask = mock(StreamTask.class);
 		processingTimeService = new TestProcessingTimeService();
 		processingTimeService.setCurrentTime(0);
 
+		this.streamTaskStateInitializer = createStreamTaskStateManager(environment, stateBackend,
processingTimeService);
+
 		StreamStatusMaintainer mockStreamStatusMaintainer = new StreamStatusMaintainer() {
 			StreamStatus currentStreamStatus = StreamStatus.ACTIVE;
 
@@ -191,8 +203,9 @@ public class AbstractStreamOperatorTestHarness<OUT> implements AutoCloseable
{
 		when(mockTask.getTaskConfiguration()).thenReturn(underlyingConfig);
 		when(mockTask.getEnvironment()).thenReturn(environment);
 		when(mockTask.getExecutionConfig()).thenReturn(executionConfig);
+		when(mockTask.createStreamTaskStateInitializer()).thenReturn(streamTaskStateInitializer);
 
-		ClassLoader cl = environment.getUserClassLoader();
+		ClassLoader cl = env.getUserClassLoader();
 		when(mockTask.getUserCodeClassLoader()).thenReturn(cl);
 
 		when(mockTask.getCancelables()).thenReturn(this.closableRegistry);
@@ -219,31 +232,6 @@ public class AbstractStreamOperatorTestHarness<OUT> implements
AutoCloseable {
 			throw new RuntimeException(e.getMessage(), e);
 		}
 
-		try {
-			doAnswer(new Answer<OperatorStateBackend>() {
-				@Override
-				public OperatorStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable
{
-					final StreamOperator<?> operator = (StreamOperator<?>) invocationOnMock.getArguments()[0];
-					final Collection<OperatorStateHandle> stateHandles = (Collection<OperatorStateHandle>)
invocationOnMock.getArguments()[1];
-					OperatorStateBackend osb;
-
-					osb = stateBackend.createOperatorStateBackend(
-						environment,
-						operator.getClass().getSimpleName());
-
-					mockTask.getCancelables().registerCloseable(osb);
-
-					if (null != stateHandles) {
-						osb.restore(stateHandles);
-					}
-
-					return osb;
-				}
-			}).when(mockTask).createOperatorStateBackend(any(StreamOperator.class), any(Collection.class));
-		} catch (Exception e) {
-			throw new RuntimeException(e.getMessage(), e);
-		}
-
 		doAnswer(new Answer<ProcessingTimeService>() {
 			@Override
 			public ProcessingTimeService answer(InvocationOnMock invocation) throws Throwable {
@@ -253,6 +241,16 @@ public class AbstractStreamOperatorTestHarness<OUT> implements
AutoCloseable {
 
 	}
 
+	protected StreamTaskStateInitializer createStreamTaskStateManager(
+		Environment env,
+		StateBackend stateBackend,
+		ProcessingTimeService processingTimeService) {
+		return new StreamTaskStateInitializerImpl(
+			env,
+			stateBackend,
+			processingTimeService);
+	}
+
 	public void setStateBackend(StateBackend stateBackend) {
 		this.stateBackend = stateBackend;
 	}
@@ -261,8 +259,8 @@ public class AbstractStreamOperatorTestHarness<OUT> implements AutoCloseable
{
 		return mockTask.getCheckpointLock();
 	}
 
-	public Environment getEnvironment() {
-		return this.mockTask.getEnvironment();
+	public MockEnvironment getEnvironment() {
+		return environment;
 	}
 
 	public ExecutionConfig getExecutionConfig() {
@@ -306,12 +304,17 @@ public class AbstractStreamOperatorTestHarness<OUT> implements
AutoCloseable {
 	 * Calls {@link StreamOperator#setup(StreamTask, StreamConfig, Output)} ()}.
 	 */
 	public void setup(TypeSerializer<OUT> outputSerializer) {
-		operator.setup(mockTask, config, new MockOutput(outputSerializer));
-		setupCalled = true;
+		if (!setupCalled) {
+			this.streamTaskStateInitializer =
+				createStreamTaskStateManager(environment, stateBackend, processingTimeService);
+			when(mockTask.createStreamTaskStateInitializer()).thenReturn(streamTaskStateInitializer);
+			operator.setup(mockTask, config, new MockOutput(outputSerializer));
+			setupCalled = true;
+		}
 	}
 
 	/**
-	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState(OperatorSubtaskState)}.
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#initializeState()}.
 	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#setup(StreamTask,
StreamConfig, Output)}
 	 * if it was not called before.
 	 *
@@ -368,16 +371,20 @@ public class AbstractStreamOperatorTestHarness<OUT> implements
AutoCloseable {
 					rawOperatorState,
 					numSubtasks).get(subtaskIndex);
 
-			OperatorSubtaskState massagedOperatorStateHandles = new OperatorSubtaskState(
+			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(
 				nullToEmptyCollection(localManagedOperatorState),
 				nullToEmptyCollection(localRawOperatorState),
 				nullToEmptyCollection(localManagedKeyGroupState),
 				nullToEmptyCollection(localRawKeyGroupState));
 
-			operator.initializeState(massagedOperatorStateHandles);
-		} else {
-			operator.initializeState(null);
+			TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot();
+			taskStateSnapshot.putSubtaskStateByOperatorID(operator.getOperatorID(), operatorSubtaskState);
+
+			taskStateManager.setReportedCheckpointId(0);
+			taskStateManager.setTaskStateSnapshotsByCheckpointId(Collections.singletonMap(0L, taskStateSnapshot));
 		}
+
+		operator.initializeState();
 		initializeCalled = true;
 	}
 
@@ -487,10 +494,10 @@ public class AbstractStreamOperatorTestHarness<OUT> implements
AutoCloseable {
 	}
 
 	/**
-	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#notifyOfCompletedCheckpoint(long)}
()}.
+	 * Calls {@link org.apache.flink.streaming.api.operators.StreamOperator#notifyCheckpointComplete(long)}
()}.
 	 */
 	public void notifyOfCompletedCheckpoint(long checkpointId) throws Exception {
-		operator.notifyOfCompletedCheckpoint(checkpointId);
+		operator.notifyCheckpointComplete(checkpointId);
 	}
 
 	/**

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
index c2ec63a..2035c46 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedOneInputStreamOperatorTestHarness.java
@@ -18,32 +18,14 @@
 
 package org.apache.flink.streaming.util;
 
-import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.ClosureCleaner;
 import org.apache.flink.api.java.functions.KeySelector;
-import org.apache.flink.runtime.checkpoint.StateAssignmentOperation;
-import org.apache.flink.runtime.execution.Environment;
-import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
-import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.state.KeyedStateBackend;
-import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
-import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
-import org.apache.flink.util.Migration;
-
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
-
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.List;
-
-import static org.mockito.Matchers.any;
-import static org.mockito.Mockito.anyInt;
-import static org.mockito.Mockito.doAnswer;
 
 /**
  * Extension of {@link OneInputStreamOperatorTestHarness} that allows the operator to get
@@ -52,14 +34,6 @@ import static org.mockito.Mockito.doAnswer;
 public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 		extends OneInputStreamOperatorTestHarness<IN, OUT> {
 
-	// in case the operator creates one we store it here so that we
-	// can snapshot its state
-	private AbstractKeyedStateBackend<?> keyedStateBackend = null;
-
-	// when we restore we keep the state here so that we can call restore
-	// when the operator requests the keyed state backend
-	private List<KeyedStateHandle> restoredKeyedState = null;
-
 	public KeyedOneInputStreamOperatorTestHarness(
 			OneInputStreamOperator<IN, OUT> operator,
 			final KeySelector<IN, K> keySelector,
@@ -72,8 +46,6 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 		ClosureCleaner.clean(keySelector, false);
 		config.setStatePartitioner(0, keySelector);
 		config.setStateKeySerializer(keyType.createSerializer(executionConfig));
-
-		setupMockTaskCreateKeyedBackend();
 	}
 
 	public KeyedOneInputStreamOperatorTestHarness(
@@ -87,61 +59,18 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 			final OneInputStreamOperator<IN, OUT> operator,
 			final  KeySelector<IN, K> keySelector,
 			final TypeInformation<K> keyType,
-			final Environment environment) throws Exception {
+			final MockEnvironment environment) throws Exception {
 
 		super(operator, environment);
 
 		ClosureCleaner.clean(keySelector, false);
 		config.setStatePartitioner(0, keySelector);
 		config.setStateKeySerializer(keyType.createSerializer(executionConfig));
-
-		setupMockTaskCreateKeyedBackend();
-	}
-
-	private void setupMockTaskCreateKeyedBackend() {
-
-		try {
-			doAnswer(new Answer<KeyedStateBackend>() {
-				@Override
-				public KeyedStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable {
-
-					final TypeSerializer keySerializer = (TypeSerializer) invocationOnMock.getArguments()[0];
-					final int numberOfKeyGroups = (Integer) invocationOnMock.getArguments()[1];
-					final KeyGroupRange keyGroupRange = (KeyGroupRange) invocationOnMock.getArguments()[2];
-
-					if (keyedStateBackend != null) {
-						keyedStateBackend.dispose();
-					}
-
-					keyedStateBackend = stateBackend.createKeyedStateBackend(
-							mockTask.getEnvironment(),
-							new JobID(),
-							"test_op",
-							keySerializer,
-							numberOfKeyGroups,
-							keyGroupRange,
-							mockTask.getEnvironment().getTaskKvStateRegistry());
-
-					keyedStateBackend.restore(restoredKeyedState);
-
-					return keyedStateBackend;
-				}
-			}).when(mockTask).createKeyedStateBackend(any(TypeSerializer.class), anyInt(), any(KeyGroupRange.class));
-		} catch (Exception e) {
-			throw new RuntimeException(e.getMessage(), e);
-		}
-	}
-
-	private static boolean hasMigrationHandles(Collection<KeyedStateHandle> allKeyGroupsHandles)
{
-		for (KeyedStateHandle handle : allKeyGroupsHandles) {
-			if (handle instanceof Migration) {
-				return true;
-			}
-		}
-		return false;
 	}
 
 	public int numKeyedStateEntries() {
+		AbstractStreamOperator<?> abstractStreamOperator = (AbstractStreamOperator<?>)
operator;
+		KeyedStateBackend<Object> keyedStateBackend = abstractStreamOperator.getKeyedStateBackend();
 		if (keyedStateBackend instanceof HeapKeyedStateBackend) {
 			return ((HeapKeyedStateBackend) keyedStateBackend).numStateEntries();
 		} else {
@@ -150,47 +79,12 @@ public class KeyedOneInputStreamOperatorTestHarness<K, IN, OUT>
 	}
 
 	public <N> int numKeyedStateEntries(N namespace) {
+		AbstractStreamOperator<?> abstractStreamOperator = (AbstractStreamOperator<?>)
operator;
+		KeyedStateBackend<Object> keyedStateBackend = abstractStreamOperator.getKeyedStateBackend();
 		if (keyedStateBackend instanceof HeapKeyedStateBackend) {
 			return ((HeapKeyedStateBackend) keyedStateBackend).numStateEntries(namespace);
 		} else {
 			throw new UnsupportedOperationException();
 		}
 	}
-
-	@Override
-	public void initializeState(OperatorStateHandles operatorStateHandles) throws Exception
{
-		if (operatorStateHandles != null) {
-			int numKeyGroups = getEnvironment().getTaskInfo().getMaxNumberOfParallelSubtasks();
-			int numSubtasks = getEnvironment().getTaskInfo().getNumberOfParallelSubtasks();
-			int subtaskIndex = getEnvironment().getTaskInfo().getIndexOfThisSubtask();
-
-			// create a new OperatorStateHandles that only contains the state for our key-groups
-
-			List<KeyGroupRange> keyGroupPartitions = StateAssignmentOperation.createKeyGroupPartitions(
-					numKeyGroups,
-					numSubtasks);
-
-			KeyGroupRange localKeyGroupRange =
-					keyGroupPartitions.get(subtaskIndex);
-
-			restoredKeyedState = null;
-			Collection<KeyedStateHandle> managedKeyedState = operatorStateHandles.getManagedKeyedState();
-			if (managedKeyedState != null) {
-
-				// if we have migration handles, don't reshuffle state and preserve
-				// the migration tag
-				if (hasMigrationHandles(managedKeyedState)) {
-					List<KeyedStateHandle> result = new ArrayList<>(managedKeyedState.size());
-					result.addAll(managedKeyedState);
-					restoredKeyedState = result;
-				} else {
-					restoredKeyedState = StateAssignmentOperation.getKeyedStateHandles(
-							managedKeyedState,
-							localKeyGroupRange);
-				}
-			}
-		}
-
-		super.initializeState(operatorStateHandles);
-	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
index b0500ca..607eee0 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/KeyedTwoInputStreamOperatorTestHarness.java
@@ -18,27 +18,13 @@
 
 package org.apache.flink.streaming.util;
 
-import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
-import org.apache.flink.api.common.typeutils.TypeSerializer;
 import org.apache.flink.api.java.ClosureCleaner;
 import org.apache.flink.api.java.functions.KeySelector;
-import org.apache.flink.runtime.state.AbstractKeyedStateBackend;
-import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateBackend;
-import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
-import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
-
-import org.mockito.invocation.InvocationOnMock;
-import org.mockito.stubbing.Answer;
-
-import java.util.Collection;
-
-import static org.mockito.Matchers.any;
-import static org.mockito.Mockito.anyInt;
-import static org.mockito.Mockito.doAnswer;
 
 /**
  * Extension of {@link TwoInputStreamOperatorTestHarness} that allows the operator to get
@@ -47,14 +33,6 @@ import static org.mockito.Mockito.doAnswer;
 public class KeyedTwoInputStreamOperatorTestHarness<K, IN1, IN2, OUT>
 		extends TwoInputStreamOperatorTestHarness<IN1, IN2, OUT> {
 
-	// in case the operator creates one we store it here so that we
-	// can snapshot its state
-	private AbstractKeyedStateBackend<?> keyedStateBackend = null;
-
-	// when we restore we keep the state here so that we can call restore
-	// when the operator requests the keyed state backend
-	private Collection<KeyedStateHandle> restoredKeyedState = null;
-
 	public KeyedTwoInputStreamOperatorTestHarness(
 			TwoInputStreamOperator<IN1, IN2, OUT> operator,
 			KeySelector<IN1, K> keySelector1,
@@ -70,8 +48,6 @@ public class KeyedTwoInputStreamOperatorTestHarness<K, IN1, IN2, OUT>
 		config.setStatePartitioner(0, keySelector1);
 		config.setStatePartitioner(1, keySelector2);
 		config.setStateKeySerializer(keyType.createSerializer(executionConfig));
-
-		setupMockTaskCreateKeyedBackend();
 	}
 
 	public KeyedTwoInputStreamOperatorTestHarness(
@@ -82,50 +58,9 @@ public class KeyedTwoInputStreamOperatorTestHarness<K, IN1, IN2, OUT>
 		this(operator, keySelector1, keySelector2, keyType, 1, 1, 0);
 	}
 
-	private void setupMockTaskCreateKeyedBackend() {
-
-		try {
-			doAnswer(new Answer<KeyedStateBackend>() {
-				@Override
-				public KeyedStateBackend answer(InvocationOnMock invocationOnMock) throws Throwable {
-
-					final TypeSerializer keySerializer = (TypeSerializer) invocationOnMock.getArguments()[0];
-					final int numberOfKeyGroups = (Integer) invocationOnMock.getArguments()[1];
-					final KeyGroupRange keyGroupRange = (KeyGroupRange) invocationOnMock.getArguments()[2];
-
-					if (keyedStateBackend != null) {
-						keyedStateBackend.close();
-					}
-
-					keyedStateBackend = stateBackend.createKeyedStateBackend(
-							mockTask.getEnvironment(),
-							new JobID(),
-							"test_op",
-							keySerializer,
-							numberOfKeyGroups,
-							keyGroupRange,
-							mockTask.getEnvironment().getTaskKvStateRegistry());
-					if (restoredKeyedState != null) {
-						keyedStateBackend.restore(restoredKeyedState);
-					}
-					return keyedStateBackend;
-				}
-			}).when(mockTask).createKeyedStateBackend(any(TypeSerializer.class), anyInt(), any(KeyGroupRange.class));
-		} catch (Exception e) {
-			throw new RuntimeException(e.getMessage(), e);
-		}
-	}
-
-	@Override
-	public void initializeState(OperatorStateHandles operatorStateHandles) throws Exception
{
-		if (restoredKeyedState != null) {
-			restoredKeyedState = operatorStateHandles.getManagedKeyedState();
-		}
-
-		super.initializeState(operatorStateHandles);
-	}
-
 	public int numKeyedStateEntries() {
+		AbstractStreamOperator<?> abstractStreamOperator = (AbstractStreamOperator<?>)
operator;
+		KeyedStateBackend<Object> keyedStateBackend = abstractStreamOperator.getKeyedStateBackend();
 		if (keyedStateBackend instanceof HeapKeyedStateBackend) {
 			return ((HeapKeyedStateBackend) keyedStateBackend).numStateEntries();
 		} else {

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
index 5c7d986..66d2f69 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/OneInputStreamOperatorTestHarness.java
@@ -19,7 +19,7 @@
 package org.apache.flink.streaming.util;
 
 import org.apache.flink.api.common.typeutils.TypeSerializer;
-import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.watermark.Watermark;
 import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
@@ -63,7 +63,7 @@ public class OneInputStreamOperatorTestHarness<IN, OUT>
 	public OneInputStreamOperatorTestHarness(
 		OneInputStreamOperator<IN, OUT> operator,
 		TypeSerializer<IN> typeSerializerIn,
-		Environment environment) throws Exception {
+		MockEnvironment environment) throws Exception {
 		this(operator, environment);
 
 		config.setTypeSerializerIn1(Preconditions.checkNotNull(typeSerializerIn));
@@ -85,7 +85,7 @@ public class OneInputStreamOperatorTestHarness<IN, OUT>
 
 	public OneInputStreamOperatorTestHarness(
 		OneInputStreamOperator<IN, OUT> operator,
-		Environment environment) throws Exception {
+		MockEnvironment environment) throws Exception {
 		super(operator, environment);
 
 		this.oneInputOperator = operator;

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/SourceFunctionUtil.java
----------------------------------------------------------------------
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/SourceFunctionUtil.java
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/SourceFunctionUtil.java
index 312891e..3f54081 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/SourceFunctionUtil.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/util/SourceFunctionUtil.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.functions.RuntimeContext;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.operators.testutils.MockEnvironment;
 import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider;
+import org.apache.flink.runtime.state.TestTaskStateManager;
 import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
 import org.apache.flink.streaming.api.operators.StreamingRuntimeContext;
@@ -50,7 +51,13 @@ public class SourceFunctionUtil {
 	}
 
 	private static <T extends Serializable> List<T> runRichSourceFunction(SourceFunction<T>
sourceFunction) throws Exception {
-		try (MockEnvironment environment = new MockEnvironment("MockTask", 3 * 1024 * 1024, new
MockInputSplitProvider(), 1024)) {
+		try (MockEnvironment environment = new MockEnvironment(
+			"MockTask",
+			3 * 1024 * 1024,
+			new MockInputSplitProvider(),
+			1024,
+			new TestTaskStateManager())) {
+
 			AbstractStreamOperator<?> operator = mock(AbstractStreamOperator.class);
 			when(operator.getExecutionConfig()).thenReturn(new ExecutionConfig());
 

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
index 21ce77a..9c023d4 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
@@ -38,6 +38,7 @@ import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.runtime.state.memory.MemoryBackendCheckpointStorage;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.util.ExceptionUtils;
 
 import org.junit.Test;
 
@@ -86,8 +87,7 @@ public class StateBackendITCase extends AbstractTestBase {
 			fail();
 		}
 		catch (JobExecutionException e) {
-			Throwable t = e.getCause();
-			assertTrue("wrong exception", t instanceof SuccessException);
+			assertTrue(ExceptionUtils.findThrowable(e, SuccessException.class).isPresent());
 		}
 	}
 
@@ -113,7 +113,7 @@ public class StateBackendITCase extends AbstractTestBase {
 		@Override
 		public CheckpointStreamFactory createSavepointStreamFactory(JobID jobId,
 			String operatorIdentifier, String targetLocation) throws IOException {
-			throw new UnsupportedOperationException();
+			throw new SuccessException();
 		}
 
 		@Override
@@ -133,7 +133,7 @@ public class StateBackendITCase extends AbstractTestBase {
 			Environment env,
 			String operatorIdentifier) throws Exception {
 
-			throw new UnsupportedOperationException();
+			throw new SuccessException();
 		}
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/517b3f87/flink-tests/src/test/java/org/apache/flink/test/typeserializerupgrade/PojoSerializerUpgradeTest.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/typeserializerupgrade/PojoSerializerUpgradeTest.java
b/flink-tests/src/test/java/org/apache/flink/test/typeserializerupgrade/PojoSerializerUpgradeTest.java
index ffe220e..908666a 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/typeserializerupgrade/PojoSerializerUpgradeTest.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/typeserializerupgrade/PojoSerializerUpgradeTest.java
@@ -41,6 +41,7 @@ import org.apache.flink.runtime.state.FunctionInitializationContext;
 import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.runtime.state.StateBackend;
 import org.apache.flink.runtime.state.StateBackendLoader;
+import org.apache.flink.runtime.state.TestTaskStateManager;
 import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
 import org.apache.flink.streaming.api.operators.StreamMap;
@@ -48,6 +49,7 @@ import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
 import org.apache.flink.streaming.util.KeyedOneInputStreamOperatorTestHarness;
 import org.apache.flink.streaming.util.OneInputStreamOperatorTestHarness;
 import org.apache.flink.util.DynamicCodeLoadingException;
+import org.apache.flink.util.IOUtils;
 import org.apache.flink.util.StateMigrationException;
 import org.apache.flink.util.TestLogger;
 
@@ -350,49 +352,49 @@ public class PojoSerializerUpgradeTest extends TestLogger {
 			Iterable<Long> input) throws Exception {
 
 		try (final MockEnvironment environment = new MockEnvironment(
-				"test task",
-				32 * 1024,
-				new MockInputSplitProvider(),
-				256,
-				taskConfiguration,
-				executionConfig,
-				16,
-				1,
-				0,
-				classLoader)) {
-
-			OneInputStreamOperatorTestHarness<Long, Long> harness;
-
-			if (isKeyedState) {
-				harness = new KeyedOneInputStreamOperatorTestHarness<>(
-					operator,
-					keySelector,
-					BasicTypeInfo.LONG_TYPE_INFO,
-					environment);
-			} else {
-				harness = new OneInputStreamOperatorTestHarness<>(operator, LongSerializer.INSTANCE,
environment);
-			}
-
-			harness.setStateBackend(stateBackend);
-
-			harness.setup();
-			harness.initializeState(operatorStateHandles);
-			harness.open();
+			"test task",
+			32 * 1024,
+			new MockInputSplitProvider(),
+			256,
+			taskConfiguration,
+			executionConfig,
+			16,
+			1,
+			0,
+			classLoader,
+			new TestTaskStateManager())) {
+
+			OneInputStreamOperatorTestHarness<Long, Long> harness = null;
+			try {
+				if (isKeyedState) {
+					harness = new KeyedOneInputStreamOperatorTestHarness<>(
+						operator,
+						keySelector,
+						BasicTypeInfo.LONG_TYPE_INFO,
+						environment);
+				} else {
+					harness = new OneInputStreamOperatorTestHarness<>(operator, LongSerializer.INSTANCE,
environment);
+				}
 
-			long timestamp = 0L;
+				harness.setStateBackend(stateBackend);
 
-			for (Long value : input) {
-				harness.processElement(value, timestamp++);
-			}
+				harness.setup();
+				harness.initializeState(operatorStateHandles);
+				harness.open();
 
-			long checkpointId = 1L;
-			long checkpointTimestamp = timestamp + 1L;
+				long timestamp = 0L;
 
-			OperatorStateHandles stateHandles = harness.snapshot(checkpointId, checkpointTimestamp);
+				for (Long value : input) {
+					harness.processElement(value, timestamp++);
+				}
 
-			harness.close();
+				long checkpointId = 1L;
+				long checkpointTimestamp = timestamp + 1L;
 
-			return stateHandles;
+				return harness.snapshot(checkpointId, checkpointTimestamp);
+			} finally {
+				IOUtils.closeQuietly(harness);
+			}
 		}
 	}
 


Mime
View raw message