flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From srich...@apache.org
Subject [12/26] flink git commit: [FLINK-8360][checkpointing] Implement state storage for local recovery and integrate with task lifecycle
Date Sun, 25 Feb 2018 16:11:46 GMT
http://git-wip-us.apache.org/repos/asf/flink/blob/df3e6bb7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
index 96c95ea..8156964 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CompletedCheckpointStoreTest.java
@@ -328,7 +328,7 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 		}
 	}
 
-	static class TestOperatorSubtaskState extends OperatorSubtaskState {
+	public static class TestOperatorSubtaskState extends OperatorSubtaskState {
 		private static final long serialVersionUID = 522580433699164230L;
 
 		boolean registered;
@@ -359,6 +359,14 @@ public abstract class CompletedCheckpointStoreTest extends TestLogger {
 			registered = false;
 			discarded = false;
 		}
+
+		public boolean isRegistered() {
+			return registered;
+		}
+
+		public boolean isDiscarded() {
+			return discarded;
+		}
 	}
 
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/df3e6bb7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PrioritizedOperatorSubtaskStateTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PrioritizedOperatorSubtaskStateTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PrioritizedOperatorSubtaskStateTest.java
new file mode 100644
index 0000000..09c9efb
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/PrioritizedOperatorSubtaskStateTest.java
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.OperatorStreamStateHandle;
+import org.apache.flink.runtime.state.StateObject;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.flink.util.TestLogger;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+import java.util.function.Function;
+
+import static org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNewKeyedStateHandle;
+import static org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.createNewOperatorStateHandle;
+import static org.apache.flink.runtime.checkpoint.StateHandleDummyUtil.deepDummyCopy;
+
+public class PrioritizedOperatorSubtaskStateTest extends TestLogger {
+
+	private final Random random = new Random(0x42);
+
+	/**
+	 * This tests attempts to test (almost) the full space of significantly different options for verifying and
+	 * prioritizing {@link OperatorSubtaskState} options for local recovery over primary/remote state handles.
+	 */
+	@Test
+	public void testPrioritization() {
+
+		for (int i = 0; i < 81; ++i) { // 3^4 possible configurations.
+
+			OperatorSubtaskState primaryAndFallback = generateForConfiguration(i);
+
+			for (int j = 0; j < 9; ++j) { // we test 3^2 configurations.
+				// mode 0: one valid state handle (deep copy of original).
+				// mode 1: empty StateHandleCollection.
+				// mode 2: one invalid state handle (e.g. wrong key group, different meta data)
+				int modeFirst = j % 3;
+				OperatorSubtaskState bestAlternative = createAlternativeSubtaskState(primaryAndFallback, modeFirst);
+				int modeSecond = (j / 3) % 3;
+				OperatorSubtaskState secondBestAlternative = createAlternativeSubtaskState(primaryAndFallback, modeSecond);
+
+				List<OperatorSubtaskState> orderedAlternativesList =
+					Arrays.asList(bestAlternative, secondBestAlternative);
+				List<OperatorSubtaskState> validAlternativesList = new ArrayList<>(3);
+				if (modeFirst == 0) {
+					validAlternativesList.add(bestAlternative);
+				}
+				if (modeSecond == 0) {
+					validAlternativesList.add(secondBestAlternative);
+				}
+				validAlternativesList.add(primaryAndFallback);
+
+				PrioritizedOperatorSubtaskState.Builder builder =
+					new PrioritizedOperatorSubtaskState.Builder(primaryAndFallback, orderedAlternativesList);
+
+				PrioritizedOperatorSubtaskState prioritizedOperatorSubtaskState = builder.build();
+
+				OperatorSubtaskState[] validAlternatives =
+					validAlternativesList.toArray(new OperatorSubtaskState[validAlternativesList.size()]);
+
+				OperatorSubtaskState[] onlyPrimary =
+					new OperatorSubtaskState[]{primaryAndFallback};
+
+				Assert.assertTrue(checkResultAsExpected(
+					OperatorSubtaskState::getManagedOperatorState,
+					PrioritizedOperatorSubtaskState::getPrioritizedManagedOperatorState,
+					prioritizedOperatorSubtaskState,
+					primaryAndFallback.getManagedOperatorState().size() == 1 ? validAlternatives : onlyPrimary));
+
+				Assert.assertTrue(checkResultAsExpected(
+					OperatorSubtaskState::getManagedKeyedState,
+					PrioritizedOperatorSubtaskState::getPrioritizedManagedKeyedState,
+					prioritizedOperatorSubtaskState,
+					primaryAndFallback.getManagedKeyedState().size() == 1 ? validAlternatives : onlyPrimary));
+
+				Assert.assertTrue(checkResultAsExpected(
+					OperatorSubtaskState::getRawOperatorState,
+					PrioritizedOperatorSubtaskState::getPrioritizedRawOperatorState,
+					prioritizedOperatorSubtaskState,
+					primaryAndFallback.getRawOperatorState().size() == 1 ? validAlternatives : onlyPrimary));
+
+				Assert.assertTrue(checkResultAsExpected(
+					OperatorSubtaskState::getRawKeyedState,
+					PrioritizedOperatorSubtaskState::getPrioritizedRawKeyedState,
+					prioritizedOperatorSubtaskState,
+					primaryAndFallback.getRawKeyedState().size() == 1 ? validAlternatives : onlyPrimary));
+			}
+		}
+	}
+
+	/**
+	 * Generator for all 3^4 = 81 possible configurations of a OperatorSubtaskState:
+	 * - 4 different sub-states:
+	 *      managed/raw + operator/keyed.
+	 * - 3 different options per sub-state:
+	 *      empty (simulate no state), single handle (simulate recovery), 2 handles (simulate e.g. rescaling)
+	 */
+	private OperatorSubtaskState generateForConfiguration(int conf) {
+
+		Preconditions.checkState(conf >= 0 && conf <= 80); // 3^4
+		final int numModes = 3;
+
+		KeyGroupRange keyGroupRange = new KeyGroupRange(0, 4);
+		KeyGroupRange keyGroupRange1 = new KeyGroupRange(0, 2);
+		KeyGroupRange keyGroupRange2 = new KeyGroupRange(3, 4);
+
+		int div = 1;
+		int mode = (conf / div) % numModes;
+		StateObjectCollection<OperatorStateHandle> s1 =
+			mode == 0 ?
+				StateObjectCollection.empty() :
+				mode == 1 ?
+					new StateObjectCollection<>(
+						Collections.singletonList(createNewOperatorStateHandle(2, random))) :
+					new StateObjectCollection<>(
+						Arrays.asList(
+							createNewOperatorStateHandle(2, random),
+							createNewOperatorStateHandle(2, random)));
+		div *= numModes;
+		mode = (conf / div) % numModes;
+		StateObjectCollection<OperatorStateHandle> s2 =
+			mode == 0 ?
+				StateObjectCollection.empty() :
+				mode == 1 ?
+					new StateObjectCollection<>(
+						Collections.singletonList(createNewOperatorStateHandle(2, random))) :
+					new StateObjectCollection<>(
+						Arrays.asList(
+							createNewOperatorStateHandle(2, random),
+							createNewOperatorStateHandle(2, random)));
+
+		div *= numModes;
+		mode = (conf / div) % numModes;
+		StateObjectCollection<KeyedStateHandle> s3 =
+			mode == 0 ?
+				StateObjectCollection.empty() :
+				mode == 1 ?
+					new StateObjectCollection<>(
+						Collections.singletonList(createNewKeyedStateHandle(keyGroupRange))) :
+					new StateObjectCollection<>(
+						Arrays.asList(
+							createNewKeyedStateHandle(keyGroupRange1),
+							createNewKeyedStateHandle(keyGroupRange2)));
+
+		div *= numModes;
+		mode = (conf / div) % numModes;
+		StateObjectCollection<KeyedStateHandle> s4 =
+			mode == 0 ?
+				StateObjectCollection.empty() :
+				mode == 1 ?
+					new StateObjectCollection<>(
+						Collections.singletonList(createNewKeyedStateHandle(keyGroupRange))) :
+					new StateObjectCollection<>(
+						Arrays.asList(
+							createNewKeyedStateHandle(keyGroupRange1),
+							createNewKeyedStateHandle(keyGroupRange2)));
+
+		return new OperatorSubtaskState(s1, s2, s3, s4);
+	}
+
+	/**
+	 * For all 4 sub-states:
+	 * - mode 0: One valid state handle (deep copy of original). Only this creates an OperatorSubtaskState that
+	 *           qualifies as alternative.
+	 * - mode 1: Empty StateHandleCollection.
+	 * - mode 2: One invalid state handle (e.g. wrong key group, different meta data)
+	 */
+	private OperatorSubtaskState createAlternativeSubtaskState(OperatorSubtaskState primaryOriginal, int mode) {
+		switch (mode) {
+			case 0:
+				return new OperatorSubtaskState(
+					deepCopyFirstElement(primaryOriginal.getManagedOperatorState()),
+					deepCopyFirstElement(primaryOriginal.getRawOperatorState()),
+					deepCopyFirstElement(primaryOriginal.getManagedKeyedState()),
+					deepCopyFirstElement(primaryOriginal.getRawKeyedState()));
+			case 1:
+				return new OperatorSubtaskState();
+			case 2:
+				KeyGroupRange otherRange = new KeyGroupRange(8, 16);
+				int numNamedStates = 2;
+				return new OperatorSubtaskState(
+					createNewOperatorStateHandle(numNamedStates, random),
+					createNewOperatorStateHandle(numNamedStates, random),
+					createNewKeyedStateHandle(otherRange),
+					createNewKeyedStateHandle(otherRange));
+			default:
+				throw new IllegalArgumentException("Mode: " + mode);
+		}
+	}
+
+	private <T extends StateObject> boolean checkResultAsExpected(
+		Function<OperatorSubtaskState, StateObjectCollection<T>> extractor,
+		Function<PrioritizedOperatorSubtaskState, Iterator<StateObjectCollection<T>>> extractor2,
+		PrioritizedOperatorSubtaskState prioritizedResult,
+		OperatorSubtaskState... expectedOrdered) {
+
+		List<StateObjectCollection<T>> collector = new ArrayList<>(expectedOrdered.length);
+		for (OperatorSubtaskState operatorSubtaskState : expectedOrdered) {
+			collector.add(extractor.apply(operatorSubtaskState));
+		}
+
+		return checkRepresentSameOrder(
+			extractor2.apply(prioritizedResult),
+			collector.toArray(new StateObjectCollection[collector.size()]));
+	}
+
+	private boolean checkRepresentSameOrder(
+		Iterator<? extends StateObjectCollection<?>> ordered,
+		StateObjectCollection<?>... expectedOrder) {
+
+		for (StateObjectCollection<?> objects : expectedOrder) {
+			if (!ordered.hasNext() || !checkContainedObjectsReferentialEquality(objects, ordered.next())) {
+				return false;
+			}
+		}
+
+		return !ordered.hasNext();
+	}
+
+	/**
+	 * Returns true iff, in iteration order, all objects in the first collection are equal by reference to their
+	 * corresponding object (by order) in the second collection and the size of the collections is equal.
+	 */
+	public boolean checkContainedObjectsReferentialEquality(StateObjectCollection<?> a, StateObjectCollection<?> b) {
+
+		if (a == b) {
+			return true;
+		}
+
+		if(a == null || b == null) {
+			return false;
+		}
+
+		if (a.size() != b.size()) {
+			return false;
+		}
+
+		Iterator<?> bIter = b.iterator();
+		for (StateObject stateObject : a) {
+			if (!bIter.hasNext() || bIter.next() != stateObject) {
+				return false;
+			}
+		}
+		return true;
+	}
+
+	/**
+	 * Creates a deep copy of the first state object in the given collection, or null if the collection is empy.
+	 */
+	private <T extends StateObject> T deepCopyFirstElement(StateObjectCollection<T> original) {
+		if (original.isEmpty()) {
+			return null;
+		}
+
+		T stateObject = original.iterator().next();
+		StateObject result;
+		if (stateObject instanceof OperatorStreamStateHandle) {
+			result = deepDummyCopy((OperatorStateHandle) stateObject);
+		} else if (stateObject instanceof KeyedStateHandle) {
+			result = deepDummyCopy((KeyedStateHandle) stateObject);
+		} else {
+			throw new IllegalStateException();
+		}
+		return (T) result;
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/df3e6bb7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java
new file mode 100644
index 0000000..548ca18e
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateHandleDummyUtil.java
@@ -0,0 +1,139 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.state.KeyGroupRange;
+import org.apache.flink.runtime.state.KeyedStateHandle;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.OperatorStreamStateHandle;
+import org.apache.flink.runtime.state.SharedStateRegistry;
+import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Random;
+import java.util.UUID;
+
+public class StateHandleDummyUtil {
+
+	/**
+	 * Creates a new test {@link OperatorStreamStateHandle} with a given number of randomly created named states.
+	 */
+	public static OperatorStateHandle createNewOperatorStateHandle(int numNamedStates, Random random) {
+		Map<String, OperatorStateHandle.StateMetaInfo> operatorStateMetaData = new HashMap<>(numNamedStates);
+		byte[] streamData = new byte[numNamedStates * 4];
+		random.nextBytes(streamData);
+		long off = 0;
+		for (int i = 0; i < numNamedStates; ++i) {
+			long[] offsets = new long[4];
+			for (int o = 0; o < offsets.length; ++o) {
+				offsets[o] = off++;
+			}
+			OperatorStateHandle.StateMetaInfo metaInfo =
+				new OperatorStateHandle.StateMetaInfo(offsets, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE);
+			operatorStateMetaData.put(String.valueOf(UUID.randomUUID()), metaInfo);
+		}
+		ByteStreamStateHandle byteStreamStateHandle =
+			new ByteStreamStateHandle(String.valueOf(UUID.randomUUID()), streamData);
+		return new OperatorStreamStateHandle(operatorStateMetaData, byteStreamStateHandle);
+	}
+
+	/**
+	 * Creates a new test {@link KeyedStateHandle} for the given key-group.
+	 */
+	public static KeyedStateHandle createNewKeyedStateHandle(KeyGroupRange keyGroupRange) {
+		return new DummyKeyedStateHandle(keyGroupRange);
+	}
+
+	/**
+	 * Creates a deep copy of the given {@link OperatorStreamStateHandle}.
+	 */
+	public static OperatorStateHandle deepDummyCopy(OperatorStateHandle original) {
+
+		if (original == null) {
+			return null;
+		}
+
+		ByteStreamStateHandle stateHandle = (ByteStreamStateHandle) original.getDelegateStateHandle();
+		ByteStreamStateHandle stateHandleCopy = new ByteStreamStateHandle(
+			String.valueOf(stateHandle.getHandleName()),
+			stateHandle.getData().clone());
+		Map<String, OperatorStateHandle.StateMetaInfo> offsets = original.getStateNameToPartitionOffsets();
+		Map<String, OperatorStateHandle.StateMetaInfo> offsetsCopy = new HashMap<>(offsets.size());
+
+		for (Map.Entry<String, OperatorStateHandle.StateMetaInfo> entry : offsets.entrySet()) {
+			OperatorStateHandle.StateMetaInfo metaInfo = entry.getValue();
+			OperatorStateHandle.StateMetaInfo metaInfoCopy =
+				new OperatorStateHandle.StateMetaInfo(metaInfo.getOffsets(), metaInfo.getDistributionMode());
+			offsetsCopy.put(String.valueOf(entry.getKey()), metaInfoCopy);
+		}
+		return new OperatorStreamStateHandle(offsetsCopy, stateHandleCopy);
+	}
+
+	/**
+	 * Creates deep copy of the given {@link KeyedStateHandle}.
+	 */
+	public static KeyedStateHandle deepDummyCopy(KeyedStateHandle original) {
+
+		if (original == null) {
+			return null;
+		}
+
+		KeyGroupRange keyGroupRange = original.getKeyGroupRange();
+		return new DummyKeyedStateHandle(
+			new KeyGroupRange(keyGroupRange.getStartKeyGroup(), keyGroupRange.getEndKeyGroup()));
+	}
+
+	/**
+	 * KeyedStateHandle that only holds a key-group information.
+	 */
+	private static class DummyKeyedStateHandle implements KeyedStateHandle {
+
+		private static final long serialVersionUID = 1L;
+
+		private final KeyGroupRange keyGroupRange;
+
+		private DummyKeyedStateHandle(KeyGroupRange keyGroupRange) {
+			this.keyGroupRange = keyGroupRange;
+		}
+
+		@Override
+		public KeyGroupRange getKeyGroupRange() {
+			return keyGroupRange;
+		}
+
+		@Override
+		public KeyedStateHandle getIntersection(KeyGroupRange keyGroupRange) {
+			return new DummyKeyedStateHandle(this.keyGroupRange.getIntersection(keyGroupRange));
+		}
+
+		@Override
+		public void registerSharedStates(SharedStateRegistry stateRegistry) {
+		}
+
+		@Override
+		public void discardState() throws Exception {
+		}
+
+		@Override
+		public long getStateSize() {
+			return 0L;
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/df3e6bb7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateObjectCollectionTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateObjectCollectionTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateObjectCollectionTest.java
new file mode 100644
index 0000000..b12ee27
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/StateObjectCollectionTest.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.state.StateObject;
+import org.apache.flink.util.MethodForwardingTestUtil;
+import org.apache.flink.util.TestLogger;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.function.Function;
+
+import static org.mockito.Mockito.mock;
+
+/**
+ * Tests for {@link StateObjectCollection}.
+ */
+public class StateObjectCollectionTest extends TestLogger {
+
+	@Test
+	public void testEmptyCollection() {
+		StateObjectCollection<StateObject> empty = StateObjectCollection.empty();
+		Assert.assertEquals(0, empty.getStateSize());
+	}
+
+	@Test
+	public void testForwardingCollectionMethods() throws Exception {
+		MethodForwardingTestUtil.testMethodForwarding(
+			Collection.class,
+			((Function<Collection, StateObjectCollection>) StateObjectCollection::new));
+	}
+
+	@Test
+	public void testForwardingStateObjectMethods() throws Exception {
+		MethodForwardingTestUtil.testMethodForwarding(
+			StateObject.class,
+			object -> new StateObjectCollection<>(Collections.singletonList(object)));
+	}
+
+	@Test
+	public void testHasState() {
+		StateObjectCollection<StateObject> stateObjects = new StateObjectCollection<>(new ArrayList<>());
+		Assert.assertFalse(stateObjects.hasState());
+
+		stateObjects = new StateObjectCollection<>(Collections.singletonList(null));
+		Assert.assertFalse(stateObjects.hasState());
+
+		stateObjects = new StateObjectCollection<>(Collections.singletonList(mock(StateObject.class)));
+		Assert.assertTrue(stateObjects.hasState());
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/df3e6bb7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshotTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshotTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshotTest.java
new file mode 100644
index 0000000..76f3906
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/TaskStateSnapshotTest.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.checkpoint;
+
+import org.apache.flink.runtime.jobgraph.OperatorID;
+import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.util.TestLogger;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.util.Random;
+
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.verify;
+
+public class TaskStateSnapshotTest extends TestLogger {
+
+	@Test
+	public void putGetSubtaskStateByOperatorID() {
+		TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot();
+
+		OperatorID operatorID_1 = new OperatorID();
+		OperatorID operatorID_2 = new OperatorID();
+		OperatorSubtaskState operatorSubtaskState_1 = new OperatorSubtaskState();
+		OperatorSubtaskState operatorSubtaskState_2 = new OperatorSubtaskState();
+		OperatorSubtaskState operatorSubtaskState_1_replace = new OperatorSubtaskState();
+
+		Assert.assertNull(taskStateSnapshot.getSubtaskStateByOperatorID(operatorID_1));
+		Assert.assertNull(taskStateSnapshot.getSubtaskStateByOperatorID(operatorID_2));
+		taskStateSnapshot.putSubtaskStateByOperatorID(operatorID_1, operatorSubtaskState_1);
+		taskStateSnapshot.putSubtaskStateByOperatorID(operatorID_2, operatorSubtaskState_2);
+		Assert.assertEquals(operatorSubtaskState_1, taskStateSnapshot.getSubtaskStateByOperatorID(operatorID_1));
+		Assert.assertEquals(operatorSubtaskState_2, taskStateSnapshot.getSubtaskStateByOperatorID(operatorID_2));
+		Assert.assertEquals(operatorSubtaskState_1, taskStateSnapshot.putSubtaskStateByOperatorID(operatorID_1, operatorSubtaskState_1_replace));
+		Assert.assertEquals(operatorSubtaskState_1_replace, taskStateSnapshot.getSubtaskStateByOperatorID(operatorID_1));
+	}
+
+	@Test
+	public void hasState() {
+		Random random = new Random(0x42);
+		TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot();
+		Assert.assertFalse(taskStateSnapshot.hasState());
+
+		OperatorSubtaskState emptyOperatorSubtaskState = new OperatorSubtaskState();
+		Assert.assertFalse(emptyOperatorSubtaskState.hasState());
+		taskStateSnapshot.putSubtaskStateByOperatorID(new OperatorID(), emptyOperatorSubtaskState);
+		Assert.assertFalse(taskStateSnapshot.hasState());
+
+		OperatorStateHandle stateHandle = StateHandleDummyUtil.createNewOperatorStateHandle(2, random);
+		OperatorSubtaskState nonEmptyOperatorSubtaskState = new OperatorSubtaskState(
+			stateHandle,
+			null,
+			null,
+			null
+		);
+
+		Assert.assertTrue(nonEmptyOperatorSubtaskState.hasState());
+		taskStateSnapshot.putSubtaskStateByOperatorID(new OperatorID(), nonEmptyOperatorSubtaskState);
+		Assert.assertTrue(taskStateSnapshot.hasState());
+	}
+
+	@Test
+	public void discardState() throws Exception {
+		TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot();
+		OperatorID operatorID_1 = new OperatorID();
+		OperatorID operatorID_2 = new OperatorID();
+
+		OperatorSubtaskState operatorSubtaskState_1 = mock(OperatorSubtaskState.class);
+		OperatorSubtaskState operatorSubtaskState_2 = mock(OperatorSubtaskState.class);
+
+		taskStateSnapshot.putSubtaskStateByOperatorID(operatorID_1, operatorSubtaskState_1);
+		taskStateSnapshot.putSubtaskStateByOperatorID(operatorID_2, operatorSubtaskState_2);
+
+		taskStateSnapshot.discardState();
+		verify(operatorSubtaskState_1).discardState();
+		verify(operatorSubtaskState_2).discardState();
+	}
+
+	@Test
+	public void getStateSize() {
+		Random random = new Random(0x42);
+		TaskStateSnapshot taskStateSnapshot = new TaskStateSnapshot();
+		Assert.assertEquals(0, taskStateSnapshot.getStateSize());
+
+		OperatorSubtaskState emptyOperatorSubtaskState = new OperatorSubtaskState();
+		Assert.assertFalse(emptyOperatorSubtaskState.hasState());
+		taskStateSnapshot.putSubtaskStateByOperatorID(new OperatorID(), emptyOperatorSubtaskState);
+		Assert.assertEquals(0, taskStateSnapshot.getStateSize());
+
+
+		OperatorStateHandle stateHandle_1 = StateHandleDummyUtil.createNewOperatorStateHandle(2, random);
+		OperatorSubtaskState nonEmptyOperatorSubtaskState_1 = new OperatorSubtaskState(
+			stateHandle_1,
+			null,
+			null,
+			null
+		);
+
+		OperatorStateHandle stateHandle_2 = StateHandleDummyUtil.createNewOperatorStateHandle(2, random);
+		OperatorSubtaskState nonEmptyOperatorSubtaskState_2 = new OperatorSubtaskState(
+			null,
+			stateHandle_2,
+			null,
+			null
+		);
+
+		taskStateSnapshot.putSubtaskStateByOperatorID(new OperatorID(), nonEmptyOperatorSubtaskState_1);
+		taskStateSnapshot.putSubtaskStateByOperatorID(new OperatorID(), nonEmptyOperatorSubtaskState_2);
+
+		long totalSize = stateHandle_1.getStateSize() + stateHandle_2.getStateSize();
+		Assert.assertEquals(totalSize, taskStateSnapshot.getStateSize());
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/df3e6bb7/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java
index d1d67ff..1963766 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/savepoint/CheckpointTestUtils.java
@@ -33,10 +33,10 @@ import org.apache.flink.runtime.state.KeyGroupRangeOffsets;
 import org.apache.flink.runtime.state.KeyGroupsStateHandle;
 import org.apache.flink.runtime.state.KeyedStateHandle;
 import org.apache.flink.runtime.state.OperatorStateHandle;
-import org.apache.flink.runtime.state.OperatorStateHandle.StateMetaInfo;
+import org.apache.flink.runtime.state.OperatorStreamStateHandle;
 import org.apache.flink.runtime.state.StateHandleID;
 import org.apache.flink.runtime.state.StreamStateHandle;
-import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare;
+import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.util.StringUtils;
 
 import java.util.ArrayList;
@@ -87,24 +87,24 @@ public class CheckpointTestUtils {
 			for (int subtaskIdx = 0; subtaskIdx < numSubtasksPerTask; subtaskIdx++) {
 
 				StreamStateHandle operatorStateBackend =
-					new TestByteStreamStateHandleDeepCompare("b", ("Beautiful").getBytes(ConfigConstants.DEFAULT_CHARSET));
+					new ByteStreamStateHandle("b", ("Beautiful").getBytes(ConfigConstants.DEFAULT_CHARSET));
 				StreamStateHandle operatorStateStream =
-					new TestByteStreamStateHandleDeepCompare("b", ("Beautiful").getBytes(ConfigConstants.DEFAULT_CHARSET));
+					new ByteStreamStateHandle("b", ("Beautiful").getBytes(ConfigConstants.DEFAULT_CHARSET));
 
 				OperatorStateHandle operatorStateHandleBackend = null;
 				OperatorStateHandle operatorStateHandleStream = null;
 				
-				Map<String, StateMetaInfo> offsetsMap = new HashMap<>();
+				Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>();
 				offsetsMap.put("A", new OperatorStateHandle.StateMetaInfo(new long[]{0, 10, 20}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
 				offsetsMap.put("B", new OperatorStateHandle.StateMetaInfo(new long[]{30, 40, 50}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
 				offsetsMap.put("C", new OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, OperatorStateHandle.Mode.UNION));
 
 				if (hasOperatorStateBackend) {
-					operatorStateHandleBackend = new OperatorStateHandle(offsetsMap, operatorStateBackend);
+					operatorStateHandleBackend = new OperatorStreamStateHandle(offsetsMap, operatorStateBackend);
 				}
 
 				if (hasOperatorStateStream) {
-					operatorStateHandleStream = new OperatorStateHandle(offsetsMap, operatorStateStream);
+					operatorStateHandleStream = new OperatorStreamStateHandle(offsetsMap, operatorStateStream);
 				}
 
 				KeyedStateHandle keyedStateBackend = null;
@@ -173,23 +173,23 @@ public class CheckpointTestUtils {
 				for (int chainIdx = 0; chainIdx < chainLength; ++chainIdx) {
 
 					StreamStateHandle operatorStateBackend =
-							new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes(ConfigConstants.DEFAULT_CHARSET));
+							new ByteStreamStateHandle("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes(ConfigConstants.DEFAULT_CHARSET));
 					StreamStateHandle operatorStateStream =
-							new TestByteStreamStateHandleDeepCompare("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes(ConfigConstants.DEFAULT_CHARSET));
-					Map<String, StateMetaInfo> offsetsMap = new HashMap<>();
+							new ByteStreamStateHandle("b-" + chainIdx, ("Beautiful-" + chainIdx).getBytes(ConfigConstants.DEFAULT_CHARSET));
+					Map<String, OperatorStateHandle.StateMetaInfo> offsetsMap = new HashMap<>();
 					offsetsMap.put("A", new OperatorStateHandle.StateMetaInfo(new long[]{0, 10, 20}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
 					offsetsMap.put("B", new OperatorStateHandle.StateMetaInfo(new long[]{30, 40, 50}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
 					offsetsMap.put("C", new OperatorStateHandle.StateMetaInfo(new long[]{60, 70, 80}, OperatorStateHandle.Mode.UNION));
 
 					if (chainIdx != noOperatorStateBackendAtIndex) {
 						OperatorStateHandle operatorStateHandleBackend =
-								new OperatorStateHandle(offsetsMap, operatorStateBackend);
+								new OperatorStreamStateHandle(offsetsMap, operatorStateBackend);
 						operatorStatesBackend.add(operatorStateHandleBackend);
 					}
 
 					if (chainIdx != noOperatorStateStreamAtIndex) {
 						OperatorStateHandle operatorStateHandleStream =
-								new OperatorStateHandle(offsetsMap, operatorStateStream);
+								new OperatorStreamStateHandle(offsetsMap, operatorStateStream);
 						operatorStatesStream.add(operatorStateHandleStream);
 					}
 				}
@@ -284,7 +284,7 @@ public class CheckpointTestUtils {
 	}
 
 	public static StreamStateHandle createDummyStreamStateHandle(Random rnd) {
-		return new TestByteStreamStateHandleDeepCompare(
+		return new ByteStreamStateHandle(
 			String.valueOf(createRandomUUID(rnd)),
 			String.valueOf(createRandomUUID(rnd)).getBytes(ConfigConstants.DEFAULT_CHARSET));
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/df3e6bb7/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
index 8f1e12c..005dd98 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/JobManagerHARecoveryTest.java
@@ -37,7 +37,9 @@ import org.apache.flink.runtime.checkpoint.CheckpointRecoveryFactory;
 import org.apache.flink.runtime.checkpoint.CheckpointRetentionPolicy;
 import org.apache.flink.runtime.checkpoint.CompletedCheckpointStore;
 import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
+import org.apache.flink.runtime.checkpoint.PrioritizedOperatorSubtaskState;
 import org.apache.flink.runtime.checkpoint.StandaloneCheckpointIDCounter;
+import org.apache.flink.runtime.checkpoint.StateObjectCollection;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
 import org.apache.flink.runtime.checkpoint.TestingCheckpointRecoveryFactory;
 import org.apache.flink.runtime.clusterframework.types.ResourceID;
@@ -68,6 +70,7 @@ import org.apache.flink.runtime.metrics.NoOpMetricRegistry;
 import org.apache.flink.runtime.metrics.groups.JobManagerMetricGroup;
 import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
 import org.apache.flink.runtime.state.OperatorStateHandle;
+import org.apache.flink.runtime.state.OperatorStreamStateHandle;
 import org.apache.flink.runtime.state.TaskStateManager;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
 import org.apache.flink.runtime.taskmanager.TaskManager;
@@ -79,7 +82,6 @@ import org.apache.flink.runtime.testingUtils.TestingTaskManagerMessages;
 import org.apache.flink.runtime.testingUtils.TestingUtils;
 import org.apache.flink.runtime.testutils.InMemorySubmittedJobGraphStore;
 import org.apache.flink.runtime.testutils.RecoverableCompletedCheckpointStore;
-import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare;
 import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.Preconditions;
 import org.apache.flink.util.TestLogger;
@@ -95,6 +97,7 @@ import akka.pattern.Patterns;
 import akka.testkit.CallingThreadDispatcher;
 import akka.testkit.JavaTestKit;
 import org.junit.AfterClass;
+import org.junit.Assert;
 import org.junit.BeforeClass;
 import org.junit.Rule;
 import org.junit.Test;
@@ -105,6 +108,7 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashMap;
+import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
@@ -479,17 +483,21 @@ public class JobManagerHARecoveryTest extends TestLogger {
 
 			OperatorID operatorID = OperatorID.fromJobVertexID(getEnvironment().getJobVertexId());
 			TaskStateManager taskStateManager = getEnvironment().getTaskStateManager();
-			OperatorSubtaskState subtaskState = taskStateManager.operatorStates(operatorID);
+			PrioritizedOperatorSubtaskState subtaskState = taskStateManager.prioritizedOperatorState(operatorID);
 
-			if(subtaskState != null) {
-				int subtaskIndex = getIndexInSubtaskGroup();
-				if (subtaskIndex < BlockingStatefulInvokable.recoveredStates.length) {
-					OperatorStateHandle operatorStateHandle = subtaskState.getManagedOperatorState().iterator().next();
+			int subtaskIndex = getIndexInSubtaskGroup();
+			if (subtaskIndex < BlockingStatefulInvokable.recoveredStates.length) {
+				Iterator<OperatorStateHandle> iterator =
+					subtaskState.getJobManagerManagedOperatorState().iterator();
+
+				if (iterator.hasNext()) {
+					OperatorStateHandle operatorStateHandle = iterator.next();
 					try (FSDataInputStream in = operatorStateHandle.openInputStream()) {
 						BlockingStatefulInvokable.recoveredStates[subtaskIndex] =
 							InstantiationUtil.deserializeObject(in, getUserCodeClassLoader());
 					}
 				}
+				Assert.assertFalse(iterator.hasNext());
 			}
 
 			LATCH.await();
@@ -516,7 +524,7 @@ public class JobManagerHARecoveryTest extends TestLogger {
 
 		@Override
 		public boolean triggerCheckpoint(CheckpointMetaData checkpointMetaData, CheckpointOptions checkpointOptions) throws Exception {
-			ByteStreamStateHandle byteStreamStateHandle = new TestByteStreamStateHandleDeepCompare(
+			ByteStreamStateHandle byteStreamStateHandle = new ByteStreamStateHandle(
 					String.valueOf(UUID.randomUUID()),
 					InstantiationUtil.serializeObject(checkpointMetaData.getCheckpointId()));
 
@@ -525,16 +533,16 @@ public class JobManagerHARecoveryTest extends TestLogger {
 				"test-state",
 				new OperatorStateHandle.StateMetaInfo(new long[]{0L}, OperatorStateHandle.Mode.SPLIT_DISTRIBUTE));
 
-			OperatorStateHandle operatorStateHandle = new OperatorStateHandle(stateNameToPartitionOffsets, byteStreamStateHandle);
+			OperatorStateHandle operatorStateHandle = new OperatorStreamStateHandle(stateNameToPartitionOffsets, byteStreamStateHandle);
 
 			TaskStateSnapshot checkpointStateHandles = new TaskStateSnapshot();
 			checkpointStateHandles.putSubtaskStateByOperatorID(
 				OperatorID.fromJobVertexID(getEnvironment().getJobVertexId()),
 				new OperatorSubtaskState(
-					Collections.singletonList(operatorStateHandle),
-					Collections.emptyList(),
-					Collections.emptyList(),
-					Collections.emptyList()));
+					StateObjectCollection.singleton(operatorStateHandle),
+					StateObjectCollection.empty(),
+					StateObjectCollection.empty(),
+					StateObjectCollection.empty()));
 
 			getEnvironment().acknowledgeCheckpoint(
 					checkpointMetaData.getCheckpointId(),

http://git-wip-us.apache.org/repos/asf/flink/blob/df3e6bb7/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/ZooKeeperSubmittedJobGraphsStoreITCase.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/ZooKeeperSubmittedJobGraphsStoreITCase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/ZooKeeperSubmittedJobGraphsStoreITCase.java
index 9454d90..c1a7b53 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/ZooKeeperSubmittedJobGraphsStoreITCase.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/ZooKeeperSubmittedJobGraphsStoreITCase.java
@@ -18,7 +18,6 @@
 
 package org.apache.flink.runtime.jobmanager;
 
-import akka.actor.ActorRef;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.runtime.akka.ListeningBehaviour;
 import org.apache.flink.runtime.concurrent.Executors;
@@ -28,11 +27,12 @@ import org.apache.flink.runtime.jobmanager.SubmittedJobGraphStore.SubmittedJobGr
 import org.apache.flink.runtime.state.RetrievableStateHandle;
 import org.apache.flink.runtime.state.RetrievableStreamStateHandle;
 import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
-import org.apache.flink.runtime.util.TestByteStreamStateHandleDeepCompare;
 import org.apache.flink.runtime.zookeeper.RetrievableStateStorageHelper;
 import org.apache.flink.runtime.zookeeper.ZooKeeperTestEnvironment;
 import org.apache.flink.util.InstantiationUtil;
 import org.apache.flink.util.TestLogger;
+
+import akka.actor.ActorRef;
 import org.junit.AfterClass;
 import org.junit.Before;
 import org.junit.Test;
@@ -65,7 +65,7 @@ public class ZooKeeperSubmittedJobGraphsStoreITCase extends TestLogger {
 	private final static RetrievableStateStorageHelper<SubmittedJobGraph> localStateStorage = new RetrievableStateStorageHelper<SubmittedJobGraph>() {
 		@Override
 		public RetrievableStateHandle<SubmittedJobGraph> store(SubmittedJobGraph state) throws IOException {
-			ByteStreamStateHandle byteStreamStateHandle = new TestByteStreamStateHandleDeepCompare(
+			ByteStreamStateHandle byteStreamStateHandle = new ByteStreamStateHandle(
 					String.valueOf(UUID.randomUUID()),
 					InstantiationUtil.serializeObject(state));
 			return new RetrievableStreamStateHandle<>(byteStreamStateHandle);

http://git-wip-us.apache.org/repos/asf/flink/blob/df3e6bb7/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/scheduler/SchedulerTestUtils.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/scheduler/SchedulerTestUtils.java b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/scheduler/SchedulerTestUtils.java
index 4186255..b7aa97d 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/scheduler/SchedulerTestUtils.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/jobmanager/scheduler/SchedulerTestUtils.java
@@ -91,7 +91,7 @@ public class SchedulerTestUtils {
 		when(vertex.toString()).thenReturn("TEST-VERTEX");
 		when(vertex.getJobVertex()).thenReturn(executionJobVertex);
 		when(vertex.getJobvertexId()).thenReturn(new JobVertexID());
-		
+
 		Execution execution = mock(Execution.class);
 		when(execution.getVertex()).thenReturn(vertex);
 		
@@ -126,6 +126,7 @@ public class SchedulerTestUtils {
 		ExecutionVertex vertex = mock(ExecutionVertex.class);
 
 		when(vertex.getPreferredLocationsBasedOnInputs()).thenReturn(preferredLocationFutures);
+		when(vertex.getPreferredLocations()).thenReturn(preferredLocationFutures);
 		when(vertex.getJobId()).thenReturn(new JobID());
 		when(vertex.toString()).thenReturn("TEST-VERTEX");
 		when(vertex.getJobVertex()).thenReturn(executionJobVertex);
@@ -152,7 +153,7 @@ public class SchedulerTestUtils {
 		when(vertex.toString()).thenReturn("TEST-VERTEX");
 		when(vertex.getTaskNameWithSubtaskIndex()).thenReturn("TEST-VERTEX");
 		when(vertex.getJobVertex()).thenReturn(executionJobVertex);
-		
+
 		Execution execution = mock(Execution.class);
 		when(execution.getVertex()).thenReturn(vertex);
 		

http://git-wip-us.apache.org/repos/asf/flink/blob/df3e6bb7/flink-runtime/src/test/java/org/apache/flink/runtime/metrics/TaskManagerMetricsTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/metrics/TaskManagerMetricsTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/metrics/TaskManagerMetricsTest.java
index 1798851..db04023 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/metrics/TaskManagerMetricsTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/metrics/TaskManagerMetricsTest.java
@@ -21,6 +21,7 @@ package org.apache.flink.runtime.metrics;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.runtime.akka.AkkaUtils;
 import org.apache.flink.runtime.clusterframework.types.ResourceID;
+import org.apache.flink.runtime.concurrent.Executors;
 import org.apache.flink.runtime.highavailability.HighAvailabilityServices;
 import org.apache.flink.runtime.highavailability.nonha.embedded.EmbeddedHaServices;
 import org.apache.flink.runtime.jobmanager.JobManager;
@@ -98,6 +99,7 @@ public class TaskManagerMetricsTest extends TestLogger {
 			TaskManagerServices taskManagerServices = TaskManagerServices.fromConfiguration(
 				taskManagerServicesConfiguration,
 				tmResourceID,
+				Executors.directExecutor(),
 				EnvironmentInformation.getSizeOfFreeHeapMemoryWithDefrag(),
 				EnvironmentInformation.getMaxJvmHeapMemory());
 
@@ -115,6 +117,7 @@ public class TaskManagerMetricsTest extends TestLogger {
 				taskManagerServices.getMemoryManager(),
 				taskManagerServices.getIOManager(),
 				taskManagerServices.getNetworkEnvironment(),
+				taskManagerServices.getTaskManagerStateStore(),
 				highAvailabilityServices,
 				taskManagerMetricGroup);
 

http://git-wip-us.apache.org/repos/asf/flink/blob/df3e6bb7/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
index 3c73e3d..7d1a777 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/DummyEnvironment.java
@@ -41,6 +41,7 @@ import org.apache.flink.runtime.metrics.groups.UnregisteredMetricGroups;
 import org.apache.flink.runtime.query.KvStateRegistry;
 import org.apache.flink.runtime.query.TaskKvStateRegistry;
 import org.apache.flink.runtime.state.TaskStateManager;
+import org.apache.flink.runtime.state.TestTaskStateManager;
 import org.apache.flink.runtime.taskmanager.TaskManagerRuntimeInfo;
 import org.apache.flink.runtime.util.TestingTaskManagerRuntimeInfo;
 
@@ -59,12 +60,17 @@ public class DummyEnvironment implements Environment {
 	private TaskStateManager taskStateManager;
 	private final AccumulatorRegistry accumulatorRegistry = new AccumulatorRegistry(jobId, executionId);
 
+	public DummyEnvironment() {
+		this("Test Job", 1, 0, 1);
+	}
+
 	public DummyEnvironment(String taskName, int numSubTasks, int subTaskIndex) {
 		this(taskName, numSubTasks, subTaskIndex, numSubTasks);
 	}
 
 	public DummyEnvironment(String taskName, int numSubTasks, int subTaskIndex, int maxParallelism) {
 		this.taskInfo = new TaskInfo(taskName, maxParallelism, subTaskIndex, numSubTasks, 0);
+		this.taskStateManager = new TestTaskStateManager();
 	}
 
 	public void setKvStateRegistry(KvStateRegistry kvStateRegistry) {

http://git-wip-us.apache.org/repos/asf/flink/blob/df3e6bb7/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
index 4d1037e..ce19a5e 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/operators/testutils/MockEnvironment.java
@@ -89,7 +89,9 @@ public class MockEnvironment implements Environment, AutoCloseable {
 
 	private final List<ResultPartitionWriter> outputs;
 
-	private final JobID jobID = new JobID();
+	private final JobID jobID;
+
+	private final JobVertexID jobVertexID;
 
 	private final BroadcastVariableManager bcVarManager = new BroadcastVariableManager();
 
@@ -170,11 +172,11 @@ public class MockEnvironment implements Environment, AutoCloseable {
 			bufferSize,
 			taskConfiguration,
 			executionConfig,
+			taskStateManager,
 			maxParallelism,
 			parallelism,
 			subtaskIndex,
-			Thread.currentThread().getContextClassLoader(),
-			taskStateManager);
+			Thread.currentThread().getContextClassLoader());
 
 	}
 
@@ -185,11 +187,45 @@ public class MockEnvironment implements Environment, AutoCloseable {
 			int bufferSize,
 			Configuration taskConfiguration,
 			ExecutionConfig executionConfig,
+			TaskStateManager taskStateManager,
 			int maxParallelism,
 			int parallelism,
 			int subtaskIndex,
-			ClassLoader userCodeClassLoader,
-			TaskStateManager taskStateManager) {
+			ClassLoader userCodeClassLoader) {
+		this(
+			new JobID(),
+			new JobVertexID(),
+			taskName,
+			memorySize,
+			inputSplitProvider,
+			bufferSize,
+			taskConfiguration,
+			executionConfig,
+			taskStateManager,
+			maxParallelism,
+			parallelism,
+			subtaskIndex,
+			userCodeClassLoader);
+	}
+
+	public MockEnvironment(
+		JobID jobID,
+		JobVertexID jobVertexID,
+		String taskName,
+		long memorySize,
+		MockInputSplitProvider inputSplitProvider,
+		int bufferSize,
+		Configuration taskConfiguration,
+		ExecutionConfig executionConfig,
+		TaskStateManager taskStateManager,
+		int maxParallelism,
+		int parallelism,
+		int subtaskIndex,
+		ClassLoader userCodeClassLoader) {
+
+		this.jobID = jobID;
+		this.jobVertexID = jobVertexID;
+
 		this.taskInfo = new TaskInfo(taskName, maxParallelism, subtaskIndex, parallelism, 0);
 		this.jobConfiguration = new Configuration();
 		this.taskConfiguration = taskConfiguration;
@@ -325,7 +361,7 @@ public class MockEnvironment implements Environment, AutoCloseable {
 
 	@Override
 	public JobVertexID getJobVertexId() {
-		return new JobVertexID(new byte[16]);
+		return jobVertexID;
 	}
 
 	@Override

http://git-wip-us.apache.org/repos/asf/flink/blob/df3e6bb7/flink-runtime/src/test/java/org/apache/flink/runtime/state/CheckpointStreamWithResultProviderTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/CheckpointStreamWithResultProviderTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/CheckpointStreamWithResultProviderTest.java
new file mode 100644
index 0000000..2af25d9
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/CheckpointStreamWithResultProviderTest.java
@@ -0,0 +1,210 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
+import org.apache.flink.util.MethodForwardingTestUtil;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.BeforeClass;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.io.Closeable;
+import java.io.File;
+import java.io.IOException;
+
+public class CheckpointStreamWithResultProviderTest extends TestLogger {
+
+	private static TemporaryFolder temporaryFolder;
+
+	@BeforeClass
+	public static void beforeClass() throws IOException {
+		temporaryFolder = new TemporaryFolder();
+		temporaryFolder.create();
+	}
+
+	@AfterClass
+	public static void afterClass() {
+		temporaryFolder.delete();
+	}
+
+	@Test
+	public void testFactory() throws Exception {
+
+		CheckpointStreamFactory primaryFactory = createCheckpointStreamFactory();
+		try (
+			CheckpointStreamWithResultProvider primaryOnly =
+				CheckpointStreamWithResultProvider.createSimpleStream(
+					CheckpointedStateScope.EXCLUSIVE,
+					primaryFactory)) {
+
+			Assert.assertTrue(primaryOnly instanceof CheckpointStreamWithResultProvider.PrimaryStreamOnly);
+		}
+
+		LocalRecoveryDirectoryProvider directoryProvider = createLocalRecoveryDirectoryProvider();
+		try (
+			CheckpointStreamWithResultProvider primaryAndSecondary =
+				CheckpointStreamWithResultProvider.createDuplicatingStream(
+					42L,
+					CheckpointedStateScope.EXCLUSIVE,
+					primaryFactory,
+					directoryProvider)) {
+
+			Assert.assertTrue(primaryAndSecondary instanceof CheckpointStreamWithResultProvider.PrimaryAndSecondaryStream);
+		}
+	}
+
+	@Test
+	public void testCloseAndFinalizeCheckpointStreamResultPrimaryOnly() throws Exception {
+		CheckpointStreamFactory primaryFactory = createCheckpointStreamFactory();
+
+		CheckpointStreamWithResultProvider resultProvider =
+			CheckpointStreamWithResultProvider.createSimpleStream(CheckpointedStateScope.EXCLUSIVE, primaryFactory);
+
+		SnapshotResult<StreamStateHandle> result = writeCheckpointTestData(resultProvider);
+
+		Assert.assertNotNull(result.getJobManagerOwnedSnapshot());
+		Assert.assertNull(result.getTaskLocalSnapshot());
+
+		try (FSDataInputStream inputStream = result.getJobManagerOwnedSnapshot().openInputStream()) {
+			Assert.assertEquals(0x42, inputStream.read());
+			Assert.assertEquals(-1, inputStream.read());
+		}
+	}
+
+	@Test
+	public void testCloseAndFinalizeCheckpointStreamResultPrimaryAndSecondary() throws Exception {
+		CheckpointStreamFactory primaryFactory = createCheckpointStreamFactory();
+		LocalRecoveryDirectoryProvider directoryProvider = createLocalRecoveryDirectoryProvider();
+
+		CheckpointStreamWithResultProvider resultProvider =
+			CheckpointStreamWithResultProvider.createDuplicatingStream(
+				42L,
+				CheckpointedStateScope.EXCLUSIVE,
+				primaryFactory,
+				directoryProvider);
+
+		SnapshotResult<StreamStateHandle> result = writeCheckpointTestData(resultProvider);
+
+		Assert.assertNotNull(result.getJobManagerOwnedSnapshot());
+		Assert.assertNotNull(result.getTaskLocalSnapshot());
+
+		try (FSDataInputStream inputStream = result.getJobManagerOwnedSnapshot().openInputStream()) {
+			Assert.assertEquals(0x42, inputStream.read());
+			Assert.assertEquals(-1, inputStream.read());
+		}
+
+		try (FSDataInputStream inputStream = result.getTaskLocalSnapshot().openInputStream()) {
+			Assert.assertEquals(0x42, inputStream.read());
+			Assert.assertEquals(-1, inputStream.read());
+		}
+	}
+
+	@Test
+	public void testCompletedAndCloseStateHandling() throws Exception {
+		CheckpointStreamFactory primaryFactory = createCheckpointStreamFactory();
+
+		testCloseBeforeComplete(new CheckpointStreamWithResultProvider.PrimaryStreamOnly(
+			primaryFactory.createCheckpointStateOutputStream(CheckpointedStateScope.EXCLUSIVE)));
+		testCompleteBeforeClose(new CheckpointStreamWithResultProvider.PrimaryStreamOnly(
+			primaryFactory.createCheckpointStateOutputStream(CheckpointedStateScope.EXCLUSIVE)));
+
+		testCloseBeforeComplete(new CheckpointStreamWithResultProvider.PrimaryAndSecondaryStream(
+				primaryFactory.createCheckpointStateOutputStream(CheckpointedStateScope.EXCLUSIVE),
+				primaryFactory.createCheckpointStateOutputStream(CheckpointedStateScope.EXCLUSIVE)));
+		testCompleteBeforeClose(new CheckpointStreamWithResultProvider.PrimaryAndSecondaryStream(
+				primaryFactory.createCheckpointStateOutputStream(CheckpointedStateScope.EXCLUSIVE),
+				primaryFactory.createCheckpointStateOutputStream(CheckpointedStateScope.EXCLUSIVE)));
+	}
+
+	@Test
+	public void testCloseMethodForwarding() throws Exception {
+		CheckpointStreamFactory streamFactory = createCheckpointStreamFactory();
+
+		MethodForwardingTestUtil.testMethodForwarding(
+			Closeable.class,
+			CheckpointStreamWithResultProvider.PrimaryStreamOnly::new,
+			() -> {
+				try {
+					return streamFactory.createCheckpointStateOutputStream(CheckpointedStateScope.EXCLUSIVE);
+				} catch (IOException e) {
+					throw new RuntimeException(e);
+				}
+			});
+
+		MethodForwardingTestUtil.testMethodForwarding(
+			Closeable.class,
+			CheckpointStreamWithResultProvider.PrimaryAndSecondaryStream::new,
+			() -> {
+				try {
+					return new DuplicatingCheckpointOutputStream(
+						streamFactory.createCheckpointStateOutputStream(CheckpointedStateScope.EXCLUSIVE),
+						streamFactory.createCheckpointStateOutputStream(CheckpointedStateScope.EXCLUSIVE));
+				} catch (IOException e) {
+					throw new RuntimeException(e);
+				}
+			});
+	}
+
+	private SnapshotResult<StreamStateHandle> writeCheckpointTestData(
+		CheckpointStreamWithResultProvider resultProvider) throws IOException {
+
+		CheckpointStreamFactory.CheckpointStateOutputStream checkpointOutputStream =
+			resultProvider.getCheckpointOutputStream();
+		checkpointOutputStream.write(0x42);
+		return resultProvider.closeAndFinalizeCheckpointStreamResult();
+	}
+
+	private CheckpointStreamFactory createCheckpointStreamFactory() {
+		return new MemCheckpointStreamFactory(16 * 1024);
+	}
+
+	/**
+	 * Test that an exception is thrown if the stream was already closed before and we ask for a result later.
+	 */
+	private void testCloseBeforeComplete(CheckpointStreamWithResultProvider resultProvider) throws IOException {
+		resultProvider.getCheckpointOutputStream().write(0x42);
+		resultProvider.close();
+		try {
+			resultProvider.closeAndFinalizeCheckpointStreamResult();
+			Assert.fail();
+		} catch (IOException ignore) {
+		}
+	}
+
+	private void testCompleteBeforeClose(CheckpointStreamWithResultProvider resultProvider) throws IOException {
+		resultProvider.getCheckpointOutputStream().write(0x42);
+		Assert.assertNotNull(resultProvider.closeAndFinalizeCheckpointStreamResult());
+		resultProvider.close();
+	}
+
+	private LocalRecoveryDirectoryProvider createLocalRecoveryDirectoryProvider() throws IOException {
+		File localStateDir = temporaryFolder.newFolder();
+		JobID jobID = new JobID();
+		JobVertexID jobVertexID = new JobVertexID();
+		int subtaskIdx = 0;
+		return new LocalRecoveryDirectoryProviderImpl(localStateDir, jobID, jobVertexID, subtaskIdx);
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/df3e6bb7/flink-runtime/src/test/java/org/apache/flink/runtime/state/DuplicatingCheckpointOutputStreamTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/DuplicatingCheckpointOutputStreamTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/DuplicatingCheckpointOutputStreamTest.java
new file mode 100644
index 0000000..886dbdd
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/DuplicatingCheckpointOutputStreamTest.java
@@ -0,0 +1,310 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.runtime.testutils.CommonTestUtils;
+import org.apache.flink.util.TestLogger;
+
+import org.apache.commons.io.IOUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+import javax.annotation.Nullable;
+
+import java.io.IOException;
+import java.util.Random;
+
+public class DuplicatingCheckpointOutputStreamTest extends TestLogger {
+
+	/**
+	 * Test that all writes are duplicated to both streams and that the state reflects what was written.
+	 */
+	@Test
+	public void testDuplicatedWrite() throws Exception {
+		int streamCapacity = 1024 * 1024;
+		TestMemoryCheckpointOutputStream primaryStream = new TestMemoryCheckpointOutputStream(streamCapacity);
+		TestMemoryCheckpointOutputStream secondaryStream = new TestMemoryCheckpointOutputStream(streamCapacity);
+		TestMemoryCheckpointOutputStream referenceStream = new TestMemoryCheckpointOutputStream(streamCapacity);
+		DuplicatingCheckpointOutputStream duplicatingStream =
+			new DuplicatingCheckpointOutputStream(primaryStream, secondaryStream, 64);
+		Random random = new Random(42);
+		for (int i = 0; i < 500; ++i) {
+			int choice = random.nextInt(3);
+			if (choice == 0) {
+				int val = random.nextInt();
+				referenceStream.write(val);
+				duplicatingStream.write(val);
+			} else {
+				byte[] bytes = new byte[random.nextInt(128)];
+				random.nextBytes(bytes);
+				if (choice == 1) {
+					referenceStream.write(bytes);
+					duplicatingStream.write(bytes);
+				} else {
+					int off = bytes.length > 0 ? random.nextInt(bytes.length) : 0;
+					int len = bytes.length > 0 ? random.nextInt(bytes.length - off) : 0;
+					referenceStream.write(bytes, off, len);
+					duplicatingStream.write(bytes, off, len);
+				}
+			}
+			Assert.assertEquals(referenceStream.getPos(), duplicatingStream.getPos());
+		}
+
+		StreamStateHandle refStateHandle = referenceStream.closeAndGetHandle();
+		StreamStateHandle primaryStateHandle = duplicatingStream.closeAndGetPrimaryHandle();
+		StreamStateHandle secondaryStateHandle = duplicatingStream.closeAndGetSecondaryHandle();
+
+		Assert.assertTrue(CommonTestUtils.isSteamContentEqual(
+			refStateHandle.openInputStream(),
+			primaryStateHandle.openInputStream()));
+
+		Assert.assertTrue(CommonTestUtils.isSteamContentEqual(
+			refStateHandle.openInputStream(),
+			secondaryStateHandle.openInputStream()));
+
+		refStateHandle.discardState();
+		primaryStateHandle.discardState();
+		secondaryStateHandle.discardState();
+	}
+
+	/**
+	 * This is the first of a set of tests that check that exceptions from the secondary stream do not impact that we
+	 * can create a result for the first stream.
+	 */
+	@Test
+	public void testSecondaryWriteFail() throws Exception {
+		DuplicatingCheckpointOutputStream duplicatingStream = createDuplicatingStreamWithFailingSecondary();
+		testFailingSecondaryStream(duplicatingStream, () -> {
+			for (int i = 0; i < 128; i++) {
+				duplicatingStream.write(42);
+			}
+		});
+	}
+
+	@Test
+	public void testFailingSecondaryWriteArrayFail() throws Exception {
+		DuplicatingCheckpointOutputStream duplicatingStream = createDuplicatingStreamWithFailingSecondary();
+		testFailingSecondaryStream(duplicatingStream, () -> duplicatingStream.write(new byte[512]));
+	}
+
+	@Test
+	public void testFailingSecondaryWriteArrayOffsFail() throws Exception {
+		DuplicatingCheckpointOutputStream duplicatingStream = createDuplicatingStreamWithFailingSecondary();
+		testFailingSecondaryStream(duplicatingStream, () -> duplicatingStream.write(new byte[512], 20, 130));
+	}
+
+	@Test
+	public void testFailingSecondaryFlush() throws Exception {
+		DuplicatingCheckpointOutputStream duplicatingStream = createDuplicatingStreamWithFailingSecondary();
+		testFailingSecondaryStream(duplicatingStream, duplicatingStream::flush);
+	}
+
+	@Test
+	public void testFailingSecondarySync() throws Exception {
+		DuplicatingCheckpointOutputStream duplicatingStream = createDuplicatingStreamWithFailingSecondary();
+		testFailingSecondaryStream(duplicatingStream, duplicatingStream::sync);
+	}
+
+	/**
+	 * This is the first of a set of tests that check that exceptions from the primary stream are immediately reported.
+	 */
+	@Test
+	public void testPrimaryWriteFail() throws Exception {
+		DuplicatingCheckpointOutputStream duplicatingStream = createDuplicatingStreamWithFailingPrimary();
+		testFailingPrimaryStream(duplicatingStream, () -> {
+			for (int i = 0; i < 128; i++) {
+				duplicatingStream.write(42);
+			}
+		});
+	}
+
+	@Test
+	public void testFailingPrimaryWriteArrayFail() throws Exception {
+		DuplicatingCheckpointOutputStream duplicatingStream = createDuplicatingStreamWithFailingPrimary();
+		testFailingPrimaryStream(duplicatingStream, () -> duplicatingStream.write(new byte[512]));
+	}
+
+	@Test
+	public void testFailingPrimaryWriteArrayOffsFail() throws Exception {
+		DuplicatingCheckpointOutputStream duplicatingStream = createDuplicatingStreamWithFailingPrimary();
+		testFailingPrimaryStream(duplicatingStream, () -> duplicatingStream.write(new byte[512], 20, 130));
+	}
+
+	@Test
+	public void testFailingPrimaryFlush() throws Exception {
+		DuplicatingCheckpointOutputStream duplicatingStream = createDuplicatingStreamWithFailingPrimary();
+		testFailingPrimaryStream(duplicatingStream, duplicatingStream::flush);
+	}
+
+	@Test
+	public void testFailingPrimarySync() throws Exception {
+		DuplicatingCheckpointOutputStream duplicatingStream = createDuplicatingStreamWithFailingPrimary();
+		testFailingPrimaryStream(duplicatingStream, duplicatingStream::sync);
+	}
+
+	/**
+	 * Tests that an exception from interacting with the secondary stream does not effect duplicating to the primary
+	 * stream, but is reflected later when we want the secondary state handle.
+	 */
+	private void testFailingSecondaryStream(
+		DuplicatingCheckpointOutputStream duplicatingStream,
+		StreamTestMethod testMethod) throws Exception {
+
+		testMethod.call();
+
+		duplicatingStream.write(42);
+
+		FailingCheckpointOutStream secondary =
+			(FailingCheckpointOutStream) duplicatingStream.getSecondaryOutputStream();
+
+		Assert.assertTrue(secondary.isClosed());
+
+		long pos = duplicatingStream.getPos();
+		StreamStateHandle primaryHandle = duplicatingStream.closeAndGetPrimaryHandle();
+
+		if (primaryHandle != null) {
+			Assert.assertEquals(pos, primaryHandle.getStateSize());
+		}
+
+		try {
+			duplicatingStream.closeAndGetSecondaryHandle();
+			Assert.fail();
+		} catch (IOException ioEx) {
+			Assert.assertEquals(ioEx.getCause(), duplicatingStream.getSecondaryStreamException());
+		}
+	}
+
+	/**
+	 * Test that a failing primary stream brings up an exception.
+	 */
+	private void testFailingPrimaryStream(
+		DuplicatingCheckpointOutputStream duplicatingStream,
+		StreamTestMethod testMethod) throws Exception {
+		try {
+			testMethod.call();
+			Assert.fail();
+		} catch (IOException ignore) {
+		} finally {
+			IOUtils.closeQuietly(duplicatingStream);
+		}
+	}
+
+	/**
+	 * Tests that in case of unaligned stream positions, the secondary stream is closed and the primary still works.
+	 * This is important because some code may rely on seeking to stream offsets in the created state files and if the
+	 * streams are not aligned this code could fail.
+	 */
+	@Test
+	public void testUnalignedStreamsException() throws IOException {
+		int streamCapacity = 1024 * 1024;
+		TestMemoryCheckpointOutputStream primaryStream = new TestMemoryCheckpointOutputStream(streamCapacity);
+		TestMemoryCheckpointOutputStream secondaryStream = new TestMemoryCheckpointOutputStream(streamCapacity);
+
+		primaryStream.write(42);
+
+		DuplicatingCheckpointOutputStream stream =
+			new DuplicatingCheckpointOutputStream(primaryStream, secondaryStream);
+
+		Assert.assertNotNull(stream.getSecondaryStreamException());
+		Assert.assertTrue(secondaryStream.isClosed());
+
+		stream.write(23);
+
+		try {
+			stream.closeAndGetSecondaryHandle();
+			Assert.fail();
+		} catch (IOException ignore) {
+			Assert.assertEquals(ignore.getCause(), stream.getSecondaryStreamException());
+		}
+
+		StreamStateHandle primaryHandle = stream.closeAndGetPrimaryHandle();
+
+		try (FSDataInputStream inputStream = primaryHandle.openInputStream();) {
+			Assert.assertEquals(42, inputStream.read());
+			Assert.assertEquals(23, inputStream.read());
+			Assert.assertEquals(-1, inputStream.read());
+		}
+	}
+
+	/**
+	 * Helper
+	 */
+	private DuplicatingCheckpointOutputStream createDuplicatingStreamWithFailingSecondary() throws IOException {
+		int streamCapacity = 1024 * 1024;
+		TestMemoryCheckpointOutputStream primaryStream = new TestMemoryCheckpointOutputStream(streamCapacity);
+		FailingCheckpointOutStream failSecondaryStream = new FailingCheckpointOutStream();
+		return new DuplicatingCheckpointOutputStream(primaryStream, failSecondaryStream, 64);
+	}
+
+	private DuplicatingCheckpointOutputStream createDuplicatingStreamWithFailingPrimary() throws IOException {
+		int streamCapacity = 1024 * 1024;
+		FailingCheckpointOutStream failPrimaryStream = new FailingCheckpointOutStream();
+		TestMemoryCheckpointOutputStream secondary = new TestMemoryCheckpointOutputStream(streamCapacity);
+		return new DuplicatingCheckpointOutputStream(failPrimaryStream, secondary, 64);
+	}
+
+	/**
+	 * Stream that throws {@link IOException} on all relevant methods under test.
+	 */
+	private static class FailingCheckpointOutStream extends CheckpointStreamFactory.CheckpointStateOutputStream {
+
+		private boolean closed = false;
+
+		@Nullable
+		@Override
+		public StreamStateHandle closeAndGetHandle() throws IOException {
+			throw new IOException();
+		}
+
+		@Override
+		public long getPos() throws IOException {
+			return 0;
+		}
+
+		@Override
+		public void write(int b) throws IOException {
+			throw new IOException();
+		}
+
+		@Override
+		public void flush() throws IOException {
+			throw new IOException();
+		}
+
+		@Override
+		public void sync() throws IOException {
+			throw new IOException();
+		}
+
+		@Override
+		public void close() throws IOException {
+			this.closed = true;
+		}
+
+		public boolean isClosed() {
+			return closed;
+		}
+	}
+
+	@FunctionalInterface
+	private interface StreamTestMethod {
+		void call() throws IOException;
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/df3e6bb7/flink-runtime/src/test/java/org/apache/flink/runtime/state/LocalRecoveryDirectoryProviderImplTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/LocalRecoveryDirectoryProviderImplTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/LocalRecoveryDirectoryProviderImplTest.java
new file mode 100644
index 0000000..cc97c0e
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/LocalRecoveryDirectoryProviderImplTest.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.state;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.core.fs.Path;
+import org.apache.flink.runtime.jobgraph.JobVertexID;
+import org.apache.flink.util.TestLogger;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.io.File;
+import java.io.IOException;
+
+/**
+ * Tests for {@link LocalRecoveryDirectoryProvider}.
+ */
+public class LocalRecoveryDirectoryProviderImplTest extends TestLogger {
+
+	private static final JobID JOB_ID = new JobID();
+	private static final JobVertexID JOB_VERTEX_ID = new JobVertexID();
+	private static final int SUBTASK_INDEX = 0;
+
+	@Rule
+	public TemporaryFolder tmpFolder = new TemporaryFolder();
+
+	private LocalRecoveryDirectoryProviderImpl directoryProvider;
+	private File[] allocBaseFolders;
+
+	@Before
+	public void setup() throws IOException {
+		this.allocBaseFolders = new File[]{tmpFolder.newFolder(), tmpFolder.newFolder(), tmpFolder.newFolder()};
+		this.directoryProvider = new LocalRecoveryDirectoryProviderImpl(
+			allocBaseFolders,
+			JOB_ID,
+			JOB_VERTEX_ID,
+			SUBTASK_INDEX);
+	}
+
+	@Test
+	public void allocationBaseDir() {
+		for (int i = 0; i < 10; ++i) {
+			Assert.assertEquals(allocBaseFolders[i % allocBaseFolders.length], directoryProvider.allocationBaseDirectory(i));
+		}
+	}
+
+	@Test
+	public void selectAllocationBaseDir() {
+		for (int i = 0; i < allocBaseFolders.length; ++i) {
+			Assert.assertEquals(allocBaseFolders[i], directoryProvider.selectAllocationBaseDirectory(i));
+		}
+	}
+
+	@Test
+	public void allocationBaseDirectoriesCount() {
+		Assert.assertEquals(allocBaseFolders.length, directoryProvider.allocationBaseDirsCount());
+	}
+
+	@Test
+	public void subtaskSpecificDirectory() {
+		for (int i = 0; i < 10; ++i) {
+			Assert.assertEquals(
+				new File(
+					directoryProvider.allocationBaseDirectory(i),
+					directoryProvider.subtaskDirString()),
+				directoryProvider.subtaskBaseDirectory(i));
+		}
+	}
+
+	@Test
+	public void subtaskCheckpointSpecificDirectory() {
+		for (int i = 0; i < 10; ++i) {
+			Assert.assertEquals(
+				new File(
+					directoryProvider.subtaskBaseDirectory(i),
+					directoryProvider.checkpointDirString(i)),
+				directoryProvider.subtaskSpecificCheckpointDirectory(i));
+		}
+	}
+
+	@Test
+	public void testPathStringConstants() {
+
+		Assert.assertEquals(
+			directoryProvider.subtaskDirString(),
+			"jid_" + JOB_ID + Path.SEPARATOR + "vtx_" + JOB_VERTEX_ID + "_sti_" + SUBTASK_INDEX);
+
+		final long checkpointId = 42;
+		Assert.assertEquals(
+			directoryProvider.checkpointDirString(checkpointId),
+			"chk_" + checkpointId);
+	}
+
+	@Test
+	public void testPreconditionsNotNullFiles() {
+		try {
+			new LocalRecoveryDirectoryProviderImpl(new File[]{null}, JOB_ID, JOB_VERTEX_ID, SUBTASK_INDEX);
+			Assert.fail();
+		} catch (NullPointerException ignore) {
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/df3e6bb7/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
index 4ac64e0..23493b5 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MemoryStateBackendTest.java
@@ -27,12 +27,12 @@ import org.apache.flink.api.common.typeutils.TypeSerializerSerializationUtil;
 import org.apache.flink.api.common.typeutils.base.IntSerializer;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
-import org.apache.flink.runtime.execution.Environment;
+import org.apache.flink.runtime.checkpoint.StateObjectCollection;
+import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
 import org.apache.flink.runtime.state.heap.HeapKeyedStateBackend;
 import org.apache.flink.runtime.state.memory.MemCheckpointStreamFactory;
 import org.apache.flink.runtime.state.memory.MemoryStateBackend;
 import org.apache.flink.util.FutureUtil;
-
 import org.junit.Assert;
 import org.junit.Ignore;
 import org.junit.Test;
@@ -43,7 +43,6 @@ import org.powermock.modules.junit4.PowerMockRunner;
 
 import java.io.IOException;
 import java.io.Serializable;
-import java.util.Collections;
 import java.util.concurrent.RunnableFuture;
 
 import static org.junit.Assert.assertEquals;
@@ -51,7 +50,6 @@ import static org.junit.Assert.fail;
 import static org.mockito.Matchers.any;
 import static org.mockito.Mockito.doThrow;
 import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.when;
 
 /**
  * Tests for the {@link org.apache.flink.runtime.state.memory.MemoryStateBackend}.
@@ -94,13 +92,11 @@ public class MemoryStateBackendTest extends StateBackendTestBase<MemoryStateBack
 	 */
 	@Test
 	public void testOperatorStateRestoreFailsIfSerializerDeserializationFails() throws Exception {
+		DummyEnvironment env = new DummyEnvironment();
 		AbstractStateBackend abstractStateBackend = new MemoryStateBackend(4096);
 
-		Environment env = mock(Environment.class);
-		when(env.getExecutionConfig()).thenReturn(new ExecutionConfig());
-		when(env.getUserClassLoader()).thenReturn(OperatorStateBackendTest.class.getClassLoader());
-
-		OperatorStateBackend operatorStateBackend = abstractStateBackend.createOperatorStateBackend(env, "test-op-name");
+		OperatorStateBackend operatorStateBackend =
+			abstractStateBackend.createOperatorStateBackend(env, "test-op-name");
 
 		// write some state
 		ListStateDescriptor<Serializable> stateDescriptor1 = new ListStateDescriptor<>("test1", new JavaSerializer<>());
@@ -124,9 +120,11 @@ public class MemoryStateBackendTest extends StateBackendTestBase<MemoryStateBack
 
 		CheckpointStreamFactory streamFactory = new MemCheckpointStreamFactory(MemoryStateBackend.DEFAULT_MAX_STATE_SIZE);
 
-		RunnableFuture<OperatorStateHandle> runnableFuture =
+		RunnableFuture<SnapshotResult<OperatorStateHandle>> runnableFuture =
 			operatorStateBackend.snapshot(1, 1, streamFactory, CheckpointOptions.forCheckpointWithDefaultLocation());
-		OperatorStateHandle stateHandle = FutureUtil.runIfNotDoneAndGet(runnableFuture);
+
+		SnapshotResult<OperatorStateHandle> snapshotResult = FutureUtil.runIfNotDoneAndGet(runnableFuture);
+		OperatorStateHandle stateHandle = snapshotResult.getJobManagerOwnedSnapshot();
 
 		try {
 
@@ -143,7 +141,7 @@ public class MemoryStateBackendTest extends StateBackendTestBase<MemoryStateBack
 			doThrow(new IOException()).when(mockProxy).read(any(DataInputViewStreamWrapper.class));
 			PowerMockito.whenNew(TypeSerializerSerializationUtil.TypeSerializerSerializationProxy.class).withAnyArguments().thenReturn(mockProxy);
 
-			operatorStateBackend.restore(Collections.singletonList(stateHandle));
+			operatorStateBackend.restore(StateObjectCollection.singleton(stateHandle));
 
 			fail("The operator state restore should have failed if the previous state serializer could not be loaded.");
 		} catch (IOException expected) {
@@ -186,10 +184,6 @@ public class MemoryStateBackendTest extends StateBackendTestBase<MemoryStateBack
 
 		// ========== restore snapshot ==========
 
-		Environment env = mock(Environment.class);
-		when(env.getExecutionConfig()).thenReturn(new ExecutionConfig());
-		when(env.getUserClassLoader()).thenReturn(OperatorStateBackendTest.class.getClassLoader());
-
 		// mock failure when deserializing serializer
 		TypeSerializerSerializationUtil.TypeSerializerSerializationProxy<?> mockProxy =
 				mock(TypeSerializerSerializationUtil.TypeSerializerSerializationProxy.class);
@@ -197,7 +191,7 @@ public class MemoryStateBackendTest extends StateBackendTestBase<MemoryStateBack
 		PowerMockito.whenNew(TypeSerializerSerializationUtil.TypeSerializerSerializationProxy.class).withAnyArguments().thenReturn(mockProxy);
 
 		try {
-			restoreKeyedBackend(IntSerializer.INSTANCE, snapshot, env);
+			restoreKeyedBackend(IntSerializer.INSTANCE, snapshot, new DummyEnvironment());
 
 			fail("The keyed state restore should have failed if the previous state serializer could not be loaded.");
 		} catch (IOException expected) {

http://git-wip-us.apache.org/repos/asf/flink/blob/df3e6bb7/flink-runtime/src/test/java/org/apache/flink/runtime/state/MultiStreamStateHandleTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MultiStreamStateHandleTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/MultiStreamStateHandleTest.java
deleted file mode 100644
index dd34f03..0000000
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/MultiStreamStateHandleTest.java
+++ /dev/null
@@ -1,135 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.flink.runtime.state;
-
-import org.apache.flink.core.fs.FSDataInputStream;
-import org.apache.flink.runtime.state.memory.ByteStreamStateHandle;
-import org.junit.Before;
-import org.junit.Test;
-
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-import java.util.Random;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.fail;
-
-public class MultiStreamStateHandleTest {
-
-	private static final int TEST_DATA_LENGTH = 123;
-	private Random random;
-	private byte[] testData;
-	private List<StreamStateHandle> streamStateHandles;
-
-	@Before
-	public void setup() {
-		random = new Random(0x42);
-		testData = new byte[TEST_DATA_LENGTH];
-		for (int i = 0; i < testData.length; ++i) {
-			testData[i] = (byte) i;
-		}
-
-		int idx = 0;
-		streamStateHandles = new ArrayList<>();
-		while (idx < testData.length) {
-			int len = random.nextInt(5);
-			byte[] sub = Arrays.copyOfRange(testData, idx, idx + len);
-			streamStateHandles.add(new ByteStreamStateHandle(String.valueOf(idx), sub));
-			idx += len;
-		}
-	}
-
-	@Test
-	public void testMetaData() throws IOException {
-		MultiStreamStateHandle multiStreamStateHandle = new MultiStreamStateHandle(streamStateHandles);
-		assertEquals(TEST_DATA_LENGTH, multiStreamStateHandle.getStateSize());
-	}
-
-	@Test
-	public void testLinearRead() throws IOException {
-		MultiStreamStateHandle multiStreamStateHandle = new MultiStreamStateHandle(streamStateHandles);
-		try (FSDataInputStream in = multiStreamStateHandle.openInputStream()) {
-
-			for (int i = 0; i < TEST_DATA_LENGTH; ++i) {
-				assertEquals(i, in.getPos());
-				assertEquals(testData[i], in.read());
-			}
-
-			assertEquals(-1, in.read());
-			assertEquals(TEST_DATA_LENGTH, in.getPos());
-			assertEquals(-1, in.read());
-			assertEquals(TEST_DATA_LENGTH, in.getPos());
-		}
-	}
-
-	@Test
-	public void testRandomRead() throws IOException {
-
-		MultiStreamStateHandle multiStreamStateHandle = new MultiStreamStateHandle(streamStateHandles);
-
-		try (FSDataInputStream in = multiStreamStateHandle.openInputStream()) {
-
-			for (int i = 0; i < 1000; ++i) {
-				int pos = random.nextInt(TEST_DATA_LENGTH);
-				int readLen = random.nextInt(TEST_DATA_LENGTH);
-				in.seek(pos);
-				while (--readLen > 0 && pos < TEST_DATA_LENGTH) {
-					assertEquals(pos, in.getPos());
-					assertEquals(testData[pos++], in.read());
-				}
-			}
-
-			in.seek(TEST_DATA_LENGTH);
-			assertEquals(TEST_DATA_LENGTH, in.getPos());
-			assertEquals(-1, in.read());
-
-			try {
-				in.seek(TEST_DATA_LENGTH + 1);
-				fail();
-			} catch (Exception ignored) {
-
-			}
-		}
-	}
-
-	@Test
-	public void testEmptyList() throws IOException {
-
-		MultiStreamStateHandle multiStreamStateHandle =
-				new MultiStreamStateHandle(Collections.<StreamStateHandle>emptyList());
-
-		try (FSDataInputStream in = multiStreamStateHandle.openInputStream()) {
-
-			assertEquals(0, in.getPos());
-			in.seek(0);
-			assertEquals(0, in.getPos());
-			assertEquals(-1, in.read());
-
-			try {
-				in.seek(1);
-				fail();
-			} catch (Exception ignored) {
-
-			}
-		}
-	}
-}
\ No newline at end of file


Mime
View raw message