mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From isa...@apache.org
Subject svn commit: r1000807 [2/2] - in /mahout/trunk: core/ core/src/main/java/org/apache/mahout/classifier/sequencelearning/ core/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/ core/src/test/java/org/apache/mahout/classifier/sequencelearnin...
Date Fri, 24 Sep 2010 11:17:14 GMT
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java?rev=1000807&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java
(added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java
Fri Sep 24 11:17:13 2010
@@ -0,0 +1,44 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import junit.framework.Assert;
+
+import org.junit.Test;
+
+public class HMMModelTest extends HMMTestBase {
+
+  @Test
+  public void testRandomModelGeneration() {
+    // make sure we generate a valid random model
+    HmmModel model = new HmmModel(10, 20);
+    // check whether the model is valid
+    HmmUtils.validate(model);
+  }
+
+  @Test
+  public void testSerialization() {
+    String serialized = model.toJson();
+    HmmModel model2 = HmmModel.fromJson(serialized);
+    String serialized2 = model2.toJson();
+    // since there are no equals methods for the underlying objects, we
+    // check identity via the serialization string
+    Assert.assertEquals(serialized, serialized2);
+  }
+
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java?rev=1000807&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java
(added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java
Fri Sep 24 11:17:13 2010
@@ -0,0 +1,49 @@
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+
+public class HMMTestBase extends MahoutTestCase {
+
+  protected HmmModel model;
+  protected int[] sequence = {1, 0, 2, 2, 0, 0, 1};
+
+  /**
+   * We initialize a new HMM model using the following parameters # hidden
+   * states: 4 ("H0","H1","H2","H3") # output states: 3 ("O0","O1","O2") #
+   * transition matrix to: H0 H1 H2 H3 from: H0 0.5 0.1 0.1 0.3 H1 0.4 0.4 0.1
+   * 0.1 H2 0.1 0.0 0.8 0.1 H3 0.1 0.1 0.1 0.7 # output matrix to: O0 O1 O2
+   * from: H0 0.8 0.1 0.1 H1 0.6 0.1 0.3 H2 0.1 0.8 0.1 H3 0.0 0.1 0.9 # initial
+   * probabilities H0 0.2
+   * <p/>
+   * H1 0.1 H2 0.4 H3 0.3
+   * <p/>
+   * We also intialize an observation sequence: "O1" "O0" "O2" "O2" "O0" "O0"
+   * "O1"
+   */
+
+  @Override
+  public void setUp() throws Exception {
+    super.setUp();
+    // intialize the hidden/output state names
+    String hiddenNames[] = {"H0", "H1", "H2", "H3"};
+    String outputNames[] = {"O0", "O1", "O2"};
+    // initialize the transition matrix
+    double transitionP[][] = {{0.5, 0.1, 0.1, 0.3}, {0.4, 0.4, 0.1, 0.1},
+        {0.1, 0.0, 0.8, 0.1}, {0.1, 0.1, 0.1, 0.7}};
+    // initialize the emission matrix
+    double emissionP[][] = {{0.8, 0.1, 0.1}, {0.6, 0.1, 0.3},
+        {0.1, 0.8, 0.1}, {0.0, 0.1, 0.9}};
+    // initialize the initial probability vector
+    double initialP[] = {0.2, 0.1, 0.4, 0.3};
+    // now generate the model
+    model = new HmmModel(new DenseMatrix(transitionP), new DenseMatrix(
+        emissionP), new DenseVector(initialP));
+    model.registerHiddenStateNames(hiddenNames);
+    model.registerOutputStateNames(outputNames);
+    // make sure the model is valid :)
+    HmmUtils.validate(model);
+  }
+
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java?rev=1000807&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java
(added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java
Fri Sep 24 11:17:13 2010
@@ -0,0 +1,160 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import junit.framework.Assert;
+
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public class HMMTrainerTest extends HMMTestBase {
+
+  @Test
+  public void testViterbiTraining() {
+    // initialize the expected model parameters (from R)
+    // expected transition matrix
+    double transitionE[][] = {{0.3125, 0.0625, 0.3125, 0.3125},
+        {0.25, 0.25, 0.25, 0.25}, {0.5, 0.071429, 0.357143, 0.071429},
+        {0.5, 0.1, 0.1, 0.3}};
+    // initialize the emission matrix
+    double emissionE[][] = {{0.882353, 0.058824, 0.058824},
+        {0.333333, 0.333333, 0.3333333}, {0.076923, 0.846154, 0.076923},
+        {0.111111, 0.111111, 0.777778}};
+
+    // train the given network to the following output sequence
+    int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
+
+    HmmModel trained = HmmTrainer.trainViterbi(model, observed, 0.5, 0.1, 10,
+        false);
+
+    // now check whether the model matches our expectations
+    Matrix emissionMatrix = trained.getEmissionMatrix();
+    Matrix transitionMatrix = trained.getTransitionMatrix();
+
+    for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
+      for (int j = 0; j < trained.getNrOfHiddenStates(); ++j)
+        Assert.assertEquals(transitionMatrix.getQuick(i, j), transitionE[i][j],
+            0.00001);
+
+      for (int j = 0; j < trained.getNrOfOutputStates(); ++j)
+        Assert.assertEquals(emissionMatrix.getQuick(i, j), emissionE[i][j],
+            0.00001);
+    }
+
+  }
+
+  @Test
+  public void testScaledViterbiTraining() {
+    // initialize the expected model parameters (from R)
+    // expected transition matrix
+    double transitionE[][] = {{0.3125, 0.0625, 0.3125, 0.3125},
+        {0.25, 0.25, 0.25, 0.25}, {0.5, 0.071429, 0.357143, 0.071429},
+        {0.5, 0.1, 0.1, 0.3}};
+    // initialize the emission matrix
+    double emissionE[][] = {{0.882353, 0.058824, 0.058824},
+        {0.333333, 0.333333, 0.3333333}, {0.076923, 0.846154, 0.076923},
+        {0.111111, 0.111111, 0.777778}};
+
+    // train the given network to the following output sequence
+    int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
+
+    HmmModel trained = HmmTrainer.trainViterbi(model, observed, 0.5, 0.1, 10,
+        true);
+
+    // now check whether the model matches our expectations
+    Matrix emissionMatrix = trained.getEmissionMatrix();
+    Matrix transitionMatrix = trained.getTransitionMatrix();
+
+    for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
+      for (int j = 0; j < trained.getNrOfHiddenStates(); ++j)
+        Assert.assertEquals(transitionMatrix.getQuick(i, j), transitionE[i][j],
+            0.00001);
+
+      for (int j = 0; j < trained.getNrOfOutputStates(); ++j)
+        Assert.assertEquals(emissionMatrix.getQuick(i, j), emissionE[i][j],
+            0.00001);
+    }
+
+  }
+
+  @Test
+  public void testBaumWelchTraining() {
+    // train the given network to the following output sequence
+    int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
+
+    // expected values from Matlab HMM package / R HMM package
+    double[] initialExpected = {0, 0, 1.0, 0};
+    double[][] transitionExpected = {{0.2319, 0.0993, 0.0005, 0.6683},
+        {0.0001, 0.3345, 0.6654, 0}, {0.5975, 0, 0.4025, 0},
+        {0.0024, 0.6657, 0, 0.3319}};
+    double[][] emissionExpected = {{0.9995, 0.0004, 0.0001},
+        {0.9943, 0.0036, 0.0021}, {0.0059, 0.9941, 0}, {0, 0, 1}};
+
+    HmmModel trained = HmmTrainer.trainBaumWelch(model, observed, 0.1, 10,
+        false);
+
+    Vector initialProbabilities = trained.getInitialProbabilities();
+    Matrix emissionMatrix = trained.getEmissionMatrix();
+    Matrix transitionMatrix = trained.getTransitionMatrix();
+
+    for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
+      Assert.assertEquals(initialProbabilities.get(i), initialExpected[i],
+          0.0001);
+      for (int j = 0; j < trained.getNrOfHiddenStates(); ++j)
+        Assert.assertEquals(transitionMatrix.getQuick(i, j),
+            transitionExpected[i][j], 0.0001);
+      for (int j = 0; j < trained.getNrOfOutputStates(); ++j)
+        Assert.assertEquals(emissionMatrix.getQuick(i, j),
+            emissionExpected[i][j], 0.0001);
+    }
+  }
+
+  @Test
+  public void testScaledBaumWelchTraining() {
+    // train the given network to the following output sequence
+    int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0};
+
+    // expected values from Matlab HMM package / R HMM package
+    double[] initialExpected = {0, 0, 1.0, 0};
+    double[][] transitionExpected = {{0.2319, 0.0993, 0.0005, 0.6683},
+        {0.0001, 0.3345, 0.6654, 0}, {0.5975, 0, 0.4025, 0},
+        {0.0024, 0.6657, 0, 0.3319}};
+    double[][] emissionExpected = {{0.9995, 0.0004, 0.0001},
+        {0.9943, 0.0036, 0.0021}, {0.0059, 0.9941, 0}, {0, 0, 1}};
+
+    HmmModel trained = HmmTrainer
+        .trainBaumWelch(model, observed, 0.1, 10, true);
+
+    Vector initialProbabilities = trained.getInitialProbabilities();
+    Matrix emissionMatrix = trained.getEmissionMatrix();
+    Matrix transitionMatrix = trained.getTransitionMatrix();
+
+    for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) {
+      Assert.assertEquals(initialProbabilities.get(i), initialExpected[i],
+          0.0001);
+      for (int j = 0; j < trained.getNrOfHiddenStates(); ++j)
+        Assert.assertEquals(transitionMatrix.getQuick(i, j),
+            transitionExpected[i][j], 0.0001);
+      for (int j = 0; j < trained.getNrOfOutputStates(); ++j)
+        Assert.assertEquals(emissionMatrix.getQuick(i, j),
+            emissionExpected[i][j], 0.0001);
+    }
+  }
+
+}

Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java?rev=1000807&view=auto
==============================================================================
--- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java
(added)
+++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java
Fri Sep 24 11:17:13 2010
@@ -0,0 +1,160 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.util.Arrays;
+
+import junit.framework.Assert;
+
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+public class HMMUtilsTest extends HMMTestBase {
+
+  Matrix legal2_2;
+  Matrix legal2_3;
+  Matrix legal3_3;
+  Vector legal2;
+  Matrix illegal2_2;
+
+  public void setUp() throws Exception {
+    super.setUp();
+    legal2_2 = new DenseMatrix(new double[][]{{0.5, 0.5}, {0.3, 0.7}});
+    legal2_3 = new DenseMatrix(new double[][]{{0.2, 0.2, 0.6},
+        {0.3, 0.3, 0.4}});
+    legal3_3 = new DenseMatrix(new double[][]{{0.1, 0.1, 0.8},
+        {0.1, 0.2, 0.7}, {0.2, 0.3, 0.5}});
+    legal2 = new DenseVector(new double[]{0.4, 0.6});
+    illegal2_2 = new DenseMatrix(new double[][]{{1, 2}, {3, 4}});
+  }
+
+  @Test
+  public void testValidatorLegal() {
+    HmmUtils.validate(new HmmModel(legal2_2, legal2_3, legal2));
+  }
+
+  @Test
+  public void testValidatorDimensionError() {
+    try {
+      HmmUtils.validate(new HmmModel(legal3_3, legal2_3, legal2));
+    } catch (IllegalArgumentException e) {
+      // success
+      return;
+    }
+    Assert.fail();
+  }
+
+  @Test
+  public void testValidatorIllegelMatrixError() {
+    try {
+      HmmUtils.validate(new HmmModel(illegal2_2, legal2_3, legal2));
+    } catch (IllegalArgumentException e) {
+      // success
+      return;
+    }
+    Assert.fail();
+  }
+
+  @Test
+  public void testEncodeStateSequence() {
+    String[] hiddenSequence = {"H1", "H2", "H0", "H3", "H4"};
+    String[] outputSequence = {"O1", "O2", "O4", "O0"};
+    // test encoding the hidden Sequence
+    int[] hiddenSequenceEnc = HmmUtils.encodeStateSequence(model, Arrays
+        .asList(hiddenSequence), false, -1);
+    int[] outputSequenceEnc = HmmUtils.encodeStateSequence(model, Arrays
+        .asList(outputSequence), true, -1);
+    // expected state sequences
+    int[] hiddenSequenceExp = {1, 2, 0, 3, -1};
+    int[] outputSequenceExp = {1, 2, -1, 0};
+    // compare
+    for (int i = 0; i < hiddenSequenceEnc.length; ++i)
+      Assert.assertEquals(hiddenSequenceExp[i], hiddenSequenceEnc[i]);
+    for (int i = 0; i < outputSequenceEnc.length; ++i)
+      Assert.assertEquals(outputSequenceExp[i], outputSequenceEnc[i]);
+  }
+
+  @Test
+  public void testDecodeStateSequence() {
+    int[] hiddenSequence = {1, 2, 0, 3, 10};
+    int[] outputSequence = {1, 2, 10, 0};
+    // test encoding the hidden Sequence
+    java.util.Vector<String> hiddenSequenceDec = HmmUtils.decodeStateSequence(
+        model, hiddenSequence, false, "unknown");
+    java.util.Vector<String> outputSequenceDec = HmmUtils.decodeStateSequence(
+        model, outputSequence, true, "unknown");
+    // expected state sequences
+    String[] hiddenSequenceExp = {"H1", "H2", "H0", "H3", "unknown"};
+    String[] outputSequenceExp = {"O1", "O2", "unknown", "O0"};
+    // compare
+    for (int i = 0; i < hiddenSequenceExp.length; ++i)
+      Assert.assertEquals(hiddenSequenceExp[i], hiddenSequenceDec.get(i));
+    for (int i = 0; i < outputSequenceExp.length; ++i)
+      Assert.assertEquals(outputSequenceExp[i], outputSequenceDec.get(i));
+  }
+
+  @Test
+  public void testNormalizeModel() {
+    DenseVector ip = new DenseVector(new double[]{10, 20});
+    DenseMatrix tr = new DenseMatrix(new double[][]{{10, 10}, {20, 25}});
+    DenseMatrix em = new DenseMatrix(new double[][]{{5, 7}, {10, 15}});
+    HmmModel model = new HmmModel(tr, em, ip);
+    HmmUtils.normalizeModel(model);
+    // the model should be valid now
+    HmmUtils.validate(model);
+  }
+
+  @Test
+  public void testTruncateModel() {
+    DenseVector ip = new DenseVector(new double[]{0.0001, 0.0001, 0.9998});
+    DenseMatrix tr = new DenseMatrix(new double[][]{
+        {0.9998, 0.0001, 0.0001}, {0.0001, 0.9998, 0.0001},
+        {0.0001, 0.0001, 0.9998}});
+    DenseMatrix em = new DenseMatrix(new double[][]{
+        {0.9998, 0.0001, 0.0001}, {0.0001, 0.9998, 0.0001},
+        {0.0001, 0.0001, 0.9998}});
+    HmmModel model = new HmmModel(tr, em, ip);
+    // now truncate the model
+    HmmModel sparseModel = HmmUtils.truncateModel(model, 0.01);
+    // first make sure this is a valid model
+    HmmUtils.validate(sparseModel);
+    // now check whether the values are as expected
+    Vector sparse_ip = sparseModel.getInitialProbabilities();
+    Matrix sparse_tr = sparseModel.getTransitionMatrix();
+    Matrix sparse_em = sparseModel.getEmissionMatrix();
+    for (int i = 0; i < sparseModel.getNrOfHiddenStates(); ++i) {
+      if (i == 2)
+        Assert.assertEquals(1.0, sparse_ip.getQuick(i));
+      else
+        Assert.assertEquals(0.0, sparse_ip.getQuick(i));
+      for (int j = 0; j < sparseModel.getNrOfHiddenStates(); ++j) {
+        if (i == j) {
+          Assert.assertEquals(1.0, sparse_tr.getQuick(i, j));
+          Assert.assertEquals(1.0, sparse_em.getQuick(i, j));
+        } else {
+          Assert.assertEquals(0.0, sparse_tr.getQuick(i, j));
+          Assert.assertEquals(0.0, sparse_em.getQuick(i, j));
+        }
+      }
+    }
+  }
+
+}

