mahout-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From Jeff Eastman <j...@windwardsolutions.com>
Subject Re: Fuzzy K Means
Date Thu, 18 Feb 2010 12:18:54 GMT
+1 Looks like what I did too.

Robin Anil wrote:
> I am pasting the patch for SoftCluster here..
>
> Index:
> core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java
> ===================================================================
> ---
> core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java
> (revision
> 910924)
> +++
> core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java
> (working
> copy)
> @@ -21,21 +21,14 @@
>  import java.io.DataOutput;
>  import java.io.IOException;
>
> -import org.apache.hadoop.io.Writable;
> +import org.apache.mahout.clustering.ClusterBase;
>  import org.apache.mahout.math.AbstractVector;
> -import org.apache.mahout.math.RandomAccessSparseVector;
>  import org.apache.mahout.math.Vector;
>  import org.apache.mahout.math.VectorWritable;
>  import org.apache.mahout.math.function.SquareRootFunction;
>
> -public class SoftCluster implements Writable {
> -
> -  // this cluster's clusterId
> -  private int clusterId;
> -
> -  // the current center
> -  private Vector center = new RandomAccessSparseVector(0);
> -
> +public class SoftCluster extends ClusterBase{
> +
>    // the current centroid is lazy evaluated and may be null
>    private Vector centroid = null;
>
> @@ -90,7 +83,7 @@
>
>    @Override
>    public void write(DataOutput out) throws IOException {
> -    out.writeInt(clusterId);
> +    out.writeInt(this.getId());
>      out.writeBoolean(converged);
>      Vector vector = computeCentroid();
>      VectorWritable.writeVector(out, vector);
> @@ -98,13 +91,13 @@
>
>    @Override
>    public void readFields(DataInput in) throws IOException {
> -    clusterId = in.readInt();
> +    this.setId(in.readInt());
>      converged = in.readBoolean();
>      VectorWritable temp = new VectorWritable();
>      temp.readFields(in);
> -    center = temp.get();
> +    this.setCenter(temp.get());
>      this.pointProbSum = 0;
> -    this.weightedPointTotal = center.like();
> +    this.weightedPointTotal = getCenter().like();
>    }
>
>    /**
> @@ -112,6 +105,7 @@
>     *
>     * @return the new centroid
>     */
> +  @Override
>    public Vector computeCentroid() {
>      if (pointProbSum == 0) {
>        return weightedPointTotal;
> @@ -132,7 +126,7 @@
>     *          the center point
>     */
>    public SoftCluster(Vector center) {
> -    this.center = center;
> +    setCenter(center);
>      this.pointProbSum = 0;
>
>      this.weightedPointTotal = center.like();
> @@ -145,8 +139,8 @@
>     *          the center point
>     */
>    public SoftCluster(Vector center, int clusterId) {
> -    this.clusterId = clusterId;
> -    this.center = center;
> +    this.setId(clusterId);
> +    this.setCenter(center);
>      this.pointProbSum = 0;
>      this.weightedPointTotal = center.like();
>    }
> @@ -154,7 +148,7 @@
>    /** Construct a new softcluster with the given clusterID */
>    public SoftCluster(String clusterId) {
>
> -    this.clusterId = Integer.parseInt(clusterId.substring(1));
> +    this.setId(Integer.parseInt(clusterId.substring(1)));
>      this.pointProbSum = 0;
>      // this.weightedPointTotal = center.like();
>      this.converged = clusterId.charAt(0) == 'V';
> @@ -162,14 +156,15 @@
>
>    @Override
>    public String toString() {
> -    return getIdentifier() + " - " + center.asFormatString();
> +    return getIdentifier() + " - " + getCenter().asFormatString();
>    }
>
> +  @Override
>    public String getIdentifier() {
>      if (converged) {
> -      return "V" + clusterId;
> +      return "V" + this.getId();
>      } else {
> -      return "C" + clusterId;
> +      return "C" + this.getId();
>      }
>    }
>
> @@ -212,7 +207,7 @@
>      centroid = null;
>      pointProbSum += ptProb;
>      if (weightedPointTotal == null) {
> -      weightedPointTotal = point.clone().times(ptProb);
> +      weightedPointTotal = point.times(ptProb);
>      } else {
>        weightedPointTotal = weightedPointTotal.plus(point.times(ptProb));
>      }
> @@ -234,19 +229,15 @@
>      }
>    }
>
> -  public Vector getCenter() {
> -    return center;
> -  }
> -
>    public double getPointProbSum() {
>      return pointProbSum;
>    }
>
>    /** Compute the centroid and set the center to it. */
>    public void recomputeCenter() {
> -    center = computeCentroid();
> +    this.setCenter(computeCentroid());
>      pointProbSum = 0;
> -    weightedPointTotal = center.like();
> +    weightedPointTotal = getCenter().like();
>    }
>
>    public Vector getWeightedPointTotal() {
> @@ -265,8 +256,9 @@
>      this.converged = converged;
>    }
>
> -  public int getClusterId() {
> -    return clusterId;
> +  @Override
> +  public String asFormatString() {
> +    return formatCluster(this);
>    }
>
>  }
>
>   


Mime
View raw message