flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From zhiji...@apache.org
Subject [flink] branch master updated: [FLINK-16537][network] Implement ResultPartition state recovery for unaligned checkpoint
Date Tue, 07 Apr 2020 15:25:22 GMT
This is an automated email from the ASF dual-hosted git repository.

zhijiang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink.git


The following commit(s) were added to refs/heads/master by this push:
     new 693cb6a  [FLINK-16537][network] Implement ResultPartition state recovery for unaligned
checkpoint
693cb6a is described below

commit 693cb6adc42d75d1db720b45013430a4c6817d4a
Author: Zhijiang <wangzhijiang999@aliyun.com>
AuthorDate: Fri Apr 3 11:08:56 2020 +0800

    [FLINK-16537][network] Implement ResultPartition state recovery for unaligned checkpoint
    
    During state recovery for unaligned checkpoint, the partition state should also be recovered
besides with existing operator states.
    
    The ResultPartition would request buffer from local pool and then interact with ChannelStateReader
to fill in the state data.
    The filled buffer would be inserted into respective ResultSubpartition queue in normal
way.
    
    It should guarantee that op can not process any inputs before finishing all the output
recovery to avoid mis-order issue.
---
 .../checkpoint/channel/ChannelStateReader.java     |   5 +-
 .../network/api/writer/ResultPartitionWriter.java  |   7 ++
 .../network/partition/PipelinedSubpartition.java   |  22 +++-
 .../partition/PipelinedSubpartitionView.java       |   4 +-
 .../io/network/partition/ResultPartition.java      |   8 ++
 .../io/network/partition/ResultSubpartition.java   |   4 +
 ...bleNotifyingResultPartitionWriterDecorator.java |   6 +
 .../io/network/api/writer/RecordWriterTest.java    |  62 ++++++++++
 .../buffer/BufferBuilderAndConsumerTest.java       |  10 +-
 .../partition/MockResultPartitionWriter.java       |   5 +
 .../partition/NoOpBufferAvailablityListener.java   |   2 +-
 .../io/network/partition/ResultPartitionTest.java  | 125 +++++++++++++++++++++
 .../flink/streaming/runtime/tasks/StreamTask.java  |   9 ++
 .../streaming/runtime/tasks/StreamTaskTest.java    |  42 +++++++
 14 files changed, 300 insertions(+), 11 deletions(-)

diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateReader.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateReader.java
index 49321cc..0753e7a 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateReader.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/checkpoint/channel/ChannelStateReader.java
@@ -18,6 +18,7 @@ package org.apache.flink.runtime.checkpoint.channel;
 
 import org.apache.flink.annotation.Internal;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
 
 import java.io.IOException;
 
