ctakes-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tm...@apache.org
Subject svn commit: r1631906 - /ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/eval/AssertionEvaluation.java
Date Tue, 14 Oct 2014 22:01:37 GMT
Author: tmill
Date: Tue Oct 14 22:01:37 2014
New Revision: 1631906

URL: http://svn.apache.org/r1631906
Log:
Allow assertion evaluation to train on a portion of data to build learning curves.

Modified:
    ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/eval/AssertionEvaluation.java

Modified: ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/eval/AssertionEvaluation.java
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/eval/AssertionEvaluation.java?rev=1631906&r1=1631905&r2=1631906&view=diff
==============================================================================
--- ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/eval/AssertionEvaluation.java
(original)
+++ ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/eval/AssertionEvaluation.java
Tue Oct 14 22:01:37 2014
@@ -35,10 +35,10 @@ import java.util.Set;
 import java.util.TreeMap;
 
 import org.apache.ctakes.assertion.attributes.features.selection.FeatureSelection;
+import org.apache.ctakes.assertion.medfacts.ConceptConverterAnalysisEngine;
 import org.apache.ctakes.assertion.medfacts.cleartk.AlternateCuePhraseAnnotator;
 import org.apache.ctakes.assertion.medfacts.cleartk.AssertionCleartkAnalysisEngine;
 import org.apache.ctakes.assertion.medfacts.cleartk.AssertionCleartkAnalysisEngine.FEATURE_CONFIG;
-import org.apache.ctakes.assertion.medfacts.cleartk.AssertionComponents;
 import org.apache.ctakes.assertion.medfacts.cleartk.ConditionalCleartkAnalysisEngine;
 import org.apache.ctakes.assertion.medfacts.cleartk.GenericCleartkAnalysisEngine;
 import org.apache.ctakes.assertion.medfacts.cleartk.HistoryCleartkAnalysisEngine;
@@ -48,7 +48,6 @@ import org.apache.ctakes.assertion.medfa
 import org.apache.ctakes.assertion.medfacts.cleartk.UncertaintyCleartkAnalysisEngine;
 import org.apache.ctakes.assertion.pipelines.GoldEntityAndAttributeReaderPipelineForSeedCorpus;
 import org.apache.ctakes.core.ae.DocumentIdPrinterAnalysisEngine;
-import org.apache.ctakes.core.util.CtakesFileNamer;
 import org.apache.ctakes.core.util.DocumentIDAnnotationUtil;
 import org.apache.ctakes.typesystem.type.constants.CONST;
 import org.apache.ctakes.typesystem.type.syntax.BaseToken;
@@ -100,7 +99,7 @@ import org.cleartk.ml.jar.DefaultDataWri
 import org.cleartk.ml.jar.DirectoryDataWriterFactory;
 import org.cleartk.ml.jar.GenericJarClassifierFactory;
 import org.cleartk.ml.jar.JarClassifierBuilder;
-import org.cleartk.ml.libsvm.LibSvmStringOutcomeDataWriter;
+import org.cleartk.ml.liblinear.LibLinearStringOutcomeDataWriter;
 //import org.cleartk.ml.libsvm.tk.TKLibSvmStringOutcomeDataWriter;
 import org.kohsuke.args4j.CmdLineParser;
 import org.kohsuke.args4j.Option;
@@ -282,6 +281,12 @@ private static Logger logger = Logger.ge
     		usage = "Domain adaptation -- for each semicolon-separated directory in train-dir,
creates a domain-specific feature space",
     		required = false)
     public boolean feda = false;
+    
+    @Option(
+        name = "--portion",
+        usage = "Learning curve building -- what percentage of the training data to train
on.",
+        required = false)
+    public double portionOfDataToUse = 1.0;
   }
   
   protected ArrayList<String> annotationTypes;
@@ -388,14 +393,14 @@ private static Logger logger = Logger.ge
     if(options.kernelParams != null){
       kernelParams = options.kernelParams.split("\\s+");
     }else{
-      kernelParams = new String[]{"-t", "0", "-c", "1"};
+      kernelParams = new String[]{"-c", "1.0"};
     }
     Class<? extends DataWriter<String>> dw = null;
     if(options.featConfig == FEATURE_CONFIG.STK || options.featConfig == FEATURE_CONFIG.PTK){

 //        dw = TKLibSvmStringOutcomeDataWriter.class;
       throw new UnsupportedOperationException("This requires cleartk-2.0 which");
     }
-    dw = LibSvmStringOutcomeDataWriter.class;
+    dw = LibLinearStringOutcomeDataWriter.class;
     
     AssertionEvaluation evaluation = new AssertionEvaluation(
         modelsDir,
@@ -403,29 +408,8 @@ private static Logger logger = Logger.ge
         annotationTypes,
         annotatorClass,
         dw,
-        kernelParams
-//        "-t",
-//        "0",
-//       TKLibSvmStringOutcomeDataWriter.class,
-//        "-c",
-//        "1"
-//        "-t",
-//        "5",
-//        "-C",
-//        "+",
-//        "-L",
-//        "0.4",
-//        "-N",
-//        "3",
-//        "-S",
-//        "0"
-       
-//        "-w0",
-//        "100.0",
-//        "-w1",
-//        "1.0"
-//        "100",
-//        "2"
+        "-c",
+        "1"
         );
     
     // if preprocessing, don't do anything else
