flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From srich...@apache.org
Subject [6/7] flink git commit: [FLINK-7213] Introduce state management by OperatorID in TaskManager
Date Tue, 15 Aug 2017 12:57:14 GMT
http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
index 344b340..88b95f5 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorFailureTest.java
@@ -23,14 +23,15 @@ import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.ExecutionVertex;
 import org.apache.flink.runtime.jobgraph.JobStatus;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
-import org.apache.flink.runtime.state.ChainedStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StreamStateHandle;
 import org.apache.flink.util.TestLogger;
+
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.powermock.core.classloader.annotations.PrepareForTest;
@@ -42,8 +43,8 @@ import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
-import static org.mockito.Matchers.anyInt;
 import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
 
@@ -89,29 +90,26 @@ public class CheckpointCoordinatorFailureTest extends TestLogger {
 		assertFalse(pendingCheckpoint.isDiscarded());
 
 		final long checkpointId = coord.getPendingCheckpoints().keySet().iterator().next();
-		
-		SubtaskState subtaskState = mock(SubtaskState.class);
+
 
 		StreamStateHandle legacyHandle = mock(StreamStateHandle.class);
-		ChainedStateHandle<StreamStateHandle> chainedLegacyHandle = mock(ChainedStateHandle.class);
-		when(chainedLegacyHandle.get(anyInt())).thenReturn(legacyHandle);
-		when(subtaskState.getLegacyOperatorState()).thenReturn(chainedLegacyHandle);
+		KeyedStateHandle managedKeyedHandle = mock(KeyedStateHandle.class);
+		KeyedStateHandle rawKeyedHandle = mock(KeyedStateHandle.class);
+		OperatorStateHandle managedOpHandle = mock(OperatorStateHandle.class);
+		OperatorStateHandle rawOpHandle = mock(OperatorStateHandle.class);
 
-		OperatorStateHandle managedHandle = mock(OperatorStateHandle.class);
-		ChainedStateHandle<OperatorStateHandle> chainedManagedHandle = mock(ChainedStateHandle.class);
-		when(chainedManagedHandle.get(anyInt())).thenReturn(managedHandle);
-		when(subtaskState.getManagedOperatorState()).thenReturn(chainedManagedHandle);
+		final OperatorSubtaskState operatorSubtaskState = spy(new OperatorSubtaskState(
+			legacyHandle,
+			managedOpHandle,
+			rawOpHandle,
+			managedKeyedHandle,
+			rawKeyedHandle));
 
-		OperatorStateHandle rawHandle = mock(OperatorStateHandle.class);
-		ChainedStateHandle<OperatorStateHandle> chainedRawHandle = mock(ChainedStateHandle.class);
-		when(chainedRawHandle.get(anyInt())).thenReturn(rawHandle);
-		when(subtaskState.getRawOperatorState()).thenReturn(chainedRawHandle);
+		TaskStateSnapshot subtaskState = spy(new TaskStateSnapshot());
+		subtaskState.putSubtaskStateByOperatorID(new OperatorID(), operatorSubtaskState);
+
+		when(subtaskState.getSubtaskStateByOperatorID(OperatorID.fromJobVertexID(vertex.getJobvertexId()))).thenReturn(operatorSubtaskState);
 
-		KeyedStateHandle managedKeyedHandle = mock(KeyedStateHandle.class);
-		when(subtaskState.getRawKeyedState()).thenReturn(managedKeyedHandle);
-		KeyedStateHandle managedRawHandle = mock(KeyedStateHandle.class);
-		when(subtaskState.getManagedKeyedState()).thenReturn(managedRawHandle);
-		
 		AcknowledgeCheckpoint acknowledgeMessage = new AcknowledgeCheckpoint(jid, executionAttemptId, checkpointId, new CheckpointMetrics(), subtaskState);
 		
 		try {
@@ -126,11 +124,12 @@ public class CheckpointCoordinatorFailureTest extends TestLogger {
 		assertTrue(pendingCheckpoint.isDiscarded());
 
 		// make sure that the subtask state has been discarded after we could not complete it.
-		verify(subtaskState.getLegacyOperatorState().get(0)).discardState();
-		verify(subtaskState.getManagedOperatorState().get(0)).discardState();
-		verify(subtaskState.getRawOperatorState().get(0)).discardState();
-		verify(subtaskState.getManagedKeyedState()).discardState();
-		verify(subtaskState.getRawKeyedState()).discardState();
+		verify(operatorSubtaskState).discardState();
+		verify(operatorSubtaskState.getLegacyOperatorState()).discardState();
+		verify(operatorSubtaskState.getManagedOperatorState().iterator().next()).discardState();
+		verify(operatorSubtaskState.getRawOperatorState().iterator().next()).discardState();
+		verify(operatorSubtaskState.getManagedKeyedState().iterator().next()).discardState();
+		verify(operatorSubtaskState.getRawKeyedState().iterator().next()).discardState();
 	}
 
 	private static final class FailingCompletedCheckpointStore implements CompletedCheckpointStore {

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
index cb92df6..d9af879 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java
@@ -44,7 +44,6 @@ import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.SharedStateRegistry;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.state.filesystem.FileStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.testutils.CommonTestUtils;
@@ -93,7 +92,6 @@ import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
 import static org.mockito.Matchers.anyLong;
 import static org.mockito.Mockito.doAnswer;
-import static org.mockito.Mockito.doReturn;
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.mock;
 import static org.mockito.Mockito.never;
@@ -102,7 +100,6 @@ import static org.mockito.Mockito.spy;
 import static org.mockito.Mockito.times;
 import static org.mockito.Mockito.verify;
 import static org.mockito.Mockito.when;
-import static org.mockito.Mockito.withSettings;
 
 /**
  * Tests for the checkpoint coordinator.
@@ -555,31 +552,29 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			assertFalse(checkpoint.isDiscarded());
 			assertFalse(checkpoint.isFullyAcknowledged());
 
-			OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId());
-			OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId());
-
-			Map<OperatorID, OperatorState> operatorStates = checkpoint.getOperatorStates();
-
-			operatorStates.put(opID1, new SpyInjectingOperatorState(
-				opID1, vertex1.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism()));
-			operatorStates.put(opID2, new SpyInjectingOperatorState(
-				opID2, vertex2.getTotalNumberOfParallelSubtasks(), vertex2.getMaxParallelism()));
-
 			// check that the vertices received the trigger checkpoint message
 			{
 				verify(vertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId), eq(timestamp), any(CheckpointOptions.class));
 				verify(vertex2.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId), eq(timestamp), any(CheckpointOptions.class));
 			}
 
+			OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId());
+			OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId());
+			TaskStateSnapshot taskOperatorSubtaskStates1 = mock(TaskStateSnapshot.class);
+			TaskStateSnapshot taskOperatorSubtaskStates2 = mock(TaskStateSnapshot.class);
+			OperatorSubtaskState subtaskState1 = mock(OperatorSubtaskState.class);
+			OperatorSubtaskState subtaskState2 = mock(OperatorSubtaskState.class);
+			when(taskOperatorSubtaskStates1.getSubtaskStateByOperatorID(opID1)).thenReturn(subtaskState1);
+			when(taskOperatorSubtaskStates2.getSubtaskStateByOperatorID(opID2)).thenReturn(subtaskState2);
+
 			// acknowledge from one of the tasks
-			AcknowledgeCheckpoint acknowledgeCheckpoint1 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class));
+			AcknowledgeCheckpoint acknowledgeCheckpoint1 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates2);
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1);
-			OperatorSubtaskState subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex());
 			assertEquals(1, checkpoint.getNumberOfAcknowledgedTasks());
 			assertEquals(1, checkpoint.getNumberOfNonAcknowledgedTasks());
 			assertFalse(checkpoint.isDiscarded());
 			assertFalse(checkpoint.isFullyAcknowledged());
-			verify(subtaskState2, never()).registerSharedStates(any(SharedStateRegistry.class));
+			verify(taskOperatorSubtaskStates2, never()).registerSharedStates(any(SharedStateRegistry.class));
 
 			// acknowledge the same task again (should not matter)
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint1);
@@ -588,8 +583,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			verify(subtaskState2, never()).registerSharedStates(any(SharedStateRegistry.class));
 
 			// acknowledge the other task.
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class)));
-			OperatorSubtaskState subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex());
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates1));
 
 			// the checkpoint is internally converted to a successful checkpoint and the
 			// pending checkpoint object is disposed
@@ -628,9 +622,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 			long checkpointIdNew = coord.getPendingCheckpoints().entrySet().iterator().next().getKey();
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointIdNew));
-			subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex());
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointIdNew));
-			subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex());
 
 			assertEquals(0, coord.getNumberOfPendingCheckpoints());
 			assertEquals(1, coord.getNumberOfRetainedSuccessfulCheckpoints());
@@ -852,18 +844,20 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			OperatorID opID2 = OperatorID.fromJobVertexID(ackVertex2.getJobvertexId());
 			OperatorID opID3 = OperatorID.fromJobVertexID(ackVertex3.getJobvertexId());
 
-			Map<OperatorID, OperatorState> operatorStates1 = pending1.getOperatorStates();
+			TaskStateSnapshot taskOperatorSubtaskStates1_1 = spy(new TaskStateSnapshot());
+			TaskStateSnapshot taskOperatorSubtaskStates1_2 = spy(new TaskStateSnapshot());
+			TaskStateSnapshot taskOperatorSubtaskStates1_3 = spy(new TaskStateSnapshot());
 
-			operatorStates1.put(opID1, new SpyInjectingOperatorState(
-				opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism()));
-			operatorStates1.put(opID2, new SpyInjectingOperatorState(
-				opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism()));
-			operatorStates1.put(opID3, new SpyInjectingOperatorState(
-				opID3, ackVertex3.getTotalNumberOfParallelSubtasks(), ackVertex3.getMaxParallelism()));
+			OperatorSubtaskState subtaskState1_1 = mock(OperatorSubtaskState.class);
+			OperatorSubtaskState subtaskState1_2 = mock(OperatorSubtaskState.class);
+			OperatorSubtaskState subtaskState1_3 = mock(OperatorSubtaskState.class);
+			taskOperatorSubtaskStates1_1.putSubtaskStateByOperatorID(opID1, subtaskState1_1);
+			taskOperatorSubtaskStates1_2.putSubtaskStateByOperatorID(opID2, subtaskState1_2);
+			taskOperatorSubtaskStates1_3.putSubtaskStateByOperatorID(opID3, subtaskState1_3);
 
 			// acknowledge one of the three tasks
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), mock(SubtaskState.class)));
-			OperatorSubtaskState subtaskState1_2 = operatorStates1.get(opID2).getState(ackVertex2.getParallelSubtaskIndex());
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_2));
+
 			// start the second checkpoint
 			// trigger the first checkpoint. this should succeed
 			assertTrue(coord.triggerCheckpoint(timestamp2, false));
@@ -880,14 +874,17 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			}
 			long checkpointId2 = pending2.getCheckpointId();
 
-			Map<OperatorID, OperatorState> operatorStates2 = pending2.getOperatorStates();
+			TaskStateSnapshot taskOperatorSubtaskStates2_1 = spy(new TaskStateSnapshot());
+			TaskStateSnapshot taskOperatorSubtaskStates2_2 = spy(new TaskStateSnapshot());
+			TaskStateSnapshot taskOperatorSubtaskStates2_3 = spy(new TaskStateSnapshot());
+
+			OperatorSubtaskState subtaskState2_1 = mock(OperatorSubtaskState.class);
+			OperatorSubtaskState subtaskState2_2 = mock(OperatorSubtaskState.class);
+			OperatorSubtaskState subtaskState2_3 = mock(OperatorSubtaskState.class);
 
-			operatorStates2.put(opID1, new SpyInjectingOperatorState(
-				opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism()));
-			operatorStates2.put(opID2, new SpyInjectingOperatorState(
-				opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism()));
-			operatorStates2.put(opID3, new SpyInjectingOperatorState(
-				opID3, ackVertex3.getTotalNumberOfParallelSubtasks(), ackVertex3.getMaxParallelism()));
+			taskOperatorSubtaskStates2_1.putSubtaskStateByOperatorID(opID1, subtaskState2_1);
+			taskOperatorSubtaskStates2_2.putSubtaskStateByOperatorID(opID2, subtaskState2_2);
+			taskOperatorSubtaskStates2_3.putSubtaskStateByOperatorID(opID3, subtaskState2_3);
 
 			// trigger messages should have been sent
 			verify(triggerVertex1.getCurrentExecutionAttempt(), times(1)).triggerCheckpoint(eq(checkpointId2), eq(timestamp2), any(CheckpointOptions.class));
@@ -896,17 +893,13 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			// we acknowledge one more task from the first checkpoint and the second
 			// checkpoint completely. The second checkpoint should then subsume the first checkpoint
 
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class)));
-			OperatorSubtaskState subtaskState2_3 = operatorStates2.get(opID3).getState(ackVertex3.getParallelSubtaskIndex());
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_3));
 
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class)));
-			OperatorSubtaskState subtaskState2_1 = operatorStates2.get(opID1).getState(ackVertex1.getParallelSubtaskIndex());
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_1));
 
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), mock(SubtaskState.class)));
-			OperatorSubtaskState subtaskState1_1 = operatorStates1.get(opID1).getState(ackVertex1.getParallelSubtaskIndex());
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_1));
 
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), mock(SubtaskState.class)));
-			OperatorSubtaskState subtaskState2_2 = operatorStates2.get(opID2).getState(ackVertex2.getParallelSubtaskIndex());
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID2, checkpointId2, new CheckpointMetrics(), taskOperatorSubtaskStates2_2));
 
 			// now, the second checkpoint should be confirmed, and the first discarded
 			// actually both pending checkpoints are discarded, and the second has been transformed
@@ -938,8 +931,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			verify(commitVertex.getCurrentExecutionAttempt(), times(1)).notifyCheckpointComplete(eq(checkpointId2), eq(timestamp2));
 
 			// send the last remaining ack for the first checkpoint. This should not do anything
-			SubtaskState subtaskState1_3 = mock(SubtaskState.class);
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), subtaskState1_3));
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID3, checkpointId1, new CheckpointMetrics(), taskOperatorSubtaskStates1_3));
 			verify(subtaskState1_3, times(1)).discardState();
 
 			coord.shutdown(JobStatus.FINISHED);
@@ -1005,13 +997,11 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 			OperatorID opID1 = OperatorID.fromJobVertexID(ackVertex1.getJobvertexId());
 
-			Map<OperatorID, OperatorState> operatorStates = checkpoint.getOperatorStates();
+			TaskStateSnapshot taskOperatorSubtaskStates1 = spy(new TaskStateSnapshot());
+			OperatorSubtaskState subtaskState1 = mock(OperatorSubtaskState.class);
+			taskOperatorSubtaskStates1.putSubtaskStateByOperatorID(opID1, subtaskState1);
 
-			operatorStates.put(opID1, new SpyInjectingOperatorState(
-				opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism()));
-
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId(), new CheckpointMetrics(), mock(SubtaskState.class)));
-			OperatorSubtaskState subtaskState = operatorStates.get(opID1).getState(ackVertex1.getParallelSubtaskIndex());
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, ackAttemptID1, checkpoint.getCheckpointId(), new CheckpointMetrics(), taskOperatorSubtaskStates1));
 
 			// wait until the checkpoint must have expired.
 			// we check every 250 msecs conservatively for 5 seconds
@@ -1029,7 +1019,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
 
 			// validate that the received states have been discarded
-			verify(subtaskState, times(1)).discardState();
+			verify(subtaskState1, times(1)).discardState();
 
 			// no confirm message must have been sent
 			verify(commitVertex.getCurrentExecutionAttempt(), times(0)).notifyCheckpointComplete(anyLong(), anyLong());
@@ -1147,26 +1137,18 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		long checkpointId = pendingCheckpoint.getCheckpointId();
 
 		OperatorID opIDtrigger = OperatorID.fromJobVertexID(triggerVertex.getJobvertexId());
-		OperatorID opID1 = OperatorID.fromJobVertexID(ackVertex1.getJobvertexId());
-		OperatorID opID2 = OperatorID.fromJobVertexID(ackVertex2.getJobvertexId());
-
-		Map<OperatorID, OperatorState> operatorStates = pendingCheckpoint.getOperatorStates();
 
-		operatorStates.put(opIDtrigger, new SpyInjectingOperatorState(
-			opIDtrigger, triggerVertex.getTotalNumberOfParallelSubtasks(), triggerVertex.getMaxParallelism()));
-		operatorStates.put(opID1, new SpyInjectingOperatorState(
-			opID1, ackVertex1.getTotalNumberOfParallelSubtasks(), ackVertex1.getMaxParallelism()));
-		operatorStates.put(opID2, new SpyInjectingOperatorState(
-			opID2, ackVertex2.getTotalNumberOfParallelSubtasks(), ackVertex2.getMaxParallelism()));
+		TaskStateSnapshot taskOperatorSubtaskStatesTrigger = spy(new TaskStateSnapshot());
+		OperatorSubtaskState subtaskStateTrigger = mock(OperatorSubtaskState.class);
+		taskOperatorSubtaskStatesTrigger.putSubtaskStateByOperatorID(opIDtrigger, subtaskStateTrigger);
 
 		// acknowledge the first trigger vertex
-		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class)));
-		OperatorSubtaskState storedTriggerSubtaskState = operatorStates.get(opIDtrigger).getState(triggerVertex.getParallelSubtaskIndex());
+		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStatesTrigger));
 
 		// verify that the subtask state has not been discarded
-		verify(storedTriggerSubtaskState, never()).discardState();
+		verify(subtaskStateTrigger, never()).discardState();
 
-		SubtaskState unknownSubtaskState = mock(SubtaskState.class);
+		TaskStateSnapshot unknownSubtaskState = mock(TaskStateSnapshot.class);
 
 		// receive an acknowledge message for an unknown vertex
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), unknownSubtaskState));
@@ -1174,7 +1156,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		// we should discard acknowledge messages from an unknown vertex belonging to our job
 		verify(unknownSubtaskState, times(1)).discardState();
 
-		SubtaskState differentJobSubtaskState = mock(SubtaskState.class);
+		TaskStateSnapshot differentJobSubtaskState = mock(TaskStateSnapshot.class);
 
 		// receive an acknowledge message from an unknown job
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(new JobID(), new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), differentJobSubtaskState));
@@ -1183,22 +1165,22 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		verify(differentJobSubtaskState, never()).discardState();
 
 		// duplicate acknowledge message for the trigger vertex
-		SubtaskState triggerSubtaskState = mock(SubtaskState.class);
+		TaskStateSnapshot triggerSubtaskState = mock(TaskStateSnapshot.class);
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, triggerAttemptId, checkpointId, new CheckpointMetrics(), triggerSubtaskState));
 
 		// duplicate acknowledge messages for a known vertex should not trigger discarding the state
 		verify(triggerSubtaskState, never()).discardState();
 
 		// let the checkpoint fail at the first ack vertex
-		reset(storedTriggerSubtaskState);
+		reset(subtaskStateTrigger);
 		coord.receiveDeclineMessage(new DeclineCheckpoint(jobId, ackAttemptId1, checkpointId));
 
 		assertTrue(pendingCheckpoint.isDiscarded());
 
 		// check that we've cleaned up the already acknowledged state
-		verify(storedTriggerSubtaskState, times(1)).discardState();
+		verify(subtaskStateTrigger, times(1)).discardState();
 
-		SubtaskState ackSubtaskState = mock(SubtaskState.class);
+		TaskStateSnapshot ackSubtaskState = mock(TaskStateSnapshot.class);
 
 		// late acknowledge message from the second ack vertex
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, ackAttemptId2, checkpointId, new CheckpointMetrics(), ackSubtaskState));
@@ -1213,7 +1195,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		// we should not interfere with different jobs
 		verify(differentJobSubtaskState, never()).discardState();
 
-		SubtaskState unknownSubtaskState2 = mock(SubtaskState.class);
+		TaskStateSnapshot unknownSubtaskState2 = mock(TaskStateSnapshot.class);
 
 		// receive an acknowledge message for an unknown vertex
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jobId, new ExecutionAttemptID(), checkpointId, new CheckpointMetrics(), unknownSubtaskState2));
@@ -1470,18 +1452,16 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 		OperatorID opID1 = OperatorID.fromJobVertexID(vertex1.getJobvertexId());
 		OperatorID opID2 = OperatorID.fromJobVertexID(vertex2.getJobvertexId());
-
-		Map<OperatorID, OperatorState> operatorStates = pending.getOperatorStates();
-
-		operatorStates.put(opID1, new SpyInjectingOperatorState(
-			opID1, vertex1.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism()));
-		operatorStates.put(opID2, new SpyInjectingOperatorState(
-			opID2, vertex2.getTotalNumberOfParallelSubtasks(), vertex1.getMaxParallelism()));
+		TaskStateSnapshot taskOperatorSubtaskStates1 = mock(TaskStateSnapshot.class);
+		TaskStateSnapshot taskOperatorSubtaskStates2 = mock(TaskStateSnapshot.class);
+		OperatorSubtaskState subtaskState1 = mock(OperatorSubtaskState.class);
+		OperatorSubtaskState subtaskState2 = mock(OperatorSubtaskState.class);
+		when(taskOperatorSubtaskStates1.getSubtaskStateByOperatorID(opID1)).thenReturn(subtaskState1);
+		when(taskOperatorSubtaskStates2.getSubtaskStateByOperatorID(opID2)).thenReturn(subtaskState2);
 
 		// acknowledge from one of the tasks
-		AcknowledgeCheckpoint acknowledgeCheckpoint2 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class));
+		AcknowledgeCheckpoint acknowledgeCheckpoint2 = new AcknowledgeCheckpoint(jid, attemptID2, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates2);
 		coord.receiveAcknowledgeMessage(acknowledgeCheckpoint2);
-		OperatorSubtaskState subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex());
 		assertEquals(1, pending.getNumberOfAcknowledgedTasks());
 		assertEquals(1, pending.getNumberOfNonAcknowledgedTasks());
 		assertFalse(pending.isDiscarded());
@@ -1495,8 +1475,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		assertFalse(savepointFuture.isDone());
 
 		// acknowledge the other task.
-		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), mock(SubtaskState.class)));
-		OperatorSubtaskState subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex());
+		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointId, new CheckpointMetrics(), taskOperatorSubtaskStates1));
 
 		// the checkpoint is internally converted to a successful checkpoint and the
 		// pending checkpoint object is disposed
@@ -1536,9 +1515,6 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID1, checkpointIdNew));
 		coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, attemptID2, checkpointIdNew));
 
-		subtaskState1 = operatorStates.get(opID1).getState(vertex1.getParallelSubtaskIndex());
-		subtaskState2 = operatorStates.get(opID2).getState(vertex2.getParallelSubtaskIndex());
-
 		assertEquals(0, coord.getNumberOfPendingCheckpoints());
 		assertEquals(0, coord.getNumberOfRetainedSuccessfulCheckpoints());
 
@@ -2037,20 +2013,8 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		List<KeyGroupRange> keyGroupPartitions1 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism1, parallelism1);
 		List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
 
-		PendingCheckpoint pending = coord.getPendingCheckpoints().get(checkpointId);
-
-		OperatorID opID1 = OperatorID.fromJobVertexID(jobVertexID1);
-		OperatorID opID2 = OperatorID.fromJobVertexID(jobVertexID2);
-
-		Map<OperatorID, OperatorState> operatorStates = pending.getOperatorStates();
-
-		operatorStates.put(opID1, new SpyInjectingOperatorState(
-			opID1, jobVertex1.getParallelism(), jobVertex1.getMaxParallelism()));
-		operatorStates.put(opID2, new SpyInjectingOperatorState(
-			opID2, jobVertex2.getParallelism(), jobVertex2.getMaxParallelism()));
-
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
-			SubtaskState subtaskState = mockSubtaskState(jobVertexID1, index, keyGroupPartitions1.get(index));
+			TaskStateSnapshot subtaskState = mockSubtaskState(jobVertexID1, index, keyGroupPartitions1.get(index));
 
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
@@ -2063,7 +2027,7 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		}
 
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
-			SubtaskState subtaskState = mockSubtaskState(jobVertexID2, index, keyGroupPartitions2.get(index));
+			TaskStateSnapshot subtaskState = mockSubtaskState(jobVertexID2, index, keyGroupPartitions2.get(index));
 
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
@@ -2165,30 +2129,34 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		List<KeyGroupRange> keyGroupPartitions2 = StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
 
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
-			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
+			StreamStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index);
 			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false);
-			SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null);
+			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(valueSizeTuple, null, null, keyGroupState, null);
+			TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot();
+			taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1), operatorSubtaskState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
 					checkpointId,
 					new CheckpointMetrics(),
-					checkpointStateHandles);
+				taskOperatorSubtaskStates);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
 
 
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
-			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID2, index);
+			StreamStateHandle valueSizeTuple = generateStateForVertex(jobVertexID2, index);
 			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
-			SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null);
+			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(valueSizeTuple, null, null, keyGroupState, null);
+			TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot();
+			taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2), operatorSubtaskState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
 					checkpointId,
 					new CheckpointMetrics(),
-					checkpointStateHandles);
+					taskOperatorSubtaskStates);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -2284,17 +2252,20 @@ public class CheckpointCoordinatorTest extends TestLogger {
 				StateAssignmentOperation.createKeyGroupPartitions(maxParallelism2, parallelism2);
 
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
-			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
+			StreamStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index);
 			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(
 					jobVertexID1, keyGroupPartitions1.get(index), false);
 
-			SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, null, null, keyGroupState, null);
+			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(valueSizeTuple, null, null, keyGroupState, null);
+			TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot();
+			taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1), operatorSubtaskState);
+
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
 					checkpointId,
 					new CheckpointMetrics(),
-					checkpointStateHandles);
+					taskOperatorSubtaskStates);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -2302,17 +2273,19 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
 
-			ChainedStateHandle<StreamStateHandle> state = generateStateForVertex(jobVertexID2, index);
+			StreamStateHandle state = generateStateForVertex(jobVertexID2, index);
 			KeyGroupsStateHandle keyGroupState = generateKeyGroupState(
 					jobVertexID2, keyGroupPartitions2.get(index), false);
 
-			SubtaskState checkpointStateHandles = new SubtaskState(state, null, null, keyGroupState, null);
+			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(state, null, null, keyGroupState, null);
+			TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot();
+			taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2), operatorSubtaskState);
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
 					checkpointId,
 					new CheckpointMetrics(),
-					checkpointStateHandles);
+					taskOperatorSubtaskStates);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -2438,18 +2411,21 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 		//vertex 1
 		for (int index = 0; index < jobVertex1.getParallelism(); index++) {
-			ChainedStateHandle<StreamStateHandle> valueSizeTuple = generateStateForVertex(jobVertexID1, index);
-			ChainedStateHandle<OperatorStateHandle> opStateBackend = generateChainedPartitionableStateHandle(jobVertexID1, index, 2, 8, false);
+			StreamStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index);
+			OperatorStateHandle opStateBackend = generatePartitionableStateHandle(jobVertexID1, index, 2, 8, false);
 			KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), false);
 			KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index), true);
 
-			SubtaskState checkpointStateHandles = new SubtaskState(valueSizeTuple, opStateBackend, null, keyedStateBackend, keyedStateRaw);
+			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(valueSizeTuple, opStateBackend, null, keyedStateBackend, keyedStateRaw);
+			TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot();
+			taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID1), operatorSubtaskState);
+
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
 					checkpointId,
 					new CheckpointMetrics(),
-					checkpointStateHandles);
+					taskOperatorSubtaskStates);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -2460,19 +2436,21 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		for (int index = 0; index < jobVertex2.getParallelism(); index++) {
 			KeyGroupsStateHandle keyedStateBackend = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), false);
 			KeyGroupsStateHandle keyedStateRaw = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index), true);
-			ChainedStateHandle<OperatorStateHandle> opStateBackend = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, false);
-			ChainedStateHandle<OperatorStateHandle> opStateRaw = generateChainedPartitionableStateHandle(jobVertexID2, index, 2, 8, true);
-			expectedOpStatesBackend.add(opStateBackend);
-			expectedOpStatesRaw.add(opStateRaw);
-			SubtaskState checkpointStateHandles =
-					new SubtaskState(new ChainedStateHandle<>(
-							Collections.<StreamStateHandle>singletonList(null)), opStateBackend, opStateRaw, keyedStateBackend, keyedStateRaw);
+			OperatorStateHandle opStateBackend = generatePartitionableStateHandle(jobVertexID2, index, 2, 8, false);
+			OperatorStateHandle opStateRaw = generatePartitionableStateHandle(jobVertexID2, index, 2, 8, true);
+			expectedOpStatesBackend.add(new ChainedStateHandle<>(Collections.singletonList(opStateBackend)));
+			expectedOpStatesRaw.add(new ChainedStateHandle<>(Collections.singletonList(opStateRaw)));
+
+			OperatorSubtaskState operatorSubtaskState = new OperatorSubtaskState(null, opStateBackend, opStateRaw, keyedStateBackend, keyedStateRaw);
+			TaskStateSnapshot taskOperatorSubtaskStates = new TaskStateSnapshot();
+			taskOperatorSubtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID2), operatorSubtaskState);
+
 			AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint(
 					jid,
 					jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(),
 					checkpointId,
 					new CheckpointMetrics(),
-					checkpointStateHandles);
+					taskOperatorSubtaskStates);
 
 			coord.receiveAcknowledgeMessage(acknowledgeCheckpoint);
 		}
@@ -2506,27 +2484,37 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		List<List<Collection<OperatorStateHandle>>> actualOpStatesBackend = new ArrayList<>(newJobVertex2.getParallelism());
 		List<List<Collection<OperatorStateHandle>>> actualOpStatesRaw = new ArrayList<>(newJobVertex2.getParallelism());
 		for (int i = 0; i < newJobVertex2.getParallelism(); i++) {
-			KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), false);
-			KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), true);
 
-			TaskStateHandles taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles();
+			List<OperatorID> operatorIDs = newJobVertex2.getOperatorIDs();
 
-			ChainedStateHandle<StreamStateHandle> operatorState = taskStateHandles.getLegacyOperatorState();
-			List<Collection<OperatorStateHandle>> opStateBackend = taskStateHandles.getManagedOperatorState();
-			List<Collection<OperatorStateHandle>> opStateRaw = taskStateHandles.getRawOperatorState();
-			Collection<KeyedStateHandle> keyedStateBackend = taskStateHandles.getManagedKeyedState();
-			Collection<KeyedStateHandle> keyGroupStateRaw = taskStateHandles.getRawKeyedState();
+			KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), false);
+			KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i), true);
 
-			actualOpStatesBackend.add(opStateBackend);
-			actualOpStatesRaw.add(opStateRaw);
-			// the 'non partition state' is not null because it is recombined.
-			assertNotNull(operatorState);
-			for (int index = 0; index < operatorState.getLength(); index++) {
-				assertNull(operatorState.get(index));
+			TaskStateSnapshot taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot();
+
+			final int headOpIndex = operatorIDs.size() - 1;
+			List<Collection<OperatorStateHandle>> allParallelManagedOpStates = new ArrayList<>(operatorIDs.size());
+			List<Collection<OperatorStateHandle>> allParallelRawOpStates = new ArrayList<>(operatorIDs.size());
+
+			for (int idx = 0; idx < operatorIDs.size(); ++idx) {
+				OperatorID operatorID = operatorIDs.get(idx);
+				OperatorSubtaskState opState = taskStateHandles.getSubtaskStateByOperatorID(operatorID);
+				Assert.assertNull(opState.getLegacyOperatorState());
+				Collection<OperatorStateHandle> opStateBackend = opState.getManagedOperatorState();
+				Collection<OperatorStateHandle> opStateRaw = opState.getRawOperatorState();
+				allParallelManagedOpStates.add(opStateBackend);
+				allParallelRawOpStates.add(opStateRaw);
+				if (idx == headOpIndex) {
+					Collection<KeyedStateHandle> keyedStateBackend = opState.getManagedKeyedState();
+					Collection<KeyedStateHandle> keyGroupStateRaw = opState.getRawKeyedState();
+					compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend);
+					compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw);
+				}
 			}
-			compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend);
-			compareKeyedState(Collections.singletonList(originalKeyedStateRaw), keyGroupStateRaw);
+			actualOpStatesBackend.add(allParallelManagedOpStates);
+			actualOpStatesRaw.add(allParallelRawOpStates);
 		}
+
 		comparePartitionableState(expectedOpStatesBackend, actualOpStatesBackend);
 		comparePartitionableState(expectedOpStatesRaw, actualOpStatesRaw);
 	}
@@ -2578,14 +2566,11 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			operatorStates.put(id.f1, taskState);
 			for (int index = 0; index < taskState.getParallelism(); index++) {
 				StreamStateHandle subNonPartitionedState = 
-					generateStateForVertex(id.f0, index)
-						.get(0);
+					generateStateForVertex(id.f0, index);
 				OperatorStateHandle subManagedOperatorState =
-					generateChainedPartitionableStateHandle(id.f0, index, 2, 8, false)
-						.get(0);
+					generatePartitionableStateHandle(id.f0, index, 2, 8, false);
 				OperatorStateHandle subRawOperatorState =
-					generateChainedPartitionableStateHandle(id.f0, index, 2, 8, true)
-						.get(0);
+					generatePartitionableStateHandle(id.f0, index, 2, 8, true);
 
 				OperatorSubtaskState subtaskState = new OperatorSubtaskState(subNonPartitionedState,
 					subManagedOperatorState,
@@ -2707,57 +2692,75 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 		for (int i = 0; i < newJobVertex1.getParallelism(); i++) {
 
-			TaskStateHandles taskStateHandles = newJobVertex1.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles();
-			ChainedStateHandle<StreamStateHandle> actualSubNonPartitionedState = taskStateHandles.getLegacyOperatorState();
-			List<Collection<OperatorStateHandle>> actualSubManagedOperatorState = taskStateHandles.getManagedOperatorState();
-			List<Collection<OperatorStateHandle>> actualSubRawOperatorState = taskStateHandles.getRawOperatorState();
+			final List<OperatorID> operatorIds = newJobVertex1.getOperatorIDs();
 
-			assertNull(taskStateHandles.getManagedKeyedState());
-			assertNull(taskStateHandles.getRawKeyedState());
+			TaskStateSnapshot stateSnapshot = newJobVertex1.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot();
+
+			OperatorSubtaskState headOpState = stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIds.size() - 1));
+			assertTrue(headOpState.getManagedKeyedState().isEmpty());
+			assertTrue(headOpState.getRawKeyedState().isEmpty());
 
 			// operator5
 			{
 				int operatorIndexInChain = 2;
-				assertNull(actualSubNonPartitionedState.get(operatorIndexInChain));
-				assertNull(actualSubManagedOperatorState.get(operatorIndexInChain));
-				assertNull(actualSubRawOperatorState.get(operatorIndexInChain));
+				OperatorSubtaskState opState =
+					stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
+
+				assertNull(opState.getLegacyOperatorState());
+				assertTrue(opState.getManagedOperatorState().isEmpty());
+				assertTrue(opState.getRawOperatorState().isEmpty());
 			}
 			// operator1
 			{
 				int operatorIndexInChain = 1;
-				ChainedStateHandle<StreamStateHandle> expectSubNonPartitionedState = generateStateForVertex(id1.f0, i);
-				ChainedStateHandle<OperatorStateHandle> expectedManagedOpState = generateChainedPartitionableStateHandle(
+				OperatorSubtaskState opState =
+					stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
+
+				StreamStateHandle expectSubNonPartitionedState = generateStateForVertex(id1.f0, i);
+				OperatorStateHandle expectedManagedOpState = generatePartitionableStateHandle(
 					id1.f0, i, 2, 8, false);
-				ChainedStateHandle<OperatorStateHandle> expectedRawOpState = generateChainedPartitionableStateHandle(
+				OperatorStateHandle expectedRawOpState = generatePartitionableStateHandle(
 					id1.f0, i, 2, 8, true);
 
 				assertTrue(CommonTestUtils.isSteamContentEqual(
-					expectSubNonPartitionedState.get(0).openInputStream(),
-					actualSubNonPartitionedState.get(operatorIndexInChain).openInputStream()));
-
-				assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.get(0).openInputStream(),
-					actualSubManagedOperatorState.get(operatorIndexInChain).iterator().next().openInputStream()));
-
-				assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.get(0).openInputStream(),
-					actualSubRawOperatorState.get(operatorIndexInChain).iterator().next().openInputStream()));
+					expectSubNonPartitionedState.openInputStream(),
+					opState.getLegacyOperatorState().openInputStream()));
+
+				Collection<OperatorStateHandle> managedOperatorState = opState.getManagedOperatorState();
+				assertEquals(1, managedOperatorState.size());
+				assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.openInputStream(),
+					managedOperatorState.iterator().next().openInputStream()));
+
+				Collection<OperatorStateHandle> rawOperatorState = opState.getRawOperatorState();
+				assertEquals(1, rawOperatorState.size());
+				assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.openInputStream(),
+					rawOperatorState.iterator().next().openInputStream()));
 			}
 			// operator2
 			{
 				int operatorIndexInChain = 0;
-				ChainedStateHandle<StreamStateHandle> expectSubNonPartitionedState = generateStateForVertex(id2.f0, i);
-				ChainedStateHandle<OperatorStateHandle> expectedManagedOpState = generateChainedPartitionableStateHandle(
+				OperatorSubtaskState opState =
+					stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
+
+				StreamStateHandle expectSubNonPartitionedState = generateStateForVertex(id2.f0, i);
+				OperatorStateHandle expectedManagedOpState = generatePartitionableStateHandle(
 					id2.f0, i, 2, 8, false);
-				ChainedStateHandle<OperatorStateHandle> expectedRawOpState = generateChainedPartitionableStateHandle(
+				OperatorStateHandle expectedRawOpState = generatePartitionableStateHandle(
 					id2.f0, i, 2, 8, true);
 
-				assertTrue(CommonTestUtils.isSteamContentEqual(expectSubNonPartitionedState.get(0).openInputStream(),
-					actualSubNonPartitionedState.get(operatorIndexInChain).openInputStream()));
-
-				assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.get(0).openInputStream(),
-					actualSubManagedOperatorState.get(operatorIndexInChain).iterator().next().openInputStream()));
-
-				assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.get(0).openInputStream(),
-					actualSubRawOperatorState.get(operatorIndexInChain).iterator().next().openInputStream()));
+				assertTrue(CommonTestUtils.isSteamContentEqual(
+					expectSubNonPartitionedState.openInputStream(),
+					opState.getLegacyOperatorState().openInputStream()));
+
+				Collection<OperatorStateHandle> managedOperatorState = opState.getManagedOperatorState();
+				assertEquals(1, managedOperatorState.size());
+				assertTrue(CommonTestUtils.isSteamContentEqual(expectedManagedOpState.openInputStream(),
+					managedOperatorState.iterator().next().openInputStream()));
+
+				Collection<OperatorStateHandle> rawOperatorState = opState.getRawOperatorState();
+				assertEquals(1, rawOperatorState.size());
+				assertTrue(CommonTestUtils.isSteamContentEqual(expectedRawOpState.openInputStream(),
+					rawOperatorState.iterator().next().openInputStream()));
 			}
 		}
 
@@ -2765,38 +2768,48 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		List<List<Collection<OperatorStateHandle>>> actualRawOperatorStates = new ArrayList<>(newJobVertex2.getParallelism());
 
 		for (int i = 0; i < newJobVertex2.getParallelism(); i++) {
-			TaskStateHandles taskStateHandles = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles();
+
+			final List<OperatorID> operatorIds = newJobVertex2.getOperatorIDs();
+
+			TaskStateSnapshot stateSnapshot = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot();
 
 			// operator 3
 			{
 				int operatorIndexInChain = 1;
+				OperatorSubtaskState opState =
+					stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
+
 				List<Collection<OperatorStateHandle>> actualSubManagedOperatorState = new ArrayList<>(1);
-				actualSubManagedOperatorState.add(taskStateHandles.getManagedOperatorState().get(operatorIndexInChain));
+				actualSubManagedOperatorState.add(opState.getManagedOperatorState());
 
 				List<Collection<OperatorStateHandle>> actualSubRawOperatorState = new ArrayList<>(1);
-				actualSubRawOperatorState.add(taskStateHandles.getRawOperatorState().get(operatorIndexInChain));
+				actualSubRawOperatorState.add(opState.getRawOperatorState());
 
 				actualManagedOperatorStates.add(actualSubManagedOperatorState);
 				actualRawOperatorStates.add(actualSubRawOperatorState);
 
-				assertNull(taskStateHandles.getLegacyOperatorState().get(operatorIndexInChain));
+				assertNull(opState.getLegacyOperatorState());
 			}
 
 			// operator 6
 			{
 				int operatorIndexInChain = 0;
-				assertNull(taskStateHandles.getManagedOperatorState().get(operatorIndexInChain));
-				assertNull(taskStateHandles.getRawOperatorState().get(operatorIndexInChain));
-				assertNull(taskStateHandles.getLegacyOperatorState().get(operatorIndexInChain));
+				OperatorSubtaskState opState =
+					stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIndexInChain));
+				assertNull(opState.getLegacyOperatorState());
+				assertTrue(opState.getManagedOperatorState().isEmpty());
+				assertTrue(opState.getRawOperatorState().isEmpty());
 
 			}
 
 			KeyGroupsStateHandle originalKeyedStateBackend = generateKeyGroupState(id3.f0, newKeyGroupPartitions2.get(i), false);
 			KeyGroupsStateHandle originalKeyedStateRaw = generateKeyGroupState(id3.f0, newKeyGroupPartitions2.get(i), true);
 
+			OperatorSubtaskState headOpState =
+				stateSnapshot.getSubtaskStateByOperatorID(operatorIds.get(operatorIds.size() - 1));
 
-			Collection<KeyedStateHandle> keyedStateBackend = taskStateHandles.getManagedKeyedState();
-			Collection<KeyedStateHandle> keyGroupStateRaw = taskStateHandles.getRawKeyedState();
+			Collection<KeyedStateHandle> keyedStateBackend = headOpState.getManagedKeyedState();
+			Collection<KeyedStateHandle> keyGroupStateRaw = headOpState.getRawKeyedState();
 
 
 			compareKeyedState(Collections.singletonList(originalKeyedStateBackend), keyedStateBackend);
@@ -2974,19 +2987,50 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		return new Tuple2<>(allSerializedValuesConcatenated, offsets);
 	}
 
-	public static ChainedStateHandle<StreamStateHandle> generateStateForVertex(
+	public static StreamStateHandle generateStateForVertex(
 			JobVertexID jobVertexID,
 			int index) throws IOException {
 
 		Random random = new Random(jobVertexID.hashCode() + index);
 		int value = random.nextInt();
-		return generateChainedStateHandle(value);
+		return generateStreamStateHandle(value);
+	}
+
+	public static StreamStateHandle generateStreamStateHandle(Serializable value) throws IOException {
+		return TestByteStreamStateHandleDeepCompare.fromSerializable(String.valueOf(UUID.randomUUID()), value);
 	}
 
 	public static ChainedStateHandle<StreamStateHandle> generateChainedStateHandle(
 			Serializable value) throws IOException {
 		return ChainedStateHandle.wrapSingleHandle(
-				TestByteStreamStateHandleDeepCompare.fromSerializable(String.valueOf(UUID.randomUUID()), value));
+				generateStreamStateHandle(value));
+	}
+
+	public static OperatorStateHandle generatePartitionableStateHandle(
+		JobVertexID jobVertexID,
+		int index,
+		int namedStates,
+		int partitionsPerState,
+		boolean rawState) throws IOException {
+
+		Map<String, List<? extends Serializable>> statesListsMap = new HashMap<>(namedStates);
+
+		for (int i = 0; i < namedStates; ++i) {
+			List<Integer> testStatesLists = new ArrayList<>(partitionsPerState);
+			// generate state
+			int seed = jobVertexID.hashCode() * index + i * namedStates;
+			if (rawState) {
+				seed = (seed + 1) * 31;
+			}
+			Random random = new Random(seed);
+			for (int j = 0; j < partitionsPerState; ++j) {
+				int simulatedStateValue = random.nextInt();
+				testStatesLists.add(simulatedStateValue);
+			}
+			statesListsMap.put("state-" + i, testStatesLists);
+		}
+
+		return generatePartitionableStateHandle(statesListsMap);
 	}
 
 	public static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle(
@@ -3013,11 +3057,11 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			statesListsMap.put("state-" + i, testStatesLists);
 		}
 
-		return generateChainedPartitionableStateHandle(statesListsMap);
+		return ChainedStateHandle.wrapSingleHandle(generatePartitionableStateHandle(statesListsMap));
 	}
 
-	private static ChainedStateHandle<OperatorStateHandle> generateChainedPartitionableStateHandle(
-			Map<String, List<? extends Serializable>> states) throws IOException {
+	private static OperatorStateHandle generatePartitionableStateHandle(
+		Map<String, List<? extends Serializable>> states) throws IOException {
 
 		List<List<? extends Serializable>> namedStateSerializables = new ArrayList<>(states.size());
 
@@ -3032,20 +3076,18 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		int idx = 0;
 		for (Map.Entry<String, List<? extends Serializable>> entry : states.entrySet()) {
 			offsetsMap.put(
-					entry.getKey(),
-					new OperatorStateHandle.StateMetaInfo(
-							serializationWithOffsets.f1.get(idx),
-							OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
+				entry.getKey(),
+				new OperatorStateHandle.StateMetaInfo(
+					serializationWithOffsets.f1.get(idx),
+					OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
 			++idx;
 		}
 
 		ByteStreamStateHandle streamStateHandle = new TestByteStreamStateHandleDeepCompare(
-				String.valueOf(UUID.randomUUID()),
-				serializationWithOffsets.f0);
+			String.valueOf(UUID.randomUUID()),
+			serializationWithOffsets.f0);
 
-		OperatorStateHandle operatorStateHandle =
-				new OperatorStateHandle(offsetsMap, streamStateHandle);
-		return ChainedStateHandle.wrapSingleHandle(operatorStateHandle);
+		return new OperatorStateHandle(offsetsMap, streamStateHandle);
 	}
 
 	static ExecutionJobVertex mockExecutionJobVertex(
@@ -3139,24 +3181,23 @@ public class CheckpointCoordinatorTest extends TestLogger {
 		return vertex;
 	}
 
-	static SubtaskState mockSubtaskState(
+	static TaskStateSnapshot mockSubtaskState(
 		JobVertexID jobVertexID,
 		int index,
 		KeyGroupRange keyGroupRange) throws IOException {
 
-		ChainedStateHandle<StreamStateHandle> nonPartitionedState = generateStateForVertex(jobVertexID, index);
-		ChainedStateHandle<OperatorStateHandle> partitionableState = generateChainedPartitionableStateHandle(jobVertexID, index, 2, 8, false);
+		StreamStateHandle nonPartitionedState = generateStateForVertex(jobVertexID, index);
+		OperatorStateHandle partitionableState = generatePartitionableStateHandle(jobVertexID, index, 2, 8, false);
 		KeyGroupsStateHandle partitionedKeyGroupState = generateKeyGroupState(jobVertexID, keyGroupRange, false);
 
-		SubtaskState subtaskState = mock(SubtaskState.class, withSettings().serializable());
+		TaskStateSnapshot subtaskStates = spy(new TaskStateSnapshot());
+		OperatorSubtaskState subtaskState = spy(new OperatorSubtaskState(
+			nonPartitionedState, partitionableState, null, partitionedKeyGroupState, null)
+		);
 
-		doReturn(nonPartitionedState).when(subtaskState).getLegacyOperatorState();
-		doReturn(partitionableState).when(subtaskState).getManagedOperatorState();
-		doReturn(null).when(subtaskState).getRawOperatorState();
-		doReturn(partitionedKeyGroupState).when(subtaskState).getManagedKeyedState();
-		doReturn(null).when(subtaskState).getRawKeyedState();
+		subtaskStates.putSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID), subtaskState);
 
-		return subtaskState;
+		return subtaskStates;
 	}
 
 	public static void verifyStateRestore(
@@ -3165,27 +3206,27 @@ public class CheckpointCoordinatorTest extends TestLogger {
 
 		for (int i = 0; i < executionJobVertex.getParallelism(); i++) {
 
-			TaskStateHandles taskStateHandles = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateHandles();
+			final List<OperatorID> operatorIds = executionJobVertex.getOperatorIDs();
 
-			ChainedStateHandle<StreamStateHandle> expectNonPartitionedState = generateStateForVertex(jobVertexID, i);
-			ChainedStateHandle<StreamStateHandle> actualNonPartitionedState = taskStateHandles.getLegacyOperatorState();
+			TaskStateSnapshot stateSnapshot = executionJobVertex.getTaskVertices()[i].getCurrentExecutionAttempt().getTaskStateSnapshot();
+
+			OperatorSubtaskState operatorState = stateSnapshot.getSubtaskStateByOperatorID(OperatorID.fromJobVertexID(jobVertexID));
+
+			StreamStateHandle expectNonPartitionedState = generateStateForVertex(jobVertexID, i);
 			assertTrue(CommonTestUtils.isSteamContentEqual(
-					expectNonPartitionedState.get(0).openInputStream(),
-					actualNonPartitionedState.get(0).openInputStream()));
+					expectNonPartitionedState.openInputStream(),
+				operatorState.getLegacyOperatorState().openInputStream()));
 
 			ChainedStateHandle<OperatorStateHandle> expectedOpStateBackend =
 					generateChainedPartitionableStateHandle(jobVertexID, i, 2, 8, false);
 
-			List<Collection<OperatorStateHandle>> actualPartitionableState = taskStateHandles.getManagedOperatorState();
-
 			assertTrue(CommonTestUtils.isSteamContentEqual(
 					expectedOpStateBackend.get(0).openInputStream(),
-					actualPartitionableState.get(0).iterator().next().openInputStream()));
+					operatorState.getManagedOperatorState().iterator().next().openInputStream()));
 
 			KeyGroupsStateHandle expectPartitionedKeyGroupState = generateKeyGroupState(
 					jobVertexID, keyGroupPartitions.get(i), false);
-			Collection<KeyedStateHandle> actualPartitionedKeyGroupState = taskStateHandles.getManagedKeyedState();
-			compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), actualPartitionedKeyGroupState);
+			compareKeyedState(Collections.singletonList(expectPartitionedKeyGroupState), operatorState.getManagedKeyedState());
 		}
 	}
 
@@ -3632,17 +3673,4 @@ public class CheckpointCoordinatorTest extends TestLogger {
 			"The latest completed (proper) checkpoint should have been added to the completed checkpoint store.",
 			completedCheckpointStore.getLatestCheckpoint().getCheckpointID() == checkpointIDCounter.getLast());
 	}
-
-	private static final class SpyInjectingOperatorState extends OperatorState {
-
-		private static final long serialVersionUID = -4004437428483663815L;
-
-		public SpyInjectingOperatorState(OperatorID taskID, int parallelism, int maxParallelism) {
-			super(taskID, parallelism, maxParallelism);
-		}
-
-		public void putState(int subtaskIndex, OperatorSubtaskState subtaskState) {
-			super.putState(subtaskIndex, spy(subtaskState));
-		}
-	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
index 7d24568..6ce071b 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointStateRestoreTest.java
@@ -34,18 +34,18 @@ import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.util.SerializableObject;
+
 import org.hamcrest.BaseMatcher;
 import org.hamcrest.Description;
 import org.junit.Test;
 import org.mockito.Mockito;
 
-import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.fail;
@@ -118,10 +118,20 @@ public class CheckpointStateRestoreTest {
 			PendingCheckpoint pending = coord.getPendingCheckpoints().values().iterator().next();
 			final long checkpointId = pending.getCheckpointId();
 
-			SubtaskState checkpointStateHandles = new SubtaskState(serializedState, null, null, serializedKeyGroupStates, null);
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles));
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles));
-			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, new CheckpointMetrics(), checkpointStateHandles));
+			final TaskStateSnapshot subtaskStates = new TaskStateSnapshot();
+
+			subtaskStates.putSubtaskStateByOperatorID(
+				OperatorID.fromJobVertexID(statefulId),
+				new OperatorSubtaskState(
+					serializedState.get(0),
+					Collections.<OperatorStateHandle>emptyList(),
+					Collections.<OperatorStateHandle>emptyList(),
+					Collections.singletonList(serializedKeyGroupStates),
+					Collections.<KeyedStateHandle>emptyList()));
+
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec1.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStates));
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec2.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStates));
+			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statefulExec3.getAttemptId(), checkpointId, new CheckpointMetrics(), subtaskStates));
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec1.getAttemptId(), checkpointId));
 			coord.receiveAcknowledgeMessage(new AcknowledgeCheckpoint(jid, statelessExec2.getAttemptId(), checkpointId));
 
@@ -133,33 +143,26 @@ public class CheckpointStateRestoreTest {
 
 			// verify that each stateful vertex got the state
 
-			final TaskStateHandles taskStateHandles = new TaskStateHandles(
-					serializedState,
-					Collections.<Collection<OperatorStateHandle>>singletonList(null),
-					Collections.<Collection<OperatorStateHandle>>singletonList(null),
-					Collections.singletonList(serializedKeyGroupStates),
-					null);
-
-			BaseMatcher<TaskStateHandles> matcher = new BaseMatcher<TaskStateHandles>() {
+			BaseMatcher<TaskStateSnapshot> matcher = new BaseMatcher<TaskStateSnapshot>() {
 				@Override
 				public boolean matches(Object o) {
-					if (o instanceof TaskStateHandles) {
-						return o.equals(taskStateHandles);
+					if (o instanceof TaskStateSnapshot) {
+						return Objects.equals(o, subtaskStates);
 					}
 					return false;
 				}
 
 				@Override
 				public void describeTo(Description description) {
-					description.appendValue(taskStateHandles);
+					description.appendValue(subtaskStates);
 				}
 			};
 
 			verify(statefulExec1, times(1)).setInitialState(Mockito.argThat(matcher));
 			verify(statefulExec2, times(1)).setInitialState(Mockito.argThat(matcher));
 			verify(statefulExec3, times(1)).setInitialState(Mockito.argThat(matcher));
-			verify(statelessExec1, times(0)).setInitialState(Mockito.<TaskStateHandles>any());
-			verify(statelessExec2, times(0)).setInitialState(Mockito.<TaskStateHandles>any());
+			verify(statelessExec1, times(0)).setInitialState(Mockito.<TaskStateSnapshot>any());
+			verify(statelessExec2, times(0)).setInitialState(Mockito.<TaskStateSnapshot>any());
 		}
 		catch (Exception e) {
 			e.printStackTrace();
@@ -250,9 +253,9 @@ public class CheckpointStateRestoreTest {
 		Map<OperatorID, OperatorState> checkpointTaskStates = new HashMap<>();
 		{
 			OperatorState taskState = new OperatorState(operatorId1, 3, 3);
-			taskState.putState(0, new OperatorSubtaskState(serializedState, null, null, null, null));
-			taskState.putState(1, new OperatorSubtaskState(serializedState, null, null, null, null));
-			taskState.putState(2, new OperatorSubtaskState(serializedState, null, null, null, null));
+			taskState.putState(0, new OperatorSubtaskState(serializedState));
+			taskState.putState(1, new OperatorSubtaskState(serializedState));
+			taskState.putState(2, new OperatorSubtaskState(serializedState));
 
 			checkpointTaskStates.put(operatorId1, taskState);
 		}
@@ -279,7 +282,7 @@ public class CheckpointStateRestoreTest {
 		// There is no task for this
 		{
 			OperatorState taskState = new OperatorState(newOperatorID, 1, 1);
-			taskState.putState(0, new OperatorSubtaskState(serializedState, null, null, null, null));
+			taskState.putState(0, new OperatorSubtaskState(serializedState));
 
 			checkpointTaskStates.put(newOperatorID, taskState);
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
index 1fe4e65..320dc2d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
@@ -331,7 +331,7 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 		boolean discarded;
 
 		public TestOperatorSubtaskState() {
-			super(null, null, null, null, null);
+			super();
 			this.registered = false;
 			this.discarded = false;
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
index 7d103d0..7ebb49a 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PendingCheckpointTest.java
@@ -324,7 +324,7 @@ public class PendingCheckpointTest {
 	@Test
 	public void testNonNullSubtaskStateLeadsToStatefulTask() throws Exception {
 		PendingCheckpoint pending = createPendingCheckpoint(CheckpointProperties.forStandardCheckpoint(), null);
-		pending.acknowledgeTask(ATTEMPT_ID, mock(SubtaskState.class), mock(CheckpointMetrics.class));
+		pending.acknowledgeTask(ATTEMPT_ID, mock(TaskStateSnapshot.class), mock(CheckpointMetrics.class));
 		Assert.assertFalse(pending.getOperatorStates().isEmpty());
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java
index 36c9cad..9ed4851 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/deployment/TaskDeploymentDescriptorTest.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.JobID;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.core.testutils.CommonTestUtils;
 import org.apache.flink.runtime.blob.BlobKey;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.executiongraph.JobInformation;
@@ -30,7 +31,6 @@ import org.apache.flink.runtime.executiongraph.TaskInformation;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.operators.BatchTask;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.util.SerializedValue;
 
 import org.junit.Test;
@@ -73,7 +73,7 @@ public class TaskDeploymentDescriptorTest {
 			final SerializedValue<TaskInformation> serializedJobVertexInformation = new SerializedValue<>(new TaskInformation(
 				vertexID, taskName, currentNumberOfSubtasks, numberOfKeyGroups, invokableClass.getName(), taskConfiguration));
 			final int targetSlotNumber = 47;
-			final TaskStateHandles taskStateHandles = new TaskStateHandles();
+			final TaskStateSnapshot taskStateHandles = new TaskStateSnapshot();
 
 			final TaskDeploymentDescriptor orig = new TaskDeploymentDescriptor(
 				serializedJobInformation,

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java
index 0eed90d..c9b7a40 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/executiongraph/ExecutionVertexLocalityTest.java
@@ -23,6 +23,7 @@ import org.apache.flink.api.common.time.Time;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.metrics.groups.UnregisteredMetricsGroup;
 import org.apache.flink.runtime.checkpoint.StandaloneCheckpointRecoveryFactory;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.clusterframework.types.ResourceID;
 import org.apache.flink.runtime.clusterframework.types.ResourceProfile;
@@ -38,7 +39,6 @@ import org.apache.flink.runtime.jobgraph.JobVertexID;
 import org.apache.flink.runtime.jobmanager.slots.AllocatedSlot;
 import org.apache.flink.runtime.jobmanager.slots.SlotOwner;
 import org.apache.flink.runtime.jobmanager.slots.TaskManagerGateway;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.taskmanager.TaskManagerLocation;
 import org.apache.flink.runtime.testingUtils.TestingUtils;
 import org.apache.flink.runtime.testtasks.NoOpInvokable;
@@ -51,8 +51,10 @@ import java.net.InetAddress;
 import java.util.Iterator;
 import java.util.concurrent.TimeUnit;
 
-import static org.mockito.Mockito.*;
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.mockito.Mockito.mock;
 
 /**
  * Tests that the execution vertex handles locality preferences well.
@@ -169,7 +171,7 @@ public class ExecutionVertexLocalityTest extends TestLogger {
 
 			// target state
 			ExecutionVertex target = graph.getAllVertices().get(targetVertexId).getTaskVertices()[i];
-			target.getCurrentExecutionAttempt().setInitialState(mock(TaskStateHandles.class));
+			target.getCurrentExecutionAttempt().setInitialState(mock(TaskStateSnapshot.class));
 		}
 
 		// validate that the target vertices have the state's location as the location preference

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
index a63b02d..23f0a38 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
@@ -18,16 +18,6 @@
 
 package org.apache.flink.runtime.jobmanager;
 
-import akka.actor.ActorRef;
-import akka.actor.ActorSystem;
-import akka.actor.Identify;
-import akka.actor.PoisonPill;
-import akka.actor.Props;
-import akka.japi.pf.FI;
-import akka.japi.pf.ReceiveBuilder;
-import akka.pattern.Patterns;
-import akka.testkit.CallingThreadDispatcher;
-import akka.testkit.JavaTestKit;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.configuration.ConfigConstants;
 import org.apache.flink.configuration.Configuration;
@@ -44,8 +34,9 @@ import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
 import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory;
 import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
 import org.apache.flink.runtime.checkpoint.StandaloneCheckpointIDCounter;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.ResourceID;
 import org.apache.flink.runtime.execution.librarycache.BlobLibraryCacheManager;
 import org.apache.flink.runtime.executiongraph.restart.FixedDelayRestartStrategy;
@@ -59,6 +50,7 @@ import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.jobgraph.JobStatus;
 import org.apache.flink.runtime.jobgraph.JobVertex;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.jobgraph.tasks.ExternalizedCheckpointSettings;
 import org.apache.flink.runtime.jobgraph.tasks.JobCheckpointingSettings;
@@ -69,9 +61,6 @@ import org.apache.flink.runtime.leaderelection.TestingLeaderElectionService;
 import org.apache.flink.runtime.leaderelection.TestingLeaderRetrievalService;
 import org.apache.flink.runtime.messages.JobManagerMessages;
 import org.apache.flink.runtime.metrics.MetricRegistry;
-import org.apache.flink.runtime.state.ChainedStateHandle;
-import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.taskmanager.TaskManager;
 import org.apache.flink.runtime.testingUtils.TestingJobManager;
@@ -83,23 +72,24 @@ import org.apache.flink.runtime.testingUtils.TestingUtils;
 import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore;
 import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare;
 import org.apache.flink.util.InstantiationUtil;
-
 import org.apache.flink.util.TestLogger;
+
+import akka.actor.ActorRef;
+import akka.actor.ActorSystem;
+import akka.actor.Identify;
+import akka.actor.PoisonPill;
+import akka.actor.Props;
+import akka.japi.pf.FI;
+import akka.japi.pf.ReceiveBuilder;
+import akka.pattern.Patterns;
+import akka.testkit.CallingThreadDispatcher;
+import akka.testkit.JavaTestKit;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.Rule;
 import org.junit.Test;
 import org.junit.rules.TemporaryFolder;
 
-import scala.Int;
-import scala.Option;
-import scala.PartialFunction;
-import scala.concurrent.Await;
-import scala.concurrent.Future;
-import scala.concurrent.duration.Deadline;
-import scala.concurrent.duration.FiniteDuration;
-import scala.runtime.BoxedUnit;
-
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -113,6 +103,15 @@ import java.util.concurrent.Executor;
 import java.util.concurrent.ScheduledExecutorService;
 import java.util.concurrent.TimeUnit;
 
+import scala.Int;
+import scala.Option;
+import scala.PartialFunction;
+import scala.concurrent.Await;
+import scala.concurrent.Future;
+import scala.concurrent.duration.Deadline;
+import scala.concurrent.duration.FiniteDuration;
+import scala.runtime.BoxedUnit;
+
 import static org.hamcrest.Matchers.containsInAnyOrder;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertThat;
@@ -552,10 +551,10 @@ public class JobManagerHARecoveryTest extends TestLogger {
 
 		@Override
 		public void setInitialState(
-				TaskStateHandles taskStateHandles) throws Exception {
+			TaskStateSnapshot taskStateHandles) throws Exception {
 			int subtaskIndex = getIndexInSubtaskGroup();
 			if (subtaskIndex < recoveredStates.length) {
-				try (FSDataInputStream in = taskStateHandles.getLegacyOperatorState().get(0).openInputStream()) {
+				try (FSDataInputStream in = taskStateHandles.getSubtaskStateMappings().iterator().next().getValue().getLegacyOperatorState().openInputStream()) {
 					recoveredStates[subtaskIndex] = InstantiationUtil.deserializeObject(in, getUserCodeClassLoader());
 				}
 			}
@@ -567,10 +566,11 @@ public class JobManagerHARecoveryTest extends TestLogger {
 					String.valueOf(UUID.randomUUID()),
 					InstantiationUtil.serializeObject(checkpointMetaData.getCheckpointId()));
 
-			ChainedStateHandle<StreamStateHandle> chainedStateHandle =
-					new ChainedStateHandle<StreamStateHandle>(Collections.singletonList(byteStreamStateHandle));
-			SubtaskState checkpointStateHandles =
-					new SubtaskState(chainedStateHandle, null, null, null, null);
+			TaskStateSnapshot checkpointStateHandles = new TaskStateSnapshot();
+			checkpointStateHandles.putSubtaskStateByOperatorID(
+				OperatorID.fromJobVertexID(getEnvironment().getJobVertexId()),
+				new OperatorSubtaskState(byteStreamStateHandle)
+			);
 
 			getEnvironment().acknowledgeCheckpoint(
 					checkpointMetaData.getCheckpointId(),

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
index bc420cc..d022cdc 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/messages/CheckpointMessagesTest.java
@@ -24,14 +24,17 @@ import org.apache.flink.core.testutils.CommonTestUtils;
 import org.apache.flink.runtime.checkpoint.CheckpointCoordinatorTest;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint;
 import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete;
 import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint;
 import org.apache.flink.runtime.state.KeyGroupRange;
 import org.apache.flink.runtime.state.StreamStateHandle;
+
 import org.junit.Test;
 
 import java.io.IOException;
@@ -68,13 +71,17 @@ public class CheckpointMessagesTest {
 
 			KeyGroupRange keyGroupRange = KeyGroupRange.of(42,42);
 
-			SubtaskState checkpointStateHandles =
-					new SubtaskState(
-							CheckpointCoordinatorTest.generateChainedStateHandle(new MyHandle()),
-							CheckpointCoordinatorTest.generateChainedPartitionableStateHandle(new JobVertexID(), 0, 2, 8, false),
-							null,
-							CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())),
-							null);
+			TaskStateSnapshot checkpointStateHandles = new TaskStateSnapshot();
+			checkpointStateHandles.putSubtaskStateByOperatorID(
+				new OperatorID(),
+				new OperatorSubtaskState(
+					CheckpointCoordinatorTest.generateStreamStateHandle(new MyHandle()),
+					CheckpointCoordinatorTest.generatePartitionableStateHandle(new JobVertexID(), 0, 2, 8, false),
+					null,
+					CheckpointCoordinatorTest.generateKeyGroupState(keyGroupRange, Collections.singletonList(new MyHandle())),
+					null
+				)
+			);
 
 			AcknowledgeCheckpoint withState = new AcknowledgeCheckpoint(
 					new JobID(),

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
index 851fa96..8ed06b2 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
@@ -26,7 +26,7 @@ import org.apache.flink.core.fs.Path;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
@@ -156,7 +156,7 @@ public class DummyEnvironment implements Environment {
 	}
 
 	@Override
-	public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, SubtaskState subtaskState) {
+	public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, TaskStateSnapshot subtaskState) {
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
index 4f0242e..7514cc4 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
@@ -27,7 +27,7 @@ import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.accumulators.AccumulatorRegistry;
 import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
-import org.apache.flink.runtime.checkpoint.SubtaskState;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.execution.Environment;
 import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
 import org.apache.flink.runtime.io.disk.iomanager.IOManager;
@@ -50,8 +50,8 @@ import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
 import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo;
 import org.apache.flink.types.Record;
 import org.apache.flink.util.MutableObjectIterator;
-
 import org.apache.flink.util.Preconditions;
+
 import org.mockito.invocation.InvocationOnMock;
 import org.mockito.stubbing.Answer;
 
@@ -354,7 +354,7 @@ public class MockEnvironment implements Environment {
 	}
 
 	@Override
-	public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, SubtaskState subtaskState) {
+	public void acknowledgeCheckpoint(long checkpointId, CheckpointMetrics checkpointMetrics, TaskStateSnapshot subtaskState) {
 		throw new UnsupportedOperationException();
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/b71154a7/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
index c6d2fec..085a386 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/taskmanager/TaskAsyncCallTest.java
@@ -27,6 +27,7 @@ import org.apache.flink.runtime.broadcast.BroadcastVariableManager;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.clusterframework.types.AllocationID;
 import org.apache.flink.runtime.deployment.InputGateDeploymentDescriptor;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
@@ -49,7 +50,6 @@ import org.apache.flink.runtime.memory.MemoryManager;
 import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
 import org.apache.flink.runtime.metrics.groups.TaskMetricGroup;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
-import org.apache.flink.runtime.state.TaskStateHandles;
 import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo;
 import org.apache.flink.util.SerializedValue;
 
@@ -187,7 +187,7 @@ public class TaskAsyncCallTest {
 			Collections.<ResultPartitionDeploymentDescriptor>emptyList(),
 			Collections.<InputGateDeploymentDescriptor>emptyList(),
 			0,
-			new TaskStateHandles(),
+			new TaskStateSnapshot(),
 			mock(MemoryManager.class),
 			mock(IOManager.class),
 			networkEnvironment,
@@ -228,7 +228,7 @@ public class TaskAsyncCallTest {
 		}
 
 		@Override
-		public void setInitialState(TaskStateHandles taskStateHandles) throws Exception {}
+		public void setInitialState(TaskStateSnapshot taskStateHandles) throws Exception {}
 
 		@Override
 		public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) {


Mime
View raw message