commons-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From l..@apache.org
Subject svn commit: r1460726 - /commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
Date Mon, 25 Mar 2013 15:47:32 GMT
Author: luc
Date: Mon Mar 25 15:47:31 2013
New Revision: 1460726

URL: http://svn.apache.org/r1460726
Log:
Fixed tests so they do not use equals on top level classes.

Patch submitted by Jared Becksfort.

JIRA: MATH-817

Modified:
    commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java

Modified: commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
URL: http://svn.apache.org/viewvc/commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java?rev=1460726&r1=1460725&r2=1460726&view=diff
==============================================================================
--- commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
(original)
+++ commons/proper/math/trunk/src/test/java/org/apache/commons/math3/distribution/fitting/MultivariateNormalMixtureExpectationMaximizationTest.java
Mon Mar 25 15:47:31 2013
@@ -17,6 +17,7 @@
 package org.apache.commons.math3.distribution.fitting;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 
 import org.apache.commons.math3.distribution.MixtureMultivariateNormalDistribution;
@@ -25,10 +26,11 @@ import org.apache.commons.math3.exceptio
 import org.apache.commons.math3.exception.DimensionMismatchException;
 import org.apache.commons.math3.exception.NotStrictlyPositiveException;
 import org.apache.commons.math3.exception.NumberIsTooSmallException;
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.commons.math3.util.Pair;
 import org.junit.Assert;
 import org.junit.Test;
-import org.junit.Ignore;
 
 /**
  * Test that demonstrates the use of
@@ -36,9 +38,6 @@ import org.junit.Ignore;
  */
 public class MultivariateNormalMixtureExpectationMaximizationTest {
 
-    // TODO reject initial mixes where means/covMats not computable with data
-    // numCols
-
     @Test(expected = NotStrictlyPositiveException.class)
     public void testNonEmptyData() {
         // Should not accept empty data
@@ -144,22 +143,34 @@ public class MultivariateNormalMixtureEx
         fitter.fit(badInitialMix);
     }
 
-    @Ignore@Test
+    @Test
     public void testInitialMixture() {
         // Testing initial mixture estimated from data
-        double[] correctWeights = new double[] { 0.5, 0.5 };
+        final double[] correctWeights = new double[] { 0.5, 0.5 };
 
-        MultivariateNormalDistribution[] correctMVNs = new MultivariateNormalDistribution[2];
+        final double[][] correctMeans = new double[][] {
+            {-0.0021722935000328823, 3.5432892936887908},
+            {5.090902706507635, 8.68540656355283},
+        };
 
-        correctMVNs[0] = new MultivariateNormalDistribution(new double[] {
-                        -0.0021722935000328823, 3.5432892936887908 },
-                        new double[][] {
-                                { 4.537422569229048, 3.5266152281729304 },
-                                { 3.5266152281729304, 6.175448814169779 } });
-        correctMVNs[1] = new MultivariateNormalDistribution(new double[] {
-                        5.090902706507635, 8.68540656355283 }, new double[][] {
-                        { 2.886778573963039, 1.5257474543463154 },
-                        { 1.5257474543463154, 3.3794567673616918 } });
+        final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2];
+
+        correctCovMats[0] = new Array2DRowRealMatrix(new double[][] {
+                { 4.537422569229048, 3.5266152281729304 },
+                { 3.5266152281729304, 6.175448814169779 } });
+
+        correctCovMats[1] = new Array2DRowRealMatrix( new double[][] {
+                { 2.886778573963039, 1.5257474543463154 },
+                { 1.5257474543463154, 3.3794567673616918 } });
+
+        final MultivariateNormalDistribution[] correctMVNs = new
+                MultivariateNormalDistribution[2];
+
+        correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0],
+                correctCovMats[0].getData());
+
+        correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1],
+                correctCovMats[1].getData());
 
         final MixtureMultivariateNormalDistribution initialMix
             = MultivariateNormalMixtureExpectationMaximization.estimate(getTestSamples(),
2);
@@ -169,30 +180,41 @@ public class MultivariateNormalMixtureEx
                 .getComponents()) {
             Assert.assertEquals(correctWeights[i], component.getFirst(),
                     Math.ulp(1d));
-            Assert.assertEquals(correctMVNs[i], component.getSecond());
+            
+            final double[] means = component.getValue().getMeans();
+            Assert.assertTrue(Arrays.equals(correctMeans[i], means));
+            
+            final RealMatrix covMat = component.getValue().getCovariances();
+            Assert.assertEquals(correctCovMats[i], covMat);
             i++;
         }
     }
 
