ctakes-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From c...@apache.org
Subject svn commit: r1504557 - in /ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal: ae/TimeAnnotator.java eval/EvaluationOfTimeSpans.java
Date Thu, 18 Jul 2013 17:27:52 GMT
Author: clin
Date: Thu Jul 18 17:27:52 2013
New Revision: 1504557

URL: http://svn.apache.org/r1504557
Log:
Added Feature Selection for TimeAnnotator. The Feature Selection can be turned off by setting
--featureSelectionThreshold 0.
Other positive value of --featureSelectionThreshold will enable the feature selection process.

Modified:
    ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/ae/TimeAnnotator.java
    ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/eval/EvaluationOfTimeSpans.java

Modified: ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/ae/TimeAnnotator.java
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/ae/TimeAnnotator.java?rev=1504557&r1=1504556&r2=1504557&view=diff
==============================================================================
--- ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/ae/TimeAnnotator.java
(original)
+++ ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/ae/TimeAnnotator.java
Thu Jul 18 17:27:52 2013
@@ -19,11 +19,15 @@
 package org.apache.ctakes.temporal.ae;
 
 import java.io.File;
+import java.io.IOException;
+import java.net.URI;
 import java.util.ArrayList;
 import java.util.List;
 
 import org.apache.ctakes.temporal.ae.feature.ParseSpanFeatureExtractor;
 import org.apache.ctakes.temporal.ae.feature.TimeWordTypeExtractor;
+import org.apache.ctakes.temporal.ae.feature.selection.Chi2FeatureSelection;
+import org.apache.ctakes.temporal.ae.feature.selection.FeatureSelection;
 import org.apache.ctakes.typesystem.type.syntax.BaseToken;
 import org.apache.ctakes.typesystem.type.textsem.TimeMention;
 import org.apache.ctakes.typesystem.type.textspan.Segment;
@@ -35,7 +39,7 @@ import org.apache.uima.cas.CASException;
 import org.apache.uima.jcas.JCas;
 import org.apache.uima.resource.ResourceInitializationException;
 import org.cleartk.classifier.CleartkAnnotator;
-import org.cleartk.classifier.DataWriter;
+//import org.cleartk.classifier.DataWriter;
 import org.cleartk.classifier.Feature;
 import org.cleartk.classifier.Instance;
 import org.cleartk.classifier.chunking.BIOChunking;
@@ -51,151 +55,203 @@ import org.cleartk.classifier.feature.ex
 import org.cleartk.classifier.jar.DefaultDataWriterFactory;
 import org.cleartk.classifier.jar.DirectoryDataWriterFactory;
 import org.cleartk.classifier.jar.GenericJarClassifierFactory;
