flink-issues mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From aljoscha <...@git.apache.org>
Subject [GitHub] flink pull request #5230: [FLINK-8345] Add iterator of keyed state on broadc...
Date Tue, 02 Jan 2018 13:18:20 GMT
Github user aljoscha commented on a diff in the pull request:

    https://github.com/apache/flink/pull/5230#discussion_r159223446
  
    --- Diff: flink-streaming-java/src/test/java/org/apache/flink/streaming/api/operators/co/CoBroadcastWithKeyedOperatorTest.java
---
    @@ -0,0 +1,734 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + * <p>
    + * http://www.apache.org/licenses/LICENSE-2.0
    + * <p>
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.flink.streaming.api.operators.co;
    +
    +import org.apache.flink.api.common.state.ListState;
    +import org.apache.flink.api.common.state.ListStateDescriptor;
    +import org.apache.flink.api.common.state.MapStateDescriptor;
    +import org.apache.flink.api.common.state.ValueStateDescriptor;
    +import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
    +import org.apache.flink.api.common.typeinfo.TypeInformation;
    +import org.apache.flink.api.java.functions.KeySelector;
    +import org.apache.flink.runtime.state.KeyedStateFunction;
    +import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction;
    +import org.apache.flink.streaming.api.watermark.Watermark;
    +import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
    +import org.apache.flink.streaming.runtime.tasks.OperatorStateHandles;
    +import org.apache.flink.streaming.util.AbstractStreamOperatorTestHarness;
    +import org.apache.flink.streaming.util.KeyedTwoInputStreamOperatorTestHarness;
    +import org.apache.flink.streaming.util.TestHarnessUtil;
    +import org.apache.flink.streaming.util.TwoInputStreamOperatorTestHarness;
    +import org.apache.flink.util.Collector;
    +import org.apache.flink.util.Preconditions;
    +
    +import org.junit.Assert;
    +import org.junit.Test;
    +
    +import java.util.ArrayList;
    +import java.util.Comparator;
    +import java.util.HashMap;
    +import java.util.HashSet;
    +import java.util.Iterator;
    +import java.util.List;
    +import java.util.Map;
    +import java.util.Queue;
    +import java.util.Set;
    +import java.util.concurrent.ConcurrentLinkedQueue;
    +import java.util.function.Function;
    +
    +/**
    + * Tests for the {@link CoBroadcastWithKeyedOperator}.
    + */
    +public class CoBroadcastWithKeyedOperatorTest {
    +
    +	/** Test the iteration over the keyed state on the broadcast side. */
    +	@Test
    +	public void testAccessToKeyedStateIt() throws Exception {
    +		final List<String> test1content = new ArrayList<>();
    +		test1content.add("test1");
    +		test1content.add("test1");
    +
    +		final List<String> test2content = new ArrayList<>();
    +		test2content.add("test2");
    +		test2content.add("test2");
    +		test2content.add("test2");
    +		test2content.add("test2");
    +
    +		final List<String> test3content = new ArrayList<>();
    +		test3content.add("test3");
    +		test3content.add("test3");
    +		test3content.add("test3");
    +
    +		final Map<String, List<String>> expectedState = new HashMap<>();
    +		expectedState.put("test1", test1content);
    +		expectedState.put("test2", test2content);
    +		expectedState.put("test3", test3content);
    +
    +		try (
    +				AutoClosableTestHarness<String, Integer, String, String, Integer, String> autoTestHarness
= new AutoClosableTestHarness<>(
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						new IdentityKeySelector<>(),
    +						new StatefulFunctionWithKeyedStateAccessedOnBroadcast(expectedState),
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						BasicTypeInfo.INT_TYPE_INFO)
    +		) {
    +
    +			TwoInputStreamOperatorTestHarness<Integer, String, String> testHarness = autoTestHarness.getTestHarness();
    +
    +			// send elements to the keyed state
    +			testHarness.processElement2(new StreamRecord<>("test1", 12L));
    +			testHarness.processElement2(new StreamRecord<>("test1", 12L));
    +
    +			testHarness.processElement2(new StreamRecord<>("test2", 13L));
    +			testHarness.processElement2(new StreamRecord<>("test2", 13L));
    +			testHarness.processElement2(new StreamRecord<>("test2", 13L));
    +
    +			testHarness.processElement2(new StreamRecord<>("test3", 14L));
    +			testHarness.processElement2(new StreamRecord<>("test3", 14L));
    +			testHarness.processElement2(new StreamRecord<>("test3", 14L));
    +
    +			testHarness.processElement2(new StreamRecord<>("test2", 13L));
    +
    +			// this is the element on the broadcast side that will trigger the verification
    +			// check the StatefulFunctionWithKeyedStateAccessedOnBroadcast#processElementOnBroadcastSide()
    +			testHarness.processElement1(new StreamRecord<>(1, 13L));
    +		}
    +	}
    +
    +	/**
    +	 * Simple {@link KeyedBroadcastProcessFunction} that adds all incoming elements in the
non-broadcast
    +	 * side to a listState and at the broadcast side it verifies if the stored data is the
expected ones.
    +	 */
    +	private static class StatefulFunctionWithKeyedStateAccessedOnBroadcast
    +			extends KeyedBroadcastProcessFunction<String, Integer, String, String, Integer,
String> {
    +
    +		private static final long serialVersionUID = 7496674620398203933L;
    +
    +		private final ListStateDescriptor<String> listStateDesc =
    +				new ListStateDescriptor<>("listStateTest", BasicTypeInfo.STRING_TYPE_INFO);
    +
    +		private final Map<String, List<String>> expectedKeyedStates;
    +
    +		StatefulFunctionWithKeyedStateAccessedOnBroadcast(Map<String, List<String>>
expectedKeyedState) {
    +			this.expectedKeyedStates = Preconditions.checkNotNull(expectedKeyedState);
    +		}
    +
    +		@Override
    +		public void processElementOnBroadcastSide(Integer value, KeyedReadWriteContext ctx,
Collector<String> out) throws Exception {
    +			// put an element in the broadcast state
    +			ctx.applyToKeyedState(
    +					listStateDesc,
    +					new KeyedStateFunction<String, ListState<String>>() {
    +						@Override
    +						public void process(String key, ListState<String> state) throws Exception
{
    +							final Iterator<String> it = state.get().iterator();
    +
    +							final List<String> list = new ArrayList<>();
    +							while (it.hasNext()) {
    +								list.add(it.next());
    +							}
    +							Assert.assertEquals(expectedKeyedStates.get(key), list);
    +						}
    +					});
    +		}
    +
    +		@Override
    +		public void processElement(String value, KeyedReadOnlyContext ctx, Collector<String>
out) throws Exception {
    +			getRuntimeContext().getListState(listStateDesc).add(value);
    +		}
    +	}
    +
    +	@Test
    +	public void testFunctionWithTimer() throws Exception {
    +
    +		try (
    +				AutoClosableTestHarness<String, Integer, String, String, Integer, String> autoTestHarness
= new AutoClosableTestHarness<>(
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						new IdentityKeySelector<>(),
    +						new FunctionWithTimerOnKeyed(41L),
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						BasicTypeInfo.INT_TYPE_INFO)
    +		) {
    +			TwoInputStreamOperatorTestHarness<Integer, String, String> testHarness = autoTestHarness.getTestHarness();
    +
    +			testHarness.processWatermark1(new Watermark(10L));
    +			testHarness.processWatermark2(new Watermark(10L));
    +			testHarness.processElement1(new StreamRecord<>(5, 12L));
    +
    +			testHarness.processWatermark1(new Watermark(40L));
    +			testHarness.processWatermark2(new Watermark(40L));
    +			testHarness.processElement2(new StreamRecord<>("6", 13L));
    +			testHarness.processElement2(new StreamRecord<>("6", 15L));
    +
    +			testHarness.processWatermark1(new Watermark(50L));
    +			testHarness.processWatermark2(new Watermark(50L));
    +
    +			Queue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
    +
    +			expectedOutput.add(new Watermark(10L));
    +			expectedOutput.add(new StreamRecord<>("BR:5 WM:10 TS:12", 12L));
    +			expectedOutput.add(new Watermark(40L));
    +			expectedOutput.add(new StreamRecord<>("NON-BR:6 WM:40 TS:13", 13L));
    +			expectedOutput.add(new StreamRecord<>("NON-BR:6 WM:40 TS:15", 15L));
    +			expectedOutput.add(new StreamRecord<>("TIMER:41", 41L));
    +			expectedOutput.add(new Watermark(50L));
    +
    +			TestHarnessUtil.assertOutputEquals("Output was not correct.", expectedOutput, testHarness.getOutput());
    +		}
    +	}
    +
    +	/**
    +	 * {@link KeyedBroadcastProcessFunction} that registers a timer and emits
    +	 * for every element the watermark and the timestamp of the element.
    +	 */
    +	private static class FunctionWithTimerOnKeyed extends KeyedBroadcastProcessFunction<String,
Integer, String, String, Integer, String> {
    +
    +		private static final long serialVersionUID = 7496674620398203933L;
    +
    +		private final long timerTS;
    +
    +		FunctionWithTimerOnKeyed(long timerTS) {
    +			this.timerTS = timerTS;
    +		}
    +
    +		@Override
    +		public void processElementOnBroadcastSide(Integer value, KeyedReadWriteContext ctx,
Collector<String> out) throws Exception {
    +			out.collect("BR:" + value + " WM:" + ctx.currentWatermark() + " TS:" + ctx.timestamp());
    +		}
    +
    +		@Override
    +		public void processElement(String value, KeyedReadOnlyContext ctx, Collector<String>
out) throws Exception {
    +			ctx.timerService().registerEventTimeTimer(timerTS);
    +			out.collect("NON-BR:" + value + " WM:" + ctx.currentWatermark() + " TS:" + ctx.timestamp());
    +		}
    +
    +		@Override
    +		public void onTimer(long timestamp, OnTimerContext ctx, Collector<String> out)
throws Exception {
    +			out.collect("TIMER:" + timestamp);
    +		}
    +	}
    +
    +	@Test
    +	public void testFunctionWithBroadcastState() throws Exception {
    +
    +		final Map<String, Integer> expectedBroadcastState = new HashMap<>();
    +		expectedBroadcastState.put("5.key", 5);
    +		expectedBroadcastState.put("34.key", 34);
    +		expectedBroadcastState.put("53.key", 53);
    +		expectedBroadcastState.put("12.key", 12);
    +		expectedBroadcastState.put("98.key", 98);
    +
    +		try (
    +				AutoClosableTestHarness<String, Integer, String, String, Integer, String> autoTestHarness
= new AutoClosableTestHarness<>(
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						new IdentityKeySelector<>(),
    +						new FunctionWithBroadcastState("key", expectedBroadcastState, 41L),
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						BasicTypeInfo.INT_TYPE_INFO)
    +		) {
    +			TwoInputStreamOperatorTestHarness<Integer, String, String> testHarness = autoTestHarness.getTestHarness();
    +
    +			testHarness.processWatermark1(new Watermark(10L));
    +			testHarness.processWatermark2(new Watermark(10L));
    +
    +			testHarness.processElement1(new StreamRecord<>(5, 10L));
    +			testHarness.processElement1(new StreamRecord<>(34, 12L));
    +			testHarness.processElement1(new StreamRecord<>(53, 15L));
    +			testHarness.processElement1(new StreamRecord<>(12, 16L));
    +			testHarness.processElement1(new StreamRecord<>(98, 19L));
    +
    +			testHarness.processElement2(new StreamRecord<>("trigger", 13L));
    +
    +			testHarness.processElement1(new StreamRecord<>(51, 21L));
    +
    +			testHarness.processWatermark1(new Watermark(50L));
    +			testHarness.processWatermark2(new Watermark(50L));
    +
    +			Queue<Object> output = testHarness.getOutput();
    +			Assert.assertEquals(3L, output.size());
    +
    +			Object firstRawWm = output.poll();
    +			Assert.assertTrue(firstRawWm instanceof Watermark);
    +			Watermark firstWm = (Watermark) firstRawWm;
    +			Assert.assertEquals(10L, firstWm.getTimestamp());
    +
    +			Object rawOutputElem = output.poll();
    +			Assert.assertTrue(rawOutputElem instanceof StreamRecord);
    +			StreamRecord<?> outputRec = (StreamRecord<?>) rawOutputElem;
    +			Assert.assertTrue(outputRec.getValue() instanceof String);
    +			String outputElem = (String) outputRec.getValue();
    +
    +			expectedBroadcastState.put("51.key", 51);
    +			List<Map.Entry<String, Integer>> expectedEntries = new ArrayList<>();
    +			expectedEntries.addAll(expectedBroadcastState.entrySet());
    +			String expected = "TS:41 " + mapToString(expectedEntries);
    +			Assert.assertEquals(expected, outputElem);
    +
    +			Object secondRawWm = output.poll();
    +			Assert.assertTrue(secondRawWm instanceof Watermark);
    +			Watermark secondWm = (Watermark) secondRawWm;
    +			Assert.assertEquals(50L, secondWm.getTimestamp());
    +		}
    +	}
    +
    +	private static class FunctionWithBroadcastState extends KeyedBroadcastProcessFunction<String,
Integer, String, String, Integer, String> {
    +
    +		private static final long serialVersionUID = 7496674620398203933L;
    +
    +		private final String keyPostfix;
    +		private final Map<String, Integer> expectedBroadcastState;
    +		private final long timerTs;
    +
    +		FunctionWithBroadcastState(
    +				final String keyPostfix,
    +				final Map<String, Integer> expectedBroadcastState,
    +				final long timerTs
    +		) {
    +			this.keyPostfix = Preconditions.checkNotNull(keyPostfix);
    +			this.expectedBroadcastState = Preconditions.checkNotNull(expectedBroadcastState);
    +			this.timerTs = timerTs;
    +		}
    +
    +		@Override
    +		public void processElementOnBroadcastSide(Integer value, KeyedReadWriteContext ctx,
Collector<String> out) throws Exception {
    +			// put an element in the broadcast state
    +			final String key = value + "." + keyPostfix;
    +			ctx.putToBroadcast(key, value);
    +		}
    +
    +		@Override
    +		public void processElement(String value, KeyedReadOnlyContext ctx, Collector<String>
out) throws Exception {
    +			Iterator<Map.Entry<String, Integer>> broadcastStateIt = ctx.readOnlyBroadcastIterable().iterator();
    +
    +			for (int i = 0; i < expectedBroadcastState.size(); i++) {
    +				Assert.assertTrue(broadcastStateIt.hasNext());
    +
    +				Map.Entry<String, Integer> entry = broadcastStateIt.next();
    +				Assert.assertTrue(expectedBroadcastState.containsKey(entry.getKey()));
    +				Assert.assertEquals(expectedBroadcastState.get(entry.getKey()), entry.getValue());
    +			}
    +
    +			Assert.assertFalse(broadcastStateIt.hasNext());
    +
    +			ctx.timerService().registerEventTimeTimer(timerTs);
    +		}
    +
    +		@Override
    +		public void onTimer(long timestamp, OnTimerContext ctx, Collector<String> out)
throws Exception {
    +			final Iterator<Map.Entry<String, Integer>> broadcastStateIt = ctx.readOnlyBroadcastIterable().iterator();
    +			final List<Map.Entry<String, Integer>> map = new ArrayList<>();
    +			while (broadcastStateIt.hasNext()) {
    +				map.add(broadcastStateIt.next());
    +			}
    +			final String mapToStr = mapToString(map);
    +			out.collect("TS:" + timestamp + " " + mapToStr);
    +		}
    +	}
    +
    +	@Test
    +	public void testScaleUp() throws Exception {
    +		final Set<String> keysToRegister = new HashSet<>();
    +		keysToRegister.add("test1");
    +		keysToRegister.add("test2");
    +		keysToRegister.add("test3");
    +
    +		final OperatorStateHandles mergedSnapshot;
    +
    +		try (
    +				AutoClosableTestHarness<String, Integer, String, String, Integer, String> autoTestHarness1
= new AutoClosableTestHarness<>(
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						new IdentityKeySelector<>(),
    +						new TestFunctionWithOutput(keysToRegister),
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						BasicTypeInfo.INT_TYPE_INFO,
    +						10,
    +						2,
    +						0);
    +
    +				AutoClosableTestHarness<String, Integer, String, String, Integer, String> autoTestHarness2
= new AutoClosableTestHarness<>(
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						new IdentityKeySelector<>(),
    +						new TestFunctionWithOutput(keysToRegister),
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						BasicTypeInfo.INT_TYPE_INFO,
    +						10,
    +						2,
    +						1)
    +
    +		) {
    +
    +			TwoInputStreamOperatorTestHarness<Integer, String, String> testHarness1 = autoTestHarness1.getTestHarness();
    +			TwoInputStreamOperatorTestHarness<Integer, String, String> testHarness2 = autoTestHarness2.getTestHarness();
    +
    +			// make sure all operators have the same state
    +			testHarness1.processElement1(new StreamRecord<>(3));
    +			testHarness2.processElement1(new StreamRecord<>(3));
    +
    +			mergedSnapshot = AbstractStreamOperatorTestHarness.repackageState(
    +					testHarness1.snapshot(0L, 0L),
    +					testHarness2.snapshot(0L, 0L)
    +			);
    +		}
    +
    +		final Set<String> expected = new HashSet<>(3);
    +		expected.add("test1=3");
    +		expected.add("test2=3");
    +		expected.add("test3=3");
    +
    +		try (
    +				AutoClosableTestHarness<String, Integer, String, String, Integer, String> autoTestHarness1
= new AutoClosableTestHarness<>(
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						new IdentityKeySelector<>(),
    +						new TestFunctionWithOutput(keysToRegister),
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						BasicTypeInfo.INT_TYPE_INFO,
    +						10,
    +						3,
    +						0,
    +						mergedSnapshot);
    +
    +				AutoClosableTestHarness<String, Integer, String, String, Integer, String> autoTestHarness2
= new AutoClosableTestHarness<>(
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						new IdentityKeySelector<>(),
    +						new TestFunctionWithOutput(keysToRegister),
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						BasicTypeInfo.INT_TYPE_INFO,
    +						10,
    +						3,
    +						1,
    +						mergedSnapshot);
    +
    +				AutoClosableTestHarness<String, Integer, String, String, Integer, String> autoTestHarness3
= new AutoClosableTestHarness<>(
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						new IdentityKeySelector<>(),
    +						new TestFunctionWithOutput(keysToRegister),
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						BasicTypeInfo.INT_TYPE_INFO,
    +						10,
    +						3,
    +						2,
    +						mergedSnapshot)
    +			) {
    +
    +			TwoInputStreamOperatorTestHarness<Integer, String, String> testHarness1 = autoTestHarness1.getTestHarness();
    +			TwoInputStreamOperatorTestHarness<Integer, String, String> testHarness2 = autoTestHarness2.getTestHarness();
    +			TwoInputStreamOperatorTestHarness<Integer, String, String> testHarness3 = autoTestHarness3.getTestHarness();
    +
    +			testHarness1.processElement2(new StreamRecord<>("trigger"));
    +			testHarness2.processElement2(new StreamRecord<>("trigger"));
    +			testHarness3.processElement2(new StreamRecord<>("trigger"));
    +
    +			Queue<?> output1 = testHarness1.getOutput();
    +			Queue<?> output2 = testHarness2.getOutput();
    +			Queue<?> output3 = testHarness3.getOutput();
    +
    +			Assert.assertEquals(expected.size(), output1.size());
    +			for (Object o: output1) {
    +				StreamRecord<String> rec = (StreamRecord<String>) o;
    +				Assert.assertTrue(expected.contains(rec.getValue()));
    +			}
    +
    +			Assert.assertEquals(expected.size(), output2.size());
    +			for (Object o: output2) {
    +				StreamRecord<String> rec = (StreamRecord<String>) o;
    +				Assert.assertTrue(expected.contains(rec.getValue()));
    +			}
    +
    +			Assert.assertEquals(expected.size(), output3.size());
    +			for (Object o: output3) {
    +				StreamRecord<String> rec = (StreamRecord<String>) o;
    +				Assert.assertTrue(expected.contains(rec.getValue()));
    +			}
    +		}
    +	}
    +
    +	@Test
    +	public void testScaleDown() throws Exception {
    +		final Set<String> keysToRegister = new HashSet<>();
    +		keysToRegister.add("test1");
    +		keysToRegister.add("test2");
    +		keysToRegister.add("test3");
    +
    +		final OperatorStateHandles mergedSnapshot;
    +
    +		try (
    +				AutoClosableTestHarness<String, Integer, String, String, Integer, String> autoTestHarness1
= new AutoClosableTestHarness<>(
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						new IdentityKeySelector<>(),
    +						new TestFunctionWithOutput(keysToRegister),
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						BasicTypeInfo.INT_TYPE_INFO,
    +						10,
    +						3,
    +						0);
    +
    +				AutoClosableTestHarness<String, Integer, String, String, Integer, String> autoTestHarness2
= new AutoClosableTestHarness<>(
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						new IdentityKeySelector<>(),
    +						new TestFunctionWithOutput(keysToRegister),
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						BasicTypeInfo.INT_TYPE_INFO,
    +						10,
    +						3,
    +						1);
    +
    +				AutoClosableTestHarness<String, Integer, String, String, Integer, String> autoTestHarness3
= new AutoClosableTestHarness<>(
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						new IdentityKeySelector<>(),
    +						new TestFunctionWithOutput(keysToRegister),
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						BasicTypeInfo.INT_TYPE_INFO,
    +						10,
    +						3,
    +						2)
    +		) {
    +
    +			TwoInputStreamOperatorTestHarness<Integer, String, String> testHarness1 = autoTestHarness1.getTestHarness();
    +			TwoInputStreamOperatorTestHarness<Integer, String, String> testHarness2 = autoTestHarness2.getTestHarness();
    +			TwoInputStreamOperatorTestHarness<Integer, String, String> testHarness3 = autoTestHarness3.getTestHarness();
    +
    +			// make sure all operators have the same state
    +			testHarness1.processElement1(new StreamRecord<>(3));
    +			testHarness2.processElement1(new StreamRecord<>(3));
    +			testHarness3.processElement1(new StreamRecord<>(3));
    +
    +			mergedSnapshot = AbstractStreamOperatorTestHarness.repackageState(
    +					testHarness1.snapshot(0L, 0L),
    +					testHarness2.snapshot(0L, 0L),
    +					testHarness3.snapshot(0L, 0L)
    +			);
    +		}
    +
    +		final Set<String> expected = new HashSet<>(3);
    +		expected.add("test1=3");
    +		expected.add("test2=3");
    +		expected.add("test3=3");
    +
    +		try (
    +				AutoClosableTestHarness<String, Integer, String, String, Integer, String> autoTestHarness1
= new AutoClosableTestHarness<>(
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						new IdentityKeySelector<>(),
    +						new TestFunctionWithOutput(keysToRegister),
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						BasicTypeInfo.INT_TYPE_INFO,
    +						10,
    +						2,
    +						0,
    +						mergedSnapshot);
    +
    +				AutoClosableTestHarness<String, Integer, String, String, Integer, String> autoTestHarness2
= new AutoClosableTestHarness<>(
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						new IdentityKeySelector<>(),
    +						new TestFunctionWithOutput(keysToRegister),
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						BasicTypeInfo.INT_TYPE_INFO,
    +						10,
    +						2,
    +						1,
    +						mergedSnapshot)
    +		) {
    +
    +			TwoInputStreamOperatorTestHarness<Integer, String, String> testHarness1 = autoTestHarness1.getTestHarness();
    +			TwoInputStreamOperatorTestHarness<Integer, String, String> testHarness2 = autoTestHarness2.getTestHarness();
    +
    +			testHarness1.processElement2(new StreamRecord<>("trigger"));
    +			testHarness2.processElement2(new StreamRecord<>("trigger"));
    +
    +			Queue<?> output1 = testHarness1.getOutput();
    +			Queue<?> output2 = testHarness2.getOutput();
    +
    +			Assert.assertEquals(expected.size(), output1.size());
    +			for (Object o: output1) {
    +				StreamRecord<String> rec = (StreamRecord<String>) o;
    +				Assert.assertTrue(expected.contains(rec.getValue()));
    +			}
    +
    +			Assert.assertEquals(expected.size(), output2.size());
    +			for (Object o: output2) {
    +				StreamRecord<String> rec = (StreamRecord<String>) o;
    +				Assert.assertTrue(expected.contains(rec.getValue()));
    +			}
    +		}
    +	}
    +
    +	private static class TestFunctionWithOutput extends KeyedBroadcastProcessFunction<String,
Integer, String, String, Integer, String> {
    +
    +		private static final long serialVersionUID = 7496674620398203933L;
    +
    +		private final Set<String> keysToRegister;
    +
    +		TestFunctionWithOutput(Set<String> keysToRegister) {
    +			this.keysToRegister = Preconditions.checkNotNull(keysToRegister);
    +		}
    +
    +		@Override
    +		public void processElementOnBroadcastSide(Integer value, KeyedReadWriteContext ctx,
Collector<String> out) throws Exception {
    +			// put an element in the broadcast state
    +			for (String k : keysToRegister) {
    +				ctx.putToBroadcast(k, value);
    +			}
    +		}
    +
    +		@Override
    +		public void processElement(String value, KeyedReadOnlyContext ctx, Collector<String>
out) throws Exception {
    +			for (Map.Entry<String, Integer> entry : ctx.readOnlyBroadcastIterable()) {
    +				out.collect(entry.toString());
    +			}
    +		}
    +	}
    +
    +	@Test
    +	public void testNoKeyedStateOnBroadcastSide() throws Exception {
    +
    +		boolean exceptionThrown = false;
    +
    +		try (
    +				AutoClosableTestHarness<String, Integer, String, String, Integer, String> autoTestHarness
= new AutoClosableTestHarness<>(
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						new IdentityKeySelector<>(),
    +						new KeyedBroadcastProcessFunction<String, Integer, String, String, Integer,
String>() {
    +
    +							private static final long serialVersionUID = -1725365436500098384L;
    +
    +							private final ValueStateDescriptor<String> valueState = new ValueStateDescriptor<>("any",
BasicTypeInfo.STRING_TYPE_INFO);
    +
    +							@Override
    +							public void processElementOnBroadcastSide(Integer value, KeyedReadWriteContext
ctx, Collector<String> out) throws Exception {
    +								getRuntimeContext().getState(valueState).value(); // this should fail
    +							}
    +
    +							@Override
    +							public void processElement(String value, KeyedReadOnlyContext ctx, Collector<String>
out) throws Exception {
    +								// do nothing
    +							}
    +						},
    +						BasicTypeInfo.STRING_TYPE_INFO,
    +						BasicTypeInfo.INT_TYPE_INFO)
    +		) {
    +			TwoInputStreamOperatorTestHarness<Integer, String, String> testHarness = autoTestHarness.getTestHarness();
    +
    +			testHarness.processWatermark1(new Watermark(10L));
    +			testHarness.processWatermark2(new Watermark(10L));
    +			testHarness.processElement1(new StreamRecord<>(5, 12L));
    +		} catch (NullPointerException e) {
    +			Assert.assertEquals("No key set. This method should not be called outside of a keyed
context.", e.getMessage());
    +			exceptionThrown = true;
    +		}
    +
    +		if (!exceptionThrown) {
    +			Assert.fail("No exception thrown");
    +		}
    +	}
    +
    +	private static class IdentityKeySelector<T> implements KeySelector<T, T>
{
    +		private static final long serialVersionUID = 1L;
    +
    +		@Override
    +		public T getKey(T value) throws Exception {
    +			return value;
    +		}
    +	}
    +
    +	/**
    +	 * A wrapper of the test harness that makes sure to close it after the test finishes.
    +	 */
    +	private static class AutoClosableTestHarness<KEY, IN1, IN2, K, V, OUT> implements
AutoCloseable {
    --- End diff --
    
    The harnesses themselves are already `AutoCloseable`. (It's a somewhat newer addition,
though.)


---

Mime
View raw message