mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tdunn...@apache.org
Subject svn commit: r1380428 [2/2] - in /mahout/trunk/math/src: main/java/org/apache/mahout/math/ main/java/org/apache/mahout/math/random/ test/java/org/apache/mahout/math/ test/java/org/apache/mahout/math/random/
Date Tue, 04 Sep 2012 02:18:45 GMT
Added: mahout/trunk/math/src/test/java/org/apache/mahout/math/random/MultinomialTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/random/MultinomialTest.java?rev=1380428&view=auto
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/random/MultinomialTest.java (added)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/random/MultinomialTest.java Tue
Sep  4 02:18:44 2012
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.random;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Multiset;
+import junit.framework.Assert;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+public class MultinomialTest {
+    @Before
+    public void setUp() {
+        RandomUtils.useTestSeed();
+    }
+
+    @Test(expected = IllegalArgumentException.class)
+    public void testNoValues() {
+        Multiset<String> emptySet = HashMultiset.create();
+        new Multinomial<String>(emptySet);
+    }
+
+    @Test
+    public void testSingleton() {
+        Multiset<String> oneThing = HashMultiset.create();
+        oneThing.add("one");
+        Multinomial<String> s = new Multinomial<String>(oneThing);
+        assertEquals("one", s.sample(0));
+        assertEquals("one", s.sample(0.1));
+        assertEquals("one", s.sample(1));
+    }
+
+    @Test
+    public void testEvenSplit() {
+        Multiset<String> stuff = HashMultiset.create();
+        for (int i = 0; i < 5; i++) {
+            stuff.add(i + "");
+        }
+        Multinomial<String> s = new Multinomial<String>(stuff);
+        final double EPSILON = 1e-15;
+
+        Multiset<String> cnt = HashMultiset.create();
+        for (int i = 0; i < 5; i++) {
+            cnt.add(s.sample(i * 0.2));
+            cnt.add(s.sample(i * 0.2 + EPSILON));
+            cnt.add(s.sample((i + 1) * 0.2 - EPSILON));
+        }
+
+        assertEquals(5, cnt.elementSet().size());
+        for (String v : cnt.elementSet()) {
+            assertEquals(3, cnt.count(v), 1.01);
+        }
+        assertTrue(cnt.contains(s.sample(1)));
+        assertEquals(s.sample(1 - EPSILON), s.sample(1));
+    }
+
+    @Test
+    public void testPrime() {
+        List<String> data = Lists.newArrayList();
+        for (int i = 0; i < 17; i++) {
+            String s = "0";
+            if ((i & 1) != 0) {
+                s = "1";
+            }
+            if ((i & 2) != 0) {
+                s = "2";
+            }
+            if ((i & 4) != 0) {
+                s = "3";
+            }
+            if ((i & 8) != 0) {
+                s = "4";
+            }
+            data.add(s);
+        }
+
+        Multiset<String> stuff = HashMultiset.create();
+
+        for (String x : data) {
+            stuff.add(x);
+        }
+
+        Multinomial<String> s0 = new Multinomial<String>(stuff);
+        Multinomial<String> s1 = new Multinomial<String>(stuff);
+        Multinomial<String> s2 = new Multinomial<String>(stuff);
+        final double EPSILON = 1e-15;
+
+        Multiset<String> cnt = HashMultiset.create();
+        for (int i = 0; i < 50; i++) {
+            final double p0 = i * 0.02;
+            final double p1 = (i + 1) * 0.02;
+            cnt.add(s0.sample(p0));
+            cnt.add(s0.sample(p0 + EPSILON));
+            cnt.add(s0.sample(p1 - EPSILON));
+
+            assertEquals(s0.sample(p0), s1.sample(p0));
+            assertEquals(s0.sample(p0 + EPSILON), s1.sample(p0 + EPSILON));
+            assertEquals(s0.sample(p1 - EPSILON), s1.sample(p1 - EPSILON));
+
+            assertEquals(s0.sample(p0), s2.sample(p0));
+            assertEquals(s0.sample(p0 + EPSILON), s2.sample(p0 + EPSILON));
+            assertEquals(s0.sample(p1 - EPSILON), s2.sample(p1 - EPSILON));
+        }
+
+        assertEquals(s0.sample(0), s1.sample(0));
+        assertEquals(s0.sample(0 + EPSILON), s1.sample(0 + EPSILON));
+        assertEquals(s0.sample(1 - EPSILON), s1.sample(1 - EPSILON));
+        assertEquals(s0.sample(1), s1.sample(1));
+
+        assertEquals(s0.sample(0), s2.sample(0));
+        assertEquals(s0.sample(0 + EPSILON), s2.sample(0 + EPSILON));
+        assertEquals(s0.sample(1 - EPSILON), s2.sample(1 - EPSILON));
+        assertEquals(s0.sample(1), s2.sample(1));
+
+        assertEquals(5, cnt.elementSet().size());
+        // regression test, really.  These values depend on the original seed and exact algorithm.
+        // the actual values should be within about 2 of these, however, almost regardless
of seed
+        Map<String, Integer> ref = ImmutableMap.of("3", 35, "2", 18, "1", 9, "0", 16,
"4", 72);
+        for (String v : cnt.elementSet()) {
+            assertEquals(ref.get(v).intValue(), cnt.count(v));
+        }
+
+        assertTrue(cnt.contains(s0.sample(1)));
+        assertEquals(s0.sample(1 - EPSILON), s0.sample(1));
+    }
+
+    @Test
+    public void testInsert() {
+        Random rand = new Random();
+        Multinomial<Integer> table = new Multinomial<Integer>();
+
+        double[] p = new double[10];
+        for (int i = 0; i < 10; i++) {
+            p[i] = rand.nextDouble();
+            table.add(i, p[i]);
+        }
+
+        checkSelfConsistent(table);
+
+        for (int i = 0; i < 10; i++) {
+            Assert.assertEquals(p[i], table.getWeight(i), 0);
+        }
+    }
+
+    @Test
+    public void testDeleteAndUpdate() {
+        Random rand = new Random();
+        Multinomial<Integer> table = new Multinomial<Integer>();
+        assertEquals(0, table.getWeight(), 1e-9);
+
+        double total = 0;
+        double[] p = new double[10];
+        for (int i = 0; i < 10; i++) {
+            p[i] = rand.nextDouble();
+            table.add(i, p[i]);
+            total += p[i];
+            assertEquals(total, table.getWeight(), 1e-9);
+        }
+
+        assertEquals(total, table.getWeight(), 1e-9);
+
+        checkSelfConsistent(table);
+
+        double delta = p[7] + p[8];
+        table.delete(7);
+        p[7] = 0;
+
+        table.set(8, 0);
+        p[8] = 0;
+        total -= delta;
+
+        checkSelfConsistent(table);
+
+        assertEquals(total, table.getWeight(), 1e-9);
+        for (int i = 0; i < 10; i++) {
+            assertEquals(p[i], table.getWeight(i), 0);
+            assertEquals(p[i] / total, table.getProbability(i), 1e-10);
+        }
+
+        table.set(9, 5.1);
+        total -= p[9];
+        p[9] = 5.1;
+        total += 5.1;
+
+        assertEquals(total , table.getWeight(), 1e-9);
+        for (int i = 0; i < 10; i++) {
+            assertEquals(p[i], table.getWeight(i), 0);
+            assertEquals(p[i] / total, table.getProbability(i), 1e-10);
+        }
+
+        checkSelfConsistent(table);
+
+        for (int i = 0; i < 10; i++) {
+            Assert.assertEquals(p[i], table.getWeight(i), 0);
+        }
+    }
+
+    private void checkSelfConsistent(Multinomial<Integer> table) {
+        final List<Double> weights = table.getWeights();
+
+        double totalWeight = table.getWeight();
+
+        double p = 0;
+        int[] k = new int[10];
+        for (double weight : weights) {
+            if (weight > 0) {
+                if (p > 0) {
+                    k[table.sample(p - 1e-9)]++;
+                }
+                k[table.sample(p + 1e-9)]++;
+            }
+            p += weight / totalWeight;
+        }
+        k[table.sample(p - 1e-9)]++;
+        Assert.assertEquals(1, p, 1e-9);
+
+        for (int i = 0; i < 10; i++) {
+            if (table.getWeight(i) > 0) {
+                Assert.assertEquals(k[i], 2);
+            }
+        }
+    }
+}

