ignite-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ch...@apache.org
Subject ignite git commit: IGNITE-8668: K-fold cross validation of models
Date Thu, 07 Jun 2018 13:46:27 GMT
Repository: ignite
Updated Branches:
  refs/heads/master 0dc8b4baf -> 17351b434


IGNITE-8668: K-fold cross validation of models

this closes #4143


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/17351b43
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/17351b43
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/17351b43

Branch: refs/heads/master
Commit: 17351b4342bb5f75df1dae9b535ff1c9626b6bd9
Parents: 0dc8b4b
Author: Anton Dmitriev <dmitrievanthony@gmail.com>
Authored: Thu Jun 7 16:46:16 2018 +0300
Committer: Yury Babak <ybabak@gridgain.com>
Committed: Thu Jun 7 16:46:16 2018 +0300

----------------------------------------------------------------------
 .../CrossValidationScoreCalculatorExample.java  | 147 +++++++++++
 .../examples/ml/selection/cv/package-info.java  |  22 ++
 .../cv/CrossValidationScoreCalculator.java      | 255 +++++++++++++++++++
 .../ignite/ml/selection/cv/package-info.java    |  22 ++
 .../score/AccuracyScoreCalculator.java          |  47 ++++
 .../ml/selection/score/ScoreCalculator.java     |  35 +++
 .../ml/selection/score/TruthWithPrediction.java |  52 ++++
 .../ignite/ml/selection/score/package-info.java |  22 ++
 .../CacheBasedTruthWithPredictionCursor.java    | 124 +++++++++
 .../util/LocalTruthWithPredictionCursor.java    | 137 ++++++++++
 .../score/util/TruthWithPredictionCursor.java   |  29 +++
 .../ml/selection/score/util/package-info.java   |  22 ++
 .../org/apache/ignite/ml/IgniteMLTestSuite.java |   4 +-
 .../ignite/ml/selection/SelectionTestSuite.java |  43 ++++
 .../cv/CrossValidationScoreCalculatorTest.java  |  95 +++++++
 .../score/AccuracyScoreCalculatorTest.java      |  44 ++++
 .../score/TestTruthWithPredictionCursor.java    |  91 +++++++
 ...CacheBasedTruthWithPredictionCursorTest.java |  78 ++++++
 .../LocalTruthWithPredictionCursorTest.java     |  54 ++++
 19 files changed, 1322 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationScoreCalculatorExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationScoreCalculatorExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationScoreCalculatorExample.java
