flink-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From va...@apache.org
Subject [2/2] flink git commit: [FLINK-3888] allow registering a custom convergence criterion in delta iterations
Date Fri, 21 Oct 2016 10:34:16 GMT
[FLINK-3888] allow registering a custom convergence criterion in delta iterations

- cleanups in iterations and aggregators code
- add delta convergence criterion in the CollectionExecutor
- add ITCases for delta custom convergence

This closes #2606


Project: http://git-wip-us.apache.org/repos/asf/flink/repo
Commit: http://git-wip-us.apache.org/repos/asf/flink/commit/8085aa98
Tree: http://git-wip-us.apache.org/repos/asf/flink/tree/8085aa98
Diff: http://git-wip-us.apache.org/repos/asf/flink/diff/8085aa98

Branch: refs/heads/master
Commit: 8085aa98333a052553f1155f65b5cc2728eb5ff8
Parents: 3ab97ae
Author: vasia <vasia@apache.org>
Authored: Wed Oct 5 13:49:20 2016 +0200
Committer: vasia <vasia@apache.org>
Committed: Fri Oct 21 12:33:50 2016 +0200

----------------------------------------------------------------------
 .../common/aggregators/AggregatorRegistry.java  |  10 +-
 .../common/operators/CollectionExecutor.java    |  11 ++
 .../api/java/operators/DeltaIteration.java      |  32 +++-
 .../plantranslate/JobGraphGenerator.java        |  13 +-
 .../task/IterationSynchronizationSinkTask.java  |  43 +++--
 .../runtime/operators/util/TaskConfig.java      |  54 +++++-
 .../AggregatorConvergenceITCase.java            | 182 +++++++++----------
 .../aggregators/AggregatorsITCase.java          |  82 ++++++---
 8 files changed, 272 insertions(+), 155 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-core/src/main/java/org/apache/flink/api/common/aggregators/AggregatorRegistry.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/aggregators/AggregatorRegistry.java
b/flink-core/src/main/java/org/apache/flink/api/common/aggregators/AggregatorRegistry.java
index 1d5c358..19663d1 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/aggregators/AggregatorRegistry.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/aggregators/AggregatorRegistry.java
@@ -49,18 +49,14 @@ public class AggregatorRegistry {
 		}
 		this.registry.put(name, aggregator);
 	}
-	
-	public Aggregator<?> unregisterAggregator(String name) {
-		return this.registry.remove(name);
-	}
-	
+
 	public Collection<AggregatorWithName<?>> getAllRegisteredAggregators() {
 		ArrayList<AggregatorWithName<?>> list = new ArrayList<AggregatorWithName<?>>(this.registry.size());
 		
 		for (Map.Entry<String, Aggregator<?>> entry : this.registry.entrySet()) {
 			@SuppressWarnings("unchecked")
 			Aggregator<Value> valAgg = (Aggregator<Value>) entry.getValue();
-			list.add(new AggregatorWithName<Value>(entry.getKey(), valAgg));
+			list.add(new AggregatorWithName<>(entry.getKey(), valAgg));
 		}
 		return list;
 	}
