ignite-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From agoncha...@apache.org
Subject [43/50] [abbrv] ignite git commit: IGNITE-8542: [ML] Add OneVsRest Trainer to handle cases with multiple class labels in dataset.
Date Wed, 28 Nov 2018 11:53:14 GMT
IGNITE-8542: [ML] Add OneVsRest Trainer to handle cases with
multiple class labels in dataset.

This closes #5512


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/c3fd4a93
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/c3fd4a93
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/c3fd4a93

Branch: refs/heads/ignite-9720
Commit: c3fd4a930cc1a76b4d1fbccc6d764bdfe88da941
Parents: 3885f3f
Author: zaleslaw <zaleslaw.sin@gmail.com>
Authored: Wed Nov 28 01:45:11 2018 +0300
Committer: Yury Babak <ybabak@gridgain.com>
Committed: Wed Nov 28 01:45:11 2018 +0300

----------------------------------------------------------------------
 .../ignite/ml/multiclass/MultiClassModel.java   | 115 +++++++++++++++
 .../ignite/ml/multiclass/OneVsRestTrainer.java  | 147 +++++++++++++++++++
 .../org/apache/ignite/ml/IgniteMLTestSuite.java |   4 +-
 .../ml/multiclass/MultiClassTestSuite.java      |  32 ++++
 .../ml/multiclass/OneVsRestTrainerTest.java     | 126 ++++++++++++++++
 5 files changed, 423 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/c3fd4a93/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/MultiClassModel.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/MultiClassModel.java