@@ -457,6 +441,26 @@ private static Logger logger = Logger.ge
       AssertionEvaluation.printScore(overallStats,  "CROSS FOLD OVERALL");
       
     } 
+    else if (Math.abs(options.portionOfDataToUse - 1.0) > 0.001){
+      int numIters = 5;
+      List<File> testFiles = Arrays.asList(options.testDirectory.listFiles());
+      Map<String, Double> overallStats = new TreeMap<>();
+
+      for(String annotationType : annotationTypes){
+        overallStats.put(annotationType, 0.0);
+      }
+      for(int iter = 0; iter < numIters; iter++){
+        Map<String,AnnotationStatisticsCompact> stats = evaluation.trainAndTest(trainFiles,
testFiles);
+        AssertionEvaluation.printScore(stats, "Sample " + iter + " score:");
+        for(String annotationType : stats.keySet()){
+          overallStats.put(annotationType, overallStats.get(annotationType) + stats.get(annotationType).f1("-1"));
+        }
+      }
+      for(String annotationType : annotationTypes){
+        System.out.println("Macro-average F-score for " + annotationType + " is: " + (overallStats.get(annotationType)
/ numIters));
+      }
+//      AssertionEvaluation.printScore(overallStats, "Learning Curve Proportion Average");
+    }
     
     // run train and test
     else {
@@ -661,6 +665,7 @@ public static void printScore(Map<String
   
   @Override
   public void train(CollectionReader collectionReader, File directory) throws Exception {
+    if(options.noCleartk) return;
     AggregateBuilder builder = new AggregateBuilder();
     
     //builder.add(AnalysisEngineFactory.createEngineDescription(ReplaceCTakesEntityMentionsAndModifiersWithGold.class));
@@ -673,8 +678,8 @@ public static void printScore(Map<String
 //        directory.getPath());
 //    builder.add(assertionDescription);
     
-    AnalysisEngineDescription documentIdPrinterAnnotator = AnalysisEngineFactory.createEngineDescription(DocumentIdPrinterAnalysisEngine.class);
-    builder.add(documentIdPrinterAnnotator);
+//    AnalysisEngineDescription documentIdPrinterAnnotator = AnalysisEngineFactory.createPrimitiveDescription(DocumentIdPrinterAnalysisEngine.class);
+//    builder.add(documentIdPrinterAnnotator);
     
     AnalysisEngineDescription goldCopierIdentifiedAnnotsAnnotator = AnalysisEngineFactory.createEngineDescription(ReferenceIdentifiedAnnotationsSystemToGoldCopier.class);
     builder.add(goldCopierIdentifiedAnnotsAnnotator);
@@ -756,7 +761,9 @@ public static void printScore(Map<String
     				AssertionCleartkAnalysisEngine.PARAM_FEATURE_SELECTION_THRESHOLD,
     				featureSelectionThreshold,
     				AssertionCleartkAnalysisEngine.PARAM_FEATURE_CONFIG,
-    				options.featConfig
+    				options.featConfig,
+    				AssertionCleartkAnalysisEngine.PARAM_PORTION_OF_DATA_TO_USE,
+    				(float) options.portionOfDataToUse
     				);
     	}
 		builder.add(polarityAnnotator);
@@ -912,6 +919,9 @@ public static void printScore(Map<String
     AnalysisEngineDescription assertionAttributeClearerAnnotator = AnalysisEngineFactory.createEngineDescription(ReferenceAnnotationsSystemAssertionClearer.class);
     builder.add(assertionAttributeClearerAnnotator);
     
+    AnalysisEngineDescription documentIdPrinterAnnotator = AnalysisEngineFactory.createEngineDescription(DocumentIdPrinterAnalysisEngine.class);
+    builder.add(documentIdPrinterAnnotator);
+
     if ( options.noCleartk ) {
     	addExternalAttributeAnnotatorsToAggregate(builder);
     } else {
@@ -994,9 +1004,6 @@ public static void printScore(Map<String
         throw new AnalysisEngineProcessException(e);
       }
 
-      String documentId = DocumentIDAnnotationUtil.getDocumentID(jCas);
-      System.out.format("document id: %s%n", documentId);
-      
       Collection<IdentifiedAnnotation> goldEntitiesAndEvents = new ArrayList<IdentifiedAnnotation>();

       if ( !ignoreAnatomicalSites ) {
     	  Collection<EntityMention> goldEntities = JCasUtil.select(goldView, EntityMention.class);
@@ -1493,6 +1500,7 @@ public static class HashableAnnotation i
 
 private void addExternalAttributeAnnotatorsToAggregate(AggregateBuilder builder)
 		throws UIMAException, IOException {
+  builder.add(AnalysisEngineFactory.createEngineDescription(ConceptConverterAnalysisEngine.class));
 	// RUN ALL THE OLD (non-ClearTK) CLASSIFIERS
 	AnalysisEngineDescription oldAssertionAnnotator = AnalysisEngineFactory.createEngineDescription("desc/assertionAnalysisEngine");

 	ConfigurationParameterFactory.addConfigurationParameters(



Mime
View raw message