flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From se...@apache.org
Subject [16/26] flink git commit: [FLINK-1943] [gelly] Added GSA compiler and translation tests
Date Tue, 21 Jul 2015 19:10:55 GMT
[FLINK-1943] [gelly] Added GSA compiler and translation tests

This closes #916


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/11b48633
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/11b48633
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/11b48633

Branch: refs/heads/master
Commit: 11b48633dd421fcdf5e3cdced908ce85c1aa399d
Parents: bbbfd22
Author: vasia <vasia@apache.org>
Authored: Wed Jul 8 12:38:27 2015 +0200
Committer: Stephan Ewen <sewen@apache.org>
Committed: Tue Jul 21 17:58:15 2015 +0200

----------------------------------------------------------------------
 flink-staging/flink-gelly/pom.xml               |   7 +
 .../apache/flink/graph/gsa/GSACompilerTest.java | 147 +++++++++++++++++
 .../flink/graph/gsa/GSATranslationTest.java     | 165 +++++++++++++++++++
 3 files changed, 319 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/11b48633/flink-staging/flink-gelly/pom.xml
----------------------------------------------------------------------
diff --git a/flink-staging/flink-gelly/pom.xml b/flink-staging/flink-gelly/pom.xml
index 6536b70..9dce170 100644
--- a/flink-staging/flink-gelly/pom.xml
+++ b/flink-staging/flink-gelly/pom.xml
@@ -52,6 +52,13 @@ under the License.
 			<scope>test</scope>
 		</dependency>
 		<dependency>
+			<groupId>org.apache.flink</groupId>
+			<artifactId>flink-optimizer</artifactId>
+			<version>${project.version}</version>
+			<type>test-jar</type>
+			<scope>test</scope>
+		</dependency>
+		<dependency>
 			<groupId>com.google.guava</groupId>
 			<artifactId>guava</artifactId>
 			<version>${guava.version}</version>

http://git-wip-us.apache.org/repos/asf/flink/blob/11b48633/flink-staging/flink-gelly/src/test/java/org/apache/flink/graph/gsa/GSACompilerTest.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-gelly/src/test/java/org/apache/flink/graph/gsa/GSACompilerTest.java
b/flink-staging/flink-gelly/src/test/java/org/apache/flink/graph/gsa/GSACompilerTest.java
new file mode 100644
index 0000000..7a66639
--- /dev/null
+++ b/flink-staging/flink-gelly/src/test/java/org/apache/flink/graph/gsa/GSACompilerTest.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.graph.gsa;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import org.apache.flink.api.common.Plan;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.operators.util.FieldList;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.io.DiscardingOutputFormat;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.graph.Edge;
+import org.apache.flink.graph.Graph;
+import org.apache.flink.graph.Vertex;
+import org.apache.flink.graph.utils.Tuple3ToEdgeMap;
+import org.apache.flink.optimizer.dataproperties.PartitioningProperty;
+import org.apache.flink.optimizer.plan.DualInputPlanNode;
+import org.apache.flink.optimizer.plan.OptimizedPlan;
+import org.apache.flink.optimizer.plan.PlanNode;
+import org.apache.flink.optimizer.plan.SingleInputPlanNode;
+import org.apache.flink.optimizer.plan.SinkPlanNode;
+import org.apache.flink.optimizer.plan.WorksetIterationPlanNode;
+import org.apache.flink.optimizer.util.CompilerTestBase;
+import org.apache.flink.runtime.operators.shipping.ShipStrategyType;
+import org.apache.flink.types.NullValue;
+import org.junit.Test;
+
+public class GSACompilerTest extends CompilerTestBase {
+
+	private static final long serialVersionUID = 1L;
+
+	@Test
+	public void testGSACompiler() {
+		try {
+			ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+			env.setParallelism(DEFAULT_PARALLELISM);
+			// compose test program
+			{
+				@SuppressWarnings("unchecked")
+				DataSet<Edge<Long, NullValue>> edges = env.fromElements(new Tuple3<Long,
Long, NullValue>(
+						1L, 2L, NullValue.getInstance())).map(new Tuple3ToEdgeMap<Long, NullValue>());
+
+				Graph<Long, Long, NullValue> graph = Graph.fromDataSet(edges, new InitVertices(),
env);
+
+				DataSet<Vertex<Long, Long>> result = graph.runGatherSumApplyIteration(
+						new GatherNeighborIds(), new SelectMinId(),
+						new UpdateComponentId(), 100).getVertices();
+				
+				result.output(new DiscardingOutputFormat<Vertex<Long, Long>>());
+			}
+			
+			Plan p = env.createProgramPlan("GSA Connected Components");
+			OptimizedPlan op = compileNoStats(p);
+			
+			// check the sink
+			SinkPlanNode sink = op.getDataSinks().iterator().next();
+			assertEquals(ShipStrategyType.FORWARD, sink.getInput().getShipStrategy());
+			assertEquals(DEFAULT_PARALLELISM, sink.getParallelism());
+			assertEquals(PartitioningProperty.HASH_PARTITIONED, sink.getGlobalProperties().getPartitioning());
+			
+			// check the iteration
+			WorksetIterationPlanNode iteration = (WorksetIterationPlanNode) sink.getInput().getSource();
+			assertEquals(DEFAULT_PARALLELISM, iteration.getParallelism());
+			
+			// check the solution set join and the delta
+			PlanNode ssDelta = iteration.getSolutionSetDeltaPlanNode();
+			assertTrue(ssDelta instanceof DualInputPlanNode); // this is only true if the update function
preserves the partitioning
+			
+			DualInputPlanNode ssJoin = (DualInputPlanNode) ssDelta;
+			assertEquals(DEFAULT_PARALLELISM, ssJoin.getParallelism());
+			assertEquals(ShipStrategyType.PARTITION_HASH, ssJoin.getInput1().getShipStrategy());
+			assertEquals(new FieldList(0), ssJoin.getInput1().getShipStrategyKeys());
+			
+			// check the workset set join
+			SingleInputPlanNode sumReducer = (SingleInputPlanNode) ssJoin.getInput1().getSource();
+			SingleInputPlanNode gatherMapper = (SingleInputPlanNode) sumReducer.getInput().getSource();
+			DualInputPlanNode edgeJoin = (DualInputPlanNode) gatherMapper.getInput().getSource();

+			assertEquals(DEFAULT_PARALLELISM, edgeJoin.getParallelism());
+			// input1 is the workset
+			assertEquals(ShipStrategyType.FORWARD, edgeJoin.getInput1().getShipStrategy());
+			// input2 is the edges
+			assertEquals(ShipStrategyType.PARTITION_HASH, edgeJoin.getInput2().getShipStrategy());
+			assertTrue(edgeJoin.getInput2().getTempMode().isCached());
+
+			assertEquals(new FieldList(0), edgeJoin.getInput2().getShipStrategyKeys());
+		}
+		catch (Exception e) {
+			System.err.println(e.getMessage());
+			e.printStackTrace();
+			fail(e.getMessage());
+		}
+	}
+
+	@SuppressWarnings("serial")
+	private static final class InitVertices	implements MapFunction<Long, Long> {
+
+		public Long map(Long vertexId) {
+			return vertexId;
+		}
+	}
+
+	@SuppressWarnings("serial")
+	private static final class GatherNeighborIds extends GatherFunction<Long, NullValue,
Long> {
+
+		public Long gather(Neighbor<Long, NullValue> neighbor) {
+			return neighbor.getNeighborValue();
+		}
+	};
+
+	@SuppressWarnings("serial")
+	private static final class SelectMinId extends SumFunction<Long, NullValue, Long>
{
+
+		public Long sum(Long newValue, Long currentValue) {
+			return Math.min(newValue, currentValue);
+		}
+	};
+
+	@SuppressWarnings("serial")
+	private static final class UpdateComponentId extends ApplyFunction<Long, Long, Long>
{
+
+		public void apply(Long summedValue, Long origValue) {
+			if (summedValue < origValue) {
+				setResult(summedValue);
+			}
+		}
+	}
+}

