flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From aljos...@apache.org
Subject [2/8] flink git commit: [FLINK-4844] Partitionable Raw Keyed/Operator State
Date Thu, 20 Oct 2016 14:15:20 GMT
http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
index 0e513fa..95115d6 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java
@@ -22,6 +22,7 @@ import io.netty.util.internal.ConcurrentSet;
 import org.apache.flink.api.common.JobID;
 import org.apache.flink.api.common.functions.RichFlatMapFunction;
 import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.api.common.state.ListState;
 import org.apache.flink.api.common.state.ValueState;
 import org.apache.flink.api.common.state.ValueStateDescriptor;
 import org.apache.flink.api.java.functions.KeySelector;
@@ -32,17 +33,21 @@ import org.apache.flink.runtime.client.JobExecutionException;
 import org.apache.flink.runtime.instance.ActorGateway;
 import org.apache.flink.runtime.jobgraph.JobGraph;
 import org.apache.flink.runtime.messages.JobManagerMessages;
+import org.apache.flink.runtime.state.FunctionInitializationContext;
+import org.apache.flink.runtime.state.FunctionSnapshotContext;
 import org.apache.flink.runtime.state.KeyGroupRangeAssignment;
 import org.apache.flink.runtime.state.filesystem.FsStateBackendFactory;
 import org.apache.flink.runtime.testingUtils.TestingCluster;
 import org.apache.flink.runtime.testingUtils.TestingJobManagerMessages;
 import org.apache.flink.streaming.api.checkpoint.Checkpointed;
+import org.apache.flink.streaming.api.checkpoint.CheckpointedFunction;
 import org.apache.flink.streaming.api.checkpoint.ListCheckpointed;
 import org.apache.flink.streaming.api.datastream.DataStream;
 import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
 import org.apache.flink.streaming.api.functions.sink.DiscardingSink;
 import org.apache.flink.streaming.api.functions.sink.SinkFunction;
 import org.apache.flink.streaming.api.functions.source.RichParallelSourceFunction;
+import org.apache.flink.streaming.api.functions.source.SourceFunction;
 import org.apache.flink.util.Collector;
 import org.apache.flink.util.TestLogger;
 import org.junit.AfterClass;
