mahout-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tdunn...@apache.org
Subject svn commit: r1002033 - /mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
Date Tue, 28 Sep 2010 06:20:16 GMT
Author: tdunning
Date: Tue Sep 28 06:20:16 2010
New Revision: 1002033

URL: http://svn.apache.org/viewvc?rev=1002033&view=rev
Log:
got rid of final declarations to avoid style complaints and keep from SHOUTING

Modified:
    mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java

Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java?rev=1002033&r1=1002032&r2=1002033&view=diff
==============================================================================
--- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
(original)
+++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
Tue Sep 28 06:20:16 2010
@@ -18,9 +18,11 @@
 package org.apache.mahout.classifier.sgd;
 
 import com.google.common.collect.ConcurrentHashMultiset;
+import com.google.common.collect.HashMultiset;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Multiset;
+import com.google.common.collect.Ordering;
 import org.apache.lucene.analysis.Analyzer;
 import org.apache.lucene.analysis.TokenStream;
 import org.apache.lucene.analysis.standard.StandardAnalyzer;
@@ -116,16 +118,16 @@ public final class TrainNewsGroups {
     new SimpleDateFormat("dd-MMM-yyyy HH:mm:ss")
   };
 
-  private static final Analyzer ANALYZER = new StandardAnalyzer(Version.LUCENE_30);
-  private static final FeatureVectorEncoder ENCODER = new StaticWordValueEncoder("body");
-  private static final FeatureVectorEncoder BIAS = new ConstantValueEncoder("Intercept");
-
-  private TrainNewsGroups() {
-  }
+  private static Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_30);
+  private static FeatureVectorEncoder encoder = new StaticWordValueEncoder("body");
+  private static FeatureVectorEncoder bias = new ConstantValueEncoder("Intercept");
+  private static Multiset<String> overallCounts;
 
   public static void main(String[] args) throws IOException {
     File base = new File(args[0]);
 
+    overallCounts = HashMultiset.create();
+
     int leakType = 0;
     if (args.length > 1) {
       leakType = Integer.parseInt(args[1]);
@@ -133,7 +135,7 @@ public final class TrainNewsGroups {
 
     Dictionary newsGroups = new Dictionary();
 
-    ENCODER.setProbes(2);
+    encoder.setProbes(2);
     AdaptiveLogisticRegression learningAlgorithm = new AdaptiveLogisticRegression(20, FEATURES,
new L1());
     learningAlgorithm.setInterval(800);
     learningAlgorithm.setAveragingWindow(500);
@@ -215,6 +217,18 @@ public final class TrainNewsGroups {
     learningAlgorithm.close();
     dissect(leakType, newsGroups, learningAlgorithm, files);
     System.out.println("exiting main");
+
+    List<Integer> counts = Lists.newArrayList();
+    System.out.printf("Word counts\n");
+    for (String count : overallCounts.elementSet()) {
+      counts.add(overallCounts.count(count));
+    }
+    Collections.sort(counts, Ordering.natural().reverse());
+    k = 0;
+    for (Integer count : counts) {
+      System.out.printf("%d\t%d\n", k, count);
+      k++;
+    }
   }
 
   private static void dissect(int leakType,
@@ -227,8 +241,8 @@ public final class TrainNewsGroups {
     Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
     ModelDissector md = new ModelDissector();
 
-    ENCODER.setTraceDictionary(traceDictionary);
-    BIAS.setTraceDictionary(traceDictionary);
+    encoder.setTraceDictionary(traceDictionary);
+    bias.setTraceDictionary(traceDictionary);
 
     for (File file : permute(files, rand).subList(0, 500)) {
       String ng = file.getParentFile().getName();
@@ -254,7 +268,7 @@ public final class TrainNewsGroups {
     try {
       String line = reader.readLine();
       Reader dateString = new StringReader(DATE_FORMATS[leakType % 3].format(new Date(date)));
-      countWords(ANALYZER, words, dateString);
+      countWords(analyzer, words, dateString);
       while (line != null && line.length() > 0) {
         boolean countHeader = (
           line.startsWith("From:") || line.startsWith("Subject:") ||
@@ -262,22 +276,22 @@ public final class TrainNewsGroups {
         do {
           Reader in = new StringReader(line);
           if (countHeader) {
-            countWords(ANALYZER, words, in);
+            countWords(analyzer, words, in);
           }
           line = reader.readLine();
         } while (line.startsWith(" "));
       }
       if (leakType < 3) {
-        countWords(ANALYZER, words, reader);
+        countWords(analyzer, words, reader);
       }
     } finally {
       reader.close();
     }
 
     Vector v = new RandomAccessSparseVector(FEATURES);
-    BIAS.addToVector("", 1, v);
+    bias.addToVector("", 1, v);
     for (String word : words.elementSet()) {
-      ENCODER.addToVector(word, Math.log(1 + words.count(word)), v);
+      encoder.addToVector(word, Math.log(1 + words.count(word)), v);
     }
 
     return v;
@@ -290,6 +304,7 @@ public final class TrainNewsGroups {
       String s = ts.getAttribute(TermAttribute.class).term();
       words.add(s);
     }
+    overallCounts.addAll(words);
   }
 
   private static List<File> permute(Iterable<File> files, Random rand) {



Mime
View raw message