ctakes-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From dlig...@apache.org
Subject svn commit: r1703204 - in /ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor: ae/features/EmbeddingFeatureExtractor.java data/analysis/Utils.java
Date Tue, 15 Sep 2015 14:33:34 GMT
Author: dligach
Date: Tue Sep 15 14:33:34 2015
New Revision: 1703204

URL: http://svn.apache.org/r1703204
Log:
Added cosine similarity feature; switched vectors from list of floats to list of doubles

Modified:
    ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/ae/features/EmbeddingFeatureExtractor.java
    ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/data/analysis/Utils.java

Modified: ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/ae/features/EmbeddingFeatureExtractor.java
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/ae/features/EmbeddingFeatureExtractor.java?rev=1703204&r1=1703203&r2=1703204&view=diff
==============================================================================
--- ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/ae/features/EmbeddingFeatureExtractor.java
(original)
+++ ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/ae/features/EmbeddingFeatureExtractor.java
Tue Sep 15 14:33:34 2015
@@ -18,62 +18,83 @@
  */
 package org.apache.ctakes.relationextractor.ae.features;
 
-import java.io.File;
-import java.io.IOException;
 import java.util.ArrayList;
 import java.util.List;
 import java.util.Map;
 
 import org.apache.ctakes.relationextractor.data.analysis.Utils;
-import org.apache.ctakes.relationextractor.data.analysis.Utils.Callback;
 import org.apache.ctakes.typesystem.type.textsem.IdentifiedAnnotation;
 import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
 import org.apache.uima.jcas.JCas;
 import org.cleartk.ml.Feature;
 
-import com.google.common.base.Charsets;
-import com.google.common.io.Files;
-
 /**
- * 
+ * Word embedding based features.
+ * OOV words are handled by the average vector which should be 
+ * included with the rest of the vectors and indexed as "oov".
  */
