ignite-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sboi...@apache.org
Subject [09/11] ignite git commit: IGNITE-9282: [ML] Add Naive Bayes classifier
Date Wed, 24 Oct 2018 06:09:35 GMT
IGNITE-9282: [ML] Add Naive Bayes classifier

this closes #4869


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/e29a8cb9
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/e29a8cb9
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/e29a8cb9

Branch: refs/heads/ignite-9720
Commit: e29a8cb9380fb2c1f6815d40670315919af58d3b
Parents: 86f5437
Author: dehasi <rgaleyev@gmail.com>
Authored: Tue Oct 23 19:11:23 2018 +0300
Committer: Yury Babak <ybabak@gridgain.com>
Committed: Tue Oct 23 19:11:23 2018 +0300

----------------------------------------------------------------------
 .../GaussianNaiveBayesTrainerExample.java       | 113 +++++++++++
 .../examples/ml/naivebayes/package-info.java    |  22 +++
 .../ignite/examples/util/IrisDataset.java       | 129 +++++++++++++
 .../gaussian/GaussianNaiveBayesModel.java       | 111 +++++++++++
 .../gaussian/GaussianNaiveBayesSumsHolder.java  |  55 ++++++
 .../gaussian/GaussianNaiveBayesTrainer.java     | 186 +++++++++++++++++++
 .../ml/naivebayes/gaussian/package-info.java    |  22 +++
 .../ignite/ml/naivebayes/package-info.java      |  22 +++
 .../gaussian/GaussianNaiveBayesModelTest.java   |  50 +++++
 .../gaussian/GaussianNaiveBayesTest.java        |  86 +++++++++
 .../gaussian/GaussianNaiveBayesTrainerTest.java | 182 ++++++++++++++++++
 11 files changed, 978 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
