ignite-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ch...@apache.org
Subject [1/3] ignite git commit: IGNITE-7174: Local MLP
Date Fri, 22 Dec 2017 15:07:59 GMT
Repository: ignite
Updated Branches:
  refs/heads/master 661ada687 -> e4f19215d


http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/test/java/org/apache/ignite/ml/nn/SimpleMLPLocalBatchTrainerInput.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/SimpleMLPLocalBatchTrainerInput.java
b/modules/ml/src/test/java/org/apache/ignite/ml/nn/SimpleMLPLocalBatchTrainerInput.java
new file mode 100644
index 0000000..ca5fe07
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/SimpleMLPLocalBatchTrainerInput.java
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn;
+
+import java.util.Random;
+import org.apache.ignite.lang.IgniteBiTuple;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
+import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
+import org.apache.ignite.ml.nn.initializers.RandomInitializer;
+import org.apache.ignite.ml.util.Utils;
+
+/**
+ * Class for local batch training of {@link MultilayerPerceptron}.
+ *
+ * It is constructed from two matrices: one containing inputs of function to approximate
and other containing ground truth
+ * values of this function for corresponding inputs.
+ *
+ * We fix batch size given by this input by some constant value.
+ */
+public class SimpleMLPLocalBatchTrainerInput implements LocalBatchTrainerInput<MultilayerPerceptron>
{
+    /**
+     * Multilayer perceptron to be trained.
+     */
+    private final MultilayerPerceptron mlp;
+
+    /**
+     * Inputs stored as columns.
+     */
+    private Matrix inputs;
+
+    /**
+     * Ground truths stored as columns.
+     */
+    private final Matrix groundTruth;
+
+    /**
+     * Size of batch returned on each step.
+     */
+    private int batchSize;
+
+    /**
+     * Construct instance of this class.
+     *
+     * @param arch Architecture of multilayer perceptron.
+     * @param rnd Random numbers generator.
+     * @param inputs Inputs stored as columns.
+     * @param groundTruth Ground truth stored as columns.
+     * @param batchSize Size of batch returned on each step.
+     */
+    public SimpleMLPLocalBatchTrainerInput(MLPArchitecture arch, Random rnd, Matrix inputs,
Matrix groundTruth, int batchSize) {
+        this.mlp = new MultilayerPerceptron(arch, new RandomInitializer(rnd));
+        this.inputs = inputs;
+        this.groundTruth = groundTruth;
+        this.batchSize = batchSize;
+    }
+
+    /** {@inheritDoc} */
+    @Override public IgniteBiTuple<Matrix, Matrix> getBatch() {
+        int inputRowSize = inputs.rowSize();
+        int outputRowSize = groundTruth.rowSize();
+
+        Matrix vectors = new DenseLocalOnHeapMatrix(inputRowSize, batchSize);
+        Matrix labels = new DenseLocalOnHeapMatrix(outputRowSize, batchSize);
+
+        int[] samples = Utils.selectKDistinct(inputs.columnSize(), batchSize);
+
+        for (int i = 0; i < batchSize; i++) {
+            vectors.assignColumn(i, inputs.getCol(samples[i]));
+            labels.assignColumn(i, groundTruth.getCol(samples[i]));
+        }
+
+        return new IgniteBiTuple<>(vectors, labels);
+    }
+
+    /** {@inheritDoc} */
+    @Override public MultilayerPerceptron mdl() {
+        return mlp;
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/Mnist.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/Mnist.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/Mnist.java
new file mode 100644
index 0000000..cf959a5
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/Mnist.java
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.nn.performance;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.util.List;
+import java.util.Properties;
+import java.util.Random;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.apache.ignite.internal.util.typedef.X;
+import org.apache.ignite.lang.IgniteBiTuple;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Tracer;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.VectorUtils;
+import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+import org.apache.ignite.ml.nn.Activators;
+import org.apache.ignite.ml.nn.LossFunctions;
+import org.apache.ignite.ml.nn.MultilayerPerceptron;
+import org.apache.ignite.ml.nn.SimpleMLPLocalBatchTrainerInput;
+import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
+import org.apache.ignite.ml.nn.trainers.local.MLPLocalBatchTrainer;
+import org.apache.ignite.ml.nn.updaters.RPropUpdater;
+import org.apache.ignite.ml.trees.performance.ColumnDecisionTreeTrainerBenchmark;
+import org.apache.ignite.ml.util.MnistUtils;
+import org.junit.Test;
+
+import static org.apache.ignite.ml.math.VectorUtils.num2Vec;
+
+/**
+ * Various benchmarks for hand runs.
+ */
+public class Mnist {
+    /** Name of the property specifying path to training set images. */
+    private static final String PROP_TRAINING_IMAGES = "mnist.training.images";
+
+    /** Name of property specifying path to training set labels. */
+    private static final String PROP_TRAINING_LABELS = "mnist.training.labels";
+
+    /** Name of property specifying path to test set images. */
+    private static final String PROP_TEST_IMAGES = "mnist.test.images";
+
+    /** Name of property specifying path to test set labels. */
+    private static final String PROP_TEST_LABELS = "mnist.test.labels";
+
+    /**
+     * Run decision tree classifier on MNIST using bi-indexed cache as a storage for dataset.
+     * To run this test rename this method so it starts from 'test'.
+     *
+     * @throws IOException In case of loading MNIST dataset errors.
+     */
+    @Test
+    public void tstMNIST() throws IOException {
+        int samplesCntCnt = 60_000;
+        int featCnt = 28 * 28;
+        int hiddenNeuronsCnt = 100;
+
+        Properties props = loadMNISTProperties();
+
+        Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnist(props.getProperty(PROP_TRAINING_IMAGES),
+            props.getProperty(PROP_TRAINING_LABELS), new Random(123L), samplesCntCnt);
+
+        Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnist(props.getProperty(PROP_TEST_IMAGES),
+            props.getProperty(PROP_TEST_LABELS), new Random(123L), 10_000);
+
+        IgniteBiTuple<Matrix, Matrix> ds = createDataset(trainingMnistStream, samplesCntCnt,
featCnt);
+        IgniteBiTuple<Matrix, Matrix> testDs = createDataset(testMnistStream, 10000,
featCnt);
+
+        MLPArchitecture conf = new MLPArchitecture(featCnt).
+            withAddedLayer(hiddenNeuronsCnt, true, Activators.SIGMOID).
+            withAddedLayer(10, false, Activators.SIGMOID);
+
+        SimpleMLPLocalBatchTrainerInput input = new SimpleMLPLocalBatchTrainerInput(conf,
+            new Random(),
+            ds.get1(),
+            ds.get2(),
+            2000);
+
+        MultilayerPerceptron mdl = new MLPLocalBatchTrainer<>(LossFunctions.MSE,
+            () -> new RPropUpdater(0.1, 1.2, 0.5),
+            1E-7,
+            200).
+            train(input);
+
+        X.println("Training started");
+        long before = System.currentTimeMillis();
+
+        X.println("Training finished in " + (System.currentTimeMillis() - before));
+
+        Vector predicted = mdl.apply(testDs.get1()).foldColumns(VectorUtils::vec2Num);
+        Vector truth = testDs.get2().foldColumns(VectorUtils::vec2Num);
+
+        Tracer.showAscii(truth);
+        Tracer.showAscii(predicted);
+    }
+
+    /** */
+    private IgniteBiTuple<Matrix, Matrix> createDataset(Stream<DenseLocalOnHeapVector>
s, int samplesCnt, int featCnt) {
+        Matrix vectors = new DenseLocalOnHeapMatrix(featCnt, samplesCnt);
+        Matrix labels = new DenseLocalOnHeapMatrix(10, samplesCnt);
+        List<DenseLocalOnHeapVector> sc = s.collect(Collectors.toList());
+
+        for (int i = 0; i < samplesCnt; i++) {
+            DenseLocalOnHeapVector v = sc.get(i);
+            vectors.assignColumn(i, v.viewPart(0, featCnt));
+            labels.assignColumn(i, num2Vec((int)v.getX(featCnt), 10));
+        }
+
+        return new IgniteBiTuple<>(vectors, labels);
+    }
+
+    /** Load properties for MNIST tests. */
+    private static Properties loadMNISTProperties() throws IOException {
+        Properties res = new Properties();
+
+        InputStream is = ColumnDecisionTreeTrainerBenchmark.class.getClassLoader().getResourceAsStream("manualrun/trees/columntrees.manualrun.properties");
+
+        res.load(is);
+
+        return res;
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelTest.java
index 37c972c..74d5524 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/OLSMultipleLinearRegressionModelTest.java
@@ -48,6 +48,6 @@ public class OLSMultipleLinearRegressionModelTest {
         OLSMultipleLinearRegressionModel mdl = trainer.train(data);
 
         TestUtils.assertEquals(new double[] {0d, 0d, 0d, 0d, 0d, 0d},
-            val.minus(mdl.predict(val)).getStorage().data(), 1e-13);
+            val.minus(mdl.apply(val)).getStorage().data(), 1e-13);
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java
index 9e81bea..b090f43 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java
@@ -184,8 +184,8 @@ public class ColumnDecisionTreeTrainerTest extends BaseDecisionTreeTest
{
         byRegion.keySet().forEach(k -> {
             LabeledVectorDouble sp = byRegion.get(k).get(0);
             Tracer.showAscii(sp.features());
-            X.println("Actual and predicted vectors [act=" + sp.label() + " " + ", pred="
+ mdl.predict(sp.features()) + "]");
-            assert mdl.predict(sp.features()) == sp.doubleLabel();
+            X.println("Actual and predicted vectors [act=" + sp.label() + " " + ", pred="
+ mdl.apply(sp.features()) + "]");
+            assert mdl.apply(sp.features()) == sp.doubleLabel();
         });
     }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java
b/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java
index 524a8ad..a72dec2 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java
@@ -274,8 +274,8 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest
{
         byRegion.keySet().forEach(k -> {
             LabeledVectorDouble sp = byRegion.get(k).get(0);
             Tracer.showAscii(sp.features());
-            X.println("Predicted value and label [pred=" + mdl.predict(sp.features()) + ",
label=" + sp.doubleLabel() + "]");
-            assert mdl.predict(sp.features()) == sp.doubleLabel();
+            X.println("Predicted value and label [pred=" + mdl.apply(sp.features()) + ",
label=" + sp.doubleLabel() + "]");
+            assert mdl.apply(sp.features()) == sp.doubleLabel();
         });
     }
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/ml/src/test/resources/manualrun/trees/columntrees.manualrun.properties
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/resources/manualrun/trees/columntrees.manualrun.properties
b/modules/ml/src/test/resources/manualrun/trees/columntrees.manualrun.properties
index d9574f3..2fd77ed 100644
--- a/modules/ml/src/test/resources/manualrun/trees/columntrees.manualrun.properties
+++ b/modules/ml/src/test/resources/manualrun/trees/columntrees.manualrun.properties
@@ -15,7 +15,7 @@
 # limitations under the License.
 #
 # Paths to mnist dataset parts.
-mnist.training.images=/path/to/train-images-idx3-ubyte
-mnist.training.labels=/path/to/train-labels-idx1-ubyte
-mnist.test.images=/path/to/t10k-images-idx3-ubyte
-mnist.test.labels=/path/to/t10k-labels-idx1-ubyte
+mnist.training.images=/path/to/mnist/train-images-idx3-ubyte
+mnist.training.labels=/path/to/mnist/train-labels-idx1-ubyte
+mnist.test.images=/path/to/mnist/t10k-images-idx3-ubyte
+mnist.test.labels=/path/to/mnist/t10k-labels-idx1-ubyte

http://git-wip-us.apache.org/repos/asf/ignite/blob/e4f19215/modules/yardstick/src/main/ml/org/apache/ignite/yardstick/ml/trees/SplitDataGenerator.java
----------------------------------------------------------------------
diff --git a/modules/yardstick/src/main/ml/org/apache/ignite/yardstick/ml/trees/SplitDataGenerator.java
b/modules/yardstick/src/main/ml/org/apache/ignite/yardstick/ml/trees/SplitDataGenerator.java
index f530300..f9117f4 100644
--- a/modules/yardstick/src/main/ml/org/apache/ignite/yardstick/ml/trees/SplitDataGenerator.java
+++ b/modules/yardstick/src/main/ml/org/apache/ignite/yardstick/ml/trees/SplitDataGenerator.java
@@ -143,7 +143,7 @@ class SplitDataGenerator<V extends Vector> {
 
         DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m,
catFeaturesInfo));
 
-        byRegion.keySet().forEach(k -> mdl.predict(byRegion.get(k).get(0).features()));
+        byRegion.keySet().forEach(k -> mdl.apply(byRegion.get(k).get(0).features()));
     }
 
     /**


Mime
View raw message