+import org.uimafit.descriptor.ConfigurationParameter;
 import org.uimafit.factory.AnalysisEngineFactory;
 import org.uimafit.util.JCasUtil;
 
 public class TimeAnnotator extends TemporalEntityAnnotator_ImplBase {
 
-  public static final String TIMEX_VIEW = "TimexView";
-  
-  public static AnalysisEngineDescription createDataWriterDescription(
-      Class<? extends DataWriter<String>> dataWriterClass,
-      File outputDirectory) throws ResourceInitializationException {
-    return AnalysisEngineFactory.createPrimitiveDescription(
-        TimeAnnotator.class,
-        CleartkAnnotator.PARAM_IS_TRAINING,
-        true,
-        DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
-        dataWriterClass,
-        DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
-        outputDirectory);
-  }
-
-  public static AnalysisEngineDescription createAnnotatorDescription(File modelDirectory)
-      throws ResourceInitializationException {
-    return AnalysisEngineFactory.createPrimitiveDescription(
-        TimeAnnotator.class,
-        CleartkAnnotator.PARAM_IS_TRAINING,
-        false,
-        GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
-        new File(modelDirectory, "model.jar"));
-  }
-
-  protected List<SimpleFeatureExtractor> tokenFeatureExtractors;
-
-  protected List<CleartkExtractor> contextFeatureExtractors;
-  
-//  protected List<SimpleFeatureExtractor> parseFeatureExtractors;
-  protected ParseSpanFeatureExtractor parseExtractor;
-  
-  private BIOChunking<BaseToken, TimeMention> timeChunking;
-
-  @Override
-  public void initialize(UimaContext context) throws ResourceInitializationException {
-    super.initialize(context);
-
-    // define chunking
-    this.timeChunking = new BIOChunking<BaseToken, TimeMention>(BaseToken.class, TimeMention.class);
-
-    CombinedExtractor allExtractors = new CombinedExtractor(
-        new CoveredTextExtractor(),
-        new CharacterCategoryPatternExtractor(PatternType.REPEATS_MERGED),
-        new CharacterCategoryPatternExtractor(PatternType.ONE_PER_CHAR),
-        new TypePathExtractor(BaseToken.class, "partOfSpeech"),
-        new TimeWordTypeExtractor());
-
-//    CombinedExtractor parseExtractors = new CombinedExtractor(
-//        new ParseSpanFeatureExtractor()
-//        );
-    this.tokenFeatureExtractors = new ArrayList<SimpleFeatureExtractor>();
-    this.tokenFeatureExtractors.add(allExtractors);
-
-    this.contextFeatureExtractors = new ArrayList<CleartkExtractor>();
-    this.contextFeatureExtractors.add(new CleartkExtractor(
-        BaseToken.class,
-        allExtractors,
-        new Preceding(3),
-        new Following(3)));
-//    this.parseFeatureExtractors = new ArrayList<ParseSpanFeatureExtractor>();
-//    this.parseFeatureExtractors.add(new ParseSpanFeatureExtractor());
-    parseExtractor = new ParseSpanFeatureExtractor();
-  }
-
-  @Override
-  public void process(JCas jCas, Segment segment) throws AnalysisEngineProcessException {
-    // classify tokens within each sentence
-    for (Sentence sentence : JCasUtil.selectCovered(jCas, Sentence.class, segment)) {
-      List<BaseToken> tokens = JCasUtil.selectCovered(jCas, BaseToken.class, sentence);
-
-      // during training, the list of all outcomes for the tokens
-      List<String> outcomes;
-      if (this.isTraining()) {
-        List<TimeMention> times = JCasUtil.selectCovered(jCas, TimeMention.class, sentence);
-        outcomes = this.timeChunking.createOutcomes(jCas, tokens, times);
-      }
-      // during prediction, the list of outcomes predicted so far
-      else {
-        outcomes = new ArrayList<String>();
-      }
-
-      // extract features for all tokens
-      int tokenIndex = -1;
-      for (BaseToken token : tokens) {
-        ++tokenIndex;
-
-        List<Feature> features = new ArrayList<Feature>();
-        // features from token attributes
-        for (SimpleFeatureExtractor extractor : this.tokenFeatureExtractors) {
-          features.addAll(extractor.extract(jCas, token));
-        }
-        // features from surrounding tokens
-        for (CleartkExtractor extractor : this.contextFeatureExtractors) {
-          features.addAll(extractor.extractWithin(jCas, token, sentence));
-        }
-        // features from previous classifications
-        int nPreviousClassifications = 2;
-        for (int i = nPreviousClassifications; i > 0; --i) {
-          int index = tokenIndex - i;
-          String previousOutcome = index < 0 ? "O" : outcomes.get(index);
-          features.add(new Feature("PreviousOutcome_" + i, previousOutcome));
-        }
-        //add segment ID as a features:
-        features.add(new Feature("SegmentID", segment.getId()));
-        
-        // features from dominating parse tree
-//        for(SimpleFeatureExtractor extractor : this.parseFeatureExtractors){
-        BaseToken startToken = token;
-        for(int i = tokenIndex-1; i >= 0; --i){
-          String outcome = outcomes.get(i);
-          if(outcome.equals("O")){
-            break;
-          }
-          startToken = tokens.get(i);
-        }
-        features.addAll(parseExtractor.extract(jCas, startToken.getBegin(), token.getEnd()));
-//        }
-        // if training, write to data file
-        if (this.isTraining()) {
-          String outcome = outcomes.get(tokenIndex);
-          this.dataWriter.write(new Instance<String>(outcome, features));
-        }
-
-        // if predicting, add prediction to outcomes
-        else {
-          outcomes.add(this.classifier.classify(features));
-        }
-      }
-
-      // during prediction, convert chunk labels to times and add them to the CAS
-      if (!this.isTraining()) {
-        JCas timexCas;
-        try {
-          timexCas = jCas.getView(TIMEX_VIEW);
-        } catch (CASException e) {
-          throw new AnalysisEngineProcessException(e);
-        }
-        this.timeChunking.createChunks(timexCas, tokens, outcomes);
-      }
-    }
-  }
+	public static final String PARAM_FEATURE_SELECTION_THRESHOLD = "WhetherToDoFeatureSelection";
+
+	@ConfigurationParameter(
+			name = PARAM_FEATURE_SELECTION_THRESHOLD,
+			mandatory = false,
+			description = "the Chi-squared threshold at which features should be removed")
+	protected Float featureSelectionThreshold = 0f;
+	
+	public static final String PARAM_FEATURE_SELECTION_URI = "FeatureSelectionURI";
+
+	@ConfigurationParameter(
+			mandatory = false,
+			name = PARAM_FEATURE_SELECTION_URI,
+			description = "provides a URI where the feature selection data will be written")
+	protected URI featureSelectionURI;
+
+	public static final String TIMEX_VIEW = "TimexView";
+
+	public static AnalysisEngineDescription createDataWriterDescription(
+			Class<?> dataWriterClass,
+					File outputDirectory,
+					float featureSelect) throws ResourceInitializationException {
+		return AnalysisEngineFactory.createPrimitiveDescription(
+				TimeAnnotator.class,
+				CleartkAnnotator.PARAM_IS_TRAINING,
+				true,
+				DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
+				dataWriterClass,
+				DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
+				outputDirectory,
+				TimeAnnotator.PARAM_FEATURE_SELECTION_THRESHOLD,
+		        featureSelect);
+	}
+
+	public static AnalysisEngineDescription createAnnotatorDescription(File modelDirectory)
+			throws ResourceInitializationException {
+		return AnalysisEngineFactory.createPrimitiveDescription(
+				TimeAnnotator.class,
+				CleartkAnnotator.PARAM_IS_TRAINING,
+				false,
+				GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
+				new File(modelDirectory, "model.jar"),
+				TimeAnnotator.PARAM_FEATURE_SELECTION_URI,
+				TimeAnnotator.createFeatureSelectionURI(modelDirectory));
+	}
+
+	protected List<SimpleFeatureExtractor> tokenFeatureExtractors;
+
+	protected List<CleartkExtractor> contextFeatureExtractors;
+
+	//  protected List<SimpleFeatureExtractor> parseFeatureExtractors;
+	protected ParseSpanFeatureExtractor parseExtractor;
+
+	private BIOChunking<BaseToken, TimeMention> timeChunking;
+	
+	private FeatureSelection<String> featureSelection;
+
+	private static final String FEATURE_SELECTION_NAME = "SelectNeighborFeatures";
+
+	public static FeatureSelection<String> createFeatureSelection(double threshold) {
+		return new Chi2FeatureSelection<String>(TimeAnnotator.FEATURE_SELECTION_NAME, threshold);
+	}
+	
+	public static URI createFeatureSelectionURI(File outputDirectoryName) {
+		return new File(outputDirectoryName, FEATURE_SELECTION_NAME + "_Chi2_extractor.dat").toURI();
+	}
+
+	@Override
+	public void initialize(UimaContext context) throws ResourceInitializationException {
+		super.initialize(context);
+
+		// define chunking
+		this.timeChunking = new BIOChunking<BaseToken, TimeMention>(BaseToken.class, TimeMention.class);
+
+		CombinedExtractor allExtractors = new CombinedExtractor(
+				new CoveredTextExtractor(),
+				new CharacterCategoryPatternExtractor(PatternType.REPEATS_MERGED),
+				new CharacterCategoryPatternExtractor(PatternType.ONE_PER_CHAR),
+				new TypePathExtractor(BaseToken.class, "partOfSpeech"),
+				new TimeWordTypeExtractor());
+
+		//    CombinedExtractor parseExtractors = new CombinedExtractor(
+		//        new ParseSpanFeatureExtractor()
+		//        );
+		this.tokenFeatureExtractors = new ArrayList<SimpleFeatureExtractor>();
+		this.tokenFeatureExtractors.add(allExtractors);
+
+		this.contextFeatureExtractors = new ArrayList<CleartkExtractor>();
+		this.contextFeatureExtractors.add(new CleartkExtractor(
+				BaseToken.class,
+				allExtractors,
+				new Preceding(3),
+				new Following(3)));
+		//    this.parseFeatureExtractors = new ArrayList<ParseSpanFeatureExtractor>();
+		//    this.parseFeatureExtractors.add(new ParseSpanFeatureExtractor());
+		parseExtractor = new ParseSpanFeatureExtractor();
+
+		//initialize feature selection
+		if (featureSelectionThreshold == 0) {
+			this.featureSelection = null;
+		} else {
+			this.featureSelection = TimeAnnotator.createFeatureSelection(this.featureSelectionThreshold);
+
+			if (this.featureSelectionURI != null) {
+				try {
+					this.featureSelection.load(this.featureSelectionURI);
+				} catch (IOException e) {
+					throw new ResourceInitializationException(e);
+				}
+			}
+		}
+	}
+
+	@Override
+	public void process(JCas jCas, Segment segment) throws AnalysisEngineProcessException {
+		// classify tokens within each sentence
+		for (Sentence sentence : JCasUtil.selectCovered(jCas, Sentence.class, segment)) {
+			List<BaseToken> tokens = JCasUtil.selectCovered(jCas, BaseToken.class, sentence);
+
+			// during training, the list of all outcomes for the tokens
+			List<String> outcomes;
+			if (this.isTraining()) {
+				List<TimeMention> times = JCasUtil.selectCovered(jCas, TimeMention.class, sentence);
+				outcomes = this.timeChunking.createOutcomes(jCas, tokens, times);
+			}
+			// during prediction, the list of outcomes predicted so far
+			else {
+				outcomes = new ArrayList<String>();
+			}
+
+			// extract features for all tokens
+			int tokenIndex = -1;
+			for (BaseToken token : tokens) {
+				++tokenIndex;
+
+				List<Feature> features = new ArrayList<Feature>();
+				// features from token attributes
+				for (SimpleFeatureExtractor extractor : this.tokenFeatureExtractors) {
+					features.addAll(extractor.extract(jCas, token));
+				}
+				// features from surrounding tokens
+				for (CleartkExtractor extractor : this.contextFeatureExtractors) {
+					features.addAll(extractor.extractWithin(jCas, token, sentence));
+				}
+				// features from previous classifications
+				int nPreviousClassifications = 2;
+				for (int i = nPreviousClassifications; i > 0; --i) {
+					int index = tokenIndex - i;
+					String previousOutcome = index < 0 ? "O" : outcomes.get(index);
+					features.add(new Feature("PreviousOutcome_" + i, previousOutcome));
+				}
+				//add segment ID as a features:
+				features.add(new Feature("SegmentID", segment.getId()));
+
+				// features from dominating parse tree
+				//        for(SimpleFeatureExtractor extractor : this.parseFeatureExtractors){
+				BaseToken startToken = token;
+				for(int i = tokenIndex-1; i >= 0; --i){
+					String outcome = outcomes.get(i);
+					if(outcome.equals("O")){
+						break;
+					}
+					startToken = tokens.get(i);
+				}
+				features.addAll(parseExtractor.extract(jCas, startToken.getBegin(), token.getEnd()));
+				//        }
+				
+				// apply feature selection, if necessary
+		        if (this.featureSelection != null) {
+		          features = this.featureSelection.transform(features);
+		        }
+				
+				// if training, write to data file
+				if (this.isTraining()) {
+					String outcome = outcomes.get(tokenIndex);
+					this.dataWriter.write(new Instance<String>(outcome, features));
+				}else {// if predicting, add prediction to outcomes
+					outcomes.add(this.classifier.classify(features));
+				}
+			}
+
+			// during prediction, convert chunk labels to times and add them to the CAS
+			if (!this.isTraining()) {
+				JCas timexCas;
+				try {
+					timexCas = jCas.getView(TIMEX_VIEW);
+				} catch (CASException e) {
+					throw new AnalysisEngineProcessException(e);
+				}
+				this.timeChunking.createChunks(timexCas, tokens, outcomes);
+			}
+		}
+	}
 }