new file mode 100644
index 0000000..cd8383e
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.naivebayes;
+
+import java.util.Arrays;
+import javax.cache.Cache;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.query.QueryCursor;
+import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.examples.ml.util.TestCache;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel;
+import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesTrainer;
+
+import static org.apache.ignite.examples.util.IrisDataset.irisDatasetFirstAndSecondClasses;
+
+/**
+ * Run naive Bayes classification model based on <a href="https://en.wikipedia.org/wiki/Naive_Bayes_classifier">
naive
+ * Bayes classifier</a> algorithm ({@link GaussianNaiveBayesTrainer}) over distributed
cache.
+ * <p>
+ * Code in this example launches Ignite grid and fills the cache with test data points (based
on the
+ * <a href="https://en.wikipedia.org/wiki/Iris_flower_data_set"></a>Iris dataset</a>).</p>
+ * <p>
+ * After that it trains the naive Bayes classification model based on the specified data.</p>
+ * <p>
+ * Finally, this example loops over the test set of data points, applies the trained model
to predict the target value,
+ * compares prediction to expected outcome (ground truth), and builds
+ * <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrix</a>.</p>
+ * <p>
+ * You can change the test data used in this example and re-run it to explore this algorithm
further.</p>
+ */
+public class GaussianNaiveBayesTrainerExample {
+    /** Run example. */
+    public static void main(String[] args) throws InterruptedException {
+        System.out.println();
+        System.out.println(">>> Naive Bayes classification model over partitioned
dataset usage example started.");
+        // Start ignite grid.
+        try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+            System.out.println(">>> Ignite grid started.");
+
+            IgniteCache<Integer, double[]> dataCache = new TestCache(ignite).fillCacheWith(irisDatasetFirstAndSecondClasses);
+
+            System.out.println(">>> Create new naive Bayes classification trainer
object.");
+            GaussianNaiveBayesTrainer trainer = new GaussianNaiveBayesTrainer();
+
+            System.out.println(">>> Perform the training to get the model.");
+            GaussianNaiveBayesModel mdl = trainer.fit(
+                ignite,
+                dataCache,
+                (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+                (k, v) -> v[0]
+            );
+
+            System.out.println(">>> Naive Bayes model: " + mdl);
+
+            int amountOfErrors = 0;
+            int totalAmount = 0;
+
+            // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
+            int[][] confusionMtx = {{0, 0}, {0, 0}};
+
+            try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new
ScanQuery<>())) {
+                for (Cache.Entry<Integer, double[]> observation : observations) {
+                    double[] val = observation.getValue();
+                    Vector inputs = VectorUtils.of(Arrays.copyOfRange(val, 1, val.length));
+                    double groundTruth = val[0];
+
+                    double prediction = mdl.apply(inputs);
+
+                    totalAmount++;
+                    if (groundTruth != prediction)
+                        amountOfErrors++;
+
+                    int idx1 = (int)prediction;
+                    int idx2 = (int)groundTruth;
+
+                    confusionMtx[idx1][idx2]++;
+
+                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction,
groundTruth);
+                }
+
+                System.out.println(">>> ---------------------------------");
+
+                System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
+                System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
+            }
+
+            System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
+            System.out.println(">>> ---------------------------------");
+
+            System.out.println(">>> Naive bayes model over partitioned dataset usage
example completed.");
+        }
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/package-info.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/package-info.java
b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/package-info.java
new file mode 100644
index 0000000..7f0420c
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * ML naive Bayes classificator examples.
+ */
+package org.apache.ignite.examples.ml.naivebayes;
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/examples/src/main/java/org/apache/ignite/examples/util/IrisDataset.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/util/IrisDataset.java b/examples/src/main/java/org/apache/ignite/examples/util/IrisDataset.java
new file mode 100644
index 0000000..53080e8
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/util/IrisDataset.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ignite.examples.util;
+
+/** Contains data from the <a href="https://en.wikipedia.org/wiki/Iris_flower_data_set"></a>Iris
dataset</a>. */
+public final class IrisDataset {
+
+    /** The 1st and 2nd classes from the Iris dataset. */
+    public static final double[][] irisDatasetFirstAndSecondClasses = {
+        {0, 5.1, 3.5, 1.4, 0.2},
+        {0, 4.9, 3, 1.4, 0.2},
+        {0, 4.7, 3.2, 1.3, 0.2},
+        {0, 4.6, 3.1, 1.5, 0.2},
+        {0, 5, 3.6, 1.4, 0.2},
+        {0, 5.4, 3.9, 1.7, 0.4},
+        {0, 4.6, 3.4, 1.4, 0.3},
+        {0, 5, 3.4, 1.5, 0.2},
+        {0, 4.4, 2.9, 1.4, 0.2},
+        {0, 4.9, 3.1, 1.5, 0.1},
+        {0, 5.4, 3.7, 1.5, 0.2},
+        {0, 4.8, 3.4, 1.6, 0.2},
+        {0, 4.8, 3, 1.4, 0.1},
+        {0, 4.3, 3, 1.1, 0.1},
+        {0, 5.8, 4, 1.2, 0.2},
+        {0, 5.7, 4.4, 1.5, 0.4},
+        {0, 5.4, 3.9, 1.3, 0.4},
+        {0, 5.1, 3.5, 1.4, 0.3},
+        {0, 5.7, 3.8, 1.7, 0.3},
+        {0, 5.1, 3.8, 1.5, 0.3},
+        {0, 5.4, 3.4, 1.7, 0.2},
+        {0, 5.1, 3.7, 1.5, 0.4},
+        {0, 4.6, 3.6, 1, 0.2},
+        {0, 5.1, 3.3, 1.7, 0.5},
+        {0, 4.8, 3.4, 1.9, 0.2},
+        {0, 5, 3, 1.6, 0.2},
+        {0, 5, 3.4, 1.6, 0.4},
+        {0, 5.2, 3.5, 1.5, 0.2},
+        {0, 5.2, 3.4, 1.4, 0.2},
+        {0, 4.7, 3.2, 1.6, 0.2},
+        {0, 4.8, 3.1, 1.6, 0.2},
+        {0, 5.4, 3.4, 1.5, 0.4},
+        {0, 5.2, 4.1, 1.5, 0.1},
+        {0, 5.5, 4.2, 1.4, 0.2},
+        {0, 4.9, 3.1, 1.5, 0.1},
+        {0, 5, 3.2, 1.2, 0.2},
+        {0, 5.5, 3.5, 1.3, 0.2},
+        {0, 4.9, 3.1, 1.5, 0.1},
+        {0, 4.4, 3, 1.3, 0.2},
+        {0, 5.1, 3.4, 1.5, 0.2},
+        {0, 5, 3.5, 1.3, 0.3},
+        {0, 4.5, 2.3, 1.3, 0.3},
+        {0, 4.4, 3.2, 1.3, 0.2},
+        {0, 5, 3.5, 1.6, 0.6},
+        {0, 5.1, 3.8, 1.9, 0.4},
+        {0, 4.8, 3, 1.4, 0.3},
+        {0, 5.1, 3.8, 1.6, 0.2},
+        {0, 4.6, 3.2, 1.4, 0.2},
+        {0, 5.3, 3.7, 1.5, 0.2},
+        {0, 5, 3.3, 1.4, 0.2},
+        {1, 7, 3.2, 4.7, 1.4},
+        {1, 6.4, 3.2, 4.5, 1.5},
+        {1, 6.9, 3.1, 4.9, 1.5},
+        {1, 5.5, 2.3, 4, 1.3},
+        {1, 6.5, 2.8, 4.6, 1.5},
+        {1, 5.7, 2.8, 4.5, 1.3},
+        {1, 6.3, 3.3, 4.7, 1.6},
+        {1, 4.9, 2.4, 3.3, 1},
+        {1, 6.6, 2.9, 4.6, 1.3},
+        {1, 5.2, 2.7, 3.9, 1.4},
+        {1, 5, 2, 3.5, 1},
+        {1, 5.9, 3, 4.2, 1.5},
+        {1, 6, 2.2, 4, 1},
+        {1, 6.1, 2.9, 4.7, 1.4},
+        {1, 5.6, 2.9, 3.6, 1.3},
+        {1, 6.7, 3.1, 4.4, 1.4},
+        {1, 5.6, 3, 4.5, 1.5},
+        {1, 5.8, 2.7, 4.1, 1},
+        {1, 6.2, 2.2, 4.5, 1.5},
+        {1, 5.6, 2.5, 3.9, 1.1},
+        {1, 5.9, 3.2, 4.8, 1.8},
+        {1, 6.1, 2.8, 4, 1.3},
+        {1, 6.3, 2.5, 4.9, 1.5},
+        {1, 6.1, 2.8, 4.7, 1.2},
+        {1, 6.4, 2.9, 4.3, 1.3},
+        {1, 6.6, 3, 4.4, 1.4},
+        {1, 6.8, 2.8, 4.8, 1.4},
+        {1, 6.7, 3, 5, 1.7},
+        {1, 6, 2.9, 4.5, 1.5},
+        {1, 5.7, 2.6, 3.5, 1},
+        {1, 5.5, 2.4, 3.8, 1.1},
+        {1, 5.5, 2.4, 3.7, 1},
+        {1, 5.8, 2.7, 3.9, 1.2},
+        {1, 6, 2.7, 5.1, 1.6},
+        {1, 5.4, 3, 4.5, 1.5},
+        {1, 6, 3.4, 4.5, 1.6},
+        {1, 6.7, 3.1, 4.7, 1.5},
+        {1, 6.3, 2.3, 4.4, 1.3},
+        {1, 5.6, 3, 4.1, 1.3},
+        {1, 5.5, 2.5, 4, 1.3},
+        {1, 5.5, 2.6, 4.4, 1.2},
+        {1, 6.1, 3, 4.6, 1.4},
+        {1, 5.8, 2.6, 4, 1.2},
+        {1, 5, 2.3, 3.3, 1},
+        {1, 5.6, 2.7, 4.2, 1.3},
+        {1, 5.7, 3, 4.2, 1.2},
+        {1, 5.7, 2.9, 4.2, 1.3},
+        {1, 6.2, 2.9, 4.3, 1.3},
+        {1, 5.1, 2.5, 3, 1.1},
+        {1, 5.7, 2.8, 4.1, 1.3},
+    };
+
+    /** */
+    private IrisDataset() {
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModel.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModel.java
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModel.java
new file mode 100644
index 0000000..985d9fe
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModel.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes.gaussian;
+
+import java.io.Serializable;
+import org.apache.ignite.ml.Exportable;
+import org.apache.ignite.ml.Exporter;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * Simple naive Bayes model which predicts result value {@code y} belongs to a class {@code
C_k, k in [0..K]} as {@code
+ * p(C_k,y) = p(C_k)*p(y_1,C_k) *...*p(y_n,C_k) / p(y)}. Return the number of the most possible
class.
+ */
+public class GaussianNaiveBayesModel implements Model<Vector, Double>, Exportable<GaussianNaiveBayesModel>,
Serializable {
+    /** */
+    private static final long serialVersionUID = -127386523291350345L;
+    /** Means of features for all classes. kth row contains means for labels[k] class. */
+    private final double[][] means;
+    /** Variances of features for all classes. kth row contains variances for labels[k] class
*/
+    private final double[][] variances;
+    /** Prior probabilities of each class */
+    private final double[] classProbabilities;
+    /** Labels. */
+    private final double[] labels;
+    /** Feature sum, squared sum and cound per label. */
+    private final GaussianNaiveBayesSumsHolder sumsHolder;
+
+    /**
+     * @param means Means of features for all classes.
+     * @param variances Variances of features for all classes.
+     * @param classProbabilities Probabilities for all classes.
+     * @param labels Labels.
+     * @param sumsHolder Feature sum, squared sum and count sum per label. This data is used
for future model updating.
+     */
+    public GaussianNaiveBayesModel(double[][] means, double[][] variances,
+        double[] classProbabilities, double[] labels, GaussianNaiveBayesSumsHolder sumsHolder)
{
+        this.means = means;
+        this.variances = variances;
+        this.classProbabilities = classProbabilities;
+        this.labels = labels;
+        this.sumsHolder = sumsHolder;
+    }
+
+    /** {@inheritDoc} */
+    @Override public <P> void saveModel(Exporter<GaussianNaiveBayesModel, P>
exporter, P path) {
+        exporter.save(this, path);
+    }
+
+    /** Returns a number of class to which the input belongs. */
+    @Override public Double apply(Vector vector) {
+        int k = classProbabilities.length;
+
+        double maxProbapility = .0;
+        int max = 0;
+
+        for (int i = 0; i < k; i++) {
+            double p = classProbabilities[i];
+            for (int j = 0; j < vector.size(); j++) {
+                double x = vector.get(j);
+                double g = gauss(x, means[i][j], variances[i][j]);
+                p *= g;
+            }
+            if (p > maxProbapility) {
+                max = i;
+                maxProbapility = p;
+            }
+        }
+        return labels[max];
+    }
+
+    /** */
+    public double[][] getMeans() {
+        return means;
+    }
+
+    /** */
+    public double[][] getVariances() {
+        return variances;
+    }
+
+    /** */
+    public double[] getClassProbabilities() {
+        return classProbabilities;
+    }
+
+    /** */
+    public GaussianNaiveBayesSumsHolder getSumsHolder() {
+        return sumsHolder;
+    }
+
+    /** Gauss distribution */
+    private double gauss(double x, double mean, double variance) {
+        return Math.exp(-1. * Math.pow(x - mean, 2) / (2. * variance)) / Math.sqrt(2. * Math.PI
* variance);
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesSumsHolder.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesSumsHolder.java
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesSumsHolder.java
new file mode 100644
index 0000000..735bbd1
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesSumsHolder.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.ignite.ml.naivebayes.gaussian;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.math.util.MapUtil;
+
+/** Service class is used to calculate means and vaiances */
+class GaussianNaiveBayesSumsHolder implements Serializable, AutoCloseable {
+    /** Serial version uid. */
+    private static final long serialVersionUID = 1L;
+    /** Sum of all values for all features for each label */
+    Map<Double, double[]> featureSumsPerLbl = new HashMap<>();
+    /** Sum of all squared values for all features for each label */
+    Map<Double, double[]> featureSquaredSumsPerLbl = new HashMap<>();
+    /** Rows count for each label */
+    Map<Double, Integer> featureCountersPerLbl = new HashMap<>();
+
+    /** Merge to current */
+    GaussianNaiveBayesSumsHolder merge(GaussianNaiveBayesSumsHolder other) {
+        featureSumsPerLbl = MapUtil.mergeMaps(featureSumsPerLbl, other.featureSumsPerLbl,
this::sum, HashMap::new);
+        featureSquaredSumsPerLbl = MapUtil.mergeMaps(featureSquaredSumsPerLbl, other.featureSquaredSumsPerLbl,
this::sum, HashMap::new);
+        featureCountersPerLbl = MapUtil.mergeMaps(featureCountersPerLbl, other.featureCountersPerLbl,
(i1, i2) -> i1 + i2, HashMap::new);
+        return this;
+    }
+
+    /** In-place operation. Sums {@code arr2} to {@code arr1} element to element. */
+    private double[] sum(double[] arr1, double[] arr2) {
+        for (int i = 0; i < arr1.length; i++) {
+            arr1[i] += arr2[i];
+        }
+        return arr1;
+    }
+
+    /** */
+    @Override public void close() {
+        // Do nothing, GC will clean up.
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java
new file mode 100644
index 0000000..1c1df83
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes.gaussian;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
+
+/**
+ * Trainer for the naive Bayes classification model. The trainer calculates prior probabilities
from the input dataset.
+ * Prior probabilities can be also set by {@code setPriorProbabilities} or {@code withEquiprobableClasses}.
If {@code
+ * equiprobableClasses} is set, the probalilities of all classes will be {@code 1/k}, where
{@code k} is classes count.
+ */
+public class GaussianNaiveBayesTrainer extends SingleLabelDatasetTrainer<GaussianNaiveBayesModel>
{
+
+    /* Preset prior probabilities. */
+    private double[] priorProbabilities;
+    /* Sets equivalent probability for all classes. */
+    private boolean equiprobableClasses;
+
+    /**
+     * Trains model based on the specified data.
+     *
+     * @param datasetBuilder Dataset builder.
+     * @param featureExtractor Feature extractor.
+     * @param lbExtractor Label extractor.
+     * @return Model.
+     */
+    @Override public <K, V> GaussianNaiveBayesModel fit(DatasetBuilder<K, V>
datasetBuilder,
+        IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V,
Double> lbExtractor) {
+        return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
+    }
+
+    /** {@inheritDoc} */
+    @Override protected boolean checkState(GaussianNaiveBayesModel mdl) {
+        return true;
+    }
+
+    /** {@inheritDoc} */
+    @Override protected <K, V> GaussianNaiveBayesModel updateModel(GaussianNaiveBayesModel
mdl,
+        DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+        IgniteBiFunction<K, V, Double> lbExtractor) {
+        assert datasetBuilder != null;
+
+        try (Dataset<EmptyContext, GaussianNaiveBayesSumsHolder> dataset = datasetBuilder.build(
+            (upstream, upstreamSize) -> new EmptyContext(),
+            (upstream, upstreamSize, ctx) -> {
+
+                GaussianNaiveBayesSumsHolder res = new GaussianNaiveBayesSumsHolder();
+                while (upstream.hasNext()) {
+                    UpstreamEntry<K, V> entity = upstream.next();
+
+                    Vector features = featureExtractor.apply(entity.getKey(), entity.getValue());
+                    Double label = lbExtractor.apply(entity.getKey(), entity.getValue());
+
+                    double[] toMeans;
+                    double[] sqSum;
+
+                    if (!res.featureSumsPerLbl.containsKey(label)) {
+                        toMeans = new double[features.size()];
+                        Arrays.fill(toMeans, 0.);
+                        res.featureSumsPerLbl.put(label, toMeans);
+                    }
+                    if (!res.featureSquaredSumsPerLbl.containsKey(label)) {
+                        sqSum = new double[features.size()];
+                        res.featureSquaredSumsPerLbl.put(label, sqSum);
+                    }
+                    if (!res.featureCountersPerLbl.containsKey(label)) {
+                        res.featureCountersPerLbl.put(label, 0);
+                    }
+                    res.featureCountersPerLbl.put(label, res.featureCountersPerLbl.get(label)
+ 1);
+
+                    toMeans = res.featureSumsPerLbl.get(label);
+                    sqSum = res.featureSquaredSumsPerLbl.get(label);
+                    for (int j = 0; j < features.size(); j++) {
+                        double x = features.get(j);
+                        toMeans[j] += x;
+                        sqSum[j] += x * x;
+                    }
+                }
+                return res;
+            }
+        )) {
+            GaussianNaiveBayesSumsHolder sumsHolder = dataset.compute(t -> t, (a, b) ->
{
+                if (a == null)
+                    return b == null ? new GaussianNaiveBayesSumsHolder() : b;
+                if (b == null)
+                    return a;
+                return a.merge(b);
+            });
+            if (mdl != null && mdl.getSumsHolder() != null) {
+                sumsHolder = sumsHolder.merge(mdl.getSumsHolder());
+            }
+
+            List<Double> sortedLabels = new ArrayList<>(sumsHolder.featureCountersPerLbl.keySet());
+            sortedLabels.sort(Double::compareTo);
+            assert !sortedLabels.isEmpty() : "The dataset should contain at least one feature";
+
+            int labelCount = sortedLabels.size();
+            int featureCount = sumsHolder.featureSumsPerLbl.get(sortedLabels.get(0)).length;
+
+            double[][] means = new double[labelCount][featureCount];
+            double[][] variances = new double[labelCount][featureCount];
+            double[] classProbabilities = new double[labelCount];
+            double[] labels = new double[labelCount];
+
+            long datasetSize = sumsHolder.featureCountersPerLbl.values().stream().mapToInt(i
-> i).sum();
+
+            int lbl = 0;
+            for (Double label : sortedLabels) {
+                int count = sumsHolder.featureCountersPerLbl.get(label);
+                double[] sum = sumsHolder.featureSumsPerLbl.get(label);
+                double[] sqSum = sumsHolder.featureSquaredSumsPerLbl.get(label);
+
+                for (int i = 0; i < featureCount; i++) {
+                    means[lbl][i] = sum[i] / count;
+                    variances[lbl][i] = (sqSum[i] - sum[i] * sum[i] / count) / count;
+                }
+
+                if (equiprobableClasses) {
+                    classProbabilities[lbl] = 1. / labelCount;
+                }
+                else if (priorProbabilities != null) {
+                    assert classProbabilities.length == priorProbabilities.length;
+                    classProbabilities[lbl] = priorProbabilities[lbl];
+                }
+                else {
+                    classProbabilities[lbl] = (double)count / datasetSize;
+                }
+
+                labels[lbl] = label;
+                ++lbl;
+            }
+
+            return new GaussianNaiveBayesModel(means, variances, classProbabilities, labels,
sumsHolder);
+        }
+        catch (Exception e) {
+            throw new RuntimeException(e);
+        }
+
+    }
+
+    /** Sets equal probability for all classes. */
+    public GaussianNaiveBayesTrainer withEquiprobableClasses() {
+        resetSettings();
+        equiprobableClasses = true;
+        return this;
+    }
+
+    /** Sets prior probabilities. */
+    public GaussianNaiveBayesTrainer setPriorProbabilities(double[] priorProbabilities) {
+        resetSettings();
+        this.priorProbabilities = priorProbabilities.clone();
+        return this;
+    }
+
+    /** Sets default settings. */
+    public GaussianNaiveBayesTrainer resetSettings() {
+        equiprobableClasses = false;
+        priorProbabilities = null;
+        return this;
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/package-info.java
b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/package-info.java
new file mode 100644
index 0000000..4e572cf
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains Gaussian naive Bayes classifier.
+ */
+package org.apache.ignite.ml.naivebayes.gaussian;
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/package-info.java
new file mode 100644
index 0000000..fae5387
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains various naive Bayes classifiers.
+ */
+package org.apache.ignite.ml.naivebayes;
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java
new file mode 100644
index 0000000..c79c0d7
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModelTest.java
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes.gaussian;
+
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Tests for {@link GaussianNaiveBayesModel}.
+ */
+public class GaussianNaiveBayesModelTest {
+
+    /** */
+    @Test
+    public void testPredictWithTwoClasses() {
+        double first = 1;
+        double second = 2;
+        double[][] means = new double[][] {
+            {5.855, 176.25, 11.25},
+            {5.4175, 132.5, 7.5},
+        };
+        double[][] variances = new double[][] {
+            {3.5033E-2, 1.2292E2, 9.1667E-1},
+            {9.7225E-2, 5.5833E2, 1.6667},
+        };
+        double[] probabilities = new double[] {.5, .5};
+        GaussianNaiveBayesModel mdl = new GaussianNaiveBayesModel(means, variances, probabilities,
new double[] {first, second}, null);
+        Vector observation = VectorUtils.of(6, 130, 8);
+
+        Assert.assertEquals(second, mdl.apply(observation), 0.0001);
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTest.java
new file mode 100644
index 0000000..504b464
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTest.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes.gaussian;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Complex tests for naive Bayes algorithm with different datasets.
+ */
+public class GaussianNaiveBayesTest {
+    /** Precision in test checks. */
+    private static final double PRECISION = 1e-2;
+
+    /**
+     * An example data set from wikipedia article about Naive Bayes https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Sex_classification
+     */
+    @Test
+    public void wikipediaSexClassificationDataset() {
+        Map<Integer, double[]> data = new HashMap<>();
+        double male = 0.;
+        double female = 1.;
+        data.put(0, new double[] {male, 6, 180, 12});
+        data.put(2, new double[] {male, 5.92, 190, 11});
+        data.put(3, new double[] {male, 5.58, 170, 12});
+        data.put(4, new double[] {male, 5.92, 165, 10});
+        data.put(5, new double[] {female, 5, 100, 6});
+        data.put(6, new double[] {female, 5.5, 150, 8});
+        data.put(7, new double[] {female, 5.42, 130, 7});
+        data.put(8, new double[] {female, 5.75, 150, 9});
+        GaussianNaiveBayesTrainer trainer = new GaussianNaiveBayesTrainer();
+        GaussianNaiveBayesModel model = trainer.fit(
+            new LocalDatasetBuilder<>(data, 2),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+        Vector observation = VectorUtils.of(6, 130, 8);
+
+        Assert.assertEquals(female, model.apply(observation), PRECISION);
+    }
+
+    /** Dataset from Gaussian NB example in the scikit-learn documentation */
+    @Test
+    public void scikitLearnExample() {
+        Map<Integer, double[]> data = new HashMap<>();
+        double one = 1.;
+        double two = 2.;
+        data.put(0, new double[] {one, -1, 1});
+        data.put(2, new double[] {one, -2, -1});
+        data.put(3, new double[] {one, -3, -2});
+        data.put(4, new double[] {two, 1, 1});
+        data.put(5, new double[] {two, 2, 1});
+        data.put(6, new double[] {two, 3, 2});
+        GaussianNaiveBayesTrainer trainer = new GaussianNaiveBayesTrainer();
+        GaussianNaiveBayesModel model = trainer.fit(
+            new LocalDatasetBuilder<>(data, 2),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+        Vector observation = VectorUtils.of(-0.8, -1);
+
+        Assert.assertEquals(one, model.apply(observation), PRECISION);
+    }
+
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/e29a8cb9/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java
b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java
new file mode 100644
index 0000000..f70f7c2
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainerTest.java
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.naivebayes.gaussian;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+/**
+ * Tests for {@link GaussianNaiveBayesTrainer}.
+ */
+public class GaussianNaiveBayesTrainerTest extends TrainerTest {
+    /** Precision in test checks. */
+    private static final double PRECISION = 1e-2;
+    /** */
+    private static final double LABEL_1 = 1.;
+    /** */
+    private static final double LABEL_2 = 2.;
+
+    /** Data. */
+    private static final Map<Integer, double[]> data = new HashMap<>();
+    /** */
+    private static final Map<Integer, double[]> singleLabeldata1 = new HashMap<>();
+    /** */
+    private static final Map<Integer, double[]> singleLabeldata2 = new HashMap<>();
+
+    static {
+        data.put(0, new double[] {1.0, -1.0, LABEL_1});
+        data.put(1, new double[] {-1.0, 2.0, LABEL_1});
+        data.put(2, new double[] {6.0, 1.0, LABEL_1});
+        data.put(3, new double[] {-3.0, 2.0, LABEL_2});
+        data.put(4, new double[] {-5.0, -2.0, LABEL_2});
+
+        singleLabeldata1.put(0, new double[] {1.0, -1.0, LABEL_1});
+        singleLabeldata1.put(1, new double[] {-1.0, 2.0, LABEL_1});
+        singleLabeldata1.put(2, new double[] {6.0, 1.0, LABEL_1});
+
+        singleLabeldata2.put(0, new double[] {-3.0, 2.0, LABEL_2});
+        singleLabeldata2.put(1, new double[] {-5.0, -2.0, LABEL_2});
+    }
+
+    private GaussianNaiveBayesTrainer trainer;
+
+    /** Initialization {@code GaussianNaiveBayesTrainer}.*/
+    @Before
+    public void createTrainer() {
+        trainer = new GaussianNaiveBayesTrainer();
+    }
+
+    /** */
+    @Test
+    public void testWithLinearlySeparableData() {
+        Map<Integer, double[]> cacheMock = new HashMap<>();
+        for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
+            cacheMock.put(i, twoLinearlySeparableClasses[i]);
+
+        GaussianNaiveBayesModel mdl = trainer.fit(
+            cacheMock,
+            parts,
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+            (k, v) -> v[0]
+        );
+
+        TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), PRECISION);
+        TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), PRECISION);
+    }
+
+    /** */
+    @Test
+    public void testReturnsCorrectLabelProbalities() {
+
+        GaussianNaiveBayesModel model = trainer.fit(
+            new LocalDatasetBuilder<>(data, parts),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        Assert.assertEquals(3. / data.size(), model.getClassProbabilities()[0], PRECISION);
+        Assert.assertEquals(2. / data.size(), model.getClassProbabilities()[1], PRECISION);
+    }
+
+    /** */
+    @Test
+    public void testReturnsEquivalentProbalitiesWhenSetEquiprobableClasses_() {
+        GaussianNaiveBayesTrainer trainer = new GaussianNaiveBayesTrainer()
+            .withEquiprobableClasses();
+
+        GaussianNaiveBayesModel model = trainer.fit(
+            new LocalDatasetBuilder<>(data, parts),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        Assert.assertEquals(.5, model.getClassProbabilities()[0], PRECISION);
+        Assert.assertEquals(.5, model.getClassProbabilities()[1], PRECISION);
+    }
+
+    /** */
+    @Test
+    public void testReturnsPresetProbalitiesWhenSetPriorProbabilities() {
+        double[] priorProbabilities = new double[] {.35, .65};
+        GaussianNaiveBayesTrainer trainer = new GaussianNaiveBayesTrainer()
+            .setPriorProbabilities(priorProbabilities);
+
+        GaussianNaiveBayesModel model = trainer.fit(
+            new LocalDatasetBuilder<>(data, parts),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        Assert.assertEquals(priorProbabilities[0], model.getClassProbabilities()[0], PRECISION);
+        Assert.assertEquals(priorProbabilities[1], model.getClassProbabilities()[1], PRECISION);
+    }
+
+    /** */
+    @Test
+    public void testReturnsCorrectMeans() {
+
+        GaussianNaiveBayesModel model = trainer.fit(
+            new LocalDatasetBuilder<>(singleLabeldata1, parts),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        Assert.assertArrayEquals(new double[] {2.0, 2. / 3.}, model.getMeans()[0], PRECISION);
+    }
+
+    /** */
+    @Test
+    public void testReturnsCorrectVariances() {
+
+        GaussianNaiveBayesModel model = trainer.fit(
+            new LocalDatasetBuilder<>(singleLabeldata1, parts),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        double[] expectedVars = {8.666666666666666, 1.5555555555555556};
+        Assert.assertArrayEquals(expectedVars, model.getVariances()[0], PRECISION);
+    }
+
+    /** */
+    @Test
+    public void testUpdatigModel() {
+        GaussianNaiveBayesModel model = trainer.fit(
+            new LocalDatasetBuilder<>(singleLabeldata1, parts),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        GaussianNaiveBayesModel updatedModel = trainer.updateModel(model,
+            new LocalDatasetBuilder<>(singleLabeldata2, parts),
+            (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+            (k, v) -> v[2]
+        );
+
+        Assert.assertEquals(3. / data.size(), updatedModel.getClassProbabilities()[0], PRECISION);
+        Assert.assertEquals(2. / data.size(), updatedModel.getClassProbabilities()[1], PRECISION);
+    }
+}


Mime
View raw message