ignite-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From ptupit...@apache.org
Subject [3/8] ignite git commit: IGNITE-7350: Distributed MLP cleanup/refactoring
Date Wed, 17 Jan 2018 07:24:13 GMT
http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/NesterovParameterUpdate.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/NesterovParameterUpdate.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/NesterovParameterUpdate.java
new file mode 100644
index 0000000..b494b14
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/NesterovParameterUpdate.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.optimization.updatecalculators;
+
+import java.io.Serializable;
+import java.util.List;
+import java.util.Objects;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+
+/**
+ * Data needed for Nesterov parameters updater.
+ */
+public class NesterovParameterUpdate implements Serializable {
+    /**
+     * Previous step weights updates.
+     */
+    protected Vector prevIterationUpdates;
+
+    /**
+     * Construct NesterovParameterUpdate.
+     *
+     * @param paramsCnt Count of parameters on which updateCache happens.
+     */
+    public NesterovParameterUpdate(int paramsCnt) {
+        prevIterationUpdates = new DenseLocalOnHeapVector(paramsCnt).assign(0);
+    }
+
+    /**
+     * Construct NesterovParameterUpdate.
+     *
+     * @param prevIterationUpdates Previous iteration updates.
+     */
+    public NesterovParameterUpdate(Vector prevIterationUpdates) {
+        this.prevIterationUpdates = prevIterationUpdates;
+    }
+
+    /**
+     * Set previous step parameters updates.
+     *
+     * @param updates Parameters updates.
+     * @return This object with updated parameters updates.
+     */
+    public NesterovParameterUpdate setPreviousUpdates(Vector updates) {
+        prevIterationUpdates = updates;
+        return this;
+    }
+
+    /**
+     * Get previous step parameters updates.
+     *
+     * @return Previous step parameters updates.
+     */
+    public Vector prevIterationUpdates() {
+        return prevIterationUpdates;
+    }
+
+    /**
+     * Get sum of parameters updates.
+     *
+     * @param parameters Parameters to sum.
+     * @return Sum of parameters updates.
+     */
+    public static NesterovParameterUpdate sum(List<NesterovParameterUpdate> parameters) {
+        return parameters.stream().filter(Objects::nonNull).map(NesterovParameterUpdate::prevIterationUpdates)
+            .reduce(Vector::plus).map(NesterovParameterUpdate::new).orElse(null);
+    }
+
+    /**
+     * Get average of parameters updates.
+     *
+     * @param parameters Parameters to average.
+     * @return Average of parameters updates.
+     */
+    public static NesterovParameterUpdate avg(List<NesterovParameterUpdate> parameters) {
+        NesterovParameterUpdate sum = sum(parameters);
+        return sum != null ? sum.setPreviousUpdates(sum.prevIterationUpdates().divide(parameters.size())) : null;
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/NesterovUpdateCalculator.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/NesterovUpdateCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/NesterovUpdateCalculator.java
new file mode 100644
index 0000000..2bee506
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/NesterovUpdateCalculator.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.optimization.updatecalculators;
+
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.optimization.SmoothParametrized;
+
+/**
+ * Class encapsulating Nesterov algorithm for MLP parameters updateCache.
+ */
+public class NesterovUpdateCalculator<M extends SmoothParametrized<M>>
+    implements ParameterUpdateCalculator<M, NesterovParameterUpdate> {
+    /**
+     * Learning rate.
+     */
+    private final double learningRate;
+
+    /**
+     * Loss function.
+     */
+    private IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
+
+    /**
+     * Momentum constant.
+     */
+    protected double momentum;
+
+    /**
+     * Construct NesterovUpdateCalculator.
+     *
+     * @param momentum Momentum constant.
+     */
+    public NesterovUpdateCalculator(double learningRate, double momentum) {
+        this.learningRate = learningRate;
+        this.momentum = momentum;
+    }
+
+    /** {@inheritDoc} */
+    @Override public NesterovParameterUpdate calculateNewUpdate(M mdl,
+        NesterovParameterUpdate updaterParameters, int iteration, Matrix inputs, Matrix groundTruth) {
+        Vector prevUpdates = updaterParameters.prevIterationUpdates();
+
+        M newMdl = mdl;
+
+        if (iteration > 0) {
+            Vector curParams = mdl.parameters();
+            newMdl = mdl.withParameters(curParams.minus(prevUpdates.times(momentum)));
+        }
+
+        Vector gradient = newMdl.differentiateByParameters(loss, inputs, groundTruth);
+
+        return new NesterovParameterUpdate(prevUpdates.plus(gradient.times(learningRate)));
+    }
+
+    /** {@inheritDoc} */
+    @Override public NesterovParameterUpdate init(M mdl,
+        IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) {
+        this.loss = loss;
+
+        return new NesterovParameterUpdate(mdl.parametersCount());
+    }
+
+    /** {@inheritDoc} */
+    @Override public <M1 extends M> M1 update(M1 obj, NesterovParameterUpdate update) {
+        Vector parameters = obj.parameters();
+        return (M1)obj.setParameters(parameters.minus(update.prevIterationUpdates()));
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/ParameterUpdateCalculator.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/ParameterUpdateCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/ParameterUpdateCalculator.java
new file mode 100644
index 0000000..92f7583
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/ParameterUpdateCalculator.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.optimization.updatecalculators;
+
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+
+/**
+ * Interface for classes encapsulating parameters updateCache logic.
+ *
+ * @param <M> Type of model to be updated.
+ * @param <P> Type of parameters needed for this update calculator.
+ */
+public interface ParameterUpdateCalculator<M, P> {
+    /**
+     * Initializes the update calculator.
+     *
+     * @param mdl Model to be trained.
+     * @param loss Loss function.
+     * @return Initialized parameters.
+     */
+    P init(M mdl, IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss);
+
+    /**
+     * Calculate new update.
+     *
+     * @param mdl Model to be updated.
+     * @param updaterParameters Updater parameters to updateCache.
+     * @param iteration Current trainer iteration.
+     * @param inputs Inputs.
+     * @param groundTruth True values.
+     * @return Updated parameters.
+     */
+    P calculateNewUpdate(M mdl, P updaterParameters, int iteration, Matrix inputs, Matrix groundTruth);
+
+    /**
+     * Update given obj with this parameters.
+     *
+     * @param obj Object to be updated.
+     */
+    <M1 extends M> M1 update(M1 obj, P update);
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropParameterUpdate.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropParameterUpdate.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropParameterUpdate.java
new file mode 100644
index 0000000..fd0a045
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropParameterUpdate.java
@@ -0,0 +1,228 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.optimization.updatecalculators;
+
+import java.io.Serializable;
+import java.util.List;
+import java.util.Objects;
+import java.util.stream.Collectors;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.VectorUtils;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+
+/**
+ * Data needed for RProp updater.
+ * <p>
+ * See <a href="https://paginas.fe.up.pt/~ee02162/dissertacao/RPROP%20paper.pdf">RProp</a>.</p>
+ */
+public class RPropParameterUpdate implements Serializable {
+    /**
+     * Previous iteration parameters updates. In original paper they are labeled with "delta w".
+     */
+    protected Vector prevIterationUpdates;
+
+    /**
+     * Previous iteration model partial derivatives by parameters.
+     */
+    protected Vector prevIterationGradient;
+    /**
+     * Previous iteration parameters deltas. In original paper they are labeled with "delta".
+     */
+    protected Vector deltas;
+
+    /**
+     * Updates mask (values by which updateCache is multiplied).
+     */
+    protected Vector updatesMask;
+
+    /**
+     * Construct RPropParameterUpdate.
+     *
+     * @param paramsCnt Parameters count.
+     * @param initUpdate Initial updateCache (in original work labeled as "delta_0").
+     */
+    RPropParameterUpdate(int paramsCnt, double initUpdate) {
+        prevIterationUpdates = new DenseLocalOnHeapVector(paramsCnt);
+        prevIterationGradient = new DenseLocalOnHeapVector(paramsCnt);
+        deltas = new DenseLocalOnHeapVector(paramsCnt).assign(initUpdate);
+        updatesMask = new DenseLocalOnHeapVector(paramsCnt);
+    }
+
+    /**
+     * Construct instance of this class by given parameters.
+     *
+     * @param prevIterationUpdates Previous iteration parameters updates.
+     * @param prevIterationGradient Previous iteration model partial derivatives by parameters.
+     * @param deltas Previous iteration parameters deltas.
+     * @param updatesMask Updates mask.
+     */
+    public RPropParameterUpdate(Vector prevIterationUpdates, Vector prevIterationGradient,
+        Vector deltas, Vector updatesMask) {
+        this.prevIterationUpdates = prevIterationUpdates;
+        this.prevIterationGradient = prevIterationGradient;
+        this.deltas = deltas;
+        this.updatesMask = updatesMask;
+    }
+
+    /**
+     * Get bias deltas.
+     *
+     * @return Bias deltas.
+     */
+    Vector deltas() {
+        return deltas;
+    }
+
+    /**
+     * Get previous iteration biases updates. In original paper they are labeled with "delta w".
+     *
+     * @return Biases updates.
+     */
+    Vector prevIterationUpdates() {
+        return prevIterationUpdates;
+    }
+
+    /**
+     * Set previous iteration parameters updates. In original paper they are labeled with "delta w".
+     *
+     * @param updates New parameters updates value.
+     * @return This object.
+     */
+    private RPropParameterUpdate setPrevIterationUpdates(Vector updates) {
+        prevIterationUpdates = updates;
+
+        return this;
+    }
+
+    /**
+     * Get previous iteration loss function partial derivatives by parameters.
+     *
+     * @return Previous iteration loss function partial derivatives by parameters.
+     */
+    Vector prevIterationGradient() {
+        return prevIterationGradient;
+    }
+
+    /**
+     * Set previous iteration loss function partial derivatives by parameters.
+     *
+     * @return This object.
+     */
+    private RPropParameterUpdate setPrevIterationGradient(Vector gradient) {
+        prevIterationGradient = gradient;
+        return this;
+    }
+
+    /**
+     * Get updates mask (values by which updateCache is multiplied).
+     *
+     * @return Updates mask (values by which updateCache is multiplied).
+     */
+    public Vector updatesMask() {
+        return updatesMask;
+    }
+
+    /**
+     * Set updates mask (values by which updateCache is multiplied).
+     *
+     * @param updatesMask New updatesMask.
+     * @return This object.
+     */
+    public RPropParameterUpdate setUpdatesMask(Vector updatesMask) {
+        this.updatesMask = updatesMask;
+
+        return this;
+    }
+
+    /**
+     * Set previous iteration deltas.
+     *
+     * @param deltas New deltas.
+     * @return This object.
+     */
+    public RPropParameterUpdate setDeltas(Vector deltas) {
+        this.deltas = deltas;
+
+        return this;
+    }
+
+    /**
+     * Sums updates during one training.
+     *
+     * @param updates Updates.
+     * @return Sum of updates during one training.
+     */
+    public static RPropParameterUpdate sumLocal(List<RPropParameterUpdate> updates) {
+        List<RPropParameterUpdate> nonNullUpdates = updates.stream().filter(Objects::nonNull)
+            .collect(Collectors.toList());
+
+        if (nonNullUpdates.isEmpty())
+            return null;
+
+        Vector newDeltas = nonNullUpdates.get(nonNullUpdates.size() - 1).deltas();
+        Vector newGradient = nonNullUpdates.get(nonNullUpdates.size() - 1).prevIterationGradient();
+        Vector totalUpdate = nonNullUpdates.stream().map(pu -> VectorUtils.elementWiseTimes(pu.updatesMask().copy(),
+            pu.prevIterationUpdates())).reduce(Vector::plus).orElse(null);
+
+        return new RPropParameterUpdate(totalUpdate, newGradient, newDeltas,
+            new DenseLocalOnHeapVector(newDeltas.size()).assign(1.0));
+    }
+
+    /**
+     * Sums updates returned by different trainings.
+     *
+     * @param updates Updates.
+     * @return Sum of updates during returned by different trainings.
+     */
+    public static RPropParameterUpdate sum(List<RPropParameterUpdate> updates) {
+        Vector totalUpdate = updates.stream().filter(Objects::nonNull)
+            .map(pu -> VectorUtils.elementWiseTimes(pu.updatesMask().copy(), pu.prevIterationUpdates()))
+            .reduce(Vector::plus).orElse(null);
+        Vector totalDelta = updates.stream().filter(Objects::nonNull)
+            .map(RPropParameterUpdate::deltas).reduce(Vector::plus).orElse(null);
+        Vector totalGradient = updates.stream().filter(Objects::nonNull)
+            .map(RPropParameterUpdate::prevIterationGradient).reduce(Vector::plus).orElse(null);
+
+        if (totalUpdate != null)
+            return new RPropParameterUpdate(totalUpdate, totalGradient, totalDelta,
+                new DenseLocalOnHeapVector(Objects.requireNonNull(totalDelta).size()).assign(1.0));
+
+        return null;
+    }
+
+    /**
+     * Averages updates returned by different trainings.
+     *
+     * @param updates Updates.
+     * @return Averages of updates during returned by different trainings.
+     */
+    public static RPropParameterUpdate avg(List<RPropParameterUpdate> updates) {
+        List<RPropParameterUpdate> nonNullUpdates = updates.stream()
+            .filter(Objects::nonNull).collect(Collectors.toList());
+        int size = nonNullUpdates.size();
+
+        RPropParameterUpdate sum = sum(updates);
+        if (sum != null)
+            return sum.
+                setPrevIterationGradient(sum.prevIterationGradient().divide(size)).
+                setPrevIterationUpdates(sum.prevIterationUpdates().divide(size)).
+                setDeltas(sum.deltas().divide(size));
+
+        return null;
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropUpdateCalculator.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropUpdateCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropUpdateCalculator.java
new file mode 100644
index 0000000..80345d9
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/RPropUpdateCalculator.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.optimization.updatecalculators;
+
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.VectorUtils;
+import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.util.MatrixUtil;
+import org.apache.ignite.ml.optimization.SmoothParametrized;
+
+/**
+ * Class encapsulating RProp algorithm.
+ * <p>
+ * See <a href="https://paginas.fe.up.pt/~ee02162/dissertacao/RPROP%20paper.pdf">RProp</a>.</p>
+ */
+public class RPropUpdateCalculator<M extends SmoothParametrized> implements ParameterUpdateCalculator<M, RPropParameterUpdate> {
+    /**
+     * Default initial update.
+     */
+    private static double DFLT_INIT_UPDATE = 0.1;
+
+    /**
+     * Default acceleration rate.
+     */
+    private static double DFLT_ACCELERATION_RATE = 1.2;
+
+    /**
+     * Default deacceleration rate.
+     */
+    private static double DFLT_DEACCELERATION_RATE = 0.5;
+
+    /**
+     * Initial update.
+     */
+    private final double initUpdate;
+
+    /**
+     * Acceleration rate.
+     */
+    private final double accelerationRate;
+
+    /**
+     * Deacceleration rate.
+     */
+    private final double deaccelerationRate;
+
+    /**
+     * Maximal value for update.
+     */
+    private final static double UPDATE_MAX = 50.0;
+
+    /**
+     * Minimal value for update.
+     */
+    private final static double UPDATE_MIN = 1E-6;
+
+    /**
+     * Loss function.
+     */
+    protected IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
+
+    /**
+     * Construct RPropUpdateCalculator.
+     *
+     * @param initUpdate Initial update.
+     * @param accelerationRate Acceleration rate.
+     * @param deaccelerationRate Deacceleration rate.
+     */
+    public RPropUpdateCalculator(double initUpdate, double accelerationRate, double deaccelerationRate) {
+        this.initUpdate = initUpdate;
+        this.accelerationRate = accelerationRate;
+        this.deaccelerationRate = deaccelerationRate;
+    }
+
+    /**
+     * Construct RPropUpdateCalculator with default parameters.
+     */
+    public RPropUpdateCalculator() {
+        this(DFLT_INIT_UPDATE, DFLT_ACCELERATION_RATE, DFLT_DEACCELERATION_RATE);
+    }
+
+    /** {@inheritDoc} */
+    @Override public RPropParameterUpdate calculateNewUpdate(SmoothParametrized mdl, RPropParameterUpdate updaterParams,
+        int iteration, Matrix inputs, Matrix groundTruth) {
+        Vector gradient = mdl.differentiateByParameters(loss, inputs, groundTruth);
+        Vector prevGradient = updaterParams.prevIterationGradient();
+        Vector derSigns;
+
+        if (prevGradient != null)
+            derSigns = VectorUtils.zipWith(prevGradient, gradient, (x, y) -> Math.signum(x * y));
+        else
+            derSigns = gradient.like(gradient.size()).assign(1.0);
+
+        Vector newDeltas = updaterParams.deltas().copy().map(derSigns, (prevDelta, sign) -> {
+            if (sign > 0)
+                return Math.min(prevDelta * accelerationRate, UPDATE_MAX);
+            else if (sign < 0)
+                return Math.max(prevDelta * deaccelerationRate, UPDATE_MIN);
+            else
+                return prevDelta;
+        });
+
+        Vector newPrevIterationUpdates = MatrixUtil.zipWith(gradient, updaterParams.deltas(), (der, delta, i) -> {
+            if (derSigns.getX(i) >= 0)
+                return -Math.signum(der) * delta;
+
+            return updaterParams.prevIterationUpdates().getX(i);
+        });
+
+        Vector updatesMask = MatrixUtil.zipWith(derSigns, updaterParams.prevIterationUpdates(), (sign, upd, i) -> {
+            if (sign < 0)
+                gradient.setX(i, 0.0);
+
+            if (sign >= 0)
+                return 1.0;
+            else
+                return -1.0;
+        });
+
+        return new RPropParameterUpdate(newPrevIterationUpdates, gradient.copy(), newDeltas, updatesMask);
+    }
+
+    /** {@inheritDoc} */
+    @Override public RPropParameterUpdate init(M mdl,
+        IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) {
+        this.loss = loss;
+        return new RPropParameterUpdate(mdl.parametersCount(), initUpdate);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <M1 extends M> M1 update(M1 obj, RPropParameterUpdate update) {
+        Vector updatesToAdd = VectorUtils.elementWiseTimes(update.updatesMask().copy(), update.prevIterationUpdates());
+        return (M1)obj.setParameters(obj.parameters().plus(updatesToAdd));
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDParameter.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDParameter.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDParameter.java
new file mode 100644
index 0000000..22fc18a
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDParameter.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.optimization.updatecalculators;
+
+import java.io.Serializable;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
+
+/**
+ * Parameters for {@link SimpleGDUpdateCalculator}.
+ */
+public class SimpleGDParameter implements Serializable {
+    /**
+     * Gradient.
+     */
+    private Vector gradient;
+
+    /**
+     * Learning rate.
+     */
+    private double learningRate;
+
+    /**
+     * Construct instance of this class.
+     *
+     * @param paramsCnt Count of parameters.
+     * @param learningRate Learning rate.
+     */
+    public SimpleGDParameter(int paramsCnt, double learningRate) {
+        gradient = new DenseLocalOnHeapVector(paramsCnt);
+        this.learningRate = learningRate;
+    }
+
+    /**
+     * Construct instance of this class.
+     *
+     * @param gradient Gradient.
+     * @param learningRate Learning rate.
+     */
+    public SimpleGDParameter(Vector gradient, double learningRate) {
+        this.gradient = gradient;
+        this.learningRate = learningRate;
+    }
+
+    /**
+     * Get gradient.
+     *
+     * @return Get gradient.
+     */
+    public Vector gradient() {
+        return gradient;
+    }
+
+    /**
+     * Get learning rate.
+     *
+     * @return learning rate.
+     */
+    public double learningRate() {
+        return learningRate;
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDUpdateCalculator.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDUpdateCalculator.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDUpdateCalculator.java
new file mode 100644
index 0000000..291e63d
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/SimpleGDUpdateCalculator.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.optimization.updatecalculators;
+
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.optimization.SmoothParametrized;
+
+/**
+ * Simple gradient descent parameters updater.
+ */
+public class SimpleGDUpdateCalculator<M extends SmoothParametrized> implements ParameterUpdateCalculator<M, SimpleGDParameter> {
+    /**
+     * Learning rate.
+     */
+    private double learningRate;
+
+    /**
+     * Loss function.
+     */
+    protected IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
+
+    /**
+     * Construct SimpleGDUpdateCalculator.
+     *
+     * @param learningRate Learning rate.
+     */
+    public SimpleGDUpdateCalculator(double learningRate) {
+        this.learningRate = learningRate;
+    }
+
+    /** {@inheritDoc} */
+    @Override public SimpleGDParameter init(M mdl,
+        IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) {
+        this.loss = loss;
+        return new SimpleGDParameter(mdl.parametersCount(), learningRate);
+    }
+
+    /** {@inheritDoc} */
+    @Override public SimpleGDParameter calculateNewUpdate(SmoothParametrized mlp, SimpleGDParameter updaterParameters,
+        int iteration, Matrix inputs, Matrix groundTruth) {
+        return new SimpleGDParameter(mlp.differentiateByParameters(loss, inputs, groundTruth), learningRate);
+    }
+
+    /** {@inheritDoc} */
+    @Override public <M1 extends M> M1 update(M1 obj, SimpleGDParameter update) {
+        Vector params = obj.parameters();
+        return (M1)obj.setParameters(params.minus(update.gradient().times(learningRate)));
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/package-info.java
new file mode 100644
index 0000000..071dc13
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/updatecalculators/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains update calculators.
+ */
+package org.apache.ignite.ml.optimization.updatecalculators;
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/main/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducer.java b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducer.java
index 7a5f90b..20f861e 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/optimization/util/SparseDistributedMatrixMapReducer.java
@@ -70,7 +70,7 @@ public class SparseDistributedMatrixMapReducer {
                     for (RowColMatrixKey key : locKeys) {
                         Map<Integer, Double> row = storage.cache().get(key);
 
-                        for (Map.Entry<Integer,Double> cell : row.entrySet())
+                        for (Map.Entry<Integer, Double> cell : row.entrySet())
                             locMatrix.set(idx, cell.getKey(), cell.getValue());
 
                         idx++;

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/main/java/org/apache/ignite/ml/trainers/Trainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/Trainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/Trainer.java
index 7540d6f..5efdf57 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/Trainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/Trainer.java
@@ -21,7 +21,8 @@ import org.apache.ignite.ml.Model;
 
 /** Trainer interface. */
 public interface Trainer<M extends Model, T> {
-    /** Train the model based on provided data.
+    /**
+     * Train the model based on provided data.
      *
      * @param data Data for training.
      * @return Trained model.

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerDistributedStep.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerDistributedStep.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerDistributedStep.java
index 67dcf7f..08e1f47 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerDistributedStep.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/MetaoptimizerDistributedStep.java
@@ -27,8 +27,17 @@ import org.apache.ignite.ml.trainers.group.chain.EntryAndContext;
 import org.apache.ignite.ml.trainers.group.chain.HasTrainingUUID;
 
 /**
- * Distributed step
- * TODO: IGNITE-7350: add full description.
+ * Distributed step based on {@link Metaoptimizer}.
+ *
+ * @param <L> Type of local context.
+ * @param <K> Type of data in {@link GroupTrainerCacheKey}.
+ * @param <V> Type of values of cache on which training is done.
+ * @param <G> Type of distributed context.
+ * @param <I> Type of data to which data returned by distributed initialization is mapped (see {@link Metaoptimizer}).
+ * @param <O> Type of data to which data returned by data processor is mapped (see {@link Metaoptimizer}).
+ * @param <X> Type of data which is processed in training loop step (see {@link Metaoptimizer}).
+ * @param <Y> Type of data returned by training loop step data processor (see {@link Metaoptimizer}).
+ * @param <D> Type of data returned by initialization (see {@link Metaoptimizer}).
  */
 class MetaoptimizerDistributedStep<L extends HasTrainingUUID, K, V, G, I extends Serializable, O extends Serializable,
     X, Y, D extends Serializable> implements DistributedEntryProcessingStep<L, K, V, G, I, O> {

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/ComputationsChain.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/ComputationsChain.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/ComputationsChain.java
index 534b5f9..3c3bdab 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/ComputationsChain.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/group/chain/ComputationsChain.java
@@ -53,7 +53,7 @@ import org.apache.ignite.ml.trainers.group.ResultAndUpdates;
  * @param <V> Type of cache values.
  * @param <I> Type of input of this chain.
  * @param <O> Type of output of this chain.
- * // TODO: IGNITE-7350 check if it is possible to integrate with {@link EntryProcessor}.
+ * // TODO: IGNITE-7405 check if it is possible to integrate with {@link EntryProcessor}.
  */
 @FunctionalInterface
 public interface ComputationsChain<L extends HasTrainingUUID, K, V, I, O> {
@@ -229,7 +229,7 @@ public interface ComputationsChain<L extends HasTrainingUUID, K, V, I, O> {
     }
 
     /**
-     * Combine two this chain to other: feed this chain as input to other, pass same context as second argument to both chains
+     * Combine this chain with other: feed this chain as input to other, pass same context as second argument to both chains
      * process method.
      *
      * @param next Next chain.

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainer.java
new file mode 100644
index 0000000..ab31f9f
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainer.java
@@ -0,0 +1,178 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.local;
+
+import org.apache.ignite.IgniteLogger;
+import org.apache.ignite.lang.IgniteBiTuple;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.Trainer;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.Vector;
+import org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.math.util.MatrixUtil;
+import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator;
+
+/**
+ * Batch trainer. This trainer is not distributed on the cluster, but input can theoretically read data from
+ * Ignite cache.
+ */
+public class LocalBatchTrainer<M extends Model<Matrix, Matrix>, P>
+    implements Trainer<M, LocalBatchTrainerInput<M>> {
+    /**
+     * Supplier for updater function.
+     */
+    private final IgniteSupplier<ParameterUpdateCalculator<M, P>> updaterSupplier;
+
+    /**
+     * Error threshold.
+     */
+    private final double errorThreshold;
+
+    /**
+     * Maximal iterations count.
+     */
+    private final int maxIterations;
+
+    /**
+     * Loss function.
+     */
+    private final IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss;
+
+    /**
+     * Logger.
+     */
+    private IgniteLogger log;
+
+    /**
+     * Construct a trainer.
+     *
+     * @param loss Loss function.
+     * @param updaterSupplier Supplier of updater function.
+     * @param errorThreshold Error threshold.
+     * @param maxIterations Maximal iterations count.
+     */
+    public LocalBatchTrainer(IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss,
+        IgniteSupplier<ParameterUpdateCalculator<M, P>> updaterSupplier, double errorThreshold, int maxIterations) {
+        this.loss = loss;
+        this.updaterSupplier = updaterSupplier;
+        this.errorThreshold = errorThreshold;
+        this.maxIterations = maxIterations;
+    }
+
+    /** {@inheritDoc} */
+    @Override public M train(LocalBatchTrainerInput<M> data) {
+        int i = 0;
+        M mdl = data.mdl();
+        double err;
+
+        ParameterUpdateCalculator<M, P> updater = updaterSupplier.get();
+
+        P updaterParams = updater.init(mdl, loss);
+
+        while (i < maxIterations) {
+            IgniteBiTuple<Matrix, Matrix> batch = data.batchSupplier().get();
+            Matrix input = batch.get1();
+            Matrix truth = batch.get2();
+
+            updaterParams = updater.calculateNewUpdate(mdl, updaterParams, i, input, truth);
+
+            // Update mdl with updater parameters.
+            mdl = updater.update(mdl, updaterParams);
+
+            Matrix predicted = mdl.apply(input);
+
+            int batchSize = input.columnSize();
+
+            err = MatrixUtil.zipFoldByColumns(predicted, truth, (predCol, truthCol) ->
+                loss.apply(truthCol).apply(predCol)).sum() / batchSize;
+
+            debug("Error: " + err);
+
+            if (err < errorThreshold)
+                break;
+
+            i++;
+        }
+
+        return mdl;
+    }
+
+    /**
+     * Construct new trainer with the same parameters as this trainer, but with new loss.
+     *
+     * @param loss New loss function.
+     * @return new trainer with the same parameters as this trainer, but with new loss.
+     */
+    public LocalBatchTrainer withLoss(IgniteFunction<Vector, IgniteDifferentiableVectorToDoubleFunction> loss) {
+        return new LocalBatchTrainer<>(loss, updaterSupplier, errorThreshold, maxIterations);
+    }
+
+    /**
+     * Construct new trainer with the same parameters as this trainer, but with new updater supplier.
+     *
+     * @param updaterSupplier New updater supplier.
+     * @return new trainer with the same parameters as this trainer, but with new updater supplier.
+     */
+    public LocalBatchTrainer withUpdater(IgniteSupplier<ParameterUpdateCalculator<M, P>> updaterSupplier) {
+        return new LocalBatchTrainer<>(loss, updaterSupplier, errorThreshold, maxIterations);
+    }
+
+    /**
+     * Construct new trainer with the same parameters as this trainer, but with new error threshold.
+     *
+     * @param errorThreshold New error threshold.
+     * @return new trainer with the same parameters as this trainer, but with new error threshold.
+     */
+    public LocalBatchTrainer withErrorThreshold(double errorThreshold) {
+        return new LocalBatchTrainer<>(loss, updaterSupplier, errorThreshold, maxIterations);
+    }
+
+    /**
+     * Construct new trainer with the same parameters as this trainer, but with new maximal iterations count.
+     *
+     * @param maxIterations New maximal iterations count.
+     * @return new trainer with the same parameters as this trainer, but with new maximal iterations count.
+     */
+    public LocalBatchTrainer withMaxIterations(int maxIterations) {
+        return new LocalBatchTrainer<>(loss, updaterSupplier, errorThreshold, maxIterations);
+    }
+
+    /**
+     * Set logger.
+     *
+     * @param log Logger.
+     * @return This object.
+     */
+    public LocalBatchTrainer setLogger(IgniteLogger log) {
+        this.log = log;
+
+        return this;
+    }
+
+    /**
+     * Output debug message.
+     *
+     * @param msg Message.
+     */
+    private void debug(String msg) {
+        if (log != null && log.isDebugEnabled())
+            log.debug(msg);
+    }
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainerInput.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainerInput.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainerInput.java
new file mode 100644
index 0000000..38b7592
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/LocalBatchTrainerInput.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.local;
+
+import org.apache.ignite.lang.IgniteBiTuple;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.Matrix;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+
+/**
+ * Interface for classes containing input parameters for LocalBatchTrainer.
+ */
+public interface LocalBatchTrainerInput<M extends Model<Matrix, Matrix>> {
+    /**
+     * Get supplier of next batch in form of matrix of inputs and matrix of outputs.
+     *
+     * @return Supplier of next batch.
+     */
+    IgniteSupplier<IgniteBiTuple<Matrix, Matrix>> batchSupplier();
+
+    /**
+     * Model to train.
+     *
+     * @return Model to train.
+     */
+    M mdl();
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/package-info.java
new file mode 100644
index 0000000..8a15b73
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/local/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains local trainers.
+ */
+package org.apache.ignite.ml.trainers.local;
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
index 4472300..206e1e9 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
@@ -60,21 +60,22 @@ public class Utils {
     }
 
     /**
-     * Select k distinct integers from range [0, n) with reservoir sampling: https://en.wikipedia.org/wiki/Reservoir_sampling.
+     * Select k distinct integers from range [0, n) with reservoir sampling:
+     * https://en.wikipedia.org/wiki/Reservoir_sampling.
      *
      * @param n Number specifying left end of range of integers to pick values from.
      * @param k Count specifying how many integers should be picked.
+     * @param rand RNG.
      * @return Array containing k distinct integers from range [0, n);
      */
-    public static int[] selectKDistinct(int n, int k) {
+    public static int[] selectKDistinct(int n, int k, Random rand) {
         int i;
+        Random r = rand != null ? rand : new Random();
 
         int res[] = new int[k];
         for (i = 0; i < k; i++)
             res[i] = i;
 
-        Random r = new Random();
-
         for (; i < n; i++) {
             int j = r.nextInt(i + 1);
 
@@ -84,4 +85,17 @@ public class Utils {
 
         return res;
     }
+
+    /**
+     * Select k distinct integers from range [0, n) with reservoir sampling:
+     * https://en.wikipedia.org/wiki/Reservoir_sampling.
+     * Equivalent to {@code selectKDistinct(n, k, new Random())}.
+     *
+     * @param n Number specifying left end of range of integers to pick values from.
+     * @param k Count specifying how many integers should be picked.
+     * @return Array containing k distinct integers from range [0, n);
+     */
+    public static int[] selectKDistinct(int n, int k) {
+        return selectKDistinct(n, k, new Random());
+    }
 }

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
index 35ffdbc..d5d6d94 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
@@ -21,6 +21,7 @@ import org.apache.ignite.ml.clustering.ClusteringTestSuite;
 import org.apache.ignite.ml.knn.KNNTestSuite;
 import org.apache.ignite.ml.math.MathImplMainTestSuite;
 import org.apache.ignite.ml.nn.MLPTestSuite;
+import org.apache.ignite.ml.optimization.OptimizationTestSuite;
 import org.apache.ignite.ml.regressions.RegressionsTestSuite;
 import org.apache.ignite.ml.trainers.group.TrainersGroupTestSuite;
 import org.apache.ignite.ml.trees.DecisionTreesTestSuite;
@@ -39,7 +40,8 @@ import org.junit.runners.Suite;
     KNNTestSuite.class,
     LocalModelsTest.class,
     MLPTestSuite.class,
-    TrainersGroupTestSuite.class
+    TrainersGroupTestSuite.class,
+    OptimizationTestSuite.class
 })
 public class IgniteMLTestSuite {
     // No-op.

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPGroupTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPGroupTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPGroupTrainerTest.java
index 7f990c9..151fead 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPGroupTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPGroupTrainerTest.java
@@ -31,7 +31,7 @@ import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
 import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
 import org.apache.ignite.ml.nn.initializers.RandomInitializer;
 import org.apache.ignite.ml.nn.trainers.distributed.MLPGroupUpdateTrainer;
-import org.apache.ignite.ml.nn.updaters.RPropParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
 import org.apache.ignite.ml.structures.LabeledVector;
 import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
 
@@ -93,7 +93,7 @@ public class MLPGroupTrainerTest extends GridCommonAbstractTest {
             }
         }
 
-        int totalCnt = 100;
+        int totalCnt = 20;
         int failCnt = 0;
         double maxFailRatio = 0.3;
         MLPGroupUpdateTrainer<RPropParameterUpdate> trainer = MLPGroupUpdateTrainer.getDefault(ignite).
@@ -104,7 +104,7 @@ public class MLPGroupTrainerTest extends GridCommonAbstractTest {
         for (int i = 0; i < totalCnt; i++) {
 
             MLPGroupUpdateTrainerCacheInput trainerInput = new MLPGroupUpdateTrainerCacheInput(conf,
-                new RandomInitializer(rnd), 6, cache, 4);
+                new RandomInitializer(new Random(123L)), 6, cache, 4, new Random(123L));
 
             MultilayerPerceptron mlp = trainer.train(trainerInput);
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java
index e659e16..b4c14e1 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPLocalTrainerTest.java
@@ -27,10 +27,11 @@ import org.apache.ignite.ml.math.functions.IgniteSupplier;
 import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
 import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
 import org.apache.ignite.ml.nn.trainers.local.MLPLocalBatchTrainer;
-import org.apache.ignite.ml.nn.updaters.NesterovUpdateCalculator;
-import org.apache.ignite.ml.nn.updaters.ParameterUpdateCalculator;
-import org.apache.ignite.ml.nn.updaters.RPropUpdateCalculator;
-import org.apache.ignite.ml.nn.updaters.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.optimization.LossFunctions;
+import org.apache.ignite.ml.optimization.updatecalculators.NesterovUpdateCalculator;
+import org.apache.ignite.ml.optimization.updatecalculators.ParameterUpdateCalculator;
+import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
 import org.junit.Test;
 
 /**

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java
index d757fcb..555abce 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/MLPTest.java
@@ -25,6 +25,7 @@ import org.apache.ignite.ml.math.functions.IgniteTriFunction;
 import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
 import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
+import org.apache.ignite.ml.optimization.LossFunctions;
 import org.junit.Assert;
 import org.junit.Test;
 

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/test/java/org/apache/ignite/ml/nn/SimpleMLPLocalBatchTrainerInput.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/SimpleMLPLocalBatchTrainerInput.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/SimpleMLPLocalBatchTrainerInput.java
index 07a9e74..8bc0a6d 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/SimpleMLPLocalBatchTrainerInput.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/SimpleMLPLocalBatchTrainerInput.java
@@ -24,6 +24,7 @@ import org.apache.ignite.ml.math.functions.IgniteSupplier;
 import org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix;
 import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
 import org.apache.ignite.ml.nn.initializers.RandomInitializer;
+import org.apache.ignite.ml.trainers.local.LocalBatchTrainerInput;
 import org.apache.ignite.ml.util.Utils;
 
 /**

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistDistributed.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistDistributed.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistDistributed.java
index d9e4060..112aade 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistDistributed.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistDistributed.java
@@ -37,7 +37,7 @@ import org.apache.ignite.ml.nn.MLPGroupUpdateTrainerCacheInput;
 import org.apache.ignite.ml.nn.MultilayerPerceptron;
 import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
 import org.apache.ignite.ml.nn.trainers.distributed.MLPGroupUpdateTrainer;
-import org.apache.ignite.ml.nn.updaters.RPropParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.RPropParameterUpdate;
 import org.apache.ignite.ml.structures.LabeledVector;
 import org.apache.ignite.ml.util.MnistUtils;
 import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistLocal.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistLocal.java b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistLocal.java
index eab5288..cda0413a 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistLocal.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/nn/performance/MnistLocal.java
@@ -28,12 +28,12 @@ import org.apache.ignite.ml.math.Vector;
 import org.apache.ignite.ml.math.VectorUtils;
 import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
 import org.apache.ignite.ml.nn.Activators;
-import org.apache.ignite.ml.nn.LossFunctions;
+import org.apache.ignite.ml.optimization.LossFunctions;
 import org.apache.ignite.ml.nn.MultilayerPerceptron;
 import org.apache.ignite.ml.nn.SimpleMLPLocalBatchTrainerInput;
 import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
 import org.apache.ignite.ml.nn.trainers.local.MLPLocalBatchTrainer;
-import org.apache.ignite.ml.nn.updaters.RPropUpdateCalculator;
+import org.apache.ignite.ml.optimization.updatecalculators.RPropUpdateCalculator;
 import org.junit.Test;
 
 import static org.apache.ignite.ml.nn.performance.MnistMLPTestUtil.createDataset;

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/test/java/org/apache/ignite/ml/optimization/OptimizationTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/optimization/OptimizationTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/optimization/OptimizationTestSuite.java
new file mode 100644
index 0000000..0ae6e4c
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/optimization/OptimizationTestSuite.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.optimization;
+
+import org.apache.ignite.ml.optimization.util.SparseDistributedMatrixMapReducerTest;
+import org.junit.runner.RunWith;
+import org.junit.runners.Suite;
+
+/**
+ * Test suite for group trainer tests.
+ */
+@RunWith(Suite.class)
+@Suite.SuiteClasses({
+    GradientDescentTest.class,
+    SparseDistributedMatrixMapReducerTest.class
+})
+public class OptimizationTestSuite {
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/6eccf230/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainer.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainer.java
index d5b4ede..0a49fe0 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainer.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/group/TestGroupTrainer.java
@@ -83,7 +83,7 @@ public class TestGroupTrainer extends GroupTrainer<TestGroupTrainerLocalContext,
     /** {@inheritDoc} */
     @Override protected ComputationsChain<TestGroupTrainerLocalContext,
         Double, Integer, Double, Double> trainingLoopStep() {
-        // TODO:IGNITE-7350 here we should explicitly create variable because we cannot infer context type, think about it.
+        // TODO:IGNITE-7405 here we should explicitly create variable because we cannot infer context type, think about it.
         ComputationsChain<TestGroupTrainerLocalContext, Double, Integer, Double, Double> chain = Chains.
             create(new TestTrainingLoopStep());
         return chain.


Mime
View raw message