new file mode 100644
index 0000000..3cec830
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationScoreCalculatorExample.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.selection.cv;
+
+import java.util.Arrays;
+import java.util.Random;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.examples.ml.tree.DecisionTreeClassificationTrainerExample;
+import org.apache.ignite.ml.selection.cv.CrossValidationScoreCalculator;
+import org.apache.ignite.ml.selection.score.AccuracyScoreCalculator;
+import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
+import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.thread.IgniteThread;
+
+/**
+ * Run decision tree classification with cross validation.
+ *
+ * @see CrossValidationScoreCalculator
+ */
+public class CrossValidationScoreCalculatorExample {
+    /**
+     * Executes example.
+     *
+     * @param args Command line arguments, none required.
+     */
+    public static void main(String... args) throws InterruptedException {
+        System.out.println(">>> Cross validation score calculator example started.");
+
+        // Start ignite grid.
+        try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+            System.out.println(">>> Ignite grid started.");
+
+            IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(),
+                DecisionTreeClassificationTrainerExample.class.getSimpleName(), () -> {
+
+                // Create cache with training data.
+                CacheConfiguration<Integer, LabeledPoint> trainingSetCfg = new CacheConfiguration<>();
+                trainingSetCfg.setName("TRAINING_SET");
+                trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
+
+                IgniteCache<Integer, LabeledPoint> trainingSet = ignite.createCache(trainingSetCfg);
+
+                Random rnd = new Random(0);
+
+                // Fill training data.
+                for (int i = 0; i < 1000; i++)
+                    trainingSet.put(i, generatePoint(rnd));
+
+                // Create classification trainer.
+                DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
+
+                CrossValidationScoreCalculator<DecisionTreeNode, Double, Integer, LabeledPoint> scoreCalculator
+                    = new CrossValidationScoreCalculator<>();
+
+                double[] scores = scoreCalculator.score(
+                    trainer,
+                    new AccuracyScoreCalculator<>(),
+                    ignite,
+                    trainingSet,
+                    (k, v) -> new double[]{v.x, v.y},
+                    (k, v) -> v.lb,
+                    4
+                );
+
+                System.out.println(">>> Accuracy: " + Arrays.toString(scores));
+
+                System.out.println(">>> Cross validation score calculator example completed.");
+            });
+
+            igniteThread.start();
+
+            igniteThread.join();
+        }
+    }
+
+    /**
+     * Generate point with {@code x} in (-0.5, 0.5) and {@code y} in the same interval. If {@code x * y > 0} then label
+     * is 1, otherwise 0.
+     *
+     * @param rnd Random.
+     * @return Point with label.
+     */
+    private static LabeledPoint generatePoint(Random rnd) {
+
+        double x = rnd.nextDouble() - 0.5;
+        double y = rnd.nextDouble() - 0.5;
+
+        return new LabeledPoint(x, y, x * y > 0 ? 1 : 0);
+    }
+
+    /** Point data class. */
+    private static class Point {
+        /** X coordinate. */
+        final double x;
+
+        /** Y coordinate. */
+        final double y;
+
+        /**
+         * Constructs a new instance of point.
+         *
+         * @param x X coordinate.
+         * @param y Y coordinate.
+         */
+        Point(double x, double y) {
+            this.x = x;
+            this.y = y;
+        }
+    }
+
+    /** Labeled point data class. */
+    private static class LabeledPoint extends Point {
+        /** Point label. */
+        final double lb;
+
+        /**
+         * Constructs a new instance of labeled point data.
+         *
+         * @param x X coordinate.
+         * @param y Y coordinate.
+         * @param lb Point label.
+         */
+        LabeledPoint(double x, double y, double lb) {
+            super(x, y);
+            this.lb = lb;
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/package-info.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/package-info.java
new file mode 100644
index 0000000..249d8b8
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * ML cross validation examples.
+ */
+package org.apache.ignite.examples.ml.selection.cv;

http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculator.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculator.java
new file mode 100644
index 0000000..f885c3e
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculator.java
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.cv;
+
+import java.util.Map;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.lang.IgniteBiPredicate;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.selection.score.ScoreCalculator;
+import org.apache.ignite.ml.selection.score.util.CacheBasedTruthWithPredictionCursor;
+import org.apache.ignite.ml.selection.score.util.LocalTruthWithPredictionCursor;
+import org.apache.ignite.ml.selection.score.util.TruthWithPredictionCursor;
+import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper;
+import org.apache.ignite.ml.selection.split.mapper.UniformMapper;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+
+/**
+ * Cross validation score calculator. Cross validation is an approach that allows to avoid overfitting that is made the
+ * following way: the training set is split into k smaller sets. The following procedure is followed for each of the k
+ * “folds”:
+ * <ul>
+ *     <li>A model is trained using k-1 of the folds as training data;</li>
+ *     <li>the resulting model is validated on the remaining part of the data (i.e., it is used as a test set to compute
+ *     a performance measure such as accuracy).</li>
+ * </ul>
+ *
+ * @param <M> Type of model.
+ * @param <L> Type of a label (truth or prediction).
+ * @param <K> Type of a key in {@code upstream} data.
+ * @param <V> Type of a value in {@code upstream} data.
+ */
+public class CrossValidationScoreCalculator<M extends Model<double[], L>, L, K, V> {
+    /**
+     * Computes cross-validated metrics.
+     *
+     * @param trainer Trainer of the model.
+     * @param scoreCalculator Score calculator.
+     * @param ignite Ignite instance.
+     * @param upstreamCache Ignite cache with {@code upstream} data.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param cv Number of folds.
+     * @return Array of scores of the estimator for each run of the cross validation.
+     */
+    public double[] score(DatasetTrainer<M, L> trainer, ScoreCalculator<L> scoreCalculator, Ignite ignite,
+        IgniteCache<K, V> upstreamCache, IgniteBiFunction<K, V, double[]> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor, int cv) {
+        return score(trainer, scoreCalculator, ignite, upstreamCache, (k, v) -> true, featureExtractor, lbExtractor,
+            new SHA256UniformMapper<>(), cv);
+    }
+
+    /**
+     * Computes cross-validated metrics.
+     *
+     * @param trainer Trainer of the model.
+     * @param scoreCalculator Base score calculator.
+     * @param ignite Ignite instance.
+     * @param upstreamCache Ignite cache with {@code upstream} data.
+     * @param filter Base {@code upstream} data filter.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param cv Number of folds.
+     * @return Array of scores of the estimator for each run of the cross validation.
+     */
+    public double[] score(DatasetTrainer<M, L> trainer, ScoreCalculator<L> scoreCalculator, Ignite ignite,
+        IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter,
+        IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv) {
+        return score(trainer, scoreCalculator, ignite, upstreamCache, filter, featureExtractor, lbExtractor,
+            new SHA256UniformMapper<>(), cv);
+    }
+
+    /**
+     * Computes cross-validated metrics.
+     *
+     * @param trainer Trainer of the model.
+     * @param scoreCalculator Base score calculator.
+     * @param ignite Ignite instance.
+     * @param upstreamCache Ignite cache with {@code upstream} data.
+     * @param filter Base {@code upstream} data filter.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param mapper Mapper used to map a key-value pair to a point on the segment (0, 1).
+     * @param cv Number of folds.
+     * @return Array of scores of the estimator for each run of the cross validation.
+     */
+    public double[] score(DatasetTrainer<M, L> trainer, ScoreCalculator<L> scoreCalculator,
+        Ignite ignite, IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter,
+        IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor,
+        UniformMapper<K, V> mapper, int cv) {
+        return score(
+            trainer,
+            predicate -> new CacheBasedDatasetBuilder<>(
+                ignite,
+                upstreamCache,
+                (k, v) -> filter.apply(k, v) && predicate.apply(k, v)
+            ),
+            (predicate, mdl) -> new CacheBasedTruthWithPredictionCursor<>(
+                upstreamCache,
+                (k, v) -> filter.apply(k, v) && !predicate.apply(k, v),
+                featureExtractor,
+                lbExtractor,
+                mdl
+            ),
+            featureExtractor,
+            lbExtractor,
+            scoreCalculator,
+            mapper,
+            cv
+        );
+    }
+
+    /**
+     * Computes cross-validated metrics.
+     *
+     * @param trainer Trainer of the model.
+     * @param scoreCalculator Base score calculator.
+     * @param upstreamMap Map with {@code upstream} data.
+     * @param parts Number of partitions.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param cv Number of folds.
+     * @return Array of scores of the estimator for each run of the cross validation.
+     */
+    public double[] score(DatasetTrainer<M, L> trainer, ScoreCalculator<L> scoreCalculator, Map<K, V> upstreamMap,
+        int parts, IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv) {
+        return score(trainer, scoreCalculator, upstreamMap, (k, v) -> true, parts, featureExtractor, lbExtractor,
+            new SHA256UniformMapper<>(), cv);
+    }
+
+    /**
+     * Computes cross-validated metrics.
+     *
+     * @param trainer Trainer of the model.
+     * @param scoreCalculator Base score calculator.
+     * @param upstreamMap Map with {@code upstream} data.
+     * @param filter Base {@code upstream} data filter.
+     * @param parts Number of partitions.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param cv Number of folds.
+     * @return Array of scores of the estimator for each run of the cross validation.
+     */
+    public double[] score(DatasetTrainer<M, L> trainer, ScoreCalculator<L> scoreCalculator, Map<K, V> upstreamMap,
+        IgniteBiPredicate<K, V> filter, int parts, IgniteBiFunction<K, V, double[]> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor, int cv) {
+        return score(trainer, scoreCalculator, upstreamMap, filter, parts, featureExtractor, lbExtractor,
+            new SHA256UniformMapper<>(), cv);
+    }
+
+    /**
+     * Computes cross-validated metrics.
+     *
+     * @param trainer Trainer of the model.
+     * @param scoreCalculator Base score calculator.
+     * @param upstreamMap Map with {@code upstream} data.
+     * @param filter Base {@code upstream} data filter.
+     * @param parts Number of partitions.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param mapper Mapper used to map a key-value pair to a point on the segment (0, 1).
+     * @param cv Number of folds.
+     * @return Array of scores of the estimator for each run of the cross validation.
+     */
+    public double[] score(DatasetTrainer<M, L> trainer, ScoreCalculator<L> scoreCalculator, Map<K, V> upstreamMap,
+        IgniteBiPredicate<K, V> filter, int parts, IgniteBiFunction<K, V, double[]> featureExtractor,
+        IgniteBiFunction<K, V, L> lbExtractor, UniformMapper<K, V> mapper, int cv) {
+        return score(
+            trainer,
+            predicate -> new LocalDatasetBuilder<>(
+                upstreamMap,
+                (k, v) -> filter.apply(k, v) && predicate.apply(k, v),
+                parts
+            ),
+            (predicate, mdl) -> new LocalTruthWithPredictionCursor<>(
+                upstreamMap,
+                (k, v) -> filter.apply(k, v) && !predicate.apply(k, v),
+                featureExtractor,
+                lbExtractor,
+                mdl
+            ),
+            featureExtractor,
+            lbExtractor,
+            scoreCalculator,
+            mapper,
+            cv
+        );
+    }
+
+    /**
+     * Computes cross-validated metrics.
+     *
+     * @param trainer Trainer of the model.
+     * @param datasetBuilderSupplier Dataset builder supplier.
+     * @param testDataIterSupplier Test data iterator supplier.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param scoreCalculator Base score calculator.
+     * @param mapper Mapper used to map a key-value pair to a point on the segment (0, 1).
+     * @param cv Number of folds.
+     * @return Array of scores of the estimator for each run of the cross validation.
+     */
+    private double[] score(DatasetTrainer<M, L> trainer, Function<IgniteBiPredicate<K, V>,
+        DatasetBuilder<K, V>> datasetBuilderSupplier,
+        BiFunction<IgniteBiPredicate<K, V>, M, TruthWithPredictionCursor<L>> testDataIterSupplier,
+        IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor,
+        ScoreCalculator<L> scoreCalculator, UniformMapper<K, V> mapper, int cv) {
+
+        double[] scores = new double[cv];
+
+        double foldSize = 1.0 / cv;
+        for (int i = 0; i < cv; i++) {
+            double from = foldSize * i;
+            double to = foldSize * (i + 1);
+
+            IgniteBiPredicate<K, V> trainSetFilter = (k, v) -> {
+                double pnt = mapper.map(k, v);
+                return pnt < from || pnt > to;
+            };
+
+            DatasetBuilder<K, V> datasetBuilder = datasetBuilderSupplier.apply(trainSetFilter);
+            M mdl = trainer.fit(datasetBuilder, featureExtractor, lbExtractor);
+
+            try (TruthWithPredictionCursor<L> cursor = testDataIterSupplier.apply(trainSetFilter, mdl)) {
+                scores[i] = scoreCalculator.score(cursor.iterator());
+            }
+            catch (Exception e) {
+                throw new RuntimeException(e);
+            }
+        }
+
+        return scores;
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/package-info.java
new file mode 100644
index 0000000..66f129f
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Root package for cross-validation algorithms.
+ */
+package org.apache.ignite.ml.selection.cv;
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/AccuracyScoreCalculator.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/AccuracyScoreCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/AccuracyScoreCalculator.java
new file mode 100644
index 0000000..c9e61f9
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/AccuracyScoreCalculator.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.score;
+
+import java.util.Iterator;
+
+/**
+ * Accuracy score calculator.
+ *
+ * @param <L> Type of a label (truth or prediction).
+ */
+public class AccuracyScoreCalculator<L> implements ScoreCalculator<L> {
+    /** {@inheritDoc} */
+    @Override public double score(Iterator<TruthWithPrediction<L>> iter) {
+        long totalCnt = 0;
+        long correctCnt = 0;
+
+        while (iter.hasNext()) {
+            TruthWithPrediction<L> e = iter.next();
+
+            L prediction = e.getPrediction();
+            L truth = e.getTruth();
+
+            if (prediction.equals(truth))
+                correctCnt++;
+
+            totalCnt++;
+        }
+
+        return 1.0 * correctCnt / totalCnt;
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/ScoreCalculator.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/ScoreCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/ScoreCalculator.java
new file mode 100644
index 0000000..e792532
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/ScoreCalculator.java
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.score;
+
+import java.util.Iterator;
+
+/**
+ * Base interface for score calculators.
+ *
+ * @param <L> Type of a label (truth or prediction).
+ */
+public interface ScoreCalculator<L> {
+    /**
+     * Calculates score.
+     *
+     * @param iter Iterator that supplies pairs of truth values and predicated.
+     * @return Score.
+     */
+    public double score(Iterator<TruthWithPrediction<L>> iter);
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/TruthWithPrediction.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/TruthWithPrediction.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/TruthWithPrediction.java
new file mode 100644
index 0000000..1b96b61
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/TruthWithPrediction.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.score;
+
+/**
+ * Pair of truth value and predicated by model.
+ *
+ * @param <L> Type of a label (truth or prediction).
+ */
+public class TruthWithPrediction<L> {
+    /** Truth value. */
+    private final L truth;
+
+    /** Predicted value. */
+    private final L prediction;
+
+    /**
+     * Constructs a new instance of truth with prediction.
+     *
+     * @param truth Truth value.
+     * @param prediction Predicted value.
+     */
+    public TruthWithPrediction(L truth, L prediction) {
+        this.truth = truth;
+        this.prediction = prediction;
+    }
+
+    /** */
+    public L getTruth() {
+        return truth;
+    }
+
+    /** */
+    public L getPrediction() {
+        return prediction;
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/package-info.java
new file mode 100644
index 0000000..9656a59
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Root package for score calculators.
+ */
+package org.apache.ignite.ml.selection.score;
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursor.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursor.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursor.java
new file mode 100644
index 0000000..862c7ab
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursor.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.score.util;
+
+import java.util.Iterator;
+import javax.cache.Cache;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.cache.query.QueryCursor;
+import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.lang.IgniteBiPredicate;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.selection.score.TruthWithPrediction;
+import org.jetbrains.annotations.NotNull;
+
+/**
+ * Truth with prediction cursor based on a data stored in Ignite cache.
+ *
+ * @param <L> Type of a label (truth or prediction).
+ * @param <K> Type of a key in {@code upstream} data.
+ * @param <V> Type of a value in {@code upstream} data.
+ */
+public class CacheBasedTruthWithPredictionCursor<L, K, V> implements TruthWithPredictionCursor<L> {
+    /** Query cursor. */
+    private final QueryCursor<Cache.Entry<K, V>> cursor;
+
+    /** Feature extractor. */
+    private final IgniteBiFunction<K, V, double[]> featureExtractor;
+
+    /** Label extractor. */
+    private final IgniteBiFunction<K, V, L> lbExtractor;
+
+    /** Model for inference. */
+    private final Model<double[], L> mdl;
+
+    /**
+     * Constructs a new instance of cache based truth with prediction cursor.
+     *
+     * @param upstreamCache Ignite cache with {@code upstream} data.
+     * @param filter Filter for {@code upstream} data.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param mdl Model for inference.
+     */
+    public CacheBasedTruthWithPredictionCursor(IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter,
+        IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor,
+        Model<double[], L> mdl) {
+        this.cursor = query(upstreamCache, filter);
+        this.featureExtractor = featureExtractor;
+        this.lbExtractor = lbExtractor;
+        this.mdl = mdl;
+    }
+
+    /** {@inheritDoc} */
+    @Override public void close() {
+        cursor.close();
+    }
+
+    /** {@inheritDoc} */
+    @NotNull @Override public Iterator<TruthWithPrediction<L>> iterator() {
+        return new TruthWithPredictionIterator(cursor.iterator());
+    }
+
+    /**
+     * Queries the specified cache using the specified filter.
+     *
+     * @param upstreamCache Ignite cache with {@code upstream} data.
+     * @param filter Filter for {@code upstream} data.
+     * @return Query cursor.
+     */
+    private QueryCursor<Cache.Entry<K, V>> query(IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter) {
+        ScanQuery<K, V> qry = new ScanQuery<>();
+        qry.setFilter(filter);
+
+        return upstreamCache.query(qry);
+    }
+
+    /**
+     * Util iterator that makes predictions using the model.
+     */
+    private class TruthWithPredictionIterator implements Iterator<TruthWithPrediction<L>> {
+        /** Base iterator. */
+        private final Iterator<Cache.Entry<K, V>> iter;
+
+        /**
+         * Constructs a new instance of truth with prediction iterator.
+         *
+         * @param iter Base iterator.
+         */
+        public TruthWithPredictionIterator(Iterator<Cache.Entry<K, V>> iter) {
+            this.iter = iter;
+        }
+
+        /** {@inheritDoc} */
+        @Override public boolean hasNext() {
+            return iter.hasNext();
+        }
+
+        /** {@inheritDoc} */
+        @Override public TruthWithPrediction<L> next() {
+            Cache.Entry<K, V> entry = iter.next();
+
+            double[] features = featureExtractor.apply(entry.getKey(), entry.getValue());
+            L lb = lbExtractor.apply(entry.getKey(), entry.getValue());
+
+            return new TruthWithPrediction<>(lb, mdl.apply(features));
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursor.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursor.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursor.java
new file mode 100644
index 0000000..093c6ed
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursor.java
@@ -0,0 +1,137 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.score.util;
+
+import java.util.Iterator;
+import java.util.Map;
+import java.util.NoSuchElementException;
+import org.apache.ignite.lang.IgniteBiPredicate;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.selection.score.TruthWithPrediction;
+import org.jetbrains.annotations.NotNull;
+
+/**
+ * Truth with prediction cursor based on a locally stored data.
+ *
+ * @param <L> Type of a label (truth or prediction).
+ * @param <K> Type of a key in {@code upstream} data.
+ * @param <V> Type of a value in {@code upstream} data.
+ */
+public class LocalTruthWithPredictionCursor<L, K, V> implements TruthWithPredictionCursor<L> {
+    /** Map with {@code upstream} data. */
+    private final Map<K, V> upstreamMap;
+
+    /** Filter for {@code upstream} data. */
+    private final IgniteBiPredicate<K, V> filter;
+
+    /** Feature extractor. */
+    private final IgniteBiFunction<K, V, double[]> featureExtractor;
+
+    /** Label extractor. */
+    private final IgniteBiFunction<K, V, L> lbExtractor;
+
+    /** Model for inference. */
+    private final Model<double[], L> mdl;
+
+    /**
+     * Constructs a new instance of local truth with prediction cursor.
+     *
+     * @param upstreamMap Map with {@code upstream} data.
+     * @param filter Filter for {@code upstream} data.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @param mdl Model for inference.
+     */
+    public LocalTruthWithPredictionCursor(Map<K, V> upstreamMap, IgniteBiPredicate<K, V> filter,
+        IgniteBiFunction<K, V, double[]> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor,
+        Model<double[], L> mdl) {
+        this.upstreamMap = upstreamMap;
+        this.filter = filter;
+        this.featureExtractor = featureExtractor;
+        this.lbExtractor = lbExtractor;
+        this.mdl = mdl;
+    }
+
+    /** {@inheritDoc} */
+    @Override public void close() {
+        /* Do nothing. */
+    }
+
+    /** {@inheritDoc} */
+    @NotNull @Override public Iterator<TruthWithPrediction<L>> iterator() {
+        return new TruthWithPredictionIterator(upstreamMap.entrySet().iterator());
+    }
+
+    /**
+     * Util iterator that filters map entries and makes predictions using the model.
+     */
+    private class TruthWithPredictionIterator implements Iterator<TruthWithPrediction<L>> {
+        /** Base iterator. */
+        private final Iterator<Map.Entry<K, V>> iter;
+
+        /** Next found entry. */
+        private Map.Entry<K, V> nextEntry;
+
+        /**
+         * Constructs a new instance of truth with prediction iterator.
+         *
+         * @param iter Base iterator.
+         */
+        public TruthWithPredictionIterator(Iterator<Map.Entry<K, V>> iter) {
+            this.iter = iter;
+        }
+
+        /** {@inheritDoc} */
+        @Override public boolean hasNext() {
+            findNext();
+
+            return nextEntry != null;
+        }
+
+        /** {@inheritDoc} */
+        @Override public TruthWithPrediction<L> next() {
+            if (!hasNext())
+                throw new NoSuchElementException();
+
+            K key = nextEntry.getKey();
+            V val = nextEntry.getValue();
+
+            double[] features = featureExtractor.apply(key, val);
+            L lb = lbExtractor.apply(key, val);
+
+            nextEntry = null;
+
+            return new TruthWithPrediction<>(lb, mdl.apply(features));
+        }
+
+        /**
+         * Finds next entry using the specified filter.
+         */
+        private void findNext() {
+            while (nextEntry == null && iter.hasNext()) {
+                Map.Entry<K, V> entry = iter.next();
+
+                if (filter.apply(entry.getKey(), entry.getValue())) {
+                    this.nextEntry = entry;
+                    break;
+                }
+            }
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/TruthWithPredictionCursor.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/TruthWithPredictionCursor.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/TruthWithPredictionCursor.java
new file mode 100644
index 0000000..525d743
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/TruthWithPredictionCursor.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.score.util;
+
+import org.apache.ignite.ml.selection.score.TruthWithPrediction;
+
+/**
+ * Closeable iterable that supplies pairs of truth and predictions (abstraction that hides a difference between querying
+ * data from Ignite cache and from local Map).
+ *
+ * @param <L> Type of a label (truth or prediction).
+ */
+public interface TruthWithPredictionCursor<L> extends Iterable<TruthWithPrediction<L>>, AutoCloseable {
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/package-info.java
new file mode 100644
index 0000000..9c86317
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Util classes used for score calculation.
+ */
+package org.apache.ignite.ml.selection.score.util;
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
index 0c3408e..9f60c48 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
@@ -25,6 +25,7 @@ import org.apache.ignite.ml.math.MathImplMainTestSuite;
 import org.apache.ignite.ml.nn.MLPTestSuite;
 import org.apache.ignite.ml.preprocessing.PreprocessingTestSuite;
 import org.apache.ignite.ml.regressions.RegressionsTestSuite;
+import org.apache.ignite.ml.selection.SelectionTestSuite;
 import org.apache.ignite.ml.svm.SVMTestSuite;
 import org.apache.ignite.ml.tree.DecisionTreeTestSuite;
 import org.junit.runner.RunWith;
@@ -45,7 +46,8 @@ import org.junit.runners.Suite;
     MLPTestSuite.class,
     DatasetTestSuite.class,
     PreprocessingTestSuite.class,
-    GAGridTestSuite.class
+    GAGridTestSuite.class,
+    SelectionTestSuite.class
 })
 public class IgniteMLTestSuite {
     // No-op.

http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java
new file mode 100644
index 0000000..e8db932
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection;
+
+import org.apache.ignite.ml.selection.cv.CrossValidationScoreCalculatorTest;
+import org.apache.ignite.ml.selection.score.AccuracyScoreCalculatorTest;
+import org.apache.ignite.ml.selection.score.util.CacheBasedTruthWithPredictionCursorTest;
+import org.apache.ignite.ml.selection.score.util.LocalTruthWithPredictionCursorTest;
+import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitterTest;
+import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapperTest;
+import org.junit.runner.RunWith;
+import org.junit.runners.Suite;
+
+/**
+ * Test suite for all tests located in org.apache.ignite.ml.selection.* package.
+ */
+@RunWith(Suite.class)
+@Suite.SuiteClasses({
+    CrossValidationScoreCalculatorTest.class,
+    CacheBasedTruthWithPredictionCursorTest.class,
+    LocalTruthWithPredictionCursorTest.class,
+    AccuracyScoreCalculatorTest.class,
+    SHA256UniformMapperTest.class,
+    TrainTestDatasetSplitterTest.class
+})
+public class SelectionTestSuite {
+    // No-op.
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculatorTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculatorTest.java
new file mode 100644
index 0000000..3679f1b
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculatorTest.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.cv;
+
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.selection.score.AccuracyScoreCalculator;
+import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
+import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.junit.Test;
+
+import static junit.framework.TestCase.assertTrue;
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests for {@link CrossValidationScoreCalculator}.
+ */
+public class CrossValidationScoreCalculatorTest {
+    /** */
+    @Test
+    public void testScoreWithGoodDataset() {
+        Map<Integer, Double> data = new HashMap<>();
+
+        for (int i = 0; i < 1000; i++)
+            data.put(i, i > 500 ? 1.0 : 0.0);
+
+        DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0);
+
+        CrossValidationScoreCalculator<DecisionTreeNode, Double, Integer, Double> scoreCalculator =
+            new CrossValidationScoreCalculator<>();
+
+        int folds = 4;
+
+        double[] scores = scoreCalculator.score(
+            trainer,
+            new AccuracyScoreCalculator<>(),
+            data,
+            1,
+            (k, v) -> new double[]{k},
+            (k, v) -> v,
+            folds
+        );
+
+        assertEquals(folds, scores.length);
+
+        for (int i = 0; i < folds; i++)
+            assertEquals(1, scores[i], 1e-1);
+    }
+
+    /** */
+    @Test
+    public void testScoreWithBadDataset() {
+        Map<Integer, Double> data = new HashMap<>();
+
+        for (int i = 0; i < 1000; i++)
+            data.put(i, i % 2 == 0 ? 1.0 : 0.0);
+
+        DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0);
+
+        CrossValidationScoreCalculator<DecisionTreeNode, Double, Integer, Double> scoreCalculator =
+            new CrossValidationScoreCalculator<>();
+
+        int folds = 4;
+
+        double[] scores = scoreCalculator.score(
+            trainer,
+            new AccuracyScoreCalculator<>(),
+            data,
+            1,
+            (k, v) -> new double[]{k},
+            (k, v) -> v,
+            folds
+        );
+
+        assertEquals(folds, scores.length);
+
+        for (int i = 0; i < folds; i++)
+            assertTrue(scores[i] < 0.6);
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/AccuracyScoreCalculatorTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/AccuracyScoreCalculatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/AccuracyScoreCalculatorTest.java
new file mode 100644
index 0000000..b79bd9a
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/AccuracyScoreCalculatorTest.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.score;
+
+import java.util.Arrays;
+import org.apache.ignite.ml.selection.score.util.TruthWithPredictionCursor;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests for {@link AccuracyScoreCalculator}.
+ */
+public class AccuracyScoreCalculatorTest {
+    /** */
+    @Test
+    public void testScore() {
+        ScoreCalculator<Integer> scoreCalculator = new AccuracyScoreCalculator<>();
+
+        TruthWithPredictionCursor<Integer> cursor = new TestTruthWithPredictionCursor<>(
+            Arrays.asList(1, 1, 1, 1),
+            Arrays.asList(1, 1, 0, 1)
+        );
+
+        double score = scoreCalculator.score(cursor.iterator());
+
+        assertEquals(0.75, score, 1e-12);
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/TestTruthWithPredictionCursor.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/TestTruthWithPredictionCursor.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/TestTruthWithPredictionCursor.java
new file mode 100644
index 0000000..db4f936
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/TestTruthWithPredictionCursor.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.score;
+
+import java.util.Iterator;
+import java.util.List;
+import org.apache.ignite.ml.selection.score.util.TruthWithPredictionCursor;
+import org.jetbrains.annotations.NotNull;
+
+/**
+ * Util truth with prediction cursor to be used in tests.
+ *
+ * @param <L> Type of a label (truth or prediction).
+ */
+public class TestTruthWithPredictionCursor<L> implements TruthWithPredictionCursor<L> {
+    /** List of truth values. */
+    private final List<L> truth;
+
+    /** List of predicted values. */
+    private final List<L> predicted;
+
+    /**
+     * Constructs a new instance of test truth with prediction cursor.
+     *
+     * @param truth List of truth values.
+     * @param predicted List of predicted values.
+     */
+    public TestTruthWithPredictionCursor(List<L> truth, List<L> predicted) {
+        this.truth = truth;
+        this.predicted = predicted;
+    }
+
+    /** {@inheritDoc} */
+    @Override public void close() throws Exception {
+        /* Do nothing. */
+    }
+
+    /** {@inheritDoc} */
+    @NotNull @Override public Iterator<TruthWithPrediction<L>> iterator() {
+        return new TestTruthWithPredictionIterator<>(truth.iterator(), predicted.iterator());
+    }
+
+    /**
+     * Util truth with prediction iterator to be used in tests.
+     *
+     * @param <L> Type of a label (truth or prediction).
+     */
+    private static final class TestTruthWithPredictionIterator<L> implements Iterator<TruthWithPrediction<L>> {
+        /** Iterator of truth values. */
+        private final Iterator<L> truthIter;
+
+        /** Iterator of predicted values. */
+        private final Iterator<L> predictedIter;
+
+        /**
+         * Constructs a new instance of test truth with prediction iterator.
+         *
+         * @param truthIter Iterator of truth values.
+         * @param predictedIter Iterator of predicted values.
+         */
+        public TestTruthWithPredictionIterator(Iterator<L> truthIter, Iterator<L> predictedIter) {
+            this.truthIter = truthIter;
+            this.predictedIter = predictedIter;
+        }
+
+        /** {@inheritDoc} */
+        @Override public boolean hasNext() {
+            return truthIter.hasNext() && predictedIter.hasNext();
+        }
+
+        /** {@inheritDoc} */
+        @Override public TruthWithPrediction<L> next() {
+            return new TruthWithPrediction<>(truthIter.next(), predictedIter.next());
+        }
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursorTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursorTest.java
new file mode 100644
index 0000000..7eba10f
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursorTest.java
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.score.util;
+
+import java.util.UUID;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.selection.score.TruthWithPrediction;
+import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
+
+/**
+ * Tests for {@link CacheBasedTruthWithPredictionCursor}.
+ */
+public class CacheBasedTruthWithPredictionCursorTest extends GridCommonAbstractTest {
+    /** Number of nodes in grid. */
+    private static final int NODE_COUNT = 4;
+
+    /** Ignite instance. */
+    private Ignite ignite;
+
+    /** {@inheritDoc} */
+    @Override protected void beforeTestsStarted() throws Exception {
+        for (int i = 1; i <= NODE_COUNT; i++)
+            startGrid(i);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected void afterTestsStopped() {
+        stopAllGrids();
+    }
+
+    /** {@inheritDoc} */
+    @Override protected void beforeTest() {
+        /* Grid instance. */
+        ignite = grid(NODE_COUNT);
+        ignite.configuration().setPeerClassLoadingEnabled(true);
+        IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
+    }
+
+    /** */
+    public void testIterate() {
+        IgniteCache<Integer, Integer> data = ignite.createCache(UUID.randomUUID().toString());
+
+        for (int i = 0; i < 1000; i++)
+            data.put(i, i);
+
+        TruthWithPredictionCursor<Integer> cursor = new CacheBasedTruthWithPredictionCursor<>(
+            data,
+            (k, v) -> v % 2 == 0,
+            (k, v) -> new double[]{v},
+            (k, v) -> v,
+            arr -> (int)arr[0]
+        );
+
+        int cnt = 0;
+        for (TruthWithPrediction<Integer> e : cursor) {
+            assertEquals(e.getPrediction(), e.getTruth());
+            cnt++;
+        }
+        assertEquals(500, cnt);
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursorTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursorTest.java
new file mode 100644
index 0000000..3fc3c83
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursorTest.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.selection.score.util;
+
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.selection.score.TruthWithPrediction;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Tests for {@link LocalTruthWithPredictionCursor}.
+ */
+public class LocalTruthWithPredictionCursorTest {
+    /** */
+    @Test
+    public void testIterate() {
+        Map<Integer, Integer> data = new HashMap<>();
+
+        for (int i = 0; i < 1000; i++)
+            data.put(i, i);
+
+        TruthWithPredictionCursor<Integer> cursor = new LocalTruthWithPredictionCursor<>(
+            data,
+            (k, v) -> v % 2 == 0,
+            (k, v) -> new double[]{v},
+            (k, v) -> v,
+            arr -> (int)arr[0]
+        );
+
+        int cnt = 0;
+        for (TruthWithPrediction<Integer> e : cursor) {
+            assertEquals(e.getPrediction(), e.getTruth());
+            cnt++;
+        }
+        assertEquals(500, cnt);
+    }
+}


Mime
View raw message