hivemall-issues mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From myui <...@git.apache.org>
Subject [GitHub] incubator-hivemall pull request #93: [WIP][HIVEMALL-126] Maximum Entropy Mod...
Date Wed, 05 Jul 2017 02:34:00 GMT
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/93#discussion_r125544508
  
    --- Diff: core/src/main/java/hivemall/smile/classification/MaxEntUDTF.java ---
    @@ -0,0 +1,440 @@
    +package hivemall.smile.classification;
    +
    +import java.io.FileNotFoundException;
    +import java.io.IOException;
    +import java.util.ArrayList;
    +import java.util.Arrays;
    +import java.util.BitSet;
    +import java.util.HashMap;
    +import java.util.List;
    +import java.util.Map;
    +import java.util.concurrent.Callable;
    +import java.util.concurrent.atomic.AtomicInteger;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import javax.annotation.Nullable;
    +import javax.annotation.concurrent.GuardedBy;
    +
    +import org.apache.commons.cli.CommandLine;
    +import org.apache.commons.cli.Options;
    +import org.apache.commons.logging.Log;
    +import org.apache.commons.logging.LogFactory;
    +import org.apache.hadoop.hive.ql.exec.MapredContext;
    +import org.apache.hadoop.hive.ql.exec.MapredContextAccessor;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.serde2.io.DoubleWritable;
    +import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorUtils;
    +import org.apache.hadoop.io.IntWritable;
    +import org.apache.hadoop.io.Text;
    +import org.apache.hadoop.mapred.Reporter;
    +import org.apache.hadoop.mapred.Counters.Counter;
    +
    +import hivemall.UDTFWithOptions;
    +import hivemall.math.matrix.Matrix;
    +import hivemall.math.matrix.MatrixUtils;
    +import hivemall.math.matrix.builders.CSRMatrixBuilder;
    +import hivemall.math.matrix.builders.MatrixBuilder;
    +import hivemall.math.matrix.builders.RowMajorDenseMatrixBuilder;
    +import hivemall.math.matrix.ints.ColumnMajorIntMatrix;
    +import hivemall.math.matrix.ints.DoKIntMatrix;
    +import hivemall.math.matrix.ints.IntMatrix;
    +import hivemall.math.random.PRNG;
    +import hivemall.math.random.RandomNumberGeneratorFactory;
    +import hivemall.math.vector.Vector;
    +import hivemall.math.vector.VectorProcedure;
    +import hivemall.smile.classification.DecisionTree.SplitRule;
    +import hivemall.smile.data.Attribute;
    +import hivemall.smile.tools.MatrixEventStream;
    +import hivemall.smile.tools.SepDelimitedTextGISModelWriter;
    +import hivemall.smile.utils.SmileExtUtils;
    +import hivemall.smile.utils.SmileTaskExecutor;
    +import hivemall.utils.codec.Base91;
    +import hivemall.utils.collections.lists.IntArrayList;
    +import hivemall.utils.hadoop.HiveUtils;
    +import hivemall.utils.hadoop.WritableUtils;
    +import hivemall.utils.lang.Preconditions;
    +import hivemall.utils.lang.Primitives;
    +import hivemall.utils.lang.RandomUtils;
    +
    +import opennlp.maxent.GIS;
    +import opennlp.maxent.io.GISModelWriter;
    +import opennlp.model.AbstractModel;
    +import opennlp.model.Event;
    +import opennlp.model.EventStream;
    +import opennlp.model.OnePassRealValueDataIndexer;
    +
    +@Description(
    +        name = "train_maxent_classifier",
    +        value = "_FUNC_(array<double> features, int label [, const boolean classification])"
    +                + " - Returns a maximum entropy model per subset of data.")
    +@UDFType(deterministic = true, stateful = false)
    +public class MaxEntUDTF extends UDTFWithOptions{
    +	private static final Log logger = LogFactory.getLog(MaxEntUDTF.class);
    +	
    +	private ListObjectInspector featureListOI;
    +    private PrimitiveObjectInspector featureElemOI;
    +    private PrimitiveObjectInspector labelOI;
    +
    +    private MatrixBuilder matrixBuilder;
    +    private IntArrayList labels;
    +    
    +	private boolean _real;
    +	private Attribute[] _attributes;
    +	private static boolean _USE_SMOOTHING;
    +	private double _SMOOTHING_OBSERVATION;
    +	
    +	private int _numTrees = 1;
    +    
    +    @Nullable
    +    private Reporter _progressReporter;
    +    @Nullable
    +    private Counter _treeBuildTaskCounter;
    +    
    +    @Override
    +    protected Options getOptions() {
    +        Options opts = new Options();
    +        opts.addOption("real", "quantative_feature_presence_indication", true,
    +            "true or false [default: true]");
    +        opts.addOption("smoothing", "smoothimg", true, "Shall smoothing be performed
[default: false]");
    +        opts.addOption("constant", "smoothing_constant", true, "real number [default:
1.0]");
    +        opts.addOption("attrs", "attribute_types", true, "Comma separated attribute types
"
    +                + "(Q for quantitative variable and C for categorical variable. e.g.,
[Q,C,Q,C])");
    +        return opts;
    +    }
    +    
    +    @Override
    +    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException
{
    +    	boolean real = true;
    + 	    boolean USE_SMOOTHING = false;
    + 	    double SMOOTHING_OBSERVATION = 0.1;
    + 	    
    +        Attribute[] attrs = null;
    +
    +        CommandLine cl = null;
    +        if (argOIs.length >= 3) {
    +            String rawArgs = HiveUtils.getConstString(argOIs[2]);
    +            cl = parseOptions(rawArgs);
    +
    +            real = Primitives.parseBoolean(cl.getOptionValue("quantative_feature_presence_indication"),
real);
    +            attrs = SmileExtUtils.resolveAttributes(cl.getOptionValue("attribute_types"));
    +            USE_SMOOTHING = Primitives.parseBoolean(cl.getOptionValue("smoothing"), USE_SMOOTHING);
    +            SMOOTHING_OBSERVATION = Primitives.parseDouble(cl.getOptionValue("smoothing_constant"),
SMOOTHING_OBSERVATION);
    +        }
    +
    +        this._real = real;
    +        this._attributes = attrs;
    +        this._USE_SMOOTHING = USE_SMOOTHING;
    +        this._SMOOTHING_OBSERVATION = SMOOTHING_OBSERVATION;
    +
    +        return cl;
    +    }
    +    
    +    @Override
    +    public StructObjectInspector initialize(ObjectInspector[] argOIs) throws UDFArgumentException
{
    +        if (argOIs.length < 2 || argOIs.length > 3) {
    +            throw new UDFArgumentException(
    +                "_FUNC_ takes 2 ~ 3 arguments: array<double> features, int label
[, const string options]: "
    +                        + argOIs.length);
    +        }
    +
    +        ListObjectInspector listOI = HiveUtils.asListOI(argOIs[0]);
    +        ObjectInspector elemOI = listOI.getListElementObjectInspector();
    +        this.featureListOI = listOI;
    +        if (HiveUtils.isNumberOI(elemOI)) {
    +            this.featureElemOI = HiveUtils.asDoubleCompatibleOI(elemOI);
    +            this.matrixBuilder = new CSRMatrixBuilder(8192);
    +        } else {
    +            throw new UDFArgumentException(
    +                "_FUNC_ takes double[] for the first argument: " + listOI.getTypeName());
    +        }
    +        this.labelOI = HiveUtils.asIntCompatibleOI(argOIs[1]);
    +
    +        processOptions(argOIs);
    +
    +        this.labels = new IntArrayList(1024);
    +
    +        final ArrayList<String> fieldNames = new ArrayList<String>(6);
    +        final ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>(6);
    +
    +        fieldNames.add("model_id");
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        fieldNames.add("model_weight");
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableDoubleObjectInspector);
    +        fieldNames.add("model");
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        fieldNames.add("attributes");
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableStringObjectInspector);
    +        fieldNames.add("oob_errors");
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +        fieldNames.add("oob_tests");
    +        fieldOIs.add(PrimitiveObjectInspectorFactory.writableIntObjectInspector);
    +
    +        return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
    +    }
    +    
    +    @Override
    +    public void process(Object[] args) throws HiveException {
    +        if (args[0] == null) {
    +            throw new HiveException("array<double> features was null");
    +        }
    +        parseFeatures(args[0], matrixBuilder);
    +        int label = PrimitiveObjectInspectorUtils.getInt(args[1], labelOI);
    +        labels.add(label);
    +    }
    +    
    +    private void parseFeatures(@Nonnull final Object argObj, @Nonnull final MatrixBuilder
builder) {
    +    	final int length = featureListOI.getListLength(argObj);
    +        for (int i = 0; i < length; i++) {
    +            Object o = featureListOI.getListElement(argObj, i);
    +            if (o == null) {
    +                continue;
    +            }
    +            double v = PrimitiveObjectInspectorUtils.getDouble(o, featureElemOI);
    +            builder.nextColumn(i, v);
    +        } 
    +        builder.nextRow();
    +    }
    +    
    +    @Override
    +    public void close() throws HiveException {
    +        this._progressReporter = getReporter();
    +        this._treeBuildTaskCounter = (_progressReporter == null) ? null
    +                : _progressReporter.getCounter("hivemall.smile.MaxEntClassifier$Counter",
    +                    "finishedGISTask");
    +        reportProgress(_progressReporter);
    +
    +        if (!labels.isEmpty()) {
    +            Matrix x = matrixBuilder.buildMatrix();
    +            this.matrixBuilder = null;
    +            int[] y = labels.toArray();
    +            this.labels = null;
    +
    +            // run training
    +            train(x, y);
    +        }
    +
    +        // clean up
    +        this.featureListOI = null;
    +        this.featureElemOI = null;
    +        this.labelOI = null;
    +    }
    +    
    +    private void checkOptions() throws HiveException {
    +    	if (_USE_SMOOTHING == false && _SMOOTHING_OBSERVATION != 0.1) {
    +            throw new HiveException("Instructions received to avoid smoothing, but smoothing
constant is set [" + _SMOOTHING_OBSERVATION + "]");
    +        }
    +    }
    +    
    +    /**
    +     * @param x features
    +     * @param y label
    +     * @param attrs attribute types
    +     * @param numTrees The number of trees
    +     * @param numVars The number of variables to pick up in each node.
    +     * @param seed The seed number for Random Forest
    +     */
    +    private void train(@Nonnull Matrix x, @Nonnull final int[] y) throws HiveException
{
    +        final int numExamples = x.numRows();
    +        if (numExamples != y.length) {
    +            throw new HiveException(String.format("The sizes of X and Y don't match:
%d != %d",
    +                numExamples, y.length));
    +        }
    +        checkOptions();
    +
    +        int[] labels = SmileExtUtils.classLables(y);
    +        Attribute[] attributes = SmileExtUtils.attributeTypes(_attributes, x);
    +
    +        if (logger.isInfoEnabled()) {
    +            logger.info("real: " + _real + ", smoothing: " + this._USE_SMOOTHING + ",
smoothing constant: "
    +                    + _SMOOTHING_OBSERVATION);
    +        }
    +
    +        IntMatrix prediction = new DoKIntMatrix(numExamples, labels.length); // placeholder
for out-of-bag prediction
    +        AtomicInteger remainingTasks = new AtomicInteger(_numTrees);
    +        List<TrainingTask> tasks = new ArrayList<TrainingTask>();
    +        for (int i = 0; i < _numTrees; i++) {
    +            tasks.add(new TrainingTask(this, i, attributes, x, y, prediction, remainingTasks));
    +        }
    +
    +        MapredContext mapredContext = MapredContextAccessor.get();
    +        final SmileTaskExecutor executor = new SmileTaskExecutor(mapredContext);
    +        try {
    +            executor.run(tasks);
    +        } catch (Exception ex) {
    +            throw new HiveException(ex);
    +        } finally {
    +            executor.shotdown();
    +        }
    +    }
    +    
    +    /**
    +     * Synchronized because {@link #forward(Object)} should be called from a single thread.
    +     * 
    +     * @param accuracy
    +     */
    +    synchronized void forward(final int taskId, @Nonnull final Text model,
    +    		@Nonnull Attribute[] attributes,
    +            @Nonnegative final double accuracy, final int[] y,
    +            @Nonnull final IntMatrix prediction, final boolean lastTask) throws HiveException
{
    +        int oobErrors = 0;
    +        int oobTests = 0;
    +        if (lastTask) {
    +            // out-of-bag error estimate
    --- End diff --
    
    Out-Of-Bag testing is a method for RandomForest. Should not use training examples used
in the input of training.
    https://www.quora.com/What-is-the-out-of-bag-error-in-Random-Forests
    
    OOB testing is not needed for MaxEnt classifier in my understanding.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

Mime
View raw message