Added: mahout/trunk/math/src/test/java/org/apache/mahout/math/random/NormalTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/random/NormalTest.java?rev=1380428&view=auto
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/random/NormalTest.java (added)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/random/NormalTest.java Tue Sep
 4 02:18:44 2012
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.random;
+
+import org.apache.commons.math.MathException;
+import org.apache.commons.math.distribution.NormalDistribution;
+import org.apache.commons.math.distribution.NormalDistributionImpl;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.stats.OnlineSummarizer;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Arrays;
+
+import static org.junit.Assert.assertEquals;
+
+public class NormalTest {
+    @Before
+    public void setUp() {
+        RandomUtils.useTestSeed();
+    }
+
+    @Test
+    public void testOffset() {
+        OnlineSummarizer s = new OnlineSummarizer();
+        Sampler<Double> sampler = new Normal(2, 5);
+        for (int i = 0; i < 10001; i++) {
+            s.add(sampler.sample());
+        }
+
+        assertEquals(String.format("m = %.3f, sd = %.3f", s.getMean(), s.getSD()), 2, s.getMean(),
0.04 * s.getSD());
+        assertEquals(5, s.getSD(), 0.12);
+    }
+
+    @Test
+    public void testSample() throws MathException {
+        double[] data = new double[10001];
+        Sampler<Double> sampler = new Normal();
+        for (int i = 0; i < 10001; i++) {
+            data[i] = sampler.sample();
+        }
+        Arrays.sort(data);
+
+        NormalDistribution reference = new NormalDistributionImpl();
+
+        assertEquals("Median", reference.inverseCumulativeProbability(0.5), data[5000], 0.04);
+    }
+}

Added: mahout/trunk/math/src/test/java/org/apache/mahout/math/random/PoissonSamplerTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/random/PoissonSamplerTest.java?rev=1380428&view=auto
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/random/PoissonSamplerTest.java
(added)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/random/PoissonSamplerTest.java
Tue Sep  4 02:18:44 2012
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.random;
+
+import org.apache.commons.math.distribution.PoissonDistribution;
+import org.apache.commons.math.distribution.PoissonDistributionImpl;
+import org.apache.mahout.common.RandomUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+public class PoissonSamplerTest {
+    @Before
+    public void setUp() {
+        RandomUtils.useTestSeed();
+    }
+
+    @Test
+    public void testBasics() {
+        for (double alpha : new double[]{0.1, 1, 10, 100}) {
+            checkDistribution(new PoissonSampler(alpha), alpha);
+        }
+    }
+
+    private void checkDistribution(PoissonSampler pd, double alpha) {
+        int[] count = new int[(int) Math.max(10, 5 * alpha)];
+        for (int i = 0; i < 10000; i++) {
+            count[pd.sample().intValue()]++;
+        }
+
+        PoissonDistribution ref = new PoissonDistributionImpl(alpha);
+        for (int i = 0; i < count.length; i++) {
+            assertEquals(ref.probability(i), count[i] / 10000.0, 2e-2);
+        }
+    }
+}



Mime
View raw message