ctakes-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tm...@apache.org
Subject svn commit: r1666502 - in /ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval: EvaluationOfEventCoreference.java EvaluationOfMarkableSalience.java EvaluationOfMarkableSpans.java
Date Fri, 13 Mar 2015 16:23:46 GMT
Author: tmill
Date: Fri Mar 13 16:23:46 2015
New Revision: 1666502

URL: http://svn.apache.org/r1666502
Log:
Evaluation code refactored from temporal

Added:
    ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfEventCoreference.java
    ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfMarkableSalience.java
    ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfMarkableSpans.java

Added: ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfEventCoreference.java
URL: http://svn.apache.org/viewvc/ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfEventCoreference.java?rev=1666502&view=auto
==============================================================================
--- ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfEventCoreference.java (added)
+++ ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfEventCoreference.java Fri Mar 13 16:23:46 2015
@@ -0,0 +1,700 @@
+package org.apache.ctakes.coreference.eval;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
+import org.apache.ctakes.assertion.medfacts.cleartk.PolarityCleartkAnalysisEngine;
+import org.apache.ctakes.core.resource.FileLocator;
+import org.apache.ctakes.core.util.DocumentIDAnnotationUtil;
+import org.apache.ctakes.coreference.ae.CoreferenceChainScoringOutput;
+import org.apache.ctakes.coreference.ae.DeterministicMarkableAnnotator;
+import org.apache.ctakes.coreference.ae.EventCoreferenceAnnotator;
+import org.apache.ctakes.coreference.ae.MarkableSalienceAnnotator;
+import org.apache.ctakes.coreference.ae.MentionClusterCoreferenceAnnotator;
+import org.apache.ctakes.coreference.ae.PersonChainAnnotator;
+import org.apache.ctakes.dependency.parser.util.DependencyUtility;
+import org.apache.ctakes.relationextractor.eval.RelationExtractorEvaluation.HashableArguments;
+import org.apache.ctakes.temporal.ae.DocTimeRelAnnotator;
+import org.apache.ctakes.temporal.ae.EventAnnotator;
+import org.apache.ctakes.temporal.eval.EvaluationOfEventTimeRelations;
+import org.apache.ctakes.temporal.eval.EvaluationOfTemporalRelations_ImplBase;
+import org.apache.ctakes.temporal.eval.Evaluation_ImplBase;
+import org.apache.ctakes.temporal.eval.EvaluationOfEventTimeRelations.ParameterSettings;
+import org.apache.ctakes.temporal.eval.EvaluationOfTemporalRelations_ImplBase.TempRelOptions;
+import org.apache.ctakes.temporal.eval.Evaluation_ImplBase.Subcorpus;
+import org.apache.ctakes.temporal.eval.Evaluation_ImplBase.XMLFormat;
+import org.apache.ctakes.typesystem.type.constants.CONST;
+import org.apache.ctakes.typesystem.type.relation.BinaryTextRelation;
+import org.apache.ctakes.typesystem.type.relation.CollectionTextRelation;
+import org.apache.ctakes.typesystem.type.relation.CoreferenceRelation;
+import org.apache.ctakes.typesystem.type.relation.RelationArgument;
+import org.apache.ctakes.typesystem.type.syntax.BaseToken;
+import org.apache.ctakes.typesystem.type.syntax.ConllDependencyNode;
+import org.apache.ctakes.typesystem.type.syntax.NewlineToken;
+import org.apache.ctakes.typesystem.type.syntax.WordToken;
+import org.apache.ctakes.typesystem.type.textsem.Markable;
+import org.apache.ctakes.typesystem.type.textspan.Paragraph;
+import org.apache.ctakes.utils.distsem.WordEmbeddings;
+import org.apache.ctakes.utils.distsem.WordVector;
+import org.apache.ctakes.utils.distsem.WordVectorReader;
+import org.apache.log4j.Level;
+import org.apache.log4j.Logger;
+import org.apache.uima.UimaContext;
+import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
+import org.apache.uima.analysis_engine.metadata.FixedFlow;
+import org.apache.uima.analysis_engine.metadata.FlowConstraints;
+import org.apache.uima.cas.CAS;
+import org.apache.uima.cas.CASException;
+import org.apache.uima.collection.CollectionReader;
+import org.apache.uima.fit.component.ViewCreatorAnnotator;
+import org.apache.uima.fit.descriptor.ConfigurationParameter;
+import org.apache.uima.fit.factory.AggregateBuilder;
+import org.apache.uima.fit.factory.AnalysisEngineFactory;
+import org.apache.uima.fit.factory.FlowControllerFactory;
+import org.apache.uima.fit.pipeline.JCasIterator;
+import org.apache.uima.fit.pipeline.SimplePipeline;
+import org.apache.uima.fit.util.JCasUtil;
+import org.apache.uima.flow.FinalStep;
+import org.apache.uima.flow.Flow;
+import org.apache.uima.flow.FlowControllerContext;
+import org.apache.uima.flow.FlowControllerDescription;
+import org.apache.uima.flow.JCasFlow_ImplBase;
+import org.apache.uima.flow.SimpleStep;
+import org.apache.uima.flow.Step;
+import org.apache.uima.jcas.JCas;
+import org.apache.uima.jcas.cas.EmptyFSList;
+import org.apache.uima.jcas.cas.FSArray;
+import org.apache.uima.jcas.cas.FSList;
+import org.apache.uima.jcas.cas.FloatArray;
+import org.apache.uima.jcas.cas.NonEmptyFSList;
+import org.apache.uima.jcas.tcas.Annotation;
+import org.apache.uima.resource.ResourceInitializationException;
+import org.apache.uima.util.FileUtils;
+import org.cleartk.eval.AnnotationStatistics;
+import org.cleartk.ml.jar.JarClassifierBuilder;
+import org.cleartk.ml.liblinear.LibLinearStringOutcomeDataWriter;
+import org.cleartk.ml.libsvm.tk.TkLibSvmStringOutcomeDataWriter;
+import org.cleartk.ml.tksvmlight.model.CompositeKernel.ComboOperator;
+import org.cleartk.util.ViewUriUtil;
+
+import com.google.common.base.Function;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
+import com.lexicalscope.jewel.cli.CliFactory;
+import com.lexicalscope.jewel.cli.Option;
+
+public class EvaluationOfEventCoreference extends EvaluationOfTemporalRelations_ImplBase {
+ 
+
+  static interface CoreferenceOptions extends TempRelOptions{
+    @Option
+    public String getOutputDirectory();
+    
+    @Option
+    public boolean getUseTmp();
+    
+    @Option
+    public boolean getTestOnTrain();
+    
+    @Option(longName="external")
+    public boolean getUseExternalScorer();
+    
+    @Option(shortName="t", defaultValue={"MENTION_PAIR"})
+    public EVAL_SYSTEM getEvalSystem();
+  }
+  
+  private static Logger logger = Logger.getLogger(EvaluationOfEventCoreference.class);
+  public static float COREF_DOWNSAMPLE = 0.5f;
+  protected static ParameterSettings allParams = new ParameterSettings(DEFAULT_BOTH_DIRECTIONS, COREF_DOWNSAMPLE, "tk",
+      1.0, 1.0, "linear", ComboOperator.SUM, 0.1, 0.5);
+
+  public static void main(String[] args) throws Exception {
+    CoreferenceOptions options = CliFactory.parseArguments(CoreferenceOptions.class, args);
+
+    List<Integer> patientSets = options.getPatients().getList();
+    List<Integer> trainItems = getTrainItems(options);
+    List<Integer> testItems = options.getTestOnTrain() ? getTrainItems(options) : getTestItems(options);
+
+    ParameterSettings params = allParams;
+    File workingDir = new File("target/eval/temporal-relations/coreference");
+    if(!workingDir.exists()) workingDir.mkdirs();
+    if(options.getUseTmp()){
+      File tempModelDir = File.createTempFile("temporal", null, workingDir);
+      tempModelDir.delete();
+      tempModelDir.mkdir();
+      workingDir = tempModelDir;
+    }
+    EvaluationOfEventCoreference eval = new EvaluationOfEventCoreference(
+        workingDir,
+        options.getRawTextDirectory(),
+        options.getXMLDirectory(),
+        options.getXMLFormat(),
+        options.getSubcorpus(),
+        options.getXMIDirectory(),
+        options.getTreebankDirectory(),
+        options.getPrintErrors(),
+        options.getPrintFormattedRelations(),
+        params,
+        options.getKernelParams(),
+        options.getOutputDirectory());
+
+    if(options.getSkipTrain()){
+      eval.skipTrain = true;
+    }
+    if(options.getSkipDataWriting()){
+      eval.skipWrite = true;
+    }
+    eval.evalType = options.getEvalSystem();
+    eval.prepareXMIsFor(patientSets);
+
+    params.stats = eval.trainAndTest(trainItems, testItems);//training);//
+    //      System.err.println(options.getKernelParams() == null ? params : options.getKernelParams());
+    System.err.println(params.stats);
+
+    if(options.getUseTmp()){
+      FileUtils.deleteRecursive(workingDir);
+    }
+    
+    if(options.getUseExternalScorer()){
+      Pattern patt = Pattern.compile("(?:Coreference|BLANC): Recall: \\([^\\)]*\\) (\\S+)%.*Precision: \\([^\\)]*\\) (\\S+)%.*F1: (\\S+)%");
+      Runtime runtime = Runtime.getRuntime();
+      Process p = runtime.exec(new String[]{
+          "perl",
+          "/home/tmill/soft/reference-coreference-scorers-read-only/scorer.pl",
+          "all",
+          options.getOutputDirectory() + "gold.chains",
+          options.getOutputDirectory() + "system.chains",
+          "none"});
+      BufferedReader reader = new BufferedReader(new InputStreamReader(p.getInputStream()));
+      String line, metric=null;
+      System.out.println(String.format("%10s%7s%7s%7s", "Metric", "Rec", "Prec", "F1"));
+      Map<String,Double> scores = new HashMap<>();
+      while((line = reader.readLine()) != null){
+        line = line.trim();
+        if(line.startsWith("METRIC")){
+          metric = line.substring(7);  // everything after "METRIC"
+          metric = metric.substring(0, metric.length()-1);  // remove colon from the end
+        }else if(line.startsWith("Coreference")){
+          Matcher m = patt.matcher(line);
+          if(m.matches()){
+            System.out.println(String.format("%10s%7.2f%7.2f%7.2f", metric, Double.parseDouble(m.group(1)), Double.parseDouble(m.group(2)), Double.parseDouble(m.group(3))));
+            scores.put(metric, Double.parseDouble(m.group(3)));
+          }
+        }
+      }
+      
+      if(scores.containsKey("muc") && scores.containsKey("bcub") && scores.containsKey("ceafe")){
+        double conll = (scores.get("muc") + scores.get("bcub") + scores.get("ceafe")) / 3.0;
+        System.out.println(String.format("%10s              %7.2f", "Conll", conll));
+      }
+    }
+  }
+  
+  boolean skipTrain=false; 
+  boolean skipWrite=false;
+  public enum EVAL_SYSTEM { BASELINE, MENTION_PAIR, MENTION_CLUSTER };
+  EVAL_SYSTEM evalType;
+  
+  private String outputDirectory;
+  
+  public EvaluationOfEventCoreference(File baseDirectory,
+      File rawTextDirectory, File xmlDirectory,
+      org.apache.ctakes.temporal.eval.Evaluation_ImplBase.XMLFormat xmlFormat, Subcorpus subcorpus,
+      File xmiDirectory, File treebankDirectory, boolean printErrors,
+      boolean printRelations, ParameterSettings params, String cmdParams, String outputDirectory) {
+    super(baseDirectory, rawTextDirectory, xmlDirectory, xmlFormat, subcorpus, xmiDirectory,
+        treebankDirectory, printErrors, printRelations, params);
+    this.outputDirectory = outputDirectory;
+    this.kernelParams = cmdParams == null ? null : cmdParams.split(" ");
+  }
+
+  @Override
+  protected void train(CollectionReader collectionReader, File directory)
+      throws Exception {
+    if(skipTrain) return;
+    if(!skipWrite){
+      AggregateBuilder aggregateBuilder = this.getPreprocessorAggregateBuilder();
+      aggregateBuilder.add(PolarityCleartkAnalysisEngine.createAnnotatorDescription());
+      aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(ViewCreatorAnnotator.class, ViewCreatorAnnotator.PARAM_VIEW_NAME, "Baseline"));
+      aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(DocumentIDPrinter.class));
+      aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(ParagraphAnnotator.class));
+      aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(ParagraphVectorAnnotator.class));
+      aggregateBuilder.add(EventAnnotator.createAnnotatorDescription());
+      aggregateBuilder.add(DocTimeRelAnnotator.createAnnotatorDescription("/org/apache/ctakes/temporal/ae/doctimerel/model.jar"));
+      aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(DeterministicMarkableAnnotator.class));
+      //    aggregateBuilder.add(CopyFromGold.getDescription(/*Markable.class,*/ CoreferenceRelation.class, CollectionTextRelation.class));
+      aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(RemovePersonMarkables.class));
+      aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(CopyCoreferenceRelations.class, CopyCoreferenceRelations.PARAM_GOLD_VIEW, GOLD_VIEW_NAME));
+      aggregateBuilder.add(MarkableSalienceAnnotator.createAnnotatorDescription("/org/apache/ctakes/temporal/ae/salience/model.jar"));
+      if(this.evalType == EVAL_SYSTEM.MENTION_PAIR){
+        aggregateBuilder.add(EventCoreferenceAnnotator.createDataWriterDescription(
+            //        TKSVMlightStringOutcomeDataWriter.class,
+            //        LibLinearStringOutcomeDataWriter.class,
+            TkLibSvmStringOutcomeDataWriter.class,
+            directory,
+            params.probabilityOfKeepingANegativeExample
+            ));
+      }else if(this.evalType == EVAL_SYSTEM.MENTION_CLUSTER){
+        aggregateBuilder.add(MentionClusterCoreferenceAnnotator.createDataWriterDescription(
+            LibLinearStringOutcomeDataWriter.class,
+//            TkLibSvmStringOutcomeDataWriter.class,
+            directory,
+            params.probabilityOfKeepingANegativeExample
+            ));
+      }
+      Logger.getLogger(EventCoreferenceAnnotator.class).setLevel(Level.WARN);
+      // create gold chains for writing out which we can then use for our scoring tool
+      //    aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(CoreferenceChainScoringOutput.class,
+      //        CoreferenceChainScoringOutput.PARAM_OUTPUT_DIR,
+      //        this.outputDirectory + "train"));
+      FlowControllerDescription corefFlowControl = FlowControllerFactory.createFlowControllerDescription(CorefEvalFlowController.class);
+      aggregateBuilder.setFlowControllerDescription(corefFlowControl);
+
+      SimplePipeline.runPipeline(collectionReader, aggregateBuilder.createAggregate());
+    }
+    String[] optArray;
+
+    if(this.kernelParams == null){
+      ArrayList<String> svmOptions = new ArrayList<>();
+      svmOptions.add("-c"); svmOptions.add(""+params.svmCost);        // svm cost
+      svmOptions.add("-t"); svmOptions.add(""+params.svmKernelIndex); // kernel index 
+      svmOptions.add("-d"); svmOptions.add("3");                      // degree parameter for polynomial
+      svmOptions.add("-g"); svmOptions.add(""+params.svmGamma);
+      if(params.svmKernelIndex==ParameterSettings.SVM_KERNELS.indexOf("tk")){
+        svmOptions.add("-S"); svmOptions.add(""+params.secondKernelIndex);   // second kernel index (similar to -t) for composite kernel
+        String comboFlag = (params.comboOperator == ComboOperator.SUM ? "+" : params.comboOperator == ComboOperator.PRODUCT ? "*" : params.comboOperator == ComboOperator.TREE_ONLY ? "T" : "V");
+        svmOptions.add("-C"); svmOptions.add(comboFlag);
+        svmOptions.add("-L"); svmOptions.add(""+params.lambda);
+        svmOptions.add("-T"); svmOptions.add(""+params.tkWeight);
+        svmOptions.add("-N"); svmOptions.add("3");   // normalize trees and features
+      }
+      optArray = svmOptions.toArray(new String[]{});
+    }else{
+      optArray = this.kernelParams;
+      for(int i = 0; i < optArray.length; i+=2){
+        optArray[i] = "-" + optArray[i];
+      }
+    }
+    JarClassifierBuilder.trainAndPackage(directory, optArray);
+  }
+
+  @Override
+  protected AnnotationStatistics<String> test(
+      CollectionReader collectionReader, File directory) throws Exception {
+    AggregateBuilder aggregateBuilder = this.getPreprocessorAggregateBuilder();
+    aggregateBuilder.add(PolarityCleartkAnalysisEngine.createAnnotatorDescription());
+    aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(DocumentIDPrinter.class));
+    aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(ParagraphAnnotator.class));
+    aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(ParagraphVectorAnnotator.class));
+    aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(DeterministicMarkableAnnotator.class));
+    aggregateBuilder.add(EventAnnotator.createAnnotatorDescription());
+    aggregateBuilder.add(DocTimeRelAnnotator.createAnnotatorDescription("/org/apache/ctakes/temporal/ae/doctimerel/model.jar"));
+    aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(CoreferenceChainScoringOutput.class,
+        CoreferenceChainScoringOutput.PARAM_OUTPUT_FILENAME,
+        this.outputDirectory + "gold.chains",
+        CoreferenceChainScoringOutput.PARAM_GOLD_VIEW_NAME,
+        GOLD_VIEW_NAME));
+    aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(RemovePersonMarkables.class));
+    aggregateBuilder.add(MarkableSalienceAnnotator.createAnnotatorDescription("/org/apache/ctakes/temporal/ae/salience/model.jar"));
+    if(this.evalType == EVAL_SYSTEM.MENTION_PAIR){
+      aggregateBuilder.add(EventCoreferenceAnnotator.createAnnotatorDescription(directory.getAbsolutePath() + File.separator + "model.jar"));
+    }else if(this.evalType == EVAL_SYSTEM.MENTION_CLUSTER){
+      aggregateBuilder.add(MentionClusterCoreferenceAnnotator.createAnnotatorDescription(directory.getAbsolutePath() + File.separator + "model.jar"));
+    }
+//    aggregateBuilder.add(CoreferenceChainAnnotator.createAnnotatorDescription());
+    aggregateBuilder.add(PersonChainAnnotator.createAnnotatorDescription());
+    aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(CoreferenceChainScoringOutput.class,
+        CoreferenceChainScoringOutput.PARAM_OUTPUT_FILENAME,
+        this.outputDirectory + "system.chains"));
+
+    FlowControllerDescription corefFlowControl = FlowControllerFactory.createFlowControllerDescription(CorefEvalFlowController.class);
+    aggregateBuilder.setFlowControllerDescription(corefFlowControl);
+//    aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(XMIWriter.class));
+    Function<CoreferenceRelation, ?> getSpan = new Function<CoreferenceRelation, HashableArguments>() {
+      public HashableArguments apply(CoreferenceRelation relation) {
+        return new HashableArguments(relation);
+      }
+    };
+    Function<CoreferenceRelation, String> getOutcome = new Function<CoreferenceRelation,String>() {
+      public String apply(CoreferenceRelation relation){
+        return "Coreference";
+      }
+    };
+     
+    AnnotationStatistics<String> corefStats = new AnnotationStatistics<>();
+
+    for(Iterator<JCas> casIter =new JCasIterator(collectionReader, aggregateBuilder.createAggregate()); casIter.hasNext();){
+      JCas jCas = casIter.next();
+      JCas goldView = jCas.getView(GOLD_VIEW_NAME);
+      JCas systemView = jCas.getView(CAS.NAME_DEFAULT_SOFA);
+      Collection<CoreferenceRelation> goldRelations = JCasUtil.select(
+          goldView,
+          CoreferenceRelation.class);
+      Collection<CoreferenceRelation> systemRelations = JCasUtil.select(
+          systemView,
+          CoreferenceRelation.class);
+      corefStats.add(goldRelations, systemRelations, getSpan, getOutcome);
+      if(this.printErrors){
+        Map<HashableArguments, BinaryTextRelation> goldMap = Maps.newHashMap();
+        for (BinaryTextRelation relation : goldRelations) {
+          goldMap.put(new HashableArguments(relation), relation);
+        }
+        Map<HashableArguments, BinaryTextRelation> systemMap = Maps.newHashMap();
+        for (BinaryTextRelation relation : systemRelations) {
+          systemMap.put(new HashableArguments(relation), relation);
+        }
+        Set<HashableArguments> all = Sets.union(goldMap.keySet(), systemMap.keySet());
+        List<HashableArguments> sorted = Lists.newArrayList(all);
+        Collections.sort(sorted);
+        for (HashableArguments key : sorted) {
+          BinaryTextRelation goldRelation = goldMap.get(key);
+          BinaryTextRelation systemRelation = systemMap.get(key);
+          if (goldRelation == null) {
+            System.out.println("System added: " + formatRelation(systemRelation));
+          } else if (systemRelation == null) {
+            System.out.println("System dropped: " + formatRelation(goldRelation));
+          } else if (!systemRelation.getCategory().equals(goldRelation.getCategory())) {
+            String label = systemRelation.getCategory();
+            System.out.printf("System labeled %s for %s\n", label, formatRelation(goldRelation));
+          } else{
+            System.out.println("Nailed it! " + formatRelation(systemRelation));
+          }
+        }
+      }
+    }
+
+    return corefStats;
+  }
+  
+  public static class AnnotationComparator implements Comparator<Annotation> {
+
+    @Override
+    public int compare(Annotation o1, Annotation o2) {
+      if(o1.getBegin() < o2.getBegin()){
+        return -1;
+      }else if(o1.getBegin() == o2.getBegin() && o1.getEnd() < o2.getEnd()){
+        return -1;
+      }else if(o1.getBegin() == o2.getBegin() && o1.getEnd() > o2.getEnd()){
+        return 1;
+      }else if(o2.getBegin() < o1.getBegin()){
+        return 1;
+      }else{
+        return 0;
+      }
+    }
+  }
+  public static class DocumentIDPrinter extends org.apache.uima.fit.component.JCasAnnotator_ImplBase {
+    static Logger logger = Logger.getLogger(DocumentIDPrinter.class);
+    @Override
+    public void process(JCas jCas) throws AnalysisEngineProcessException {
+      String docId = DocumentIDAnnotationUtil.getDocumentID(jCas);
+      if(docId == null){
+        docId = new File(ViewUriUtil.getURI(jCas)).getName();
+      }
+      logger.info(String.format("Processing %s\n", docId));
+    }
+    
+  }
+  
+  public static class ParagraphAnnotator extends org.apache.uima.fit.component.JCasAnnotator_ImplBase {
+
+    @Override
+    public void process(JCas jcas) throws AnalysisEngineProcessException {
+      List<BaseToken> tokens = new ArrayList<>(JCasUtil.select(jcas, BaseToken.class));
+      BaseToken lastToken = null;
+      int parStart = 0;
+      
+      for(int i = 0; i < tokens.size(); i++){
+        BaseToken token = tokens.get(i);
+        if(parStart == i && token instanceof NewlineToken){
+          // we've just created a pargraph ending but there were multiple newlines -- don't want to start the
+          // new paragraph until we are past the newlines -- increment the parStart index and move forward
+          parStart++;
+        }else if(lastToken != null && token instanceof NewlineToken){
+          Paragraph par = new Paragraph(jcas, tokens.get(parStart).getBegin(), lastToken.getEnd());
+          par.addToIndexes();
+          parStart = i+1;
+        }
+        lastToken = token;
+      }
+      
+    }
+    
+  }
+  
+  public static class ParagraphVectorAnnotator extends org.apache.uima.fit.component.JCasAnnotator_ImplBase {
+    WordEmbeddings words = null;
+
+    @Override
+    public void initialize(final UimaContext context) throws ResourceInitializationException{
+      try {
+        words = WordVectorReader.getEmbeddings(FileLocator.getAsStream("org/apache/ctakes/coreference/distsem/mimic_vectors.txt"));
+      } catch (IOException e) {
+        e.printStackTrace();
+        throw new ResourceInitializationException(e);
+      }
+    }
+    
+    @Override
+    public void process(JCas jcas) throws AnalysisEngineProcessException {
+      List<Paragraph> pars = new ArrayList<>(JCasUtil.select(jcas, Paragraph.class));
+      FSArray parVecs = new FSArray(jcas, pars.size());
+      for(int parNum = 0; parNum < pars.size(); parNum++){
+        Paragraph par = pars.get(parNum);
+        float[] parVec = new float[words.getDimensionality()];
+
+        List<BaseToken> tokens = JCasUtil.selectCovered(BaseToken.class, par);
+        for(int i = 0; i < tokens.size(); i++){
+          BaseToken token = tokens.get(i);
+          if(token instanceof WordToken){
+            String word = token.getCoveredText().toLowerCase();
+            if(words.containsKey(word)){
+              WordVector wv = words.getVector(word);
+              for(int j = 0; j < parVec.length; j++){
+                parVec[j] += wv.getValue(j);
+              }
+            }          
+          }
+        }
+        normalize(parVec);
+        FloatArray vec = new FloatArray(jcas, words.getDimensionality());
+        vec.copyFromArray(parVec, 0, 0, parVec.length);
+        vec.addToIndexes();
+        parVecs.set(parNum, vec);
+      }
+      parVecs.addToIndexes();
+    }
+
+    private static final void normalize(float[] vec) {
+      double sum = 0.0;
+      for(int i = 0; i < vec.length; i++){
+        sum += (vec[i]*vec[i]);
+      }
+      sum = Math.sqrt(sum);
+      for(int i = 0; i < vec.length; i++){
+        vec[i] /= sum;
+      }
+    }
+  }
+  
+  public static class CopyCoreferenceRelations extends org.apache.uima.fit.component.JCasAnnotator_ImplBase {
+
+    public static final String PARAM_GOLD_VIEW = "GoldViewName";
+    @ConfigurationParameter(name=PARAM_GOLD_VIEW, mandatory=true, description="View containing gold standard annotations")
+    private String goldViewName;
+    
+    @SuppressWarnings("synthetic-access")
+    @Override
+    public void process(JCas jcas) throws AnalysisEngineProcessException {
+      JCas goldView = null;
+      try {
+        goldView = jcas.getView(goldViewName);
+      } catch (CASException e) {
+        e.printStackTrace();
+        throw new AnalysisEngineProcessException(e);
+      }
+      
+      HashMap<Markable,Markable> gold2sys = new HashMap<>();
+      Map<ConllDependencyNode,Collection<Markable>> depIndex = JCasUtil.indexCovering(jcas, ConllDependencyNode.class, Markable.class);
+      // remove those with removed markables (person mentions)
+      List<CollectionTextRelation> toRemove = new ArrayList<>();
+      
+      for(CollectionTextRelation goldChain : JCasUtil.select(goldView, CollectionTextRelation.class)){
+        FSList head = goldChain.getMembers();
+        NonEmptyFSList sysList = new NonEmptyFSList(jcas);
+        NonEmptyFSList listEnd = sysList;
+        boolean removeChain = false;
+        
+        // first one is guaranteed to be nonempty otherwise it would not be in cas
+        do{
+          NonEmptyFSList element = (NonEmptyFSList) head;
+          // if this is not first time through move listEnd to end.
+          if(listEnd.getHead() != null){
+            listEnd.setTail(new NonEmptyFSList(jcas));
+            listEnd.addToIndexes();
+            listEnd = (NonEmptyFSList) listEnd.getTail();
+          }
+          Markable goldMarkable = (Markable) element.getHead();
+          if(!(goldMarkable.getBegin() < 0 || goldMarkable.getEnd() >= jcas.getDocumentText().length())){
+            
+          
+            ConllDependencyNode headNode = DependencyUtility.getNominalHeadNode(jcas, goldMarkable);
+
+            for(Markable sysMarkable : depIndex.get(headNode)){
+              ConllDependencyNode markNode = DependencyUtility.getNominalHeadNode(jcas, sysMarkable);
+              if(markNode == headNode){
+                gold2sys.put(goldMarkable, sysMarkable);
+                break;
+              }
+            }
+            if(!gold2sys.containsKey(goldMarkable)){
+              Markable mappedGold = new Markable(jcas, goldMarkable.getBegin(), goldMarkable.getEnd());
+              mappedGold.addToIndexes();
+            }
+          }else{
+            // Have seen some instances where anafora writes a span that is not possible, log them
+            // so they can be found and fixed:
+            logger.warn(String.format("There is a markable with span [%d, %d] in a document with length %d\n", 
+                goldMarkable.getBegin(), goldMarkable.getEnd(), jcas.getDocumentText().length()));
+          }
+          
+          // add markable to end of list:
+          if(gold2sys.get(goldMarkable) == null){
+            logger.warn(String.format("There is a gold markable [%d, %d] which could not map to a system markable.", 
+                goldMarkable.getBegin(), goldMarkable.getEnd()));
+            removeChain = true;
+            break;
+          }
+          listEnd.setHead(gold2sys.get(goldMarkable));
+          
+          head = element.getTail();
+        }while(head instanceof NonEmptyFSList);
+        
+        // don't bother copying over -- the gold chain was of person mentions
+        if(!removeChain){      
+          listEnd.setTail(new EmptyFSList(jcas));
+          listEnd.addToIndexes();
+          listEnd.getTail().addToIndexes();
+          sysList.addToIndexes();
+          CollectionTextRelation sysRel = new CollectionTextRelation(jcas);
+          sysRel.setMembers(sysList);
+          sysRel.addToIndexes();
+        }
+      }
+      
+      for(CoreferenceRelation goldRel : JCasUtil.select(goldView, CoreferenceRelation.class)){
+        if((gold2sys.containsKey(goldRel.getArg1().getArgument()) && gold2sys.containsKey(goldRel.getArg2().getArgument()))){
+          CoreferenceRelation sysRel = new CoreferenceRelation(jcas);
+          sysRel.setCategory(goldRel.getCategory());
+          sysRel.setDiscoveryTechnique(CONST.REL_DISCOVERY_TECH_GOLD_ANNOTATION);
+
+          RelationArgument arg1 = new RelationArgument(jcas);
+          arg1.setArgument(gold2sys.get(goldRel.getArg1().getArgument()));
+          sysRel.setArg1(arg1);
+          arg1.addToIndexes();
+
+          RelationArgument arg2 = new RelationArgument(jcas);
+          arg2.setArgument(gold2sys.get(goldRel.getArg2().getArgument()));
+          sysRel.setArg2(arg2);
+          arg2.addToIndexes();         
+          
+          sysRel.addToIndexes();        
+        }
+      }
+    }    
+  }
+  public static class RemovePersonMarkables extends org.apache.uima.fit.component.JCasAnnotator_ImplBase {
+
+    @Override
+    public void process(JCas jcas) throws AnalysisEngineProcessException {
+//      JCas systemView=null, goldView=null;
+//      try{
+//        systemView = jcas.getView(CAS.NAME_DEFAULT_SOFA);
+//        goldView = jcas.getView(GOLD_VIEW_NAME);
+//      }catch(Exception e){
+//        throw new AnalysisEngineProcessException(e);
+//      }
+      List<Markable> toRemove = new ArrayList<>();
+      for(Markable markable : JCasUtil.select(jcas, Markable.class)){
+        List<BaseToken> coveredTokens = JCasUtil.selectCovered(jcas, BaseToken.class, markable);
+        if(coveredTokens.size() == 1 && coveredTokens.get(0).getPartOfSpeech().startsWith("PRP")){
+          toRemove.add(markable);
+        }else if(coveredTokens.size() == 2 && 
+            (coveredTokens.get(0).getCoveredText().startsWith("Mr.") || coveredTokens.get(0).getCoveredText().startsWith("Dr.") ||
+                coveredTokens.get(0).getCoveredText().startsWith("Mrs.") || coveredTokens.get(0).getCoveredText().startsWith("Ms."))){
+          toRemove.add(markable);
+        }else if(markable.getCoveredText().toLowerCase().equals("patient")){
+          toRemove.add(markable);
+        }
+      }
+      
+      for(Markable markable : toRemove){
+        markable.removeFromIndexes();
+      }
+    } 
+  }
+  
+  /* This flow control section borrows from the UIMA implementation of FixedFlowController
+   * and its internal Flow object. Simple change to check if there are any gold
+   * coref annotations inside the cas, and if not skip out so we don't waste
+   * time running coref code on those (since we're not going to print out the answers
+   * anyways)
+   */
+  public static class CorefEvalFlowController extends org.apache.uima.flow.JCasFlowController_ImplBase {
+    List<String> mSequence;
+
+    
+    @Override
+    public void initialize(FlowControllerContext context)
+        throws ResourceInitializationException {
+      super.initialize(context);
+      
+      FlowConstraints flowConstraints = context.getAggregateMetadata().getFlowConstraints();
+      mSequence = new ArrayList<>();
+      if (flowConstraints instanceof FixedFlow) {
+        String[] sequence = ((FixedFlow) flowConstraints).getFixedFlow();
+        mSequence.addAll(Arrays.asList(sequence));
+      } else {
+        throw new ResourceInitializationException(ResourceInitializationException.FLOW_CONTROLLER_REQUIRES_FLOW_CONSTRAINTS,
+                new Object[]{this.getClass().getName(), "fixedFlow", context.getAggregateMetadata().getSourceUrlString()});
+      }
+    }
+
+    @Override
+    public Flow computeFlow(JCas jcas) throws AnalysisEngineProcessException {
+      return new CorefEvalFlow(jcas, 0);
+    }
+    
+    class CorefEvalFlow extends JCasFlow_ImplBase {
+
+      private JCas jcas;
+      private int currentStep;
+
+      public CorefEvalFlow(JCas jcas, int step){
+        this.jcas = jcas;
+        this.currentStep = step;
+      }
+
+      @Override
+      public Step next() {
+        // if we are past the last annotator finish
+        if (currentStep >= mSequence.size()) {
+          return new FinalStep();
+        }
+
+        // if we have gold standard relations, continue
+        if(currentStep > 0 && mSequence.get(currentStep-1).equals(DocumentIDPrinter.class.getName())){
+          JCas goldView;
+          try {
+            goldView = jcas.getView(GOLD_VIEW_NAME);
+            if(JCasUtil.select(goldView, CoreferenceRelation.class).size() == 0){
+              System.out.println("Skipping this document with no coreference relations.");
+              return new FinalStep();
+            }
+          } catch (CASException e) {
+            // no need to stop flow -- just go ahead to default simple step.
+            e.printStackTrace();
+          }
+        }
+        
+        // otherwise finish
+        return new SimpleStep(mSequence.get(currentStep++));
+      }
+    }
+  }
+}

