mxnet-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From t..@apache.org
Subject [incubator-mxnet] branch master updated: [MXNET-173]fix acc metric shape miss match (#10446)
Date Wed, 11 Apr 2018 00:27:43 GMT
This is an automated email from the ASF dual-hosted git repository.

the pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new e42f7e0  [MXNET-173]fix acc metric shape miss match (#10446)
e42f7e0 is described below

commit e42f7e08bb5d8277bd87f81babf42cb5121ec160
Author: Lai Wei <royweilai@gmail.com>
AuthorDate: Tue Apr 10 17:27:36 2018 -0700

    [MXNET-173]fix acc metric shape miss match (#10446)
    
    * fix acc metric shape miss match
    
    * add unit test
    
    * fix style
    
    * fix python2 division
---
 python/mxnet/metric.py               |  7 +++++--
 tests/python/unittest/test_metric.py | 11 +++++++++++
 2 files changed, 16 insertions(+), 2 deletions(-)

diff --git a/python/mxnet/metric.py b/python/mxnet/metric.py
index d8facbc..152cbc1 100644
--- a/python/mxnet/metric.py
+++ b/python/mxnet/metric.py
@@ -417,11 +417,14 @@ class Accuracy(EvalMetric):
                 pred_label = ndarray.argmax(pred_label, axis=self.axis)
             pred_label = pred_label.asnumpy().astype('int32')
             label = label.asnumpy().astype('int32')
+            # flatten before checking shapes to avoid shape miss match
+            label = label.flat
+            pred_label = pred_label.flat
 
             labels, preds = check_label_shapes(label, pred_label)
 
-            self.sum_metric += (pred_label.flat == label.flat).sum()
-            self.num_inst += len(pred_label.flat)
+            self.sum_metric += (pred_label == label).sum()
+            self.num_inst += len(pred_label)
 
 
 @register
diff --git a/tests/python/unittest/test_metric.py b/tests/python/unittest/test_metric.py
index bcb0e2d..1571a0b 100644
--- a/tests/python/unittest/test_metric.py
+++ b/tests/python/unittest/test_metric.py
@@ -54,6 +54,17 @@ def test_acc():
     expected_acc = (np.argmax(pred, axis=1) == label).sum().asscalar() / label.size
     assert acc == expected_acc
 
+def test_acc_2d_label():
+    # label maybe provided in 2d arrays in custom data iterator
+    pred = mx.nd.array([[0.3, 0.7], [0, 1.], [0.4, 0.6], [0.8, 0.2], [0.3, 0.5], [0.6, 0.4]])
+    label = mx.nd.array([[0, 1, 1], [1, 0, 1]])
+    metric = mx.metric.create('acc')
+    metric.update([label], [pred])
+    _, acc = metric.get()
+    expected_acc = (np.argmax(pred, axis=1).asnumpy() == label.asnumpy().ravel()).sum() /
\
+                   float(label.asnumpy().ravel().size)
+    assert acc == expected_acc
+
 def test_f1():
     microF1 = mx.metric.create("f1", average="micro")
     macroF1 = mx.metric.F1(average="macro")

-- 
To stop receiving notification emails like this one, please contact
the@apache.org.

Mime
View raw message