flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From u..@apache.org
Subject [4/6] flink git commit: [FLINK-5169] [network] Add tests for channel consumption
Date Fri, 02 Dec 2016 08:42:38 GMT
[FLINK-5169] [network] Add tests for channel consumption

This closes #2882.


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/c0cdc5c4
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/c0cdc5c4
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/c0cdc5c4

Branch: refs/heads/master
Commit: c0cdc5c4ec08e35a8ea319d1bbf2b24e03e24fd3
Parents: d3ac0ad
Author: Stephan Ewen <sewen@apache.org>
Authored: Sun Nov 27 18:15:40 2016 +0100
Committer: Ufuk Celebi <uce@apache.org>
Committed: Thu Dec 1 21:42:49 2016 +0100

----------------------------------------------------------------------
 .../partition/PipelinedSubpartition.java        |   8 +
 .../partition/consumer/LocalInputChannel.java   |   4 +-
 .../partition/consumer/SingleInputGate.java     |   4 +-
 .../partition/consumer/UnionInputGate.java      |   2 +-
 .../partition/InputChannelTestUtils.java        |  89 +++++
 .../partition/InputGateConcurrentTest.java      | 323 +++++++++++++++
 .../partition/InputGateFairnessTest.java        | 395 +++++++++++++++++++
 7 files changed, 820 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/c0cdc5c4/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
index e9400f0..9e2f5ba 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/PipelinedSubpartition.java
@@ -183,6 +183,14 @@ class PipelinedSubpartition extends ResultSubpartition {
 		return readView;
 	}
 