Added: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java?rev=1000807&view=auto
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java
(added)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sequencelearning/hmm/PosTagger.java
Fri Sep 24 11:17:13 2010
@@ -0,0 +1,278 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sequencelearning.hmm;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.net.URL;
+import java.net.URLConnection;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.mahout.math.Matrix;
+
+/**
+ * This class implements a sample program that uses a pre-tagged training data
+ * set to train an HMM model as a POS tagger. The training data is automatically
+ * downloaded from the following URL:
+ * http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt It then
+ * trains an HMM Model using supervised learning and tests the model on the
+ * following test data set:
+ * http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt Further
+ * details regarding the data files can be found at
+ * http://flexcrfs.sourceforge.net/#Case_Study
+ *
+ * @author mheimel
+ */
+public final class PosTagger {
+
+  /**
+   * No public constructors for utility classes.
+   */
+  private PosTagger() {
+    // nothing to do here really.
+  }
+
+  /**
+   * Logger for this class.
+   */
+  private static final Log LOG = LogFactory.getLog(PosTagger.class);
+  /**
+   * Model trained in the example.
+   */
+  private static HmmModel taggingModel;
+
+  /**
+   * Map for storing the IDs for the POS tags (hidden states)
+   */
+  private static Map<String, Integer> tagIDs;
+
+  /**
+   * Counter for the next assigned POS tag ID The value of 0 is reserved for
+   * "unknown POS tag"
+   */
+  private static int nextTagId;
+
+  /**
+   * Map for storing the IDs for observed words (observed states)
+   */
+  private static Map<String, Integer> wordIDs;
+
+  /**
+   * Counter for the next assigned word ID The value of 0 is reserved for
+   * "unknown word"
+   */
+  private static int nextWordId = 1; // 0 is reserved for "unknown word"
+
+  /**
+   * Used for storing a list of POS tags of read sentences.
+   */
+  private static List<int[]> hiddenSequences;
+
+  /**
+   * Used for storing a list of word tags of read sentences.
+   */
+  private static List<int[]> observedSequences;
+
+  /**
+   * number of read lines
+   */
+  private static int readLines;
+
+  /**
+   * Given an URL, this function fetches the data file, parses it, assigns POS
+   * Tag/word IDs and fills the hiddenSequences/observedSequences lists with
+   * data from those files. The data is expected to be in the following format
+   * (one word per line): word pos-tag np-tag sentences are closed with the .
+   * pos tag
+   *
+   * @param url       Where the data file is stored
+   * @param assignIDs Should IDs for unknown words/tags be assigned? (Needed for
+   *                  training data, not needed for test data)
+   * @throws IOException in case data file cannot be read.
+   */
+  private static void readFromURL(String url, boolean assignIDs) throws IOException {
+    URLConnection connection = (new URL(url)).openConnection();
+    BufferedReader input = new BufferedReader(new InputStreamReader(connection.getInputStream()));
+    // initialize the data structure
+    hiddenSequences = new LinkedList<int[]>();
+    observedSequences = new LinkedList<int[]>();
+    readLines = 0;
+
+    // now read line by line of the input file
+    String line;
+    List<Integer> observedSequence = new LinkedList<Integer>();
+    List<Integer> hiddenSequence = new LinkedList<Integer>();
+    while ((line = input.readLine()) != null) {
+      if (line.isEmpty()) {
+        // new sentence starts
+        int[] observedSequenceArray = new int[observedSequence.size()];
+        int[] hiddenSequenceArray = new int[hiddenSequence.size()];
+        for (int i = 0; i < observedSequence.size(); ++i) {
+          observedSequenceArray[i] = observedSequence.get(i);
+          hiddenSequenceArray[i] = hiddenSequence.get(i);
+        }
+        // now register those arrays
+        hiddenSequences.add(hiddenSequenceArray);
+        observedSequences.add(observedSequenceArray);
+        // and reset the linked lists
+        observedSequence.clear();
+        hiddenSequence.clear();
+        continue;
+      }
+      readLines++;
+      // we expect the format [word] [POS tag] [NP tag]
+      String[] tags = line.split(" ");
+      // when analyzing the training set, assign IDs
+      if (assignIDs) {
+        if (!wordIDs.containsKey(tags[0]))
+          wordIDs.put(tags[0], nextWordId++);
+        if (!tagIDs.containsKey(tags[1]))
+          tagIDs.put(tags[1], nextTagId++);
+      }
+      // determine the IDs
+      Integer wordID = wordIDs.get(tags[0]);
+      Integer tagID = tagIDs.get(tags[1]);
+      // handle unknown values
+      wordID = (wordID == null) ? 0 : wordID;
+      tagID = (tagID == null) ? 0 : tagID;
+      // now construct the current sequence
+      observedSequence.add(wordID);
+      hiddenSequence.add(tagID);
+    }
+    // if there is still something in the pipe, register it
+    if (!observedSequence.isEmpty()) {
+      int[] observedSequenceArray = new int[observedSequence.size()];
+      int[] hiddenSequenceArray = new int[hiddenSequence.size()];
+      for (int i = 0; i < observedSequence.size(); ++i) {
+        observedSequenceArray[i] = observedSequence.get(i);
+        hiddenSequenceArray[i] = hiddenSequence.get(i);
+      }
+      // now register those arrays
+      hiddenSequences.add(hiddenSequenceArray);
+      observedSequences.add(observedSequenceArray);
+    }
+  }
+
+  private static void trainModel(String trainingURL) throws IOException {
+    tagIDs = new HashMap<String, Integer>(44); // we expect 44 distinct tags
+    wordIDs = new HashMap<String, Integer>(19122); // we expect 19122
+    // distinct words
+    LOG.info("Reading and parsing training data file from URL: " + trainingURL);
+    long start = System.currentTimeMillis();
+    readFromURL(trainingURL, true);
+    long end = System.currentTimeMillis();
+    double duration = (end - start) / (double) 1000;
+    LOG.info("Parsing done in " + duration + " seconds!");
+    LOG.info("Read " + readLines + " lines containing "
+        + hiddenSequences.size() + " sentences with a total of "
+        + (nextWordId - 1) + " distinct words and " + (nextTagId - 1)
+        + " distinct POS tags.");
+    start = System.currentTimeMillis();
+    taggingModel = HmmTrainer.trainSupervisedSequence(nextTagId, nextWordId,
+        hiddenSequences, observedSequences, 0.05);
+    // we have to adjust the model a bit,
+    // since we assume a higher probability that a given unknown word is NNP
+    // than anything else
+    Matrix emissions = taggingModel.getEmissionMatrix();
+    for (int i = 0; i < taggingModel.getNrOfHiddenStates(); ++i)
+      emissions.setQuick(i, 0, 0.1 / (double) taggingModel
+          .getNrOfHiddenStates());
+    int nnptag = tagIDs.get("NNP");
+    emissions.setQuick(nnptag, 0, 1 / (double) taggingModel
+        .getNrOfHiddenStates());
+    // re-normalize the emission probabilities
+    HmmUtils.normalizeModel(taggingModel);
+    // now register the names
+    taggingModel.registerHiddenStateNames(tagIDs);
+    taggingModel.registerOutputStateNames(wordIDs);
+    end = System.currentTimeMillis();
+    duration = (end - start) / (double) 1000;
+    LOG.info("Trained HMM model sin " + duration + " seconds!");
+  }
+
+  private static void testModel(String testingURL) throws IOException {
+    LOG.info("Reading and parsing test data file from URL:" + testingURL);
+    long start = System.currentTimeMillis();
+    readFromURL(testingURL, false);
+    long end = System.currentTimeMillis();
+    double duration = (end - start) / (double) 1000;
+    LOG.info("Parsing done in " + duration + " seconds!");
+    LOG.info("Read " + readLines + " lines containing "
+        + hiddenSequences.size() + " sentences.");
+
+    start = System.currentTimeMillis();
+    int errorCount = 0;
+    int totalCount = 0;
+    for (int i = 0; i < observedSequences.size(); ++i) {
+      // fetch the viterbi path as the POS tag for this observed sequence
+      int[] posEstimate = HmmEvaluator.decode(taggingModel, observedSequences
+          .get(i), false);
+      // compare with the expected
+      int[] posExpected = hiddenSequences.get(i);
+      for (int j = 0; j < posExpected.length; ++j) {
+        totalCount++;
+        if (posEstimate[j] != posExpected[j])
+          errorCount++;
+      }
+    }
+    end = System.currentTimeMillis();
+    duration = (end - start) / (double) 1000;
+    LOG.info("POS tagged test file in " + duration + " seconds!");
+    double errorRate = (double) errorCount / (double) totalCount;
+    LOG.info("Tagged the test file with an error rate of: " + errorRate);
+  }
+
+  private static java.util.Vector<String> tagSentence(String sentence) {
+    // first, we need to isolate all punctuation characters, so that they
+    // can be recognized
+    sentence = sentence.replaceAll("[,.!?:;\"]", " $0 ");
+    sentence = sentence.replaceAll("''", " '' ");
+    // now we tokenize the sentence
+    String[] tokens = sentence.split("[ ]+");
+    // now generate the observed sequence
+    int[] observedSequence = HmmUtils.encodeStateSequence(taggingModel, Arrays
+        .asList(tokens), true, 0);
+    // POS tag this observedSequence
+    int[] hiddenSequence = HmmEvaluator.decode(taggingModel, observedSequence,
+        false);
+    // and now decode the tag names
+    return HmmUtils.decodeStateSequence(taggingModel, hiddenSequence, false,
+        null);
+  }
+
+  public static void main(String[] args) throws IOException {
+    // generate the model from URL
+    trainModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/train.txt");
+    testModel("http://www.jaist.ac.jp/~hieuxuan/flexcrfs/CoNLL2000-NP/test.txt");
+    // tag an exemplary sentence
+    String test = "McDonalds is a huge company with many employees .";
+    String[] testWords = test.split(" ");
+    java.util.Vector<String> posTags;
+    posTags = tagSentence(test);
+    for (int i = 0; i < posTags.size(); ++i)
+      LOG.info(testWords[i] + "[" + posTags.get(i) + "]");
+  }
+
+}



Mime
View raw message