hivemall-issues mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From myui <...@git.apache.org>
Subject [GitHub] incubator-hivemall pull request #116: [WIP][HIVEMALL-118] word2vec
Date Thu, 28 Sep 2017 07:53:57 GMT
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/116#discussion_r141543945
  
    --- Diff: core/src/main/java/hivemall/embedding/Word2VecUDTF.java ---
    @@ -0,0 +1,364 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.embedding;
    +
    +import hivemall.UDTFWithOptions;
    +import hivemall.utils.collections.IMapIterator;
    +import hivemall.utils.collections.maps.Int2FloatOpenHashTable;
    +import hivemall.utils.collections.maps.OpenHashTable;
    +import hivemall.utils.hadoop.HiveUtils;
    +import hivemall.utils.lang.Primitives;
    +
    +import org.apache.commons.cli.CommandLine;
    +import org.apache.commons.cli.Options;
    +import org.apache.hadoop.hive.ql.exec.Description;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
    +import org.apache.hadoop.io.FloatWritable;
    +import org.apache.hadoop.io.IntWritable;
    +import org.apache.hadoop.io.Text;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import java.util.List;
    +import java.util.Arrays;
    +import java.util.ArrayList;
    +
    +@Description(
    +        name = "train_word2vec",
    +        value = "_FUNC_(array<array<float | string>> negative_table, array<int
| string> doc [, const string options]) - Returns a prediction model")
    +public class Word2VecUDTF extends UDTFWithOptions {
    +    protected transient AbstractWord2VecModel model;
    +    @Nonnegative
    +    private float startingLR;
    +    @Nonnegative
    +    private long numTrainWords;
    +    private OpenHashTable<String, Integer> word2index;
    +
    +    @Nonnegative
    +    private int dim;
    +    @Nonnegative
    +    private int win;
    +    @Nonnegative
    +    private int neg;
    +    @Nonnegative
    +    private int iter;
    +    private boolean skipgram;
    +    private boolean isStringInput;
    +
    +    private Int2FloatOpenHashTable S;
    +    private int[] aliasWordIds;
    +
    +    private ListObjectInspector negativeTableOI;
    +    private ListObjectInspector negativeTableElementListOI;
    +    private PrimitiveObjectInspector negativeTableElementOI;
    +
    +    private ListObjectInspector docOI;
    +    private PrimitiveObjectInspector wordOI;
    +
    +    @Override
    +    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException
{
    +        final int numArgs = argOIs.length;
    +
    +        if (numArgs != 3) {
    +            throw new UDFArgumentException(getClass().getSimpleName()
    +                    + " takes 3 arguments:  [, constant string options]: "
    +                    + Arrays.toString(argOIs));
    +        }
    +
    +        processOptions(argOIs);
    +
    +        this.negativeTableOI = HiveUtils.asListOI(argOIs[0]);
    +        this.negativeTableElementListOI = HiveUtils.asListOI(negativeTableOI.getListElementObjectInspector());
    +        this.docOI = HiveUtils.asListOI(argOIs[1]);
    +
    +        this.isStringInput = HiveUtils.isStringListOI(argOIs[1]);
    +
    +        if (isStringInput) {
    +            this.negativeTableElementOI = HiveUtils.asStringOI(negativeTableElementListOI.getListElementObjectInspector());
    +            this.wordOI = HiveUtils.asStringOI(docOI.getListElementObjectInspector());
    +        } else {
    +            this.negativeTableElementOI = HiveUtils.asFloatingPointOI(negativeTableElementListOI.getListElementObjectInspector());
    +            this.wordOI = HiveUtils.asIntCompatibleOI(docOI.getListElementObjectInspector());
    +        }
    +
    +        List<String> fieldNames = new ArrayList<>();
    +        List<ObjectInspector> fieldOIs = new ArrayList<>();
    +
    +        fieldNames.add("word");
    +        fieldNames.add("i");
    +        fieldNames.add("wi");
    +
    +        if (isStringInput) {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        } else {
    +            fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        }
    +
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableFloatObjectInspector);
    +
    +        this.model = null;
    +        this.word2index = null;
    +        this.S = null;
    +        this.aliasWordIds = null;
    +
    +        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    +    }
    +
    +    @Override
    +    public void process(Object[] args) throws HiveException {
    +        if (model == null) {
    +            parseNegativeTable(args[0]);
    +            this.model = createModel();
    +        }
    +
    +        List<?> rawDoc = docOI.getList(args[1]);
    +
    +        // parse rawDoc
    +        final int docLength = rawDoc.size();
    +        final int[] doc = new int[docLength];
    +        if (isStringInput) {
    +            for (int i = 0; i < docLength; i++) {
    +                doc[i] = getWordId(PrimitiveObjectInspectorUtils.getString(rawDoc.get(i),
wordOI));
    --- End diff --
    
    `PrimitiveObjectInspectorUtils.getString(rawDoc.get(i), wordOI)` may return null.


---

Mime
View raw message