+	// ------------------------------------------------------------------------
+
+	int getCurrentNumberOfBuffers() {
+		return buffers.size();
+	}
+
+	// ------------------------------------------------------------------------
+
 	@Override
 	public String toString() {
 		final long numBuffers;

http://git-wip-us.apache.org/repos/asf/flink/blob/c0cdc5c4/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java
index d5308a8..1936da2 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/LocalInputChannel.java
@@ -64,7 +64,7 @@ public class LocalInputChannel extends InputChannel implements BufferAvailabilit
 
 	private volatile boolean isReleased;
 
-	LocalInputChannel(
+	public LocalInputChannel(
 		SingleInputGate inputGate,
 		int channelIndex,
 		ResultPartitionID partitionId,
@@ -76,7 +76,7 @@ public class LocalInputChannel extends InputChannel implements BufferAvailabilit
 			0, 0, metrics);
 	}
 
-	LocalInputChannel(
+	public LocalInputChannel(
 		SingleInputGate inputGate,
 		int channelIndex,
 		ResultPartitionID partitionId,

http://git-wip-us.apache.org/repos/asf/flink/blob/c0cdc5c4/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
index bcbb2c4..b4d8d2c 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/SingleInputGate.java
@@ -261,7 +261,7 @@ public class SingleInputGate implements InputGate {
 		this.bufferPool = checkNotNull(bufferPool);
 	}
 
-	void setInputChannel(IntermediateResultPartitionID partitionId, InputChannel inputChannel)
{
+	public void setInputChannel(IntermediateResultPartitionID partitionId, InputChannel inputChannel)
{
 		synchronized (requestLock) {
 			if (inputChannels.put(checkNotNull(partitionId), checkNotNull(inputChannel)) == null
 					&& inputChannel.getClass() == UnknownInputChannel.class) {
@@ -546,7 +546,7 @@ public class SingleInputGate implements InputGate {
 			inputChannelsWithData.add(channel);
 
 			if (availableChannels == 0) {
-				inputChannelsWithData.notify();
+				inputChannelsWithData.notifyAll();
 			}
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/c0cdc5c4/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
index e8ccbb4..55c78af 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/io/network/partition/consumer/UnionInputGate.java
@@ -225,7 +225,7 @@ public class UnionInputGate implements InputGate, InputGateListener {
 			inputGatesWithData.add(inputGate);
 
 			if (availableInputGates == 0) {
-				inputGatesWithData.notify();
+				inputGatesWithData.notifyAll();
 			}
 		}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/c0cdc5c4/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java
new file mode 100644
index 0000000..e292576
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputChannelTestUtils.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition;
+
+import org.apache.flink.runtime.io.network.ConnectionID;
+import org.apache.flink.runtime.io.network.ConnectionManager;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.buffer.BufferProvider;
+import org.apache.flink.runtime.io.network.netty.PartitionRequestClient;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+
+import static org.mockito.Matchers.any;
+import static org.mockito.Matchers.anyInt;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.when;
+
+/**
+ * Some utility methods used for testing InputChannels and InputGates.
+ */
+class InputChannelTestUtils {
+
+	/**
+	 * Creates a simple Buffer that is not recycled (never will be) of the given size.
+	 */
+	public static Buffer createMockBuffer(int size) {
+		final Buffer mockBuffer = mock(Buffer.class);
+		when(mockBuffer.isBuffer()).thenReturn(true);
+		when(mockBuffer.getSize()).thenReturn(size);
+		when(mockBuffer.isRecycled()).thenReturn(false);
+
+		return mockBuffer;
+	}
+
+	/**
+	 * Creates a result partition manager that ignores all IDs, and simply returns the given
+	 * subpartitions in sequence.
+	 */
+	public static ResultPartitionManager createResultPartitionManager(final ResultSubpartition[]
sources) throws Exception {
+
+		final Answer<ResultSubpartitionView> viewCreator = new Answer<ResultSubpartitionView>()
{
+
+			private int num = 0;
+
+			@Override
+			public ResultSubpartitionView answer(InvocationOnMock invocation) throws Throwable {
+				BufferAvailabilityListener channel = (BufferAvailabilityListener) invocation.getArguments()[3];
+				return sources[num++].createReadView(null, channel);
+			}
+		};
+
+		ResultPartitionManager manager = mock(ResultPartitionManager.class);
+		when(manager.createSubpartitionView(
+				any(ResultPartitionID.class), anyInt(), any(BufferProvider.class), any(BufferAvailabilityListener.class)))
+				.thenAnswer(viewCreator);
+
+		return manager;
+	}
+	
+	public static ConnectionManager createDummyConnectionManager() throws Exception {
+		final PartitionRequestClient mockClient = mock(PartitionRequestClient.class);
+
+		final ConnectionManager connManager = mock(ConnectionManager.class);
+		when(connManager.createPartitionRequestClient(any(ConnectionID.class))).thenReturn(mockClient);
+
+		return connManager;
+	}
+
+	// ------------------------------------------------------------------------
+
+	/** This class is not meant to be instantiated */
+	private InputChannelTestUtils() {}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/c0cdc5c4/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
new file mode 100644
index 0000000..6570679
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateConcurrentTest.java
@@ -0,0 +1,323 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.io.network.ConnectionID;
+import org.apache.flink.runtime.io.network.ConnectionManager;
+import org.apache.flink.runtime.io.network.TaskEventDispatcher;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
+import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
+import org.apache.flink.runtime.taskmanager.TaskActions;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+
+import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createDummyConnectionManager;
+import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createResultPartitionManager;
+import static org.junit.Assert.assertNotNull;
+import static org.mockito.Mockito.mock;
+
+public class InputGateConcurrentTest {
+
+	@Test
+	public void testConsumptionWithLocalChannels() throws Exception {
+		final int numChannels = 11;
+		final int buffersPerChannel = 1000;
+
+		final ResultPartition resultPartition = mock(ResultPartition.class);
+
+		final PipelinedSubpartition[] partitions = new PipelinedSubpartition[numChannels];
+		final Source[] sources = new Source[numChannels];
+
+		final ResultPartitionManager resultPartitionManager = createResultPartitionManager(partitions);
+
+		final SingleInputGate gate = new SingleInputGate(
+				"Test Task Name",
+				new JobID(),
+				new ExecutionAttemptID(),
+				new IntermediateDataSetID(),
+				0, numChannels,
+				mock(TaskActions.class),
+				new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+
+		for (int i = 0; i < numChannels; i++) {
+			LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(),
+					resultPartitionManager, mock(TaskEventDispatcher.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+			gate.setInputChannel(new IntermediateResultPartitionID(), channel);
+
+			partitions[i] = new PipelinedSubpartition(0, resultPartition);
+			sources[i] = new PipelinedSubpartitionSource(partitions[i]);
+		}
+
+		ProducerThread producer = new ProducerThread(sources, numChannels * buffersPerChannel,
4, 10);
+		ConsumerThread consumer = new ConsumerThread(gate, numChannels * buffersPerChannel);
+		producer.start();
+		consumer.start();
+
+		// the 'sync()' call checks for exceptions and failed assertions
+		producer.sync();
+		consumer.sync();
+	}
+
+	@Test
+	public void testConsumptionWithRemoteChannels() throws Exception {
+		final int numChannels = 11;
+		final int buffersPerChannel = 1000;
+
+		final ConnectionManager connManager = createDummyConnectionManager();
+		final Source[] sources = new Source[numChannels];
+
+		final SingleInputGate gate = new SingleInputGate(
+				"Test Task Name",
+				new JobID(),
+				new ExecutionAttemptID(),
+				new IntermediateDataSetID(),
+				0,
+				numChannels,
+				mock(TaskActions.class),
+				new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+
+		for (int i = 0; i < numChannels; i++) {
+			RemoteInputChannel channel = new RemoteInputChannel(
+					gate, i, new ResultPartitionID(), mock(ConnectionID.class),
+					connManager, 0, 0, new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+			gate.setInputChannel(new IntermediateResultPartitionID(), channel);
+
+			sources[i] = new RemoteChannelSource(channel);
+		}
+
+		ProducerThread producer = new ProducerThread(sources, numChannels * buffersPerChannel,
4, 10);
+		ConsumerThread consumer = new ConsumerThread(gate, numChannels * buffersPerChannel);
+		producer.start();
+		consumer.start();
+
+		// the 'sync()' call checks for exceptions and failed assertions
+		producer.sync();
+		consumer.sync();
+	}
+
+	@Test
+	public void testConsumptionWithMixedChannels() throws Exception {
+		final int numChannels = 61;
+		final int numLocalChannels = 20;
+		final int buffersPerChannel = 1000;
+
+		// fill the local/remote decision
+		List<Boolean> localOrRemote = new ArrayList<>(numChannels);
+		for (int i = 0; i < numChannels; i++) {
+			localOrRemote.add(i < numLocalChannels);
+		}
+		Collections.shuffle(localOrRemote);
+
+		final ConnectionManager connManager = createDummyConnectionManager();
+		final ResultPartition resultPartition = mock(ResultPartition.class);
+
+		final PipelinedSubpartition[] localPartitions = new PipelinedSubpartition[numLocalChannels];
+		final ResultPartitionManager resultPartitionManager = createResultPartitionManager(localPartitions);
+
+		final Source[] sources = new Source[numChannels];
+
+		final SingleInputGate gate = new SingleInputGate(
+				"Test Task Name",
+				new JobID(),
+				new ExecutionAttemptID(),
+				new IntermediateDataSetID(),
+				0,
+				numChannels,
+				mock(TaskActions.class),
+				new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+
+		for (int i = 0, local = 0; i < numChannels; i++) {
+			if (localOrRemote.get(i)) {
+				// local channel
+				PipelinedSubpartition psp = new PipelinedSubpartition(0, resultPartition);
+				localPartitions[local++] = psp;
+				sources[i] = new PipelinedSubpartitionSource(psp);
+
+				LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(),
+						resultPartitionManager, mock(TaskEventDispatcher.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+				gate.setInputChannel(new IntermediateResultPartitionID(), channel);
+			}
+			else {
+				//remote channel
+				RemoteInputChannel channel = new RemoteInputChannel(
+						gate, i, new ResultPartitionID(), mock(ConnectionID.class),
+						connManager, 0, 0, new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+				gate.setInputChannel(new IntermediateResultPartitionID(), channel);
+
+				sources[i] = new RemoteChannelSource(channel);
+			}
+		}
+
+		ProducerThread producer = new ProducerThread(sources, numChannels * buffersPerChannel,
4, 10);
+		ConsumerThread consumer = new ConsumerThread(gate, numChannels * buffersPerChannel);
+		producer.start();
+		consumer.start();
+
+		// the 'sync()' call checks for exceptions and failed assertions
+		producer.sync();
+		consumer.sync();
+	}
+
+	// ------------------------------------------------------------------------
+	//  testing threads
+	// ------------------------------------------------------------------------
+
+	private static abstract class Source {
+	
+		abstract void addBuffer(Buffer buffer) throws Exception;
+	}
+
+	private static class PipelinedSubpartitionSource extends Source {
+
+		final PipelinedSubpartition partition;
+
+		PipelinedSubpartitionSource(PipelinedSubpartition partition) {
+			this.partition = partition;
+		}
+
+		@Override
+		void addBuffer(Buffer buffer) throws Exception {
+			partition.add(buffer);
+		}
+	}
+
+	private static class RemoteChannelSource extends Source {
+
+		final RemoteInputChannel channel;
+		private int seq = 0;
+
+		RemoteChannelSource(RemoteInputChannel channel) {
+			this.channel = channel;
+		}
+
+		@Override
+		void addBuffer(Buffer buffer) throws Exception {
+			channel.onBuffer(buffer, seq++);
+		}
+	}
+
+	// ------------------------------------------------------------------------
+	//  testing threads
+	// ------------------------------------------------------------------------
+
+	private static abstract class CheckedThread extends Thread {
+
+		private volatile Throwable error;
+
+		public abstract void go() throws Exception;
+
+		@Override
+		public void run() {
+			try {
+				go();
+			}
+			catch (Throwable t) {
+				error = t;
+			}
+		}
+
+		public void sync() throws Exception {
+			join();
+
+			// propagate the error
+			if (error != null) {
+				if (error instanceof Error) {
+					throw (Error) error;
+				}
+				else if (error instanceof Exception) {
+					throw (Exception) error;
+				}
+				else {
+					throw new Exception(error.getMessage(), error);
+				}
+			}
+		}
+	}
+
+	private static class ProducerThread extends CheckedThread {
+
+		private final Random rnd = new Random();
+		private final Source[] sources;
+		private final int numTotal;
+		private final int maxChunk;
+		private final int yieldAfter;
+
+		ProducerThread(Source[] sources, int numTotal, int maxChunk, int yieldAfter) {
+			this.sources = sources;
+			this.numTotal = numTotal;
+			this.maxChunk = maxChunk;
+			this.yieldAfter = yieldAfter;
+		}
+
+		@Override
+		public void go() throws Exception {
+			final Buffer buffer = InputChannelTestUtils.createMockBuffer(100);
+			int nextYield = numTotal - yieldAfter;
+
+			for (int i = numTotal; i > 0;) {
+				final int nextChannel = rnd.nextInt(sources.length);
+				final int chunk = Math.min(i, rnd.nextInt(maxChunk) + 1);
+
+				final Source next = sources[nextChannel];
+
+				for (int k = chunk; k > 0; --k) {
+					next.addBuffer(buffer);
+				}
+
+				i -= chunk;
+
+				if (i <= nextYield) {
+					nextYield -= yieldAfter;
+					//noinspection CallToThreadYield
+					Thread.yield();
+				}
+
+			}
+		}
+	}
+
+	private static class ConsumerThread extends CheckedThread {
+
+		private final SingleInputGate gate;
+		private final int numBuffers;
+
+		ConsumerThread(SingleInputGate gate, int numBuffers) {
+			this.gate = gate;
+			this.numBuffers = numBuffers;
+		}
+
+		@Override
+		public void go() throws Exception {
+			for (int i = numBuffers; i > 0; --i) {
+				assertNotNull(gate.getNextBufferOrEvent());
+			}
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/c0cdc5c4/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
new file mode 100644
index 0000000..b35612a
--- /dev/null
+++ b/flink-runtime/src/test/java/org/apache/flink/runtime/io/network/partition/InputGateFairnessTest.java
@@ -0,0 +1,395 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.runtime.io.network.partition;
+
+import org.apache.flink.api.common.JobID;
+import org.apache.flink.runtime.executiongraph.ExecutionAttemptID;
+import org.apache.flink.runtime.io.network.ConnectionID;
+import org.apache.flink.runtime.io.network.ConnectionManager;
+import org.apache.flink.runtime.io.network.TaskEventDispatcher;
+import org.apache.flink.runtime.io.network.api.EndOfPartitionEvent;
+import org.apache.flink.runtime.io.network.api.serialization.EventSerializer;
+import org.apache.flink.runtime.io.network.buffer.Buffer;
+import org.apache.flink.runtime.io.network.partition.consumer.BufferOrEvent;
+import org.apache.flink.runtime.io.network.partition.consumer.InputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.LocalInputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.RemoteInputChannel;
+import org.apache.flink.runtime.io.network.partition.consumer.SingleInputGate;
+import org.apache.flink.runtime.jobgraph.IntermediateDataSetID;
+import org.apache.flink.runtime.jobgraph.IntermediateResultPartitionID;
+import org.apache.flink.runtime.metrics.groups.TaskIOMetricGroup;
+import org.apache.flink.runtime.operators.testutils.UnregisteredTaskMetricsGroup;
+import org.apache.flink.runtime.taskmanager.TaskActions;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.lang.reflect.Field;
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+
+import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createDummyConnectionManager;
+import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createMockBuffer;
+import static org.apache.flink.runtime.io.network.partition.InputChannelTestUtils.createResultPartitionManager;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+import static org.mockito.Mockito.mock;
+
+public class InputGateFairnessTest {
+
+	@Test
+	public void testFairConsumptionLocalChannelsPreFilled() throws Exception {
+		final int numChannels = 37;
+		final int buffersPerChannel = 27;
+
+		final ResultPartition resultPartition = mock(ResultPartition.class);
+		final Buffer mockBuffer = createMockBuffer(42);
+
+		// ----- create some source channels and fill them with buffers -----
+
+		final PipelinedSubpartition[] sources = new PipelinedSubpartition[numChannels];
+
+		for (int i = 0; i < numChannels; i++) {
+			PipelinedSubpartition partition = new PipelinedSubpartition(0, resultPartition);
+
+			for (int p = 0; p < buffersPerChannel; p++) {
+				partition.add(mockBuffer);
+			}
+
+			partition.finish();
+			sources[i] = partition;
+		}
+
+		// ----- create reading side -----
+
+		ResultPartitionManager resultPartitionManager = createResultPartitionManager(sources);
+
+		SingleInputGate gate = new FairnessVerifyingInputGate(
+				"Test Task Name",
+				new JobID(),
+				new ExecutionAttemptID(),
+				new IntermediateDataSetID(),
+				0, numChannels,
+				mock(TaskActions.class),
+				new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+
+		for (int i = 0; i < numChannels; i++) {
+			LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(),
+					resultPartitionManager, mock(TaskEventDispatcher.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+			gate.setInputChannel(new IntermediateResultPartitionID(), channel);
+		}
+
+		// read all the buffers and the EOF event
+		for (int i = numChannels * (buffersPerChannel + 1); i > 0; --i) {
+			assertNotNull(gate.getNextBufferOrEvent());
+
+			int min = Integer.MAX_VALUE;
+			int max = 0;
+
+			for (PipelinedSubpartition source : sources) {
+				int size = source.getCurrentNumberOfBuffers();
+				min = Math.min(min, size);
+				max = Math.max(max, size);
+			}
+
+			assertTrue(max == min || max == min+1);
+		}
+
+		assertNull(gate.getNextBufferOrEvent());
+	}
+
+	@Test
+	public void testFairConsumptionLocalChannels() throws Exception {
+		final int numChannels = 37;
+		final int buffersPerChannel = 27;
+
+		final ResultPartition resultPartition = mock(ResultPartition.class);
+		final Buffer mockBuffer = createMockBuffer(42);
+
+		// ----- create some source channels and fill them with one buffer each -----
+
+		final PipelinedSubpartition[] sources = new PipelinedSubpartition[numChannels];
+
+		for (int i = 0; i < numChannels; i++) {
+			sources[i] = new PipelinedSubpartition(0, resultPartition);
+		}
+
+		// ----- create reading side -----
+
+		ResultPartitionManager resultPartitionManager = createResultPartitionManager(sources);
+
+		SingleInputGate gate = new FairnessVerifyingInputGate(
+				"Test Task Name",
+				new JobID(),
+				new ExecutionAttemptID(),
+				new IntermediateDataSetID(),
+				0, numChannels,
+				mock(TaskActions.class),
+				new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+
+		for (int i = 0; i < numChannels; i++) {
+			LocalInputChannel channel = new LocalInputChannel(gate, i, new ResultPartitionID(),
+					resultPartitionManager, mock(TaskEventDispatcher.class), new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+			gate.setInputChannel(new IntermediateResultPartitionID(), channel);
+		}
+
+		// seed one initial buffer
+		sources[12].add(mockBuffer);
+
+		// read all the buffers and the EOF event
+		for (int i = 0; i < numChannels * buffersPerChannel; i++) {
+			assertNotNull(gate.getNextBufferOrEvent());
+
+			int min = Integer.MAX_VALUE;
+			int max = 0;
+
+			for (PipelinedSubpartition source : sources) {
+				int size = source.getCurrentNumberOfBuffers();
+				min = Math.min(min, size);
+				max = Math.max(max, size);
+			}
+
+			assertTrue(max == min || max == min+1);
+
+			if (i % (2 * numChannels) == 0) {
+				// add three buffers to each channel, in random order
+				fillRandom(sources, 3, mockBuffer);
+			}
+		}
+
+		// there is still more in the queues
+	}
+
+	@Test
+	public void testFairConsumptionRemoteChannelsPreFilled() throws Exception {
+		final int numChannels = 37;
+		final int buffersPerChannel = 27;
+
+		final Buffer mockBuffer = createMockBuffer(42);
+
+		// ----- create some source channels and fill them with buffers -----
+
+		SingleInputGate gate = new FairnessVerifyingInputGate(
+				"Test Task Name",
+				new JobID(),
+				new ExecutionAttemptID(),
+				new IntermediateDataSetID(),
+				0, numChannels,
+				mock(TaskActions.class),
+				new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+
+		final ConnectionManager connManager = createDummyConnectionManager();
+
+		final RemoteInputChannel[] channels = new RemoteInputChannel[numChannels];
+
+		for (int i = 0; i < numChannels; i++) {
+			RemoteInputChannel channel = new RemoteInputChannel(
+					gate, i, new ResultPartitionID(), mock(ConnectionID.class), 
+					connManager, 0, 0, new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+
+			channels[i] = channel;
+			
+			for (int p = 0; p < buffersPerChannel; p++) {
+				channel.onBuffer(mockBuffer, p);
+			}
+			channel.onBuffer(EventSerializer.toBuffer(EndOfPartitionEvent.INSTANCE), buffersPerChannel);
+
+			gate.setInputChannel(new IntermediateResultPartitionID(), channel);
+		}
+
+		// read all the buffers and the EOF event
+		for (int i = numChannels * (buffersPerChannel + 1); i > 0; --i) {
+			assertNotNull(gate.getNextBufferOrEvent());
+
+			int min = Integer.MAX_VALUE;
+			int max = 0;
+
+			for (RemoteInputChannel channel : channels) {
+				int size = channel.getNumberOfQueuedBuffers();
+				min = Math.min(min, size);
+				max = Math.max(max, size);
+			}
+
+			assertTrue(max == min || max == min+1);
+		}
+
+		assertNull(gate.getNextBufferOrEvent());
+	}
+
+	@Test
+	public void testFairConsumptionRemoteChannels() throws Exception {
+		final int numChannels = 37;
+		final int buffersPerChannel = 27;
+
+		final Buffer mockBuffer = createMockBuffer(42);
+
+		// ----- create some source channels and fill them with buffers -----
+
+		SingleInputGate gate = new FairnessVerifyingInputGate(
+				"Test Task Name",
+				new JobID(),
+				new ExecutionAttemptID(),
+				new IntermediateDataSetID(),
+				0, numChannels,
+				mock(TaskActions.class),
+				new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+
+		final ConnectionManager connManager = createDummyConnectionManager();
+
+		final RemoteInputChannel[] channels = new RemoteInputChannel[numChannels];
+		final int[] channelSequenceNums = new int[numChannels];
+
+		for (int i = 0; i < numChannels; i++) {
+			RemoteInputChannel channel = new RemoteInputChannel(
+					gate, i, new ResultPartitionID(), mock(ConnectionID.class),
+					connManager, 0, 0, new UnregisteredTaskMetricsGroup.DummyTaskIOMetricGroup());
+
+			channels[i] = channel;
+			gate.setInputChannel(new IntermediateResultPartitionID(), channel);
+		}
+
+		channels[11].onBuffer(mockBuffer, 0);
+		channelSequenceNums[11]++;
+
+		// read all the buffers and the EOF event
+		for (int i = 0; i < numChannels * buffersPerChannel; i++) {
+			assertNotNull(gate.getNextBufferOrEvent());
+
+			int min = Integer.MAX_VALUE;
+			int max = 0;
+
+			for (RemoteInputChannel channel : channels) {
+				int size = channel.getNumberOfQueuedBuffers();
+				min = Math.min(min, size);
+				max = Math.max(max, size);
+			}
+
+			assertTrue(max == min || max == min+1);
+
+			if (i % (2 * numChannels) == 0) {
+				// add three buffers to each channel, in random order
+				fillRandom(channels, channelSequenceNums, 3, mockBuffer);
+			}
+		}
+	}
+
+	// ------------------------------------------------------------------------
+	//  Utilities
+	// ------------------------------------------------------------------------
+
+	private void fillRandom(PipelinedSubpartition[] partitions, int numPerPartition, Buffer
buffer) throws Exception {
+		ArrayList<Integer> poss = new ArrayList<>(partitions.length * numPerPartition);
+
+		for (int i = 0; i < partitions.length; i++) {
+			for (int k = 0; k < numPerPartition; k++) {
+				poss.add(i);
+			}
+		}
+
+		Collections.shuffle(poss);
+
+		for (Integer i : poss) {
+			partitions[i].add(buffer);
+		}
+	}
+
+	private void fillRandom(
+			RemoteInputChannel[] partitions,
+			int[] sequenceNumbers,
+			int numPerPartition,
+			Buffer buffer) throws Exception {
+
+		ArrayList<Integer> poss = new ArrayList<>(partitions.length * numPerPartition);
+
+		for (int i = 0; i < partitions.length; i++) {
+			for (int k = 0; k < numPerPartition; k++) {
+				poss.add(i);
+			}
+		}
+
+		Collections.shuffle(poss);
+
+		for (int i : poss) {
+			partitions[i].onBuffer(buffer, sequenceNumbers[i]++);
+		}
+	}
+	
+	// ------------------------------------------------------------------------
+
+	private static class FairnessVerifyingInputGate extends SingleInputGate {
+
+		private final ArrayDeque<InputChannel> channelsWithData;
+
+		private final HashSet<InputChannel> uniquenessChecker;
+
+		@SuppressWarnings("unchecked")
+		public FairnessVerifyingInputGate(
+				String owningTaskName,
+				JobID jobId,
+				ExecutionAttemptID executionId,
+				IntermediateDataSetID consumedResultId,
+				int consumedSubpartitionIndex,
+				int numberOfInputChannels,
+				TaskActions taskActions,
+				TaskIOMetricGroup metrics) {
+
+			super(owningTaskName, jobId, executionId, consumedResultId, consumedSubpartitionIndex,
+					numberOfInputChannels, taskActions, metrics);
+
+			try {
+				Field f = SingleInputGate.class.getDeclaredField("inputChannelsWithData");
+				f.setAccessible(true);
+				channelsWithData = (ArrayDeque<InputChannel>) f.get(this);
+			}
+			catch (Exception e) {
+				throw new RuntimeException(e);
+			}
+
+			this.uniquenessChecker = new HashSet<>();
+		}
+
+
+		@Override
+		public BufferOrEvent getNextBufferOrEvent() throws IOException, InterruptedException {
+			synchronized (channelsWithData) {
+				assertTrue("too many input channels", channelsWithData.size() <= getNumberOfInputChannels());
+				ensureUnique(channelsWithData);
+			}
+
+			return super.getNextBufferOrEvent();
+		}
+
+		private void ensureUnique(Collection<InputChannel> channels) {
+			HashSet<InputChannel> uniquenessChecker = this.uniquenessChecker;
+
+			for (InputChannel channel : channels) {
+				if (!uniquenessChecker.add(channel)) {
+					fail("Duplicate channel in input gate: " + channel);
+				}
+			}
+
+			assertTrue("found duplicate input channels", uniquenessChecker.size() == channels.size());
+			uniquenessChecker.clear();
+		}
+	}
+}


Mime
View raw message