This is an automated email from the ASF dual-hosted git repository.
the pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new e42f7e0 [MXNET-173]fix acc metric shape miss match (#10446)
e42f7e0 is described below
commit e42f7e08bb5d8277bd87f81babf42cb5121ec160
Author: Lai Wei <royweilai@gmail.com>
AuthorDate: Tue Apr 10 17:27:36 2018 -0700
[MXNET-173]fix acc metric shape miss match (#10446)
* fix acc metric shape miss match
* add unit test
* fix style
* fix python2 division
---
python/mxnet/metric.py | 7 +++++--
tests/python/unittest/test_metric.py | 11 +++++++++++
2 files changed, 16 insertions(+), 2 deletions(-)
diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index d8facbc..152cbc1 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -417,11 +417,14 @@ class Accuracy(EvalMetric):
pred_label = ndarray.argmax(pred_label, axis=self.axis)
pred_label = pred_label.asnumpy().astype('int32')
label = label.asnumpy().astype('int32')
+ # flatten before checking shapes to avoid shape miss match
+ label = label.flat
+ pred_label = pred_label.flat
labels, preds = check_label_shapes(label, pred_label)
- self.sum_metric += (pred_label.flat == label.flat).sum()
- self.num_inst += len(pred_label.flat)
+ self.sum_metric += (pred_label == label).sum()
+ self.num_inst += len(pred_label)
@register
diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py
index bcb0e2d..1571a0b 100644
--- a/tests/python/unittest/test_metric.py
+++ b/tests/python/unittest/test_metric.py
@@ -54,6 +54,17 @@ def test_acc():
expected_acc = (np.argmax(pred, axis=1) == label).sum().asscalar() / label.size
assert acc == expected_acc
+def test_acc_2d_label():
+ # label maybe provided in 2d arrays in custom data iterator
+ pred = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6], [0.8, 0.2], [0.3, 0.5], [0.6, 0.4]])
+ label = mx.nd.array([[0, 1, 1], [1, 0, 1]])
+ metric = mx.metric.create('acc')
+ metric.update([label], [pred])
+ _, acc = metric.get()
+ expected_acc = (np.argmax(pred, axis=1).asnumpy() == label.asnumpy().ravel()).sum() /
\
+ float(label.asnumpy().ravel().size)
+ assert acc == expected_acc
+
def test_f1():
microF1 = mx.metric.create("f1", average="micro")
macroF1 = mx.metric.F1(average="macro")
--
To stop receiving notification emails like this one, please contact
the@apache.org.
|