http://git-wip-us.apache.org/repos/asf/flink/blob/11b48633/flink-staging/flink-gelly/src/test/java/org/apache/flink/graph/gsa/GSATranslationTest.java
----------------------------------------------------------------------
diff --git a/flink-staging/flink-gelly/src/test/java/org/apache/flink/graph/gsa/GSATranslationTest.java
b/flink-staging/flink-gelly/src/test/java/org/apache/flink/graph/gsa/GSATranslationTest.java
new file mode 100644
index 0000000..0a7b1c7
--- /dev/null
+++ b/flink-staging/flink-gelly/src/test/java/org/apache/flink/graph/gsa/GSATranslationTest.java
@@ -0,0 +1,165 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.graph.gsa;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+import org.apache.flink.api.common.aggregators.LongSumAggregator;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.DataSet;
+import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.api.java.io.DiscardingOutputFormat;
+import org.apache.flink.api.java.operators.DeltaIteration;
+import org.apache.flink.api.java.operators.DeltaIterationResultSet;
+import org.apache.flink.api.java.operators.SingleInputUdfOperator;
+import org.apache.flink.api.java.operators.TwoInputUdfOperator;
+import org.apache.flink.api.java.tuple.Tuple3;
+import org.apache.flink.graph.Edge;
+import org.apache.flink.graph.Graph;
+import org.apache.flink.graph.Vertex;
+import org.apache.flink.graph.utils.Tuple3ToEdgeMap;
+import org.apache.flink.types.NullValue;
+import org.junit.Test;
+
+public class GSATranslationTest {
+
+	@Test
+	public void testTranslation() {
+		try {
+			final String ITERATION_NAME = "Test Name";
+			
+			final String AGGREGATOR_NAME = "AggregatorName";
+			
+			final String BC_SET_GATHER_NAME = "gather messages";
+			
+			final String BC_SET_SUM_NAME = "sum updates";
+
+			final String BC_SET_APLLY_NAME = "apply updates";
+
+			final int NUM_ITERATIONS = 13;
+			
+			final int ITERATION_parallelism = 77;
+			
+			
+			ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+			
+			DataSet<Long> bcGather = env.fromElements(1L);
+			DataSet<Long> bcSum = env.fromElements(1L);
+			DataSet<Long> bcApply = env.fromElements(1L);
+
+			DataSet<Vertex<Long, Long>> result;
+
+			// ------------ construct the test program ------------------
+			{
+
+				@SuppressWarnings("unchecked")
+				DataSet<Edge<Long, NullValue>> edges = env.fromElements(new Tuple3<Long,
Long, NullValue>(
+						1L, 2L, NullValue.getInstance())).map(new Tuple3ToEdgeMap<Long, NullValue>());
+
+				Graph<Long, Long, NullValue> graph = Graph.fromDataSet(edges, new InitVertices(),
env);
+
+				GSAConfiguration parameters = new GSAConfiguration();
+
+				parameters.registerAggregator(AGGREGATOR_NAME, new LongSumAggregator());
+				parameters.setName(ITERATION_NAME);
+				parameters.setParallelism(ITERATION_parallelism);
+				parameters.addBroadcastSetForGatherFunction(BC_SET_GATHER_NAME, bcGather);
+				parameters.addBroadcastSetForSumFunction(BC_SET_SUM_NAME, bcSum);
+				parameters.addBroadcastSetForApplyFunction(BC_SET_APLLY_NAME, bcApply);
+
+				result = graph.runGatherSumApplyIteration(
+						new GatherNeighborIds(), new SelectMinId(),
+						new UpdateComponentId(), NUM_ITERATIONS, parameters).getVertices();
+				
+				result.output(new DiscardingOutputFormat<Vertex<Long, Long>>());
+			}
+			
+			
+			// ------------- validate the java program ----------------
+			
+			assertTrue(result instanceof DeltaIterationResultSet);
+			
+			DeltaIterationResultSet<?, ?> resultSet = (DeltaIterationResultSet<?, ?>)
result;
+			DeltaIteration<?, ?> iteration = (DeltaIteration<?, ?>) resultSet.getIterationHead();
+			
+			// check the basic iteration properties
+			assertEquals(NUM_ITERATIONS, resultSet.getMaxIterations());
+			assertArrayEquals(new int[] {0}, resultSet.getKeyPositions());
+			assertEquals(ITERATION_parallelism, iteration.getParallelism());
+			assertEquals(ITERATION_NAME, iteration.getName());
+			
+			assertEquals(AGGREGATOR_NAME, iteration.getAggregators().getAllRegisteredAggregators().iterator().next().getName());
+			
+			// validate that the semantic properties are set as they should
+			TwoInputUdfOperator<?, ?, ?, ?> solutionSetJoin = (TwoInputUdfOperator<?, ?,
?, ?>) resultSet.getNextWorkset();
+			assertTrue(solutionSetJoin.getSemanticProperties().getForwardingTargetFields(0, 0).contains(0));
+			assertTrue(solutionSetJoin.getSemanticProperties().getForwardingTargetFields(1, 0).contains(0));
+
+			SingleInputUdfOperator<?, ?, ?> sumReduce = (SingleInputUdfOperator<?, ?, ?>)
solutionSetJoin.getInput1();
+			SingleInputUdfOperator<?, ?, ?> gatherMap = (SingleInputUdfOperator<?, ?, ?>)
sumReduce.getInput();
+
+			// validate that the broadcast sets are forwarded
+			assertEquals(bcGather, gatherMap.getBroadcastSets().get(BC_SET_GATHER_NAME));
+			assertEquals(bcSum, sumReduce.getBroadcastSets().get(BC_SET_SUM_NAME));
+			assertEquals(bcApply, solutionSetJoin.getBroadcastSets().get(BC_SET_APLLY_NAME));
+		}
+		catch (Exception e) {
+			System.err.println(e.getMessage());
+			e.printStackTrace();
+			fail(e.getMessage());
+		}
+	}
+
+	@SuppressWarnings("serial")
+	private static final class InitVertices	implements MapFunction<Long, Long> {
+
+		public Long map(Long vertexId) {
+			return vertexId;
+		}
+	}
+
+	@SuppressWarnings("serial")
+	private static final class GatherNeighborIds extends GatherFunction<Long, NullValue,
Long> {
+
+		public Long gather(Neighbor<Long, NullValue> neighbor) {
+			return neighbor.getNeighborValue();
+		}
+	};
+
+	@SuppressWarnings("serial")
+	private static final class SelectMinId extends SumFunction<Long, NullValue, Long>
{
+
+		public Long sum(Long newValue, Long currentValue) {
+			return Math.min(newValue, currentValue);
+		}
+	};
+
+	@SuppressWarnings("serial")
+	private static final class UpdateComponentId extends ApplyFunction<Long, Long, Long>
{
+
+		public void apply(Long summedValue, Long origValue) {
+			if (summedValue < origValue) {
+				setResult(summedValue);
+			}
+		}
+	}
+}


Mime
View raw message