@@ -72,7 +68,7 @@ public class AggregatorRegistry {
 			throw new IllegalArgumentException("Name, aggregator, or convergence criterion must not
be null");
 		}
 		
-		Aggregator<?> genAgg = (Aggregator<?>) aggregator;
+		Aggregator<?> genAgg = aggregator;
 		
 		Aggregator<?> previous = this.registry.get(name);
 		if (previous != null && previous != genAgg) {

http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-core/src/main/java/org/apache/flink/api/common/operators/CollectionExecutor.java
----------------------------------------------------------------------
diff --git a/flink-core/src/main/java/org/apache/flink/api/common/operators/CollectionExecutor.java
b/flink-core/src/main/java/org/apache/flink/api/common/operators/CollectionExecutor.java
index d9240fe..a6fc17e 100644
--- a/flink-core/src/main/java/org/apache/flink/api/common/operators/CollectionExecutor.java
+++ b/flink-core/src/main/java/org/apache/flink/api/common/operators/CollectionExecutor.java
@@ -412,6 +412,9 @@ public class CollectionExecutor {
 			aggregators.put(a.getName(), a.getAggregator());
 		}
 
+		String convCriterionAggName = iteration.getAggregators().getConvergenceCriterionAggregatorName();
+		ConvergenceCriterion<Value> convCriterion = (ConvergenceCriterion<Value>) iteration.getAggregators().getConvergenceCriterion();
+
 		final int maxIterations = iteration.getMaximumNumberOfIterations();
 
 		for (int superstep = 1; superstep <= maxIterations; superstep++) {
@@ -442,6 +445,14 @@ public class CollectionExecutor {
 				break;
 			}
 
+			// evaluate the aggregator convergence criterion
+			if (convCriterion != null && convCriterionAggName != null) {
+				Value v = aggregators.get(convCriterionAggName).getAggregate();
+				if (convCriterion.isConverged(superstep, v)) {
+					break;
+				}
+			}
+
 			// clear the dynamic results
 			for (Operator<?> o : dynamics) {
 				intermediateResults.remove(o);

http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-java/src/main/java/org/apache/flink/api/java/operators/DeltaIteration.java
----------------------------------------------------------------------
diff --git a/flink-java/src/main/java/org/apache/flink/api/java/operators/DeltaIteration.java
b/flink-java/src/main/java/org/apache/flink/api/java/operators/DeltaIteration.java
index d53b499..b97a9de 100644
--- a/flink-java/src/main/java/org/apache/flink/api/java/operators/DeltaIteration.java
+++ b/flink-java/src/main/java/org/apache/flink/api/java/operators/DeltaIteration.java
@@ -26,10 +26,12 @@ import org.apache.flink.api.common.ExecutionConfig;
 import org.apache.flink.api.common.InvalidProgramException;
 import org.apache.flink.api.common.aggregators.Aggregator;
 import org.apache.flink.api.common.aggregators.AggregatorRegistry;
+import org.apache.flink.api.common.aggregators.ConvergenceCriterion;
 import org.apache.flink.api.common.operators.Keys;
 import org.apache.flink.api.common.typeinfo.TypeInformation;
 import org.apache.flink.api.java.DataSet;
 import org.apache.flink.api.java.ExecutionEnvironment;
+import org.apache.flink.types.Value;
 import org.apache.flink.util.Preconditions;
 
 /**
@@ -62,13 +64,13 @@ public class DeltaIteration<ST, WT> {
 	private int parallelism = ExecutionConfig.PARALLELISM_DEFAULT;
 	
 	private boolean solutionSetUnManaged;
-	
-	
+
+
 	public DeltaIteration(ExecutionEnvironment context, TypeInformation<ST> type, DataSet<ST>
solutionSet, DataSet<WT> workset, Keys<ST> keys, int maxIterations) {
 		initialSolutionSet = solutionSet;
 		initialWorkset = workset;
-		solutionSetPlaceholder = new SolutionSetPlaceHolder<ST>(context, solutionSet.getType(),
this);
-		worksetPlaceholder = new WorksetPlaceHolder<WT>(context, workset.getType());
+		solutionSetPlaceholder = new SolutionSetPlaceHolder<>(context, solutionSet.getType(),
this);
+		worksetPlaceholder = new WorksetPlaceHolder<>(context, workset.getType());
 		this.keys = keys;
 		this.maxIterations = maxIterations;
 	}
@@ -210,6 +212,28 @@ public class DeltaIteration<ST, WT> {
 		this.aggregators.registerAggregator(name, aggregator);
 		return this;
 	}
+
+	/**
+	 * Registers an {@link Aggregator} for the iteration together with a {@link ConvergenceCriterion}.
For a general description
+	 * of aggregators, see {@link #registerAggregator(String, Aggregator)} and {@link Aggregator}.
+	 * At the end of each iteration, the convergence criterion takes the aggregator's global
aggregate value and decides whether
+	 * the iteration should terminate. A typical use case is to have an aggregator that sums
up the total error of change
+	 * in an iteration step and have to have a convergence criterion that signals termination
as soon as the aggregate value
+	 * is below a certain threshold.
+	 *
+	 * @param name The name under which the aggregator is registered.
+	 * @param aggregator The aggregator class.
+	 * @param convergenceCheck The convergence criterion.
+	 *
+	 * @return The DeltaIteration itself, to allow chaining function calls.
+	 */
+	@PublicEvolving
+	public <X extends Value> DeltaIteration<ST, WT> registerAggregationConvergenceCriterion(
+			String name, Aggregator<X> aggregator, ConvergenceCriterion<X> convergenceCheck)
+	{
+		this.aggregators.registerAggregationConvergenceCriterion(name, aggregator, convergenceCheck);
+		return this;
+	}
 	
 	/**
 	 * Gets the registry for aggregators for the iteration.

http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java
----------------------------------------------------------------------
diff --git a/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java
b/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java
index 5ab1fbf..4ccfae3 100644
--- a/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java
+++ b/flink-optimizer/src/main/java/org/apache/flink/optimizer/plantranslate/JobGraphGenerator.java
@@ -1513,14 +1513,21 @@ public class JobGraphGenerator implements Visitor<PlanNode>
{
 		
 		String convAggName = aggs.getConvergenceCriterionAggregatorName();
 		ConvergenceCriterion<?> convCriterion = aggs.getConvergenceCriterion();
-		
+
 		if (convCriterion != null || convAggName != null) {
-			throw new CompilerException("Error: Cannot use custom convergence criterion with workset
iteration. Workset iterations have implicit convergence criterion where workset is empty.");
+			if (convCriterion == null) {
+				throw new CompilerException("Error: Convergence criterion aggregator set, but criterion
is null.");
+			}
+			if (convAggName == null) {
+				throw new CompilerException("Error: Aggregator convergence criterion set, but aggregator
is null.");
+			}
+
+			syncConfig.setConvergenceCriterion(convAggName, convCriterion);
 		}
 		
 		headConfig.addIterationAggregator(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new
LongSumAggregator());
 		syncConfig.addIterationAggregator(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new
LongSumAggregator());
-		syncConfig.setConvergenceCriterion(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME, new
WorksetEmptyConvergenceCriterion());
+		syncConfig.setImplicitConvergenceCriterion(WorksetEmptyConvergenceCriterion.AGGREGATOR_NAME,
new WorksetEmptyConvergenceCriterion());
 	}
 	
 	private String getDescriptionForUserCode(UserCodeWrapper<?> wrapper) {

http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java
index 66fb45b..11a8cfa 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/iterative/task/IterationSynchronizationSinkTask.java
@@ -56,11 +56,15 @@ public class IterationSynchronizationSinkTask extends AbstractInvokable
implemen
 	private SyncEventHandler eventHandler;
 
 	private ConvergenceCriterion<Value> convergenceCriterion;
+
+	private ConvergenceCriterion<Value> implicitConvergenceCriterion;
 	
 	private Map<String, Aggregator<?>> aggregators;
 
 	private String convergenceAggregatorName;
 
+	private String implicitConvergenceAggregatorName;
+
 	private int currentIteration = 1;
 	
 	private int maxNumberOfIterations;
@@ -71,14 +75,14 @@ public class IterationSynchronizationSinkTask extends AbstractInvokable
implemen
 	
 	@Override
 	public void invoke() throws Exception {
-		this.headEventReader = new MutableRecordReader<IntValue>(
+		this.headEventReader = new MutableRecordReader<>(
 				getEnvironment().getInputGate(0),
 				getEnvironment().getTaskManagerInfo().getTmpDirectories());
 
 		TaskConfig taskConfig = new TaskConfig(getTaskConfiguration());
 		
 		// store all aggregators
-		this.aggregators = new HashMap<String, Aggregator<?>>();
+		this.aggregators = new HashMap<>();
 		for (AggregatorWithName<?> aggWithName : taskConfig.getIterationAggregators(getUserCodeClassLoader()))
{
 			aggregators.put(aggWithName.getName(), aggWithName.getAggregator());
 		}
@@ -89,6 +93,13 @@ public class IterationSynchronizationSinkTask extends AbstractInvokable
implemen
 			convergenceAggregatorName = taskConfig.getConvergenceCriterionAggregatorName();
 			Preconditions.checkNotNull(convergenceAggregatorName);
 		}
+
+		// store the default aggregator convergence criterion
+		if (taskConfig.usesImplicitConvergenceCriterion()) {
+			implicitConvergenceCriterion = taskConfig.getImplicitConvergenceCriterion(getUserCodeClassLoader());
+			implicitConvergenceAggregatorName = taskConfig.getImplicitConvergenceCriterionAggregatorName();
+			Preconditions.checkNotNull(implicitConvergenceAggregatorName);
+		}
 		
 		maxNumberOfIterations = taskConfig.getNumberOfIterations();
 		
@@ -102,7 +113,6 @@ public class IterationSynchronizationSinkTask extends AbstractInvokable
implemen
 		
 		while (!terminationRequested()) {
 
-//			notifyMonitor(IterationMonitoring.Event.SYNC_STARTING, currentIteration);
 			if (log.isInfoEnabled()) {
 				log.info(formatLogString("starting iteration [" + currentIteration + "]"));
 			}
@@ -122,7 +132,6 @@ public class IterationSynchronizationSinkTask extends AbstractInvokable
implemen
 
 				requestTermination();
 				sendToAllWorkers(new TerminationEvent());
-//				notifyMonitor(IterationMonitoring.Event.SYNC_FINISHED, currentIteration);
 			} else {
 				if (log.isInfoEnabled()) {
 					log.info(formatLogString("signaling that all workers are done in iteration [" + currentIteration
@@ -136,19 +145,11 @@ public class IterationSynchronizationSinkTask extends AbstractInvokable
implemen
 				for (Aggregator<?> agg : aggregators.values()) {
 					agg.reset();
 				}
-				
-//				notifyMonitor(IterationMonitoring.Event.SYNC_FINISHED, currentIteration);
 				currentIteration++;
 			}
 		}
 	}
 
-//	protected void notifyMonitor(IterationMonitoring.Event event, int currentIteration) {
-//		if (log.isInfoEnabled()) {
-//			log.info(IterationMonitoring.logLine(getEnvironment().getJobID(), event, currentIteration,
1));
-//		}
-//	}
-
 	private boolean checkForConvergence() {
 		if (maxNumberOfIterations == currentIteration) {
 			if (log.isInfoEnabled()) {
@@ -175,6 +176,24 @@ public class IterationSynchronizationSinkTask extends AbstractInvokable
implemen
 				return true;
 			}
 		}
+
+		if (implicitConvergenceAggregatorName != null) {
+			@SuppressWarnings("unchecked")
+			Aggregator<Value> aggregator = (Aggregator<Value>) aggregators.get(implicitConvergenceAggregatorName);
+			if (aggregator == null) {
+				throw new RuntimeException("Error: Aggregator for default convergence criterion was null.");
+			}
+
+			Value aggregate = aggregator.getAggregate();
+
+			if (implicitConvergenceCriterion.isConverged(currentIteration, aggregate)) {
+				if (log.isInfoEnabled()) {
+					log.info(formatLogString("empty workset convergence reached after [" + currentIteration
+							+ "] iterations, terminating..."));
+				}
+				return true;
+			}
+		}
 		
 		return false;
 	}

http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java
----------------------------------------------------------------------
diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java
b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java
index b598523..71c0405 100644
--- a/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java
+++ b/flink-runtime/src/main/java/org/apache/flink/runtime/operators/util/TaskConfig.java
@@ -37,6 +37,7 @@ import org.apache.flink.api.common.operators.util.UserCodeWrapper;
 import org.apache.flink.api.common.typeutils.TypeComparatorFactory;
 import org.apache.flink.api.common.typeutils.TypePairComparatorFactory;
 import org.apache.flink.api.common.typeutils.TypeSerializerFactory;
+import org.apache.flink.api.java.operators.DeltaIteration;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.configuration.DelegatingConfiguration;
 import org.apache.flink.core.memory.DataInputViewStreamWrapper;
@@ -197,6 +198,10 @@ public class TaskConfig implements Serializable {
 	private static final String ITERATION_CONVERGENCE_CRITERION = "iterative.terminationCriterion";
 	
 	private static final String ITERATION_CONVERGENCE_CRITERION_AGG_NAME = "iterative.terminationCriterion.agg.name";
+
+	private static final String ITERATION_IMPLICIT_CONVERGENCE_CRITERION = "iterative.implicit.terminationCriterion";
+
+	private static final String ITERATION_IMPLICIT_CONVERGENCE_CRITERION_AGG_NAME = "iterative.implicit.terminationCriterion.agg.name";
 	
 	private static final String ITERATION_NUM_AGGREGATORS = "iterative.num-aggs";
 	
@@ -992,16 +997,31 @@ public class TaskConfig implements Serializable {
 		this.config.setString(ITERATION_CONVERGENCE_CRITERION_AGG_NAME, aggregatorName);
 	}
 
+	/**
+	 * Sets the default convergence criterion of a {@link DeltaIteration}
+	 *
+	 * @param aggregatorName
+	 * @param convCriterion
+	 */
+	public void setImplicitConvergenceCriterion(String aggregatorName, ConvergenceCriterion<?>
convCriterion) {
+		try {
+			InstantiationUtil.writeObjectToConfig(convCriterion, this.config, ITERATION_IMPLICIT_CONVERGENCE_CRITERION);
+		} catch (IOException e) {
+			throw new RuntimeException("Error while writing the implicit convergence criterion object
to the task configuration.");
+		}
+		this.config.setString(ITERATION_IMPLICIT_CONVERGENCE_CRITERION_AGG_NAME, aggregatorName);
+	}
+
 	@SuppressWarnings("unchecked")
 	public <T extends Value> ConvergenceCriterion<T> getConvergenceCriterion(ClassLoader
cl) {
-		ConvergenceCriterion<T> convCriterionObj = null;
+		ConvergenceCriterion<T> convCriterionObj;
 		try {
-			convCriterionObj = (ConvergenceCriterion<T>) InstantiationUtil.readObjectFromConfig(
+			convCriterionObj = InstantiationUtil.readObjectFromConfig(
 			this.config, ITERATION_CONVERGENCE_CRITERION, cl);
 		} catch (IOException e) {
-			throw new RuntimeException("Error while reading the covergence criterion object from the
task configuration.");
+			throw new RuntimeException("Error while reading the convergence criterion object from
the task configuration.");
 		} catch (ClassNotFoundException e) {
-			throw new RuntimeException("Error while reading the covergence criterion object from the
task configuration. " +
+			throw new RuntimeException("Error while reading the convergence criterion object from
the task configuration. " +
 					"ConvergenceCriterion class not found.");
 		}
 		if (convCriterionObj == null) {
@@ -1017,6 +1037,32 @@ public class TaskConfig implements Serializable {
 	public String getConvergenceCriterionAggregatorName() {
 		return this.config.getString(ITERATION_CONVERGENCE_CRITERION_AGG_NAME, null);
 	}
+
+	@SuppressWarnings("unchecked")
+	public <T extends Value> ConvergenceCriterion<T> getImplicitConvergenceCriterion(ClassLoader
cl) {
+		ConvergenceCriterion<T> convCriterionObj;
+		try {
+			convCriterionObj = InstantiationUtil.readObjectFromConfig(
+					this.config, ITERATION_IMPLICIT_CONVERGENCE_CRITERION, cl);
+		} catch (IOException e) {
+			throw new RuntimeException("Error while reading the default convergence criterion object
from the task configuration.");
+		} catch (ClassNotFoundException e) {
+			throw new RuntimeException("Error while reading the default convergence criterion object
from the task configuration. " +
+					"ConvergenceCriterion class not found.");
+		}
+		if (convCriterionObj == null) {
+			throw new NullPointerException();
+		}
+		return convCriterionObj;
+	}
+
+	public boolean usesImplicitConvergenceCriterion() {
+		return config.getBytes(ITERATION_IMPLICIT_CONVERGENCE_CRITERION, null) != null;
+	}
+
+	public String getImplicitConvergenceCriterionAggregatorName() {
+		return this.config.getString(ITERATION_IMPLICIT_CONVERGENCE_CRITERION_AGG_NAME, null);
+	}
 	
 	public void setIsSolutionSetUpdate() {
 		this.config.setBoolean(ITERATION_SOLUTION_SET_UPDATE, true);

http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase.java
b/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase.java
index 941b31b..7bade80 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorConvergenceITCase.java
@@ -26,6 +26,7 @@ import org.apache.flink.api.common.aggregators.ConvergenceCriterion;
 import org.apache.flink.api.common.aggregators.LongSumAggregator;
 import org.apache.flink.api.common.functions.RichFlatMapFunction;
 import org.apache.flink.api.common.functions.RichJoinFunction;
+import org.apache.flink.api.java.operators.DeltaIteration;
 import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.configuration.Configuration;
 import org.apache.flink.test.util.JavaProgramTestBase;
@@ -52,47 +53,59 @@ public class AggregatorConvergenceITCase extends MultipleProgramsTestBase
{
 	public AggregatorConvergenceITCase(TestExecutionMode mode) {
 		super(mode);
 	}
-	
+
+	final List<Tuple2<Long, Long>> verticesInput = Arrays.asList(
+			new Tuple2<>(1l,1l),
+			new Tuple2<>(2l,2l),
+			new Tuple2<>(3l,3l),
+			new Tuple2<>(4l,4l),
+			new Tuple2<>(5l,5l),
+			new Tuple2<>(6l,6l),
+			new Tuple2<>(7l,7l),
+			new Tuple2<>(8l,8l),
+			new Tuple2<>(9l,9l)
+	);
+
+	final List<Tuple2<Long, Long>> edgesInput = Arrays.asList(
+			new Tuple2<>(1l,2l),
+			new Tuple2<>(1l,3l),
+			new Tuple2<>(2l,3l),
+			new Tuple2<>(2l,4l),
+			new Tuple2<>(2l,1l),
+			new Tuple2<>(3l,1l),
+			new Tuple2<>(3l,2l),
+			new Tuple2<>(4l,2l),
+			new Tuple2<>(4l,6l),
+			new Tuple2<>(5l,6l),
+			new Tuple2<>(6l,4l),
+			new Tuple2<>(6l,5l),
+			new Tuple2<>(7l,8l),
+			new Tuple2<>(7l,9l),
+			new Tuple2<>(8l,7l),
+			new Tuple2<>(8l,9l),
+			new Tuple2<>(9l,7l),
+			new Tuple2<>(9l,8l)
+	);
+
+	final List<Tuple2<Long, Long>> expectedResult = Arrays.asList(
+			new Tuple2<>(1L,1L),
+			new Tuple2<>(2L,1L),
+			new Tuple2<>(3L,1L),
+			new Tuple2<>(4L,1L),
+			new Tuple2<>(5L,2L),
+			new Tuple2<>(6L,1L),
+			new Tuple2<>(7L,7L),
+			new Tuple2<>(8L,7L),
+			new Tuple2<>(9L,7L)
+	);
+
 	@Test
-	public void testConnectedComponentsWithParametrizableConvergence() {
-		try {
-			List<Tuple2<Long, Long>> verticesInput = Arrays.asList(
-					new Tuple2<Long, Long>(1l,1l),
-					new Tuple2<Long, Long>(2l,2l),
-					new Tuple2<Long, Long>(3l,3l),
-					new Tuple2<Long, Long>(4l,4l),
-					new Tuple2<Long, Long>(5l,5l),
-					new Tuple2<Long, Long>(6l,6l),
-					new Tuple2<Long, Long>(7l,7l),
-					new Tuple2<Long, Long>(8l,8l),
-					new Tuple2<Long, Long>(9l,9l)
-			);
-			
-			List<Tuple2<Long, Long>> edgesInput = Arrays.asList(
-					new Tuple2<Long, Long>(1l,2l),
-					new Tuple2<Long, Long>(1l,3l),
-					new Tuple2<Long, Long>(2l,3l),
-					new Tuple2<Long, Long>(2l,4l),
-					new Tuple2<Long, Long>(2l,1l),
-					new Tuple2<Long, Long>(3l,1l),
-					new Tuple2<Long, Long>(3l,2l),
-					new Tuple2<Long, Long>(4l,2l),
-					new Tuple2<Long, Long>(4l,6l),
-					new Tuple2<Long, Long>(5l,6l),
-					new Tuple2<Long, Long>(6l,4l),
-					new Tuple2<Long, Long>(6l,5l),
-					new Tuple2<Long, Long>(7l,8l),
-					new Tuple2<Long, Long>(7l,9l),
-					new Tuple2<Long, Long>(8l,7l),
-					new Tuple2<Long, Long>(8l,9l),
-					new Tuple2<Long, Long>(9l,7l),
-					new Tuple2<Long, Long>(9l,8l)
-			);
+	public void testConnectedComponentsWithParametrizableConvergence() throws Exception {
 
 			// name of the aggregator that checks for convergence
 			final String UPDATED_ELEMENTS = "updated.elements.aggr";
 
-			// the iteration stops if less than this number os elements change value
+			// the iteration stops if less than this number of elements change value
 			final long convergence_threshold = 3;
 
 			final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
@@ -100,8 +113,7 @@ public class AggregatorConvergenceITCase extends MultipleProgramsTestBase
{
 			DataSet<Tuple2<Long, Long>> initialSolutionSet = env.fromCollection(verticesInput);
 			DataSet<Tuple2<Long, Long>> edges = env.fromCollection(edgesInput);
 
-			IterativeDataSet<Tuple2<Long, Long>> iteration =
-					initialSolutionSet.iterate(10);
+			IterativeDataSet<Tuple2<Long, Long>> iteration = initialSolutionSet.iterate(10);
 
 			// register the convergence criterion
 			iteration.registerAggregationConvergenceCriterion(UPDATED_ELEMENTS,
@@ -117,62 +129,47 @@ public class AggregatorConvergenceITCase extends MultipleProgramsTestBase
{
 
 			List<Tuple2<Long, Long>> result = iteration.closeWith(updatedComponentId).collect();
 			Collections.sort(result, new JavaProgramTestBase.TupleComparator<Tuple2<Long, Long>>());
-
-			List<Tuple2<Long, Long>> expectedResult = Arrays.asList(
-					new Tuple2<Long, Long>(1L,1L),
-					new Tuple2<Long, Long>(2L,1L),
-					new Tuple2<Long, Long>(3L,1L),
-					new Tuple2<Long, Long>(4L,1L),
-					new Tuple2<Long, Long>(5L,2L),
-					new Tuple2<Long, Long>(6L,1L),
-					new Tuple2<Long, Long>(7L,7L),
-					new Tuple2<Long, Long>(8L,7L),
-					new Tuple2<Long, Long>(9L,7L)
-			);
 			
 			assertEquals(expectedResult, result);
-		}
-		catch (Exception e) {
-			e.printStackTrace();
-			fail(e.getMessage());
-		}
+	}
+
+	@Test
+	public void testDeltaConnectedComponentsWithParametrizableConvergence() throws Exception
{
+
+			// name of the aggregator that checks for convergence
+			final String UPDATED_ELEMENTS = "updated.elements.aggr";
+
+			// the iteration stops if less than this number of elements change value
+			final long convergence_threshold = 3;
+
+			final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+
+			DataSet<Tuple2<Long, Long>> initialSolutionSet = env.fromCollection(verticesInput);
+			DataSet<Tuple2<Long, Long>> edges = env.fromCollection(edgesInput);
+
+			DeltaIteration<Tuple2<Long, Long>, Tuple2<Long, Long>> iteration =
+					initialSolutionSet.iterateDelta(initialSolutionSet, 10, 0);
+
+			// register the convergence criterion
+			iteration.registerAggregationConvergenceCriterion(UPDATED_ELEMENTS,
+					new LongSumAggregator(), new UpdatedElementsConvergenceCriterion(convergence_threshold));
+
+			DataSet<Tuple2<Long, Long>> verticesWithNewComponents = iteration.getWorkset().join(edges).where(0).equalTo(0)
+					.with(new NeighborWithComponentIDJoin())
+					.groupBy(0).min(1);
+
+			DataSet<Tuple2<Long, Long>> updatedComponentId =
+					verticesWithNewComponents.join(iteration.getSolutionSet()).where(0).equalTo(0)
+							.flatMap(new MinimumIdFilter(UPDATED_ELEMENTS));
+
+			List<Tuple2<Long, Long>> result = iteration.closeWith(updatedComponentId,
updatedComponentId).collect();
+			Collections.sort(result, new JavaProgramTestBase.TupleComparator<Tuple2<Long, Long>>());
+
+			assertEquals(expectedResult, result);
 	}
 	
 	@Test
-	public void testParameterizableAggregator() {
-		try {
-			List<Tuple2<Long, Long>> verticesInput = Arrays.asList(
-				new Tuple2<Long, Long>(1l,1l),
-				new Tuple2<Long, Long>(2l,2l),
-				new Tuple2<Long, Long>(3l,3l),
-				new Tuple2<Long, Long>(4l,4l),
-				new Tuple2<Long, Long>(5l,5l),
-				new Tuple2<Long, Long>(6l,6l),
-				new Tuple2<Long, Long>(7l,7l),
-				new Tuple2<Long, Long>(8l,8l),
-				new Tuple2<Long, Long>(9l,9l)
-			);
-			
-			List<Tuple2<Long, Long>> edgesInput = Arrays.asList(
-					new Tuple2<>(1l,2l),
-					new Tuple2<>(1l,3l),
-					new Tuple2<>(2l,3l),
-					new Tuple2<>(2l,4l),
-					new Tuple2<>(2l,1l),
-					new Tuple2<>(3l,1l),
-					new Tuple2<>(3l,2l),
-					new Tuple2<>(4l,2l),
-					new Tuple2<>(4l,6l),
-					new Tuple2<>(5l,6l),
-					new Tuple2<>(6l,4l),
-					new Tuple2<>(6l,5l),
-					new Tuple2<>(7l,8l),
-					new Tuple2<>(7l,9l),
-					new Tuple2<>(8l,7l),
-					new Tuple2<>(8l,9l),
-					new Tuple2<>(9l,7l),
-					new Tuple2<>(9l,8l)
-			);
+	public void testParameterizableAggregator() throws Exception {
 
 			final int MAX_ITERATIONS = 5;
 			final String AGGREGATOR_NAME = "elements.in.component.aggregator";
@@ -213,7 +210,7 @@ public class AggregatorConvergenceITCase extends MultipleProgramsTestBase
{
 					new Tuple2<>(9L,7L)
 			);
 
-			// checkpogram result
+			// check program result
 			assertEquals(expectedResult, result);
 
 			// check aggregators
@@ -226,11 +223,6 @@ public class AggregatorConvergenceITCase extends MultipleProgramsTestBase
{
 			assertEquals(4, aggr_values[1]);
 			assertEquals(5, aggr_values[2]);
 			assertEquals(6, aggr_values[3]);
-		}
-		catch (Exception e) {
-			e.printStackTrace();
-			fail(e.getMessage());
-		}
 	}
 	
 	// ------------------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/flink/blob/8085aa98/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorsITCase.java
----------------------------------------------------------------------
diff --git a/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorsITCase.java
b/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorsITCase.java
index 4c5e955..042617d 100644
--- a/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorsITCase.java
+++ b/flink-tests/src/test/java/org/apache/flink/test/iterative/aggregators/AggregatorsITCase.java
@@ -272,6 +272,44 @@ public class AggregatorsITCase extends MultipleProgramsTestBase {
 				+ "5\n" + "5\n" + "5\n" + "5\n" + "5\n";
 	}
 
+	@Test
+	public void testConvergenceCriterionWithParameterForIterateDelta() throws Exception {
+		/*
+		 * Test convergence criterion with parameter for iterate delta
+		 */
+
+		final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
+		env.setParallelism(parallelism);
+
+		DataSet<Tuple2<Integer, Integer>> initialSolutionSet = CollectionDataSets.getIntegerDataSet(env).map(new
TupleMakerMap());
+
+		DeltaIteration<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> iteration
= initialSolutionSet.iterateDelta(
+				initialSolutionSet, MAX_ITERATIONS, 0);
+
+		// register aggregator
+		LongSumAggregator aggr = new LongSumAggregator();
+		iteration.registerAggregator(NEGATIVE_ELEMENTS_AGGR, aggr);
+
+		// register convergence criterion
+		iteration.registerAggregationConvergenceCriterion(NEGATIVE_ELEMENTS_AGGR, aggr,
+				new NegativeElementsConvergenceCriterionWithParam(3));
+
+		DataSet<Tuple2<Integer, Integer>> updatedDs = iteration.getWorkset().map(new
AggregateAndSubtractOneDelta());
+
+		DataSet<Tuple2<Integer, Integer>> newElements = updatedDs.join(iteration.getSolutionSet())
+				.where(0).equalTo(0).projectFirst(0, 1);
+
+		DataSet<Tuple2<Integer, Integer>> iterationRes = iteration.closeWith(newElements,
newElements);
+		DataSet<Integer> result = iterationRes.map(new ProjectSecondMapper());
+		result.writeAsText(resultPath);
+
+		env.execute();
+
+		expected = "-3\n" + "-2\n" + "-2\n" + "-1\n" + "-1\n"
+				+ "-1\n" + "0\n" + "0\n" + "0\n" + "0\n"
+				+ "1\n" + "1\n" + "1\n" + "1\n" + "1\n";
+	}
+
 	@SuppressWarnings("serial")
 	public static final class NegativeElementsConvergenceCriterion implements ConvergenceCriterion<LongValue>
{
 
@@ -313,9 +351,9 @@ public class AggregatorsITCase extends MultipleProgramsTestBase {
 
 		@Override
 		public Integer map(Integer value) {
-			Integer newValue = Integer.valueOf(value.intValue() - 1);
+			Integer newValue = value - 1;
 			// count negative numbers
-			if (newValue.intValue() < 0) {
+			if (newValue < 0) {
 				aggr.aggregate(1l);
 			}
 			return newValue;
@@ -334,9 +372,9 @@ public class AggregatorsITCase extends MultipleProgramsTestBase {
 
 		@Override
 		public Integer map(Integer value) {
-			Integer newValue = Integer.valueOf(value.intValue() - 1);
-			// count numbers less then the aggregator parameter
-			if ( newValue.intValue() < aggr.getValue() ) {
+			Integer newValue = value - 1;
+			// count numbers less than the aggregator parameter
+			if ( newValue < aggr.getValue() ) {
 				aggr.aggregate(1l);
 			}
 			return newValue;
@@ -369,8 +407,8 @@ public class AggregatorsITCase extends MultipleProgramsTestBase {
 
 		@Override
 		public Tuple2<Integer, Integer> map(Integer value) {
-			Integer nodeId = Integer.valueOf(rnd.nextInt(100000));
-			return new Tuple2<Integer, Integer>(nodeId, value);
+			Integer nodeId = rnd.nextInt(100000);
+			return new Tuple2<>(nodeId, value);
 		}
 
 	}
@@ -398,7 +436,7 @@ public class AggregatorsITCase extends MultipleProgramsTestBase {
 		@Override
 		public Tuple2<Integer, Integer> map(Tuple2<Integer, Integer> value) {
 			// count the elements that are equal to the superstep number
-			if (value.f1.intValue() == superstep) {
+			if (value.f1 == superstep) {
 				aggr.aggregate(1l);
 			}
 			return value;
@@ -436,48 +474,32 @@ public class AggregatorsITCase extends MultipleProgramsTestBase {
 	}
 
 	@SuppressWarnings("serial")
-	public static final class AggregateMapDeltaWithParam extends RichMapFunction<Tuple2<Integer,
Integer>, Tuple2<Integer, Integer>> {
+	public static final class AggregateAndSubtractOneDelta extends RichMapFunction<Tuple2<Integer,
Integer>, Tuple2<Integer, Integer>> {
 
-		private LongSumAggregatorWithParameter aggr;
+		private LongSumAggregator aggr;
 		private LongValue previousAggr;
 		private int superstep;
 
 		@Override
 		public void open(Configuration conf) {
-
 			aggr = getIterationRuntimeContext().getIterationAggregator(NEGATIVE_ELEMENTS_AGGR);
 			superstep = getIterationRuntimeContext().getSuperstepNumber();
 
 			if (superstep > 1) {
 				previousAggr = getIterationRuntimeContext().getPreviousIterationAggregate(NEGATIVE_ELEMENTS_AGGR);
-
 				// check previous aggregator value
-				switch(superstep) {
-					case 2: {
-						Assert.assertEquals(6, previousAggr.getValue());
-					}
-					case 3: {
-						Assert.assertEquals(5, previousAggr.getValue());
-					}
-					case 4: {
-						Assert.assertEquals(3, previousAggr.getValue());
-					}
-					case 5: {
-						Assert.assertEquals(0, previousAggr.getValue());
-					}
-					default:
-				}
-				Assert.assertEquals(superstep-1, previousAggr.getValue());
+				Assert.assertEquals(superstep - 1, previousAggr.getValue());
 			}
 
 		}
 
 		@Override
 		public Tuple2<Integer, Integer> map(Tuple2<Integer, Integer> value) {
-			// count the elements that are equal to the superstep number
-			if (value.f1.intValue() < aggr.getValue()) {
+			// count the ones
+			if (value.f1 == 1) {
 				aggr.aggregate(1l);
 			}
+			value.f1--;
 			return value;
 		}
 	}


Mime
View raw message