hivemall-issues mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From helenahm <...@git.apache.org>
Subject [GitHub] incubator-hivemall pull request #93: [WIP][HIVEMALL-126] Maximum Entropy Mod...
Date Wed, 05 Jul 2017 04:19:06 GMT
Github user helenahm commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/93#discussion_r125551895
  
    --- Diff: core/src/main/java/hivemall/smile/classification/MaxEntUDTF.java ---
    @@ -0,0 +1,440 @@
    +package hivemall.smile.classification;
    +
    +import java.io.FileNotFoundException;
    +import java.io.IOException;
    +import java.util.ArrayList;
    +import java.util.Arrays;
    +import java.util.BitSet;
    +import java.util.HashMap;
    +import java.util.List;
    +import java.util.Map;
    +import java.util.concurrent.Callable;
    +import java.util.concurrent.atomic.AtomicInteger;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import javax.annotation.Nullable;
    +import javax.annotation.concurrent.GuardedBy;
    +
    +import org.apache.commons.cli.CommandLine;
    +import org.apache.commons.cli.Options;
    +import org.apache.commons.logging.Log;
    +import org.apache.commons.logging.LogFactory;
    +import org.apache.hadoop.hive.ql.exec.MapredContext;
    +import org.apache.hadoop.hive.ql.exec.MapredContextAccessor;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.serde2.io.DoubleWritable;
    +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
    +import org.apache.hadoop.io.IntWritable;
    +import org.apache.hadoop.io.Text;
    +import org.apache.hadoop.mapred.Reporter;
    +import org.apache.hadoop.mapred.Counters.Counter;
    +
    +import hivemall.UDTFWithOptions;
    +import hivemall.math.matrix.Matrix;
    +import hivemall.math.matrix.MatrixUtils;
    +import hivemall.math.matrix.builders.CSRMatrixBuilder;
    +import hivemall.math.matrix.builders.MatrixBuilder;
    +import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder;
    +import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
    +import hivemall.math.matrix.ints.DoKIntMatrix;
    +import hivemall.math.matrix.ints.IntMatrix;
    +import hivemall.math.random.PRNG;
    +import hivemall.math.random.RandomNumberGeneratorFactory;
    +import hivemall.math.vector.Vector;
    +import hivemall.math.vector.VectorProcedure;
    +import hivemall.smile.classification.DecisionTree.SplitRule;
    +import hivemall.smile.data.Attribute;
    +import hivemall.smile.tools.MatrixEventStream;
    +import hivemall.smile.tools.SepDelimitedTextGISModelWriter;
    +import hivemall.smile.utils.SmileExtUtils;
    +import hivemall.smile.utils.SmileTaskExecutor;
    +import hivemall.utils.codec.Base91;
    +import hivemall.utils.collections.lists.IntArrayList;
    +import hivemall.utils.hadoop.HiveUtils;
    +import hivemall.utils.hadoop.WritableUtils;
    +import hivemall.utils.lang.Preconditions;
    +import hivemall.utils.lang.Primitives;
    +import hivemall.utils.lang.RandomUtils;
    +
    +import opennlp.maxent.GIS;
    +import opennlp.maxent.io.GISModelWriter;
    +import opennlp.model.AbstractModel;
    +import opennlp.model.Event;
    +import opennlp.model.EventStream;
    +import opennlp.model.OnePassRealValueDataIndexer;
    +
    +@Description(
    +        name = "train_maxent_classifier",
    +        value = "_FUNC_(array<double> features, int label [, const boolean classification])"
    +                + " - Returns a maximum entropy model per subset of data.")
    +@UDFType(deterministic = true, stateful = false)
    +public class MaxEntUDTF extends UDTFWithOptions{
    +	private static final Log logger = LogFactory.getLog(MaxEntUDTF.class);
    +	
    +	private ListObjectInspector featureListOI;
    +    private PrimitiveObjectInspector featureElemOI;
    +    private PrimitiveObjectInspector labelOI;
    +
    +    private MatrixBuilder matrixBuilder;
    +    private IntArrayList labels;
    +    
    +	private boolean _real;
    +	private Attribute[] _attributes;
    +	private static boolean _USE_SMOOTHING;
    +	private double _SMOOTHING_OBSERVATION;
    +	
    +	private int _numTrees = 1;
    +    
    +    @Nullable
    +    private Reporter _progressReporter;
    +    @Nullable
    +    private Counter _treeBuildTaskCounter;
    +    
    +    @Override
    +    protected Options getOptions() {
    +        Options opts = new Options();
    +        opts.addOption("real", "quantative_feature_presence_indication", true,
    +            "true or false [default: true]");
    +        opts.addOption("smoothing", "smoothimg", true, "Shall smoothing be performed
[default: false]");
    +        opts.addOption("constant", "smoothing_constant", true, "real number [default:
1.0]");
    +        opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types
"
    +                + "(Q for quantitative variable and C for categorical variable. e.g.,
[Q,C,Q,C])");
    +        return opts;
    +    }
    +    
    +    @Override
    +    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException
{
    +    	boolean real = true;
    + 	    boolean USE_SMOOTHING = false;
    + 	    double SMOOTHING_OBSERVATION = 0.1;
    + 	    
    +        Attribute[] attrs = null;
    +
    +        CommandLine cl = null;
    +        if (argOIs.length >= 3) {
    +            String rawArgs = HiveUtils.getConstString(argOIs[2]);
    +            cl = parseOptions(rawArgs);
    +
    +            real = Primitives.parseBoolean(cl.getOptionValue("quantative_feature_presence_indication"),
real);
    +            attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
    +            USE_SMOOTHING = Primitives.parseBoolean(cl.getOptionValue("smoothing"), USE_SMOOTHING);
    +            SMOOTHING_OBSERVATION = Primitives.parseDouble(cl.getOptionValue("smoothing_constant"),
SMOOTHING_OBSERVATION);
    +        }
    +
    +        this._real = real;
    +        this._attributes = attrs;
    +        this._USE_SMOOTHING = USE_SMOOTHING;
    +        this._SMOOTHING_OBSERVATION = SMOOTHING_OBSERVATION;
    +
    +        return cl;
    +    }
    +    
    +    @Override
    +    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException
{
    +        if (argOIs.length < 2 || argOIs.length > 3) {
    +            throw new UDFArgumentException(
    +                "_FUNC_ takes 2 ~ 3 arguments: array<double> features, int label
[, const string options]: "
    +                        + argOIs.length);
    +        }
    +
    +        ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
    +        ObjectInspector elemOI = listOI.getListElementObjectInspector();
    +        this.featureListOI = listOI;
    +        if (HiveUtils.isNumberOI(elemOI)) {
    +            this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
    +            this.matrixBuilder = new CSRMatrixBuilder(8192);
    +        } else {
    +            throw new UDFArgumentException(
    +                "_FUNC_ takes double[] for the first argument: " + listOI.getTypeName());
    +        }
    +        this.labelOI = HiveUtils.asIntCompatibleOI(argOIs[1]);
    +
    +        processOptions(argOIs);
    +
    +        this.labels = new IntArrayList(1024);
    +
    +        final ArrayList<String> fieldNames = new ArrayList<String>(6);
    +        final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(6);
    +
    +        fieldNames.add("model_id");
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        fieldNames.add("model_weight");
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
    +        fieldNames.add("model");
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        fieldNames.add("attributes");
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        fieldNames.add("oob_errors");
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        fieldNames.add("oob_tests");
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +
    +        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    +    }
    +    
    +    @Override
    +    public void process(Object[] args) throws HiveException {
    +        if (args[0] == null) {
    +            throw new HiveException("array<double> features was null");
    +        }
    +        parseFeatures(args[0], matrixBuilder);
    +        int label = PrimitiveObjectInspectorUtils.getInt(args[1], labelOI);
    +        labels.add(label);
    +    }
    +    
    +    private void parseFeatures(@Nonnull final Object argObj, @Nonnull final MatrixBuilder
builder) {
    +    	final int length = featureListOI.getListLength(argObj);
    +        for (int i = 0; i < length; i++) {
    +            Object o = featureListOI.getListElement(argObj, i);
    +            if (o == null) {
    +                continue;
    +            }
    +            double v = PrimitiveObjectInspectorUtils.getDouble(o, featureElemOI);
    +            builder.nextColumn(i, v);
    +        } 
    +        builder.nextRow();
    +    }
    +    
    +    @Override
    +    public void close() throws HiveException {
    +        this._progressReporter = getReporter();
    +        this._treeBuildTaskCounter = (_progressReporter == null) ? null
    +                : _progressReporter.getCounter("hivemall.smile.MaxEntClassifier$Counter",
    +                    "finishedGISTask");
    +        reportProgress(_progressReporter);
    +
    +        if (!labels.isEmpty()) {
    +            Matrix x = matrixBuilder.buildMatrix();
    +            this.matrixBuilder = null;
    +            int[] y = labels.toArray();
    +            this.labels = null;
    +
    +            // run training
    +            train(x, y);
    +        }
    +
    +        // clean up
    +        this.featureListOI = null;
    +        this.featureElemOI = null;
    +        this.labelOI = null;
    +    }
    +    
    +    private void checkOptions() throws HiveException {
    +    	if (_USE_SMOOTHING == false && _SMOOTHING_OBSERVATION != 0.1) {
    +            throw new HiveException("Instructions received to avoid smoothing, but smoothing
constant is set [" + _SMOOTHING_OBSERVATION + "]");
    +        }
    +    }
    +    
    +    /**
    +     * @param x features
    +     * @param y label
    +     * @param attrs attribute types
    +     * @param numTrees The number of trees
    +     * @param numVars The number of variables to pick up in each node.
    +     * @param seed The seed number for Random Forest
    +     */
    +    private void train(@Nonnull Matrix x, @Nonnull final int[] y) throws HiveException
{
    +        final int numExamples = x.numRows();
    +        if (numExamples != y.length) {
    +            throw new HiveException(String.format("The sizes of X and Y don't match:
%d != %d",
    +                numExamples, y.length));
    +        }
    +        checkOptions();
    +
    +        int[] labels = SmileExtUtils.classLables(y);
    +        Attribute[] attributes = SmileExtUtils.attributeTypes(_attributes, x);
    +
    +        if (logger.isInfoEnabled()) {
    +            logger.info("real: " + _real + ", smoothing: " + this._USE_SMOOTHING + ",
smoothing constant: "
    +                    + _SMOOTHING_OBSERVATION);
    +        }
    +
    +        IntMatrix prediction = new DoKIntMatrix(numExamples, labels.length); // placeholder
for out-of-bag prediction
    +        AtomicInteger remainingTasks = new AtomicInteger(_numTrees);
    +        List<TrainingTask> tasks = new ArrayList<TrainingTask>();
    +        for (int i = 0; i < _numTrees; i++) {
    --- End diff --
    
    I left an ability to divide training set further into subsets for memory reasons. Perhaps
at times it is faster to subdivide the data that mapper accepts into a number of chunks and
train a model per each training set chunk. Implementation-wise it is same to Random Forests
indeed. At the moment _numTrees is set to 1. If it works with 1 all the time then perhaps
it is a good setting, otherwise it may be better to allow people to set it. 


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

Mime
View raw message