mxnet-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] piiswrong closed pull request #10300: Fix list output names typo
Date Wed, 28 Mar 2018 21:58:49 GMT
piiswrong closed pull request #10300: Fix list output names typo
URL: https://github.com/apache/incubator-mxnet/pull/10300
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc
index 79f2e800ee1..7aafe9d82f7 100644
--- a/src/operator/nn/convolution.cc
+++ b/src/operator/nn/convolution.cc
@@ -468,7 +468,7 @@ There are other options to tune the performance.
   else
     return std::vector<std::string>{"data", "weight", "bias"};
 })
-.set_attr<nnvm::FListInputNames>("FListOutputNames",
+.set_attr<nnvm::FListOutputNames>("FListOutputNames",
     [](const NodeAttrs& attrs) {
     return std::vector<std::string>{"output"};
 })
diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc
index 475b6362516..278c130064f 100644
--- a/src/operator/nn/fully_connected.cc
+++ b/src/operator/nn/fully_connected.cc
@@ -267,7 +267,7 @@ This could be used for model inference with `row_sparse` weights trained
with `S
     return std::vector<std::string>{"data", "weight"};
   }
 })
-.set_attr<nnvm::FListInputNames>("FListOutputNames",
+.set_attr<nnvm::FListOutputNames>("FListOutputNames",
     [](const NodeAttrs& attrs) {
     return std::vector<std::string>{"output"};
 })
diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py
index 2486be04a52..20cc4b511cc 100644
--- a/tests/python/unittest/test_operator.py
+++ b/tests/python/unittest/test_operator.py
@@ -24,6 +24,7 @@
 import itertools
 from numpy.testing import assert_allclose, assert_array_equal
 from mxnet.test_utils import *
+from mxnet.base import py_str
 from common import setup_module, with_seed
 import unittest
 
@@ -5399,6 +5400,30 @@ def f(x, a, b, c):
         check_numeric_gradient(quad_sym, [data_np], atol=0.001)
 
 
+def test_op_output_names_monitor():
+    def check_name(op_sym, expected_names):
+        output_names = []
+
+        def get_output_names_callback(name, arr):
+            output_names.append(py_str(name))
+
+        op_exe = op_sym.simple_bind(ctx=mx.current_context(), grad_req='null')
+        op_exe.set_monitor_callback(get_output_names_callback)
+        op_exe.forward()
+        for output_name, expected_name in zip(output_names, expected_names):
+            assert output_name == expected_name
+
+    data = mx.sym.Variable('data', shape=(10, 3, 10, 10))
+    conv_sym = mx.sym.Convolution(data, kernel=(2, 2), num_filter=1, name='conv')
+    check_name(conv_sym, ['conv_output'])
+
+    fc_sym = mx.sym.FullyConnected(data, num_hidden=10, name='fc')
+    check_name(fc_sym, ['fc_output'])
+
+    lrn_sym = mx.sym.LRN(data, nsize=1, name='lrn')
+    check_name(lrn_sym, ['lrn_output', 'lrn_tmp_norm'])
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message