@@ -71,7 +76,6 @@ import static org.junit.Assert.fail;
 
 /**
  * TODO : parameterize to test all different state backends!
- * TODO: reactivate ignored test as soon as savepoints work with deactivated checkpoints.
  */
 public class RescalingITCase extends TestLogger {
 
@@ -79,6 +83,10 @@ public class RescalingITCase extends TestLogger {
 	private static final int slotsPerTaskManager = 2;
 	private static final int numSlots = numTaskManagers * slotsPerTaskManager;
 
+	enum OperatorCheckpointMethod {
+		NON_PARTITIONED, CHECKPOINTED_FUNCTION, LIST_CHECKPOINTED
+	}
+
 	private static TestingCluster cluster;
 
 	@ClassRule
@@ -242,7 +250,7 @@ public class RescalingITCase extends TestLogger {
 		try {
 			jobManager = cluster.getLeaderGateway(deadline.timeLeft());
 
-			JobGraph jobGraph = createJobGraphWithOperatorState(parallelism, maxParallelism, false);
+			JobGraph jobGraph = createJobGraphWithOperatorState(parallelism, maxParallelism, OperatorCheckpointMethod.NON_PARTITIONED);
 
 			jobID = jobGraph.getJobID();
 
@@ -280,7 +288,7 @@ public class RescalingITCase extends TestLogger {
 			// job successfully removed
 			jobID = null;
 
-			JobGraph scaledJobGraph = createJobGraphWithOperatorState(parallelism2, maxParallelism,
false);
+			JobGraph scaledJobGraph = createJobGraphWithOperatorState(parallelism2, maxParallelism,
OperatorCheckpointMethod.NON_PARTITIONED);
 
 			scaledJobGraph.setSavepointPath(savepointPath);
 
@@ -433,12 +441,22 @@ public class RescalingITCase extends TestLogger {
 
 	@Test
 	public void testSavepointRescalingInPartitionedOperatorState() throws Exception {
-		testSavepointRescalingPartitionedOperatorState(false);
+		testSavepointRescalingPartitionedOperatorState(false, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION);
 	}
 
 	@Test
 	public void testSavepointRescalingOutPartitionedOperatorState() throws Exception {
-		testSavepointRescalingPartitionedOperatorState(true);
+		testSavepointRescalingPartitionedOperatorState(true, OperatorCheckpointMethod.CHECKPOINTED_FUNCTION);
+	}
+
+	@Test
+	public void testSavepointRescalingInPartitionedOperatorStateList() throws Exception {
+		testSavepointRescalingPartitionedOperatorState(false, OperatorCheckpointMethod.LIST_CHECKPOINTED);
+	}
+
+	@Test
+	public void testSavepointRescalingOutPartitionedOperatorStateList() throws Exception {
+		testSavepointRescalingPartitionedOperatorState(true, OperatorCheckpointMethod.LIST_CHECKPOINTED);
 	}
 
 
@@ -446,7 +464,7 @@ public class RescalingITCase extends TestLogger {
 	 * Tests rescaling of partitioned operator state. More specific, we test the mechanism with
{@link ListCheckpointed}
 	 * as it subsumes {@link org.apache.flink.streaming.api.checkpoint.CheckpointedFunction}.
 	 */
-	public void testSavepointRescalingPartitionedOperatorState(boolean scaleOut) throws Exception
{
+	public void testSavepointRescalingPartitionedOperatorState(boolean scaleOut, OperatorCheckpointMethod
checkpointMethod) throws Exception {
 		final int parallelism = scaleOut ? numSlots : numSlots / 2;
 		final int parallelism2 = scaleOut ? numSlots / 2 : numSlots;
 		final int maxParallelism = 13;
@@ -459,13 +477,18 @@ public class RescalingITCase extends TestLogger {
 
 		int counterSize = Math.max(parallelism, parallelism2);
 
-		PartitionedStateSource.CHECK_CORRECT_SNAPSHOT = new int[counterSize];
-		PartitionedStateSource.CHECK_CORRECT_RESTORE = new int[counterSize];
+		if(checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION) {
+			PartitionedStateSource.CHECK_CORRECT_SNAPSHOT = new int[counterSize];
+			PartitionedStateSource.CHECK_CORRECT_RESTORE = new int[counterSize];
+		} else {
+			PartitionedStateSourceListCheckpointed.CHECK_CORRECT_SNAPSHOT = new int[counterSize];
+			PartitionedStateSourceListCheckpointed.CHECK_CORRECT_RESTORE = new int[counterSize];
+		}
 
 		try {
 			jobManager = cluster.getLeaderGateway(deadline.timeLeft());
 
-			JobGraph jobGraph = createJobGraphWithOperatorState(parallelism, maxParallelism, true);
+			JobGraph jobGraph = createJobGraphWithOperatorState(parallelism, maxParallelism, checkpointMethod);
 
 			jobID = jobGraph.getJobID();
 
@@ -504,7 +527,7 @@ public class RescalingITCase extends TestLogger {
 			// job successfully removed
 			jobID = null;
 
-			JobGraph scaledJobGraph = createJobGraphWithOperatorState(parallelism2, maxParallelism,
true);
+			JobGraph scaledJobGraph = createJobGraphWithOperatorState(parallelism2, maxParallelism,
checkpointMethod);
 
 			scaledJobGraph.setSavepointPath(savepointPath);
 
@@ -515,12 +538,22 @@ public class RescalingITCase extends TestLogger {
 			int sumExp = 0;
 			int sumAct = 0;
 
-			for (int c : PartitionedStateSource.CHECK_CORRECT_SNAPSHOT) {
-				sumExp += c;
-			}
+			if (checkpointMethod == OperatorCheckpointMethod.CHECKPOINTED_FUNCTION) {
+				for (int c : PartitionedStateSource.CHECK_CORRECT_SNAPSHOT) {
+					sumExp += c;
+				}
+
+				for (int c : PartitionedStateSource.CHECK_CORRECT_RESTORE) {
+					sumAct += c;
+				}
+			} else {
+				for (int c : PartitionedStateSourceListCheckpointed.CHECK_CORRECT_SNAPSHOT) {
+					sumExp += c;
+				}
 
-			for (int c : PartitionedStateSource.CHECK_CORRECT_RESTORE) {
-				sumAct += c;
+				for (int c : PartitionedStateSourceListCheckpointed.CHECK_CORRECT_RESTORE) {
+					sumAct += c;
+				}
 			}
 
 			assertEquals(sumExp, sumAct);
@@ -543,7 +576,7 @@ public class RescalingITCase extends TestLogger {
 	//------------------------------------------------------------------------------------------------------------------
 
 	private static JobGraph createJobGraphWithOperatorState(
-			int parallelism, int maxParallelism, boolean partitionedOperatorState) {
+			int parallelism, int maxParallelism, OperatorCheckpointMethod checkpointMethod) {
 
 		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
 		env.setParallelism(parallelism);
@@ -553,8 +586,23 @@ public class RescalingITCase extends TestLogger {
 
 		StateSourceBase.workStartedLatch = new CountDownLatch(1);
 
-		DataStream<Integer> input = env.addSource(
-				partitionedOperatorState ? new PartitionedStateSource() : new NonPartitionedStateSource());
+		SourceFunction<Integer> src;
+
+		switch (checkpointMethod) {
+			case CHECKPOINTED_FUNCTION:
+				src = new PartitionedStateSource();
+				break;
+			case LIST_CHECKPOINTED:
+				src = new PartitionedStateSourceListCheckpointed();
+				break;
+			case NON_PARTITIONED:
+				src = new NonPartitionedStateSource();
+				break;
+			default:
+				throw new IllegalArgumentException();
+		}
+
+		DataStream<Integer> input = env.addSource(src);
 
 		input.addSink(new DiscardingSink<Integer>());
 
@@ -711,7 +759,7 @@ public class RescalingITCase extends TestLogger {
 		}
 	}
 
-	private static class SubtaskIndexFlatMapper extends RichFlatMapFunction<Integer, Tuple2<Integer,
Integer>> {
+	private static class SubtaskIndexFlatMapper extends RichFlatMapFunction<Integer, Tuple2<Integer,
Integer>> implements CheckpointedFunction {
 
 		private static final long serialVersionUID = 5273172591283191348L;
 
@@ -727,12 +775,6 @@ public class RescalingITCase extends TestLogger {
 		}
 
 		@Override
-		public void open(Configuration configuration) {
-			counter = getRuntimeContext().getState(new ValueStateDescriptor<>("counter", Integer.class,
0));
-			sum = getRuntimeContext().getState(new ValueStateDescriptor<>("sum", Integer.class,
0));
-		}
-
-		@Override
 		public void flatMap(Integer value, Collector<Tuple2<Integer, Integer>> out)
throws Exception {
 
 			int count = counter.value() + 1;
@@ -746,6 +788,17 @@ public class RescalingITCase extends TestLogger {
 				workCompletedLatch.countDown();
 			}
 		}
+
+		@Override
+		public void snapshotState(FunctionSnapshotContext context) throws Exception {
+			//all managed, nothing to do.
+		}
+
+		@Override
+		public void initializeState(FunctionInitializationContext context) throws Exception {
+			counter = context.getManagedKeyedStateStore().getState(new ValueStateDescriptor<>("counter",
Integer.class, 0));
+			sum = context.getManagedKeyedStateStore().getState(new ValueStateDescriptor<>("sum",
Integer.class, 0));
+		}
 	}
 
 	private static class CollectionSink<IN> implements SinkFunction<IN> {
@@ -817,9 +870,9 @@ public class RescalingITCase extends TestLogger {
 		}
 	}
 
-	private static class PartitionedStateSource extends StateSourceBase implements ListCheckpointed<Integer>
{
+	private static class PartitionedStateSourceListCheckpointed extends StateSourceBase implements
ListCheckpointed<Integer> {
 
-		private static final long serialVersionUID = -359715965103593462L;
+		private static final long serialVersionUID = -4357864582992546L;
 		private static final int NUM_PARTITIONS = 7;
 
 		private static int[] CHECK_CORRECT_SNAPSHOT;
@@ -853,4 +906,46 @@ public class RescalingITCase extends TestLogger {
 			CHECK_CORRECT_RESTORE[getRuntimeContext().getIndexOfThisSubtask()] = counter;
 		}
 	}
+
+	private static class PartitionedStateSource extends StateSourceBase implements CheckpointedFunction
{
+
+		private static final long serialVersionUID = -359715965103593462L;
+		private static final int NUM_PARTITIONS = 7;
+
+		private ListState<Integer> counterPartitions;
+
+		private static int[] CHECK_CORRECT_SNAPSHOT;
+		private static int[] CHECK_CORRECT_RESTORE;
+
+
+		@Override
+		public void snapshotState(FunctionSnapshotContext context) throws Exception {
+
+			CHECK_CORRECT_SNAPSHOT[getRuntimeContext().getIndexOfThisSubtask()] = counter;
+
+			int div = counter / NUM_PARTITIONS;
+			int mod = counter % NUM_PARTITIONS;
+
+			for (int i = 0; i < NUM_PARTITIONS; ++i) {
+				int partitionValue = div;
+				if (mod > 0) {
+					--mod;
+					++partitionValue;
+				}
+				counterPartitions.add(partitionValue);
+			}
+		}
+
+		@Override
+		public void initializeState(FunctionInitializationContext context) throws Exception {
+			this.counterPartitions =
+					context.getManagedOperatorStateStore().getSerializableListState("counter_partitions");
+			if (context.isRestored()) {
+				for (int v : counterPartitions.get()) {
+					counter += v;
+				}
+				CHECK_CORRECT_RESTORE[getRuntimeContext().getIndexOfThisSubtask()] = counter;
+			}
+		}
+	}
 }

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
index 92e1f41..fc48719 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/SavepointITCase.java
@@ -338,7 +338,8 @@ public class SavepointITCase extends TestLogger {
 
 					assertNotNull(subtaskState);
 					errMsg = "Initial operator state mismatch.";
-					assertEquals(errMsg, subtaskState.getChainedStateHandle(), tdd.getOperatorState());
+					assertEquals(errMsg, subtaskState.getLegacyOperatorState(),
+							tdd.getTaskStateHandles().getLegacyOperatorState());
 				}
 			}
 
@@ -364,7 +365,7 @@ public class SavepointITCase extends TestLogger {
 
 			for (TaskState stateForTaskGroup : savepoint.getTaskStates()) {
 				for (SubtaskState subtaskState : stateForTaskGroup.getStates()) {
-					ChainedStateHandle<StreamStateHandle> streamTaskState = subtaskState.getChainedStateHandle();
+					ChainedStateHandle<StreamStateHandle> streamTaskState = subtaskState.getLegacyOperatorState();
 
 					for (int i = 0; i < streamTaskState.getLength(); i++) {
 						if (streamTaskState.get(i) != null) {

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
index 2a635ab..963d18a 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/streaming/runtime/StateBackendITCase.java
@@ -38,7 +38,7 @@ import org.apache.flink.streaming.util.StreamingMultipleProgramsTestBase;
 import org.junit.Test;
 
 import java.io.IOException;
-import java.util.List;
+import java.util.Collection;
 
 import static org.junit.Assert.fail;
 
@@ -119,7 +119,7 @@ public class StateBackendITCase extends StreamingMultipleProgramsTestBase
{
 				TypeSerializer<K> keySerializer,
 				int numberOfKeyGroups,
 				KeyGroupRange keyGroupRange,
-				List<KeyGroupsStateHandle> restoredState,
+				Collection<KeyGroupsStateHandle> restoredState,
 				TaskKvStateRegistry kvStateRegistry) throws Exception {
 			throw new SuccessException();
 		}

http://git-wip-us.apache.org/repos/asf/flink/blob/cab9cd44/flink-yarn/src/main/java/org/apache/flink/yarn/cli/FlinkYarnSessionCli.java
----------------------------------------------------------------------
diff --git a/flink-yarn/src/main/java/org/apache/flink/yarn/cli/FlinkYarnSessionCli.java b/flink-yarn/src/main/java/org/apache/flink/yarn/cli/FlinkYarnSessionCli.java
index 88708b6..7ce040b 100644
--- a/flink-yarn/src/main/java/org/apache/flink/yarn/cli/FlinkYarnSessionCli.java
+++ b/flink-yarn/src/main/java/org/apache/flink/yarn/cli/FlinkYarnSessionCli.java
@@ -31,12 +31,12 @@ import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.GlobalConfiguration;
 import org.apache.flink.configuration.HighAvailabilityOptions;
 import org.apache.flink.configuration.IllegalConfigurationException;
+import org.apache.flink.runtime.clusterframework.ApplicationStatus;
+import org.apache.flink.runtime.clusterframework.messages.GetClusterStatusResponse;
 import org.apache.flink.runtime.security.SecurityContext;
 import org.apache.flink.yarn.AbstractYarnClusterDescriptor;
-import org.apache.flink.yarn.YarnClusterDescriptor;
 import org.apache.flink.yarn.YarnClusterClient;
-import org.apache.flink.runtime.clusterframework.ApplicationStatus;
-import org.apache.flink.runtime.clusterframework.messages.GetClusterStatusResponse;
+import org.apache.flink.yarn.YarnClusterDescriptor;
 import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.yarn.util.ConverterUtils;
 import org.slf4j.Logger;


Mime
View raw message