hivemall-issues mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From myui <...@git.apache.org>
Subject [GitHub] incubator-hivemall pull request #66: [HIVEMALL-91] Implement Online LDA
Date Fri, 14 Apr 2017 08:56:54 GMT
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/66#discussion_r111548963
  
    --- Diff: core/src/main/java/hivemall/topicmodel/LDAUDTF.java ---
    @@ -0,0 +1,553 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.topicmodel;
    +
    +import hivemall.UDTFWithOptions;
    +import hivemall.utils.hadoop.HiveUtils;
    +import hivemall.utils.io.FileUtils;
    +import hivemall.utils.io.NioStatefullSegment;
    +import hivemall.utils.lang.NumberUtils;
    +import hivemall.utils.lang.Primitives;
    +import hivemall.utils.lang.SizeOf;
    +
    +import java.io.File;
    +import java.io.IOException;
    +import java.nio.ByteBuffer;
    +import java.util.ArrayList;
    +import java.util.Arrays;
    +import java.util.List;
    +import java.util.Map;
    +import java.util.SortedMap;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +
    +import org.apache.commons.cli.CommandLine;
    +import org.apache.commons.cli.Options;
    +import org.apache.commons.logging.Log;
    +import org.apache.commons.logging.LogFactory;
    +import org.apache.hadoop.hive.ql.exec.Description;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import org.apache.hadoop.io.FloatWritable;
    +import org.apache.hadoop.io.IntWritable;
    +import org.apache.hadoop.io.Text;
    +import org.apache.hadoop.mapred.Counters;
    +import org.apache.hadoop.mapred.Reporter;
    +
    +import com.google.common.annotations.VisibleForTesting;
    +
    +@Description(name = "train_lda", value = "_FUNC_(array<string> words[, const string
options])"
    +        + " - Returns a relation consists of <int topic, string word, float score>")
    +public class LDAUDTF extends UDTFWithOptions {
    +    private static final Log logger = LogFactory.getLog(LDAUDTF.class);
    +
    +    // Options
    +    protected int topic;
    +    protected float alpha;
    +    protected float eta;
    +    protected int numDoc;
    +    protected double tau0;
    +    protected double kappa;
    +    protected int iterations;
    +    protected double delta;
    +    protected double eps;
    +    protected int miniBatchSize;
    +
    +    // if `num_doc` option is not given, this flag will be true
    +    // in that case, UDTF automatically sets `count` value to the _D parameter in an
online LDA model
    +    protected boolean isAutoD;
    +
    +    // number of proceeded training samples
    +    protected long count;
    +
    +    protected String[][] miniBatch;
    +    protected int miniBatchCount;
    +
    +    protected OnlineLDAModel model;
    +
    +    protected ListObjectInspector wordCountsOI;
    +
    +    // for iterations
    +    protected NioStatefullSegment fileIO;
    +    protected ByteBuffer inputBuf;
    +
    +    public LDAUDTF() {
    +        this.topic = 10;
    +        this.alpha = 1.f / topic;
    +        this.eta = 1.f / topic;
    +        this.numDoc = -1;
    +        this.tau0 = 64.d;
    +        this.kappa = 0.7;
    +        this.iterations = 1;
    +        this.delta = 1E-5d;
    +        this.eps = 1E-1d;
    +        this.miniBatchSize = 1; // truly online setting
    +    }
    +
    +    @Override
    +    protected Options getOptions() {
    +        Options opts = new Options();
    +        opts.addOption("k", "topic", true, "The number of topics [default: 10]");
    +        opts.addOption("alpha", true, "The hyperparameter for theta [default: 1/k]");
    +        opts.addOption("eta", true, "The hyperparameter for beta [default: 1/k]");
    +        opts.addOption("d", "num_doc", true, "The total number of documents [default:
auto]");
    +        opts.addOption("tau", "tau0", true,
    +            "The parameter which downweights early iterations [default: 64.0]");
    +        opts.addOption("kappa", true, "Exponential decay rate (i.e., learning rate) [default:
0.7]");
    +        opts.addOption("iter", "iterations", true, "The maximum number of iterations
[default: 1]");
    +        opts.addOption("delta", true, "Check convergence in the expectation step [default:
1E-5]");
    +        opts.addOption("eps", "epsilon", true,
    +            "Check convergence based on the difference of perplexity [default: 1E-1]");
    +        opts.addOption("s", "mini_batch_size", true, "Repeat model updating per mini-batch
[default: 1]");
    +        return opts;
    +    }
    +
    +    @Override
    +    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException
{
    +        CommandLine cl = null;
    +
    +        if (argOIs.length >= 2) {
    +            String rawArgs = HiveUtils.getConstString(argOIs[1]);
    +            cl = parseOptions(rawArgs);
    +            this.topic = Primitives.parseInt(cl.getOptionValue("topic"), 10);
    +            this.alpha = Primitives.parseFloat(cl.getOptionValue("alpha"), 1.f / topic);
    +            this.eta = Primitives.parseFloat(cl.getOptionValue("eta"), 1.f / topic);
    +            this.numDoc = Primitives.parseInt(cl.getOptionValue("num_doc"), -1);
    +            this.tau0 = Primitives.parseDouble(cl.getOptionValue("tau0"), 64.d);
    +            if (tau0 <= 0.d) {
    +                throw new UDFArgumentException("'-tau0' must be positive: " + tau0);
    +            }
    +            this.kappa = Primitives.parseDouble(cl.getOptionValue("kappa"), 0.7d);
    +            if (kappa <= 0.5 || kappa > 1.d) {
    +                throw new UDFArgumentException("'-kappa' must be in (0.5, 1.0]: " + kappa);
    +            }
    +            this.iterations = Primitives.parseInt(cl.getOptionValue("iterations"), 1);
    +            if (iterations < 1) {
    +                throw new UDFArgumentException(
    +                    "'-iterations' must be greater than or equals to 1: " + iterations);
    +            }
    +            this.delta = Primitives.parseDouble(cl.getOptionValue("delta"), 1E-5d);
    +            this.eps = Primitives.parseDouble(cl.getOptionValue("epsilon"), 1E-1d);
    +            this.miniBatchSize = Primitives.parseInt(cl.getOptionValue("mini_batch_size"),
1);
    +        }
    +
    +        return cl;
    +    }
    +
    +    @Override
    +    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException
{
    +        if (argOIs.length < 1) {
    +            throw new UDFArgumentException(
    +                "_FUNC_ takes 1 arguments: array<string> words [, const string
options]");
    +        }
    +
    +        this.wordCountsOI = HiveUtils.asListOI(argOIs[0]);
    +        HiveUtils.validateFeatureOI(wordCountsOI.getListElementObjectInspector());
    +
    +        processOptions(argOIs);
    +
    +        this.model = new OnlineLDAModel(topic, alpha, eta, numDoc, tau0, kappa, delta);
    +        this.count = 0L;
    +        this.isAutoD = (numDoc < 0);
    +        this.miniBatch = new String[miniBatchSize][];
    +        this.miniBatchCount = 0;
    +
    +        ArrayList<String> fieldNames = new ArrayList<String>();
    +        ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
    +        fieldNames.add("topic");
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        fieldNames.add("word");
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        fieldNames.add("score");
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
    +
    +        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    +    }
    +
    +    @Override
    +    public void process(Object[] args) throws HiveException {
    +        int length = wordCountsOI.getListLength(args[0]);
    +        String[] wordCounts = new String[length];
    +        int j = 0;
    +        for (int i = 0; i < length; i++) {
    +            Object o = wordCountsOI.getListElement(args[0], i);
    +            if (o == null) {
    +                continue;
    +            }
    +            String s = o.toString();
    +            wordCounts[j] = s;
    +            j++;
    +        }
    +
    +        count++;
    +        if (isAutoD) {
    +            model.setNumTotalDocs((int) count);
    +        }
    +
    +        recordTrainSampleToTempFile(wordCounts);
    +
    +        miniBatch[miniBatchCount] = wordCounts;
    +        miniBatchCount++;
    +
    +        if (miniBatchCount == miniBatchSize) {
    +            model.train(miniBatch);
    --- End diff --
    
    null element handling for `miniBatch` is required


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

Mime
View raw message