hivemall-issues mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From myui <...@git.apache.org>
Subject [GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Date Thu, 28 Sep 2017 07:54:17 GMT
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543757
  
    --- Diff: core/src/main/java/hivemall/embedding/CBoWModel.java ---
    @@ -0,0 +1,131 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.math.random.PRNG;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +
    +public final class CBoWModel extends AbstractWord2VecModel {
    +    protected CBoWModel(final int dim, final int win, final int neg, final int iter,
    +            final float startingLR, final long numTrainWords, final Int2FloatOpenHashTable
S,
    +            final int[] aliasWordId) {
    +        super(dim, win, neg, iter, startingLR, numTrainWords, S, aliasWordId);
    +    }
    +
    +    protected void trainOnDoc(@Nonnull final int[] doc) {
    +        final int vecDim = dim;
    +        final int numNegative = neg;
    +        final PRNG _rnd = rnd;
    +        final Int2FloatOpenHashTable _S = S;
    +        final int[] _aliasWordId = aliasWordId;
    +        float label, gradient;
    +
    +        // reuse instance
    +        int windowSize, k, numContext, targetWord, inWord, positiveWord;
    +
    +        updateLearningRate();
    +
    +        int docLength = doc.length;
    +        for (int t = 0; t < iter; t++) {
    +            for (int positiveWordPosition = 0; positiveWordPosition < docLength; positiveWordPosition++)
{
    +                windowSize = _rnd.nextInt(win) + 1;
    +
    +                numContext = windowSize * 2 + Math.min(0, positiveWordPosition - windowSize)
    +                        + Math.min(0, docLength - positiveWordPosition - windowSize -
1);
    +
    +                float[] gradVec = new float[vecDim];
    +                float[] averageVec = new float[vecDim];
    +
    +                // collect context words
    +                for (int contextPosition = positiveWordPosition - windowSize; contextPosition
< positiveWordPosition
    +                        + windowSize + 1; contextPosition++) {
    +                    if (contextPosition == positiveWordPosition || contextPosition <
0
    +                            || contextPosition >= docLength) {
    +                        continue;
    +                    }
    +
    +                    inWord = doc[contextPosition];
    +
    +                    // average vector of input word vectors
    +                    if (!inputWeights.containsKey(inWord * vecDim)) {
    +                        initWordWeights(inWord);
    +                    }
    +
    +                    for (int i = 0; i < vecDim; i++) {
    +                        averageVec[i] += inputWeights.get(inWord * vecDim + i) / numContext;
    +                    }
    +                }
    +                positiveWord = doc[positiveWordPosition];
    +                // negative sampling
    +                for (int d = 0; d < numNegative + 1; d++) {
    +                    if (d == 0) {
    +                        targetWord = positiveWord;
    +                        label = 1.f;
    +                    } else {
    +                        do {
    +                            k = _rnd.nextInt(_S.size());
    +                            if (_S.get(k) > _rnd.nextDouble()) {
    +                                targetWord = k;
    +                            } else {
    +                                targetWord = _aliasWordId[k];
    +                            }
    +                        } while (targetWord == positiveWord);
    +                        label = 0.f;
    +                    }
    +
    +                    gradient = grad(label, averageVec, targetWord) * lr;
    +                    for (int i = 0; i < vecDim; i++) {
    +                        gradVec[i] += gradient * contextWeights.get(targetWord * vecDim
+ i);
    +                        contextWeights.put(targetWord * vecDim + i,
    +                            contextWeights.get(targetWord * vecDim + i) + gradient *
averageVec[i]);
    +                    }
    +                }
    +
    +                // update inWord vector
    +                for (int contextPosition = positiveWordPosition - windowSize; contextPosition
< positiveWordPosition
    +                        + windowSize + 1; contextPosition++) {
    +                    if (contextPosition == positiveWordPosition || contextPosition <
0
    +                            || contextPosition >= docLength) {
    +                        continue;
    +                    }
    +
    +                    inWord = doc[contextPosition];
    +                    for (int i = 0; i < vecDim; i++) {
    +                        inputWeights.put(inWord * vecDim + i, inputWeights.get(inWord
* vecDim + i)
    +                                + gradVec[i]);
    +                    }
    +                }
    +            }
    +        }
    +
    +        wordCount += docLength * iter;
    +    }
    +
    +    private float grad(final float label, @Nonnull final float[] w, final int c) {
    +        float dotValue = 0.f;
    +        for (int i = 0; i < dim; i++) {
    +            dotValue += w[i] * contextWeights.get(c * dim + i);
    +        }
    +
    +        return (label - sigmoid(dotValue, sigmoidTable));
    --- End diff --
    
    remove redundant outermost `(`.


---

Mime
View raw message