Modified: ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/eval/EvaluationOfTimeSpans.java
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/eval/EvaluationOfTimeSpans.java?rev=1504557&r1=1504556&r2=1504557&view=diff
==============================================================================
--- ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/eval/EvaluationOfTimeSpans.java
(original)
+++ ctakes/trunk/ctakes-temporal/src/main/java/org/apache/ctakes/temporal/eval/EvaluationOfTimeSpans.java
Thu Jul 18 17:27:52 2013
@@ -29,6 +29,7 @@ import org.apache.ctakes.temporal.ae.CRF
 import org.apache.ctakes.temporal.ae.ConstituencyBasedTimeAnnotator;
 import org.apache.ctakes.temporal.ae.MetaTimeAnnotator;
 import org.apache.ctakes.temporal.ae.TimeAnnotator;
+import org.apache.ctakes.temporal.ae.feature.selection.FeatureSelection;
 import org.apache.ctakes.typesystem.type.textsem.TimeMention;
 import org.apache.ctakes.typesystem.type.textspan.Segment;
 import org.apache.uima.analysis_engine.AnalysisEngineDescription;
@@ -37,7 +38,11 @@ import org.apache.uima.jcas.tcas.Annotat
 import org.apache.uima.resource.ResourceInitializationException;
 import org.cleartk.classifier.CleartkAnnotator;
 import org.cleartk.classifier.CleartkSequenceAnnotator;
