hivemall-issues mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From myui <...@git.apache.org>
Subject [GitHub] incubator-hivemall pull request #14: [WIP] Separate optimizer implementation...
Date Thu, 26 Jan 2017 07:12:46 GMT
Github user myui commented on a diff in the pull request:

    https://github.com/apache/incubator-hivemall/pull/14#discussion_r97716349
  
    --- Diff: core/src/main/java/hivemall/model/NewSpaceEfficientDenseModel.java ---
    @@ -0,0 +1,317 @@
    +/*
    + * Hivemall: Hive scalable Machine Learning Library
    + *
    + * Copyright (C) 2015 Makoto YUI
    + * Copyright (C) 2013-2015 National Institute of Advanced Industrial Science and Technology
(AIST)
    + *
    + * Licensed under the Apache License, Version 2.0 (the "License");
    + * you may not use this file except in compliance with the License.
    + * You may obtain a copy of the License at
    + *
    + *         http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +package hivemall.model;
    +
    +import hivemall.model.WeightValue.WeightValueWithCovar;
    +import hivemall.utils.collections.IMapIterator;
    +import hivemall.utils.hadoop.HiveUtils;
    +import hivemall.utils.lang.Copyable;
    +import hivemall.utils.lang.HalfFloat;
    +import hivemall.utils.math.MathUtils;
    +
    +import java.util.Arrays;
    +import javax.annotation.Nonnull;
    +
    +import org.apache.commons.logging.Log;
    +import org.apache.commons.logging.LogFactory;
    +
    +public final class NewSpaceEfficientDenseModel extends AbstractPredictionModel {
    +    private static final Log logger = LogFactory.getLog(NewSpaceEfficientDenseModel.class);
    +
    +    private int size;
    +    private short[] weights;
    +    private short[] covars;
    +
    +    // optional value for MIX
    +    private short[] clocks;
    +    private byte[] deltaUpdates;
    +
    +    public NewSpaceEfficientDenseModel(int ndims) {
    +        this(ndims, false);
    +    }
    +
    +    public NewSpaceEfficientDenseModel(int ndims, boolean withCovar) {
    +        super();
    +        int size = ndims + 1;
    +        this.size = size;
    +        this.weights = new short[size];
    +        if (withCovar) {
    +            short[] covars = new short[size];
    +            Arrays.fill(covars, HalfFloat.ONE);
    +            this.covars = covars;
    +        } else {
    +            this.covars = null;
    +        }
    +        this.clocks = null;
    +        this.deltaUpdates = null;
    +    }
    +
    +    @Override
    +    protected boolean isDenseModel() {
    +        return true;
    +    }
    +
    +    @Override
    +    public boolean hasCovariance() {
    +        return covars != null;
    +    }
    +
    +    @Override
    +    public void configureParams(boolean sum_of_squared_gradients, boolean sum_of_squared_delta_x,
    +            boolean sum_of_gradients) {}
    +
    +    @Override
    +    public void configureClock() {
    +        if (clocks == null) {
    +            this.clocks = new short[size];
    +            this.deltaUpdates = new byte[size];
    +        }
    +    }
    +
    +    @Override
    +    public boolean hasClock() {
    +        return clocks != null;
    +    }
    +
    +    @Override
    +    public void resetDeltaUpdates(int feature) {
    +        deltaUpdates[feature] = 0;
    +    }
    +
    +    private float getWeight(final int i) {
    +        final short w = weights[i];
    +        return (w == HalfFloat.ZERO) ? HalfFloat.ZERO : HalfFloat.halfFloatToFloat(w);
    +    }
    +
    +    private float getCovar(final int i) {
    +        return HalfFloat.halfFloatToFloat(covars[i]);
    +    }
    +
    +    private void _setWeight(final int i, final float v) {
    +        if(Math.abs(v) >= HalfFloat.MAX_FLOAT) {
    +            throw new IllegalArgumentException("Acceptable maximum weight is "
    +                    + HalfFloat.MAX_FLOAT + ": " + v);
    +        }
    +        weights[i] = HalfFloat.floatToHalfFloat(v);
    +    }
    +
    +    private void setCovar(final int i, final float v) {
    +        HalfFloat.checkRange(v);
    +        covars[i] = HalfFloat.floatToHalfFloat(v);
    +    }
    +
    +    private void ensureCapacity(final int index) {
    +        if (index >= size) {
    +            int bits = MathUtils.bitsRequired(index);
    +            int newSize = (1 << bits) + 1;
    +            int oldSize = size;
    +            logger.info("Expands internal array size from " + oldSize + " to " + newSize
+ " ("
    +                    + bits + " bits)");
    +            this.size = newSize;
    +            this.weights = Arrays.copyOf(weights, newSize);
    +            if (covars != null) {
    +                this.covars = Arrays.copyOf(covars, newSize);
    +                Arrays.fill(covars, oldSize, newSize, HalfFloat.ONE);
    +            }
    +            if(clocks != null) {
    +                this.clocks = Arrays.copyOf(clocks, newSize);
    +                this.deltaUpdates = Arrays.copyOf(deltaUpdates, newSize);
    +            }
    +        }
    +    }
    +
    +    @SuppressWarnings("unchecked")
    +    @Override
    +    public <T extends IWeightValue> T get(Object feature) {
    +        final int i = HiveUtils.parseInt(feature);
    +        if (i >= size) {
    +            return null;
    +        }
    +
    +        if(covars != null) {
    +            return (T) new WeightValueWithCovar(getWeight(i), getCovar(i));
    +        } else {
    +            return (T) new WeightValue(getWeight(i));
    +        }
    +    }
    +
    +    @Override
    +    public <T extends IWeightValue> void set(Object feature, T value) {
    +        int i = HiveUtils.parseInt(feature);
    +        ensureCapacity(i);
    +        float weight = value.get();
    +        _setWeight(i, weight);
    +        float covar = 1.f;
    +        boolean hasCovar = value.hasCovariance();
    +        if (hasCovar) {
    +            covar = value.getCovariance();
    +            setCovar(i, covar);
    +        }
    +        short clock = 0;
    +        int delta = 0;
    +        if (clocks != null && value.isTouched()) {
    +            clock = (short) (clocks[i] + 1);
    +            clocks[i] = clock;
    +            delta = deltaUpdates[i] + 1;
    +            assert (delta > 0) : delta;
    +            deltaUpdates[i] = (byte) delta;
    +        }
    +
    +        onUpdate(i, weight, covar, clock, delta, hasCovar);
    +    }
    +
    +    @Override
    +    public void delete(@Nonnull Object feature) {
    +        final int i = HiveUtils.parseInt(feature);
    +        if (i >= size) {
    +            return;
    +        }
    +        _setWeight(i, 0.f);
    +        if(covars != null) {
    +            setCovar(i, 1.f);
    +        }
    +        // avoid clock/delta
    +    }
    +
    +    @Override
    +    public float getWeight(Object feature) {
    +        int i = HiveUtils.parseInt(feature);
    +        if (i >= size) {
    +            return 0f;
    +        }
    +        return getWeight(i);
    +    }
    +
    +    @Override
    +    public void setWeight(Object feature, float value) {
    +        int i = HiveUtils.parseInt(feature);
    +        ensureCapacity(i);
    +        _setWeight(i, value);
    +    }
    +
    +    @Override
    +    public float getCovariance(Object feature) {
    +        int i = HiveUtils.parseInt(feature);
    +        if (i >= size) {
    +            return 1f;
    +        }
    +        return getCovar(i);
    +    }
    +
    +    @Override
    +    protected void _set(Object feature, float weight, short clock) {
    +        int i = ((Integer) feature).intValue();
    +        ensureCapacity(i);
    +        _setWeight(i, weight);
    +        clocks[i] = clock;
    +        deltaUpdates[i] = 0;
    +    }
    +
    +    @Override
    +    protected void _set(Object feature, float weight, float covar, short clock) {
    +        int i = ((Integer) feature).intValue();
    +        ensureCapacity(i);
    +        _setWeight(i, weight);
    +        setCovar(i, covar);
    +        clocks[i] = clock;
    +        deltaUpdates[i] = 0;
    +    }
    +
    +    @Override
    +    public int size() {
    +        return size;
    +    }
    +
    +    @Override
    +    public boolean contains(Object feature) {
    +        int i = HiveUtils.parseInt(feature);
    +        if (i >= size) {
    +            return false;
    +        }
    +        float w = getWeight(i);
    +        return w != 0.f;
    +    }
    +
    +    @SuppressWarnings("unchecked")
    +    @Override
    +    public <K, V extends IWeightValue> IMapIterator<K, V> entries() {
    +        return (IMapIterator<K, V>) new Itr();
    +    }
    +
    +    private final class Itr implements IMapIterator<Number, IWeightValue> {
    +
    +        private int cursor;
    +        private final WeightValueWithCovar tmpWeight;
    +
    +        private Itr() {
    +            this.cursor = -1;
    +            this.tmpWeight = new WeightValueWithCovar();
    +        }
    +
    +        @Override
    +        public boolean hasNext() {
    +            return cursor < size;
    +        }
    +
    +        @Override
    +        public int next() {
    +            ++cursor;
    +            if (!hasNext()) {
    +                return -1;
    +            }
    +            return cursor;
    +        }
    +
    +        @Override
    +        public Integer getKey() {
    +            return cursor;
    +        }
    +
    +        @Override
    +        public IWeightValue getValue() {
    +            if (covars == null) {
    --- End diff --
    
    okey as it is.


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastructure@apache.org or file a JIRA ticket
with INFRA.
---

Mime
View raw message