mahout-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From "Jeff Eastman (JIRA)" <j...@apache.org>
Subject [jira] Commented: (MAHOUT-30) dirichlet process implementation
Date Thu, 13 Nov 2008 04:37:44 GMT

    [ https://issues.apache.org/jira/browse/MAHOUT-30?page=com.atlassian.jira.plugin.system.issuetabpanels:comment-tabpanel&focusedCommentId=12647186#action_12647186
] 

Jeff Eastman commented on MAHOUT-30:
------------------------------------

I refactored again and was able eliminate materializing of the posterior {{data}} sets by
adding {{observe()}} and {{computeParameters()}} operations to {{Model}}. The idea is that
all models begin in their prior state and are asked to observe each sample that is assigned
to them. Then, before {{pdf()}} is called on them in the next iteration a call to {{computeParameters()}}
finalizes the parameters once and turns the model into a posterior model. I also compute {{counts}}
on the fly to eliminate materializing {{z}} altogether. I hope I didn't throw the baby out
with the bath water.

Finally, I introduced a {{DirichletState}} bean to hold the models, dirichlet distribution
and the mixture, simplifying the arguments and, I think, fixing a bug in the earlier refactoring.
The algorithm runs over 10,000 points and produces the following outputs (prior() indicates
a model with no observations, n is the number of observations, m the mean and sd the std):

Generating 4000 samples m=[1.0, 1.0] sd=3.0
Generating 3000 samples m=[1.0, 0.0] sd=0.1
Generating 3000 samples m=[0.0, 1.0] sd=0.1

* sample[0]= [prior(), normal(n=6604 m=[0.67, 0.63] sd=1.11), normal(n=86 m=[0.77, 2.81] sd=2.15),
prior(), normal(n=242 m=[2.89, 1.67] sd=2.14), normal(n=2532 m=[0.53, 0.55] sd=0.69), normal(n=339
m=[0.99, 1.70] sd=2.18), normal(n=77 m=[0.53, 0.47] sd=0.51), normal(n=119 m=[0.36, 0.47]
sd=2.85), normal(n=1 m=[0.00, 0.00] sd=0.33)]
* sample[1]= [prior(), normal(n=6626 m=[0.62, 0.54] sd=0.91), normal(n=137 m=[0.51, 2.99]
sd=1.56), normal(n=2 m=[0.57, 0.25] sd=0.70), normal(n=506 m=[2.55, 0.93] sd=1.73), normal(n=1573
m=[0.38, 0.60] sd=0.50), normal(n=848 m=[0.81, 1.59] sd=2.11), normal(n=67 m=[0.76, 0.31]
sd=0.45), normal(n=240 m=[0.73, 0.31] sd=2.24), normal(n=1 m=[0.00, 0.00] sd=0.98)]
* sample[2]= [prior(), normal(n=5842 m=[0.67, 0.39] sd=0.73), normal(n=157 m=[0.73, 3.12]
sd=1.14), prior(), normal(n=655 m=[2.32, 0.64] sd=1.60), normal(n=1439 m=[0.00, 1.00] sd=0.33),
normal(n=1439 m=[0.78, 1.53] sd=1.89), normal(n=66 m=[0.96, -0.04] sd=0.24), normal(n=399
m=[0.63, -0.03] sd=1.99), normal(n=3 m=[-0.07, 0.76] sd=0.41)]

{code:title=Model}
/**
 * A model is a probability distribution over observed data points and allows 
 * the probability of any data point to be computed.
 */
public interface Model<Observation> {
  
  /**
   * Observe the given observation, retaining information about it
   * 
   * @param x an Observation from the posterior
   */
  public abstract void observe(Observation x);
  
  /**
   * Compute a new set of posterior parameters based upon the Observations 
   * that have been observed since my creation
   */
  public abstract void computeParameters();

  /**
  * Return the probability that the observation is described by this model
  * 
  * @param x an Observation from the posterior
  * @return the probability that x is in z
  */
  public abstract double pdf(Observation x);
}
{code}

