Return-Path: X-Original-To: apmail-incubator-ctakes-commits-archive@minotaur.apache.org Delivered-To: apmail-incubator-ctakes-commits-archive@minotaur.apache.org Received: from mail.apache.org (hermes.apache.org [140.211.11.3]) by minotaur.apache.org (Postfix) with SMTP id 899D8D73C for ; Mon, 29 Oct 2012 21:18:58 +0000 (UTC) Received: (qmail 68594 invoked by uid 500); 29 Oct 2012 21:18:58 -0000 Delivered-To: apmail-incubator-ctakes-commits-archive@incubator.apache.org Received: (qmail 68568 invoked by uid 500); 29 Oct 2012 21:18:58 -0000 Mailing-List: contact ctakes-commits-help@incubator.apache.org; run by ezmlm Precedence: bulk List-Help: List-Unsubscribe: List-Post: List-Id: Reply-To: ctakes-dev@incubator.apache.org Delivered-To: mailing list ctakes-commits@incubator.apache.org Received: (qmail 68560 invoked by uid 99); 29 Oct 2012 21:18:58 -0000 Received: from athena.apache.org (HELO athena.apache.org) (140.211.11.136) by apache.org (qpsmtpd/0.29) with ESMTP; Mon, 29 Oct 2012 21:18:58 +0000 X-ASF-Spam-Status: No, hits=-2000.0 required=5.0 tests=ALL_TRUSTED X-Spam-Check-By: apache.org Received: from [140.211.11.4] (HELO eris.apache.org) (140.211.11.4) by apache.org (qpsmtpd/0.29) with ESMTP; Mon, 29 Oct 2012 21:18:55 +0000 Received: from eris.apache.org (localhost [127.0.0.1]) by eris.apache.org (Postfix) with ESMTP id B950423888CD; Mon, 29 Oct 2012 21:18:11 +0000 (UTC) Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit Subject: svn commit: r1403535 - in /incubator/ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion: eval/ eval/AssertionEvalBasedOnModifier.java medfacts/cleartk/AssertionCleartkAnalysisEngine.java medfacts/cleartk/TrainAssertionModel.java Date: Mon, 29 Oct 2012 21:18:11 -0000 To: ctakes-commits@incubator.apache.org From: mattcoarr@apache.org X-Mailer: svnmailer-1.0.8-patched Message-Id: <20121029211811.B950423888CD@eris.apache.org> X-Virus-Checked: Checked by ClamAV on apache.org Author: mattcoarr Date: Mon Oct 29 21:18:11 2012 New Revision: 1403535 URL: http://svn.apache.org/viewvc?rev=1403535&view=rev Log: beginnings of the assertion evaluation code Added: incubator/ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/eval/ incubator/ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/eval/AssertionEvalBasedOnModifier.java incubator/ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/medfacts/cleartk/TrainAssertionModel.java Modified: incubator/ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/medfacts/cleartk/AssertionCleartkAnalysisEngine.java Added: incubator/ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/eval/AssertionEvalBasedOnModifier.java URL: http://svn.apache.org/viewvc/incubator/ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/eval/AssertionEvalBasedOnModifier.java?rev=1403535&view=auto ============================================================================== --- incubator/ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/eval/AssertionEvalBasedOnModifier.java (added) +++ incubator/ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/eval/AssertionEvalBasedOnModifier.java Mon Oct 29 21:18:11 2012 @@ -0,0 +1,722 @@ +package org.apache.ctakes.assertion.eval; + +/* + * Copyright: (c) 2012 Children's Hospital Boston, Regents of the University of Colorado + * + * Except as contained in the copyright notice above, or as used to identify + * MFMER as the author of this software, the trade names, trademarks, service + * marks, or product names of the copyright holder shall not be used in + * advertising, promotion or otherwise in connection with this software without + * prior written authorization of the copyright holder. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * @author Steven Bethard + */ + +import java.io.File; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationTargetException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +import org.apache.log4j.Logger; +import org.apache.uima.analysis_engine.AnalysisEngine; +import org.apache.uima.analysis_engine.AnalysisEngineDescription; +import org.apache.uima.analysis_engine.AnalysisEngineProcessException; +import org.apache.uima.cas.CAS; +import org.apache.uima.cas.CASException; +import org.apache.uima.cas.Feature; +import org.apache.uima.collection.CollectionReader; +import org.apache.uima.jcas.JCas; +import org.apache.uima.jcas.tcas.Annotation; +import org.apache.uima.resource.ResourceInitializationException; +import org.apache.uima.util.CasCopier; +import org.apache.uima.util.Level; +//import org.chboston.cnlp.ctakes.relationextractor.ae.RelationExtractorAnnotator; +//import org.chboston.cnlp.ctakes.relationextractor.eval.RelationExtractorEvaluation; +//import org.chboston.cnlp.ctakes.relationextractor.ae.ModifierExtractorAnnotator; +import org.cleartk.classifier.CleartkAnnotator; +import org.cleartk.classifier.DataWriterFactory; +import org.cleartk.classifier.jar.DirectoryDataWriterFactory; +import org.cleartk.classifier.jar.GenericJarClassifierFactory; +import org.cleartk.classifier.jar.JarClassifierBuilder; +import org.cleartk.classifier.opennlp.DefaultMaxentDataWriterFactory; +import org.cleartk.classifier.opennlp.MaxentStringOutcomeDataWriter; +import org.cleartk.eval.AnnotationStatistics; +import org.cleartk.eval.Evaluation_ImplBase; +import org.cleartk.util.Options_ImplBase; +import org.kohsuke.args4j.Option; +import org.apache.ctakes.assertion.medfacts.cleartk.AssertionCleartkAnalysisEngine; +import org.uimafit.component.JCasAnnotator_ImplBase; +import org.uimafit.factory.AggregateBuilder; +import org.uimafit.factory.AnalysisEngineFactory; +import org.uimafit.factory.CollectionReaderFactory; +import org.uimafit.factory.ConfigurationParameterFactory; +import org.uimafit.factory.TypeSystemDescriptionFactory; +import org.uimafit.pipeline.JCasIterable; +import org.uimafit.pipeline.SimplePipeline; +import org.uimafit.testing.util.HideOutput; +import org.uimafit.util.JCasUtil; + +import com.google.common.base.Function; + +import org.apache.ctakes.typesystem.type.relation.BinaryTextRelation; +import org.apache.ctakes.typesystem.type.relation.RelationArgument; +import org.apache.ctakes.typesystem.type.syntax.BaseToken; +import org.apache.ctakes.typesystem.type.syntax.ContractionToken; +import org.apache.ctakes.typesystem.type.syntax.NewlineToken; +import org.apache.ctakes.typesystem.type.syntax.NumToken; +import org.apache.ctakes.typesystem.type.syntax.PunctuationToken; +import org.apache.ctakes.typesystem.type.syntax.SymbolToken; +import org.apache.ctakes.typesystem.type.syntax.WordToken; +import org.apache.ctakes.typesystem.type.textsem.EntityMention; +import org.apache.ctakes.typesystem.type.textsem.EventMention; +import org.apache.ctakes.typesystem.type.textsem.IdentifiedAnnotation; +import org.apache.ctakes.typesystem.type.textsem.Modifier; +import org.apache.ctakes.typesystem.type.textspan.Sentence; + +public class AssertionEvalBasedOnModifier extends Evaluation_ImplBase { + + private static Logger logger = Logger.getLogger(AssertionEvalBasedOnModifier.class); + + public static class Options extends Options_ImplBase { + @Option( + name = "--train-dir", + usage = "specify the directory contraining the XMI training files (for example, /NLP/Corpus/Relations/mipacq/xmi/train)", + required = true) + public File trainDirectory; + + @Option( + name = "--test-dir", + usage = "specify the directory contraining the XMI testing files (for example, /NLP/Corpus/Relations/mipacq/xmi/test)", + required = false) + public File testDirectory; + + @Option( + name = "--models-dir", + usage = "specify the directory where the models will be placed", + required = true) + public File modelsDirectory; + + + } + + private Class classifierAnnotatorClass; + + private Class> dataWriterFactoryClass; + + + + public static void main(String[] args) throws Exception { + Options options = new Options(); + options.parseOptions(args); + List trainFiles = Arrays.asList(options.trainDirectory.listFiles()); + //File modelsDir = new File("models/modifier"); + File modelsDir = options.modelsDirectory; + + // determine the type of classifier to be trained + Class> dataWriterFactoryClass = DefaultMaxentDataWriterFactory.class; + // TODO Class> dataWriterFactoryClass = DefaultDataWriterFactory.class; + // + // A DataWriterFactory that creates a data writer from the class given by + // PARAM_DATA_WRITER_CLASS_NAME and the directory given by + // DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY. + // + // DefaultMaxentDataWriterFactory is deprecated and says to use DefaultDattaWriterFactory + // with MaxentDataWriter. + + Class annotatorClass = AssertionCleartkAnalysisEngine.class; + + + AssertionEvalBasedOnModifier evaluation = new AssertionEvalBasedOnModifier( + modelsDir, + annotatorClass, + dataWriterFactoryClass + ); + /* + , + "-t", + "0", + "-c", + "1000"); + */ + +// List foldStats = evaluation.crossValidation(trainFiles, 2); +// //AnnotationStatistics overallStats = AnnotationStatistics.addAll(foldStats); +// //AnnotationStatistics overallStats = new AnnotationStatistics(); +// //overallStats.addAll(foldStats); +// AnnotationStatistics overallStats = new AnnotationStatistics(); +// for (AnnotationStatistics singleFoldStats : foldStats) +// { +// overallStats.addAll(singleFoldStats); +// } +// System.err.println("Overall:"); +// System.err.println(overallStats); + + + + if(options.testDirectory == null) { + // run n-fold cross-validation + List foldStats = evaluation.crossValidation(trainFiles, 2); + //AnnotationStatistics overallStats = AnnotationStatistics.addAll(foldStats); + AnnotationStatistics overallStats = new AnnotationStatistics(); + for (AnnotationStatistics singleFoldStats : foldStats) + { + overallStats.addAll(singleFoldStats); + } + + System.err.println("overall:"); + System.err.print(overallStats); + System.err.println(overallStats.confusions()); + System.err.println(); + + } else { + // train on the entire training set and evaluate on the test set + List testFiles = Arrays.asList(options.testDirectory.listFiles()); + + CollectionReader trainCollectionReader = evaluation.getCollectionReader(trainFiles); + evaluation.train(trainCollectionReader, modelsDir); + + CollectionReader testCollectionReader = evaluation.getCollectionReader(testFiles); + AnnotationStatistics stats = evaluation.test(testCollectionReader, modelsDir); + return; + } + + } + + private String[] trainingArguments; + + public AssertionEvalBasedOnModifier( + File directory, + Class classifierAnnotatorClass, + Class> dataWriterFactoryClass, + String... trainingArguments + ) { + super(directory); + + this.classifierAnnotatorClass = classifierAnnotatorClass; + this.dataWriterFactoryClass = dataWriterFactoryClass; + + this.trainingArguments = trainingArguments; + } + + @Override + public CollectionReader getCollectionReader(List items) + throws ResourceInitializationException { + String[] paths = new String[items.size()]; + for (int i = 0; i < paths.length; ++i) { + paths[i] = items.get(i).getPath(); + } + return CollectionReaderFactory.createCollectionReader( + XMIReader.class, + TypeSystemDescriptionFactory.createTypeSystemDescriptionFromPath("../common-type-system/desc/common_type_system.xml"), + XMIReader.PARAM_FILES, + paths); + } + + @Override + public void train(CollectionReader collectionReader, File directory) throws Exception { + AggregateBuilder builder = new AggregateBuilder(); + + //builder.add(AnalysisEngineFactory.createPrimitiveDescription(ReplaceCTakesEntityMentionsAndModifiersWithGold.class)); + +// AnalysisEngineDescription assertionDescription = AssertionCleartkAnalysisEngine.getDescription( +// CleartkAnnotator.PARAM_DATA_WRITER_FACTORY_CLASS_NAME, +// //MultiClassLIBSVMDataWriterFactory.class.getName(), +// MaxentStringOutcomeDataWriter.class.getName(), +// DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY, +// directory.getPath()); +// builder.add(assertionDescription); + + AnalysisEngineDescription goldCopierIdentifiedAnnotsAnnotator = AnalysisEngineFactory.createPrimitiveDescription(ReferenceIdentifiedAnnotationsSystemToGoldCopier.class); + builder.add(goldCopierIdentifiedAnnotsAnnotator); + + AnalysisEngineDescription goldCopierSupportingAnnotsAnnotator = AnalysisEngineFactory.createPrimitiveDescription(ReferenceSupportingAnnotationsSystemToGoldCopier.class); + builder.add(goldCopierSupportingAnnotsAnnotator); + + AnalysisEngineDescription assertionAttributeClearerAnnotator = AnalysisEngineFactory.createPrimitiveDescription(ReferenceAnnotationsSystemAssertionClearer.class); + builder.add(assertionAttributeClearerAnnotator); + + AnalysisEngineDescription assertionAnnotator = AnalysisEngineFactory.createPrimitiveDescription(AssertionCleartkAnalysisEngine.class); //, this.additionalParamemters); + ConfigurationParameterFactory.addConfigurationParameters( + assertionAnnotator, + AssertionCleartkAnalysisEngine.PARAM_GOLD_VIEW_NAME, + AssertionEvalBasedOnModifier.GOLD_VIEW_NAME, + CleartkAnnotator.PARAM_DATA_WRITER_FACTORY_CLASS_NAME, + this.dataWriterFactoryClass.getName(), + DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY, + directory.getPath() + ); + builder.add(assertionAnnotator); + +/* + AnalysisEngineDescription classifierAnnotator = AnalysisEngineFactory.createPrimitiveDescription( + this.classifierAnnotatorClass, + this.additionalParameters); + ConfigurationParameterFactory.addConfigurationParameters( + classifierAnnotator, + RelationExtractorAnnotator.PARAM_GOLD_VIEW_NAME, + RelationExtractorEvaluation.GOLD_VIEW_NAME, + CleartkAnnotator.PARAM_DATA_WRITER_FACTORY_CLASS_NAME, + this.dataWriterFactoryClass.getName(), + DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY, + directory.getPath()); + builder.add(classifierAnnotator); +*/ + + SimplePipeline.runPipeline(collectionReader, builder.createAggregateDescription()); + + HideOutput hider = new HideOutput(); + JarClassifierBuilder.trainAndPackage(directory, this.trainingArguments); + hider.restoreOutput(); + } + + @Override + protected AnnotationStatistics test(CollectionReader collectionReader, File directory) + throws Exception { +// AnalysisEngine classifierAnnotator = AnalysisEngineFactory.createPrimitive(AssertionCleartkAnalysisEngine.getDescription( +// GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH, +// new File(directory, "model.jar").getPath())); + + AggregateBuilder builder = new AggregateBuilder(); + + AnalysisEngineDescription goldCopierAnnotator = AnalysisEngineFactory.createPrimitiveDescription(ReferenceIdentifiedAnnotationsSystemToGoldCopier.class); + builder.add(goldCopierAnnotator); + + AnalysisEngineDescription assertionAttributeClearerAnnotator = AnalysisEngineFactory.createPrimitiveDescription(ReferenceAnnotationsSystemAssertionClearer.class); + builder.add(assertionAttributeClearerAnnotator); + + AnalysisEngineDescription assertionAnnotator = AnalysisEngineFactory.createPrimitiveDescription(AssertionCleartkAnalysisEngine.class); //, this.additionalParamemters); + ConfigurationParameterFactory.addConfigurationParameters( + assertionAnnotator, + AssertionCleartkAnalysisEngine.PARAM_GOLD_VIEW_NAME, + AssertionEvalBasedOnModifier.GOLD_VIEW_NAME, + GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH, + new File(directory, "model.jar").getPath() + ); + builder.add(assertionAnnotator); + + //SimplePipeline.runPipeline(collectionReader, builder.createAggregateDescription()); + AnalysisEngineDescription aggregateDescription = builder.createAggregateDescription(); + AnalysisEngine aggregate = builder.createAggregate(); + + AnnotationStatistics stats = new AnnotationStatistics(); + for (JCas jCas : new JCasIterable(collectionReader, aggregate)) { + JCas goldView; + try { + goldView = jCas.getView(GOLD_VIEW_NAME); + } catch (CASException e) { + throw new AnalysisEngineProcessException(e); + } + Collection goldEntities = new ArrayList(); + goldEntities.addAll(JCasUtil.select(goldView, EntityMention.class)); + goldEntities.addAll(JCasUtil.select(goldView, EventMention.class)); + + Collection systemEntities = new ArrayList(); + systemEntities.addAll(JCasUtil.select(jCas, EntityMention.class)); + systemEntities.addAll(JCasUtil.select(jCas, EventMention.class)); + + stats.add(goldEntities, systemEntities, + AnnotationStatistics.annotationToSpan(), + AnnotationStatistics.annotationToFeatureValue("polarity")); + } + System.err.println(directory.getName() + ":"); + System.err.println(stats); + return stats; + } + + public static final String GOLD_VIEW_NAME = "GoldView"; + + /** + * Class that copies the manual {@link Modifier} annotations to the default CAS. + */ + public static class OnlyGoldAssertions extends JCasAnnotator_ImplBase { + + @Override + public void process(JCas jCas) throws AnalysisEngineProcessException { + JCas goldView; + try { + goldView = jCas.getView(GOLD_VIEW_NAME); + } catch (CASException e) { + throw new AnalysisEngineProcessException(e); + } + + // remove any automatically generated Modifiers + for (EntityMention entityMention : JCasUtil.select(jCas, EntityMention.class)) { + entityMention.removeFromIndexes(); + } + + // copy over the manually annotated Modifiers + for (EntityMention entityMention : JCasUtil.select(goldView, EntityMention.class)) { + EntityMention newEntityMention = new EntityMention(jCas, entityMention.getBegin(), entityMention.getEnd()); + newEntityMention.setTypeID(entityMention.getTypeID()); + newEntityMention.setId(entityMention.getId()); + newEntityMention.setDiscoveryTechnique(entityMention.getDiscoveryTechnique()); + newEntityMention.setConfidence(entityMention.getConfidence()); + newEntityMention.addToIndexes(); + } + } + } + + + /** + * Annotator that removes cTAKES EntityMentions and Modifiers from the system + * view, and copies over the manually annotated EntityMentions and Modifiers + * from the gold view. + * + */ + public static class ReplaceCTakesEntityMentionsAndModifiersWithGold extends + JCasAnnotator_ImplBase + { + + @Override + public void process(JCas jCas) throws AnalysisEngineProcessException + { + JCas goldView, systemView; + try + { + goldView = jCas.getView(GOLD_VIEW_NAME); + systemView = jCas.getView(CAS.NAME_DEFAULT_SOFA); + } catch (CASException e) + { + throw new AnalysisEngineProcessException(e); + } + + // remove cTAKES EntityMentions and Modifiers from system view + List cTakesMentions = new ArrayList(); + cTakesMentions.addAll(JCasUtil.select(systemView, EntityMention.class)); + cTakesMentions.addAll(JCasUtil.select(systemView, Modifier.class)); + for (IdentifiedAnnotation cTakesMention : cTakesMentions) + { + cTakesMention.removeFromIndexes(); + } + + // copy gold EntityMentions and Modifiers to the system view + List goldMentions = new ArrayList(); + goldMentions.addAll(JCasUtil.select(goldView, EntityMention.class)); + goldMentions.addAll(JCasUtil.select(goldView, Modifier.class)); + CasCopier copier = new CasCopier(goldView.getCas(), systemView.getCas()); + for (IdentifiedAnnotation goldMention : goldMentions) + { + Annotation copy = (Annotation) copier.copyFs(goldMention); + Feature sofaFeature = copy.getType().getFeatureByBaseName("sofa"); + copy.setFeatureValue(sofaFeature, systemView.getSofa()); + copy.addToIndexes(); + } + } + } + + public static class ReplaceGoldEntityMentionsAndModifiersWithCTakes extends + JCasAnnotator_ImplBase + { + + @Override + public void process(JCas jCas) throws AnalysisEngineProcessException + { + JCas goldView, systemView; + try + { + goldView = jCas.getView(GOLD_VIEW_NAME); + systemView = jCas.getView(CAS.NAME_DEFAULT_SOFA); + } catch (CASException e) + { + throw new AnalysisEngineProcessException(e); + } + + // remove manual EntityMentions and Modifiers from gold view + List goldMentions = new ArrayList(); + goldMentions.addAll(JCasUtil.select(goldView, EntityMention.class)); + goldMentions.addAll(JCasUtil.select(goldView, Modifier.class)); + for (IdentifiedAnnotation goldMention : goldMentions) + { + goldMention.removeFromIndexes(); + } + + // copy cTAKES EntityMentions and Modifiers to gold view + List cTakesMentions = new ArrayList(); + cTakesMentions.addAll(JCasUtil.select(systemView, EntityMention.class)); + cTakesMentions.addAll(JCasUtil.select(systemView, Modifier.class)); + CasCopier copier = new CasCopier(systemView.getCas(), goldView.getCas()); + for (IdentifiedAnnotation cTakesMention : cTakesMentions) + { + Annotation copy = (Annotation) copier.copyFs(cTakesMention); + Feature sofaFeature = copy.getType().getFeatureByBaseName("sofa"); + copy.setFeatureValue(sofaFeature, goldView.getSofa()); + copy.addToIndexes(); + } + + // replace gold EntityMentions and Modifiers in relations with cTAKES ones + List relations = new ArrayList(); + relations.addAll(JCasUtil.select(goldView, BinaryTextRelation.class)); + for (BinaryTextRelation relation : relations) + { + + // attempt to replace the gold RelationArguments with system ones + int replacedArgumentCount = 0; + for (RelationArgument relArg : Arrays.asList(relation.getArg1(), + relation.getArg2())) + { + Annotation goldArg = relArg.getArgument(); + Class argClass = goldArg.getClass(); + + // find all annotations covered by the gold argument and of the same + // class (these should + // be the ones copied over from the cTAKES output earlier) + List systemArgs = JCasUtil.selectCovered( + goldView, argClass, goldArg); + + // no ctakes annotation found + if (systemArgs.size() == 0) + { + String word = "no"; + String className = argClass.getSimpleName(); + String argText = goldArg.getCoveredText(); + String message = String.format("%s %s for \"%s\"", word, className, + argText); + this.getContext().getLogger().log(Level.FINE, message); + continue; + } + + // if there's exactly one annotation, replace the gold one with that + if (systemArgs.size() == 1) + { + relArg.setArgument(systemArgs.get(0)); + replacedArgumentCount += 1; + } + + else + { + // multiple ctakes arguments found; look for one that matches + // exactly + // e.g. gold: "right breast", ctakes: "right breast", "breast" + for (Annotation systemArg : systemArgs) + { + String goldArgText = goldArg.getCoveredText(); + String systemArgText = systemArg.getCoveredText(); + if (systemArgText.equals(goldArgText)) + { + relArg.setArgument(systemArg); + replacedArgumentCount += 1; + } + } + + if (replacedArgumentCount < 1) + { + // issue a warning message + String word = "multiple"; + String className = argClass.getSimpleName(); + String argText = goldArg.getCoveredText(); + String message = String.format("%s %s for \"%s\"", word, + className, argText); + this.getContext().getLogger().log(Level.FINE, message); + + System.out.println("gold argument: " + goldArg.getCoveredText()); + System.out.println("gold type: " + + ((IdentifiedAnnotation) goldArg).getTypeID()); + for (Annotation systemArg : systemArgs) + { + System.out.println("ctakes argument: " + + systemArg.getCoveredText()); + System.out.println("ctakes type: " + + ((IdentifiedAnnotation) systemArg).getTypeID()); + } + System.out.println(); + } + } + } + + // if replacements were not found for both arguments, remove the + // relation + if (replacedArgumentCount < 2) + { + relation.removeFromIndexes(); + } + } + } + } + + /** + * Class that copies the manual {@link Modifier} annotations to the default CAS. + */ + public static class ReferenceIdentifiedAnnotationsSystemToGoldCopier extends JCasAnnotator_ImplBase { + + @Override + public void process(JCas jCas) throws AnalysisEngineProcessException { + JCas goldView; + try { + goldView = jCas.createView(GOLD_VIEW_NAME); + goldView.setSofaDataString(jCas.getSofaDataString(), jCas.getSofaMimeType()); + //goldView.setDocumentText(jCas.getDocumentText()); + //goldView = jCas.getView(GOLD_VIEW_NAME); + } catch (CASException e) { + throw new AnalysisEngineProcessException(e); + } + +// // remove any automatically generated Modifiers +// for (Modifier modifier : JCasUtil.select(jCas, Modifier.class)) { +// modifier.removeFromIndexes(); +// } + + for (EntityMention oldSystemEntityMention : JCasUtil.select(jCas, EntityMention.class)) + { + EntityMention newGoldEntityMention = new EntityMention(goldView, oldSystemEntityMention.getBegin(), oldSystemEntityMention.getEnd()); + + // copying assertion fields + newGoldEntityMention.setDiscoveryTechnique(oldSystemEntityMention.getDiscoveryTechnique()); + newGoldEntityMention.setUncertainty(oldSystemEntityMention.getUncertainty()); + newGoldEntityMention.setConditional(oldSystemEntityMention.getConditional()); + newGoldEntityMention.setGeneric(oldSystemEntityMention.getGeneric()); + newGoldEntityMention.setPolarity(oldSystemEntityMention.getPolarity()); + newGoldEntityMention.setSubject(oldSystemEntityMention.getSubject()); + + // copying non-assertion fields + newGoldEntityMention.setConfidence(oldSystemEntityMention.getConfidence()); + + newGoldEntityMention.addToIndexes(); + } + + for (EventMention oldSystemEventMention : JCasUtil.select(jCas, EventMention.class)) + { + EventMention newGoldEventMention = new EventMention(goldView, oldSystemEventMention.getBegin(), oldSystemEventMention.getEnd()); + + // copying assertion fields + newGoldEventMention.setDiscoveryTechnique(oldSystemEventMention.getDiscoveryTechnique()); + newGoldEventMention.setUncertainty(oldSystemEventMention.getUncertainty()); + newGoldEventMention.setConditional(oldSystemEventMention.getConditional()); + newGoldEventMention.setGeneric(oldSystemEventMention.getGeneric()); + newGoldEventMention.setPolarity(oldSystemEventMention.getPolarity()); + newGoldEventMention.setSubject(oldSystemEventMention.getSubject()); + + // copying non-assertion fields + newGoldEventMention.setConfidence(oldSystemEventMention.getConfidence()); + + newGoldEventMention.addToIndexes(); + } + + // TODO do we need to copy supporting feature structures (particularly ontology concept array)?? + + } // end of method ReferenceIdentifiedAnnotationsSystemToGoldCopier.process() + } // end of class ReferenceIdentifiedAnnotationsSystemToGoldCopier + + /** + * Class that copies the manual {@link Modifier} annotations to the default CAS. + */ + public static class ReferenceSupportingAnnotationsSystemToGoldCopier extends JCasAnnotator_ImplBase { + + @Override + public void process(JCas jCas) throws AnalysisEngineProcessException { + JCas goldView; + try { + goldView = jCas.getView(GOLD_VIEW_NAME); + } catch (CASException e) { + throw new AnalysisEngineProcessException(e); + } + +// // remove any automatically generated Modifiers +// for (Modifier modifier : JCasUtil.select(jCas, Modifier.class)) { +// modifier.removeFromIndexes(); +// } + + for (Sentence oldSystemSentence : JCasUtil.select(jCas, Sentence.class)) + { + Sentence newGoldSentence = new Sentence(goldView, oldSystemSentence.getBegin(), oldSystemSentence.getEnd()); + + newGoldSentence.addToIndexes(); + } + + for (BaseToken oldSystemToken : JCasUtil.select(jCas, BaseToken.class)) + { + BaseToken newGoldToken = null; //new BaseToken(goldView, oldSystemEventMention.getBegin(), oldSystemEventMention.getEnd()); + + // TODO the following commented out block is an alternative to having the hard coded if..then..else-if..else block for constructing new BaseToken objects +// Constructor constructor = null; +// try +// { +// constructor = oldSystemToken.getClass().getConstructor(JCas.class, int.class, int.class); +// } catch(NoSuchMethodException| SecurityException e) +// { +// logger.error("problem getting constructor for copying BaseToken instance (inside AssertionEvalBasedOnModifier.ReferenceSupportingAnnotationsSystemToGoldcopier.process())"); +// continue; +// } +// try +// { +// newGoldToken = constructor.newInstance(goldView, oldSystemToken.getBegin(), oldSystemToken.getEnd()); +// } catch (InstantiationException | IllegalAccessException +// | IllegalArgumentException | InvocationTargetException e) +// { +// logger.error("problem invoking constructor to copy BaseToken instance (inside AssertionEvalBasedOnModifier.ReferenceSupportingAnnotationsSystemToGoldcopier.process())"); +// continue; +// } + + String oldSystemTokenClass = oldSystemToken.getClass().getName(); + if (oldSystemTokenClass.equals(WordToken.class.getName())) + { + newGoldToken = new WordToken(goldView, oldSystemToken.getBegin(), oldSystemToken.getEnd()); + } else if (oldSystemTokenClass.equals(ContractionToken.class.getName())) + { + newGoldToken = new ContractionToken(goldView, oldSystemToken.getBegin(), oldSystemToken.getEnd()); + } else if (oldSystemTokenClass.equals(NewlineToken.class.getName())) + { + newGoldToken = new NewlineToken(goldView, oldSystemToken.getBegin(), oldSystemToken.getEnd()); + } else if (oldSystemTokenClass.equals(NumToken.class.getName())) + { + newGoldToken = new NumToken(goldView, oldSystemToken.getBegin(), oldSystemToken.getEnd()); + } else if (oldSystemTokenClass.equals(PunctuationToken.class.getName())) + { + newGoldToken = new PunctuationToken(goldView, oldSystemToken.getBegin(), oldSystemToken.getEnd()); + } else if (oldSystemTokenClass.equals(SymbolToken.class.getName())) + { + newGoldToken = new SymbolToken(goldView, oldSystemToken.getBegin(), oldSystemToken.getEnd()); + } else if (oldSystemTokenClass.equals(BaseToken.class.getName())) + { + newGoldToken = new BaseToken(goldView, oldSystemToken.getBegin(), oldSystemToken.getEnd()); + } else + { + newGoldToken = new BaseToken(goldView, oldSystemToken.getBegin(), oldSystemToken.getEnd()); + } + + newGoldToken.setPartOfSpeech(oldSystemToken.getPartOfSpeech()); + newGoldToken.setTokenNumber(oldSystemToken.getTokenNumber()); + + newGoldToken.addToIndexes(); + } + + } // end of method ReferenceSupportingAnnotationsSystemToGoldCopier.process() + } // end of class ReferenceSupportingAnnotationsSystemToGoldCopier + + /** + * Class that copies the manual {@link Modifier} annotations to the default CAS. + */ + public static class ReferenceAnnotationsSystemAssertionClearer extends JCasAnnotator_ImplBase + { + + @Override + public void process(JCas jCas) throws AnalysisEngineProcessException + { + for (EntityMention entityMention : JCasUtil.select(jCas, EntityMention.class)) + { + entityMention.setPolarity(1); + } + for (EventMention eventMention : JCasUtil.select(jCas, EventMention.class)) + { + eventMention.setPolarity(1); + } + } // end method ReferenceAnnotationsSystemAssertionClearer.process() + } // end class ReferenceAnnotationsSystemAssertionClearer + + + +} // end of class AssertionEvalBasedOnModifier Modified: incubator/ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/medfacts/cleartk/AssertionCleartkAnalysisEngine.java URL: http://svn.apache.org/viewvc/incubator/ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/medfacts/cleartk/AssertionCleartkAnalysisEngine.java?rev=1403535&r1=1403534&r2=1403535&view=diff ============================================================================== --- incubator/ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/medfacts/cleartk/AssertionCleartkAnalysisEngine.java (original) +++ incubator/ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/medfacts/cleartk/AssertionCleartkAnalysisEngine.java Mon Oct 29 21:18:11 2012 @@ -1,21 +1,3 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ package org.apache.ctakes.assertion.medfacts.cleartk; import java.util.ArrayList; @@ -25,11 +7,15 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import org.apache.log4j.Level; +import org.apache.log4j.Logger; import org.apache.uima.UimaContext; import org.apache.uima.analysis_engine.AnalysisEngineDescription; import org.apache.uima.analysis_engine.AnalysisEngineProcessException; +import org.apache.uima.cas.CASException; import org.apache.uima.jcas.JCas; import org.apache.uima.resource.ResourceInitializationException; +//import org.chboston.cnlp.ctakes.relationextractor.ae.ModifierExtractorAnnotator; import org.cleartk.classifier.CleartkAnnotator; import org.cleartk.classifier.CleartkAnnotatorDescriptionFactory; import org.cleartk.classifier.CleartkSequenceAnnotator; @@ -50,19 +36,44 @@ import org.cleartk.classifier.feature.pr import org.cleartk.classifier.opennlp.DefaultMaxentDataWriterFactory; import org.cleartk.classifier.opennlp.MaxentDataWriterFactory_ImplBase; import org.cleartk.type.test.Token; +import org.uimafit.descriptor.ConfigurationParameter; +import org.uimafit.factory.AnalysisEngineFactory; import org.uimafit.factory.ConfigurationParameterFactory; import org.uimafit.util.JCasUtil; -import org.apache.ctakes.typesystem.type.syntax.BaseToken; -import org.apache.ctakes.typesystem.type.textsem.EntityMention; -import org.apache.ctakes.typesystem.type.textsem.IdentifiedAnnotation; -import org.apache.ctakes.typesystem.type.textspan.Sentence; +import edu.mayo.bmi.uima.core.type.structured.DocumentID; +import edu.mayo.bmi.uima.core.type.syntax.BaseToken; +import edu.mayo.bmi.uima.core.type.textsem.EntityMention; +import edu.mayo.bmi.uima.core.type.textsem.EventMention; +import edu.mayo.bmi.uima.core.type.textsem.IdentifiedAnnotation; +import edu.mayo.bmi.uima.core.type.textspan.Sentence; public class AssertionCleartkAnalysisEngine extends - CleartkSequenceAnnotator + CleartkAnnotator { + Logger logger = Logger.getLogger(AssertionCleartkAnalysisEngine.class); - public static final String PARAM_GOLD_VIEW_NAME = "GoldViewName"; + public static final String PARAM_GOLD_VIEW_NAME = "GoldViewName"; + + public static int relationId; // counter for error logging + + @ConfigurationParameter( + name = PARAM_GOLD_VIEW_NAME, + mandatory = false, + description = "view containing the manual identified annotations (especially EntityMention and EventMention annotations); needed for training") + protected String goldViewName; + + public static final String PARAM_PRINT_ERRORS = "PrintErrors"; + + @ConfigurationParameter( + name = PARAM_PRINT_ERRORS, + mandatory = false, + description = "Print errors true/false", + defaultValue = "false") + boolean printErrors; + + + //private SimpleFeatureExtractor tokenFeatureExtractor; private List> contextFeatureExtractors; private List> tokenContextFeatureExtractors; @@ -70,6 +81,11 @@ public class AssertionCleartkAnalysisEng public void initialize(UimaContext context) throws ResourceInitializationException { super.initialize(context); + + if (this.isTraining() && this.goldViewName == null) { + throw new IllegalArgumentException(PARAM_GOLD_VIEW_NAME + " must be defined during training"); + } + // alias for NGram feature parameters int fromRight = CharacterNGramProliferator.RIGHT_TO_LEFT; @@ -95,6 +111,7 @@ public class AssertionCleartkAnalysisEng //new TypePathExtractor(IdentifiedAnnotation.class, "stem"), new Preceding(2), new Following(2))); + ContextExtractor tokenContextExtractor1 = new ContextExtractor( BaseToken.class, new SpannedTextExtractor(), @@ -140,14 +157,58 @@ public class AssertionCleartkAnalysisEng @Override public void process(JCas jCas) throws AnalysisEngineProcessException { - Map> coveringSentenceMap = JCasUtil.indexCovering(jCas, IdentifiedAnnotation.class, Sentence.class); - Map> tokensCoveredInSentenceMap = JCasUtil.indexCovered(jCas, Sentence.class, BaseToken.class); + DocumentID documentId = JCasUtil.selectSingle(jCas, DocumentID.class); + if (documentId != null) + { + logger.info("processing next doc: " + documentId.getDocumentID()); + } else + { + logger.info("processing next doc (doc id is null)"); + } +// // get gold standard relation instances during testing for error analysis +// if (! this.isTraining() && printErrors) { +// JCas goldView; +// try { +// goldView = jCas.getView("GoldView"); +// } catch(CASException e) { +// throw new AnalysisEngineProcessException(e); +// } +// +// //categoryLookup = createCategoryLookup(goldView); +// } + + JCas identifiedAnnotationView, relationView; + if (this.isTraining()) { + try { + identifiedAnnotationView = relationView = jCas.getView(this.goldViewName); + } catch (CASException e) { + throw new AnalysisEngineProcessException(e); + } + } else { + identifiedAnnotationView = relationView = jCas; + } + + + Map> coveringSentenceMap = JCasUtil.indexCovering(identifiedAnnotationView, IdentifiedAnnotation.class, Sentence.class); + Map> tokensCoveredInSentenceMap = JCasUtil.indexCovered(identifiedAnnotationView, Sentence.class, BaseToken.class); List> instances = new ArrayList>(); // generate a list of training instances for each sentence in the document - Collection entities = JCasUtil.select(jCas, IdentifiedAnnotation.class); + Collection entities = JCasUtil.select(identifiedAnnotationView, IdentifiedAnnotation.class); for (IdentifiedAnnotation entityMention : entities) { + if (!(entityMention instanceof EntityMention || entityMention instanceof EventMention)) + { + continue; + } + if (entityMention.getPolarity() == -1) + { + logger.info(String.format(" - identified annotation: [%d-%d] polarity %d (%s)", + entityMention.getBegin(), + entityMention.getEnd(), + entityMention.getPolarity(), + entityMention.getClass().getName())); + } Instance instance = new Instance(); // // extract all features that require only the entity mention annotation @@ -155,64 +216,85 @@ public class AssertionCleartkAnalysisEng // extract all features that require the token and sentence annotations Collection sentenceList = coveringSentenceMap.get(entityMention); + Sentence sentence = null; if (sentenceList == null || sentenceList.isEmpty()) { String message = "no surrounding sentence found"; Exception runtimeException = new RuntimeException(message); - throw new AnalysisEngineProcessException(runtimeException); + AnalysisEngineProcessException aeException = new AnalysisEngineProcessException(runtimeException); + logger.log(Level.ERROR, message); } else if (sentenceList.size() > 1) { String message = "more than one surrounding sentence found"; Exception runtimeException = new RuntimeException(message); - throw new AnalysisEngineProcessException(runtimeException); + AnalysisEngineProcessException aeException = new AnalysisEngineProcessException(runtimeException); + logger.log(Level.ERROR, message); + } else + { + sentence = sentenceList.iterator().next(); } - Sentence sentence = sentenceList.iterator().next(); - for (ContextExtractor extractor : this.contextFeatureExtractors) { - instance.addAll(extractor.extractWithin(jCas, entityMention, sentence)); + //Sentence sentence = sentenceList.iterator().next(); + + if (sentence != null) + { + for (ContextExtractor extractor : this.contextFeatureExtractors) { + instance.addAll(extractor.extractWithin(identifiedAnnotationView, entityMention, sentence)); + } + } else + { + // TODO extract context features for annotations that don't fall within a sentence + logger.log(Level.WARN, "FIXME/TODO: generate context features for entities that don't fall within a sentence"); } + for (ContextExtractor extractor : this.tokenContextFeatureExtractors) { - instance.addAll(extractor.extract(jCas, entityMention)); + instance.addAll(extractor.extract(identifiedAnnotationView, entityMention)); } for (SimpleFeatureExtractor extractor : this.entityFeatureExtractors) { - instance.addAll(extractor.extract(jCas, entityMention)); + instance.addAll(extractor.extract(identifiedAnnotationView, entityMention)); } if (this.isTraining()) { - String polarity = (entityMention.getPolarity() == 1) ? "present" : "negated"; + String polarity = (entityMention.getPolarity() == -1) ? "negated" : "present"; instance.setOutcome(polarity); - } - - // add the instance to the list - instances.add(instance); - } - - // TODO figure out exactly what should be happening on training/evaluation/decoding - - // for training, write instances to the data write - if (this.isTraining()) { - this.dataWriter.write(instances); - } - // for classification, set the labels as the token POS labels - else - { - Iterator entityIter = entities.iterator(); - for (String label : this.classify(instances)) + if ("negated".equals(polarity)) + { + logger.info("TRAINING: " + polarity); + } + this.dataWriter.write(instance); + } else { + String label = this.classifier.classify(instance.getFeatures()); int polarity = 1; if (label!= null && label.equals("present")) { - polarity = 1; + polarity = 0; } else if (label != null && label.equals("negated")) { polarity = -1; } - entityIter.next().setPolarity(polarity); - } // end for loop - } // end else - + entityMention.setPolarity(polarity); + if ("negated".equals(label)) + { + logger.info(String.format("DECODING/EVAL: %s//%s [%d-%d] (%s)", label, polarity, entityMention.getBegin(), entityMention.getEnd(), entityMention.getClass().getName())); + } + } + + } + } + public static AnalysisEngineDescription getDescription(Object... additionalConfiguration) + throws ResourceInitializationException { + AnalysisEngineDescription desc = AnalysisEngineFactory.createPrimitiveDescription(AssertionCleartkAnalysisEngine.class); + if (additionalConfiguration.length > 0) { + ConfigurationParameterFactory.addConfigurationParameters(desc, additionalConfiguration); + } + return desc; + } + + + /* public static AnalysisEngineDescription getClassifierDescription(String modelFileName) throws ResourceInitializationException { Added: incubator/ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/medfacts/cleartk/TrainAssertionModel.java URL: http://svn.apache.org/viewvc/incubator/ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/medfacts/cleartk/TrainAssertionModel.java?rev=1403535&view=auto ============================================================================== --- incubator/ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/medfacts/cleartk/TrainAssertionModel.java (added) +++ incubator/ctakes/trunk/ctakes-assertion/src/main/java/org/apache/ctakes/assertion/medfacts/cleartk/TrainAssertionModel.java Mon Oct 29 21:18:11 2012 @@ -0,0 +1,378 @@ +package org.apache.ctakes.assertion.medfacts.cleartk; + +import java.io.File; +import java.util.Locale; + +import org.apache.commons.cli.CommandLine; +import org.apache.commons.cli.CommandLineParser; +import org.apache.commons.cli.GnuParser; +import org.apache.commons.cli.HelpFormatter; +import org.apache.commons.cli.Option; +import org.apache.commons.cli.OptionBuilder; +import org.apache.commons.cli.Options; +import org.apache.commons.cli.ParseException; +import org.apache.ctakes.assertion.eval.AssertionEvalBasedOnModifier; +import org.apache.ctakes.assertion.eval.AssertionEvalBasedOnModifier.ReferenceAnnotationsSystemAssertionClearer; +import org.apache.ctakes.assertion.eval.AssertionEvalBasedOnModifier.ReferenceIdentifiedAnnotationsSystemToGoldCopier; +import org.apache.log4j.Logger; +import org.apache.uima.analysis_engine.AnalysisEngineDescription; + +import org.apache.uima.collection.CollectionReader; +import org.apache.uima.collection.CollectionReaderDescription; +import org.cleartk.classifier.CleartkAnnotator; +import org.cleartk.classifier.CleartkAnnotatorDescriptionFactory; +import org.cleartk.classifier.DataWriterFactory; +import org.cleartk.classifier.jar.DirectoryDataWriterFactory; +import org.cleartk.classifier.jar.GenericJarClassifierFactory; +import org.cleartk.classifier.opennlp.DefaultMaxentDataWriterFactory; +import org.cleartk.classifier.opennlp.MaxentDataWriter; +import org.cleartk.classifier.opennlp.MaxentStringOutcomeDataWriter; +import org.cleartk.util.cr.FilesCollectionReader; +import org.cleartk.util.cr.XReader; +import org.uimafit.component.xwriter.XWriter; +import org.uimafit.factory.AggregateBuilder; +import org.uimafit.factory.AnalysisEngineFactory; +import org.uimafit.factory.CollectionReaderFactory; +import org.uimafit.factory.ConfigurationParameterFactory; +import org.uimafit.pipeline.SimplePipeline; +import org.uimafit.testing.util.HideOutput; +import org.junit.Test; +import org.apache.ctakes.assertion.medfacts.AssertionAnalysisEngine; +import edu.mayo.bmi.uima.core.type.syntax.BaseToken; +//import edu.mayo.bmi.uima.core.type.textsem.EntityMention; +import edu.mayo.bmi.uima.core.type.textsem.IdentifiedAnnotation; +import edu.mayo.bmi.uima.core.type.textspan.Sentence; +import org.cleartk.classifier.jar.DefaultDataWriterFactory; +import org.cleartk.examples.pos.ExamplePOSPlainTextWriter; + + +public class TrainAssertionModel { + + public static final String PARAM_NAME_DECODING_OUTPUT_DIRECTORY = "decoding-output-directory"; + + public static final String PARAM_NAME_DECODING_INPUT_DIRECTORY = "decoding-input-directory"; + + public static final String PARAM_NAME_TRAINING_INPUT_DIRECTORY = "training-input-directory"; + + public static final String PARAM_NAME_MODEL_DIRECTORY = "model-directory"; + + protected static final Logger logger = Logger.getLogger(TrainAssertionModel.class.getName()); + + /** + * @param args + */ + /* + public static void main(String[] args) { + // TODO Auto-generated method stub + String trainDir = args[0]; + String outputDir = args[1]; + + try { + CollectionReader reader = FilesCollectionReader.getCollectionReader(trainDir); + AggregateBuilder builder = new AggregateBuilder(); + //builder.add(AnalysisEngineFactory.createAnalysisEngineDescription("desc/AssertionMiniPipelineAnalysisEngine.xml", null)); + //builder.add(AnalysisEngineFactory.createPrimitiveDescription(IdentifiedAnnotation.class)); + //builder.add(AnalysisEngineFactory.createAnalysisEngineDescription("edu.mayo.bmi.uima.core.type.textsem.IdentifiedAnnotation")); + builder.add(AssertionCleartkAnalysisEngine.getWriterDescription(outputDir)); + SimplePipeline.runPipeline(reader, builder.createAggregateDescription()); + org.cleartk.classifier.jar.Train.main(outputDir); + } catch (Exception e) { + System.err.println("Exception: " + e); + e.printStackTrace(); + throw new RuntimeException(e); + } + + + } + */ + + protected String modelOutputDirectory = "/work/medfacts/cleartk/data/train.model"; + + @Test + public void testMaxent() throws Exception { + + String trainingDataDirectory = "/work/medfacts/cleartk/data/train"; + String evaluationDataDirectory = "/work/medfacts/cleartk/data/eval2.input"; + String evaluationOutputDataDirectory = "/work/medfacts/cleartk/data/eval2.output"; + + String maxentModelOutputDirectory = modelOutputDirectory + "/maxent"; + AnalysisEngineDescription dataWriter = AnalysisEngineFactory.createPrimitiveDescription( + AssertionCleartkAnalysisEngine.class, + AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION, + DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME, + MaxentStringOutcomeDataWriter.class.getName(), + DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY, + maxentModelOutputDirectory); + testClassifier( + dataWriter, + maxentModelOutputDirectory, + trainingDataDirectory, + evaluationDataDirectory, + evaluationOutputDataDirectory + ); + +// // Not sure why the _SPLIT is here, but we will throw it out for good measure +// String firstLine = FileUtil.loadListOfStrings(new File(maxentDirectoryName +// + "/2008_Sichuan_earthquake.txt.pos"))[0].trim().replace("_SPLIT", ""); +// checkPOS(firstLine); + } + + public static void main(String args[]) + { + + Options options = new Options(); + + Option modelDirectoryOption = + OptionBuilder + .withLongOpt(TrainAssertionModel.PARAM_NAME_MODEL_DIRECTORY) + .withArgName("DIR") + .hasArg() + .isRequired() + .withDescription("the directory where the model is written to for training, or read from for decoding") + .create(); + options.addOption(modelDirectoryOption); + + Option trainingInputDirectoryOption = + OptionBuilder + .withLongOpt(TrainAssertionModel.PARAM_NAME_TRAINING_INPUT_DIRECTORY) + .withArgName("DIR") + .hasArg() + .isRequired() + .withDescription("directory where input training xmi files are located") + .create(); + options.addOption(trainingInputDirectoryOption); + + Option decodingInputDirectoryOption = + OptionBuilder + .withLongOpt(TrainAssertionModel.PARAM_NAME_DECODING_INPUT_DIRECTORY) + .withArgName("DIR") + .hasArg() + .isRequired() + .withDescription("directory where input xmi files are located for decoding") + .create(); + options.addOption(decodingInputDirectoryOption); + + Option decodingOutputDirectoryOption = + OptionBuilder + .withLongOpt(TrainAssertionModel.PARAM_NAME_DECODING_OUTPUT_DIRECTORY) + .withArgName("DIR") + .hasArg() + .isRequired() + .withDescription("directory where output xmi files that are generated in decoding are placed") + .create(); + options.addOption(decodingOutputDirectoryOption); + + CommandLineParser parser = new GnuParser(); + + boolean invalidInput = false; + + CommandLine commandLine = null; + String modelDirectory = null; + String trainingInputDirectory = null; + String decodingInputDirectory = null; + String decodingOutputDirectory = null; + try + { + commandLine = parser.parse(options, args); + + modelDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_MODEL_DIRECTORY); + trainingInputDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_TRAINING_INPUT_DIRECTORY); + decodingInputDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_DECODING_INPUT_DIRECTORY); + decodingOutputDirectory = commandLine.getOptionValue(TrainAssertionModel.PARAM_NAME_DECODING_OUTPUT_DIRECTORY); + } catch (ParseException e) + { + invalidInput = true; + logger.error("unable to parse command-line arguments", e); + } + + if (modelDirectory == null || modelDirectory.isEmpty() || + trainingInputDirectory == null || trainingInputDirectory.isEmpty() || + decodingInputDirectory == null || decodingInputDirectory.isEmpty() || + decodingOutputDirectory == null || decodingOutputDirectory.isEmpty() + ) + { + logger.error("required parameters not supplied"); + invalidInput = true; + } + + if (invalidInput) + { + HelpFormatter formatter = new HelpFormatter(); + formatter.printHelp(TrainAssertionModel.class.getName(), options, true); + return; + } + + logger.info(String.format( + "%n" + + "model dir: \"%s\"%n" + + "training input dir: \"%s\"%n" + + "decoding input dir: \"%s\"%n" + + "decoding output dir: \"%s\"%n", + modelDirectory, + trainingInputDirectory, + decodingInputDirectory, + decodingOutputDirectory)); + + String maxentModelOutputDirectory = modelDirectory + "/maxent"; + try + { + AnalysisEngineDescription dataWriter = AnalysisEngineFactory.createPrimitiveDescription( + AssertionCleartkAnalysisEngine.class, + AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION, + DefaultDataWriterFactory.PARAM_DATA_WRITER_CLASS_NAME, + MaxentStringOutcomeDataWriter.class.getName(), + DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY, + maxentModelOutputDirectory); + testClassifier( + dataWriter, + maxentModelOutputDirectory, + trainingInputDirectory, + decodingInputDirectory, + decodingOutputDirectory + ); + } catch (Exception e) + { + logger.error("Some exception happened while training or decoding...", e); + return; + } + + } + + + + public static void testClassifier( + AnalysisEngineDescription dataWriter, + String modelOutputDirectory, + String trainingDataInputDirectory, + String decodingInputDirectory, + String decodingOutputDirectory, + String... trainingArgs) throws Exception + { + + CollectionReader trainingCollectionReader = CollectionReaderFactory.createCollectionReader( + XReader.class, + XReader.PARAM_ROOT_FILE, + trainingDataInputDirectory, + XReader.PARAM_XML_SCHEME, + XReader.XMI); + CollectionReader evaluationCollectionReader = CollectionReaderFactory.createCollectionReader( + XReader.class, + XReader.PARAM_ROOT_FILE, + decodingInputDirectory, + XReader.PARAM_XML_SCHEME, + XReader.XMI); + + + AggregateBuilder trainingBuilder = new AggregateBuilder(); + + AnalysisEngineDescription goldCopierAnnotator = AnalysisEngineFactory.createPrimitiveDescription(ReferenceIdentifiedAnnotationsSystemToGoldCopier.class); + trainingBuilder.add(goldCopierAnnotator); + + AnalysisEngineDescription assertionAttributeClearerAnnotator = AnalysisEngineFactory.createPrimitiveDescription(ReferenceAnnotationsSystemAssertionClearer.class); + trainingBuilder.add(assertionAttributeClearerAnnotator); + + Class> dataWriterFactoryClass = DefaultMaxentDataWriterFactory.class; + AnalysisEngineDescription trainingAssertionAnnotator = + AnalysisEngineFactory.createPrimitiveDescription( + AssertionCleartkAnalysisEngine.class, + AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION + ); + ConfigurationParameterFactory.addConfigurationParameters( + trainingAssertionAnnotator, + AssertionCleartkAnalysisEngine.PARAM_GOLD_VIEW_NAME, + AssertionEvalBasedOnModifier.GOLD_VIEW_NAME, + CleartkAnnotator.PARAM_DATA_WRITER_FACTORY_CLASS_NAME, + dataWriterFactoryClass.getName(), + DirectoryDataWriterFactory.PARAM_OUTPUT_DIRECTORY, + modelOutputDirectory + ); + trainingBuilder.add(trainingAssertionAnnotator); + + + +// CollectionReader collectionReader = XReader.getCollectionReader( +// trainingDataDirectory); +// collectionReader.setConfigParameterValue(XReader.PARAM_XML_SCHEME, XReader.XMI); +// collectionReader.reconfigure(); + + logger.info("starting feature generation..."); + SimplePipeline.runPipeline( + trainingCollectionReader, +// FilesCollectionReader.getCollectionReaderWithView( +// "src/test/resources/data/treebank/11597317.tree", +// TreebankConstants.TREEBANK_VIEW) +// , +// TreebankGoldAnnotator.getDescriptionPOSTagsOnly(), +// DefaultSnowballStemmer.getDescription("English"), +// dataWriter); + trainingBuilder.createAggregateDescription()); + logger.info("finished feature generation."); + + String[] args; + if (trainingArgs != null && trainingArgs.length > 0) { + args = new String[trainingArgs.length + 1]; + args[0] = modelOutputDirectory; + System.arraycopy(trainingArgs, 0, args, 1, trainingArgs.length); + } else { + args = new String[] { modelOutputDirectory }; + } + + HideOutput hider = new HideOutput(); + logger.info("starting training..."); + org.cleartk.classifier.jar.Train.main(args); + logger.info("finished training."); + hider.restoreOutput(); + + AggregateBuilder decodingBuilder = new AggregateBuilder(); + + //AnalysisEngineDescription goldCopierAnnotator = AnalysisEngineFactory.createPrimitiveDescription(ReferenceIdentifiedAnnotationsSystemToGoldCopier.class); + decodingBuilder.add(goldCopierAnnotator); + + //AnalysisEngineDescription assertionAttributeClearerAnnotator = AnalysisEngineFactory.createPrimitiveDescription(ReferenceAnnotationsSystemAssertionClearer.class); + decodingBuilder.add(assertionAttributeClearerAnnotator); + + AnalysisEngineDescription decodingAssertionAnnotator = + AnalysisEngineFactory.createPrimitiveDescription( + AssertionCleartkAnalysisEngine.class, + AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION + ); + ConfigurationParameterFactory.addConfigurationParameters( + decodingAssertionAnnotator, + AssertionCleartkAnalysisEngine.PARAM_GOLD_VIEW_NAME, + AssertionEvalBasedOnModifier.GOLD_VIEW_NAME, + GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH, + new File(modelOutputDirectory, "model.jar").getPath() + ); + decodingBuilder.add(decodingAssertionAnnotator); + + //SimplePipeline.runPipeline(collectionReader, builder.createAggregateDescription()); + AnalysisEngineDescription decodingAggregateDescription = decodingBuilder.createAggregateDescription(); + + + +// AnalysisEngineDescription taggerDescription = AnalysisEngineFactory.createPrimitiveDescription( +// AssertionCleartkAnalysisEngine.class, +// GenericJarClassifierFactory.PARAM_CLASSIFIER_JAR_PATH, +// //AssertionComponents.TYPE_SYSTEM_DESCRIPTION, +// modelOutputDirectory + "/model.jar"); + + logger.info("starting decoding..."); + SimplePipeline.runPipeline( + evaluationCollectionReader, +// BreakIteratorAnnotatorFactory.createSentenceAnnotator(Locale.US), +// TokenAnnotator.getDescription(), +// DefaultSnowballStemmer.getDescription("English"), + //taggerDescription, + decodingAggregateDescription, + AnalysisEngineFactory.createPrimitiveDescription( + XWriter.class, + AssertionComponents.CTAKES_CTS_TYPE_SYSTEM_DESCRIPTION, + XWriter.PARAM_OUTPUT_DIRECTORY_NAME, + decodingOutputDirectory, + XWriter.PARAM_XML_SCHEME_NAME, + XWriter.XMI)); + logger.info("finished decoding."); + + } + +}