Added: ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfMarkableSalience.java
URL: http://svn.apache.org/viewvc/ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfMarkableSalience.java?rev=1666502&view=auto
==============================================================================
--- ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfMarkableSalience.java (added)
+++ ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfMarkableSalience.java Fri Mar 13 16:23:46 2015
@@ -0,0 +1,239 @@
+package org.apache.ctakes.coreference.eval;
+
+import java.io.File;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import org.apache.ctakes.assertion.medfacts.cleartk.PolarityCleartkAnalysisEngine;
+import org.apache.ctakes.coreference.ae.DeterministicMarkableAnnotator;
+import org.apache.ctakes.coreference.ae.MarkableSalienceAnnotator;
+import org.apache.ctakes.coreference.eval.EvaluationOfEventCoreference.DocumentIDPrinter;
+import org.apache.ctakes.coreference.eval.EvaluationOfEventCoreference.RemovePersonMarkables;
+import org.apache.ctakes.dependency.parser.util.DependencyUtility;
+import org.apache.ctakes.temporal.eval.Evaluation_ImplBase;
+import org.apache.ctakes.typesystem.type.relation.CollectionTextRelation;
+import org.apache.ctakes.typesystem.type.syntax.ConllDependencyNode;
+import org.apache.ctakes.typesystem.type.textsem.Markable;
+import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
+import org.apache.uima.cas.CAS;
+import org.apache.uima.cas.CASException;
+import org.apache.uima.collection.CollectionReader;
+import org.apache.uima.fit.component.ViewCreatorAnnotator;
+import org.apache.uima.fit.descriptor.ConfigurationParameter;
+import org.apache.uima.fit.factory.AggregateBuilder;
+import org.apache.uima.fit.factory.AnalysisEngineFactory;
+import org.apache.uima.fit.pipeline.JCasIterator;
+import org.apache.uima.fit.pipeline.SimplePipeline;
+import org.apache.uima.fit.util.JCasUtil;
+import org.apache.uima.jcas.JCas;
+import org.apache.uima.jcas.cas.FSList;
+import org.apache.uima.jcas.cas.NonEmptyFSList;
+import org.cleartk.eval.AnnotationStatistics;
+import org.cleartk.ml.jar.JarClassifierBuilder;
+import org.cleartk.ml.liblinear.LibLinearBooleanOutcomeDataWriter;
+
+import com.google.common.base.Function;
+import com.lexicalscope.jewel.cli.CliFactory;
+
+public class EvaluationOfMarkableSalience extends Evaluation_ImplBase<AnnotationStatistics<Boolean>> {
+
+  public static void main(String[] args) throws Exception {
+    Options options = CliFactory.parseArguments(Options.class, args);
+    List<Integer> patientSets = options.getPatients().getList();
+    List<Integer> trainItems = getTrainItems(options);
+    List<Integer> testItems = getTestItems(options);
+    
+    EvaluationOfMarkableSalience eval = 
+        new EvaluationOfMarkableSalience(new File("target/eval/salience"), 
+            options.getRawTextDirectory(), 
+            options.getXMLDirectory(), 
+            options.getXMLFormat(), 
+            options.getSubcorpus(), 
+            options.getXMIDirectory(), null);
+    eval.prepareXMIsFor(patientSets);
+
+    AnnotationStatistics<Boolean> stats = eval.trainAndTest(trainItems, testItems);
+    System.out.println(stats);
+    System.out.println(stats.confusions());
+  }
+
+  public EvaluationOfMarkableSalience(File baseDirectory,
+      File rawTextDirectory, File xmlDirectory,
+      org.apache.ctakes.temporal.eval.Evaluation_ImplBase.XMLFormat xmlFormat,
+      org.apache.ctakes.temporal.eval.Evaluation_ImplBase.Subcorpus subcorpus,
+      File xmiDirectory, File treebankDirectory) {
+    super(baseDirectory, rawTextDirectory, xmlDirectory, xmlFormat, subcorpus,
+        xmiDirectory, treebankDirectory);
+  }
+
+  @Override
+  protected void train(CollectionReader collectionReader, File directory)
+      throws Exception {
+    AggregateBuilder aggregateBuilder = this.getPreprocessorAggregateBuilder();
+    aggregateBuilder.add(PolarityCleartkAnalysisEngine.createAnnotatorDescription());
+    aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(DocumentIDPrinter.class));
+    aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(DeterministicMarkableAnnotator.class));
+    aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(RemovePersonMarkables.class));
+    aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(SetGoldConfidence.class, SetGoldConfidence.PARAM_GOLD_VIEW, GOLD_VIEW_NAME));
+    aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(MarkableSalienceAnnotator.createDataWriterDescription(
+        LibLinearBooleanOutcomeDataWriter.class,
+        directory
+        )));
+    SimplePipeline.runPipeline(collectionReader, aggregateBuilder.createAggregate());
+    // s=0 -> logistic regression with L2-norm (gives probabilistic outputs)
+    String[] optArray = new String[]{ "-s", "0", "-c", "1", "-w1", "1"};
+    JarClassifierBuilder.trainAndPackage(directory, optArray);
+  }
+
+  @Override
+  protected AnnotationStatistics<Boolean> test(
+      CollectionReader collectionReader, File directory) throws Exception {
+    AggregateBuilder aggregateBuilder = this.getPreprocessorAggregateBuilder();
+    aggregateBuilder.add(PolarityCleartkAnalysisEngine.createAnnotatorDescription());
+    aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(DocumentIDPrinter.class));
+    aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(DeterministicMarkableAnnotator.class));
+    aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(RemovePersonMarkables.class));
+    
+    aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(ViewCreatorAnnotator.class, ViewCreatorAnnotator.PARAM_VIEW_NAME, "PseudoGold"));
+    aggregateBuilder.add(AnalysisEngineFactory.createEngineDescription(CreatePseudoGoldMarkables.class, CreatePseudoGoldMarkables.PARAM_GOLD_VIEW, GOLD_VIEW_NAME, CreatePseudoGoldMarkables.PARAM_PSEUDO_GOLD_VIEW, "PseudoGold"));
+    aggregateBuilder.add(MarkableSalienceAnnotator.createAnnotatorDescription(directory.getAbsolutePath() + File.separator + "model.jar"));
+    AnnotationStatistics<Boolean> stats = new AnnotationStatistics<>();
+    
+    for(Iterator<JCas> casIter = new JCasIterator(collectionReader, aggregateBuilder.createAggregate()); casIter.hasNext();){
+      JCas jCas = casIter.next();
+      JCas goldView = jCas.getView("PseudoGold");
+      JCas systemView = jCas.getView(CAS.NAME_DEFAULT_SOFA);
+      
+      stats.add(JCasUtil.select(goldView, Markable.class),
+          JCasUtil.select(systemView, Markable.class),
+          AnnotationStatistics.<Markable>annotationToSpan(),
+          mapConfidenceToBoolean());      
+    }
+    
+    
+    return stats;
+  }
+  
+  public static class SetGoldConfidence extends org.apache.uima.fit.component.JCasAnnotator_ImplBase {
+
+    public static final String PARAM_GOLD_VIEW = "GoldViewName";
+    @ConfigurationParameter(name=PARAM_GOLD_VIEW, mandatory=true, description="View containing gold standard annotations")
+    private String goldViewName;
+    @Override
+    public void process(JCas jcas) throws AnalysisEngineProcessException {
+      JCas goldView = null;
+      try {
+        goldView = jcas.getView(goldViewName);
+      } catch (CASException e) {
+        e.printStackTrace();
+        throw new AnalysisEngineProcessException(e);
+      }
+      
+      Map<ConllDependencyNode,Collection<Markable>> depIndex = JCasUtil.indexCovering(jcas, ConllDependencyNode.class, Markable.class);
+      
+      // iterate over every gold coreference chain
+      for(CollectionTextRelation goldChain : JCasUtil.select(goldView, CollectionTextRelation.class)){
+        FSList head = goldChain.getMembers();
+        
+        // iterate over every gold markable in the chain
+        // first one is guaranteed to be nonempty otherwise it would not be in cas
+        do{
+          NonEmptyFSList element = (NonEmptyFSList) head;
+          Markable goldMarkable = (Markable) element.getHead();
+          if(!(goldMarkable.getBegin() < 0 || goldMarkable.getEnd() >= jcas.getDocumentText().length())){
+            // get the head of this markable, then check if there are any system markables with the same
+            // head, and if so, that markable is "true" for being coreferent, AKA high confidence.
+            ConllDependencyNode headNode = DependencyUtility.getNominalHeadNode(jcas, goldMarkable);
+
+            for(Markable sysMarkable : depIndex.get(headNode)){
+              ConllDependencyNode markNode = DependencyUtility.getNominalHeadNode(jcas, sysMarkable);
+              if(markNode == headNode){
+                sysMarkable.setConfidence(1.0f);
+                break;
+              }
+            }
+          }
+          head = element.getTail();
+        }while(head instanceof NonEmptyFSList);
+      }
+    }
+  }
+  
+  public static class CreatePseudoGoldMarkables extends org.apache.uima.fit.component.JCasAnnotator_ImplBase {
+
+    public static final String PARAM_PSEUDO_GOLD_VIEW = "PseudoViewName";
+    @ConfigurationParameter(name = PARAM_PSEUDO_GOLD_VIEW)
+    private String fakeGoldName;
+    
+    public static final String PARAM_GOLD_VIEW = "GoldViewName";
+    @ConfigurationParameter(name = PARAM_GOLD_VIEW)
+    private String goldViewName;
+    
+    @Override
+    public void process(JCas jcas) throws AnalysisEngineProcessException {
+      JCas fakeView = null;
+      JCas goldView = null;
+      
+      try{
+        fakeView = jcas.getView(fakeGoldName);
+        goldView = jcas.getView(goldViewName);
+      }catch(CASException e){
+        throw new AnalysisEngineProcessException(e);
+      }
+      // create a set of markables that map to gold
+      Set<Markable> sys = new HashSet<>();
+      Map<ConllDependencyNode,Collection<Markable>> depIndex = JCasUtil.indexCovering(jcas, ConllDependencyNode.class, Markable.class);
+      
+      // iterate over every gold coreference chain
+      for(CollectionTextRelation goldChain : JCasUtil.select(goldView, CollectionTextRelation.class)){
+        FSList head = goldChain.getMembers();
+        
+        // iterate over every gold markable in the chain
+        // first one is guaranteed to be nonempty otherwise it would not be in cas
+        do{
+          NonEmptyFSList element = (NonEmptyFSList) head;
+          Markable goldMarkable = (Markable) element.getHead();
+          if(!(goldMarkable.getBegin() < 0 || goldMarkable.getEnd() >= jcas.getDocumentText().length())){
+            // get the head of this markable, then check if there are any system markables with the same
+            // head, and if so, that markable is "true" for being coreferent, AKA high confidence.
+            ConllDependencyNode headNode = DependencyUtility.getNominalHeadNode(jcas, goldMarkable);
+
+            for(Markable sysMarkable : depIndex.get(headNode)){
+              ConllDependencyNode markNode = DependencyUtility.getNominalHeadNode(jcas, sysMarkable);
+              if(markNode == headNode){
+                sys.add(sysMarkable);
+                break;
+              }
+            }
+          }
+          head = element.getTail();
+        }while(head instanceof NonEmptyFSList);
+      }
+      
+      // add all system markables to psuedo-gold and with confidence based on whether they map
+      for(Markable markable : JCasUtil.select(jcas, Markable.class)){
+        Markable fakeMarkable = new Markable(fakeView, markable.getBegin(), markable.getEnd());
+        
+        if(sys.contains(markable)){
+          fakeMarkable.setConfidence(1.0f);
+        }else{
+          fakeMarkable.setConfidence(0.0f);
+        }
+        fakeMarkable.addToIndexes();
+      } 
+    }
+  }
+  
+  // this is predicting non-singletons rather than singletons
+  public static Function<Markable,Boolean> mapConfidenceToBoolean(){
+    return new Function<Markable,Boolean>() {
+      public Boolean apply(Markable markable) {
+        return markable.getConfidence() > 0.5;
+      }
+    };
+  }
+}

