hivemall-issues mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From takuti <...@git.apache.org>
Subject [GitHub] incubator-hivemall pull request #107: [HIVEMALL-132] Generalize f1score UDAF...
Date Mon, 21 Aug 2017 08:17:58 GMT
Github user takuti commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/107#discussion_r134144074
  
    --- Diff: core/src/main/java/hivemall/evaluation/FMeasureUDAF.java ---
    @@ -18,118 +18,387 @@
      */
     package hivemall.evaluation;
     
    -import hivemall.utils.hadoop.WritableUtils;
    +import hivemall.UDAFEvaluatorWithOptions;
    +import hivemall.utils.hadoop.HiveUtils;
     
    +import java.util.ArrayList;
    +import java.util.Arrays;
    +import java.util.Collections;
     import java.util.List;
     
    +import hivemall.utils.lang.Primitives;
    +import org.apache.commons.cli.CommandLine;
    +import org.apache.commons.cli.Options;
    +
     import org.apache.hadoop.hive.ql.exec.Description;
    -import org.apache.hadoop.hive.ql.exec.UDAF;
    -import org.apache.hadoop.hive.ql.exec.UDAFEvaluator;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
    +import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
    +import org.apache.hadoop.hive.ql.metadata.HiveException;
    +import org.apache.hadoop.hive.ql.parse.SemanticException;
    +import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
    +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
     import org.apache.hadoop.hive.serde2.io.DoubleWritable;
    -import org.apache.hadoop.io.IntWritable;
    -
    -@SuppressWarnings("deprecation")
    -@Description(name = "f1score",
    -        value = "_FUNC_(array[int], array[int]) - Return a F-measure/F1 score")
    -public final class FMeasureUDAF extends UDAF {
    -
    -    public static class Evaluator implements UDAFEvaluator {
    -
    -        public static class PartialResult {
    -            long tp;
    -            /** tp + fn */
    -            long totalAcutal;
    -            /** tp + fp */
    -            long totalPredicted;
    -
    -            PartialResult() {
    -                this.tp = 0L;
    -                this.totalPredicted = 0L;
    -                this.totalAcutal = 0L;
    -            }
    +import org.apache.hadoop.hive.serde2.objectinspector.*;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
    +import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector;
    +import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
    +import org.apache.hadoop.io.LongWritable;
     
    -            void updateScore(final List<IntWritable> actual, final List<IntWritable>
predicted) {
    -                final int numActual = actual.size();
    -                final int numPredicted = predicted.size();
    -                int countTp = 0;
    -                for (int i = 0; i < numPredicted; i++) {
    -                    IntWritable p = predicted.get(i);
    -                    if (actual.contains(p)) {
    -                        countTp++;
    -                    }
    +import javax.annotation.Nonnull;
    +
    +@Description(
    +        name = "fmeasure",
    +        value = "_FUNC_(array | int | boolean, array | int | boolean, String) - Return
a F-measure (f1score is the special with beta=1.)")
    +public final class FMeasureUDAF extends AbstractGenericUDAFResolver {
    +    @Override
    +    public GenericUDAFEvaluator getEvaluator(@Nonnull TypeInfo[] typeInfo) throws SemanticException
{
    +        if (typeInfo.length != 2 && typeInfo.length != 3) {
    +            throw new UDFArgumentTypeException(typeInfo.length - 1,
    +                "_FUNC_ takes two or three arguments");
    +        }
    +
    +        boolean isArg1ListOrIntOrBoolean = HiveUtils.isListTypeInfo(typeInfo[0])
    +                || HiveUtils.isIntegerTypeInfo(typeInfo[0])
    +                || HiveUtils.isBooleanTypeInfo(typeInfo[0]);
    +        if (!isArg1ListOrIntOrBoolean) {
    +            throw new UDFArgumentTypeException(0,
    +                "The first argument `array/int/boolean actual` is invalid form: " + typeInfo[0]);
    +        }
    +
    +        boolean isArg2ListOrIntOrBoolean = HiveUtils.isListTypeInfo(typeInfo[1])
    +                || HiveUtils.isIntegerTypeInfo(typeInfo[1])
    +                || HiveUtils.isBooleanTypeInfo(typeInfo[1]);
    +        if (!isArg2ListOrIntOrBoolean) {
    +            throw new UDFArgumentTypeException(1,
    +                "The first argument `array/int/boolean actual` is invalid form: " + typeInfo[1]);
    +        }
    +
    +        if (typeInfo[0] != typeInfo[1]) {
    +            throw new UDFArgumentTypeException(1, "The first argument's `actual` type
is "
    +                    + typeInfo[0] + ", but the second argument `predicated`'s type is
not match: "
    +                    + typeInfo[1]);
    +        }
    +
    +        return new Evaluator();
    +    }
    +
    +    public static class Evaluator extends UDAFEvaluatorWithOptions {
    +
    +        private ObjectInspector actualOI;
    +        private ObjectInspector predictedOI;
    +        private StructObjectInspector internalMergeOI;
    +
    +        private StructField tpField;
    +        private StructField totalActualField;
    +        private StructField totalPredictedField;
    +        private StructField betaOptionField;
    +        private StructField averageOptionFiled;
    +
    +        private double beta;
    +        private String average;
    +
    +        public Evaluator() {}
    +
    +        @Override
    +        protected Options getOptions() {
    +            Options opts = new Options();
    +            opts.addOption("beta", true, "The weight of precision [default: 1.]");
    +            opts.addOption("average", true, "The way of average calculation [default:
micro]");
    +            return opts;
    +        }
    +
    +        @Override
    +        protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException
{
    --- End diff --
    
    Generally, we first store the option values to local variables, and set to the fields
at the end of the method as:
    
    ```java
    @Override
    protected CommandLine processOptions(ObjectInspector[] argOIs) throws UDFArgumentException
{
        CommandLine cl = null;
    
        double beta = 1.d;
        String average = "micro";
    
        if (argOIs.length >= 3) {
            String rawArgs = HiveUtils.getConstString(argOIs[2]);
            cl = parseOptions(rawArgs);
    
            beta = Primitives.parseDouble(cl.getOptionValue("beta"), beta);
            if (beta <= 0.d) {
                throw new UDFArgumentException(
                    "The third argument `double beta` must be greater than 0.0: " + beta);
            }
    
            average = cl.getOptionValue("average", average);
            if (!(average.equals("binary") || average.equals("macro") || average.equals("micro")))
{
                throw new UDFArgumentException(
                    "The third argument `String average` must be one of the {binary, micro,
macro}: "
                            + average);
            }
        }
    
        this.beta = beta;
        this.average = average;
    
        return cl;
    }
    ```
    
    (see other `processOptions` implementations)
    
    The above implementation explicitly defines default values of options at the beginning,
and hence it's more readable and easier to change the default values in the future.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

Mime
View raw message