@@ -42,7 +43,7 @@ public interface ChannelStateReader extends AutoCloseable {
 	 * Put data into the supplied buffer to be injected into
 	 * {@link org.apache.flink.runtime.io.network.partition.ResultSubpartition ResultSubpartition}.
 	 */
-	ReadResult readOutputData(ResultSubpartitionInfo info, Buffer buffer) throws IOException;
+	ReadResult readOutputData(ResultSubpartitionInfo info, BufferBuilder bufferBuilder) throws
IOException;
 
 	@Override
 	void close() throws Exception;
@@ -55,7 +56,7 @@ public interface ChannelStateReader extends AutoCloseable {
 		}
 
 		@Override
-		public ReadResult readOutputData(ResultSubpartitionInfo info, Buffer buffer) {
+		public ReadResult readOutputData(ResultSubpartitionInfo info, BufferBuilder bufferBuilder)
{
 			return ReadResult.NO_MORE_DATA;
 		}
 
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
index 75cd5fb..2c1717d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/api/writer/ResultPartitionWriter.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.io.network.api.writer;
 
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
 import org.apache.flink.runtime.io.AvailabilityProvider;
 import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
 import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
@@ -42,6 +43,12 @@ public interface ResultPartitionWriter extends AutoCloseable, AvailabilityProvid
 	 */
 	void setup() throws IOException;
 
+	/**
+	 * Loads the previous output states with the given reader for unaligned checkpoint.
+	 * It should be done before task processing the inputs.
+	 */
+	void initializeState(ChannelStateReader stateReader) throws IOException, InterruptedException;
+
 	ResultPartitionID getPartitionId();
 
 	int getNumberOfSubpartitions();
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
index ecf6956..070089d 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
@@ -19,9 +19,12 @@
 package org.apache.flink.runtime.io.network.partition;
 
 import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader.ReadResult;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
 import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
 import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
 
 import org.slf4j.Logger;
@@ -52,7 +55,7 @@ import static org.apache.flink.util.Preconditions.checkState;
  * {@link PipelinedSubpartitionView#notifyDataAvailable() notification} for any
  * {@link BufferConsumer} present in the queue.
  */
-class PipelinedSubpartition extends ResultSubpartition {
+public class PipelinedSubpartition extends ResultSubpartition {
 
 	private static final Logger LOG = LoggerFactory.getLogger(PipelinedSubpartition.class);
 
@@ -90,6 +93,23 @@ class PipelinedSubpartition extends ResultSubpartition {
 	}
 
 	@Override
+	public void initializeState(ChannelStateReader stateReader) throws IOException, InterruptedException
{
+		for (ReadResult readResult = ReadResult.HAS_MORE_DATA; readResult == ReadResult.HAS_MORE_DATA;)
{
+			BufferBuilder bufferBuilder = parent.getBufferPool().requestBufferBuilderBlocking();
+			BufferConsumer bufferConsumer = bufferBuilder.createBufferConsumer();
+			readResult = stateReader.readOutputData(subpartitionInfo, bufferBuilder);
+
+			// check whether there are some states data filled in this time
+			if (bufferConsumer.isDataAvailable()) {
+				add(bufferConsumer);
+				bufferBuilder.finish();
+			} else {
+				bufferConsumer.close();
+			}
+		}
+	}
+
+	@Override
 	public boolean add(BufferConsumer bufferConsumer) {
 		return add(bufferConsumer, false);
 	}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartitionView.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartitionView.java
index febbfbd..ee837d5 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartitionView.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartitionView.java
@@ -29,7 +29,7 @@ import static org.apache.flink.util.Preconditions.checkNotNull;
 /**
  * View over a pipelined in-memory only subpartition.
  */
-class PipelinedSubpartitionView implements ResultSubpartitionView {
+public class PipelinedSubpartitionView implements ResultSubpartitionView {
 
 	/** The subpartition this view belongs to. */
 	private final PipelinedSubpartition parent;
@@ -39,7 +39,7 @@ class PipelinedSubpartitionView implements ResultSubpartitionView {
 	/** Flag indicating whether this view has been released. */
 	private final AtomicBoolean isReleased;
 
-	PipelinedSubpartitionView(PipelinedSubpartition parent, BufferAvailabilityListener listener)
{
+	public PipelinedSubpartitionView(PipelinedSubpartition parent, BufferAvailabilityListener
listener) {
 		this.parent = checkNotNull(parent);
 		this.availabilityListener = checkNotNull(listener);
 		this.isReleased = new AtomicBoolean();
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
index ccd3fa9..bb925fb 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultPartition.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.io.network.partition;
 
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
 import org.apache.flink.runtime.executiongraph.IntermediateResultPartition;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
@@ -150,6 +151,13 @@ public class ResultPartition implements ResultPartitionWriter, BufferPoolOwner
{
 		partitionManager.registerResultPartition(this);
 	}
 
+	@Override
+	public void initializeState(ChannelStateReader stateReader) throws IOException, InterruptedException
{
+		for (ResultSubpartition subpartition : subpartitions) {
+			subpartition.initializeState(stateReader);
+		}
+	}
+
 	public String getOwningTaskName() {
 		return owningTaskName;
 	}
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultSubpartition.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultSubpartition.java
index d139df0..d0256a1 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultSubpartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/ResultSubpartition.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.io.network.partition;
 
 import org.apache.flink.annotation.VisibleForTesting;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
 import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
@@ -76,6 +77,9 @@ public abstract class ResultSubpartition {
 		parent.onConsumedSubpartition(index);
 	}
 
+	public void initializeState(ChannelStateReader stateReader) throws IOException, InterruptedException
{
+	}
+
 	/**
 	 * Adds the given buffer.
 	 *
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java
index 8b1d97d..ada45cb 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/taskmanager/ConsumableNotifyingResultPartitionWriterDecorator.java
@@ -19,6 +19,7 @@
 package org.apache.flink.runtime.taskmanager;
 
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
 import org.apache.flink.runtime.deployment.ResultPartitionDeploymentDescriptor;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
@@ -89,6 +90,11 @@ public class ConsumableNotifyingResultPartitionWriterDecorator implements
Result
 	}
 
 	@Override
+	public void initializeState(ChannelStateReader stateReader) throws IOException, InterruptedException
{
+		partitionWriter.initializeState(stateReader);
+	}
+
+	@Override
 	public boolean addBufferConsumer(BufferConsumer bufferConsumer, int subpartitionIndex) throws
IOException {
 		boolean success = partitionWriter.addBufferConsumer(bufferConsumer, subpartitionIndex);
 		if (success) {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java
index 4964d93..867f591 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/api/writer/RecordWriterTest.java
@@ -25,6 +25,7 @@ import org.apache.flink.core.memory.DataOutputView;
 import org.apache.flink.core.memory.MemorySegment;
 import org.apache.flink.core.memory.MemorySegmentFactory;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
 import org.apache.flink.runtime.event.AbstractEvent;
 import org.apache.flink.runtime.io.network.api.CheckpointBarrier;
 import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
@@ -34,6 +35,7 @@ import org.apache.flink.runtime.io.network.api.serialization.RecordSerializer.Se
 import org.apache.flink.runtime.io.network.api.serialization.SpillingAdaptiveSpanningRecordDeserializer;
 import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
+import org.apache.flink.runtime.io.network.buffer.BufferBuilderAndConsumerTest;
 import org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils;
 import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
 import org.apache.flink.runtime.io.network.buffer.BufferPool;
@@ -41,8 +43,15 @@ import org.apache.flink.runtime.io.network.buffer.BufferProvider;
 import org.apache.flink.runtime.io.network.buffer.BufferRecycler;
 import org.apache.flink.runtime.io.network.buffer.NetworkBufferPool;
 import org.apache.flink.runtime.io.network.partition.MockResultPartitionWriter;
+import org.apache.flink.runtime.io.network.partition.NoOpBufferAvailablityListener;
 import org.apache.flink.runtime.io.network.partition.NoOpResultPartitionConsumableNotifier;
+import org.apache.flink.runtime.io.network.partition.PipelinedSubpartition;
+import org.apache.flink.runtime.io.network.partition.PipelinedSubpartitionView;
+import org.apache.flink.runtime.io.network.partition.ResultPartition;
 import org.apache.flink.runtime.io.network.partition.ResultPartitionBuilder;
+import org.apache.flink.runtime.io.network.partition.ResultPartitionTest;
+import org.apache.flink.runtime.io.network.partition.ResultSubpartition;
+import org.apache.flink.runtime.io.network.partition.ResultSubpartitionView;
 import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
 import org.apache.flink.runtime.io.network.util.DeserializationUtils;
 import org.apache.flink.runtime.io.network.util.TestPooledBufferProvider;
@@ -464,6 +473,59 @@ public class RecordWriterTest {
 		}
 	}
 
+	@Test
+	public void testEmitRecordWithPartitionStateRecovery() throws Exception {
+		final int totalBuffers = 10; // enough for both states and normal records
+		final int totalStates = 2;
+		final int[] states = {1, 2, 3, 4};
+		final int[] records = {5, 6, 7, 8};
+		final int bufferSize = states.length * Integer.BYTES;
+
+		final NetworkBufferPool globalPool = new NetworkBufferPool(totalBuffers, bufferSize, 1);
+		final ChannelStateReader stateReader = new ResultPartitionTest.FiniteChannelStateReader(totalStates,
states);
+		final ResultPartition partition = new ResultPartitionBuilder()
+			.setNetworkBufferPool(globalPool)
+			.build();
+		final RecordWriter<IntValue> recordWriter = new RecordWriterBuilder<IntValue>().build(partition);
+
+		try {
+			partition.setup();
+			partition.initializeState(stateReader);
+
+			for (int record: records) {
+				// the record length 4 is also written into buffer for every emit
+				recordWriter.broadcastEmit(new IntValue(record));
+			}
+
+			// every buffer can contain 2 int records with 2 int length(4)
+			final int[][] expectedRecordsInBuffer = {{4, 5, 4, 6}, {4, 7, 4, 8}};
+
+			for (ResultSubpartition subpartition : partition.getAllPartitions()) {
+				// create the view to consume all the buffers with states and records
+				final ResultSubpartitionView view = new PipelinedSubpartitionView(
+					(PipelinedSubpartition) subpartition,
+					new NoOpBufferAvailablityListener());
+
+				int numConsumedBuffers = 0;
+				ResultSubpartition.BufferAndBacklog bufferAndBacklog;
+				while ((bufferAndBacklog = view.getNextBuffer()) != null) {
+					Buffer buffer = bufferAndBacklog.buffer();
+					int[] expected = numConsumedBuffers < totalStates ? states : expectedRecordsInBuffer[numConsumedBuffers
- totalStates];
+					BufferBuilderAndConsumerTest.assertContent(buffer, partition.getBufferPool(), expected);
+
+					buffer.recycleBuffer();
+					numConsumedBuffers++;
+				}
+
+				assertEquals(totalStates + expectedRecordsInBuffer.length, numConsumedBuffers);
+			}
+		} finally {
+			// cleanup
+			globalPool.destroyAllBufferPools();
+			globalPool.destroy();
+		}
+	}
+
 	private void verifyBroadcastBufferOrEventIndependence(boolean broadcastEvent) throws Exception
{
 		@SuppressWarnings("unchecked")
 		ArrayDeque<BufferConsumer>[] queues = new ArrayDeque[]{new ArrayDeque(), new ArrayDeque()};
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferBuilderAndConsumerTest.java
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferBuilderAndConsumerTest.java
index 3975a71..1033c5e 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferBuilderAndConsumerTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/buffer/BufferBuilderAndConsumerTest.java
@@ -164,7 +164,7 @@ public class BufferBuilderAndConsumerTest {
 	public void buildEmptyBuffer() {
 		Buffer buffer = buildSingleBuffer(createBufferBuilder());
 		assertEquals(0, buffer.getSize());
-		assertContent(buffer);
+		assertContent(buffer, FreeingBufferRecycler.INSTANCE);
 	}
 
 	@Test
@@ -240,7 +240,7 @@ public class BufferBuilderAndConsumerTest {
 		assertTrue(bufferConsumer.isFinished());
 	}
 
-	private static ByteBuffer toByteBuffer(int... data) {
+	public static ByteBuffer toByteBuffer(int... data) {
 		ByteBuffer byteBuffer = ByteBuffer.allocate(data.length * Integer.BYTES);
 		byteBuffer.asIntBuffer().put(data);
 		return byteBuffer;
@@ -250,18 +250,18 @@ public class BufferBuilderAndConsumerTest {
 		assertFalse(actualConsumer.isFinished());
 		Buffer buffer = actualConsumer.build();
 		assertFalse(buffer.isRecycled());
-		assertContent(buffer, expected);
+		assertContent(buffer, FreeingBufferRecycler.INSTANCE, expected);
 		assertEquals(expected.length * Integer.BYTES, buffer.getSize());
 		buffer.recycleBuffer();
 	}
 
-	private static void assertContent(Buffer actualBuffer, int... expected) {
+	public static void assertContent(Buffer actualBuffer, BufferRecycler recycler, int... expected)
{
 		IntBuffer actualIntBuffer = actualBuffer.getNioBufferReadable().asIntBuffer();
 		int[] actual = new int[actualIntBuffer.limit()];
 		actualIntBuffer.get(actual);
 		assertArrayEquals(expected, actual);
 
-		assertEquals(FreeingBufferRecycler.INSTANCE, actualBuffer.getRecycler());
+		assertEquals(recycler, actualBuffer.getRecycler());
 	}
 
 	private static BufferBuilder createBufferBuilder() {
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockResultPartitionWriter.java
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockResultPartitionWriter.java
index fd6c1f8..9fd8205 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockResultPartitionWriter.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/MockResultPartitionWriter.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.runtime.io.network.partition;
 
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
 import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
 import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
@@ -39,6 +40,10 @@ public class MockResultPartitionWriter implements ResultPartitionWriter
{
 	}
 
 	@Override
+	public void initializeState(ChannelStateReader stateReader) {
+	}
+
+	@Override
 	public ResultPartitionID getPartitionId() {
 		return partitionId;
 	}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpBufferAvailablityListener.java
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpBufferAvailablityListener.java
index 4162975..7fbd43e 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpBufferAvailablityListener.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/NoOpBufferAvailablityListener.java
@@ -21,7 +21,7 @@ package org.apache.flink.runtime.io.network.partition;
 /**
  * Test implementation of {@link BufferAvailabilityListener}.
  */
-class NoOpBufferAvailablityListener implements BufferAvailabilityListener {
+public class NoOpBufferAvailablityListener implements BufferAvailabilityListener {
 	@Override
 	public void notifyDataAvailable() {
 	}
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
index 011aa72..f3e512f 100644
--- a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/ResultPartitionTest.java
@@ -19,13 +19,17 @@
 package org.apache.flink.runtime.io.network.partition;
 
 import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
+import org.apache.flink.runtime.checkpoint.channel.InputChannelInfo;
 import org.apache.flink.runtime.checkpoint.channel.ResultSubpartitionInfo;
 import org.apache.flink.runtime.io.disk.FileChannelManager;
 import org.apache.flink.runtime.io.disk.FileChannelManagerImpl;
 import org.apache.flink.runtime.io.network.NettyShuffleEnvironment;
 import org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
 import org.apache.flink.runtime.io.network.buffer.BufferBuilder;
+import org.apache.flink.runtime.io.network.buffer.BufferBuilderAndConsumerTest;
 import org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils;
 import org.apache.flink.runtime.io.network.buffer.BufferConsumer;
 import org.apache.flink.runtime.io.network.buffer.BufferPool;
@@ -42,6 +46,11 @@ import org.junit.Test;
 
 import java.io.IOException;
 import java.util.Collections;
+import java.util.concurrent.Callable;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
 
 import static org.apache.flink.runtime.io.network.buffer.BufferBuilderTestUtils.createFilledFinishedBufferConsumer;
 import static org.apache.flink.runtime.io.network.partition.PartitionTestUtils.createPartition;
@@ -407,4 +416,120 @@ public class ResultPartitionTest {
 			jobId,
 			notifier)[0];
 	}
+
+	@Test
+	public void testInitializeEmptyState() throws Exception {
+		final int totalBuffers = 2;
+		final NetworkBufferPool globalPool = new NetworkBufferPool(totalBuffers, 1, 1);
+		final ResultPartition partition = new ResultPartitionBuilder()
+			.setNetworkBufferPool(globalPool)
+			.build();
+		final ChannelStateReader stateReader = ChannelStateReader.NO_OP;
+		try {
+			partition.setup();
+			partition.initializeState(stateReader);
+
+			for (ResultSubpartition subpartition : partition.getAllPartitions()) {
+				// no buffers are added into the queue for empty states
+				assertEquals(0, subpartition.getTotalNumberOfBuffers());
+			}
+
+			// destroy the local pool to verify that all the requested buffers by partition are recycled
+			partition.getBufferPool().lazyDestroy();
+			assertEquals(totalBuffers, globalPool.getNumberOfAvailableMemorySegments());
+		} finally {
+			// cleanup
+			globalPool.destroyAllBufferPools();
+			globalPool.destroy();
+		}
+	}
+
+	@Test
+	public void testInitializeMoreStateThanBuffer() throws Exception {
+		final int totalBuffers = 2; // the total buffers are less than the requirement from total
states
+		final int totalStates = 5;
+		final int[] states = {1, 2, 3, 4};
+		final int bufferSize = states.length * Integer.BYTES;
+
+		final NetworkBufferPool globalPool = new NetworkBufferPool(totalBuffers, bufferSize, 1);
+		final ChannelStateReader stateReader = new FiniteChannelStateReader(totalStates, states);
+		final ResultPartition partition = new ResultPartitionBuilder()
+			.setNetworkBufferPool(globalPool)
+			.build();
+		final ExecutorService executor = Executors.newFixedThreadPool(1);
+
+		try {
+			final Callable<Void> partitionConsumeTask = () -> {
+				for (ResultSubpartition subpartition : partition.getAllPartitions()) {
+					final ResultSubpartitionView view = new PipelinedSubpartitionView(
+						(PipelinedSubpartition) subpartition,
+						new NoOpBufferAvailablityListener());
+
+					int numConsumedBuffers = 0;
+					while (numConsumedBuffers != totalStates) {
+						ResultSubpartition.BufferAndBacklog bufferAndBacklog = view.getNextBuffer();
+						if (bufferAndBacklog != null) {
+							Buffer buffer = bufferAndBacklog.buffer();
+							BufferBuilderAndConsumerTest.assertContent(buffer, partition.getBufferPool(), states);
+							buffer.recycleBuffer();
+							numConsumedBuffers++;
+						} else {
+							Thread.sleep(5);
+						}
+					}
+				}
+				return null;
+			};
+			Future<Void> result = executor.submit(partitionConsumeTask);
+
+			partition.setup();
+			partition.initializeState(stateReader);
+
+			// wait the partition consume task finish
+			result.get(20, TimeUnit.SECONDS);
+
+			// destroy the local pool to verify that all the requested buffers by partition are recycled
+			partition.getBufferPool().lazyDestroy();
+			assertEquals(totalBuffers, globalPool.getNumberOfAvailableMemorySegments());
+		} finally {
+			// cleanup
+			executor.shutdown();
+			globalPool.destroyAllBufferPools();
+			globalPool.destroy();
+		}
+	}
+
+	/**
+	 * The {@link ChannelStateReader} instance for restoring the specific number of states.
+	 */
+	public static final class FiniteChannelStateReader implements ChannelStateReader {
+		private final int totalStates;
+		private int numRestoredStates;
+		private final int[] states;
+
+		public FiniteChannelStateReader(int totalStates, int[] states) {
+			this.totalStates = totalStates;
+			this.states = states;
+		}
+
+		@Override
+		public ReadResult readInputData(InputChannelInfo info, Buffer buffer) {
+			return ReadResult.NO_MORE_DATA;
+		}
+
+		@Override
+		public ReadResult readOutputData(ResultSubpartitionInfo info, BufferBuilder bufferBuilder)
{
+			bufferBuilder.appendAndCommit(BufferBuilderAndConsumerTest.toByteBuffer(states));
+
+			if (++numRestoredStates < totalStates) {
+				return ReadResult.HAS_MORE_DATA;
+			} else {
+				return ReadResult.NO_MORE_DATA;
+			}
+		}
+
+		@Override
+		public void close() {
+		}
+	}
 }
diff --git a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
index d3a79b6..e699994 100644
--- a/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
+++ b/flink-streaming-java/src/main/java/org/apache/flink/streaming/runtime/tasks/StreamTask.java
@@ -28,6 +28,7 @@ import org.apache.flink.runtime.checkpoint.CheckpointFailureReason;
 import org.apache.flink.runtime.checkpoint.CheckpointMetaData;
 import org.apache.flink.runtime.checkpoint.CheckpointMetrics;
 import org.apache.flink.runtime.checkpoint.CheckpointOptions;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
 import org.apache.flink.runtime.concurrent.FutureUtils;
 import org.apache.flink.runtime.execution.CancelTaskException;
 import org.apache.flink.runtime.execution.Environment;
@@ -434,6 +435,14 @@ public abstract class StreamTask<OUT, OP extends StreamOperator<OUT>>
 			// so that we avoid race conditions in the case that initializeState()
 			// registers a timer, that fires before the open() is called.
 			operatorChain.initializeStateAndOpenOperators(createStreamTaskStateInitializer());
+
+			ResultPartitionWriter[] writers = getEnvironment().getAllWriters();
+			if (writers != null) {
+				//TODO we should get proper state reader from getEnvironment().getTaskStateManager().getChannelStateReader()
+				for (ResultPartitionWriter writer : writers) {
+					writer.initializeState(ChannelStateReader.NO_OP);
+				}
+			}
 		});
 
 		isRunning = true;
diff --git a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
index 3fa9508..c610a4f 100644
--- a/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
+++ b/flink-streaming-java/src/test/java/org/apache/flink/streaming/runtime/tasks/StreamTaskTest.java
@@ -33,6 +33,7 @@ import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
 import org.apache.flink.runtime.checkpoint.StateObjectCollection;
 import org.apache.flink.runtime.checkpoint.SubtaskState;
 import org.apache.flink.runtime.checkpoint.TaskStateSnapshot;
+import org.apache.flink.runtime.checkpoint.channel.ChannelStateReader;
 import org.apache.flink.runtime.concurrent.FutureUtils;
 import org.apache.flink.runtime.concurrent.TestingUncaughtExceptionHandler;
 import org.apache.flink.runtime.execution.CancelTaskException;
@@ -44,6 +45,7 @@ import org.apache.flink.runtime.io.network.NettyShuffleEnvironmentBuilder;
 import org.apache.flink.runtime.io.network.api.writer.AvailabilityTestResultPartitionWriter;
 import org.apache.flink.runtime.io.network.api.writer.RecordWriter;
 import org.apache.flink.runtime.io.network.api.writer.ResultPartitionWriter;
+import org.apache.flink.runtime.io.network.partition.MockResultPartitionWriter;
 import org.apache.flink.runtime.jobgraph.OperatorID;
 import org.apache.flink.runtime.jobgraph.tasks.AbstractInvokable;
 import org.apache.flink.runtime.operators.testutils.DummyEnvironment;
@@ -895,6 +897,30 @@ public class StreamTaskTest extends TestLogger {
 		}
 	}
 
+	@Test
+	public void testInitializeResultPartitionState() throws Exception {
+		int numWriters = 2;
+		RecoveryResultPartition[] partitions = new RecoveryResultPartition[numWriters];
+		for (int i = 0; i < numWriters; i++) {
+			partitions[i] = new RecoveryResultPartition();
+		}
+
+		MockEnvironment mockEnvironment = new MockEnvironmentBuilder().build();
+		mockEnvironment.addOutputs(Arrays.asList(partitions));
+		StreamTask task = new MockStreamTaskBuilder(mockEnvironment).build();
+
+		try {
+			task.beforeInvoke();
+
+			// output recovery should be done before task processing
+			for (RecoveryResultPartition resultPartition : partitions) {
+				assertTrue(resultPartition.isStateInitialized());
+			}
+		} finally {
+			task.cleanUpInvoke();
+		}
+	}
+
 	/**
 	 * Tests that some StreamTask methods are called only in the main task's thread.
 	 * Currently, the main task's thread is the thread that creates the task.
@@ -1723,4 +1749,20 @@ public class StreamTaskTest extends TestLogger {
 			throw new UnsupportedOperationException();
 		}
 	}
+
+	private static class RecoveryResultPartition extends MockResultPartitionWriter {
+		private boolean isStateInitialized;
+
+		RecoveryResultPartition() {
+		}
+
+		@Override
+		public void initializeState(ChannelStateReader stateReader) {
+			isStateInitialized = true;
+		}
+
+		boolean isStateInitialized() {
+			return isStateInitialized;
+		}
+	}
 }


Mime
View raw message