From commits-return-118574-archive-asf-public=cust-asf.ponee.io@ignite.apache.org Thu Jun 7 15:46:30 2018 Return-Path: X-Original-To: archive-asf-public@cust-asf.ponee.io Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by mx-eu-01.ponee.io (Postfix) with SMTP id 7FE8A180674 for ; Thu, 7 Jun 2018 15:46:28 +0200 (CEST) Received: (qmail 60860 invoked by uid 500); 7 Jun 2018 13:46:27 -0000 Mailing-List: contact commits-help@ignite.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@ignite.apache.org Delivered-To: mailing list commits@ignite.apache.org Received: (qmail 60642 invoked by uid 99); 7 Jun 2018 13:46:27 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Thu, 07 Jun 2018 13:46:27 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 4F7BBE04A3; Thu, 7 Jun 2018 13:46:27 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 8bit From: chief@apache.org To: commits@ignite.apache.org Message-Id: X-Mailer: ASF-Git Admin Mailer Subject: ignite git commit: IGNITE-8668: K-fold cross validation of models Date: Thu, 7 Jun 2018 13:46:27 +0000 (UTC) Repository: ignite Updated Branches: refs/heads/master 0dc8b4baf -> 17351b434 IGNITE-8668: K-fold cross validation of models this closes #4143 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/17351b43 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/17351b43 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/17351b43 Branch: refs/heads/master Commit: 17351b4342bb5f75df1dae9b535ff1c9626b6bd9 Parents: 0dc8b4b Author: Anton Dmitriev Authored: Thu Jun 7 16:46:16 2018 +0300 Committer: Yury Babak Committed: Thu Jun 7 16:46:16 2018 +0300 ---------------------------------------------------------------------- .../CrossValidationScoreCalculatorExample.java | 147 +++++++++++ .../examples/ml/selection/cv/package-info.java | 22 ++ .../cv/CrossValidationScoreCalculator.java | 255 +++++++++++++++++++ .../ignite/ml/selection/cv/package-info.java | 22 ++ .../score/AccuracyScoreCalculator.java | 47 ++++ .../ml/selection/score/ScoreCalculator.java | 35 +++ .../ml/selection/score/TruthWithPrediction.java | 52 ++++ .../ignite/ml/selection/score/package-info.java | 22 ++ .../CacheBasedTruthWithPredictionCursor.java | 124 +++++++++ .../util/LocalTruthWithPredictionCursor.java | 137 ++++++++++ .../score/util/TruthWithPredictionCursor.java | 29 +++ .../ml/selection/score/util/package-info.java | 22 ++ .../org/apache/ignite/ml/IgniteMLTestSuite.java | 4 +- .../ignite/ml/selection/SelectionTestSuite.java | 43 ++++ .../cv/CrossValidationScoreCalculatorTest.java | 95 +++++++ .../score/AccuracyScoreCalculatorTest.java | 44 ++++ .../score/TestTruthWithPredictionCursor.java | 91 +++++++ ...CacheBasedTruthWithPredictionCursorTest.java | 78 ++++++ .../LocalTruthWithPredictionCursorTest.java | 54 ++++ 19 files changed, 1322 insertions(+), 1 deletion(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationScoreCalculatorExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationScoreCalculatorExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationScoreCalculatorExample.java new file mode 100644 index 0000000..3cec830 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationScoreCalculatorExample.java @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.selection.cv; + +import java.util.Arrays; +import java.util.Random; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.examples.ml.tree.DecisionTreeClassificationTrainerExample; +import org.apache.ignite.ml.selection.cv.CrossValidationScoreCalculator; +import org.apache.ignite.ml.selection.score.AccuracyScoreCalculator; +import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; +import org.apache.ignite.ml.tree.DecisionTreeNode; +import org.apache.ignite.thread.IgniteThread; + +/** + * Run decision tree classification with cross validation. + * + * @see CrossValidationScoreCalculator + */ +public class CrossValidationScoreCalculatorExample { + /** + * Executes example. + * + * @param args Command line arguments, none required. + */ + public static void main(String... args) throws InterruptedException { + System.out.println(">>> Cross validation score calculator example started."); + + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + DecisionTreeClassificationTrainerExample.class.getSimpleName(), () -> { + + // Create cache with training data. + CacheConfiguration trainingSetCfg = new CacheConfiguration<>(); + trainingSetCfg.setName("TRAINING_SET"); + trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10)); + + IgniteCache trainingSet = ignite.createCache(trainingSetCfg); + + Random rnd = new Random(0); + + // Fill training data. + for (int i = 0; i < 1000; i++) + trainingSet.put(i, generatePoint(rnd)); + + // Create classification trainer. + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0); + + CrossValidationScoreCalculator scoreCalculator + = new CrossValidationScoreCalculator<>(); + + double[] scores = scoreCalculator.score( + trainer, + new AccuracyScoreCalculator<>(), + ignite, + trainingSet, + (k, v) -> new double[]{v.x, v.y}, + (k, v) -> v.lb, + 4 + ); + + System.out.println(">>> Accuracy: " + Arrays.toString(scores)); + + System.out.println(">>> Cross validation score calculator example completed."); + }); + + igniteThread.start(); + + igniteThread.join(); + } + } + + /** + * Generate point with {@code x} in (-0.5, 0.5) and {@code y} in the same interval. If {@code x * y > 0} then label + * is 1, otherwise 0. + * + * @param rnd Random. + * @return Point with label. + */ + private static LabeledPoint generatePoint(Random rnd) { + + double x = rnd.nextDouble() - 0.5; + double y = rnd.nextDouble() - 0.5; + + return new LabeledPoint(x, y, x * y > 0 ? 1 : 0); + } + + /** Point data class. */ + private static class Point { + /** X coordinate. */ + final double x; + + /** Y coordinate. */ + final double y; + + /** + * Constructs a new instance of point. + * + * @param x X coordinate. + * @param y Y coordinate. + */ + Point(double x, double y) { + this.x = x; + this.y = y; + } + } + + /** Labeled point data class. */ + private static class LabeledPoint extends Point { + /** Point label. */ + final double lb; + + /** + * Constructs a new instance of labeled point data. + * + * @param x X coordinate. + * @param y Y coordinate. + * @param lb Point label. + */ + LabeledPoint(double x, double y, double lb) { + super(x, y); + this.lb = lb; + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/package-info.java ---------------------------------------------------------------------- diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/package-info.java new file mode 100644 index 0000000..249d8b8 --- /dev/null +++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * + * ML cross validation examples. + */ +package org.apache.ignite.examples.ml.selection.cv; http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculator.java new file mode 100644 index 0000000..f885c3e --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculator.java @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.cv; + +import java.util.Map; +import java.util.function.BiFunction; +import java.util.function.Function; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.lang.IgniteBiPredicate; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.dataset.DatasetBuilder; +import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder; +import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.selection.score.ScoreCalculator; +import org.apache.ignite.ml.selection.score.util.CacheBasedTruthWithPredictionCursor; +import org.apache.ignite.ml.selection.score.util.LocalTruthWithPredictionCursor; +import org.apache.ignite.ml.selection.score.util.TruthWithPredictionCursor; +import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper; +import org.apache.ignite.ml.selection.split.mapper.UniformMapper; +import org.apache.ignite.ml.trainers.DatasetTrainer; + +/** + * Cross validation score calculator. Cross validation is an approach that allows to avoid overfitting that is made the + * following way: the training set is split into k smaller sets. The following procedure is followed for each of the k + * “folds”: + *
    + *
  • A model is trained using k-1 of the folds as training data;
  • + *
  • the resulting model is validated on the remaining part of the data (i.e., it is used as a test set to compute + * a performance measure such as accuracy).
  • + *
