flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From g...@apache.org
Subject [1/2] flink git commit: [FLINK-3806] [gelly] Revert use of DataSet.count()
Date Thu, 02 Jun 2016 14:27:39 GMT
Repository: flink
Updated Branches:
  refs/heads/master b201f8664 -> 65545c2ed


[FLINK-3806] [gelly] Revert use of DataSet.count()

This closes #2036


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/36ad78c0
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/36ad78c0
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/36ad78c0

Branch: refs/heads/master
Commit: 36ad78c0821fdae0a69371c67602dd2a7955e4a8
Parents: b201f86
Author: Greg Hogan <code@greghogan.com>
Authored: Wed May 25 11:06:01 2016 -0400
Committer: Greg Hogan <code@greghogan.com>
Committed: Thu Jun 2 09:11:19 2016 -0400

----------------------------------------------------------------------
 docs/apis/batch/libs/gelly.md                   |  1 -
 .../graph/library/HITSAlgorithmITCase.java      | 28 --------
 .../flink/graph/library/PageRankITCase.java     | 15 ++--
 .../graph/gsa/GatherSumApplyIteration.java      | 31 +++++++--
 .../apache/flink/graph/library/GSAPageRank.java | 52 +++-----------
 .../flink/graph/library/HITSAlgorithm.java      | 54 ++-------------
 .../apache/flink/graph/library/PageRank.java    | 55 ++++-----------
 .../graph/spargel/ScatterGatherIteration.java   | 73 ++++++++++++++------
 .../apache/flink/graph/utils/GraphUtils.java    | 58 ++++++++++++++++
 9 files changed, 171 insertions(+), 196 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/docs/apis/batch/libs/gelly.md
----------------------------------------------------------------------
diff --git a/docs/apis/batch/libs/gelly.md b/docs/apis/batch/libs/gelly.md
index 0d3e594..aadbd44 100644
--- a/docs/apis/batch/libs/gelly.md
+++ b/docs/apis/batch/libs/gelly.md
@@ -1967,7 +1967,6 @@ The constructors take the following parameters:
 
 * `beta`: the damping factor.
 * `maxIterations`: the maximum number of iterations to run.
-* `numVertices`: the number of vertices in the input. If known beforehand, is it advised
to provide this argument to speed up execution.
 
 ### GSA PageRank
 

http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/HITSAlgorithmITCase.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/HITSAlgorithmITCase.java
b/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/HITSAlgorithmITCase.java
index 019b851..1887725 100644
--- a/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/HITSAlgorithmITCase.java
+++ b/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/HITSAlgorithmITCase.java
@@ -56,20 +56,6 @@ public class HITSAlgorithmITCase extends MultipleProgramsTestBase{
 	}
 
 	@Test
