Return-Path: X-Original-To: archive-asf-public-internal@cust-asf2.ponee.io Delivered-To: archive-asf-public-internal@cust-asf2.ponee.io Received: from cust-asf.ponee.io (cust-asf.ponee.io [163.172.22.183]) by cust-asf2.ponee.io (Postfix) with ESMTP id A269C200D4E for ; Thu, 7 Dec 2017 16:15:10 +0100 (CET) Received: by cust-asf.ponee.io (Postfix) id A1431160BFE; Thu, 7 Dec 2017 15:15:10 +0000 (UTC) Delivered-To: archive-asf-public@cust-asf.ponee.io Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by cust-asf.ponee.io (Postfix) with SMTP id 4DB76160C0C for ; Thu, 7 Dec 2017 16:15:08 +0100 (CET) Received: (qmail 72448 invoked by uid 500); 7 Dec 2017 15:15:07 -0000 Mailing-List: contact commits-help@ignite.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: dev@ignite.apache.org Delivered-To: mailing list commits@ignite.apache.org Received: (qmail 72439 invoked by uid 99); 7 Dec 2017 15:15:07 -0000 Received: from git1-us-west.apache.org (HELO git1-us-west.apache.org) (140.211.11.23) by apache.org (qpsmtpd/0.29) with ESMTP; Thu, 07 Dec 2017 15:15:07 +0000 Received: by git1-us-west.apache.org (ASF Mail Server at git1-us-west.apache.org, from userid 33) id 0AA54DFA09; Thu, 7 Dec 2017 15:15:04 +0000 (UTC) Content-Type: text/plain; charset="us-ascii" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit From: chief@apache.org To: commits@ignite.apache.org Date: Thu, 07 Dec 2017 15:15:05 -0000 Message-Id: <553f8400e4b54b6285c917b4b5f42d47@git.apache.org> In-Reply-To: References: X-Mailer: ASF-Git Admin Mailer Subject: [2/2] ignite git commit: IGNITE-6872: Linear regression should implement Model API archived-at: Thu, 07 Dec 2017 15:15:10 -0000 IGNITE-6872: Linear regression should implement Model API This closes #3168 Project: http://git-wip-us.apache.org/repos/asf/ignite/repo Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/c5c512e4 Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/c5c512e4 Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/c5c512e4 Branch: refs/heads/master Commit: c5c512e460140c91fb77b527ff909ddbe3d1fd72 Parents: bbeb205 Author: Oleg Ignatenko Authored: Thu Dec 7 18:14:51 2017 +0300 Committer: Yury Babak Committed: Thu Dec 7 18:14:51 2017 +0300 ---------------------------------------------------------------------- .../decompositions/QRDecompositionExample.java | 82 ++++++ .../DistributedRegressionExample.java | 149 ----------- .../examples/ml/math/trees/MNISTExample.java | 261 ------------------- .../examples/ml/math/trees/package-info.java | 22 -- .../apache/ignite/examples/ml/package-info.java | 22 ++ .../DistributedRegressionExample.java | 149 +++++++++++ .../DistributedRegressionModelExample.java | 134 ++++++++++ .../examples/ml/regression/package-info.java | 22 ++ .../ignite/examples/ml/trees/MNISTExample.java | 261 +++++++++++++++++++ .../ignite/examples/ml/trees/package-info.java | 22 ++ .../ml/math/decompositions/QRDSolver.java | 197 ++++++++++++++ .../ml/math/decompositions/QRDecomposition.java | 54 +--- .../AbstractMultipleLinearRegression.java | 20 ++ .../OLSMultipleLinearRegression.java | 41 +-- .../OLSMultipleLinearRegressionModel.java | 77 ++++++ .../OLSMultipleLinearRegressionModelFormat.java | 46 ++++ .../OLSMultipleLinearRegressionTrainer.java | 62 +++++ .../org/apache/ignite/ml/IgniteMLTestSuite.java | 3 +- .../org/apache/ignite/ml/LocalModelsTest.java | 99 +++++-- .../ignite/ml/math/MathImplLocalTestSuite.java | 2 + .../ml/math/decompositions/QRDSolverTest.java | 87 +++++++ ...tedBlockOLSMultipleLinearRegressionTest.java | 38 ++- ...tributedOLSMultipleLinearRegressionTest.java | 38 ++- .../OLSMultipleLinearRegressionModelTest.java | 53 ++++ .../ml/regressions/RegressionsTestSuite.java | 5 +- 25 files changed, 1371 insertions(+), 575 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/examples/src/main/ml/org/apache/ignite/examples/ml/math/decompositions/QRDecompositionExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/math/decompositions/QRDecompositionExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/math/decompositions/QRDecompositionExample.java new file mode 100644 index 0000000..bed99d1 --- /dev/null +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/math/decompositions/QRDecompositionExample.java @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.math.decompositions; + +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Tracer; +import org.apache.ignite.ml.math.decompositions.QRDecomposition; +import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix; + +/** + * Example of using {@link QRDecomposition}. + */ +public class QRDecompositionExample { + /** + * Executes example. + * + * @param args Command line arguments, none required. + */ + public static void main(String[] args) { + System.out.println(">>> QR decomposition example started."); + Matrix m = new DenseLocalOnHeapMatrix(new double[][] { + {2.0d, -1.0d, 0.0d}, + {-1.0d, 2.0d, -1.0d}, + {0.0d, -1.0d, 2.0d} + }); + + System.out.println("\n>>> Input matrix:"); + Tracer.showAscii(m); + + QRDecomposition dec = new QRDecomposition(m); + System.out.println("\n>>> Value for full rank in decomposition: [" + dec.hasFullRank() + "]."); + + Matrix q = dec.getQ(); + Matrix r = dec.getR(); + + System.out.println("\n>>> Orthogonal matrix Q:"); + Tracer.showAscii(q); + System.out.println("\n>>> Upper triangular matrix R:"); + Tracer.showAscii(r); + + Matrix qSafeCp = safeCopy(q); + + Matrix identity = qSafeCp.times(qSafeCp.transpose()); + + System.out.println("\n>>> Identity matrix obtained from Q:"); + Tracer.showAscii(identity); + + Matrix recomposed = qSafeCp.times(r); + + System.out.println("\n>>> Recomposed input matrix:"); + Tracer.showAscii(recomposed); + + Matrix sol = dec.solve(new DenseLocalOnHeapMatrix(3, 10)); + + System.out.println("\n>>> Solved matrix:"); + Tracer.showAscii(sol); + + dec.destroy(); + + System.out.println("\n>>> QR decomposition example completed."); + } + + /** */ + private static Matrix safeCopy(Matrix orig) { + return new DenseLocalOnHeapMatrix(orig.rowSize(), orig.columnSize()).assign(orig); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/examples/src/main/ml/org/apache/ignite/examples/ml/math/regression/DistributedRegressionExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/math/regression/DistributedRegressionExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/math/regression/DistributedRegressionExample.java deleted file mode 100644 index de2c541..0000000 --- a/examples/src/main/ml/org/apache/ignite/examples/ml/math/regression/DistributedRegressionExample.java +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.examples.ml.math.regression; - -import java.util.Arrays; -import org.apache.ignite.Ignite; -import org.apache.ignite.Ignition; -import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample; -import org.apache.ignite.ml.math.StorageConstants; -import org.apache.ignite.ml.math.Tracer; -import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; -import org.apache.ignite.ml.regressions.OLSMultipleLinearRegression; -import org.apache.ignite.thread.IgniteThread; - -/** - * Run linear regression over distributed matrix. - * - * TODO: IGNITE-6222, Currently works only in local mode. - * - * @see OLSMultipleLinearRegression - */ -public class DistributedRegressionExample { - /** Run example. */ - public static void main(String[] args) throws InterruptedException { - System.out.println(); - System.out.println(">>> Linear regression over sparse distributed matrix API usage example started."); - // Start ignite grid. - try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { - System.out.println(">>> Ignite grid started."); - // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread - // because we create ignite cache internally. - IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), SparseDistributedMatrixExample.class.getSimpleName(), () -> { - - double[] data = { - 8, 78, 284, 9.100000381, 109, - 9.300000191, 68, 433, 8.699999809, 144, - 7.5, 70, 739, 7.199999809, 113, - 8.899999619, 96, 1792, 8.899999619, 97, - 10.19999981, 74, 477, 8.300000191, 206, - 8.300000191, 111, 362, 10.89999962, 124, - 8.800000191, 77, 671, 10, 152, - 8.800000191, 168, 636, 9.100000381, 162, - 10.69999981, 82, 329, 8.699999809, 150, - 11.69999981, 89, 634, 7.599999905, 134, - 8.5, 149, 631, 10.80000019, 292, - 8.300000191, 60, 257, 9.5, 108, - 8.199999809, 96, 284, 8.800000191, 111, - 7.900000095, 83, 603, 9.5, 182, - 10.30000019, 130, 686, 8.699999809, 129, - 7.400000095, 145, 345, 11.19999981, 158, - 9.600000381, 112, 1357, 9.699999809, 186, - 9.300000191, 131, 544, 9.600000381, 177, - 10.60000038, 80, 205, 9.100000381, 127, - 9.699999809, 130, 1264, 9.199999809, 179, - 11.60000038, 140, 688, 8.300000191, 80, - 8.100000381, 154, 354, 8.399999619, 103, - 9.800000191, 118, 1632, 9.399999619, 101, - 7.400000095, 94, 348, 9.800000191, 117, - 9.399999619, 119, 370, 10.39999962, 88, - 11.19999981, 153, 648, 9.899999619, 78, - 9.100000381, 116, 366, 9.199999809, 102, - 10.5, 97, 540, 10.30000019, 95, - 11.89999962, 176, 680, 8.899999619, 80, - 8.399999619, 75, 345, 9.600000381, 92, - 5, 134, 525, 10.30000019, 126, - 9.800000191, 161, 870, 10.39999962, 108, - 9.800000191, 111, 669, 9.699999809, 77, - 10.80000019, 114, 452, 9.600000381, 60, - 10.10000038, 142, 430, 10.69999981, 71, - 10.89999962, 238, 822, 10.30000019, 86, - 9.199999809, 78, 190, 10.69999981, 93, - 8.300000191, 196, 867, 9.600000381, 106, - 7.300000191, 125, 969, 10.5, 162, - 9.399999619, 82, 499, 7.699999809, 95, - 9.399999619, 125, 925, 10.19999981, 91, - 9.800000191, 129, 353, 9.899999619, 52, - 3.599999905, 84, 288, 8.399999619, 110, - 8.399999619, 183, 718, 10.39999962, 69, - 10.80000019, 119, 540, 9.199999809, 57, - 10.10000038, 180, 668, 13, 106, - 9, 82, 347, 8.800000191, 40, - 10, 71, 345, 9.199999809, 50, - 11.30000019, 118, 463, 7.800000191, 35, - 11.30000019, 121, 728, 8.199999809, 86, - 12.80000019, 68, 383, 7.400000095, 57, - 10, 112, 316, 10.39999962, 57, - 6.699999809, 109, 388, 8.899999619, 94 - }; - - final int nobs = 53; - final int nvars = 4; - - System.out.println(">>> Create new SparseDistributedMatrix inside IgniteThread."); - // Create SparseDistributedMatrix, new cache will be created automagically. - SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(0, 0, - StorageConstants.ROW_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE); - - System.out.println(">>> Create new linear regression object"); - OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(); - regression.newSampleData(data, nobs, nvars, distributedMatrix); - System.out.println(); - - System.out.println(">>> Estimates the regression parameters b:"); - System.out.println(Arrays.toString(regression.estimateRegressionParameters())); - - System.out.println(">>> Estimates the residuals, ie u = y - X*b:"); - System.out.println(Arrays.toString(regression.estimateResiduals())); - - System.out.println(">>> Standard errors of the regression parameters:"); - System.out.println(Arrays.toString(regression.estimateRegressionParametersStandardErrors())); - - System.out.println(">>> Estimates the variance of the regression parameters, ie Var(b):"); - Tracer.showAscii(regression.estimateRegressionParametersVariance()); - - System.out.println(">>> Estimates the standard error of the regression:"); - System.out.println(regression.estimateRegressionStandardError()); - - System.out.println(">>> R-Squared statistic:"); - System.out.println(regression.calculateRSquared()); - - System.out.println(">>> Adjusted R-squared statistic:"); - System.out.println(regression.calculateAdjustedRSquared()); - - System.out.println(">>> Returns the variance of the regressand, ie Var(y):"); - System.out.println(regression.estimateErrorVariance()); - }); - - igniteThread.start(); - - igniteThread.join(); - } - } - -} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/MNISTExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/MNISTExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/MNISTExample.java deleted file mode 100644 index 6aaadd9..0000000 --- a/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/MNISTExample.java +++ /dev/null @@ -1,261 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.ignite.examples.ml.math.trees; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Iterator; -import java.util.Random; -import java.util.function.Function; -import java.util.stream.Stream; -import org.apache.commons.cli.BasicParser; -import org.apache.commons.cli.CommandLine; -import org.apache.commons.cli.CommandLineParser; -import org.apache.commons.cli.Option; -import org.apache.commons.cli.OptionBuilder; -import org.apache.commons.cli.Options; -import org.apache.commons.cli.ParseException; -import org.apache.ignite.Ignite; -import org.apache.ignite.IgniteCache; -import org.apache.ignite.IgniteDataStreamer; -import org.apache.ignite.Ignition; -import org.apache.ignite.cache.CacheAtomicityMode; -import org.apache.ignite.cache.CacheMode; -import org.apache.ignite.cache.CacheWriteSynchronizationMode; -import org.apache.ignite.configuration.CacheConfiguration; -import org.apache.ignite.examples.ExampleNodeStartup; -import org.apache.ignite.internal.util.IgniteUtils; -import org.apache.ignite.lang.IgniteBiTuple; -import org.apache.ignite.ml.Model; -import org.apache.ignite.ml.estimators.Estimators; -import org.apache.ignite.ml.math.Vector; -import org.apache.ignite.ml.math.functions.IgniteTriFunction; -import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; -import org.apache.ignite.ml.trees.models.DecisionTreeModel; -import org.apache.ignite.ml.trees.trainers.columnbased.BiIndex; -import org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput; -import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; -import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators; -import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator; -import org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators; -import org.apache.ignite.ml.util.MnistUtils; -import org.jetbrains.annotations.NotNull; - -/** - *

