opennlp-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From co...@apache.org
Subject svn commit: r1490460 - in /opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools: ml/TrainerFactory.java ml/model/TrainUtil.java util/TrainingParameters.java
Date Thu, 06 Jun 2013 22:11:39 GMT
Author: colen
Date: Thu Jun  6 22:11:39 2013
New Revision: 1490460

URL: http://svn.apache.org/r1490460
Log:
OPENNLP-581 First proposal for a TrainerFactory. Again, I only changed the TrainUtil to avoid
changing many classes. I still not happy with the implementation, but would like feedback

Added:
    opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java   (with
props)
Modified:
    opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java
    opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/util/TrainingParameters.java

Added: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java?rev=1490460&view=auto
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java (added)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java Thu Jun
 6 22:11:39 2013
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package opennlp.tools.ml;
+
+import java.lang.reflect.Constructor;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+import opennlp.tools.ml.maxent.GIS;
+import opennlp.tools.ml.maxent.quasinewton.QNTrainer;
+import opennlp.tools.ml.perceptron.PerceptronTrainer;
+import opennlp.tools.ml.perceptron.SimplePerceptronSequenceTrainer;
+
+public class TrainerFactory {
+
+  // built-in trainers
+  private static final Map<String, Class> BUILTIN_TRAINERS;
+
+  static {
+    Map<String, Class> _trainers = new HashMap<String, Class>();
+    _trainers.put(GIS.MAXENT_VALUE, GIS.class);
+    _trainers.put(QNTrainer.MAXENT_QN_VALUE, QNTrainer.class);
+    _trainers.put(PerceptronTrainer.PERCEPTRON_VALUE, PerceptronTrainer.class);
+    _trainers.put(SimplePerceptronSequenceTrainer.PERCEPTRON_SEQUENCE_VALUE,
+        SimplePerceptronSequenceTrainer.class);
+
+    BUILTIN_TRAINERS = Collections.unmodifiableMap(_trainers);
+  }
+
+  public static boolean isSupportEvent(Map<String, String> trainParams) {
+    if (trainParams.get(AbstractTrainer.TRAINER_TYPE_PARAM) != null) {
+      if(EventTrainer.EVENT_VALUE.equals(trainParams
+            .get(AbstractTrainer.TRAINER_TYPE_PARAM))) {
+        return true;
+      }
+      return false;
+    } else {
+      return true; // default to event train
+    }
+  }
+
+  public static boolean isSupportSequence(Map<String, String> trainParams) {
+    if (SequenceTrainer.SEQUENCE_VALUE.equals(trainParams
+        .get(AbstractTrainer.TRAINER_TYPE_PARAM))) {
+      return true;
+    }
+    return false;
+  }
+
+  public static SequenceTrainer getSequenceTrainer(
+      Map<String, String> trainParams, Map<String, String> reportMap) {
+    String trainerType = getTrainerType(trainParams);
+    if (BUILTIN_TRAINERS.containsKey(trainerType)) {
+      return TrainerFactory.<SequenceTrainer> create(
+          BUILTIN_TRAINERS.get(trainerType), trainParams, reportMap);
+    } else {
+      return TrainerFactory.<SequenceTrainer> create(trainerType, trainParams,
+          reportMap);
+    }
+  }
+
+  public static EventTrainer getEventTrainer(Map<String, String> trainParams,
+      Map<String, String> reportMap) {
+    String trainerType = getTrainerType(trainParams);
+    if(trainerType == null) {
+      // default to MAXENT
+      return new GIS(trainParams, reportMap);
+    }
+    
+    if (BUILTIN_TRAINERS.containsKey(trainerType)) {
+      return TrainerFactory.<EventTrainer> create(
+          BUILTIN_TRAINERS.get(trainerType), trainParams, reportMap);
+    } else {
+      return TrainerFactory.<EventTrainer> create(trainerType, trainParams,
+          reportMap);
+    }
+  }
+
+  private static String getTrainerType(Map<String, String> trainParams) {
+    return trainParams.get(AbstractTrainer.ALGORITHM_PARAM);
+  }
+
+  private static <T> T create(String className,
+      Map<String, String> trainParams, Map<String, String> reportMap) {
+    T theFactory = null;
+
+    try {
+      // TODO: won't work in OSGi!
+      Class<T> trainerClass = (Class<T>) Class.forName(className);
+      theFactory = create(trainerClass, trainParams, reportMap);
+    } catch (Exception e) {
+      String msg = "Could not instantiate the " + className
+          + ". The initialization throw an exception.";
+      System.err.println(msg);
+      e.printStackTrace();
+      throw new IllegalArgumentException(msg, e);
+    }
+    return theFactory;
+  }
+
+  private static <T> T create(Class<T> trainerClass,
+      Map<String, String> trainParams, Map<String, String> reportMap) {
+    T theTrainer = null;
+    if (trainerClass != null) {
+      try {
+        Constructor<T> contructor = trainerClass.getConstructor(Map.class,
+            Map.class);
+        theTrainer = contructor.newInstance(trainParams, reportMap);
+      } catch (Exception e) {
+        String msg = "Could not instantiate the "
+            + trainerClass.getCanonicalName()
+            + ". The initialization throw an exception.";
+        System.err.println(msg);
+        e.printStackTrace();
+        throw new IllegalArgumentException(msg, e);
+      }
+    }
+    return theTrainer;
+  }
+}

Propchange: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/TrainerFactory.java
------------------------------------------------------------------------------
    svn:mime-type = text/plain

Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java?rev=1490460&r1=1490459&r2=1490460&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/ml/model/TrainUtil.java Thu Jun
 6 22:11:39 2013