+import org.cleartk.classifier.Instance;
+//import org.cleartk.classifier.DataWriter;
 import org.cleartk.classifier.crfsuite.CRFSuiteStringOutcomeDataWriter;
+import org.cleartk.classifier.feature.transform.InstanceDataWriter;
+import org.cleartk.classifier.feature.transform.InstanceStream;
 import org.cleartk.classifier.jar.DefaultDataWriterFactory;
 import org.cleartk.classifier.jar.DefaultSequenceDataWriterFactory;
 import org.cleartk.classifier.jar.DirectoryDataWriterFactory;
@@ -53,144 +58,183 @@ import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Ordering;
 import com.lexicalscope.jewel.cli.CliFactory;
+import com.lexicalscope.jewel.cli.Option;
 
 public class EvaluationOfTimeSpans extends EvaluationOfAnnotationSpans_ImplBase {
 
-  public static void main(String[] args) throws Exception {
-    Options options = CliFactory.parseArguments(Options.class, args);
-    List<Integer> patientSets = options.getPatients().getList();
-    List<Integer> trainItems = THYMEData.getTrainPatientSets(patientSets);
-    List<Integer> devItems = THYMEData.getDevPatientSets(patientSets);
-    
-    // specify the annotator classes to use
-    List<Class<? extends JCasAnnotator_ImplBase>> annotatorClasses = Lists.newArrayList();
-    annotatorClasses.add(BackwardsTimeAnnotator.class);
-    annotatorClasses.add(TimeAnnotator.class);
-    annotatorClasses.add(ConstituencyBasedTimeAnnotator.class);
-    annotatorClasses.add(CRFTimeAnnotator.class);
-    annotatorClasses.add(MetaTimeAnnotator.class);
-    Map<Class<? extends JCasAnnotator_ImplBase>, String[]> annotatorTrainingArguments
= Maps.newHashMap();
-    annotatorTrainingArguments.put(BackwardsTimeAnnotator.class, new String[]{"-c", "0.1"});
-    annotatorTrainingArguments.put(TimeAnnotator.class, new String[]{"-c", "0.1"});
-    annotatorTrainingArguments.put(ConstituencyBasedTimeAnnotator.class, new String[]{"-c",
"0.1"});
-    annotatorTrainingArguments.put(CRFTimeAnnotator.class, new String[]{"-p", "c2=0.1"});
-    annotatorTrainingArguments.put(MetaTimeAnnotator.class, new String[]{"-p", "c2=0.1"});
-    
-    // run one evaluation per annotator class
-    final Map<Class<?>, AnnotationStatistics<?>> annotatorStats = Maps.newHashMap();
-    for (Class<? extends JCasAnnotator_ImplBase> annotatorClass : annotatorClasses)
{
-      EvaluationOfTimeSpans evaluation = new EvaluationOfTimeSpans(
-          new File("target/eval/time-spans"),
-          options.getRawTextDirectory(),
-          options.getXMLDirectory(),
-          options.getXMLFormat(),
-          options.getXMIDirectory(),
-          options.getTreebankDirectory(),
-          annotatorClass,
-          options.getPrintOverlappingSpans(),
-          annotatorTrainingArguments.get(annotatorClass));
-      evaluation.prepareXMIsFor(patientSets);
-      String name = String.format("%s.errors", annotatorClass.getSimpleName());
-      evaluation.setLogging(Level.FINE, new File("target/eval", name));
-      AnnotationStatistics<String> stats = evaluation.trainAndTest(trainItems, devItems);
-      annotatorStats.put(annotatorClass, stats);
-    }
-
-    // allow ordering of models by F1
-    Ordering<Class<? extends JCasAnnotator_ImplBase>> byF1 = Ordering.natural().onResultOf(
-      new Function<Class<? extends JCasAnnotator_ImplBase>, Double>() {
-        @Override
-        public Double apply(
-            Class<? extends JCasAnnotator_ImplBase> annotatorClass) {
-          return annotatorStats.get(annotatorClass).f1();
-        }
-      });
-
-    // print out models, ordered by F1
-    for (Class<?> annotatorClass : byF1.sortedCopy(annotatorClasses)) {
-      System.err.printf("===== %s =====\n", annotatorClass.getSimpleName());
-      System.err.println(annotatorStats.get(annotatorClass));
-    }
-  }
-
-  private Class<? extends JCasAnnotator_ImplBase> annotatorClass;
-
-  private String[] trainingArguments;
-
-  public EvaluationOfTimeSpans(
-      File baseDirectory,
-      File rawTextDirectory,
-      File xmlDirectory,
-      XMLFormat xmlFormat,
-      File xmiDirectory,
-      File treebankDirectory,
-      Class<? extends JCasAnnotator_ImplBase> annotatorClass,
-      boolean printOverlapping,
-      String[] trainingArguments) {
-    super(baseDirectory, rawTextDirectory, xmlDirectory, xmlFormat, xmiDirectory, treebankDirectory,
TimeMention.class);
-    this.annotatorClass = annotatorClass;
-    this.trainingArguments = trainingArguments;
-    this.printOverlapping = printOverlapping;
-  }
-
-  @Override
-  protected AnalysisEngineDescription getDataWriterDescription(File directory)
-      throws ResourceInitializationException {
-    if(MetaTimeAnnotator.class.isAssignableFrom(this.annotatorClass)){
-      return MetaTimeAnnotator.getDataWriterDescription(CRFSuiteStringOutcomeDataWriter.class,
directory);          
-    }else if(CleartkAnnotator.class.isAssignableFrom(this.annotatorClass)){
-      return AnalysisEngineFactory.createPrimitiveDescription(
-          this.annotatorClass,
-          CleartkAnnotator.PARAM_IS_TRAINING,
-          true,
-          DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
-          LIBLINEARStringOutcomeDataWriter.class,
-          DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
-          this.getModelDirectory(directory));
-    }else if(CleartkSequenceAnnotator.class.isAssignableFrom(this.annotatorClass)){
-      return AnalysisEngineFactory.createPrimitiveDescription(
-          this.annotatorClass,
-          CleartkSequenceAnnotator.PARAM_IS_TRAINING,
-          true,
-          DefaultSequenceDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
-          CRFSuiteStringOutcomeDataWriter.class,
-          DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
-          this.getModelDirectory(directory));
-    }else{
-      throw new ResourceInitializationException("Annotator class was not recognized as an
acceptable class!", new Object[]{});
-    }
-  }
-
-  @Override
-  protected void trainAndPackage(File directory) throws Exception {
-    JarClassifierBuilder.trainAndPackage(this.getModelDirectory(directory), this.trainingArguments);
-  }
-
-  @Override
-  protected AnalysisEngineDescription getAnnotatorDescription(File directory)
-      throws ResourceInitializationException {
-    if(MetaTimeAnnotator.class.isAssignableFrom(this.annotatorClass)){
-      return MetaTimeAnnotator.getAnnotatorDescription(directory);
-    }
-    return AnalysisEngineFactory.createPrimitiveDescription(
-        this.annotatorClass,
-        CleartkAnnotator.PARAM_IS_TRAINING,
-        false,
-        GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
-        new File(this.getModelDirectory(directory), "model.jar"));
-  }
-
-  @Override
-  protected Collection<? extends Annotation> getGoldAnnotations(JCas jCas, Segment
segment) {
-    return selectExact(jCas, TimeMention.class, segment);
-  }
-
-  @Override
-  protected Collection<? extends Annotation> getSystemAnnotations(JCas jCas, Segment
segment) {
-    return selectExact(jCas, TimeMention.class, segment);
-  }
-  
-  private File getModelDirectory(File directory) {
-    return new File(directory, this.annotatorClass.getSimpleName());
-  }
+	static interface Options extends Evaluation_ImplBase.Options {
+
+		@Option(longName = "featureSelectionThreshold", defaultValue = "0")
+		public float getFeatureSelectionThreshold();
+	}
+
+	public static void main(String[] args) throws Exception {
+		Options options = CliFactory.parseArguments(Options.class, args);
+		List<Integer> patientSets = options.getPatients().getList();
+		List<Integer> trainItems = THYMEData.getTrainPatientSets(patientSets);
+		List<Integer> devItems = THYMEData.getDevPatientSets(patientSets);
+
+		// specify the annotator classes to use
+		List<Class<? extends JCasAnnotator_ImplBase>> annotatorClasses = Lists.newArrayList();
+		annotatorClasses.add(BackwardsTimeAnnotator.class);
+		annotatorClasses.add(TimeAnnotator.class);
+		annotatorClasses.add(ConstituencyBasedTimeAnnotator.class);
+		annotatorClasses.add(CRFTimeAnnotator.class);
+		annotatorClasses.add(MetaTimeAnnotator.class);
+		Map<Class<? extends JCasAnnotator_ImplBase>, String[]> annotatorTrainingArguments
= Maps.newHashMap();
+		annotatorTrainingArguments.put(BackwardsTimeAnnotator.class, new String[]{"-c", "0.1"});
+		annotatorTrainingArguments.put(TimeAnnotator.class, new String[]{"-c", "0.1"});
+		annotatorTrainingArguments.put(ConstituencyBasedTimeAnnotator.class, new String[]{"-c",
"0.1"});
+		annotatorTrainingArguments.put(CRFTimeAnnotator.class, new String[]{"-p", "c2=0.1"});
+		annotatorTrainingArguments.put(MetaTimeAnnotator.class, new String[]{"-p", "c2=0.1"});
+
+		// run one evaluation per annotator class
+		final Map<Class<?>, AnnotationStatistics<?>> annotatorStats = Maps.newHashMap();
+		for (Class<? extends JCasAnnotator_ImplBase> annotatorClass : annotatorClasses) {
+			EvaluationOfTimeSpans evaluation = new EvaluationOfTimeSpans(
+					new File("target/eval/time-spans"),
+					options.getRawTextDirectory(),
+					options.getXMLDirectory(),
+					options.getXMLFormat(),
+					options.getXMIDirectory(),
+					options.getTreebankDirectory(),
+					options.getFeatureSelectionThreshold(),
+					annotatorClass,
+					options.getPrintOverlappingSpans(),
+					annotatorTrainingArguments.get(annotatorClass));
+			evaluation.prepareXMIsFor(patientSets);
+			String name = String.format("%s.errors", annotatorClass.getSimpleName());
+			evaluation.setLogging(Level.FINE, new File("target/eval", name));
+			AnnotationStatistics<String> stats = evaluation.trainAndTest(trainItems, devItems);
+			annotatorStats.put(annotatorClass, stats);
+		}
+
+		// allow ordering of models by F1
+		Ordering<Class<? extends JCasAnnotator_ImplBase>> byF1 = Ordering.natural().onResultOf(
+				new Function<Class<? extends JCasAnnotator_ImplBase>, Double>() {
+					@Override
+					public Double apply(
+							Class<? extends JCasAnnotator_ImplBase> annotatorClass) {
+						return annotatorStats.get(annotatorClass).f1();
+					}
+				});
+
+		// print out models, ordered by F1
+		for (Class<?> annotatorClass : byF1.sortedCopy(annotatorClasses)) {
+			System.err.printf("===== %s =====\n", annotatorClass.getSimpleName());
+			System.err.println(annotatorStats.get(annotatorClass));
+		}
+	}
+
+	private Class<? extends JCasAnnotator_ImplBase> annotatorClass;
+
+	private String[] trainingArguments;
+	
+	private float featureSelectionThreshold;
+
+	public EvaluationOfTimeSpans(
+			File baseDirectory,
+			File rawTextDirectory,
+			File xmlDirectory,
+			XMLFormat xmlFormat,
+			File xmiDirectory,
+			File treebankDirectory,
+			float featureSelectionThreshold,
+			Class<? extends JCasAnnotator_ImplBase> annotatorClass,
+					boolean printOverlapping,
+					String[] trainingArguments) {
+		super(baseDirectory, rawTextDirectory, xmlDirectory, xmlFormat, xmiDirectory, treebankDirectory,
TimeMention.class);
+		this.annotatorClass = annotatorClass;
+		this.featureSelectionThreshold = featureSelectionThreshold;
+		this.trainingArguments = trainingArguments;
+		this.printOverlapping = printOverlapping;
+	}
+
+	@Override
+	protected AnalysisEngineDescription getDataWriterDescription(File directory)
+			throws ResourceInitializationException {
+		if(MetaTimeAnnotator.class.isAssignableFrom(this.annotatorClass)){
+			return MetaTimeAnnotator.getDataWriterDescription(CRFSuiteStringOutcomeDataWriter.class,
directory);          
+		}else if(CleartkAnnotator.class.isAssignableFrom(this.annotatorClass)){
+			//limit feature selection only to TimeAnnotator
+			if("org.apache.ctakes.temporal.ae.TimeAnnotator".equals(this.annotatorClass.getName())){
+				Class<?> dataWriterClass = this.featureSelectionThreshold > 0f
+				        ? InstanceDataWriter.class
+				        : LIBLINEARStringOutcomeDataWriter.class;
+				return TimeAnnotator.createDataWriterDescription(
+						dataWriterClass,
+						this.getModelDirectory(directory),
+						this.featureSelectionThreshold);
+			}
+			return AnalysisEngineFactory.createPrimitiveDescription(
+					this.annotatorClass,
+					CleartkAnnotator.PARAM_IS_TRAINING,
+					true,
+					DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
+					LIBLINEARStringOutcomeDataWriter.class,
+					DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
+					this.getModelDirectory(directory));
+			
+		}else if(CleartkSequenceAnnotator.class.isAssignableFrom(this.annotatorClass)){
+			return AnalysisEngineFactory.createPrimitiveDescription(
+					this.annotatorClass,
+					CleartkSequenceAnnotator.PARAM_IS_TRAINING,
+					true,
+					DefaultSequenceDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
+					CRFSuiteStringOutcomeDataWriter.class,
+					DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
+					this.getModelDirectory(directory));
+		}else{
+			throw new ResourceInitializationException("Annotator class was not recognized as an acceptable
class!", new Object[]{});
+		}
+	}
+
+	@Override
+	protected void trainAndPackage(File directory) throws Exception {
+		if (this.featureSelectionThreshold > 0 && "org.apache.ctakes.temporal.ae.TimeAnnotator".equals(this.annotatorClass.getName())
) {
+			// Extracting features and writing instances
+			Iterable<Instance<String>> instances = InstanceStream.loadFromDirectory(this.getModelDirectory(directory));
+			// Collect MinMax stats for feature normalization
+			FeatureSelection<String> featureSelection = TimeAnnotator.createFeatureSelection(this.featureSelectionThreshold);
+			featureSelection.train(instances);
+			featureSelection.save(TimeAnnotator.createFeatureSelectionURI(this.getModelDirectory(directory)));
+			// now write in the libsvm format
+			LIBLINEARStringOutcomeDataWriter dataWriter = new LIBLINEARStringOutcomeDataWriter(this.getModelDirectory(directory));
+			for (Instance<String> instance : instances) {
+				dataWriter.write(featureSelection.transform(instance));
+			}
+			dataWriter.finish();
+		}
+		JarClassifierBuilder.trainAndPackage(this.getModelDirectory(directory), this.trainingArguments);
+	}
+
+	@Override
+	protected AnalysisEngineDescription getAnnotatorDescription(File directory)
+			throws ResourceInitializationException {
+		if(MetaTimeAnnotator.class.isAssignableFrom(this.annotatorClass)){
+			return MetaTimeAnnotator.getAnnotatorDescription(directory);
+		}else if("org.apache.ctakes.temporal.ae.TimeAnnotator".equals(this.annotatorClass.getName()
)){
+			return TimeAnnotator.createAnnotatorDescription(this.getModelDirectory(directory));
+		}
+		return AnalysisEngineFactory.createPrimitiveDescription(
+				this.annotatorClass,
+				CleartkAnnotator.PARAM_IS_TRAINING,
+				false,
+				GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
+				new File(this.getModelDirectory(directory), "model.jar"));
+	}
+
+	@Override
+	protected Collection<? extends Annotation> getGoldAnnotations(JCas jCas, Segment segment)
{
+		return selectExact(jCas, TimeMention.class, segment);
+	}
+
+	@Override
+	protected Collection<? extends Annotation> getSystemAnnotations(JCas jCas, Segment
segment) {
+		return selectExact(jCas, TimeMention.class, segment);
+	}
+
+	private File getModelDirectory(File directory) {
+		return new File(directory, this.annotatorClass.getSimpleName());
+	}
 }



Mime
View raw message