- * Example of usage of decision trees algorithm for MNIST dataset - * (it can be found here: http://yann.lecun.com/exdb/mnist/).

- *

- * Remote nodes should always be started with special configuration file which - * enables P2P class loading: {@code 'ignite.{sh|bat} examples/config/example-ignite.xml'}.

- *

- * Alternatively you can run {@link ExampleNodeStartup} in another JVM which will start node - * with {@code examples/config/example-ignite.xml} configuration.

- *

- * It is recommended to start at least one node prior to launching this example if you intend - * to run it with default memory settings.

- *

- * This example should with program arguments, for example - * -ts_i /path/to/train-images-idx3-ubyte - * -ts_l /path/to/train-labels-idx1-ubyte - * -tss_i /path/to/t10k-images-idx3-ubyte - * -tss_l /path/to/t10k-labels-idx1-ubyte - * -cfg examples/config/example-ignite.xml.

- *

- * -ts_i specifies path to training set images of MNIST; - * -ts_l specifies path to training set labels of MNIST; - * -tss_i specifies path to test set images of MNIST; - * -tss_l specifies path to test set labels of MNIST; - * -cfg specifies path to a config path.

- */ -public class MNISTExample { - /** Name of parameter specifying path to training set images. */ - private static final String MNIST_TRAINING_IMAGES_PATH = "ts_i"; - - /** Name of parameter specifying path to training set labels. */ - private static final String MNIST_TRAINING_LABELS_PATH = "ts_l"; - - /** Name of parameter specifying path to test set images. */ - private static final String MNIST_TEST_IMAGES_PATH = "tss_i"; - - /** Name of parameter specifying path to test set labels. */ - private static final String MNIST_TEST_LABELS_PATH = "tss_l"; - - /** Name of parameter specifying path of Ignite config. */ - private static final String CONFIG = "cfg"; - - /** Default config path. */ - private static final String DEFAULT_CONFIG = "examples/config/example-ignite.xml"; - - /** - * Launches example. - * - * @param args Program arguments. - */ - public static void main(String[] args) { - String igniteCfgPath; - - CommandLineParser parser = new BasicParser(); - - String trainingImagesPath; - String trainingLabelsPath; - - String testImagesPath; - String testLabelsPath; - - try { - // Parse the command line arguments. - CommandLine line = parser.parse(buildOptions(), args); - - trainingImagesPath = line.getOptionValue(MNIST_TRAINING_IMAGES_PATH); - trainingLabelsPath = line.getOptionValue(MNIST_TRAINING_LABELS_PATH); - testImagesPath = line.getOptionValue(MNIST_TEST_IMAGES_PATH); - testLabelsPath = line.getOptionValue(MNIST_TEST_LABELS_PATH); - igniteCfgPath = line.getOptionValue(CONFIG, DEFAULT_CONFIG); - } - catch (ParseException e) { - e.printStackTrace(); - return; - } - - try (Ignite ignite = Ignition.start(igniteCfgPath)) { - IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); - - int ptsCnt = 60000; - int featCnt = 28 * 28; - - Stream trainingMnistStream = MnistUtils.mnist(trainingImagesPath, trainingLabelsPath, new Random(123L), ptsCnt); - Stream testMnistStream = MnistUtils.mnist(testImagesPath, testLabelsPath, new Random(123L), 10_000); - - IgniteCache cache = createBiIndexedCache(ignite); - - loadVectorsIntoBiIndexedCache(cache.getName(), trainingMnistStream.iterator(), featCnt + 1, ignite); - - ColumnDecisionTreeTrainer trainer = - new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite); - - System.out.println(">>> Training started"); - long before = System.currentTimeMillis(); - DecisionTreeModel mdl = trainer.train(new BiIndexedCacheColumnDecisionTreeTrainerInput(cache, new HashMap<>(), ptsCnt, featCnt)); - System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before)); - - IgniteTriFunction, Stream>, Function, Double> mse = Estimators.errorsPercentage(); - Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity()); - System.out.println(">>> Errs percentage: " + accuracy); - } - catch (IOException e) { - e.printStackTrace(); - } - } - - /** - * Build cli options. - */ - @NotNull private static Options buildOptions() { - Options options = new Options(); - - Option trsImagesPathOpt = OptionBuilder.withArgName(MNIST_TRAINING_IMAGES_PATH).withLongOpt(MNIST_TRAINING_IMAGES_PATH).hasArg() - .withDescription("Path to the MNIST training set.") - .isRequired(true).create(); - - Option trsLabelsPathOpt = OptionBuilder.withArgName(MNIST_TRAINING_LABELS_PATH).withLongOpt(MNIST_TRAINING_LABELS_PATH).hasArg() - .withDescription("Path to the MNIST training set.") - .isRequired(true).create(); - - Option tssImagesPathOpt = OptionBuilder.withArgName(MNIST_TEST_IMAGES_PATH).withLongOpt(MNIST_TEST_IMAGES_PATH).hasArg() - .withDescription("Path to the MNIST test set.") - .isRequired(true).create(); - - Option tssLabelsPathOpt = OptionBuilder.withArgName(MNIST_TEST_LABELS_PATH).withLongOpt(MNIST_TEST_LABELS_PATH).hasArg() - .withDescription("Path to the MNIST test set.") - .isRequired(true).create(); - - Option configOpt = OptionBuilder.withArgName(CONFIG).withLongOpt(CONFIG).hasArg() - .withDescription("Path to the config.") - .isRequired(false).create(); - - options.addOption(trsImagesPathOpt); - options.addOption(trsLabelsPathOpt); - options.addOption(tssImagesPathOpt); - options.addOption(tssLabelsPathOpt); - options.addOption(configOpt); - - return options; - } - - /** - * Creates cache where data for training is stored. - * - * @param ignite Ignite instance. - * @return cache where data for training is stored. - */ - private static IgniteCache createBiIndexedCache(Ignite ignite) { - CacheConfiguration cfg = new CacheConfiguration<>(); - - // Write to primary. - cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC); - - // Atomic transactions only. - cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC); - - // No eviction. - cfg.setEvictionPolicy(null); - - // No copying of values. - cfg.setCopyOnRead(false); - - // Cache is partitioned. - cfg.setCacheMode(CacheMode.PARTITIONED); - - cfg.setBackups(0); - - cfg.setName("TMP_BI_INDEXED_CACHE"); - - return ignite.getOrCreateCache(cfg); - } - - /** - * Loads vectors into cache. - * - * @param cacheName Name of cache. - * @param vectorsIterator Iterator over vectors to load. - * @param vectorSize Size of vector. - * @param ignite Ignite instance. - */ - private static void loadVectorsIntoBiIndexedCache(String cacheName, Iterator vectorsIterator, - int vectorSize, Ignite ignite) { - try (IgniteDataStreamer streamer = - ignite.dataStreamer(cacheName)) { - int sampleIdx = 0; - - streamer.perNodeBufferSize(10000); - - while (vectorsIterator.hasNext()) { - org.apache.ignite.ml.math.Vector next = vectorsIterator.next(); - - for (int i = 0; i < vectorSize; i++) - streamer.addData(new BiIndex(sampleIdx, i), next.getX(i)); - - sampleIdx++; - - if (sampleIdx % 1000 == 0) - System.out.println("Loaded " + sampleIdx + " vectors."); - } - } - } -} \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/package-info.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/package-info.java b/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/package-info.java deleted file mode 100644 index 9b6867b..0000000 --- a/examples/src/main/ml/org/apache/ignite/examples/ml/math/trees/package-info.java +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * - * Decision trees examples. - */ -package org.apache.ignite.examples.ml.math.trees; http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/examples/src/main/ml/org/apache/ignite/examples/ml/package-info.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/package-info.java b/examples/src/main/ml/org/apache/ignite/examples/ml/package-info.java new file mode 100644 index 0000000..52778b5 --- /dev/null +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * + * Machine learning examples. + */ +package org.apache.ignite.examples.ml; \ No newline at end of file http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/examples/src/main/ml/org/apache/ignite/examples/ml/regression/DistributedRegressionExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/regression/DistributedRegressionExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/regression/DistributedRegressionExample.java new file mode 100644 index 0000000..3e65527 --- /dev/null +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/regression/DistributedRegressionExample.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.regression; + +import java.util.Arrays; +import org.apache.ignite.Ignite; +import org.apache.ignite.Ignition; +import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample; +import org.apache.ignite.ml.math.StorageConstants; +import org.apache.ignite.ml.math.Tracer; +import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; +import org.apache.ignite.ml.regressions.OLSMultipleLinearRegression; +import org.apache.ignite.thread.IgniteThread; + +/** + * Run linear regression over distributed matrix. + * + * TODO: IGNITE-6222, Currently works only in local mode. + * + * @see OLSMultipleLinearRegression + */ +public class DistributedRegressionExample { + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + System.out.println(); + System.out.println(">>> Linear regression over sparse distributed matrix API usage example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread + // because we create ignite cache internally. + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), SparseDistributedMatrixExample.class.getSimpleName(), () -> { + + double[] data = { + 8, 78, 284, 9.100000381, 109, + 9.300000191, 68, 433, 8.699999809, 144, + 7.5, 70, 739, 7.199999809, 113, + 8.899999619, 96, 1792, 8.899999619, 97, + 10.19999981, 74, 477, 8.300000191, 206, + 8.300000191, 111, 362, 10.89999962, 124, + 8.800000191, 77, 671, 10, 152, + 8.800000191, 168, 636, 9.100000381, 162, + 10.69999981, 82, 329, 8.699999809, 150, + 11.69999981, 89, 634, 7.599999905, 134, + 8.5, 149, 631, 10.80000019, 292, + 8.300000191, 60, 257, 9.5, 108, + 8.199999809, 96, 284, 8.800000191, 111, + 7.900000095, 83, 603, 9.5, 182, + 10.30000019, 130, 686, 8.699999809, 129, + 7.400000095, 145, 345, 11.19999981, 158, + 9.600000381, 112, 1357, 9.699999809, 186, + 9.300000191, 131, 544, 9.600000381, 177, + 10.60000038, 80, 205, 9.100000381, 127, + 9.699999809, 130, 1264, 9.199999809, 179, + 11.60000038, 140, 688, 8.300000191, 80, + 8.100000381, 154, 354, 8.399999619, 103, + 9.800000191, 118, 1632, 9.399999619, 101, + 7.400000095, 94, 348, 9.800000191, 117, + 9.399999619, 119, 370, 10.39999962, 88, + 11.19999981, 153, 648, 9.899999619, 78, + 9.100000381, 116, 366, 9.199999809, 102, + 10.5, 97, 540, 10.30000019, 95, + 11.89999962, 176, 680, 8.899999619, 80, + 8.399999619, 75, 345, 9.600000381, 92, + 5, 134, 525, 10.30000019, 126, + 9.800000191, 161, 870, 10.39999962, 108, + 9.800000191, 111, 669, 9.699999809, 77, + 10.80000019, 114, 452, 9.600000381, 60, + 10.10000038, 142, 430, 10.69999981, 71, + 10.89999962, 238, 822, 10.30000019, 86, + 9.199999809, 78, 190, 10.69999981, 93, + 8.300000191, 196, 867, 9.600000381, 106, + 7.300000191, 125, 969, 10.5, 162, + 9.399999619, 82, 499, 7.699999809, 95, + 9.399999619, 125, 925, 10.19999981, 91, + 9.800000191, 129, 353, 9.899999619, 52, + 3.599999905, 84, 288, 8.399999619, 110, + 8.399999619, 183, 718, 10.39999962, 69, + 10.80000019, 119, 540, 9.199999809, 57, + 10.10000038, 180, 668, 13, 106, + 9, 82, 347, 8.800000191, 40, + 10, 71, 345, 9.199999809, 50, + 11.30000019, 118, 463, 7.800000191, 35, + 11.30000019, 121, 728, 8.199999809, 86, + 12.80000019, 68, 383, 7.400000095, 57, + 10, 112, 316, 10.39999962, 57, + 6.699999809, 109, 388, 8.899999619, 94 + }; + + final int nobs = 53; + final int nvars = 4; + + System.out.println(">>> Create new SparseDistributedMatrix inside IgniteThread."); + // Create SparseDistributedMatrix, new cache will be created automagically. + SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(0, 0, + StorageConstants.ROW_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE); + + System.out.println(">>> Create new linear regression object"); + OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(); + regression.newSampleData(data, nobs, nvars, distributedMatrix); + System.out.println(); + + System.out.println(">>> Estimates the regression parameters b:"); + System.out.println(Arrays.toString(regression.estimateRegressionParameters())); + + System.out.println(">>> Estimates the residuals, ie u = y - X*b:"); + System.out.println(Arrays.toString(regression.estimateResiduals())); + + System.out.println(">>> Standard errors of the regression parameters:"); + System.out.println(Arrays.toString(regression.estimateRegressionParametersStandardErrors())); + + System.out.println(">>> Estimates the variance of the regression parameters, ie Var(b):"); + Tracer.showAscii(regression.estimateRegressionParametersVariance()); + + System.out.println(">>> Estimates the standard error of the regression:"); + System.out.println(regression.estimateRegressionStandardError()); + + System.out.println(">>> R-Squared statistic:"); + System.out.println(regression.calculateRSquared()); + + System.out.println(">>> Adjusted R-squared statistic:"); + System.out.println(regression.calculateAdjustedRSquared()); + + System.out.println(">>> Returns the variance of the regressand, ie Var(y):"); + System.out.println(regression.estimateErrorVariance()); + }); + + igniteThread.start(); + + igniteThread.join(); + } + } + +} http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/examples/src/main/ml/org/apache/ignite/examples/ml/regression/DistributedRegressionModelExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/regression/DistributedRegressionModelExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/regression/DistributedRegressionModelExample.java new file mode 100644 index 0000000..ab1b17d --- /dev/null +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/regression/DistributedRegressionModelExample.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.regression; + +import org.apache.ignite.Ignite; +import org.apache.ignite.Ignition; +import org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample; +import org.apache.ignite.ml.math.StorageConstants; +import org.apache.ignite.ml.math.Tracer; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix; +import org.apache.ignite.ml.math.impls.vector.SparseDistributedVector; +import org.apache.ignite.ml.regressions.OLSMultipleLinearRegressionModel; +import org.apache.ignite.ml.regressions.OLSMultipleLinearRegressionTrainer; +import org.apache.ignite.thread.IgniteThread; + +/** + * Run linear regression model over distributed matrix. + * + * @see OLSMultipleLinearRegressionModel + */ +public class DistributedRegressionModelExample { + /** Run example. */ + public static void main(String[] args) throws InterruptedException { + System.out.println(); + System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started."); + // Start ignite grid. + try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) { + System.out.println(">>> Ignite grid started."); + // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread + // because we create ignite cache internally. + IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), + SparseDistributedMatrixExample.class.getSimpleName(), () -> { + double[] data = { + 8, 78, 284, 9.100000381, 109, + 9.300000191, 68, 433, 8.699999809, 144, + 7.5, 70, 739, 7.199999809, 113, + 8.899999619, 96, 1792, 8.899999619, 97, + 10.19999981, 74, 477, 8.300000191, 206, + 8.300000191, 111, 362, 10.89999962, 124, + 8.800000191, 77, 671, 10, 152, + 8.800000191, 168, 636, 9.100000381, 162, + 10.69999981, 82, 329, 8.699999809, 150, + 11.69999981, 89, 634, 7.599999905, 134, + 8.5, 149, 631, 10.80000019, 292, + 8.300000191, 60, 257, 9.5, 108, + 8.199999809, 96, 284, 8.800000191, 111, + 7.900000095, 83, 603, 9.5, 182, + 10.30000019, 130, 686, 8.699999809, 129, + 7.400000095, 145, 345, 11.19999981, 158, + 9.600000381, 112, 1357, 9.699999809, 186, + 9.300000191, 131, 544, 9.600000381, 177, + 10.60000038, 80, 205, 9.100000381, 127, + 9.699999809, 130, 1264, 9.199999809, 179, + 11.60000038, 140, 688, 8.300000191, 80, + 8.100000381, 154, 354, 8.399999619, 103, + 9.800000191, 118, 1632, 9.399999619, 101, + 7.400000095, 94, 348, 9.800000191, 117, + 9.399999619, 119, 370, 10.39999962, 88, + 11.19999981, 153, 648, 9.899999619, 78, + 9.100000381, 116, 366, 9.199999809, 102, + 10.5, 97, 540, 10.30000019, 95, + 11.89999962, 176, 680, 8.899999619, 80, + 8.399999619, 75, 345, 9.600000381, 92, + 5, 134, 525, 10.30000019, 126, + 9.800000191, 161, 870, 10.39999962, 108, + 9.800000191, 111, 669, 9.699999809, 77, + 10.80000019, 114, 452, 9.600000381, 60, + 10.10000038, 142, 430, 10.69999981, 71, + 10.89999962, 238, 822, 10.30000019, 86, + 9.199999809, 78, 190, 10.69999981, 93, + 8.300000191, 196, 867, 9.600000381, 106, + 7.300000191, 125, 969, 10.5, 162, + 9.399999619, 82, 499, 7.699999809, 95, + 9.399999619, 125, 925, 10.19999981, 91, + 9.800000191, 129, 353, 9.899999619, 52, + 3.599999905, 84, 288, 8.399999619, 110, + 8.399999619, 183, 718, 10.39999962, 69, + 10.80000019, 119, 540, 9.199999809, 57, + 10.10000038, 180, 668, 13, 106, + 9, 82, 347, 8.800000191, 40, + 10, 71, 345, 9.199999809, 50, + 11.30000019, 118, 463, 7.800000191, 35, + 11.30000019, 121, 728, 8.199999809, 86, + 12.80000019, 68, 383, 7.400000095, 57, + 10, 112, 316, 10.39999962, 57, + 6.699999809, 109, 388, 8.899999619, 94 + }; + + final int nobs = 53; + final int nvars = 4; + + System.out.println(">>> Create new SparseDistributedMatrix inside IgniteThread."); + // Create SparseDistributedMatrix, new cache will be created automagically. + SparseDistributedMatrix distributedMatrix = new SparseDistributedMatrix(0, 0, + StorageConstants.ROW_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE); + + System.out.println(">>> Create new linear regression trainer object."); + OLSMultipleLinearRegressionTrainer trainer + = new OLSMultipleLinearRegressionTrainer(0, nobs, nvars, distributedMatrix); + System.out.println(">>> Perform the training to get the model."); + OLSMultipleLinearRegressionModel mdl = trainer.train(data); + System.out.println(); + + Vector val = new SparseDistributedVector(nobs).assign((i) -> data[i * (nvars + 1)]); + + System.out.println(">>> The input data:"); + Tracer.showAscii(val); + + System.out.println(">>> Trained model prediction results:"); + Tracer.showAscii(mdl.predict(val)); + }); + + igniteThread.start(); + + igniteThread.join(); + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/examples/src/main/ml/org/apache/ignite/examples/ml/regression/package-info.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/regression/package-info.java b/examples/src/main/ml/org/apache/ignite/examples/ml/regression/package-info.java new file mode 100644 index 0000000..c89c80c --- /dev/null +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/regression/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * + * ML regression examples. + */ +package org.apache.ignite.examples.ml.regression; http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/examples/src/main/ml/org/apache/ignite/examples/ml/trees/MNISTExample.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/trees/MNISTExample.java b/examples/src/main/ml/org/apache/ignite/examples/ml/trees/MNISTExample.java new file mode 100644 index 0000000..6ff121e --- /dev/null +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/trees/MNISTExample.java @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.examples.ml.trees; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Random; +import java.util.function.Function; +import java.util.stream.Stream; +import org.apache.commons.cli.BasicParser; +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.CommandLineParser; +import org.apache.commons.cli.Option; +import org.apache.commons.cli.OptionBuilder; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.ParseException; +import org.apache.ignite.Ignite; +import org.apache.ignite.IgniteCache; +import org.apache.ignite.IgniteDataStreamer; +import org.apache.ignite.Ignition; +import org.apache.ignite.cache.CacheAtomicityMode; +import org.apache.ignite.cache.CacheMode; +import org.apache.ignite.cache.CacheWriteSynchronizationMode; +import org.apache.ignite.configuration.CacheConfiguration; +import org.apache.ignite.examples.ExampleNodeStartup; +import org.apache.ignite.internal.util.IgniteUtils; +import org.apache.ignite.lang.IgniteBiTuple; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.estimators.Estimators; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.functions.IgniteTriFunction; +import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector; +import org.apache.ignite.ml.trees.models.DecisionTreeModel; +import org.apache.ignite.ml.trees.trainers.columnbased.BiIndex; +import org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput; +import org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer; +import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators; +import org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator; +import org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators; +import org.apache.ignite.ml.util.MnistUtils; +import org.jetbrains.annotations.NotNull; + +/** + *

+ * Example of usage of decision trees algorithm for MNIST dataset + * (it can be found here: http://yann.lecun.com/exdb/mnist/).

+ *

+ * Remote nodes should always be started with special configuration file which + * enables P2P class loading: {@code 'ignite.{sh|bat} examples/config/example-ignite.xml'}.

+ *

+ * Alternatively you can run {@link ExampleNodeStartup} in another JVM which will start node + * with {@code examples/config/example-ignite.xml} configuration.

+ *

+ * It is recommended to start at least one node prior to launching this example if you intend + * to run it with default memory settings.

+ *

+ * This example should with program arguments, for example + * -ts_i /path/to/train-images-idx3-ubyte + * -ts_l /path/to/train-labels-idx1-ubyte + * -tss_i /path/to/t10k-images-idx3-ubyte + * -tss_l /path/to/t10k-labels-idx1-ubyte + * -cfg examples/config/example-ignite.xml.

+ *

+ * -ts_i specifies path to training set images of MNIST; + * -ts_l specifies path to training set labels of MNIST; + * -tss_i specifies path to test set images of MNIST; + * -tss_l specifies path to test set labels of MNIST; + * -cfg specifies path to a config path.

+ */ +public class MNISTExample { + /** Name of parameter specifying path to training set images. */ + private static final String MNIST_TRAINING_IMAGES_PATH = "ts_i"; + + /** Name of parameter specifying path to training set labels. */ + private static final String MNIST_TRAINING_LABELS_PATH = "ts_l"; + + /** Name of parameter specifying path to test set images. */ + private static final String MNIST_TEST_IMAGES_PATH = "tss_i"; + + /** Name of parameter specifying path to test set labels. */ + private static final String MNIST_TEST_LABELS_PATH = "tss_l"; + + /** Name of parameter specifying path of Ignite config. */ + private static final String CONFIG = "cfg"; + + /** Default config path. */ + private static final String DEFAULT_CONFIG = "examples/config/example-ignite.xml"; + + /** + * Launches example. + * + * @param args Program arguments. + */ + public static void main(String[] args) { + String igniteCfgPath; + + CommandLineParser parser = new BasicParser(); + + String trainingImagesPath; + String trainingLabelsPath; + + String testImagesPath; + String testLabelsPath; + + try { + // Parse the command line arguments. + CommandLine line = parser.parse(buildOptions(), args); + + trainingImagesPath = line.getOptionValue(MNIST_TRAINING_IMAGES_PATH); + trainingLabelsPath = line.getOptionValue(MNIST_TRAINING_LABELS_PATH); + testImagesPath = line.getOptionValue(MNIST_TEST_IMAGES_PATH); + testLabelsPath = line.getOptionValue(MNIST_TEST_LABELS_PATH); + igniteCfgPath = line.getOptionValue(CONFIG, DEFAULT_CONFIG); + } + catch (ParseException e) { + e.printStackTrace(); + return; + } + + try (Ignite ignite = Ignition.start(igniteCfgPath)) { + IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName()); + + int ptsCnt = 60000; + int featCnt = 28 * 28; + + Stream trainingMnistStream = MnistUtils.mnist(trainingImagesPath, trainingLabelsPath, new Random(123L), ptsCnt); + Stream testMnistStream = MnistUtils.mnist(testImagesPath, testLabelsPath, new Random(123L), 10_000); + + IgniteCache cache = createBiIndexedCache(ignite); + + loadVectorsIntoBiIndexedCache(cache.getName(), trainingMnistStream.iterator(), featCnt + 1, ignite); + + ColumnDecisionTreeTrainer trainer = + new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite); + + System.out.println(">>> Training started"); + long before = System.currentTimeMillis(); + DecisionTreeModel mdl = trainer.train(new BiIndexedCacheColumnDecisionTreeTrainerInput(cache, new HashMap<>(), ptsCnt, featCnt)); + System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before)); + + IgniteTriFunction, Stream>, Function, Double> mse = Estimators.errorsPercentage(); + Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity()); + System.out.println(">>> Errs percentage: " + accuracy); + } + catch (IOException e) { + e.printStackTrace(); + } + } + + /** + * Build cli options. + */ + @NotNull private static Options buildOptions() { + Options options = new Options(); + + Option trsImagesPathOpt = OptionBuilder.withArgName(MNIST_TRAINING_IMAGES_PATH).withLongOpt(MNIST_TRAINING_IMAGES_PATH).hasArg() + .withDescription("Path to the MNIST training set.") + .isRequired(true).create(); + + Option trsLabelsPathOpt = OptionBuilder.withArgName(MNIST_TRAINING_LABELS_PATH).withLongOpt(MNIST_TRAINING_LABELS_PATH).hasArg() + .withDescription("Path to the MNIST training set.") + .isRequired(true).create(); + + Option tssImagesPathOpt = OptionBuilder.withArgName(MNIST_TEST_IMAGES_PATH).withLongOpt(MNIST_TEST_IMAGES_PATH).hasArg() + .withDescription("Path to the MNIST test set.") + .isRequired(true).create(); + + Option tssLabelsPathOpt = OptionBuilder.withArgName(MNIST_TEST_LABELS_PATH).withLongOpt(MNIST_TEST_LABELS_PATH).hasArg() + .withDescription("Path to the MNIST test set.") + .isRequired(true).create(); + + Option configOpt = OptionBuilder.withArgName(CONFIG).withLongOpt(CONFIG).hasArg() + .withDescription("Path to the config.") + .isRequired(false).create(); + + options.addOption(trsImagesPathOpt); + options.addOption(trsLabelsPathOpt); + options.addOption(tssImagesPathOpt); + options.addOption(tssLabelsPathOpt); + options.addOption(configOpt); + + return options; + } + + /** + * Creates cache where data for training is stored. + * + * @param ignite Ignite instance. + * @return cache where data for training is stored. + */ + private static IgniteCache createBiIndexedCache(Ignite ignite) { + CacheConfiguration cfg = new CacheConfiguration<>(); + + // Write to primary. + cfg.setWriteSynchronizationMode(CacheWriteSynchronizationMode.PRIMARY_SYNC); + + // Atomic transactions only. + cfg.setAtomicityMode(CacheAtomicityMode.ATOMIC); + + // No eviction. + cfg.setEvictionPolicy(null); + + // No copying of values. + cfg.setCopyOnRead(false); + + // Cache is partitioned. + cfg.setCacheMode(CacheMode.PARTITIONED); + + cfg.setBackups(0); + + cfg.setName("TMP_BI_INDEXED_CACHE"); + + return ignite.getOrCreateCache(cfg); + } + + /** + * Loads vectors into cache. + * + * @param cacheName Name of cache. + * @param vectorsIterator Iterator over vectors to load. + * @param vectorSize Size of vector. + * @param ignite Ignite instance. + */ + private static void loadVectorsIntoBiIndexedCache(String cacheName, Iterator vectorsIterator, + int vectorSize, Ignite ignite) { + try (IgniteDataStreamer streamer = + ignite.dataStreamer(cacheName)) { + int sampleIdx = 0; + + streamer.perNodeBufferSize(10000); + + while (vectorsIterator.hasNext()) { + org.apache.ignite.ml.math.Vector next = vectorsIterator.next(); + + for (int i = 0; i < vectorSize; i++) + streamer.addData(new BiIndex(sampleIdx, i), next.getX(i)); + + sampleIdx++; + + if (sampleIdx % 1000 == 0) + System.out.println("Loaded " + sampleIdx + " vectors."); + } + } + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/examples/src/main/ml/org/apache/ignite/examples/ml/trees/package-info.java ---------------------------------------------------------------------- diff --git a/examples/src/main/ml/org/apache/ignite/examples/ml/trees/package-info.java b/examples/src/main/ml/org/apache/ignite/examples/ml/trees/package-info.java new file mode 100644 index 0000000..d944f60 --- /dev/null +++ b/examples/src/main/ml/org/apache/ignite/examples/ml/trees/package-info.java @@ -0,0 +1,22 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * + * Decision trees examples. + */ +package org.apache.ignite.examples.ml.trees; http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDSolver.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDSolver.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDSolver.java new file mode 100644 index 0000000..bb591ee --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDSolver.java @@ -0,0 +1,197 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.math.decompositions; + +import java.io.Serializable; +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.exceptions.NoDataException; +import org.apache.ignite.ml.math.exceptions.NullArgumentException; +import org.apache.ignite.ml.math.exceptions.SingularMatrixException; +import org.apache.ignite.ml.math.functions.Functions; +import org.apache.ignite.ml.math.util.MatrixUtil; + +import static org.apache.ignite.ml.math.util.MatrixUtil.like; + +/** + * For an {@code m x n} matrix {@code A} with {@code m >= n}, the QR decomposition + * is an {@code m x n} orthogonal matrix {@code Q} and an {@code n x n} upper + * triangular matrix {@code R} so that {@code A = Q*R}. + */ +public class QRDSolver implements Serializable { + /** */ + private final Matrix q; + + /** */ + private final Matrix r; + + /** + * Constructs a new QR decomposition solver object. + * + * @param q An orthogonal matrix. + * @param r An upper triangular matrix + */ + public QRDSolver(Matrix q, Matrix r) { + this.q = q; + this.r = r; + } + + /** + * Least squares solution of {@code A*X = B}; {@code returns X}. + * + * @param mtx A matrix with as many rows as {@code A} and any number of cols. + * @return {@code X<} that minimizes the two norm of {@code Q*R*X - B}. + * @throws IllegalArgumentException if {@code B.rows() != A.rows()}. + */ + public Matrix solve(Matrix mtx) { + if (mtx.rowSize() != q.rowSize()) + throw new IllegalArgumentException("Matrix row dimensions must agree."); + + int cols = mtx.columnSize(); + Matrix x = like(r, r.columnSize(), cols); + + Matrix qt = q.transpose(); + Matrix y = qt.times(mtx); + + for (int k = Math.min(r.columnSize(), q.rowSize()) - 1; k >= 0; k--) { + // X[k,] = Y[k,] / R[k,k], note that X[k,] starts with 0 so += is same as = + x.viewRow(k).map(y.viewRow(k), Functions.plusMult(1 / r.get(k, k))); + + if (k == 0) + continue; + + // Y[0:(k-1),] -= R[0:(k-1),k] * X[k,] + Vector rCol = r.viewColumn(k).viewPart(0, k); + + for (int c = 0; c < cols; c++) + y.viewColumn(c).viewPart(0, k).map(rCol, Functions.plusMult(-x.get(k, c))); + } + + return x; + } + + /** + * Least squares solution of {@code A*X = B}; {@code returns X}. + * + * @param vec A vector with as many rows as {@code A}. + * @return {@code X<} that minimizes the two norm of {@code Q*R*X - B}. + * @throws IllegalArgumentException if {@code B.rows() != A.rows()}. + */ + public Vector solve(Vector vec) { + if (vec == null) + throw new NullArgumentException(); + if (vec.size() == 0) + throw new NoDataException(); + // TODO: IGNITE-5826, Should we copy here? + + Matrix res = solve(vec.likeMatrix(vec.size(), 1).assignColumn(0, vec)); + + return vec.like(res.rowSize()).assign(res.viewColumn(0)); + } + + /** + *

Compute the "hat" matrix. + *

+ *

The hat matrix is defined in terms of the design matrix X + * by X(XTX)-1XT + *

+ *

The implementation here uses the QR decomposition to compute the + * hat matrix as Q IpQT where Ip is the + * p-dimensional identity matrix augmented by 0's. This computational + * formula is from "The Hat Matrix in Regression and ANOVA", + * David C. Hoaglin and Roy E. Welsch, + * The American Statistician, Vol. 32, No. 1 (Feb., 1978), pp. 17-22. + *

+ *

Data for the model must have been successfully loaded using one of + * the {@code newSampleData} methods before invoking this method; otherwise + * a {@code NullPointerException} will be thrown.

+ * + * @return the hat matrix + * @throws NullPointerException unless method {@code newSampleData} has been called beforehand. + */ + public Matrix calculateHat() { + // Create augmented identity matrix + // No try-catch or advertised NotStrictlyPositiveException - NPE above if n < 3 + Matrix augI = MatrixUtil.like(q, q.columnSize(), q.columnSize()); + + int n = augI.columnSize(); + int p = r.columnSize(); + + for (int i = 0; i < n; i++) + for (int j = 0; j < n; j++) + if (i == j && i < p) + augI.setX(i, j, 1d); + else + augI.setX(i, j, 0d); + + // Compute and return Hat matrix + // No DME advertised - args valid if we get here + return q.times(augI).times(q.transpose()); + } + + /** + *

Calculates the variance-covariance matrix of the regression parameters. + *

+ *

Var(b) = (XTX)-1 + *

+ *

Uses QR decomposition to reduce (XTX)-1 + * to (RTR)-1, with only the top p rows of + * R included, where p = the length of the beta vector.

+ * + *

Data for the model must have been successfully loaded using one of + * the {@code newSampleData} methods before invoking this method; otherwise + * a {@code NullPointerException} will be thrown.

+ * + * @param p Size of the beta variance-covariance matrix + * @return The beta variance-covariance matrix + * @throws SingularMatrixException if the design matrix is singular + * @throws NullPointerException if the data for the model have not been loaded + */ + public Matrix calculateBetaVariance(int p) { + Matrix rAug = MatrixUtil.copy(r.viewPart(0, p, 0, p)); + Matrix rInv = rAug.inverse(); + + return rInv.times(rInv.transpose()); + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + QRDSolver solver = (QRDSolver)o; + + return q.equals(solver.q) && r.equals(solver.r); + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + int res = q.hashCode(); + res = 31 * res + r.hashCode(); + return res; + } + + /** + * Returns a rough string rendition of a QRD solver. + */ + @Override public String toString() { + return String.format("QRD Solver(%d x %d)", q.rowSize(), r.columnSize()); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDecomposition.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDecomposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDecomposition.java index 3d0bb5d..c069683 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDecomposition.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/decompositions/QRDecomposition.java @@ -46,8 +46,6 @@ public class QRDecomposition implements Destroyable { private final int rows; /** */ private final int cols; - /** */ - private double threshold; /** * @param v Value to be checked for being an ordinary double. @@ -89,7 +87,6 @@ public class QRDecomposition implements Destroyable { boolean fullRank = true; r = like(mtx, min, cols); - this.threshold = threshold; for (int i = 0; i < min; i++) { Vector qi = qTmp.viewColumn(i); @@ -129,6 +126,8 @@ public class QRDecomposition implements Destroyable { else q = qTmp; + verifyNonSingularR(threshold); + this.fullRank = fullRank; } @@ -170,32 +169,7 @@ public class QRDecomposition implements Destroyable { * @throws IllegalArgumentException if {@code B.rows() != A.rows()}. */ public Matrix solve(Matrix mtx) { - if (mtx.rowSize() != rows) - throw new IllegalArgumentException("Matrix row dimensions must agree."); - - int cols = mtx.columnSize(); - Matrix r = getR(); - checkSingular(r, threshold, true); - Matrix x = like(mType, this.cols, cols); - - Matrix qt = getQ().transpose(); - Matrix y = qt.times(mtx); - - for (int k = Math.min(this.cols, rows) - 1; k >= 0; k--) { - // X[k,] = Y[k,] / R[k,k], note that X[k,] starts with 0 so += is same as = - x.viewRow(k).map(y.viewRow(k), Functions.plusMult(1 / r.get(k, k))); - - if (k == 0) - continue; - - // Y[0:(k-1),] -= R[0:(k-1),k] * X[k,] - Vector rCol = r.viewColumn(k).viewPart(0, k); - - for (int c = 0; c < cols; c++) - y.viewColumn(c).viewPart(0, k).map(rCol, Functions.plusMult(-x.get(k, c))); - } - - return x; + return new QRDSolver(q, r).solve(mtx); } /** @@ -206,8 +180,7 @@ public class QRDecomposition implements Destroyable { * @throws IllegalArgumentException if {@code B.rows() != A.rows()}. */ public Vector solve(Vector vec) { - Matrix res = solve(vec.likeMatrix(vec.size(), 1).assignColumn(0, vec)); - return vec.like(res.rowSize()).assign(res.viewColumn(0)); + return new QRDSolver(q, r).solve(vec); } /** @@ -220,27 +193,20 @@ public class QRDecomposition implements Destroyable { /** * Check singularity. * - * @param r R matrix. * @param min Singularity threshold. - * @param raise Whether to raise a {@link SingularMatrixException} if any element of the diagonal fails the check. - * @return {@code true} if any element of the diagonal is smaller or equal to {@code min}. * @throws SingularMatrixException if the matrix is singular and {@code raise} is {@code true}. */ - private static boolean checkSingular(Matrix r, double min, boolean raise) { - // TODO: IGNITE-5828, Not a very fast approach for distributed matrices. would be nice if we could independently check - // parts on different nodes for singularity and do fold with 'or'. + private void verifyNonSingularR(double min) { + // TODO: IGNITE-5828, Not a very fast approach for distributed matrices. would be nice if we could independently + // check parts on different nodes for singularity and do fold with 'or'. - final int len = r.columnSize(); + final int len = r.columnSize() > r.rowSize() ? r.rowSize() : r.columnSize(); for (int i = 0; i < len; i++) { final double d = r.getX(i, i); if (Math.abs(d) <= min) - if (raise) - throw new SingularMatrixException("Number is too small (%f, while " + - "threshold is %f). Index of diagonal element is (%d, %d)", d, min, i, i); - else - return true; + throw new SingularMatrixException("Number is too small (%f, while " + + "threshold is %f). Index of diagonal element is (%d, %d)", d, min, i, i); } - return false; } } http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/main/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegression.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegression.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegression.java index a2a8f16..5bc92c9 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegression.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/AbstractMultipleLinearRegression.java @@ -355,4 +355,24 @@ public abstract class AbstractMultipleLinearRegression implements MultipleLinear return yVector.minus(xMatrix.times(b)); } + /** {@inheritDoc} */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + AbstractMultipleLinearRegression that = (AbstractMultipleLinearRegression)o; + + return noIntercept == that.noIntercept && xMatrix.equals(that.xMatrix); + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + int res = xMatrix.hashCode(); + + res = 31 * res + (noIntercept ? 1 : 0); + + return res; + } } http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegression.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegression.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegression.java index 36d5f2c..aafeae8 100644 --- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegression.java +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegression.java @@ -18,11 +18,11 @@ package org.apache.ignite.ml.regressions; import org.apache.ignite.ml.math.Matrix; import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.decompositions.QRDSolver; import org.apache.ignite.ml.math.decompositions.QRDecomposition; import org.apache.ignite.ml.math.exceptions.MathIllegalArgumentException; import org.apache.ignite.ml.math.exceptions.SingularMatrixException; import org.apache.ignite.ml.math.functions.Functions; -import org.apache.ignite.ml.math.util.MatrixUtil; /** * This class is based on the corresponding class from Apache Common Math lib. @@ -51,7 +51,7 @@ import org.apache.ignite.ml.math.util.MatrixUtil; */ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegression { /** Cached QR decomposition of X matrix */ - private QRDecomposition qr = null; + private QRDSolver solver = null; /** Singularity threshold for QR decomposition */ private final double threshold; @@ -94,7 +94,8 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio */ @Override public void newSampleData(double[] data, int nobs, int nvars, Matrix like) { super.newSampleData(data, nobs, nvars, like); - qr = new QRDecomposition(getX(), threshold); + QRDecomposition qr = new QRDecomposition(getX(), threshold); + solver = new QRDSolver(qr.getQ(), qr.getR()); } /** @@ -118,24 +119,7 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio * @throws NullPointerException unless method {@code newSampleData} has been called beforehand. */ public Matrix calculateHat() { - // Create augmented identity matrix - // No try-catch or advertised NotStrictlyPositiveException - NPE above if n < 3 - Matrix q = qr.getQ(); - Matrix augI = MatrixUtil.like(q, q.columnSize(), q.columnSize()); - - int n = augI.columnSize(); - int p = qr.getR().columnSize(); - - for (int i = 0; i < n; i++) - for (int j = 0; j < n; j++) - if (i == j && i < p) - augI.setX(i, j, 1d); - else - augI.setX(i, j, 0d); - - // Compute and return Hat matrix - // No DME advertised - args valid if we get here - return q.times(augI).times(q.transpose()); + return solver.calculateHat(); } /** @@ -226,7 +210,8 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio */ @Override protected void newXSampleData(Matrix x) { super.newXSampleData(x); - qr = new QRDecomposition(getX()); + QRDecomposition qr = new QRDecomposition(getX()); + solver = new QRDSolver(qr.getQ(), qr.getR()); } /** @@ -241,7 +226,7 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio * @throws NullPointerException if the data for the model have not been loaded */ @Override protected Vector calculateBeta() { - return qr.solve(getY()); + return solver.solve(getY()); } /** @@ -262,11 +247,11 @@ public class OLSMultipleLinearRegression extends AbstractMultipleLinearRegressio * @throws NullPointerException if the data for the model have not been loaded */ @Override protected Matrix calculateBetaVariance() { - int p = getX().columnSize(); - - Matrix rAug = MatrixUtil.copy(qr.getR().viewPart(0, p, 0, p)); - Matrix rInv = rAug.inverse(); + return solver.calculateBetaVariance(getX().columnSize()); + } - return rInv.times(rInv.transpose()); + /** */ + QRDSolver solver() { + return solver; } } http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java new file mode 100644 index 0000000..76a90fc --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModel.java @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions; + +import org.apache.ignite.ml.Exportable; +import org.apache.ignite.ml.Exporter; +import org.apache.ignite.ml.Model; +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.Vector; +import org.apache.ignite.ml.math.decompositions.QRDSolver; +import org.apache.ignite.ml.math.decompositions.QRDecomposition; + +/** + * Model for linear regression. + */ +public class OLSMultipleLinearRegressionModel implements Model, + Exportable { + /** */ + private final Matrix xMatrix; + /** */ + private final QRDSolver solver; + + /** + * Construct linear regression model. + * + * @param xMatrix See {@link QRDecomposition#QRDecomposition(Matrix)}. + * @param solver Linear regression solver object. + */ + public OLSMultipleLinearRegressionModel(Matrix xMatrix, QRDSolver solver) { + this.xMatrix = xMatrix; + this.solver = solver; + } + + /** {@inheritDoc} */ + @Override public Vector predict(Vector val) { + return xMatrix.times(solver.solve(val)); + } + + /** {@inheritDoc} */ + @Override public

void saveModel(Exporter exporter, P path) { + exporter.save(new OLSMultipleLinearRegressionModelFormat(xMatrix, solver), path); + } + + /** {@inheritDoc} */ + @Override public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + OLSMultipleLinearRegressionModel mdl = (OLSMultipleLinearRegressionModel)o; + + return xMatrix.equals(mdl.xMatrix) && solver.equals(mdl.solver); + } + + /** {@inheritDoc} */ + @Override public int hashCode() { + int res = xMatrix.hashCode(); + res = 31 * res + solver.hashCode(); + return res; + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelFormat.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelFormat.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelFormat.java new file mode 100644 index 0000000..fc44968 --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelFormat.java @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions; + +import java.io.Serializable; +import org.apache.ignite.ml.math.Matrix; +import org.apache.ignite.ml.math.decompositions.QRDSolver; + +/** + * Linear regression model representation. + * + * @see OLSMultipleLinearRegressionModel + */ +public class OLSMultipleLinearRegressionModelFormat implements Serializable { + /** X sample data. */ + private final Matrix xMatrix; + + /** Whether or not the regression model includes an intercept. True means no intercept. */ + private final QRDSolver solver; + + /** */ + public OLSMultipleLinearRegressionModelFormat(Matrix xMatrix, QRDSolver solver) { + this.xMatrix = xMatrix; + this.solver = solver; + } + + /** */ + public OLSMultipleLinearRegressionModel getOLSMultipleLinearRegressionModel() { + return new OLSMultipleLinearRegressionModel(xMatrix, solver); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTrainer.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTrainer.java new file mode 100644 index 0000000..dde0aca --- /dev/null +++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionTrainer.java @@ -0,0 +1,62 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ignite.ml.regressions; + +import org.apache.ignite.ml.Trainer; +import org.apache.ignite.ml.math.Matrix; + +/** + * Trainer for linear regression. + */ +public class OLSMultipleLinearRegressionTrainer implements Trainer { + /** */ + private final double threshold; + + /** */ + private final int nobs; + + /** */ + private final int nvars; + + /** */ + private final Matrix like; + + /** + * Construct linear regression trainer. + * + * @param threshold the singularity threshold for QR decomposition + * @param nobs number of observations (rows) + * @param nvars number of independent variables (columns, not counting y) + * @param like matrix(maybe empty) indicating how data should be stored + */ + public OLSMultipleLinearRegressionTrainer(double threshold, int nobs, int nvars, Matrix like) { + this.threshold = threshold; + this.nobs = nobs; + this.nvars = nvars; + this.like = like; + } + + /** {@inheritDoc} */ + @Override public OLSMultipleLinearRegressionModel train(double[] data) { + OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression(threshold); + + regression.newSampleData(data, nobs, nvars, like); + + return new OLSMultipleLinearRegressionModel(regression.getX(), regression.solver()); + } +} http://git-wip-us.apache.org/repos/asf/ignite/blob/c5c512e4/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java ---------------------------------------------------------------------- diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java index 47910c8..7a61bad 100644 --- a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java +++ b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java @@ -32,7 +32,8 @@ import org.junit.runners.Suite; MathImplMainTestSuite.class, RegressionsTestSuite.class, ClusteringTestSuite.class, - DecisionTreesTestSuite.class + DecisionTreesTestSuite.class, + LocalModelsTest.class }) public class IgniteMLTestSuite { // No-op.