hivemall-issues mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From myui <...@git.apache.org>
Subject [GitHub] incubator-hivemall pull request #167: [HIVEMALL-220] Implement Cofactor
Date Thu, 18 Oct 2018 09:57:16 GMT
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/167#discussion_r226243032
  
    --- Diff: core/src/main/java/hivemall/mf/CofactorModel.java ---
    @@ -0,0 +1,638 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + *   http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing,
    + * software distributed under the License is distributed on an
    + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
    + * KIND, either express or implied.  See the License for the
    + * specific language governing permissions and limitations
    + * under the License.
    + */
    +package hivemall.mf;
    +
    +import hivemall.fm.Feature;
    +import hivemall.utils.math.MathUtils;
    +import hivemall.utils.math.MatrixUtils;
    +import org.apache.commons.math3.linear.ArrayRealVector;
    +import org.apache.commons.math3.linear.Array2DRowRealMatrix;
    +import org.apache.commons.math3.linear.RealMatrix;
    +import org.apache.commons.math3.linear.RealVector;
    +import org.apache.commons.math3.linear.SingularValueDecomposition;
    +
    +import javax.annotation.Nonnegative;
    +import javax.annotation.Nonnull;
    +import javax.annotation.Nullable;
    +import java.util.*;
    +
    +public class CofactorModel {
    +
    +    public enum RankInitScheme {
    +        random /* default */, gaussian;
    +
    +        @Nonnegative
    +        protected float maxInitValue;
    +        @Nonnegative
    +        protected double initStdDev;
    +
    +        @Nonnull
    +        public static CofactorModel.RankInitScheme resolve(@Nullable String opt) {
    +            if (opt == null) {
    +                return random;
    +            } else if ("gaussian".equalsIgnoreCase(opt)) {
    +                return gaussian;
    +            } else if ("random".equalsIgnoreCase(opt)) {
    +                return random;
    +            }
    +            return random;
    +        }
    +
    +        public void setMaxInitValue(float maxInitValue) {
    +            this.maxInitValue = maxInitValue;
    +        }
    +
    +        public void setInitStdDev(double initStdDev) {
    +            this.initStdDev = initStdDev;
    +        }
    +
    +    }
    +
    +    private static final int EXPECTED_SIZE = 136861;
    +    @Nonnegative
    +    protected final int factor;
    +
    +    // rank matrix initialization
    +    protected final RankInitScheme initScheme;
    +
    +    @Nonnull
    +    private double globalBias;
    +
    +    // storing trainable latent factors and weights
    +    private Map<String, RealVector> theta;
    +    private Map<String, RealVector> beta;
    +    private Map<String, Double> betaBias;
    +    private Map<String, RealVector> gamma;
    +    private Map<String, Double> gammaBias;
    +
    +    // precomputed identity matrix
    +    private RealMatrix identity;
    +
    +    protected final Random[] randU, randI;
    +
    +    // hyperparameters
    +    private final float c0, c1;
    +    private final float lambdaTheta, lambdaBeta, lambdaGamma;
    +
    +    public CofactorModel(@Nonnegative int factor, @Nonnull RankInitScheme initScheme,
    +                         @Nonnull float c0, @Nonnull float c1, float lambdaTheta,
    +                         float lambdaBeta, float lambdaGamma) {
    +
    +        // rank init scheme is gaussian
    +        // https://github.com/dawenl/cofactor/blob/master/src/cofacto.py#L98
    +        this.factor = factor;
    +        this.initScheme = initScheme;
    +        this.globalBias = 0.d;
    +        this.lambdaTheta = lambdaTheta;
    +        this.lambdaBeta = lambdaBeta;
    +        this.lambdaGamma = lambdaGamma;
    +
    +        this.theta = new HashMap<>();
    +        this.beta = new HashMap<>();
    +        this.betaBias = new HashMap<>();
    +        this.gamma = new HashMap<>();
    +        this.gammaBias = new HashMap<>();
    +
    +        this.randU = newRandoms(factor, 31L);
    +        this.randI = newRandoms(factor, 41L);
    +
    +        checkHyperparameterC(c0);
    +        checkHyperparameterC(c1);
    +        this.c0 = c0;
    +        this.c1 = c1;
    +
    +    }
    +
    +    private void initFactorVector(String key, Map<String, RealVector> weights)
{
    +        if (weights.containsKey(key)) {
    +            return;
    +        }
    +        RealVector v = new ArrayRealVector(factor);
    +        switch (initScheme) {
    +            case random:
    +                uniformFill(v, randI[0], initScheme.maxInitValue);
    +                break;
    +            case gaussian:
    +                gaussianFill(v, randI, initScheme.initStdDev);
    +                break;
    +            default:
    +                throw new IllegalStateException(
    +                        "Unsupported rank initialization scheme: " + initScheme);
    +
    +        }
    +        weights.put(key, v);
    +    }
    +
    +    private static RealVector getFactorVector(String key, Map<String, RealVector>
weights) {
    +        return weights.get(key);
    +    }
    +
    +    private static void setFactorVector(String key, Map<String, RealVector> weights,
RealVector factorVector) {
    +        assert weights.containsKey(key);
    +        weights.put(key, factorVector);
    +    }
    +
    +    private static double getBias(String key, Map<String, Double> biases) {
    +        if (!biases.containsKey(key)) {
    --- End diff --
    
    three hash lookup for worse case...
    
    ```
            final Double v = biases.get(key);
            if(v == null) {
                return 0.d;
            }
            return v.doubleValue();
    ```
    
    Or, 
    
    ```java
        private static double getBias(String key, Object2DoubleMap<String> biases) {
            // biases.defaultReturnValue(0.f); -- set in initialization
            return biases.get(key);
        }
    ```


---

Mime
View raw message