@@ -23,6 +23,8 @@ import java.io.IOException;
 import java.util.Map;
 
 import opennlp.tools.ml.EventTrainer;
+import opennlp.tools.ml.SequenceTrainer;
+import opennlp.tools.ml.TrainerFactory;
 import opennlp.tools.ml.maxent.GIS;
 import opennlp.tools.ml.maxent.quasinewton.QNTrainer;
 import opennlp.tools.ml.perceptron.PerceptronTrainer;
@@ -37,33 +39,14 @@ public class TrainUtil {
   public static final String PERCEPTRON_VALUE = "PERCEPTRON";
   public static final String PERCEPTRON_SEQUENCE_VALUE = "PERCEPTRON_SEQUENCE";
   
-  
   public static final String CUTOFF_PARAM = "Cutoff";
-  private static final int CUTOFF_DEFAULT = 5;
   
   public static final String ITERATIONS_PARAM = "Iterations";
-  private static final int ITERATIONS_DEFAULT = 100;
   
   public static final String DATA_INDEXER_PARAM = "DataIndexer";
   public static final String DATA_INDEXER_ONE_PASS_VALUE = "OnePass";
   public static final String DATA_INDEXER_TWO_PASS_VALUE = "TwoPass";
   
-  
-  private static String getStringParam(Map<String, String> trainParams, String key,
-      String defaultValue, Map<String, String> reportMap) {
-
-    String valueString = trainParams.get(key);
-
-    if (valueString == null)
-      valueString = defaultValue;
-    
-    if (reportMap != null)
-      reportMap.put(key, valueString);
-    
-    return valueString;
-  }
-  
-  
   public static boolean isValid(Map<String, String> trainParams) {
 
     // TODO: Need to validate all parameters correctly ... error prone?!
@@ -108,30 +91,10 @@ public class TrainUtil {
   public static AbstractModel train(EventStream events, Map<String, String> trainParams,
Map<String, String> reportMap) 
       throws IOException {
     
-    if (!isValid(trainParams))
-        throw new IllegalArgumentException("trainParams are not valid!");
-    
-    if(isSequenceTraining(trainParams))
-      throw new IllegalArgumentException("sequence training is not supported by this method!");
-    
-    String algorithmName = getStringParam(trainParams, ALGORITHM_PARAM, MAXENT_VALUE, reportMap);
-    
-    EventTrainer trainer;
-    if(PERCEPTRON_VALUE.equals(algorithmName)) {
-      
-      trainer = new PerceptronTrainer(trainParams, reportMap);
-      
-    } else if(MAXENT_VALUE.equals(algorithmName)) {
-      
-      trainer = new GIS(trainParams, reportMap);
-      
-    } else if(MAXENT_QN_VALUE.equals(algorithmName)) {
-      
-      trainer = new QNTrainer(trainParams, reportMap);
-    
-    } else {
-      trainer = new GIS(trainParams, reportMap); // default to maxent?
+    if(!TrainerFactory.isSupportEvent(trainParams)) {
+      throw new IllegalArgumentException("EventTrain is not supported");
     }
+    EventTrainer trainer = TrainerFactory.getEventTrainer(trainParams, reportMap);
     
     return trainer.train(events);
   }
@@ -147,8 +110,11 @@ public class TrainUtil {
   public static AbstractModel train(SequenceStream events, Map<String, String> trainParams,
       Map<String, String> reportMap) throws IOException {
     
-    SimplePerceptronSequenceTrainer trainer = new SimplePerceptronSequenceTrainer(
-        trainParams, reportMap);
+    if(!TrainerFactory.isSupportSequence(trainParams)) {
+      throw new IllegalArgumentException("EventTrain is not supported");
+    }
+    SequenceTrainer trainer = TrainerFactory.getSequenceTrainer(trainParams, reportMap);
+    
     return trainer.train(events);
   }
 }

Modified: opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/util/TrainingParameters.java
URL: http://svn.apache.org/viewvc/opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/util/TrainingParameters.java?rev=1490460&r1=1490459&r2=1490460&view=diff
==============================================================================
--- opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/util/TrainingParameters.java (original)
+++ opennlp/trunk/opennlp-tools/src/main/java/opennlp/tools/util/TrainingParameters.java Thu
Jun  6 22:11:39 2013
@@ -25,9 +25,13 @@ import java.util.HashMap;
 import java.util.Map;
 import java.util.Properties;
 
+import opennlp.tools.ml.EventTrainer;
+
 public class TrainingParameters {
   
+  // TODO: are them duplicated?
   public static final String ALGORITHM_PARAM = "Algorithm";
+  public static final String TRAINER_TYPE_PARAM = "TrainerType";
   
   public static final String ITERATIONS_PARAM = "Iterations";
   public static final String CUTOFF_PARAM = "Cutoff";
@@ -144,6 +148,7 @@ public class TrainingParameters {
   public static final TrainingParameters defaultParams() {
     TrainingParameters mlParams = new TrainingParameters();
     mlParams.put(TrainingParameters.ALGORITHM_PARAM, "MAXENT");
+    mlParams.put(TrainingParameters.TRAINER_TYPE_PARAM, EventTrainer.EVENT_VALUE);
     mlParams.put(TrainingParameters.ITERATIONS_PARAM, Integer.toString(100));
     mlParams.put(TrainingParameters.CUTOFF_PARAM, Integer.toString(5));
 



Mime
View raw message