-public class EmbeddingFeatureExtractor implements RelationFeaturesExtractor<IdentifiedAnnotation,IdentifiedAnnotation>
{
+public class EmbeddingFeatureExtractor implements RelationFeaturesExtractor<IdentifiedAnnotation,
IdentifiedAnnotation> {
+
+  private int numberOfDimensions;
+  private Map<String, List<Double>> wordVectors;
+
+  public EmbeddingFeatureExtractor(Map<String, List<Double>> wordVectors) {
+    this.wordVectors = wordVectors;
+    numberOfDimensions = this.wordVectors.get("oov").size();
+  }
 
   @Override
   public List<Feature> extract(JCas jCas, IdentifiedAnnotation arg1, IdentifiedAnnotation
arg2) throws AnalysisEngineProcessException {
-			
-    File word2vec = new File(Utils.embeddingPath);
-    Map<String, List<Float>> wordVectors = null;
-    try {
-      wordVectors = Files.readLines(word2vec, Charsets.UTF_8, new Callback());
-    } catch (IOException e) {
-      e.printStackTrace();
-    }
-    
+
     List<Feature> features = new ArrayList<>();
-    
+
+    // get head words
     String arg1LastWord = Utils.getLastWord(jCas, arg1).toLowerCase();
     String arg2LastWord = Utils.getLastWord(jCas, arg2).toLowerCase();
-    
-    List<Float> arg1Vector = wordVectors.get("oov");
+
+    List<Double> arg1Vector;
     if(wordVectors.containsKey(arg1LastWord)) {
       arg1Vector = wordVectors.get(arg1LastWord);
+    } else {
+      arg1Vector = wordVectors.get("oov");
     }
-    for(int dim = 0; dim < 300; dim++) {
-      String featureName = String.format("arg1_dim_%d", dim);
-      features.add(new Feature(featureName, arg1Vector.get(dim)));
-    }
-    
-    List<Float> arg2Vector = wordVectors.get("oov");
+    List<Double> arg2Vector;
     if(wordVectors.containsKey(arg2LastWord)) {
       arg2Vector = wordVectors.get(arg2LastWord);
+    } else {
+      arg2Vector = wordVectors.get("oov");
+    }
+
+    double similarity = computeCosineSimilarity(arg1Vector, arg2Vector); 
+    features.add(new Feature("arg_cos_sim", similarity));
+
+    for(int dim = 0; dim < numberOfDimensions; dim++) {
+      String featureName = String.format("arg1_dim_%d", dim);
+      features.add(new Feature(featureName, arg1Vector.get(dim)));
     }
-    for(int dim = 0; dim < 300; dim++) {
+    for(int dim = 0; dim < numberOfDimensions; dim++) {
       String featureName = String.format("arg2_dim_%d", dim);
       features.add(new Feature(featureName, arg2Vector.get(dim)));
     }    
-  	
+
     return features;
   }
 
+  /**
+   * Compute cosine similarity between two vectors.
+   */
+  public double computeCosineSimilarity(List<Double> vector1, List<Double> vector2)
{
+
+    double dotProduct = 0.0;
+    double norm1 = 0.0;
+    double norm2 = 0.0;
+
+    for (int dim = 0; dim < numberOfDimensions; dim++) {
+      dotProduct = dotProduct + vector1.get(dim) * vector2.get(dim);
+      norm1 = norm1 + Math.pow(vector1.get(dim), 2);
+      norm2 = norm2 + Math.pow(vector2.get(dim), 2);
+    }
+
+    return dotProduct / (Math.sqrt(norm1) * Math.sqrt(norm2));
+  }
 }

Modified: ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/data/analysis/Utils.java
URL: http://svn.apache.org/viewvc/ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/data/analysis/Utils.java?rev=1703204&r1=1703203&r2=1703204&view=diff
==============================================================================
--- ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/data/analysis/Utils.java
(original)
+++ ctakes/trunk/ctakes-relation-extractor/src/main/java/org/apache/ctakes/relationextractor/data/analysis/Utils.java
Tue Sep 15 14:33:34 2015
@@ -42,7 +42,7 @@ import com.google.common.io.LineProcesso
  */
 public class Utils {
   
-  public static final String embeddingPath = "/Users/dima/Boston/Vectors/Python/sharp-arg-head-word-vectors.txt";
+  public static final String embeddingPath = "/Users/dima/Boston/Vectors/Models/ties-plus-oov.txt";
   
   /**
    * Instantiate an XMI collection reader.
@@ -82,9 +82,9 @@ public class Utils {
   /**
    * Read word embeddings from file.
    */
-  public static class Callback implements LineProcessor <Map<String, List<Float>>>
{
+  public static class Callback implements LineProcessor <Map<String, List<Double>>>
{
     
-    private Map<String, List<Float>> wordToVector;
+    private Map<String, List<Double>> wordToVector;
     
     public Callback() {
       wordToVector = new HashMap<>();
@@ -93,17 +93,17 @@ public class Utils {
     public boolean processLine(String line) throws IOException {
       
       String[] elements = line.split(" "); // e.g. skin -0.024690 0.108761 0.038441 -0.088759
...
-      List<Float> vector = new ArrayList<>();
+      List<Double> vector = new ArrayList<>();
       
       for(int dimension = 1; dimension < elements.length; dimension++) {
-        vector.add(Float.parseFloat(elements[dimension]));
+        vector.add(Double.parseDouble(elements[dimension]));
       }
       
       wordToVector.put(elements[0], vector);
       return true;
     }
     
-    public Map<String, List<Float>> getResult() {
+    public Map<String, List<Double>> getResult() {
       
       return wordToVector;
     }
@@ -112,7 +112,7 @@ public class Utils {
   public static void main(String[] args) throws IOException {
     
     File word2vec = new File(embeddingPath);
-    Map<String, List<Float>> data = Files.readLines(word2vec, Charsets.UTF_8,
new Callback());
+    Map<String, List<Double>> data = Files.readLines(word2vec, Charsets.UTF_8,
new Callback());
     System.out.println(data.get("skin"));
     System.out.println(data.get("oov"));
   }



Mime
View raw message