incubator-ctakes-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From stevenbeth...@apache.org
Subject svn commit: r1414692 - in /incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor: eval/RelationExtractorEvaluation.java pipelines/RelationExtractorTrain.java
Date Wed, 28 Nov 2012 14:03:48 GMT
Author: stevenbethard
Date: Wed Nov 28 14:03:47 2012
New Revision: 1414692

URL: http://svn.apache.org/viewvc?rev=1414692&view=rev
Log:
Removes irrelevant relations from relation extraction evaluation (e.g. don't include degree_of
relations when evaluating location_of)

Modified:
    incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java
    incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/pipelines/RelationExtractorTrain.java

Modified: incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java
URL: http://svn.apache.org/viewvc/incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java?rev=1414692&r1=1414691&r2=1414692&view=diff
==============================================================================
--- incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java
(original)
+++ incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/eval/RelationExtractorEvaluation.java
Wed Nov 28 14:03:47 2012
@@ -51,6 +51,7 @@ import org.cleartk.eval.Evaluation_ImplB
 import org.cleartk.util.Options_ImplBase;
 import org.kohsuke.args4j.Option;
 import org.uimafit.component.JCasAnnotator_ImplBase;
+import org.uimafit.descriptor.ConfigurationParameter;
 import org.uimafit.factory.AggregateBuilder;
 import org.uimafit.factory.AnalysisEngineFactory;
 import org.uimafit.factory.CollectionReaderFactory;
@@ -103,10 +104,10 @@ public class RelationExtractorEvaluation
     public boolean gridSearch = false;
 
     @Option(
-        name = "--run-degree-of",
-        usage = "if true runs the degree of relation extractor otherwise "
-            + "it uses the normal entity mention pair relation extractor")
-    public boolean runDegreeOf = false;
+        name = "--relations",
+        usage = "determines which relations to evaluate on (separately)",
+        required = false)
+    public List<String> relations = Arrays.asList("degree_of", "location_of");
 
     @Option(
         name = "--test-on-ctakes",
@@ -116,124 +117,127 @@ public class RelationExtractorEvaluation
   }
 
   public static final String GOLD_VIEW_NAME = "GoldView";