+ * + * @param Type of model. + * @param Type of a label (truth or prediction). + * @param Type of a key in {@code upstream} data. + * @param Type of a value in {@code upstream} data. + */ +public class CrossValidationScoreCalculator, L, K, V> { + /** + * Computes cross-validated metrics. + * + * @param trainer Trainer of the model. + * @param scoreCalculator Score calculator. + * @param ignite Ignite instance. + * @param upstreamCache Ignite cache with {@code upstream} data. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param cv Number of folds. + * @return Array of scores of the estimator for each run of the cross validation. + */ + public double[] score(DatasetTrainer trainer, ScoreCalculator scoreCalculator, Ignite ignite, + IgniteCache upstreamCache, IgniteBiFunction featureExtractor, + IgniteBiFunction lbExtractor, int cv) { + return score(trainer, scoreCalculator, ignite, upstreamCache, (k, v) -> true, featureExtractor, lbExtractor, + new SHA256UniformMapper<>(), cv); + } + + /** + * Computes cross-validated metrics. + * + * @param trainer Trainer of the model. + * @param scoreCalculator Base score calculator. + * @param ignite Ignite instance. + * @param upstreamCache Ignite cache with {@code upstream} data. + * @param filter Base {@code upstream} data filter. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param cv Number of folds. + * @return Array of scores of the estimator for each run of the cross validation. + */ + public double[] score(DatasetTrainer trainer, ScoreCalculator scoreCalculator, Ignite ignite, + IgniteCache upstreamCache, IgniteBiPredicate filter, + IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor, int cv) { + return score(trainer, scoreCalculator, ignite, upstreamCache, filter, featureExtractor, lbExtractor, + new SHA256UniformMapper<>(), cv); + } + + /** + * Computes cross-validated metrics. + * + * @param trainer Trainer of the model. + * @param scoreCalculator Base score calculator. + * @param ignite Ignite instance. + * @param upstreamCache Ignite cache with {@code upstream} data. + * @param filter Base {@code upstream} data filter. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param mapper Mapper used to map a key-value pair to a point on the segment (0, 1). + * @param cv Number of folds. + * @return Array of scores of the estimator for each run of the cross validation. + */ + public double[] score(DatasetTrainer trainer, ScoreCalculator scoreCalculator, + Ignite ignite, IgniteCache upstreamCache, IgniteBiPredicate filter, + IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor, + UniformMapper mapper, int cv) { + return score( + trainer, + predicate -> new CacheBasedDatasetBuilder<>( + ignite, + upstreamCache, + (k, v) -> filter.apply(k, v) && predicate.apply(k, v) + ), + (predicate, mdl) -> new CacheBasedTruthWithPredictionCursor<>( + upstreamCache, + (k, v) -> filter.apply(k, v) && !predicate.apply(k, v), + featureExtractor, + lbExtractor, + mdl + ), + featureExtractor, + lbExtractor, + scoreCalculator, + mapper, + cv + ); + } + + /** + * Computes cross-validated metrics. + * + * @param trainer Trainer of the model. + * @param scoreCalculator Base score calculator. + * @param upstreamMap Map with {@code upstream} data. + * @param parts Number of partitions. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param cv Number of folds. + * @return Array of scores of the estimator for each run of the cross validation. + */ + public double[] score(DatasetTrainer trainer, ScoreCalculator scoreCalculator, Map upstreamMap, + int parts, IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor, int cv) { + return score(trainer, scoreCalculator, upstreamMap, (k, v) -> true, parts, featureExtractor, lbExtractor, + new SHA256UniformMapper<>(), cv); + } + + /** + * Computes cross-validated metrics. + * + * @param trainer Trainer of the model. + * @param scoreCalculator Base score calculator. + * @param upstreamMap Map with {@code upstream} data. + * @param filter Base {@code upstream} data filter. + * @param parts Number of partitions. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param cv Number of folds. + * @return Array of scores of the estimator for each run of the cross validation. + */ + public double[] score(DatasetTrainer trainer, ScoreCalculator scoreCalculator, Map upstreamMap, + IgniteBiPredicate filter, int parts, IgniteBiFunction featureExtractor, + IgniteBiFunction lbExtractor, int cv) { + return score(trainer, scoreCalculator, upstreamMap, filter, parts, featureExtractor, lbExtractor, + new SHA256UniformMapper<>(), cv); + } + + /** + * Computes cross-validated metrics. + * + * @param trainer Trainer of the model. + * @param scoreCalculator Base score calculator. + * @param upstreamMap Map with {@code upstream} data. + * @param filter Base {@code upstream} data filter. + * @param parts Number of partitions. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param mapper Mapper used to map a key-value pair to a point on the segment (0, 1). + * @param cv Number of folds. + * @return Array of scores of the estimator for each run of the cross validation. + */ + public double[] score(DatasetTrainer trainer, ScoreCalculator scoreCalculator, Map upstreamMap, + IgniteBiPredicate filter, int parts, IgniteBiFunction featureExtractor, + IgniteBiFunction lbExtractor, UniformMapper mapper, int cv) { + return score( + trainer, + predicate -> new LocalDatasetBuilder<>( + upstreamMap, + (k, v) -> filter.apply(k, v) && predicate.apply(k, v), + parts + ), + (predicate, mdl) -> new LocalTruthWithPredictionCursor<>( + upstreamMap, + (k, v) -> filter.apply(k, v) && !predicate.apply(k, v), + featureExtractor, + lbExtractor, + mdl + ), + featureExtractor, + lbExtractor, + scoreCalculator, + mapper, + cv + ); + } + + /** + * Computes cross-validated metrics. + * + * @param trainer Trainer of the model. + * @param datasetBuilderSupplier Dataset builder supplier. + * @param testDataIterSupplier Test data iterator supplier. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param scoreCalculator Base score calculator. + * @param mapper Mapper used to map a key-value pair to a point on the segment (0, 1). + * @param cv Number of folds. + * @return Array of scores of the estimator for each run of the cross validation. + */ + private double[] score(DatasetTrainer trainer, Function, + DatasetBuilder> datasetBuilderSupplier, + BiFunction, M, TruthWithPredictionCursor> testDataIterSupplier, + IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor, + ScoreCalculator scoreCalculator, UniformMapper mapper, int cv) { + + double[] scores = new double[cv]; + + double foldSize = 1.0 / cv; + for (int i = 0; i < cv; i++) { + double from = foldSize * i; + double to = foldSize * (i + 1); + + IgniteBiPredicate trainSetFilter = (k, v) -> { + double pnt = mapper.map(k, v); + return pnt < from || pnt > to; + }; + + DatasetBuilder datasetBuilder = datasetBuilderSupplier.apply(trainSetFilter); + M mdl = trainer.fit(datasetBuilder, featureExtractor, lbExtractor); + + try (TruthWithPredictionCursor cursor = testDataIterSupplier.apply(trainSetFilter, mdl)) { + scores[i] = scoreCalculator.score(cursor.iterator()); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + return scores; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/package-info.java new file mode 100644 index 0000000..66f129f --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * + * Root package for cross-validation algorithms. + */ +package org.apache.ignite.ml.selection.cv; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/AccuracyScoreCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/AccuracyScoreCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/AccuracyScoreCalculator.java new file mode 100644 index 0000000..c9e61f9 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/AccuracyScoreCalculator.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.score; + +import java.util.Iterator; + +/** + * Accuracy score calculator. + * + * @param Type of a label (truth or prediction). + */ +public class AccuracyScoreCalculator implements ScoreCalculator { + /** {@inheritDoc} */ + @Override public double score(Iterator> iter) { + long totalCnt = 0; + long correctCnt = 0; + + while (iter.hasNext()) { + TruthWithPrediction e = iter.next(); + + L prediction = e.getPrediction(); + L truth = e.getTruth(); + + if (prediction.equals(truth)) + correctCnt++; + + totalCnt++; + } + + return 1.0 * correctCnt / totalCnt; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/ScoreCalculator.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/ScoreCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/ScoreCalculator.java new file mode 100644 index 0000000..e792532 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/ScoreCalculator.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.score; + +import java.util.Iterator; + +/** + * Base interface for score calculators. + * + * @param Type of a label (truth or prediction). + */ +public interface ScoreCalculator { + /** + * Calculates score. + * + * @param iter Iterator that supplies pairs of truth values and predicated. + * @return Score. + */ + public double score(Iterator> iter); +} http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/TruthWithPrediction.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/TruthWithPrediction.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/TruthWithPrediction.java new file mode 100644 index 0000000..1b96b61 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/TruthWithPrediction.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.score; + +/** + * Pair of truth value and predicated by model. + * + * @param Type of a label (truth or prediction). + */ +public class TruthWithPrediction { + /** Truth value. */ + private final L truth; + + /** Predicted value. */ + private final L prediction; + + /** + * Constructs a new instance of truth with prediction. + * + * @param truth Truth value. + * @param prediction Predicted value. + */ + public TruthWithPrediction(L truth, L prediction) { + this.truth = truth; + this.prediction = prediction; + } + + /** */ + public L getTruth() { + return truth; + } + + /** */ + public L getPrediction() { + return prediction; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/package-info.java new file mode 100644 index 0000000..9656a59 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * + * Root package for score calculators. + */ +package org.apache.ignite.ml.selection.score; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursor.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursor.java new file mode 100644 index 0000000..862c7ab --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursor.java @@ -0,0 +1,124 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.score.util; + +import java.util.Iterator; +import javax.cache.Cache; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.cache.query.QueryCursor; +import org.apache.ignite.cache.query.ScanQuery; +import org.apache.ignite.lang.IgniteBiPredicate; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.selection.score.TruthWithPrediction; +import org.jetbrains.annotations.NotNull; + +/** + * Truth with prediction cursor based on a data stored in Ignite cache. + * + * @param Type of a label (truth or prediction). + * @param Type of a key in {@code upstream} data. + * @param Type of a value in {@code upstream} data. + */ +public class CacheBasedTruthWithPredictionCursor implements TruthWithPredictionCursor { + /** Query cursor. */ + private final QueryCursor> cursor; + + /** Feature extractor. */ + private final IgniteBiFunction featureExtractor; + + /** Label extractor. */ + private final IgniteBiFunction lbExtractor; + + /** Model for inference. */ + private final Model mdl; + + /** + * Constructs a new instance of cache based truth with prediction cursor. + * + * @param upstreamCache Ignite cache with {@code upstream} data. + * @param filter Filter for {@code upstream} data. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param mdl Model for inference. + */ + public CacheBasedTruthWithPredictionCursor(IgniteCache upstreamCache, IgniteBiPredicate filter, + IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor, + Model mdl) { + this.cursor = query(upstreamCache, filter); + this.featureExtractor = featureExtractor; + this.lbExtractor = lbExtractor; + this.mdl = mdl; + } + + /** {@inheritDoc} */ + @Override public void close() { + cursor.close(); + } + + /** {@inheritDoc} */ + @NotNull @Override public Iterator> iterator() { + return new TruthWithPredictionIterator(cursor.iterator()); + } + + /** + * Queries the specified cache using the specified filter. + * + * @param upstreamCache Ignite cache with {@code upstream} data. + * @param filter Filter for {@code upstream} data. + * @return Query cursor. + */ + private QueryCursor> query(IgniteCache upstreamCache, IgniteBiPredicate filter) { + ScanQuery qry = new ScanQuery<>(); + qry.setFilter(filter); + + return upstreamCache.query(qry); + } + + /** + * Util iterator that makes predictions using the model. + */ + private class TruthWithPredictionIterator implements Iterator> { + /** Base iterator. */ + private final Iterator> iter; + + /** + * Constructs a new instance of truth with prediction iterator. + * + * @param iter Base iterator. + */ + public TruthWithPredictionIterator(Iterator> iter) { + this.iter = iter; + } + + /** {@inheritDoc} */ + @Override public boolean hasNext() { + return iter.hasNext(); + } + + /** {@inheritDoc} */ + @Override public TruthWithPrediction next() { + Cache.Entry entry = iter.next(); + + double[] features = featureExtractor.apply(entry.getKey(), entry.getValue()); + L lb = lbExtractor.apply(entry.getKey(), entry.getValue()); + + return new TruthWithPrediction<>(lb, mdl.apply(features)); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursor.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursor.java new file mode 100644 index 0000000..093c6ed --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursor.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.score.util; + +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import org.apache.ignite.lang.IgniteBiPredicate; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.math.functions.IgniteBiFunction; +import org.apache.ignite.ml.selection.score.TruthWithPrediction; +import org.jetbrains.annotations.NotNull; + +/** + * Truth with prediction cursor based on a locally stored data. + * + * @param Type of a label (truth or prediction). + * @param Type of a key in {@code upstream} data. + * @param Type of a value in {@code upstream} data. + */ +public class LocalTruthWithPredictionCursor implements TruthWithPredictionCursor { + /** Map with {@code upstream} data. */ + private final Map upstreamMap; + + /** Filter for {@code upstream} data. */ + private final IgniteBiPredicate filter; + + /** Feature extractor. */ + private final IgniteBiFunction featureExtractor; + + /** Label extractor. */ + private final IgniteBiFunction lbExtractor; + + /** Model for inference. */ + private final Model mdl; + + /** + * Constructs a new instance of local truth with prediction cursor. + * + * @param upstreamMap Map with {@code upstream} data. + * @param filter Filter for {@code upstream} data. + * @param featureExtractor Feature extractor. + * @param lbExtractor Label extractor. + * @param mdl Model for inference. + */ + public LocalTruthWithPredictionCursor(Map upstreamMap, IgniteBiPredicate filter, + IgniteBiFunction featureExtractor, IgniteBiFunction lbExtractor, + Model mdl) { + this.upstreamMap = upstreamMap; + this.filter = filter; + this.featureExtractor = featureExtractor; + this.lbExtractor = lbExtractor; + this.mdl = mdl; + } + + /** {@inheritDoc} */ + @Override public void close() { + /* Do nothing. */ + } + + /** {@inheritDoc} */ + @NotNull @Override public Iterator> iterator() { + return new TruthWithPredictionIterator(upstreamMap.entrySet().iterator()); + } + + /** + * Util iterator that filters map entries and makes predictions using the model. + */ + private class TruthWithPredictionIterator implements Iterator> { + /** Base iterator. */ + private final Iterator> iter; + + /** Next found entry. */ + private Map.Entry nextEntry; + + /** + * Constructs a new instance of truth with prediction iterator. + * + * @param iter Base iterator. + */ + public TruthWithPredictionIterator(Iterator> iter) { + this.iter = iter; + } + + /** {@inheritDoc} */ + @Override public boolean hasNext() { + findNext(); + + return nextEntry != null; + } + + /** {@inheritDoc} */ + @Override public TruthWithPrediction next() { + if (!hasNext()) + throw new NoSuchElementException(); + + K key = nextEntry.getKey(); + V val = nextEntry.getValue(); + + double[] features = featureExtractor.apply(key, val); + L lb = lbExtractor.apply(key, val); + + nextEntry = null; + + return new TruthWithPrediction<>(lb, mdl.apply(features)); + } + + /** + * Finds next entry using the specified filter. + */ + private void findNext() { + while (nextEntry == null && iter.hasNext()) { + Map.Entry entry = iter.next(); + + if (filter.apply(entry.getKey(), entry.getValue())) { + this.nextEntry = entry; + break; + } + } + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/TruthWithPredictionCursor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/TruthWithPredictionCursor.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/TruthWithPredictionCursor.java new file mode 100644 index 0000000..525d743 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/TruthWithPredictionCursor.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.score.util; + +import org.apache.ignite.ml.selection.score.TruthWithPrediction; + +/** + * Closeable iterable that supplies pairs of truth and predictions (abstraction that hides a difference between querying + * data from Ignite cache and from local Map). + * + * @param Type of a label (truth or prediction). + */ +public interface TruthWithPredictionCursor extends Iterable>, AutoCloseable { +} http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/package-info.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/package-info.java new file mode 100644 index 0000000..9c86317 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/score/util/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * + * Util classes used for score calculation. + */ +package org.apache.ignite.ml.selection.score.util; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java index 0c3408e..9f60c48 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java @@ -25,6 +25,7 @@ import org.apache.ignite.ml.math.MathImplMainTestSuite; import org.apache.ignite.ml.nn.MLPTestSuite; import org.apache.ignite.ml.preprocessing.PreprocessingTestSuite; import org.apache.ignite.ml.regressions.RegressionsTestSuite; +import org.apache.ignite.ml.selection.SelectionTestSuite; import org.apache.ignite.ml.svm.SVMTestSuite; import org.apache.ignite.ml.tree.DecisionTreeTestSuite; import org.junit.runner.RunWith; @@ -45,7 +46,8 @@ import org.junit.runners.Suite; MLPTestSuite.class, DatasetTestSuite.class, PreprocessingTestSuite.class, - GAGridTestSuite.class + GAGridTestSuite.class, + SelectionTestSuite.class }) public class IgniteMLTestSuite { // No-op. http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java new file mode 100644 index 0000000..e8db932 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/SelectionTestSuite.java @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection; + +import org.apache.ignite.ml.selection.cv.CrossValidationScoreCalculatorTest; +import org.apache.ignite.ml.selection.score.AccuracyScoreCalculatorTest; +import org.apache.ignite.ml.selection.score.util.CacheBasedTruthWithPredictionCursorTest; +import org.apache.ignite.ml.selection.score.util.LocalTruthWithPredictionCursorTest; +import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitterTest; +import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapperTest; +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + +/** + * Test suite for all tests located in org.apache.ignite.ml.selection.* package. + */ +@RunWith(Suite.class) +@Suite.SuiteClasses({ + CrossValidationScoreCalculatorTest.class, + CacheBasedTruthWithPredictionCursorTest.class, + LocalTruthWithPredictionCursorTest.class, + AccuracyScoreCalculatorTest.class, + SHA256UniformMapperTest.class, + TrainTestDatasetSplitterTest.class +}) +public class SelectionTestSuite { + // No-op. +} http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculatorTest.java new file mode 100644 index 0000000..3679f1b --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationScoreCalculatorTest.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.cv; + +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.ml.selection.score.AccuracyScoreCalculator; +import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer; +import org.apache.ignite.ml.tree.DecisionTreeNode; +import org.junit.Test; + +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link CrossValidationScoreCalculator}. + */ +public class CrossValidationScoreCalculatorTest { + /** */ + @Test + public void testScoreWithGoodDataset() { + Map data = new HashMap<>(); + + for (int i = 0; i < 1000; i++) + data.put(i, i > 500 ? 1.0 : 0.0); + + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0); + + CrossValidationScoreCalculator scoreCalculator = + new CrossValidationScoreCalculator<>(); + + int folds = 4; + + double[] scores = scoreCalculator.score( + trainer, + new AccuracyScoreCalculator<>(), + data, + 1, + (k, v) -> new double[]{k}, + (k, v) -> v, + folds + ); + + assertEquals(folds, scores.length); + + for (int i = 0; i < folds; i++) + assertEquals(1, scores[i], 1e-1); + } + + /** */ + @Test + public void testScoreWithBadDataset() { + Map data = new HashMap<>(); + + for (int i = 0; i < 1000; i++) + data.put(i, i % 2 == 0 ? 1.0 : 0.0); + + DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0); + + CrossValidationScoreCalculator scoreCalculator = + new CrossValidationScoreCalculator<>(); + + int folds = 4; + + double[] scores = scoreCalculator.score( + trainer, + new AccuracyScoreCalculator<>(), + data, + 1, + (k, v) -> new double[]{k}, + (k, v) -> v, + folds + ); + + assertEquals(folds, scores.length); + + for (int i = 0; i < folds; i++) + assertTrue(scores[i] < 0.6); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/AccuracyScoreCalculatorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/AccuracyScoreCalculatorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/AccuracyScoreCalculatorTest.java new file mode 100644 index 0000000..b79bd9a --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/AccuracyScoreCalculatorTest.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.score; + +import java.util.Arrays; +import org.apache.ignite.ml.selection.score.util.TruthWithPredictionCursor; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link AccuracyScoreCalculator}. + */ +public class AccuracyScoreCalculatorTest { + /** */ + @Test + public void testScore() { + ScoreCalculator scoreCalculator = new AccuracyScoreCalculator<>(); + + TruthWithPredictionCursor cursor = new TestTruthWithPredictionCursor<>( + Arrays.asList(1, 1, 1, 1), + Arrays.asList(1, 1, 0, 1) + ); + + double score = scoreCalculator.score(cursor.iterator()); + + assertEquals(0.75, score, 1e-12); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/TestTruthWithPredictionCursor.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/TestTruthWithPredictionCursor.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/TestTruthWithPredictionCursor.java new file mode 100644 index 0000000..db4f936 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/TestTruthWithPredictionCursor.java @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.score; + +import java.util.Iterator; +import java.util.List; +import org.apache.ignite.ml.selection.score.util.TruthWithPredictionCursor; +import org.jetbrains.annotations.NotNull; + +/** + * Util truth with prediction cursor to be used in tests. + * + * @param Type of a label (truth or prediction). + */ +public class TestTruthWithPredictionCursor implements TruthWithPredictionCursor { + /** List of truth values. */ + private final List truth; + + /** List of predicted values. */ + private final List predicted; + + /** + * Constructs a new instance of test truth with prediction cursor. + * + * @param truth List of truth values. + * @param predicted List of predicted values. + */ + public TestTruthWithPredictionCursor(List truth, List predicted) { + this.truth = truth; + this.predicted = predicted; + } + + /** {@inheritDoc} */ + @Override public void close() throws Exception { + /* Do nothing. */ + } + + /** {@inheritDoc} */ + @NotNull @Override public Iterator> iterator() { + return new TestTruthWithPredictionIterator<>(truth.iterator(), predicted.iterator()); + } + + /** + * Util truth with prediction iterator to be used in tests. + * + * @param Type of a label (truth or prediction). + */ + private static final class TestTruthWithPredictionIterator implements Iterator> { + /** Iterator of truth values. */ + private final Iterator truthIter; + + /** Iterator of predicted values. */ + private final Iterator predictedIter; + + /** + * Constructs a new instance of test truth with prediction iterator. + * + * @param truthIter Iterator of truth values. + * @param predictedIter Iterator of predicted values. + */ + public TestTruthWithPredictionIterator(Iterator truthIter, Iterator predictedIter) { + this.truthIter = truthIter; + this.predictedIter = predictedIter; + } + + /** {@inheritDoc} */ + @Override public boolean hasNext() { + return truthIter.hasNext() && predictedIter.hasNext(); + } + + /** {@inheritDoc} */ + @Override public TruthWithPrediction next() { + return new TruthWithPrediction<>(truthIter.next(), predictedIter.next()); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursorTest.java new file mode 100644 index 0000000..7eba10f --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/CacheBasedTruthWithPredictionCursorTest.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.score.util; + +import java.util.UUID; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.ml.selection.score.TruthWithPrediction; +import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest; + +/** + * Tests for {@link CacheBasedTruthWithPredictionCursor}. + */ +public class CacheBasedTruthWithPredictionCursorTest extends GridCommonAbstractTest { + /** Number of nodes in grid. */ + private static final int NODE_COUNT = 4; + + /** Ignite instance. */ + private Ignite ignite; + + /** {@inheritDoc} */ + @Override protected void beforeTestsStarted() throws Exception { + for (int i = 1; i <= NODE_COUNT; i++) + startGrid(i); + } + + /** {@inheritDoc} */ + @Override protected void afterTestsStopped() { + stopAllGrids(); + } + + /** {@inheritDoc} */ + @Override protected void beforeTest() { + /* Grid instance. */ + ignite = grid(NODE_COUNT); + ignite.configuration().setPeerClassLoadingEnabled(true); + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + } + + /** */ + public void testIterate() { + IgniteCache data = ignite.createCache(UUID.randomUUID().toString()); + + for (int i = 0; i < 1000; i++) + data.put(i, i); + + TruthWithPredictionCursor cursor = new CacheBasedTruthWithPredictionCursor<>( + data, + (k, v) -> v % 2 == 0, + (k, v) -> new double[]{v}, + (k, v) -> v, + arr -> (int)arr[0] + ); + + int cnt = 0; + for (TruthWithPrediction e : cursor) { + assertEquals(e.getPrediction(), e.getTruth()); + cnt++; + } + assertEquals(500, cnt); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/17351b43/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursorTest.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursorTest.java new file mode 100644 index 0000000..3fc3c83 --- /dev/null +++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/score/util/LocalTruthWithPredictionCursorTest.java @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.selection.score.util; + +import java.util.HashMap; +import java.util.Map; +import org.apache.ignite.ml.selection.score.TruthWithPrediction; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * Tests for {@link LocalTruthWithPredictionCursor}. + */ +public class LocalTruthWithPredictionCursorTest { + /** */ + @Test + public void testIterate() { + Map data = new HashMap<>(); + + for (int i = 0; i < 1000; i++) + data.put(i, i); + + TruthWithPredictionCursor cursor = new LocalTruthWithPredictionCursor<>( + data, + (k, v) -> v % 2 == 0, + (k, v) -> new double[]{v}, + (k, v) -> v, + arr -> (int)arr[0] + ); + + int cnt = 0; + for (TruthWithPrediction e : cursor) { + assertEquals(e.getPrediction(), e.getTruth()); + cnt++; + } + assertEquals(500, cnt); + } +}