-    @Ignore@Test
+    @Test
     public void testFit() {
         // Test that the loglikelihood, weights, and models are determined and
         // fitted correctly
-        double[][] data = getTestSamples();
-        double correctLogLikelihood = -4.292431006791994;
-        double[] correctWeights = new double[] { 0.2962324189652912, 0.7037675810347089 };
-        MultivariateNormalDistribution[] correctMVNs = new MultivariateNormalDistribution[2];
-        correctMVNs[0] = new MultivariateNormalDistribution(new double[] {
-                        -1.4213112715121132, 1.6924690505757753 },
-                        new double[][] {
-                                { 1.739356907285747, -0.5867644251487614 },
-                                { -0.5867644251487614, 1.0232932029324642 } });
-
-        correctMVNs[1] = new MultivariateNormalDistribution(new double[] {
-                        4.213612224374709, 7.975621325853645 },
-                        new double[][] {
-                                { 4.245384898007161, 2.5797798966382155 },
-                                { 2.5797798966382155, 3.9200272522448367 } });
+        final double[][] data = getTestSamples();
+        final double correctLogLikelihood = -4.292431006791994;
+        final double[] correctWeights = new double[] { 0.2962324189652912, 0.7037675810347089
};
+        
+        final double[][] correctMeans = new double[][]{
+            {-1.4213112715121132, 1.6924690505757753},
+            {4.213612224374709, 7.975621325853645}
+        };
+        
+        final RealMatrix[] correctCovMats = new Array2DRowRealMatrix[2];
+        correctCovMats[0] = new Array2DRowRealMatrix(new double[][] {
+            { 1.739356907285747, -0.5867644251487614 },
+            { -0.5867644251487614, 1.0232932029324642 } }
+                );
+        correctCovMats[1] = new Array2DRowRealMatrix(new double[][] {
+            { 4.245384898007161, 2.5797798966382155 },
+            { 2.5797798966382155, 3.9200272522448367 } });
+        
+        final MultivariateNormalDistribution[] correctMVNs = new MultivariateNormalDistribution[2];
+        correctMVNs[0] = new MultivariateNormalDistribution(correctMeans[0], correctCovMats[0].getData());
+        correctMVNs[1] = new MultivariateNormalDistribution(correctMeans[1], correctCovMats[1].getData());
 
         MultivariateNormalMixtureExpectationMaximization fitter
             = new MultivariateNormalMixtureExpectationMaximization(data);
@@ -209,10 +231,13 @@ public class MultivariateNormalMixtureEx
 
         int i = 0;
         for (Pair<Double, MultivariateNormalDistribution> component : components) {
-            double weight = component.getFirst();
-            MultivariateNormalDistribution mvn = component.getSecond();
+            final double weight = component.getFirst();
+            final MultivariateNormalDistribution mvn = component.getSecond();
+            final double[] mean = mvn.getMeans();
+            final RealMatrix covMat = mvn.getCovariances();
             Assert.assertEquals(correctWeights[i], weight, Math.ulp(1d));
-            Assert.assertEquals(correctMVNs[i], mvn);
+            Assert.assertTrue(Arrays.equals(correctMeans[i], mean));
+            Assert.assertEquals(correctCovMats[i], covMat);
             i++;
         }
     }



Mime
View raw message