This is an automated email from the ASF dual-hosted git repository.
haibin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 9ccd647 standard sgd_update (#10614)
9ccd647 is described below
commit 9ccd64787a05c6c04466ecfb0763b70ee8fbc988
Author: Ziyue Huang <zyhuang94@gmail.com>
AuthorDate: Fri Apr 20 08:03:02 2018 +0800
standard sgd_update (#10614)
---
python/mxnet/optimizer.py | 2 +-
src/operator/optimizer_op-inl.h | 15 +++++++++++++--
tests/python/unittest/test_optimizer.py | 2 +-
3 files changed, 15 insertions(+), 4 deletions(-)
diff --git a/python/mxnet/optimizer.py b/python/mxnet/optimizer.py
index 6589e77..2f7c51b 100644
--- a/python/mxnet/optimizer.py
+++ b/python/mxnet/optimizer.py
@@ -516,7 +516,7 @@ class SGD(Optimizer):
sgd_mom_update(weight, grad, state, out=weight,
lr=lr, wd=wd, **kwargs)
else:
- sgd_update(weight, grad, out=weight,
+ sgd_update(weight, grad, out=weight, lazy_update=self.lazy_update,
lr=lr, wd=wd, **kwargs)
else:
if state[0] is not None:
diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h
index 3b6bd57..dfc7bef 100644
--- a/src/operator/optimizer_op-inl.h
+++ b/src/operator/optimizer_op-inl.h
@@ -47,6 +47,7 @@ struct SGDParam : public dmlc::Parameter<SGDParam> {
float wd;
float rescale_grad;
float clip_gradient;
+ bool lazy_update;
DMLC_DECLARE_PARAMETER(SGDParam) {
DMLC_DECLARE_FIELD(lr)
.describe("Learning rate");
@@ -63,6 +64,9 @@ struct SGDParam : public dmlc::Parameter<SGDParam> {
.describe("Clip gradient to the range of [-clip_gradient, clip_gradient] "
"If clip_gradient <= 0, gradient clipping is turned off. "
"grad = max(min(grad, clip_gradient), -clip_gradient).");
+ DMLC_DECLARE_FIELD(lazy_update)
+ .set_default(true)
+ .describe("If true, lazy updates are applied.");
}
};
@@ -177,7 +181,7 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param,
Stream<xpu>* s = ctx.get_stream<xpu>();
CHECK_EQ(grad.storage_type(), kRowSparseStorage);
// if gradients are zeros, no weights are updated
- if (!grad.storage_initialized() || req == kNullOp) return;
+ if (req == kNullOp) return;
CHECK_EQ(req, kWriteInplace) << "kWriteInplace is expected for sparse sgd_mom_update";
CHECK_GT(weight.shape_.Size(), 0);
@@ -185,6 +189,13 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param,
MSHADOW_IDX_TYPE_SWITCH(grad.aux_type(rowsparse::kIdx), IType, {
MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
DType* weight_data = weight.dptr<DType>();
+ float wd = param.wd;
+ if (!param.lazy_update) {
+ Kernel<op_with_req<mshadow_op::mul, req_type>, xpu>::Launch(s, weight.Size(),
+ weight_data, weight_data, static_cast<DType>(1 - param.lr * param.wd));
+ wd = 0;
+ }
+ if (!grad.storage_initialized()) return;
const IType* grad_idx = grad.aux_data(rowsparse::kIdx).dptr<IType>();
const DType* grad_val = grad.data().dptr<DType>();
const nnvm::dim_t num_rows = grad.aux_shape(rowsparse::kIdx)[0];
@@ -196,7 +207,7 @@ inline void SGDUpdateDnsRspImpl(const SGDParam& param,
Kernel<SGDDnsRspKernel<req_type, xpu>, xpu>::Launch(s, num_threads, row_length,
out->dptr<DType>(), weight_data, grad_idx, grad_val,
static_cast<DType>(param.clip_gradient),
- static_cast<DType>(param.lr), static_cast<DType>(param.wd),
+ static_cast<DType>(param.lr), static_cast<DType>(wd),
static_cast<DType>(param.rescale_grad));
});
});
diff --git a/tests/python/unittest/test_optimizer.py b/tests/python/unittest/test_optimizer.py
index bbd7845..d1dc31a 100644
--- a/tests/python/unittest/test_optimizer.py
+++ b/tests/python/unittest/test_optimizer.py
@@ -344,7 +344,7 @@ def test_std_sparse_sgd():
opt1 = PySGD
opt2 = mx.optimizer.SGD
shape = (3, 4, 5)
- mom_options = [{'momentum': 0.9}]
+ mom_options = [{'momentum': 0.0}, {'momentum': 0.9}]
cg_options = [{}, {'clip_gradient': 0.4}, {'clip_gradient': 0.5}]
rg_options = [{}, {'rescale_grad': 0.14}, {'rescale_grad': 0.8}]
wd_options = [{}, {'wd': 0.03}, {'wd': 0.05}, {'wd': 0.07}]
--
To stop receiving notification emails like this one, please contact
haibin@apache.org.
|