b/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/MultiClassModel.java
new file mode 100644
index 0000000..8520aa9
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/MultiClassModel.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.multiclass;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.TreeMap;
+import org.apache.ignite.ml.Exportable;
+import org.apache.ignite.ml.Exporter;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/** Base class for multi-classification model for set of classifiers. */
+public class MultiClassModel<M extends Model<Vector, Double>> implements Model<Vector,
Double>, Exportable<MultiClassModel>, Serializable {
+    /** */
+    private static final long serialVersionUID = -114986533359917L;
+
+    /** List of models associated with each class. */
+    private Map<Double, M> models;
+
+    /** */
+    public MultiClassModel() {
+        this.models = new HashMap<>();
+    }
+
+    /**
+     * Adds a specific binary classifier to the bunch of same classifiers.
+     *
+     * @param clsLb The class label for the added model.
+     * @param mdl The model.
+     */
+    public void add(double clsLb, M mdl) {
+        models.put(clsLb, mdl);
+    }
+
+    /**
+     * @param clsLb Class label.
+     * @return model for class label if it exists.
+     */
+    public Optional<M> getModel(Double clsLb) {
+        return Optional.ofNullable(models.get(clsLb));
+    }
+
+    /** {@inheritDoc} */
+    @Override public Double apply(Vector input) {
+        TreeMap<Double, Double> maxMargins = new TreeMap<>();
+
+        models.forEach((k, v) -> maxMargins.put(v.apply(input), k));
+
+        // returns value the most closest to 1
+        return maxMargins.lastEntry().getValue();
+    }
+
+    /** {@inheritDoc} */
+    @Override public <P> void saveModel(Exporter<MultiClassModel, P> exporter,
P path) {
+        exporter.save(this, path);
+    }
+
+    /** {@inheritDoc} */
+    @Override public boolean equals(Object o) {
+        if (this == o)
+            return true;
+
+        if (o == null || getClass() != o.getClass())
+            return false;
+
+        MultiClassModel mdl = (MultiClassModel)o;
+
+        return Objects.equals(models, mdl.models);
+    }
+
+    /** {@inheritDoc} */
+    @Override public int hashCode() {
+        return Objects.hash(models);
+    }
+
+    /** {@inheritDoc} */
+    @Override public String toString() {
+        StringBuilder wholeStr = new StringBuilder();
+
+        models.forEach((clsLb, mdl) ->
+            wholeStr
+                .append("The class with label ")
+                .append(clsLb)
+                .append(" has classifier: ")
+                .append(mdl.toString())
+                .append(System.lineSeparator())
+        );
+
+        return wholeStr.toString();
+    }
+
+    /** {@inheritDoc} */
+    @Override public String toString(boolean pretty) {
+        return toString();
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/c3fd4a93/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java
b/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java
new file mode 100644
index 0000000..7426506
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.multiclass;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap;
+import org.apache.ignite.ml.structures.partition.LabelPartitionDataOnHeap;
+import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
+
+/**
+ * This is a common heuristic trainer for multi-class labeled models.
+ *
+ * NOTE: The current implementation suffers from unbalanced training over the dataset due
to unweighted approach
+ * during the process of reassign labels from all range of labels to 0,1.
+ */
+public class OneVsRestTrainer<M extends Model<Vector, Double>>
+    extends SingleLabelDatasetTrainer<MultiClassModel<M>> {
+    /** The common binary classifier with all hyper-parameters to spread them for all separate
trainings . */
+    private SingleLabelDatasetTrainer<M> classifier;
+
+    /** */
+    public OneVsRestTrainer(SingleLabelDatasetTrainer<M> classifier) {
+        this.classifier = classifier;
+    }
+
+    /**
+     * Trains model based on the specified data.
+     *
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @return Model.
+     */
+    @Override public <K, V> MultiClassModel<M> fit(DatasetBuilder<K, V>
datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <K, V> MultiClassModel<M> updateModel(MultiClassModel<M>
newMdl,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+
+        List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor);
+
+        if (classes.isEmpty())
+            return getLastTrainedModelOrThrowEmptyDatasetException(newMdl);
+
+        MultiClassModel<M> multiClsMdl = new MultiClassModel<>();
+
+        classes.forEach(clsLb -> {
+            IgniteBiFunction<K, V, Double> lbTransformer = (k, v) -> {
+                Double lb = lbExtractor.apply(k, v);
+
+                if (lb.equals(clsLb))
+                    return 1.0;
+                else
+                    return 0.0;
+            };
+
+            M mdl = Optional.ofNullable(newMdl)
+                .flatMap(multiClassModel -> multiClassModel.getModel(clsLb))
+                .map(learnedModel -> classifier.update(learnedModel, datasetBuilder, featureExtractor,
lbTransformer))
+                .orElseGet(() -> classifier.fit(datasetBuilder, featureExtractor, lbTransformer));
+
+            multiClsMdl.add(clsLb, mdl);
+        });
+
+        return multiClsMdl;
+    }
+
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(MultiClassModel<M> mdl) {
+        return true;
+    }
+
+    /** Iterates among dataset and collects class labels. */
+    private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V>
datasetBuilder,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+        assert datasetBuilder != null;
+
+        PartitionDataBuilder<K, V, EmptyContext, LabelPartitionDataOnHeap> partDataBuilder
= new LabelPartitionDataBuilderOnHeap<>(lbExtractor);
+
+        List<Double> res = new ArrayList<>();
+
+        try (Dataset<EmptyContext, LabelPartitionDataOnHeap> dataset = datasetBuilder.build(
+            (upstream, upstreamSize) -> new EmptyContext(),
+            partDataBuilder
+        )) {
+            final Set<Double> clsLabels = dataset.compute(data -> {
+                final Set<Double> locClsLabels = new HashSet<>();
+
+                final double[] lbs = data.getY();
+
+                for (double lb : lbs)
+                    locClsLabels.add(lb);
+
+                return locClsLabels;
+            }, (a, b) -> {
+                if (a == null)
+                    return b == null ? new HashSet<>() : b;
+                if (b == null)
+                    return a;
+                return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet());
+            });
+
+            if (clsLabels != null)
+                res.addAll(clsLabels);
+
+        }
+        catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+        return res;
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/c3fd4a93/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
index f9645d8..78d6659 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
@@ -26,6 +26,7 @@ import org.apache.ignite.ml.genetic.GAGridTestSuite;
 import org.apache.ignite.ml.inference.InferenceTestSuite;
 import org.apache.ignite.ml.knn.KNNTestSuite;
 import org.apache.ignite.ml.math.MathImplMainTestSuite;
+import org.apache.ignite.ml.multiclass.MultiClassTestSuite;
 import org.apache.ignite.ml.nn.MLPTestSuite;
 import org.apache.ignite.ml.pipeline.PipelineTestSuite;
 import org.apache.ignite.ml.preprocessing.PreprocessingTestSuite;
@@ -61,7 +62,8 @@ import org.junit.runners.Suite;
     StructuresTestSuite.class,
     CommonTestSuite.class,
     InferenceTestSuite.class,
-    BaggingTest.class
+    BaggingTest.class,
+    MultiClassTestSuite.class
 })
 public class IgniteMLTestSuite {
     // No-op.

http://git-wip-us.apache.org/repos/asf/ignite/blob/c3fd4a93/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/MultiClassTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/MultiClassTestSuite.java
b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/MultiClassTestSuite.java
new file mode 100644
index 0000000..551597f
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/MultiClassTestSuite.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.multiclass;
+
+import org.junit.runner.RunWith;
+import org.junit.runners.Suite;
+
+/**
+ * Test suite for multilayer perceptrons.
+ */
+@RunWith(Suite.class)
+@Suite.SuiteClasses({
+    OneVsRestTrainerTest.class
+})
+public class MultiClassTestSuite {
+    // No-op.
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/c3fd4a93/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java
new file mode 100644
index 0000000..9842d92
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.multiclass;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Tests for {@link OneVsRestTrainer}.
+ */
+public class OneVsRestTrainerTest extends TrainerTest {
+    /**
+     * Test trainer on 2 linearly separable sets.
+     */
+    @Test
+    public void testTrainWithTheLinearlySeparableCase() {
+        Map<Integer, double[]> cacheMock = new HashMap<>();
+
+        for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
+            cacheMock.put(i, twoLinearlySeparableClasses[i]);
+
+        LogisticRegressionSGDTrainer<?> binaryTrainer = new LogisticRegressionSGDTrainer<>()
+            .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
+                SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
+            .withMaxIterations(1000)
+            .withLocIterations(10)
+            .withBatchSize(100)
+            .withSeed(123L);
+
+        OneVsRestTrainer<LogisticRegressionModel> trainer = new OneVsRestTrainer<>(binaryTrainer);
+
+        MultiClassModel mdl = trainer.fit(
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        Assert.assertTrue(mdl.toString().length() > 0);
+        Assert.assertTrue(mdl.toString(true).length() > 0);
+        Assert.assertTrue(mdl.toString(false).length() > 0);
+
+        TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(-100, 0)), PRECISION);
+        TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 0)), PRECISION);
+    }
+
+    /** */
+    @Test
+    public void testUpdate() {
+        Map<Integer, double[]> cacheMock = new HashMap<>();
+
+        for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
+            cacheMock.put(i, twoLinearlySeparableClasses[i]);
+
+        LogisticRegressionSGDTrainer<?> binaryTrainer = new LogisticRegressionSGDTrainer<>()
+            .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
+                SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
+            .withMaxIterations(1000)
+            .withLocIterations(10)
+            .withBatchSize(100)
+            .withSeed(123L);
+
+        OneVsRestTrainer<LogisticRegressionModel> trainer = new OneVsRestTrainer<>(binaryTrainer);
+
+        MultiClassModel originalMdl = trainer.fit(
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        MultiClassModel updatedOnSameDS = trainer.update(
+            originalMdl,
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        MultiClassModel updatedOnEmptyDS = trainer.update(
+            originalMdl,
+            new HashMap<Integer, double[]>(),
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        List<Vector> vectors = Arrays.asList(
+            VectorUtils.of(-100, 0),
+            VectorUtils.of(100, 0)
+        );
+
+        for (Vector vec : vectors) {
+            TestUtils.assertEquals(originalMdl.apply(vec), updatedOnSameDS.apply(vec), PRECISION);
+            TestUtils.assertEquals(originalMdl.apply(vec), updatedOnEmptyDS.apply(vec), PRECISION);
+        }
+    }
+}


Mime
View raw message