Added: ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfMarkableSpans.java
URL: http://svn.apache.org/viewvc/ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfMarkableSpans.java?rev=1666502&view=auto
==============================================================================
--- ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfMarkableSpans.java (added)
+++ ctakes/sandbox/ctakes-coref-cleartk/src/main/java/org/apache/ctakes/coreference/eval/EvaluationOfMarkableSpans.java Fri Mar 13 16:23:46 2015
@@ -0,0 +1,208 @@
+package org.apache.ctakes.coreference.eval;
+
+import java.io.File;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+import java.util.logging.Level;
+
+import org.apache.ctakes.coreference.ae.DeterministicMarkableAnnotator;
+import org.apache.ctakes.coreference.ae.MarkableAnnotator;
+import org.apache.ctakes.temporal.eval.EvaluationOfAnnotationSpans_ImplBase;
+import org.apache.ctakes.temporal.eval.Evaluation_ImplBase;
+import org.apache.ctakes.temporal.eval.THYMEData;
+import org.apache.ctakes.temporal.eval.Evaluation_ImplBase.CopyFromGold;
+import org.apache.ctakes.temporal.eval.Evaluation_ImplBase.Options;
+import org.apache.ctakes.temporal.eval.Evaluation_ImplBase.XMLFormat;
+import org.apache.ctakes.typesystem.type.textsem.Markable;
+import org.apache.ctakes.typesystem.type.textspan.Segment;
+import org.apache.uima.analysis_engine.AnalysisEngineDescription;
+import org.apache.uima.cas.CAS;
+import org.apache.uima.collection.CollectionReader;
+import org.apache.uima.fit.component.JCasAnnotator_ImplBase;
+import org.apache.uima.fit.factory.AggregateBuilder;
+import org.apache.uima.fit.factory.AnalysisEngineFactory;
+import org.apache.uima.fit.pipeline.SimplePipeline;
+import org.apache.uima.jcas.JCas;
+import org.apache.uima.jcas.tcas.Annotation;
+import org.apache.uima.resource.ResourceInitializationException;
+import org.cleartk.eval.AnnotationStatistics;
+import org.cleartk.ml.CleartkAnnotator;
+import org.cleartk.ml.jar.DefaultDataWriterFactory;
+import org.cleartk.ml.jar.DirectoryDataWriterFactory;
+import org.cleartk.ml.jar.GenericJarClassifierFactory;
+import org.cleartk.ml.jar.JarClassifierBuilder;
+import org.cleartk.ml.liblinear.LibLinearStringOutcomeDataWriter;
+
+import com.lexicalscope.jewel.cli.CliFactory;
+import com.lexicalscope.jewel.cli.Option;
+
+public class EvaluationOfMarkableSpans extends EvaluationOfAnnotationSpans_ImplBase {
+  static interface Options extends Evaluation_ImplBase.Options {
+    @Option
+    public boolean getUseTmp();
+    
+    @Option
+    public boolean getPul();
+    
+    @Option(shortName="m")
+    public boolean getUseMachineLearning();
+  }
+
+
+  public static void main(String[] args) throws Exception {
+    Options options = CliFactory.parseArguments(Options.class, args);
+    List<Integer> trainItems = null;
+    List<Integer> devItems = null;
+    List<Integer> testItems = null;
+
+    List<Integer> patientSets = options.getPatients().getList();
+    trainItems = THYMEData.getTrainPatientSets(patientSets);
+    devItems = THYMEData.getDevPatientSets(patientSets);
+    testItems = THYMEData.getTestPatientSets(patientSets);
+    File workingDir = new File("target/eval/markable-spans/");
+    if(!workingDir.exists()) workingDir.mkdirs();
+    if(options.getUseTmp()){
+      File tempModelDir = File.createTempFile("temporal", null, workingDir);
+      tempModelDir.delete();
+      tempModelDir.mkdir();
+      workingDir = tempModelDir;
+    }
+
+    List<Integer> allTrain = new ArrayList<>(trainItems);
+    List<Integer> allTest = null;
+    
+    if(options.getTest()){
+      allTrain.addAll(devItems);
+      allTest = new ArrayList<>(testItems);
+    }else{
+      allTest = new ArrayList<>(devItems);
+    }
+    
+    EvaluationOfMarkableSpans eval = new EvaluationOfMarkableSpans(
+        workingDir,
+        options.getRawTextDirectory(),
+        options.getXMLDirectory(),
+        options.getXMLFormat(),
+        options.getXMIDirectory(),
+        options.getTreebankDirectory(),
+        options.getPrintErrors());
+    
+    
+    eval.trainingArguments = new String[]{ "-c", "1.0", "-s", "0"};
+    eval.annotatorClass = options.getUseMachineLearning() ? MarkableAnnotator.class : DeterministicMarkableAnnotator.class;
+    String name = String.format("%s.errors", eval.annotatorClass.getSimpleName());
+    eval.setLogging(Level.FINE, new File("target/eval", name));
+    
+    AnnotationStatistics<String> stats = null;
+    if(options.getPul()){
+      stats = eval.trainAndRetrainAndTest(allTrain, allTest);
+    }else{
+      stats = eval.trainAndTest(allTrain, allTest);
+    }
+    System.out.println(stats);
+  }
+  
+  protected String[] trainingArguments;
+  
+  protected Class<? extends JCasAnnotator_ImplBase> annotatorClass = null;
+  
+  public EvaluationOfMarkableSpans(File workingDir, File rawTextDirectory,
+      File xmlDirectory,
+      org.apache.ctakes.temporal.eval.Evaluation_ImplBase.XMLFormat xmlFormat,
+      File xmiDirectory, File treebankDirectory,
+      boolean printErrors) {
+    super(workingDir, rawTextDirectory, xmlDirectory, xmlFormat, xmiDirectory, treebankDirectory, Markable.class);
+    this.printErrors = printErrors;
+  }
+
+  public AnnotationStatistics<String> trainAndRetrainAndTest(List<Integer> trainItems, List<Integer> testItems)
+      throws Exception {
+    File subDirectory = new File(this.baseDirectory, "train_and_test");
+    subDirectory.mkdirs();
+    this.train(this.getCollectionReader(trainItems), subDirectory);
+    this.retrain(this.getCollectionReader(trainItems), subDirectory);
+    return this.test(this.getCollectionReader(testItems), subDirectory);
+  }
+  
+  
+
+  @Override
+  protected void train(CollectionReader collectionReader, File directory)
+      throws Exception {
+    if(this.annotatorClass == MarkableAnnotator.class){
+      super.train(collectionReader, directory);
+    }
+  }
+
+  
+  protected void retrain(CollectionReader collectionReader, File directory) throws Exception{
+    AggregateBuilder aggregateBuilder = this.getPreprocessorAggregateBuilder();
+    aggregateBuilder.add(CopyFromGold.getDescription(Markable.class));
+    aggregateBuilder.add(this.getDataRewriterDescription(directory), "TimexView", CAS.NAME_DEFAULT_SOFA);
+    SimplePipeline.runPipeline(collectionReader, aggregateBuilder.createAggregate());
+    this.trainAndPackage(directory);
+  }
+  
+  @Override
+  protected void trainAndPackage(File directory) throws Exception {
+    JarClassifierBuilder.trainAndPackage(getModelDirectory(directory), this.trainingArguments);
+  }
+
+  @Override
+  protected AnalysisEngineDescription getDataWriterDescription(File directory)
+      throws ResourceInitializationException {
+    return AnalysisEngineFactory.createEngineDescription(
+        MarkableAnnotator.class,
+        CleartkAnnotator.PARAM_IS_TRAINING,
+        true,
+        DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
+        LibLinearStringOutcomeDataWriter.class,
+        DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
+        getModelDirectory(directory));        
+  }
+
+  protected AnalysisEngineDescription getDataRewriterDescription(File directory)
+      throws ResourceInitializationException {
+    return AnalysisEngineFactory.createEngineDescription(
+        MarkableAnnotator.class,
+        CleartkAnnotator.PARAM_IS_TRAINING,
+        false,
+        MarkableAnnotator.PARAM_IS_RETRAINING,
+        true,
+        DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME,
+        LibLinearStringOutcomeDataWriter.class,
+        DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY,
+        getModelDirectory(directory),
+        GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
+        new File(getModelDirectory(directory), "model.jar"));        
+  }
+
+  @Override
+  protected AnalysisEngineDescription getAnnotatorDescription(File directory)
+      throws ResourceInitializationException {
+    return AnalysisEngineFactory.createEngineDescription(
+        annotatorClass,
+        CleartkAnnotator.PARAM_IS_TRAINING,
+        false,
+        GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH,
+        new File(getModelDirectory(directory), "model.jar"));
+  }
+
+  @Override
+  protected Collection<? extends Annotation> getGoldAnnotations(JCas jCas,
+      Segment segment) {
+    return selectExact(jCas, Markable.class, segment);
+  }
+
+  @Override
+  protected Collection<? extends Annotation> getSystemAnnotations(JCas jCas,
+      Segment segment) {
+    return selectExact(jCas, Markable.class, segment);
+  }
+
+  private static File getModelDirectory(File directory) {
+    return new File(directory, MarkableAnnotator.class.getSimpleName());
+  }
+
+}



Mime
View raw message