Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id 578F5200B85 for ; Wed, 31 Aug 2016 19:28:26 +0200 (CEST) Received: by cust-asf.ponee.io (Postfix) id 561FA160AB4; Wed, 31 Aug 2016 17:28:26 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id 77E1F160ADC for ; Wed, 31 Aug 2016 19:28:23 +0200 (CEST) Received: (qmail 28066 invoked by uid 500); 31 Aug 2016 17:28:22 -0000 Mailing-List: contact commits-help@flink.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@flink.apache.org Delivered-To: mailing list commits@flink.apache.org Received: (qmail 26459 invoked by uid 99); 31 Aug 2016 17:28:20 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Wed, 31 Aug 2016 17:28:20 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 43DE5ED30E; Wed, 31 Aug 2016 17:28:20 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: aljoscha@apache.org To: commits@flink.apache.org Date: Wed, 31 Aug 2016 17:28:42 -0000 Message-Id: <7ab099d257ae4551a4f288761ff2e786@git.apache.org> In-Reply-To: <269a6aa1f4e74350a6152149f78faed9@git.apache.org> References: <269a6aa1f4e74350a6152149f78faed9@git.apache.org> X-Mailer: ASF-Git Admin Mailer Subject: [24/27] flink git commit: [FLINK-4380] Add tests for new Key-Group/Max-Parallelism archived-at: Wed, 31 Aug 2016 17:28:26 -0000 [FLINK-4380] Add tests for new Key-Group/Max-Parallelism This tests the rescaling features in CheckpointCoordinator and SavepointCoordinator. Project: http://git-wip-us.apache.org/repos/asf/flink/repo Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/516ad011 Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/516ad011 Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/516ad011 Branch: refs/heads/master Commit: 516ad011865ca5beece273ca9b985e2861b3435a Parents: 847ead0 Author: Till Rohrmann Authored: Thu Aug 11 12:14:18 2016 +0200 Committer: Aljoscha Krettek Committed: Wed Aug 31 19:10:01 2016 +0200 ---------------------------------------------------------------------- .../checkpoint/CheckpointCoordinatorTest.java | 733 ++++++++++++++++++- .../runtime/tasks/OneInputStreamTaskTest.java | 280 ++++++- .../test/checkpointing/RescalingITCase.java | 1 - 3 files changed, 1007 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/flink/blob/516ad011/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java ---------------------------------------------------------------------- diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java index 50330fa..495dced 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/checkpoint/CheckpointCoordinatorTest.java @@ -18,28 +18,45 @@ package org.apache.flink.runtime.checkpoint; +import com.google.common.collect.Iterables; import org.apache.flink.api.common.JobID; +import org.apache.flink.core.fs.FSDataInputStream; import org.apache.flink.runtime.checkpoint.savepoint.HeapSavepointStore; import org.apache.flink.runtime.checkpoint.stats.DisabledCheckpointStatsTracker; import org.apache.flink.runtime.execution.ExecutionState; import org.apache.flink.runtime.executiongraph.Execution; import org.apache.flink.runtime.executiongraph.ExecutionAttemptID; +import org.apache.flink.runtime.executiongraph.ExecutionJobVertex; import org.apache.flink.runtime.executiongraph.ExecutionVertex; import org.apache.flink.runtime.jobgraph.JobVertexID; import org.apache.flink.runtime.messages.checkpoint.AcknowledgeCheckpoint; import org.apache.flink.runtime.messages.checkpoint.DeclineCheckpoint; import org.apache.flink.runtime.messages.checkpoint.NotifyCheckpointComplete; import org.apache.flink.runtime.messages.checkpoint.TriggerCheckpoint; +import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.KeyGroupRange; +import org.apache.flink.runtime.state.KeyGroupRangeOffsets; +import org.apache.flink.runtime.state.KeyGroupsStateHandle; +import org.apache.flink.runtime.state.StreamStateHandle; +import org.apache.flink.runtime.state.memory.ByteStreamStateHandle; +import org.apache.flink.util.InstantiationUtil; +import org.apache.flink.util.Preconditions; +import org.junit.Assert; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import scala.concurrent.ExecutionContext; import scala.concurrent.Future; +import java.io.IOException; import java.io.Serializable; import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.Random; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -47,12 +64,14 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.eq; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -124,7 +143,7 @@ public class CheckpointCoordinatorTest { final ExecutionAttemptID triggerAttemptID1 = new ExecutionAttemptID(); final ExecutionAttemptID triggerAttemptID2 = new ExecutionAttemptID(); ExecutionVertex triggerVertex1 = mockExecutionVertex(triggerAttemptID1); - ExecutionVertex triggerVertex2 = mockExecutionVertex(triggerAttemptID2, ExecutionState.FINISHED); + ExecutionVertex triggerVertex2 = mockExecutionVertex(triggerAttemptID2, new JobVertexID(), 1, 1, ExecutionState.FINISHED); // create some mock Execution vertices that need to ack the checkpoint final ExecutionAttemptID ackAttemptID1 = new ExecutionAttemptID(); @@ -1529,7 +1548,7 @@ public class CheckpointCoordinatorTest { coord.startCheckpointScheduler(); // after a while, there should be exactly as many checkpoints - // as concurrently permitted + // as concurrently permitted long now = System.currentTimeMillis(); long timeout = now + 60000; long minDuration = now + 100; @@ -1622,7 +1641,7 @@ public class CheckpointCoordinatorTest { } while (System.currentTimeMillis() < timeout && coord.getNumberOfPendingCheckpoints() == 0); - + assertTrue(coord.getNumberOfPendingCheckpoints() > 0); } catch (Exception e) { @@ -1738,4 +1757,712 @@ public class CheckpointCoordinatorTest { return vertex; } +/** + * Tests that the checkpointed partitioned and non-partitioned state is assigned properly to + * the {@link Execution} upon recovery. + * + * @throws Exception + */ + @Test + public void testRestoreLatestCheckpointedState() throws Exception { + final JobID jid = new JobID(); + final long timestamp = System.currentTimeMillis(); + + final JobVertexID jobVertexID1 = new JobVertexID(); + final JobVertexID jobVertexID2 = new JobVertexID(); + int parallelism1 = 3; + int parallelism2 = 2; + int maxParallelism1 = 42; + int maxParallelism2 = 13; + + final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex( + jobVertexID1, + parallelism1, + maxParallelism1); + final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex( + jobVertexID2, + parallelism2, + maxParallelism2); + + List allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2); + + allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices())); + allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices())); + + ExecutionVertex[] arrayExecutionVertices = allExecutionVertices.toArray(new ExecutionVertex[0]); + + // set up the coordinator and validate the initial state + CheckpointCoordinator coord = new CheckpointCoordinator( + jid, + 600000, + 600000, + 0, + Integer.MAX_VALUE, + arrayExecutionVertices, + arrayExecutionVertices, + arrayExecutionVertices, + cl, + new StandaloneCheckpointIDCounter(), + new StandaloneCompletedCheckpointStore(1, cl), + new HeapSavepointStore(), + new DisabledCheckpointStatsTracker()); + + // trigger the checkpoint + coord.triggerCheckpoint(timestamp); + + assertTrue(coord.getPendingCheckpoints().keySet().size() == 1); + long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet()); + + List keyGroupPartitions1 = coord.createKeyGroupPartitions(maxParallelism1, parallelism1); + List keyGroupPartitions2 = coord.createKeyGroupPartitions(maxParallelism2, parallelism2); + + for (int index = 0; index < jobVertex1.getParallelism(); index++) { + ChainedStateHandle nonPartitionedState = generateStateForVertex(jobVertexID1, index); + List partitionedKeyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index)); + + AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( + jid, + jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), + checkpointId, + nonPartitionedState, + partitionedKeyGroupState); + + coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); + } + + + for (int index = 0; index < jobVertex2.getParallelism(); index++) { + ChainedStateHandle nonPartitionedState = generateStateForVertex(jobVertexID2, index); + List partitionedKeyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index)); + + AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( + jid, + jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), + checkpointId, + nonPartitionedState, + partitionedKeyGroupState); + + coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); + } + + List completedCheckpoints = coord.getSuccessfulCheckpoints(); + + assertEquals(1, completedCheckpoints.size()); + + Map tasks = new HashMap<>(); + + tasks.put(jobVertexID1, jobVertex1); + tasks.put(jobVertexID2, jobVertex2); + + coord.restoreLatestCheckpointedState(tasks, true, true); + + // verify the restored state + verifiyStateRestore(jobVertexID1, jobVertex1, keyGroupPartitions1); + verifiyStateRestore(jobVertexID2, jobVertex2, keyGroupPartitions2); + } + + /** + * Tests that the checkpoint restoration fails if the max parallelism of the job vertices has + * changed. + * + * @throws Exception + */ + @Test(expected=IllegalStateException.class) + public void testRestoreLatestCheckpointFailureWhenMaxParallelismChanges() throws Exception { + final JobID jid = new JobID(); + final long timestamp = System.currentTimeMillis(); + + final JobVertexID jobVertexID1 = new JobVertexID(); + final JobVertexID jobVertexID2 = new JobVertexID(); + int parallelism1 = 3; + int parallelism2 = 2; + int maxParallelism1 = 42; + int maxParallelism2 = 13; + + final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex( + jobVertexID1, + parallelism1, + maxParallelism1); + final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex( + jobVertexID2, + parallelism2, + maxParallelism2); + + List allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2); + + allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices())); + allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices())); + + ExecutionVertex[] arrayExecutionVertices = allExecutionVertices.toArray(new ExecutionVertex[0]); + + // set up the coordinator and validate the initial state + CheckpointCoordinator coord = new CheckpointCoordinator( + jid, + 600000, + 600000, + 0, + Integer.MAX_VALUE, + arrayExecutionVertices, + arrayExecutionVertices, + arrayExecutionVertices, + cl, + new StandaloneCheckpointIDCounter(), + new StandaloneCompletedCheckpointStore(1, cl), + new HeapSavepointStore(), + new DisabledCheckpointStatsTracker()); + + // trigger the checkpoint + coord.triggerCheckpoint(timestamp); + + assertTrue(coord.getPendingCheckpoints().keySet().size() == 1); + long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet()); + + List keyGroupPartitions1 = coord.createKeyGroupPartitions(maxParallelism1, parallelism1); + List keyGroupPartitions2 = coord.createKeyGroupPartitions(maxParallelism2, parallelism2); + + for (int index = 0; index < jobVertex1.getParallelism(); index++) { + ChainedStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index); + List keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index)); + + AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( + jid, + jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), + checkpointId, + valueSizeTuple, + keyGroupState); + + coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); + } + + + for (int index = 0; index < jobVertex2.getParallelism(); index++) { + ChainedStateHandle valueSizeTuple = generateStateForVertex(jobVertexID2, index); + List keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index)); + + AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( + jid, + jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), + checkpointId, + valueSizeTuple, + keyGroupState); + + coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); + } + + List completedCheckpoints = coord.getSuccessfulCheckpoints(); + + assertEquals(1, completedCheckpoints.size()); + + Map tasks = new HashMap<>(); + + int newMaxParallelism1 = 20; + int newMaxParallelism2 = 42; + + final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex( + jobVertexID1, + parallelism1, + newMaxParallelism1); + + final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex( + jobVertexID2, + parallelism2, + newMaxParallelism2); + + tasks.put(jobVertexID1, newJobVertex1); + tasks.put(jobVertexID2, newJobVertex2); + + coord.restoreLatestCheckpointedState(tasks, true, true); + + fail("The restoration should have failed because the max parallelism changed."); + } + + /** + * Tests that the checkpoint restoration fails if the parallelism of a job vertices with + * non-partitioned state has changed. + * + * @throws Exception + */ + @Test(expected=IllegalStateException.class) + public void testRestoreLatestCheckpointFailureWhenParallelismChanges() throws Exception { + final JobID jid = new JobID(); + final long timestamp = System.currentTimeMillis(); + + final JobVertexID jobVertexID1 = new JobVertexID(); + final JobVertexID jobVertexID2 = new JobVertexID(); + int parallelism1 = 3; + int parallelism2 = 2; + int maxParallelism1 = 42; + int maxParallelism2 = 13; + + final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex( + jobVertexID1, + parallelism1, + maxParallelism1); + final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex( + jobVertexID2, + parallelism2, + maxParallelism2); + + List allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2); + + allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices())); + allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices())); + + ExecutionVertex[] arrayExecutionVertices = allExecutionVertices.toArray(new ExecutionVertex[0]); + + // set up the coordinator and validate the initial state + CheckpointCoordinator coord = new CheckpointCoordinator( + jid, + 600000, + 600000, + 0, + Integer.MAX_VALUE, + arrayExecutionVertices, + arrayExecutionVertices, + arrayExecutionVertices, + cl, + new StandaloneCheckpointIDCounter(), + new StandaloneCompletedCheckpointStore(1, cl), + new HeapSavepointStore(), + new DisabledCheckpointStatsTracker()); + + // trigger the checkpoint + coord.triggerCheckpoint(timestamp); + + assertTrue(coord.getPendingCheckpoints().keySet().size() == 1); + long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet()); + + List keyGroupPartitions1 = coord.createKeyGroupPartitions(maxParallelism1, parallelism1); + List keyGroupPartitions2 = coord.createKeyGroupPartitions(maxParallelism2, parallelism2); + + for (int index = 0; index < jobVertex1.getParallelism(); index++) { + ChainedStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index); + List keyGroupState = generateKeyGroupState( + jobVertexID1, keyGroupPartitions1.get(index)); + + AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( + jid, + jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), + checkpointId, + valueSizeTuple, + keyGroupState); + + coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); + } + + + for (int index = 0; index < jobVertex2.getParallelism(); index++) { + + ChainedStateHandle state = generateStateForVertex(jobVertexID2, index); + List keyGroupState = generateKeyGroupState( + jobVertexID2, keyGroupPartitions2.get(index)); + + AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( + jid, + jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), + checkpointId, + state, + keyGroupState); + + coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); + } + + List completedCheckpoints = coord.getSuccessfulCheckpoints(); + + assertEquals(1, completedCheckpoints.size()); + + Map tasks = new HashMap<>(); + + int newParallelism1 = 4; + int newParallelism2 = 3; + + final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex( + jobVertexID1, + newParallelism1, + maxParallelism1); + + final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex( + jobVertexID2, + newParallelism2, + maxParallelism2); + + tasks.put(jobVertexID1, newJobVertex1); + tasks.put(jobVertexID2, newJobVertex2); + + coord.restoreLatestCheckpointedState(tasks, true, true); + + fail("The restoration should have failed because the parallelism of an vertex with " + + "non-partitioned state changed."); + } + + /** + * Tests the checkpoint restoration with changing parallelism of job vertex with partitioned + * state. + * + * @throws Exception + */ + @Test + public void testRestoreLatestCheckpointedStateWithChangingParallelism() throws Exception { + final JobID jid = new JobID(); + final long timestamp = System.currentTimeMillis(); + + final JobVertexID jobVertexID1 = new JobVertexID(); + final JobVertexID jobVertexID2 = new JobVertexID(); + int parallelism1 = 3; + int parallelism2 = 2; + int maxParallelism1 = 42; + int maxParallelism2 = 13; + + final ExecutionJobVertex jobVertex1 = mockExecutionJobVertex( + jobVertexID1, + parallelism1, + maxParallelism1); + final ExecutionJobVertex jobVertex2 = mockExecutionJobVertex( + jobVertexID2, + parallelism2, + maxParallelism2); + + List allExecutionVertices = new ArrayList<>(parallelism1 + parallelism2); + + allExecutionVertices.addAll(Arrays.asList(jobVertex1.getTaskVertices())); + allExecutionVertices.addAll(Arrays.asList(jobVertex2.getTaskVertices())); + + ExecutionVertex[] arrayExecutionVertices = allExecutionVertices.toArray(new ExecutionVertex[0]); + + // set up the coordinator and validate the initial state + CheckpointCoordinator coord = new CheckpointCoordinator( + jid, + 600000, + 600000, + 0, + Integer.MAX_VALUE, + arrayExecutionVertices, + arrayExecutionVertices, + arrayExecutionVertices, + cl, + new StandaloneCheckpointIDCounter(), + new StandaloneCompletedCheckpointStore(1, cl), + new HeapSavepointStore(), + new DisabledCheckpointStatsTracker()); + + // trigger the checkpoint + coord.triggerCheckpoint(timestamp); + + assertTrue(coord.getPendingCheckpoints().keySet().size() == 1); + long checkpointId = Iterables.getOnlyElement(coord.getPendingCheckpoints().keySet()); + + List keyGroupPartitions1 = coord.createKeyGroupPartitions(maxParallelism1, parallelism1); + List keyGroupPartitions2 = coord.createKeyGroupPartitions(maxParallelism2, parallelism2); + + for (int index = 0; index < jobVertex1.getParallelism(); index++) { + ChainedStateHandle valueSizeTuple = generateStateForVertex(jobVertexID1, index); + List keyGroupState = generateKeyGroupState(jobVertexID1, keyGroupPartitions1.get(index)); + + AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( + jid, + jobVertex1.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), + checkpointId, + valueSizeTuple, + keyGroupState); + + coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); + } + + + for (int index = 0; index < jobVertex2.getParallelism(); index++) { + List keyGroupState = generateKeyGroupState(jobVertexID2, keyGroupPartitions2.get(index)); + + AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( + jid, + jobVertex2.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), + checkpointId, + null, + keyGroupState); + + coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); + } + + List completedCheckpoints = coord.getSuccessfulCheckpoints(); + + assertEquals(1, completedCheckpoints.size()); + + Map tasks = new HashMap<>(); + + int newParallelism2 = 13; + + List newKeyGroupPartitions2 = coord.createKeyGroupPartitions(maxParallelism2, newParallelism2); + + final ExecutionJobVertex newJobVertex1 = mockExecutionJobVertex( + jobVertexID1, + parallelism1, + maxParallelism1); + + final ExecutionJobVertex newJobVertex2 = mockExecutionJobVertex( + jobVertexID2, + newParallelism2, + maxParallelism2); + + tasks.put(jobVertexID1, newJobVertex1); + tasks.put(jobVertexID2, newJobVertex2); + coord.restoreLatestCheckpointedState(tasks, true, true); + + // verify the restored state + verifiyStateRestore(jobVertexID1, newJobVertex1, keyGroupPartitions1); + + for (int i = 0; i < newJobVertex2.getParallelism(); i++) { + List originalKeyGroupState = generateKeyGroupState(jobVertexID2, newKeyGroupPartitions2.get(i)); + + ChainedStateHandle operatorState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle(); + List keyGroupState = newJobVertex2.getTaskVertices()[i].getCurrentExecutionAttempt().getKeyGroupsStateHandles(); + + assertNull(operatorState); + comparePartitionedState(originalKeyGroupState, keyGroupState); + } + } + + // ------------------------------------------------------------------------ + // Utilities + // ------------------------------------------------------------------------ + + static void sendAckMessageToCoordinator( + CheckpointCoordinator coord, + long checkpointId, JobID jid, + ExecutionJobVertex jobVertex, + JobVertexID jobVertexID, + List keyGroupPartitions) throws Exception { + + for (int index = 0; index < jobVertex.getParallelism(); index++) { + ChainedStateHandle state = generateStateForVertex(jobVertexID, index); + List keyGroupState = generateKeyGroupState( + jobVertexID, + keyGroupPartitions.get(index)); + + AcknowledgeCheckpoint acknowledgeCheckpoint = new AcknowledgeCheckpoint( + jid, + jobVertex.getTaskVertices()[index].getCurrentExecutionAttempt().getAttemptId(), + checkpointId, + state, + keyGroupState); + + coord.receiveAcknowledgeMessage(acknowledgeCheckpoint); + } + } + + public static List generateKeyGroupState( + JobVertexID jobVertexID, + KeyGroupRange keyGroupPartition) throws IOException { + + KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupPartition); + List testStatesLists = new ArrayList<>(keyGroupPartition.getNumberOfKeyGroups()); + int runningGroupsOffset = 0; + // generate state for one keygroup + for (int keyGroupIndex : keyGroupPartition) { + Random random = new Random(jobVertexID.hashCode() + keyGroupIndex); + int simulatedStateValue = random.nextInt(); + testStatesLists.add(simulatedStateValue); + } + + return generateKeyGroupState(keyGroupPartition, testStatesLists); + } + + public static List generateKeyGroupState(KeyGroupRange keyGroupRange, List< ? extends Serializable> states) throws IOException { + Preconditions.checkArgument(keyGroupRange.getNumberOfKeyGroups() == states.size()); + + long[] offsets = new long[keyGroupRange.getNumberOfKeyGroups()]; + List serializedGroupValues = new ArrayList<>(offsets.length); + + KeyGroupRangeOffsets keyGroupRangeOffsets = new KeyGroupRangeOffsets(keyGroupRange, offsets); + + int runningGroupsOffset = 0; + // generate test state for all keygroups + int idx = 0; + for (int keyGroup : keyGroupRange) { + keyGroupRangeOffsets.setKeyGroupOffset(keyGroup,runningGroupsOffset); + byte[] serializedValue = InstantiationUtil.serializeObject(states.get(idx)); + runningGroupsOffset += serializedValue.length; + serializedGroupValues.add(serializedValue); + ++idx; + } + + //write all generated values in a single byte array, which is index by groupOffsetsInFinalByteArray + byte[] allSerializedValuesConcatenated = new byte[runningGroupsOffset]; + runningGroupsOffset = 0; + byte[] old = null; + for(byte[] serializedGroupValue : serializedGroupValues) { + System.arraycopy( + serializedGroupValue, + 0, + allSerializedValuesConcatenated, + runningGroupsOffset, + serializedGroupValue.length); + runningGroupsOffset += serializedGroupValue.length; + old = serializedGroupValue; + } + + ByteStreamStateHandle allSerializedStatesHandle = new ByteStreamStateHandle( + allSerializedValuesConcatenated); + KeyGroupsStateHandle keyGroupsStateHandle = new KeyGroupsStateHandle( + keyGroupRangeOffsets, + allSerializedStatesHandle); + List keyGroupsStateHandleList = new ArrayList<>(); + keyGroupsStateHandleList.add(keyGroupsStateHandle); + return keyGroupsStateHandleList; + } + + public static ChainedStateHandle generateStateForVertex( + JobVertexID jobVertexID, + int index) throws IOException { + + Random random = new Random(jobVertexID.hashCode() + index); + int value = random.nextInt(); + return generateChainedStateHandle(value); + } + + public static ChainedStateHandle generateChainedStateHandle( + Serializable value) throws IOException { + return ChainedStateHandle.wrapSingleHandle(ByteStreamStateHandle.fromSerializable(value)); + } + + public static ExecutionJobVertex mockExecutionJobVertex( + JobVertexID jobVertexID, + int parallelism, + int maxParallelism) { + final ExecutionJobVertex executionJobVertex = mock(ExecutionJobVertex.class); + + ExecutionVertex[] executionVertices = new ExecutionVertex[parallelism]; + + for (int i = 0; i < parallelism; i++) { + executionVertices[i] = mockExecutionVertex( + new ExecutionAttemptID(), + jobVertexID, + parallelism, + maxParallelism, + ExecutionState.RUNNING); + + when(executionVertices[i].getParallelSubtaskIndex()).thenReturn(i); + } + + when(executionJobVertex.getJobVertexId()).thenReturn(jobVertexID); + when(executionJobVertex.getTaskVertices()).thenReturn(executionVertices); + when(executionJobVertex.getParallelism()).thenReturn(parallelism); + when(executionJobVertex.getMaxParallelism()).thenReturn(maxParallelism); + + return executionJobVertex; + } + + private static ExecutionVertex mockExecutionVertex(ExecutionAttemptID attemptID) { + return mockExecutionVertex( + attemptID, + new JobVertexID(), + 1, + 1, + ExecutionState.RUNNING); + } + + private static ExecutionVertex mockExecutionVertex( + ExecutionAttemptID attemptID, + JobVertexID jobVertexID, + int parallelism, + int maxParallelism, + ExecutionState state, + ExecutionState ... successiveStates) { + + ExecutionVertex vertex = mock(ExecutionVertex.class); + + final Execution exec = spy(new Execution( + mock(ExecutionContext.class), + vertex, + 1, + 1L, + null + )); + when(exec.getAttemptId()).thenReturn(attemptID); + when(exec.getState()).thenReturn(state, successiveStates); + + when(vertex.getJobvertexId()).thenReturn(jobVertexID); + when(vertex.getCurrentExecutionAttempt()).thenReturn(exec); + when(vertex.getTotalNumberOfParallelSubtasks()).thenReturn(parallelism); + when(vertex.getMaxParallelism()).thenReturn(maxParallelism); + + return vertex; + } + + public static void verifiyStateRestore( + JobVertexID jobVertexID, ExecutionJobVertex executionJobVertex, + List keyGroupPartitions) throws Exception { + + for (int i = 0; i < executionJobVertex.getParallelism(); i++) { + + ChainedStateHandle expectNonPartitionedState = generateStateForVertex(jobVertexID, i); + ChainedStateHandle actualNonPartitionedState = executionJobVertex. + getTaskVertices()[i].getCurrentExecutionAttempt().getChainedStateHandle(); + assertEquals(expectNonPartitionedState.get(0), actualNonPartitionedState.get(0)); + + List expectPartitionedKeyGroupState = generateKeyGroupState( + jobVertexID, + keyGroupPartitions.get(i)); + List actualPartitionedKeyGroupState = executionJobVertex. + getTaskVertices()[i].getCurrentExecutionAttempt().getKeyGroupsStateHandles(); + comparePartitionedState(expectPartitionedKeyGroupState, actualPartitionedKeyGroupState); + } + } + + public static void comparePartitionedState( + List expectPartitionedKeyGroupState, + List actualPartitionedKeyGroupState) throws Exception { + + KeyGroupsStateHandle expectedHeadOpKeyGroupStateHandle = expectPartitionedKeyGroupState.get(0); + int expectedTotalKeyGroups = expectedHeadOpKeyGroupStateHandle.getNumberOfKeyGroups(); + int actualTotalKeyGroups = 0; + for(KeyGroupsStateHandle keyGroupsStateHandle: actualPartitionedKeyGroupState) { + actualTotalKeyGroups += keyGroupsStateHandle.getNumberOfKeyGroups(); + } + + assertEquals(expectedTotalKeyGroups, actualTotalKeyGroups); + + FSDataInputStream inputStream = expectedHeadOpKeyGroupStateHandle.getStateHandle().openInputStream(); + for(int groupId : expectedHeadOpKeyGroupStateHandle.keyGroups()) { + long offset = expectedHeadOpKeyGroupStateHandle.getOffsetForKeyGroup(groupId); + inputStream.seek(offset); + int expectedKeyGroupState = InstantiationUtil.deserializeObject(inputStream); + for(KeyGroupsStateHandle oneActualKeyGroupStateHandle : actualPartitionedKeyGroupState) { + if (oneActualKeyGroupStateHandle.containsKeyGroup(groupId)) { + long actualOffset = oneActualKeyGroupStateHandle.getOffsetForKeyGroup(groupId); + FSDataInputStream actualInputStream = oneActualKeyGroupStateHandle.getStateHandle().openInputStream(); + actualInputStream.seek(actualOffset); + int actualGroupState = InstantiationUtil.deserializeObject(actualInputStream); + + assertEquals(expectedKeyGroupState, actualGroupState); + } + } + } + } + + @Test + public void testCreateKeyGroupPartitions() { + testCreateKeyGroupPartitions(1, 1); + testCreateKeyGroupPartitions(13, 1); + testCreateKeyGroupPartitions(13, 2); + testCreateKeyGroupPartitions(Short.MAX_VALUE, 1); + testCreateKeyGroupPartitions(Short.MAX_VALUE, 13); + testCreateKeyGroupPartitions(Short.MAX_VALUE, Short.MAX_VALUE); + + Random r = new Random(1234); + for (int k = 0; k < 1000; ++k) { + int maxParallelism = 1 + r.nextInt(Short.MAX_VALUE - 1); + int parallelism = 1 + r.nextInt(maxParallelism); + testCreateKeyGroupPartitions(maxParallelism, parallelism); + } + } + + private void testCreateKeyGroupPartitions(int maxParallelism, int parallelism) { + List ranges = CheckpointCoordinator.createKeyGroupPartitions(maxParallelism, parallelism); + for (int i = 0; i < maxParallelism; ++i) { + KeyGroupRange range = ranges.get(KeyGroupRange.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, i)); + if (!range.contains(i)) { + Assert.fail("Could not find expected key-group " + i + " in range " + range); + } + } + } + } http://git-wip-us.apache.org/repos/asf/flink/blob/516ad011/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java ---------------------------------------------------------------------- diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java index 5fcc59e..f757943 100644 --- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java +++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/OneInputStreamTaskTest.java @@ -18,26 +18,54 @@ package org.apache.flink.streaming.runtime.tasks; +import org.apache.flink.api.common.ExecutionConfig; import org.apache.flink.api.common.functions.MapFunction; import org.apache.flink.api.common.functions.RichMapFunction; import org.apache.flink.api.common.typeinfo.BasicTypeInfo; +import org.apache.flink.api.java.functions.KeySelector; import org.apache.flink.configuration.Configuration; +import org.apache.flink.core.fs.FSDataInputStream; +import org.apache.flink.core.fs.FSDataOutputStream; import org.apache.flink.runtime.io.network.api.CheckpointBarrier; import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter; +import org.apache.flink.runtime.operators.testutils.MockInputSplitProvider; +import org.apache.flink.runtime.state.ChainedStateHandle; +import org.apache.flink.runtime.state.KeyGroupsStateHandle; +import org.apache.flink.runtime.state.StreamStateHandle; import org.apache.flink.streaming.api.graph.StreamConfig; +import org.apache.flink.streaming.api.graph.StreamEdge; +import org.apache.flink.streaming.api.graph.StreamNode; +import org.apache.flink.streaming.api.operators.AbstractStreamOperator; +import org.apache.flink.streaming.api.operators.OneInputStreamOperator; import org.apache.flink.streaming.api.operators.StreamMap; import org.apache.flink.streaming.api.watermark.Watermark; import org.apache.flink.streaming.runtime.streamrecord.StreamRecord; import org.apache.flink.streaming.util.TestHarnessUtil; +import org.apache.flink.util.InstantiationUtil; +import org.apache.flink.util.Preconditions; +import org.apache.flink.util.TestLogger; import org.junit.Assert; import org.junit.Test; import org.junit.runner.RunWith; import org.powermock.core.classloader.annotations.PowerMockIgnore; import org.powermock.core.classloader.annotations.PrepareForTest; import org.powermock.modules.junit4.PowerMockRunner; +import scala.concurrent.duration.Deadline; +import scala.concurrent.duration.FiniteDuration; +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.Random; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.TimeUnit; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; /** * Tests for {@link OneInputStreamTask}. @@ -51,7 +79,7 @@ import java.util.concurrent.ConcurrentLinkedQueue; @RunWith(PowerMockRunner.class) @PrepareForTest({ResultPartitionWriter.class}) @PowerMockIgnore({"javax.management.*", "com.sun.jndi.*"}) -public class OneInputStreamTaskTest { +public class OneInputStreamTaskTest extends TestLogger { /** * This test verifies that open() and close() are correctly called. This test also verifies @@ -82,7 +110,7 @@ public class OneInputStreamTaskTest { testHarness.waitForTaskCompletion(); - Assert.assertTrue("RichFunction methods where not called.", TestOpenCloseMapFunction.closeCalled); + assertTrue("RichFunction methods where not called.", TestOpenCloseMapFunction.closeCalled); TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, @@ -165,7 +193,7 @@ public class OneInputStreamTaskTest { testHarness.waitForTaskCompletion(); List resultElements = TestHarnessUtil.getRawElementsFromOutput(testHarness.getOutput()); - Assert.assertEquals(2, resultElements.size()); + assertEquals(2, resultElements.size()); } /** @@ -293,6 +321,252 @@ public class OneInputStreamTaskTest { TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput()); } + /** + * Tests that the stream operator can snapshot and restore the operator state of chained + * operators + */ + @Test + public void testSnapshottingAndRestoring() throws Exception { + final Deadline deadline = new FiniteDuration(2, TimeUnit.MINUTES).fromNow(); + final OneInputStreamTask streamTask = new OneInputStreamTask(); + final OneInputStreamTaskTestHarness testHarness = new OneInputStreamTaskTestHarness(streamTask, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO); + IdentityKeySelector keySelector = new IdentityKeySelector<>(); + testHarness.configureForKeyedStream(keySelector, BasicTypeInfo.STRING_TYPE_INFO); + + long checkpointId = 1L; + long checkpointTimestamp = 1L; + long recoveryTimestamp = 3L; + long seed = 2L; + int numberChainedTasks = 11; + + StreamConfig streamConfig = testHarness.getStreamConfig(); + + configureChainedTestingStreamOperator(streamConfig, numberChainedTasks, seed, recoveryTimestamp); + + AcknowledgeStreamMockEnvironment env = new AcknowledgeStreamMockEnvironment( + testHarness.jobConfig, + testHarness.taskConfig, + testHarness.executionConfig, + testHarness.memorySize, + new MockInputSplitProvider(), + testHarness.bufferSize); + + // reset number of restore calls + TestingStreamOperator.numberRestoreCalls = 0; + + testHarness.invoke(env); + testHarness.waitForTaskRunning(deadline.timeLeft().toMillis()); + + streamTask.triggerCheckpoint(checkpointId, checkpointTimestamp); + + testHarness.endInput(); + testHarness.waitForTaskCompletion(deadline.timeLeft().toMillis()); + + // since no state was set, there shouldn't be restore calls + assertEquals(0, TestingStreamOperator.numberRestoreCalls); + + assertEquals(checkpointId, env.getCheckpointId()); + + final OneInputStreamTask restoredTask = new OneInputStreamTask(); + restoredTask.setInitialState(env.getState(), env.getKeyGroupStates()); + + final OneInputStreamTaskTestHarness restoredTaskHarness = new OneInputStreamTaskTestHarness(restoredTask, BasicTypeInfo.STRING_TYPE_INFO, BasicTypeInfo.STRING_TYPE_INFO); + restoredTaskHarness.configureForKeyedStream(keySelector, BasicTypeInfo.STRING_TYPE_INFO); + + StreamConfig restoredTaskStreamConfig = restoredTaskHarness.getStreamConfig(); + + configureChainedTestingStreamOperator(restoredTaskStreamConfig, numberChainedTasks, seed, recoveryTimestamp); + + TestingStreamOperator.numberRestoreCalls = 0; + + restoredTaskHarness.invoke(); + restoredTaskHarness.endInput(); + restoredTaskHarness.waitForTaskCompletion(deadline.timeLeft().toMillis()); + + // restore of every chained operator should have been called + assertEquals(numberChainedTasks, TestingStreamOperator.numberRestoreCalls); + + TestingStreamOperator.numberRestoreCalls = 0; + } + + //============================================================================================== + // Utility functions and classes + //============================================================================================== + + private void configureChainedTestingStreamOperator( + StreamConfig streamConfig, + int numberChainedTasks, + long seed, + long recoveryTimestamp) { + + Preconditions.checkArgument(numberChainedTasks >= 1, "The operator chain must at least " + + "contain one operator."); + + Random random = new Random(seed); + + TestingStreamOperator previousOperator = new TestingStreamOperator<>(random.nextLong(), recoveryTimestamp); + streamConfig.setStreamOperator(previousOperator); + + // create the chain of operators + Map chainedTaskConfigs = new HashMap<>(numberChainedTasks - 1); + List outputEdges = new ArrayList<>(numberChainedTasks - 1); + + for (int chainedIndex = 1; chainedIndex < numberChainedTasks; chainedIndex++) { + TestingStreamOperator chainedOperator = new TestingStreamOperator<>(random.nextLong(), recoveryTimestamp); + StreamConfig chainedConfig = new StreamConfig(new Configuration()); + chainedConfig.setStreamOperator(chainedOperator); + chainedTaskConfigs.put(chainedIndex, chainedConfig); + + StreamEdge outputEdge = new StreamEdge( + new StreamNode( + null, + chainedIndex - 1, + null, + null, + null, + null, + null + ), + new StreamNode( + null, + chainedIndex, + null, + null, + null, + null, + null + ), + 0, + Collections.emptyList(), + null + ); + + outputEdges.add(outputEdge); + } + + streamConfig.setChainedOutputs(outputEdges); + streamConfig.setTransitiveChainedTaskConfigs(chainedTaskConfigs); + } + + private static class IdentityKeySelector implements KeySelector { + + private static final long serialVersionUID = -3555913664416688425L; + + @Override + public IN getKey(IN value) throws Exception { + return value; + } + } + + private static class AcknowledgeStreamMockEnvironment extends StreamMockEnvironment { + private long checkpointId; + private ChainedStateHandle state; + private List keyGroupStates; + + public long getCheckpointId() { + return checkpointId; + } + + public ChainedStateHandle getState() { + return state; + } + + List getKeyGroupStates() { + List result = new ArrayList<>(); + for (int i = 0; i < keyGroupStates.size(); i++) { + if (keyGroupStates.get(i) != null) { + result.add(keyGroupStates.get(i)); + } + } + return result; + } + + AcknowledgeStreamMockEnvironment(Configuration jobConfig, Configuration taskConfig, + ExecutionConfig executionConfig, long memorySize, + MockInputSplitProvider inputSplitProvider, int bufferSize) { + super(jobConfig, taskConfig, executionConfig, memorySize, inputSplitProvider, bufferSize); + } + + + @Override + public void acknowledgeCheckpoint(long checkpointId, ChainedStateHandle state, + List keyGroupStates) { + this.checkpointId = checkpointId; + this.state = state; + this.keyGroupStates = keyGroupStates; + } + } + + private static class TestingStreamOperator + extends AbstractStreamOperator implements OneInputStreamOperator { + + private static final long serialVersionUID = 774614855940397174L; + + public static int numberRestoreCalls = 0; + + private final long seed; + private final long recoveryTimestamp; + + private transient Random random; + + TestingStreamOperator(long seed, long recoveryTimestamp) { + this.seed = seed; + this.recoveryTimestamp = recoveryTimestamp; + } + + @Override + public void processElement(StreamRecord element) throws Exception { + + } + + @Override + public void processWatermark(Watermark mark) throws Exception { + + } + + @Override + public void snapshotState(FSDataOutputStream out, long checkpointId, long timestamp) throws Exception { + if (random == null) { + random = new Random(seed); + } + + Serializable functionState = generateFunctionState(); + Integer operatorState = generateOperatorState(); + + InstantiationUtil.serializeObject(out, functionState); + InstantiationUtil.serializeObject(out, operatorState); + } + + @Override + public void restoreState(FSDataInputStream in) throws Exception { + numberRestoreCalls++; + + if (random == null) { + random = new Random(seed); + } + + assertEquals(this.recoveryTimestamp, recoveryTimestamp); + + assertNotNull(in); + + Serializable functionState= InstantiationUtil.deserializeObject(in); + Integer operatorState= InstantiationUtil.deserializeObject(in); + + assertEquals(random.nextInt(), functionState); + assertEquals(random.nextInt(), (int) operatorState); + } + + + private Serializable generateFunctionState() { + return random.nextInt(); + } + + private Integer generateOperatorState() { + return random.nextInt(); + } + } + + // This must only be used in one test, otherwise the static fields will be changed // by several tests concurrently private static class TestOpenCloseMapFunction extends RichMapFunction { http://git-wip-us.apache.org/repos/asf/flink/blob/516ad011/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java ---------------------------------------------------------------------- diff --git a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java index 8d1baeb..39f3086 100644 --- a/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java +++ b/flink-tests/src/test/java/org/apache/flink/test/checkpointing/RescalingITCase.java @@ -352,7 +352,6 @@ public class RescalingITCase extends TestLogger { for (int key = 0; key < numberKeys; key++) { int keyGroupIndex = keyGroupAssigner.getKeyGroupIndex(key); -// expectedResult.add(Tuple2.of(keyGroupIndex % parallelism, numberElements * key)); expectedResult.add(Tuple2.of(KeyGroupRange.computeOperatorIndexForKeyGroup(maxParallelism, parallelism, keyGroupIndex) , numberElements * key)); }