-
+  
   public static void main(String[] args) throws Exception {
     Options options = new Options();
     options.parseOptions(args);
-    
+
     // error on invalid option combinations
     if (options.testDirectory != null && options.gridSearch) {
       throw new IllegalArgumentException("grid search can only be run on the train or dev
sets");
     }
-    
+
     List<File> trainFiles = Arrays.asList(options.trainDirectory.listFiles());
-    
-    // define the output directory for models
-    File modelsDir = options.runDegreeOf
-        ? new File("target/models/degree_of")
-        : new File("target/models/em_pair");
-
-    // determine class for the classifier annotator
-    Class<? extends RelationExtractorAnnotator> annotatorClass = options.runDegreeOf
-        ? DegreeOfRelationExtractorAnnotator.class
-        : EntityMentionPairRelationExtractorAnnotator.class;
-
-    // determine the type of classifier to be trained
-    Class<? extends DataWriter<String>> dataWriterClass = LIBSVMStringOutcomeDataWriter.class;
-
-    // define the set of possible training parameters
-    List<ParameterSettings> possibleParams = options.runDegreeOf
-        ? getDegreeOfParameterSpace(options.gridSearch)
-        : getEMPairParameterSpace(options.gridSearch);
-
-    // run an evaluation for each set of parameters
-    Map<ParameterSettings, Double> scoredParams = new HashMap<ParameterSettings,
Double>();
-    for (ParameterSettings params : possibleParams) {
-      System.err.println(params);
-      System.err.println();
-
-      // define additional configuration parameters for the annotator
-      Object[] additionalParameters = new Object[] {
-          RelationExtractorAnnotator.PARAM_PROBABILITY_OF_KEEPING_A_NEGATIVE_EXAMPLE,
-          params.probabilityOfKeepingANegativeExample,
-          EntityMentionPairRelationExtractorAnnotator.PARAM_CLASSIFY_BOTH_DIRECTIONS,
-  			  params.classifyBothDirections,
-          RelationExtractorAnnotator.PARAM_PRINT_ERRORS,
-          false };
-
-      // define arguments to be passed to the classifier
-      String[] trainingArguments = new String[] {
-          "-t",
-          String.valueOf(params.svmKernelIndex),
-          "-c",
-          String.valueOf(params.svmCost),
-          "-g",
-          String.valueOf(params.svmGamma) };
-
-      // create the evaluation
-      RelationExtractorEvaluation evaluation = new RelationExtractorEvaluation(
-          modelsDir,
-          annotatorClass,
-          dataWriterClass,
-          additionalParameters,
-          trainingArguments,
-          options.testOnCTakes);
-      
-      if (options.devDirectory != null) {
-        if (options.testDirectory != null) {
-          // train on the training set + dev set and evaluate on the test set
-          List<File> allTrainFiles = new ArrayList<File>();
-          allTrainFiles.addAll(trainFiles);
-          allTrainFiles.addAll(Arrays.asList(options.devDirectory.listFiles()));
-          List<File> testFiles = Arrays.asList(options.testDirectory.listFiles());
-          params.stats = evaluation.trainAndTest(allTrainFiles, testFiles);
+
+    for (String relationCategory : options.relations) {
+
+      // define the output directory for models
+      File modelsDir = new File("target/models/" + relationCategory);
+
+      // determine class for the classifier annotator
+      boolean isDegreeOf = relationCategory.equals("degree_of");
+      Class<? extends RelationExtractorAnnotator> annotatorClass = isDegreeOf
+          ? DegreeOfRelationExtractorAnnotator.class
+          : EntityMentionPairRelationExtractorAnnotator.class;
+
+      // determine the type of classifier to be trained
+      Class<? extends DataWriter<String>> dataWriterClass = LIBSVMStringOutcomeDataWriter.class;
+
+      // define the set of possible training parameters
+      List<ParameterSettings> possibleParams = isDegreeOf
+          ? getDegreeOfParameterSpace(options.gridSearch)
+          : getEMPairParameterSpace(options.gridSearch);
+
+      // run an evaluation for each set of parameters
+      Map<ParameterSettings, Double> scoredParams = new HashMap<ParameterSettings,
Double>();
+      for (ParameterSettings params : possibleParams) {
+        System.err.println(relationCategory + ": " + params);
+        System.err.println();
+
+        // define additional configuration parameters for the annotator
+        Object[] additionalParameters = new Object[] {
+            RelationExtractorAnnotator.PARAM_PROBABILITY_OF_KEEPING_A_NEGATIVE_EXAMPLE,
+            params.probabilityOfKeepingANegativeExample,
+            EntityMentionPairRelationExtractorAnnotator.PARAM_CLASSIFY_BOTH_DIRECTIONS,
+            params.classifyBothDirections,
+            RelationExtractorAnnotator.PARAM_PRINT_ERRORS,
+            false };
+
+        // define arguments to be passed to the classifier
+        String[] trainingArguments = new String[] {
+            "-t",
+            String.valueOf(params.svmKernelIndex),
+            "-c",
+            String.valueOf(params.svmCost),
+            "-g",
+            String.valueOf(params.svmGamma) };
+
+        // create the evaluation
+        RelationExtractorEvaluation evaluation = new RelationExtractorEvaluation(
+            modelsDir,
+            relationCategory,
+            annotatorClass,
+            dataWriterClass,
+            additionalParameters,
+            trainingArguments,
+            options.testOnCTakes);
+
+        if (options.devDirectory != null) {
+          if (options.testDirectory != null) {
+            // train on the training set + dev set and evaluate on the test set
+            List<File> allTrainFiles = new ArrayList<File>();
+            allTrainFiles.addAll(trainFiles);
+            allTrainFiles.addAll(Arrays.asList(options.devDirectory.listFiles()));
+            List<File> testFiles = Arrays.asList(options.testDirectory.listFiles());
+            params.stats = evaluation.trainAndTest(allTrainFiles, testFiles);
+          } else {
+            // train on the training set and evaluate on the dev set
+            List<File> devFiles = Arrays.asList(options.devDirectory.listFiles());
+            params.stats = evaluation.trainAndTest(trainFiles, devFiles);
+          }
         } else {
-          // train on the training set and evaluate on the dev set
-          List<File> devFiles = Arrays.asList(options.devDirectory.listFiles());
-          params.stats = evaluation.trainAndTest(trainFiles, devFiles);
+          if (options.testDirectory != null) {
+            // train on the training set and evaluate on the test set
+            List<File> testFiles = Arrays.asList(options.testDirectory.listFiles());
+            params.stats = evaluation.trainAndTest(trainFiles, testFiles);
+          } else {
+            // run n-fold cross-validation on the training set
+            List<AnnotationStatistics<String>> foldStats = evaluation.crossValidation(trainFiles,
2);
+            params.stats = AnnotationStatistics.addAll(foldStats);
+          }
         }
-      } else {
-        if (options.testDirectory != null) {
-          // train on the training set and evaluate on the test set
-          List<File> testFiles = Arrays.asList(options.testDirectory.listFiles());
-          params.stats = evaluation.trainAndTest(trainFiles, testFiles);
-        } else {
-          // run n-fold cross-validation on the training set
-          List<AnnotationStatistics<String>> foldStats = evaluation.crossValidation(trainFiles,
2);
-          params.stats = AnnotationStatistics.addAll(foldStats);
+        scoredParams.put(params, params.stats.f1());
+      }
+
+      // print parameters sorted by F1
+      List<ParameterSettings> list = new ArrayList<ParameterSettings>(scoredParams.keySet());
+      Function<ParameterSettings, Double> getCount = Functions.forMap(scoredParams);
+      Collections.sort(list, Ordering.natural().onResultOf(getCount));
+
+      // print performance of each set of parameters
+      if (list.size() > 1) {
+        System.err.println(relationCategory + ": summary:");
+        for (ParameterSettings params : list) {
+          System.err.printf(
+              "F1=%.3f P=%.3f R=%.3f %s\n",
+              params.stats.f1(),
+              params.stats.precision(),
+              params.stats.recall(),
+              params);
         }
+        System.err.println();
       }
-      scoredParams.put(params, params.stats.f1());
-    }
 
-    // print parameters sorted by F1
-    List<ParameterSettings> list = new ArrayList<ParameterSettings>(scoredParams.keySet());
-    Function<ParameterSettings, Double> getCount = Functions.forMap(scoredParams);
-    Collections.sort(list, Ordering.natural().onResultOf(getCount));
-
-    // print performance of each set of parameters
-    if (list.size() > 1) {
-      System.err.println("Summary:");
-      for (ParameterSettings params : list) {
-        System.err.printf(
-            "F1=%.3f P=%.3f R=%.3f %s\n",
-            params.stats.f1(),
-            params.stats.precision(),
-            params.stats.recall(),
-            params);
-      }
-      System.err.println();
-    }
-
-    // print overall best model
-    if (!list.isEmpty()) {
-      ParameterSettings lastParams = list.get(list.size() - 1);
-      System.err.println("Best model:");
-      System.err.print(lastParams.stats);
-      System.err.println(lastParams);
-      System.err.println(lastParams.stats.confusions());
-      System.err.println();
-      System.err.println(lastParams.stats.confusions().toHTML());
+      // print overall best model
+      if (!list.isEmpty()) {
+        ParameterSettings lastParams = list.get(list.size() - 1);
+        System.err.println(relationCategory + ": best model:");
+        System.err.print(lastParams.stats);
+        System.err.println(lastParams);
+        System.err.println(lastParams.stats.confusions());
+        System.err.println();
+        System.err.println(lastParams.stats.confusions().toHTML());
+      }
     }
   }
 
@@ -255,18 +259,22 @@ public class RelationExtractorEvaluation
    */
   public RelationExtractorEvaluation(
       File baseDirectory,
+      String relationCategory,
       Class<? extends RelationExtractorAnnotator> classifierAnnotatorClass,
       Class<? extends DataWriter<String>> dataWriterClass,
       Object[] additionalParameters,
       String[] trainingArguments,
       boolean testOnCTakes) {
     super(baseDirectory);
+    this.relationCategory = relationCategory;
     this.classifierAnnotatorClass = classifierAnnotatorClass;
     this.dataWriterClass = dataWriterClass;
     this.additionalParameters = additionalParameters;
     this.trainingArguments = trainingArguments;
     this.testOnCTakes = testOnCTakes;
   }
+  
+  private String relationCategory;
 
   private Class<? extends RelationExtractorAnnotator> classifierAnnotatorClass;
 
@@ -297,6 +305,12 @@ public class RelationExtractorEvaluation
   @Override
   public void train(CollectionReader collectionReader, File directory) throws Exception {
     AggregateBuilder builder = new AggregateBuilder();
+    // remove all but the relation of interest from the gold annotations
+    builder.add(AnalysisEngineFactory.createPrimitiveDescription(
+        RemoveOtherRelations.class,
+        RemoveOtherRelations.PARAM_RELATION_CATEGORY,
+        this.relationCategory),
+        CAS.NAME_DEFAULT_SOFA, GOLD_VIEW_NAME);
     // replace cTAKES entity mentions and modifiers in the system view with the gold annotations
     builder.add(AnalysisEngineFactory.createPrimitiveDescription(ReplaceCTakesEntityMentionsAndModifiersWithGold.class));
     // add the relation extractor, configured for training mode
@@ -327,6 +341,12 @@ public class RelationExtractorEvaluation
   protected AnnotationStatistics<String> test(CollectionReader collectionReader, File
directory)
       throws Exception {
     AggregateBuilder builder = new AggregateBuilder();
+    // remove all but the relation of interest from the gold annotations
+    builder.add(AnalysisEngineFactory.createPrimitiveDescription(
+        RemoveOtherRelations.class,
+        RemoveOtherRelations.PARAM_RELATION_CATEGORY,
+        this.relationCategory),
+        CAS.NAME_DEFAULT_SOFA, GOLD_VIEW_NAME);
     if (this.testOnCTakes) {
       // add the modifier extractor
       File file = new File("desc/analysis_engine/ModifierExtractorAnnotator.xml");
@@ -386,7 +406,7 @@ public class RelationExtractorEvaluation
           getOutcome);
     }
 
-    System.err.println(directory.getName() + ":");
+    System.err.printf("%s: %s:\n", this.relationCategory, directory.getName());
     System.err.print(stats);
     System.err.println(stats.confusions());
     System.err.println();
@@ -662,4 +682,23 @@ public class RelationExtractorEvaluation
       return a == null ? null : String.format("\"%s\"(type=%d)", a.getCoveredText(), a.getTypeID());
     }
   }
+  
+  public static class RemoveOtherRelations extends JCasAnnotator_ImplBase {
+    
+    public static final String PARAM_RELATION_CATEGORY = "RelationCategory";
+    @ConfigurationParameter(name = PARAM_RELATION_CATEGORY)
+    private String relationCategory;
+    
+
+    @Override
+    public void process(JCas jCas) throws AnalysisEngineProcessException {
+      List<BinaryTextRelation> relations = new ArrayList<BinaryTextRelation>();
+      relations.addAll(JCasUtil.select(jCas, BinaryTextRelation.class));
+      for (BinaryTextRelation relation : relations) {
+        if (!relation.getCategory().equals(this.relationCategory)) {
+          relation.removeFromIndexes();
+        }
+      }
+    }
+  }
 }

Modified: incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/pipelines/RelationExtractorTrain.java
URL: http://svn.apache.org/viewvc/incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/pipelines/RelationExtractorTrain.java?rev=1414692&r1=1414691&r2=1414692&view=diff
==============================================================================
--- incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/pipelines/RelationExtractorTrain.java
(original)
+++ incubator/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/pipelines/RelationExtractorTrain.java
Wed Nov 28 14:03:47 2012
@@ -95,6 +95,7 @@ public class RelationExtractorTrain {
   public static AnalysisEngineDescription trainRelationExtractor(
 		  File modelsDir, 
 		  List<File> trainFiles,
+		  String relationCategory,
 		  Class<? extends RelationExtractorAnnotator> annotatorClass, 
 		  Class<? extends DataWriter<String>> dataWriterClass,
 		  ParameterSettings params) throws Exception {
@@ -119,6 +120,7 @@ public class RelationExtractorTrain {
 	  
 	    RelationExtractorEvaluation evaluation = new RelationExtractorEvaluation(
 	    		modelsDir,
+	    		relationCategory,
 	    		annotatorClass,
 	    		dataWriterClass,
 	    		additionalParameters,
@@ -167,9 +169,9 @@ public class RelationExtractorTrain {
     // Train and write models
     AnalysisEngineDescription modifierExtractorDesc = trainModifierExtractor(modelsDirModExtractor,
trainFiles);
     writeDesc(options.descDir, "ModifierExtractorAnnotator", modifierExtractorDesc);
-    AnalysisEngineDescription degreeOfRelationExtractorDesc = trainRelationExtractor(modelsDirDegreeOf,
trainFiles, DegreeOfRelationExtractorAnnotator.class, dataWriterClass, degreeOfParams);
+    AnalysisEngineDescription degreeOfRelationExtractorDesc = trainRelationExtractor(modelsDirDegreeOf,
trainFiles, "degree_of", DegreeOfRelationExtractorAnnotator.class, dataWriterClass, degreeOfParams);
     writeDesc(options.descDir, "DegreeOfRelationExtractorAnnotator", degreeOfRelationExtractorDesc);
-    AnalysisEngineDescription emPairRelationExtractorDesc = trainRelationExtractor(modelsDirEMPair,
trainFiles, EntityMentionPairRelationExtractorAnnotator.class, dataWriterClass, emPairParams);
+    AnalysisEngineDescription emPairRelationExtractorDesc = trainRelationExtractor(modelsDirEMPair,
trainFiles, "location_of", EntityMentionPairRelationExtractorAnnotator.class, dataWriterClass,
emPairParams);
     writeDesc(options.descDir, "EntityMentionPairRelationExtractorAnnotator", emPairRelationExtractorDesc);
 
     // create the aggregate description



Mime
View raw message