{code:title=DirichletCluster}
  /**
   * Initialize the variables and run the iterations to assign the sample data
   * points to a computed number of clusters
   *
   * @return a List<List<Model<Observation>>> of the observed models
   */
  public List<List<Model<Observation>>> dirichletCluster() {
    DirichletState<Observation> state = initializeState();

    // create a posterior sample list to collect results
    List<List<Model<Observation>>> clusterSamples = new ArrayList<List<Model<Observation>>>();

    // now iterate
    for (int iteration = 0; iteration < maxIterations; iteration++)
      iterate(state, iteration, clusterSamples);

    return clusterSamples;
  }

  /**
   * Initialize the state of the computation
   * 
   * @return the DirichletState
   */
  private DirichletState<Observation> initializeState() {
    // get initial prior models
    List<Model<Observation>> models = createPriorModels();
    // create the initial distribution.
    DirichletDistribution distribution = new DirichletDistribution(maxClusters,
        alpha_0, dist);
    // mixture parameters are sampled from the Dirichlet distribution. 
    Vector mixture = distribution.sample();
    return new DirichletState<Observation>(models, distribution, mixture);
  }

  /**
   * Create a list of prior models
   * @return the Observation
   */
  private List<Model<Observation>> createPriorModels() {
    List<Model<Observation>> models = new ArrayList<Model<Observation>>();
    for (int k = 0; k < maxClusters; k++) {
      models.add(modelFactory.sampleFromPrior());
    }
    return models;
  }

  /**
   * Perform one iteration of the clustering process, updating the state for the next iteration
   * @param state the DirichletState<Observation> of this iteration
   * @param iteration the int iteration number
   * @param clusterSamples a List<List<Model<Observation>>> that will be
modified in each iteration
   */
  private void iterate(DirichletState<Observation> state, int iteration,
      List<List<Model<Observation>>> clusterSamples) {

    // create new prior models
    List<Model<Observation>> newModels = createPriorModels();

    // initialize vector of membership counts
    Vector counts = new DenseVector(maxClusters);
    counts.assign(alpha_0 / maxClusters);

    // iterate over the samples
    for (int i = 0; i < sampleData.size(); i++) {
      Observation x = sampleData.get(i);
      // compute vector of probabilities x is described by each model
      Vector pi = computeProbabilities(state, x);
      // then pick one cluster by sampling a Multinomial distribution based upon them
      // see: http://en.wikipedia.org/wiki/Multinomial_distribution
      int model = dist.rmultinom(pi);
      // ask the selected model to observe the datum
      newModels.get(model).observe(x);
      // record counts for the model
      counts.set(model, counts.get(model) + 1);
    }

    // compute new model parameters based upon observations
    for (int k = 0; k < maxClusters; k++)
      newModels.get(k).computeParameters();

    // update the state from the new models and counts
    state.distribution.setAlpha(counts);
    state.mixture = state.distribution.sample();
    state.models = newModels;

    // periodically add models to cluster samples after getting started
    if ((iteration > burnin) && (iteration % thin == 0))
      clusterSamples.add(state.models);
  }

  /**
   * Compute a normalized vector of probabilities that x is described
   * by each model using the mixture and the model pdfs
   * 
   * @param state the DirichletState<Observation> of this iteration
   * @param x an Observation
   * @return the Vector of probabilities
   */
  private Vector computeProbabilities(DirichletState<Observation> state,
      Observation x) {
    Vector pi = new DenseVector(maxClusters);
    double max = 0;
    for (int k = 0; k < maxClusters; k++) {
      double p = state.mixture.get(k) * state.models.get(k).pdf(x);
      pi.set(k, p);
      if (max < p)
        max = p;
    }
    // normalize the probabilities by largest observed value
    pi.assign(new TimesFunction(), 1.0 / max);
    return pi;
  }
{code}


> dirichlet process implementation
> --------------------------------
>
>                 Key: MAHOUT-30
>                 URL: https://issues.apache.org/jira/browse/MAHOUT-30
>             Project: Mahout
>          Issue Type: New Feature
>          Components: Clustering
>            Reporter: Isabel Drost
>         Attachments: MAHOUT-30.patch
>
>
> Copied over from original issue:
> > Further extension can also be made by assuming an infinite mixture model. The implementation
is only slightly more difficult and the result is a (nearly)
> > non-parametric clustering algorithm.

-- 
This message is automatically generated by JIRA.
-
You can reply to this email to add a comment to the issue online.


Mime
View raw message