-	public void testHITSWithTenIterationsAndNumOfVertices() throws Exception {
-		final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
-
-		Graph<Long, Double, NullValue> graph = Graph.fromDataSet(
-				HITSData.getVertexDataSet(env),
-				HITSData.getEdgeDataSet(env),
-				env);
-
-		List<Vertex<Long, Tuple2<DoubleValue, DoubleValue>>> result = graph.run(new
HITSAlgorithm<Long, Double, NullValue>(10, 5)).collect();
-		
-		compareWithDelta(result, 1e-7);
-	}
-
-	@Test
 	public void testHITSWithConvergeThreshold() throws Exception {
 		final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
 
@@ -83,20 +69,6 @@ public class HITSAlgorithmITCase extends MultipleProgramsTestBase{
 		compareWithDelta(result, 1e-7);
 	}
 
-	@Test
-	public void testHITSWithConvergeThresholdAndNumOfVertices() throws Exception {
-		final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
-
-		Graph<Long, Double, NullValue> graph = Graph.fromDataSet(
-				HITSData.getVertexDataSet(env),
-				HITSData.getEdgeDataSet(env),
-				env);
-
-		List<Vertex<Long, Tuple2<DoubleValue, DoubleValue>>> result = graph.run(new
HITSAlgorithm<Long, Double, NullValue>(1e-7, 5)).collect();
-
-		compareWithDelta(result, 1e-7);
-	}
-
 	private void compareWithDelta(List<Vertex<Long, Tuple2<DoubleValue, DoubleValue>>>
result, double delta) {
 
 		String resultString = "";

http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/PageRankITCase.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/PageRankITCase.java
b/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/PageRankITCase.java
index 034bcd5..e3e8f08 100644
--- a/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/PageRankITCase.java
+++ b/flink-libraries/flink-gelly-examples/src/test/java/org/apache/flink/graph/library/PageRankITCase.java
@@ -18,9 +18,6 @@
 
 package org.apache.flink.graph.library;
 
-import java.util.Arrays;
-import java.util.List;
-
 import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.java.ExecutionEnvironment;
 import org.apache.flink.graph.Graph;
@@ -32,6 +29,9 @@ import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
 
+import java.util.Arrays;
+import java.util.List;
+
 @RunWith(Parameterized.class)
 public class PageRankITCase extends MultipleProgramsTestBase {
 
@@ -72,9 +72,9 @@ public class PageRankITCase extends MultipleProgramsTestBase {
 		Graph<Long, Double, Double> inputGraph = Graph.fromDataSet(
 				PageRankData.getDefaultEdgeDataSet(env), new InitMapper(), env);
 
-        List<Vertex<Long, Double>> result = inputGraph.run(new PageRank<Long>(0.85,
5, 3))
+        List<Vertex<Long, Double>> result = inputGraph.run(new PageRank<Long>(0.85,
3))
         		.collect();
-        
+
         compareWithDelta(result, 0.01);
 	}
 
@@ -85,14 +85,13 @@ public class PageRankITCase extends MultipleProgramsTestBase {
 		Graph<Long, Double, Double> inputGraph = Graph.fromDataSet(
 				PageRankData.getDefaultEdgeDataSet(env), new InitMapper(), env);
 
-        List<Vertex<Long, Double>> result = inputGraph.run(new GSAPageRank<Long>(0.85,
5, 3))
+        List<Vertex<Long, Double>> result = inputGraph.run(new GSAPageRank<Long>(0.85,
3))
         		.collect();
         
         compareWithDelta(result, 0.01);
 	}
 
-	private void compareWithDelta(List<Vertex<Long, Double>> result,
-																double delta) {
+	private void compareWithDelta(List<Vertex<Long, Double>> result, double delta)
{
 
 		String resultString = "";
         for (Vertex<Long, Double> v : result) {

http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherSumApplyIteration.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherSumApplyIteration.java
b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherSumApplyIteration.java
index d092086..d1b12f9 100755
--- a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherSumApplyIteration.java
+++ b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/gsa/GatherSumApplyIteration.java
@@ -41,9 +41,12 @@ import org.apache.flink.graph.Edge;
 import org.apache.flink.graph.EdgeDirection;
 import org.apache.flink.graph.Graph;
 import org.apache.flink.graph.Vertex;
+import org.apache.flink.graph.utils.GraphUtils;
+import org.apache.flink.types.LongValue;
 import org.apache.flink.util.Collector;
 import org.apache.flink.util.Preconditions;
 
+import java.util.Collection;
 import java.util.Map;
 
 /**
@@ -125,12 +128,11 @@ public class GatherSumApplyIteration<K, VV, EV, M> implements
CustomUnaryOperati
 
 		// check whether the numVertices option is set and, if so, compute the total number of
vertices
 		// and set it within the gather, sum and apply functions
+
+		DataSet<LongValue> numberOfVertices = null;
 		if (this.configuration != null && this.configuration.isOptNumVertices()) {
 			try {
-				long numberOfVertices = graph.numberOfVertices();
-				gather.setNumberOfVertices(numberOfVertices);
-				sum.setNumberOfVertices(numberOfVertices);
-				apply.setNumberOfVertices(numberOfVertices);
+				numberOfVertices = GraphUtils.count(this.vertexDataSet);
 			} catch (Exception e) {
 				e.printStackTrace();
 			}
@@ -203,6 +205,9 @@ public class GatherSumApplyIteration<K, VV, EV, M> implements CustomUnaryOperati
 			for (Tuple2<String, DataSet<?>> e : this.configuration.getGatherBcastVars())
{
 				gatherMapOperator = gatherMapOperator.withBroadcastSet(e.f1, e.f0);
 			}
+			if (this.configuration.isOptNumVertices()) {
+				gatherMapOperator = gatherMapOperator.withBroadcastSet(numberOfVertices, "number of vertices");
+			}
 		}
 		DataSet<Tuple2<K, M>> gatheredSet = gatherMapOperator;
 
@@ -215,6 +220,9 @@ public class GatherSumApplyIteration<K, VV, EV, M> implements CustomUnaryOperati
 			for (Tuple2<String, DataSet<?>> e : this.configuration.getSumBcastVars())
{
 				sumReduceOperator = sumReduceOperator.withBroadcastSet(e.f1, e.f0);
 			}
+			if (this.configuration.isOptNumVertices()) {
+				sumReduceOperator = sumReduceOperator.withBroadcastSet(numberOfVertices, "number of vertices");
+			}
 		}
 		DataSet<Tuple2<K, M>> summedSet = sumReduceOperator;
 
@@ -231,6 +239,9 @@ public class GatherSumApplyIteration<K, VV, EV, M> implements CustomUnaryOperati
 			for (Tuple2<String, DataSet<?>> e : this.configuration.getApplyBcastVars())
{
 				appliedSet = appliedSet.withBroadcastSet(e.f1, e.f0);
 			}
+			if (this.configuration.isOptNumVertices()) {
+				appliedSet = appliedSet.withBroadcastSet(numberOfVertices, "number of vertices");
+			}
 		}
 
 		// let the operator know that we preserve the key field
@@ -289,6 +300,10 @@ public class GatherSumApplyIteration<K, VV, EV, M> implements CustomUnaryOperati
 
 		@Override
 		public void open(Configuration parameters) throws Exception {
+			if (getRuntimeContext().hasBroadcastVariable("number of vertices")) {
+				Collection<LongValue> numberOfVertices = getRuntimeContext().getBroadcastVariable("number
of vertices");
+				this.gatherFunction.setNumberOfVertices(numberOfVertices.iterator().next().getValue());
+			}
 			if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
 				this.gatherFunction.init(getIterationRuntimeContext());
 			}
@@ -327,6 +342,10 @@ public class GatherSumApplyIteration<K, VV, EV, M> implements CustomUnaryOperati
 
 		@Override
 		public void open(Configuration parameters) throws Exception {
+			if (getRuntimeContext().hasBroadcastVariable("number of vertices")) {
+				Collection<LongValue> numberOfVertices = getRuntimeContext().getBroadcastVariable("number
of vertices");
+				this.sumFunction.setNumberOfVertices(numberOfVertices.iterator().next().getValue());
+			}
 			if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
 				this.sumFunction.init(getIterationRuntimeContext());
 			}
@@ -365,6 +384,10 @@ public class GatherSumApplyIteration<K, VV, EV, M> implements CustomUnaryOperati
 
 		@Override
 		public void open(Configuration parameters) throws Exception {
+			if (getRuntimeContext().hasBroadcastVariable("number of vertices")) {
+				Collection<LongValue> numberOfVertices = getRuntimeContext().getBroadcastVariable("number
of vertices");
+				this.applyFunction.setNumberOfVertices(numberOfVertices.iterator().next().getValue());
+			}
 			if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
 				this.applyFunction.init(getIterationRuntimeContext());
 			}

http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/GSAPageRank.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/GSAPageRank.java
b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/GSAPageRank.java
index 99624ca..324f9c3 100644
--- a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/GSAPageRank.java
+++ b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/GSAPageRank.java
@@ -25,6 +25,7 @@ import org.apache.flink.graph.Graph;
 import org.apache.flink.graph.GraphAlgorithm;
 import org.apache.flink.graph.Vertex;
 import org.apache.flink.graph.gsa.ApplyFunction;
+import org.apache.flink.graph.gsa.GSAConfiguration;
 import org.apache.flink.graph.gsa.GatherFunction;
 import org.apache.flink.graph.gsa.Neighbor;
 import org.apache.flink.graph.gsa.SumFunction;
@@ -32,22 +33,17 @@ import org.apache.flink.graph.gsa.SumFunction;
 /**
  * This is an implementation of a simple PageRank algorithm, using a gather-sum-apply iteration.
  * The user can define the damping factor and the maximum number of iterations.
- * If the number of vertices of the input graph is known, it should be provided as a parameter
- * to speed up computation. Otherwise, the algorithm will first execute a job to count the
vertices.
- * 
+ *
  * The implementation assumes that each page has at least one incoming and one outgoing link.
  */
 public class GSAPageRank<K> implements GraphAlgorithm<K, Double, Double, DataSet<Vertex<K,
Double>>> {
 
 	private double beta;
 	private int maxIterations;
-	private long numberOfVertices;
 
 	/**
 	 * Creates an instance of the GSA PageRank algorithm.
-	 * If the number of vertices of the input graph is known,
-	 * use the {@link GSAPageRank#GSAPageRank(double, long, int)} constructor instead.
-	 * 
+	 *
 	 * The implementation assumes that each page has at least one incoming and one outgoing
link.
 	 * 
 	 * @param beta the damping factor
@@ -58,37 +54,19 @@ public class GSAPageRank<K> implements GraphAlgorithm<K, Double,
Double, DataSet
 		this.maxIterations = maxIterations;
 	}
 
-	/**
-	 * Creates an instance of the GSA PageRank algorithm.
-	 * If the number of vertices of the input graph is known,
-	 * use the {@link GSAPageRank#GSAPageRank(double, int)} constructor instead.
-	 * 
-	 * The implementation assumes that each page has at least one incoming and one outgoing
link.
-	 * 
-	 * @param beta the damping factor
-	 * @param maxIterations the maximum number of iterations
-	 * @param numVertices the number of vertices in the input
-	 */
-	public GSAPageRank(double beta, long numVertices, int maxIterations) {
-		this.beta = beta;
-		this.numberOfVertices = numVertices;
-		this.maxIterations = maxIterations;
-	}
-
 	@Override
 	public DataSet<Vertex<K, Double>> run(Graph<K, Double, Double> network)
throws Exception {
 
-		if (numberOfVertices == 0) {
-			numberOfVertices = network.numberOfVertices();
-		}
-
 		DataSet<Tuple2<K, Long>> vertexOutDegrees = network.outDegrees();
 
 		Graph<K, Double, Double> networkWithWeights = network
 				.joinWithEdgesOnSource(vertexOutDegrees, new InitWeights());
 
-		return networkWithWeights.runGatherSumApplyIteration(new GatherRanks(numberOfVertices),
new SumRanks(),
-				new UpdateRanks<K>(beta, numberOfVertices), maxIterations)
+		GSAConfiguration parameters = new GSAConfiguration();
+		parameters.setOptNumVertices(true);
+
+		return networkWithWeights.runGatherSumApplyIteration(new GatherRanks(), new SumRanks(),
+				new UpdateRanks<K>(beta), maxIterations, parameters)
 				.getVertices();
 	}
 
@@ -99,18 +77,12 @@ public class GSAPageRank<K> implements GraphAlgorithm<K, Double,
Double, DataSet
 	@SuppressWarnings("serial")
 	private static final class GatherRanks extends GatherFunction<Double, Double, Double>
{
 
-		long numberOfVertices;
-
-		public GatherRanks(long numberOfVertices) {
-			this.numberOfVertices = numberOfVertices;
-		}
-
 		@Override
 		public Double gather(Neighbor<Double, Double> neighbor) {
 			double neighborRank = neighbor.getNeighborValue();
 
 			if(getSuperstepNumber() == 1) {
-				neighborRank = 1.0 / numberOfVertices;
+				neighborRank = 1.0 / this.getNumberOfVertices();
 			}
 
 			return neighborRank * neighbor.getEdgeValue();
@@ -130,16 +102,14 @@ public class GSAPageRank<K> implements GraphAlgorithm<K, Double,
Double, DataSet
 	private static final class UpdateRanks<K> extends ApplyFunction<K, Double, Double>
{
 
 		private final double beta;
-		private final long numVertices;
 
-		public UpdateRanks(double beta, long numberOfVertices) {
+		public UpdateRanks(double beta) {
 			this.beta = beta;
-			this.numVertices = numberOfVertices;
 		}
 
 		@Override
 		public void apply(Double rankSum, Double currentValue) {
-			setResult((1-beta)/numVertices + beta * rankSum);
+			setResult((1-beta)/this.getNumberOfVertices() + beta * rankSum);
 		}
 	}
 

http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/HITSAlgorithm.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/HITSAlgorithm.java
b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/HITSAlgorithm.java
index 1ea367e..39e9487 100644
--- a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/HITSAlgorithm.java
+++ b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/HITSAlgorithm.java
@@ -42,8 +42,6 @@ import org.apache.flink.util.Preconditions;
  * represented a page that is linked by many different hubs.
  * Each vertex has a value of Tuple2 type, the first field is hub score and the second field
is authority score.
  * The implementation sets same score to every vertex and adds the reverse edge to every
edge at the beginning. 
- * If the number of vertices of the input graph is known, it should be provided as a parameter
- * to speed up computation. Otherwise, the algorithm will first execute a job to count the
vertices.
  * <p>
  *
  * @see <a href="https://en.wikipedia.org/wiki/HITS_algorithm">HITS Algorithm</a>
@@ -54,7 +52,6 @@ public class HITSAlgorithm<K, VV, EV> implements GraphAlgorithm<K,
VV, EV, DataS
 	private final static double MINIMUMTHRESHOLD = 1e-9;
 
 	private int maxIterations;
-	private long numberOfVertices;
 	private double convergeThreshold;
 
 	/**
@@ -76,26 +73,6 @@ public class HITSAlgorithm<K, VV, EV> implements GraphAlgorithm<K,
VV, EV, DataS
 	}
 
 	/**
-	 * Create an instance of HITS algorithm.
-	 * 
-	 * @param maxIterations    the maximum number of iterations
-	 * @param numberOfVertices the number of vertices in the graph
-	 */
-	public HITSAlgorithm(int maxIterations, long numberOfVertices) {
-		this(maxIterations, MINIMUMTHRESHOLD, numberOfVertices);
-	}
-
-	/**
-	 * Create an instance of HITS algorithm.
-	 * 
-	 * @param convergeThreshold convergence threshold for sum of scores to control whether the
iteration should be stopped
-	 * @param numberOfVertices  the number of vertices in the graph
-	 */
-	public HITSAlgorithm(double convergeThreshold, long numberOfVertices) {
-		this(MAXIMUMITERATION, convergeThreshold, numberOfVertices);
-	}
-
-	/**
 	 * Creates an instance of HITS algorithm.
 	 *
 	 * @param maxIterations     the maximum number of iterations
@@ -108,26 +85,8 @@ public class HITSAlgorithm<K, VV, EV> implements GraphAlgorithm<K,
VV, EV, DataS
 		this.convergeThreshold = convergeThreshold;
 	}
 
-	/**
-	 * Creates an instance of HITS algorithm.
-	 *
-	 * @param maxIterations     the maximum number of iterations
-	 * @param convergeThreshold convergence threshold for sum of scores to control whether the
iteration should be stopped
-	 * @param numberOfVertices  the number of vertices in the graph
-	 */
-	public HITSAlgorithm(int maxIterations, double convergeThreshold, long numberOfVertices)
{
-		this(maxIterations, convergeThreshold);
-		Preconditions.checkArgument(numberOfVertices > 0, "Number of vertices must be greater
than zero.");
-		this.numberOfVertices = numberOfVertices;
-	}
-
 	@Override
 	public DataSet<Vertex<K, Tuple2<DoubleValue, DoubleValue>>> run(Graph<K,
VV, EV> graph) throws Exception {
-
-		if (numberOfVertices == 0) {
-			numberOfVertices = graph.numberOfVertices();
-		}
-
 		Graph<K, Tuple2<DoubleValue, DoubleValue>, Boolean> newGraph = graph
 				.mapEdges(new AuthorityEdgeMapper<K, EV>())
 				.union(graph.reverse().mapEdges(new HubEdgeMapper<K, EV>()))
@@ -135,12 +94,13 @@ public class HITSAlgorithm<K, VV, EV> implements GraphAlgorithm<K,
VV, EV, DataS
 
 		ScatterGatherConfiguration parameter = new ScatterGatherConfiguration();
 		parameter.setDirection(EdgeDirection.OUT);
+		parameter.setOptNumVertices(true);
 		parameter.registerAggregator("updatedValueSum", new DoubleSumAggregator());
 		parameter.registerAggregator("authorityValueSum", new DoubleSumAggregator());
 		parameter.registerAggregator("diffValueSum", new DoubleSumAggregator());
 
 		return newGraph
-				.runScatterGatherIteration(new VertexUpdate<K>(maxIterations, convergeThreshold,
numberOfVertices),
+				.runScatterGatherIteration(new VertexUpdate<K>(maxIterations, convergeThreshold),
 						new MessageUpdate<K>(maxIterations), maxIterations, parameter)
 				.getVertices();
 	}
@@ -153,15 +113,13 @@ public class HITSAlgorithm<K, VV, EV> implements GraphAlgorithm<K,
VV, EV, DataS
 	public static final class VertexUpdate<K> extends VertexUpdateFunction<K, Tuple2<DoubleValue,
DoubleValue>, Double> {
 		private int maxIteration;
 		private double convergeThreshold;
-		private long numberOfVertices;
 		private DoubleSumAggregator updatedValueSumAggregator;
 		private DoubleSumAggregator authoritySumAggregator;
 		private DoubleSumAggregator diffSumAggregator;
 
-		public VertexUpdate(int maxIteration, double convergeThreshold, long numberOfVertices)
{
+		public VertexUpdate(int maxIteration, double convergeThreshold) {
 			this.maxIteration = maxIteration;
 			this.convergeThreshold = convergeThreshold;
-			this.numberOfVertices = numberOfVertices;
 		}
 
 		@Override
@@ -198,9 +156,9 @@ public class HITSAlgorithm<K, VV, EV> implements GraphAlgorithm<K,
VV, EV, DataS
 
 					//in the first iteration, the diff is the authority value of each vertex
 					double previousAuthAverage = 1.0;
-					double diffValueSum = 1.0 * numberOfVertices;
+					double diffValueSum = 1.0 * getNumberOfVertices();
 					if (getSuperstepNumber() > 1) {
-						previousAuthAverage = ((DoubleValue) getPreviousIterationAggregate("authorityValueSum")).getValue()
/ numberOfVertices;
+						previousAuthAverage = ((DoubleValue) getPreviousIterationAggregate("authorityValueSum")).getValue()
/ getNumberOfVertices();
 						diffValueSum = ((DoubleValue) getPreviousIterationAggregate("diffValueSum")).getValue();
 					}
 					authoritySumAggregator.aggregate(previousAuthAverage);
@@ -218,7 +176,7 @@ public class HITSAlgorithm<K, VV, EV> implements GraphAlgorithm<K,
VV, EV, DataS
 					newHubValue.setValue(updateValue);
 					newAuthorityValue.setValue(newAuthorityValue.getValue() / iterationValueSum);
 					authoritySumAggregator.aggregate(newAuthorityValue.getValue());
-					double previousAuthAverage = ((DoubleValue) getPreviousIterationAggregate("authorityValueSum")).getValue()
/ numberOfVertices;
+					double previousAuthAverage = ((DoubleValue) getPreviousIterationAggregate("authorityValueSum")).getValue()
/ getNumberOfVertices();
 
 					// count the diff value of sum of authority scores
 					diffSumAggregator.aggregate((previousAuthAverage - newAuthorityValue.getValue()));

http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/PageRank.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/PageRank.java
b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/PageRank.java
index 9890a7c..f83b05b 100644
--- a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/PageRank.java
+++ b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/library/PageRank.java
@@ -27,27 +27,23 @@ import org.apache.flink.graph.GraphAlgorithm;
 import org.apache.flink.graph.Vertex;
 import org.apache.flink.graph.spargel.MessageIterator;
 import org.apache.flink.graph.spargel.MessagingFunction;
+import org.apache.flink.graph.spargel.ScatterGatherConfiguration;
 import org.apache.flink.graph.spargel.VertexUpdateFunction;
 
 /**
  * This is an implementation of a simple PageRank algorithm, using a scatter-gather iteration.
  * The user can define the damping factor and the maximum number of iterations.
- * If the number of vertices of the input graph is known, it should be provided as a parameter
- * to speed up computation. Otherwise, the algorithm will first execute a job to count the
vertices.
- * 
+ *
  * The implementation assumes that each page has at least one incoming and one outgoing link.
  */
 public class PageRank<K> implements GraphAlgorithm<K, Double, Double, DataSet<Vertex<K,
Double>>> {
 
 	private double beta;
 	private int maxIterations;
-	private long numberOfVertices;
 
 	/**
 	 * Creates an instance of the PageRank algorithm.
-	 * If the number of vertices of the input graph is known,
-	 * use the {@link PageRank#PageRank(double, long, int)} constructor instead.
-	 * 
+	 *
 	 * The implementation assumes that each page has at least one incoming and one outgoing
link.
 	 * 
 	 * @param beta the damping factor
@@ -56,40 +52,21 @@ public class PageRank<K> implements GraphAlgorithm<K, Double,
Double, DataSet<Ve
 	public PageRank(double beta, int maxIterations) {
 		this.beta = beta;
 		this.maxIterations = maxIterations;
-		this.numberOfVertices = 0;
-	}
-
-	/**
-	 * Creates an instance of the PageRank algorithm.
-	 * If the number of vertices of the input graph is known,
-	 * use the {@link PageRank#PageRank(double, int)} constructor instead.
-	 * 
-	 * The implementation assumes that each page has at least one incoming and one outgoing
link.
-	 * 
-	 * @param beta the damping factor
-	 * @param maxIterations the maximum number of iterations
-	 * @param numVertices the number of vertices in the input
-	 */
-	public PageRank(double beta, long numVertices, int maxIterations) {
-		this.beta = beta;
-		this.maxIterations = maxIterations;
-		this.numberOfVertices = numVertices;
 	}
 
 	@Override
 	public DataSet<Vertex<K, Double>> run(Graph<K, Double, Double> network)
throws Exception {
 
-		if (numberOfVertices == 0) {
-			numberOfVertices = network.numberOfVertices();
-		}
-
 		DataSet<Tuple2<K, Long>> vertexOutDegrees = network.outDegrees();
 
 		Graph<K, Double, Double> networkWithWeights = network
 				.joinWithEdgesOnSource(vertexOutDegrees, new InitWeights());
 
-		return networkWithWeights.runScatterGatherIteration(new VertexRankUpdater<K>(beta,
numberOfVertices),
-				new RankMessenger<K>(numberOfVertices), maxIterations)
+		ScatterGatherConfiguration parameters = new ScatterGatherConfiguration();
+		parameters.setOptNumVertices(true);
+
+		return networkWithWeights.runScatterGatherIteration(new VertexRankUpdater<K>(beta),
+				new RankMessenger<K>(), maxIterations, parameters)
 				.getVertices();
 	}
 
@@ -101,11 +78,9 @@ public class PageRank<K> implements GraphAlgorithm<K, Double,
Double, DataSet<Ve
 	public static final class VertexRankUpdater<K> extends VertexUpdateFunction<K,
Double, Double> {
 
 		private final double beta;
-		private final long numVertices;
-		
-		public VertexRankUpdater(double beta, long numberOfVertices) {
+
+		public VertexRankUpdater(double beta) {
 			this.beta = beta;
-			this.numVertices = numberOfVertices;
 		}
 
 		@Override
@@ -116,7 +91,7 @@ public class PageRank<K> implements GraphAlgorithm<K, Double,
Double, DataSet<Ve
 			}
 
 			// apply the dampening factor / random jump
-			double newRank = (beta * rankSum) + (1 - beta) / numVertices;
+			double newRank = (beta * rankSum) + (1 - beta) / this.getNumberOfVertices();
 			setNewVertexValue(newRank);
 		}
 	}
@@ -129,17 +104,11 @@ public class PageRank<K> implements GraphAlgorithm<K, Double,
Double, DataSet<Ve
 	@SuppressWarnings("serial")
 	public static final class RankMessenger<K> extends MessagingFunction<K, Double,
Double, Double> {
 
-		private final long numVertices;
-
-		public RankMessenger(long numberOfVertices) {
-			this.numVertices = numberOfVertices;
-		}
-
 		@Override
 		public void sendMessages(Vertex<K, Double> vertex) {
 			if (getSuperstepNumber() == 1) {
 				// initialize vertex ranks
-				vertex.setValue(new Double(1.0 / numVertices));
+				vertex.setValue(1.0 / this.getNumberOfVertices());
 			}
 
 			for (Edge<K, Double> edge : getEdges()) {

http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/spargel/ScatterGatherIteration.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/spargel/ScatterGatherIteration.java
b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/spargel/ScatterGatherIteration.java
index 496e36d..165ef1e 100644
--- a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/spargel/ScatterGatherIteration.java
+++ b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/spargel/ScatterGatherIteration.java
@@ -18,18 +18,15 @@
 
 package org.apache.flink.graph.spargel;
 
-import java.util.Iterator;
-import java.util.Map;
-
 import org.apache.flink.api.common.aggregators.Aggregator;
 import org.apache.flink.api.common.functions.FlatJoinFunction;
 import org.apache.flink.api.common.functions.MapFunction;
-import org.apache.flink.api.java.DataSet;
-import org.apache.flink.api.java.operators.DeltaIteration;
 import org.apache.flink.api.common.functions.RichCoGroupFunction;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.java.DataSet;
 import org.apache.flink.api.java.operators.CoGroupOperator;
 import org.apache.flink.api.java.operators.CustomUnaryOperation;
+import org.apache.flink.api.java.operators.DeltaIteration;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.tuple.Tuple3;
 import org.apache.flink.api.java.typeutils.ResultTypeQueryable;
@@ -40,9 +37,15 @@ import org.apache.flink.graph.Edge;
 import org.apache.flink.graph.EdgeDirection;
 import org.apache.flink.graph.Graph;
 import org.apache.flink.graph.Vertex;
+import org.apache.flink.graph.utils.GraphUtils;
+import org.apache.flink.types.LongValue;
 import org.apache.flink.util.Collector;
 import org.apache.flink.util.Preconditions;
 
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.Map;
+
 /**
  * This class represents iterative graph computations, programmed in a scatter-gather perspective.
  * It is a special case of <i>Bulk Synchronous Parallel</i> computation.
@@ -151,11 +154,10 @@ public class ScatterGatherIteration<K, VV, Message, EV>
 		// check whether the numVertices option is set and, if so, compute the total number of
vertices
 		// and set it within the messaging and update functions
 
+		DataSet<LongValue> numberOfVertices = null;
 		if (this.configuration != null && this.configuration.isOptNumVertices()) {
 			try {
-				long numberOfVertices = graph.numberOfVertices();
-				messagingFunction.setNumberOfVertices(numberOfVertices);
-				updateFunction.setNumberOfVertices(numberOfVertices);
+				numberOfVertices = GraphUtils.count(this.initialVertices);
 			} catch (Exception e) {
 				e.printStackTrace();
 			}
@@ -173,9 +175,9 @@ public class ScatterGatherIteration<K, VV, Message, EV>
 		// check whether the degrees option is set and, if so, compute the in and the out degrees
and
 		// add them to the vertex value
 		if(this.configuration != null && this.configuration.isOptDegrees()) {
-			return createResultVerticesWithDegrees(graph, messagingDirection, messageTypeInfo);
+			return createResultVerticesWithDegrees(graph, messagingDirection, messageTypeInfo, numberOfVertices);
 		} else {
-			return createResultSimpleVertex(messagingDirection, messageTypeInfo);
+			return createResultSimpleVertex(messagingDirection, messageTypeInfo, numberOfVertices);
 		}
 	}
 
@@ -246,6 +248,10 @@ public class ScatterGatherIteration<K, VV, Message, EV>
 
 		@Override
 		public void open(Configuration parameters) throws Exception {
+			if (getRuntimeContext().hasBroadcastVariable("number of vertices")) {
+				Collection<LongValue> numberOfVertices = getRuntimeContext().getBroadcastVariable("number
of vertices");
+				this.vertexUpdateFunction.setNumberOfVertices(numberOfVertices.iterator().next().getValue());
+			}
 			if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
 				this.vertexUpdateFunction.init(getIterationRuntimeContext());
 			}
@@ -368,10 +374,13 @@ public class ScatterGatherIteration<K, VV, Message, EV>
 		
 		@Override
 		public void open(Configuration parameters) throws Exception {
+			if (getRuntimeContext().hasBroadcastVariable("number of vertices")) {
+				Collection<LongValue> numberOfVertices = getRuntimeContext().getBroadcastVariable("number
of vertices");
+				this.messagingFunction.setNumberOfVertices(numberOfVertices.iterator().next().getValue());
+			}
 			if (getIterationRuntimeContext().getSuperstepNumber() == 1) {
 				this.messagingFunction.init(getIterationRuntimeContext());
 			}
-			
 			this.messagingFunction.preSuperstep();
 		}
 		
@@ -459,7 +468,8 @@ public class ScatterGatherIteration<K, VV, Message, EV>
 	 */
 	private CoGroupOperator<?, ?, Tuple2<K, Message>> buildMessagingFunction(
 			DeltaIteration<Vertex<K, VV>, Vertex<K, VV>> iteration,
-			TypeInformation<Tuple2<K, Message>> messageTypeInfo, int whereArg, int equalToArg)
{
+			TypeInformation<Tuple2<K, Message>> messageTypeInfo, int whereArg, int equalToArg,
+			DataSet<LongValue> numberOfVertices) {
 
 		// build the messaging function (co group)
 		CoGroupOperator<?, ?, Tuple2<K, Message>> messages;
@@ -475,6 +485,9 @@ public class ScatterGatherIteration<K, VV, Message, EV>
 			for (Tuple2<String, DataSet<?>> e : this.configuration.getMessagingBcastVars())
{
 				messages = messages.withBroadcastSet(e.f1, e.f0);
 			}
+			if (this.configuration.isOptNumVertices()) {
+				messages = messages.withBroadcastSet(numberOfVertices, "number of vertices");
+			}
 		}
 
 		return messages;
@@ -493,7 +506,8 @@ public class ScatterGatherIteration<K, VV, Message, EV>
 	 */
 	private CoGroupOperator<?, ?, Tuple2<K, Message>> buildMessagingFunctionVerticesWithDegrees(
 			DeltaIteration<Vertex<K, Tuple3<VV, Long, Long>>, Vertex<K, Tuple3<VV,
Long, Long>>> iteration,
-			TypeInformation<Tuple2<K, Message>> messageTypeInfo, int whereArg, int equalToArg)
{
+			TypeInformation<Tuple2<K, Message>> messageTypeInfo, int whereArg, int equalToArg,
+			DataSet<LongValue> numberOfVertices) {
 
 		// build the messaging function (co group)
 		CoGroupOperator<?, ?, Tuple2<K, Message>> messages;
@@ -510,6 +524,9 @@ public class ScatterGatherIteration<K, VV, Message, EV>
 			for (Tuple2<String, DataSet<?>> e : this.configuration.getMessagingBcastVars())
{
 				messages = messages.withBroadcastSet(e.f1, e.f0);
 			}
+			if (this.configuration.isOptNumVertices()) {
+				messages = messages.withBroadcastSet(numberOfVertices, "number of vertices");
+			}
 		}
 
 		return messages;
@@ -546,10 +563,11 @@ public class ScatterGatherIteration<K, VV, Message, EV>
 	 *
 	 * @param messagingDirection
 	 * @param messageTypeInfo
+	 * @param numberOfVertices
 	 * @return the operator
 	 */
 	private DataSet<Vertex<K, VV>> createResultSimpleVertex(EdgeDirection messagingDirection,
-		TypeInformation<Tuple2<K, Message>> messageTypeInfo) {
+		TypeInformation<Tuple2<K, Message>> messageTypeInfo, DataSet<LongValue>
numberOfVertices) {
 
 		DataSet<Tuple2<K, Message>> messages;
 
@@ -561,14 +579,14 @@ public class ScatterGatherIteration<K, VV, Message, EV>
 
 		switch (messagingDirection) {
 			case IN:
-				messages = buildMessagingFunction(iteration, messageTypeInfo, 1, 0);
+				messages = buildMessagingFunction(iteration, messageTypeInfo, 1, 0, numberOfVertices);
 				break;
 			case OUT:
-				messages = buildMessagingFunction(iteration, messageTypeInfo, 0, 0);
+				messages = buildMessagingFunction(iteration, messageTypeInfo, 0, 0, numberOfVertices);
 				break;
 			case ALL:
-				messages = buildMessagingFunction(iteration, messageTypeInfo, 1, 0)
-						.union(buildMessagingFunction(iteration, messageTypeInfo, 0, 0)) ;
+				messages = buildMessagingFunction(iteration, messageTypeInfo, 1, 0, numberOfVertices)
+						.union(buildMessagingFunction(iteration, messageTypeInfo, 0, 0, numberOfVertices))
;
 				break;
 			default:
 				throw new IllegalArgumentException("Illegal edge direction");
@@ -581,6 +599,10 @@ public class ScatterGatherIteration<K, VV, Message, EV>
 		CoGroupOperator<?, ?, Vertex<K, VV>> updates =
 				messages.coGroup(iteration.getSolutionSet()).where(0).equalTo(0).with(updateUdf);
 
+		if (this.configuration != null && this.configuration.isOptNumVertices()) {
+			updates = updates.withBroadcastSet(numberOfVertices, "number of vertices");
+		}
+
 		configureUpdateFunction(updates);
 
 		return iteration.closeWith(updates, updates);
@@ -593,11 +615,12 @@ public class ScatterGatherIteration<K, VV, Message, EV>
 	 * @param graph
 	 * @param messagingDirection
 	 * @param messageTypeInfo
+	 * @param numberOfVertices
 	 * @return the operator
 	 */
 	@SuppressWarnings("serial")
 	private DataSet<Vertex<K, VV>> createResultVerticesWithDegrees(Graph<K, VV,
EV> graph, EdgeDirection messagingDirection,
-			TypeInformation<Tuple2<K, Message>> messageTypeInfo) {
+			TypeInformation<Tuple2<K, Message>> messageTypeInfo, DataSet<LongValue>
numberOfVertices) {
 
 		DataSet<Tuple2<K, Message>> messages;
 
@@ -636,14 +659,14 @@ public class ScatterGatherIteration<K, VV, Message, EV>
 
 		switch (messagingDirection) {
 			case IN:
-				messages = buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 1, 0);
+				messages = buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 1, 0,
numberOfVertices);
 				break;
 			case OUT:
-				messages = buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 0, 0);
+				messages = buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 0, 0,
numberOfVertices);
 				break;
 			case ALL:
-				messages = buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 1, 0)
-						.union(buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 0, 0))
;
+				messages = buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 1, 0,
numberOfVertices)
+						.union(buildMessagingFunctionVerticesWithDegrees(iteration, messageTypeInfo, 0, 0,
numberOfVertices)) ;
 				break;
 			default:
 				throw new IllegalArgumentException("Illegal edge direction");
@@ -657,6 +680,10 @@ public class ScatterGatherIteration<K, VV, Message, EV>
 		CoGroupOperator<?, ?, Vertex<K, Tuple3<VV, Long, Long>>> updates =
 				messages.coGroup(iteration.getSolutionSet()).where(0).equalTo(0).with(updateUdf);
 
+		if (this.configuration != null && this.configuration.isOptNumVertices()) {
+			updates = updates.withBroadcastSet(numberOfVertices, "number of vertices");
+		}
+
 		configureUpdateFunction(updates);
 
 		return iteration.closeWith(updates, updates).map(

http://git-wip-us.apache.org/repos/asf/flink/blob/36ad78c0/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/utils/GraphUtils.java
----------------------------------------------------------------------
diff --git a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/utils/GraphUtils.java
b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/utils/GraphUtils.java
index 009d791..264479b 100644
--- a/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/utils/GraphUtils.java
+++ b/flink-libraries/flink-gelly/src/main/java/org/apache/flink/graph/utils/GraphUtils.java
@@ -19,12 +19,18 @@
 package org.apache.flink.graph.utils;
 
 import org.apache.flink.api.common.JobExecutionResult;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.java.DataSet;
 import org.apache.flink.api.java.Utils;
 import org.apache.flink.graph.Edge;
 import org.apache.flink.graph.Graph;
 import org.apache.flink.graph.Vertex;
+import org.apache.flink.types.LongValue;
 import org.apache.flink.util.AbstractID;
 
+import static org.apache.flink.api.java.typeutils.ValueTypeInfo.LONG_VALUE_TYPE_INFO;
+
 public class GraphUtils {
 
 	/**
@@ -50,4 +56,56 @@ public class GraphUtils {
 
 		return checksum;
 	}
+
+	/**
+	 * Count the number of elements in a DataSet.
+	 *
+	 * @param input DataSet of elements to be counted
+	 * @param <T> element type
+	 * @return count
+	 */
+	public static <T> DataSet<LongValue> count(DataSet<T> input) {
+		return input
+			.map(new MapTo<T, LongValue>(new LongValue(1)))
+				.returns(LONG_VALUE_TYPE_INFO)
+			.reduce(new AddLongValue());
+	}
+
+	/**
+	 * Map each element to a value.
+	 *
+	 * @param <I> input type
+	 * @param <O> output type
+	 */
+	public static class MapTo<I, O>
+	implements MapFunction<I, O> {
+		private final O value;
+
+		/**
+		 * Map each element to the given object.
+		 *
+		 * @param value the object to emit for each element
+		 */
+		public MapTo(O value) {
+			this.value = value;
+		}
+
+		@Override
+		public O map(I o) throws Exception {
+			return value;
+		}
+	}
+
+	/**
+	 * Add {@link LongValue} elements.
+	 */
+	public static class AddLongValue
+	implements ReduceFunction<LongValue> {
+		@Override
+		public LongValue reduce(LongValue value1, LongValue value2)
+				throws Exception {
+			value1.setValue(value1.getValue() + value2.getValue());
+			return value1;
+		}
+	}
 }


Mime
View raw message