mxnet-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From j...@apache.org
Subject [incubator-mxnet] branch master updated: add export to gluon (#8212)
Date Fri, 13 Oct 2017 17:34:19 GMT
This is an automated email from the ASF dual-hosted git repository.

jxie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new ddc6c39  add export to gluon (#8212)
ddc6c39 is described below

commit ddc6c39ccb9ca692bb445f0cd9ca66d0d1b08218
Author: Eric Junyuan Xie <piiswrong@users.noreply.github.com>
AuthorDate: Fri Oct 13 10:34:16 2017 -0700

    add export to gluon (#8212)
    
    * add export
    
    * fix
    
    * add test
    
    * fix nnvm
    
    * fix
---
 nnvm                                |  2 +-
 python/mxnet/gluon/block.py         | 91 ++++++++++++++++++++++++++-----------
 python/mxnet/gluon/parameter.py     |  4 +-
 src/imperative/cached_op.cc         | 14 +++++-
 tests/python/unittest/test_gluon.py | 37 +++++++++++++++
 5 files changed, 119 insertions(+), 29 deletions(-)

diff --git a/nnvm b/nnvm
index 65a1a71..c86afa8 160000
--- a/nnvm
+++ b/nnvm
@@ -1 +1 @@
-Subproject commit 65a1a7104f8dca986c57765012555172239b31b1
+Subproject commit c86afa8f17a44bcd4e6eec41cd49ba87e4f7a635
diff --git a/python/mxnet/gluon/block.py b/python/mxnet/gluon/block.py
index def5d14..fb4ac85 100644
--- a/python/mxnet/gluon/block.py
+++ b/python/mxnet/gluon/block.py
@@ -340,23 +340,13 @@ class HybridBlock(Block):
                 "Block construction instead."
             self._reg_params[name] = value
 
-    def register_child(self, block):
-        if not isinstance(block, HybridBlock):
-            raise ValueError(
-                "Children of HybridBlock must also be HybridBlock, " \
-                "but %s has type %s. If you are using Sequential, " \
-                "please try HybridSequential instead"%(
-                    str(block), str(type(block))))
-        super(HybridBlock, self).register_child(block)
-
-    def hybridize(self, active=True):
-        self._active = active
-        super(HybridBlock, self).hybridize(active)
-
     def _get_graph(self, *args):
         if not self._cached_graph:
             args, self._in_format = _flatten(args)
-            inputs = [symbol.var('input_%d'%i) for i in range(len(args))]
+            if len(args) > 1:
+                inputs = [symbol.var('data%d'%i) for i in range(len(args))]
+            else:
+                inputs = [symbol.var('data')]
             grouped_inputs = _regroup(inputs, self._in_format)[0]
 
             params = {i: j.var() for i, j in self._reg_params.items()}
@@ -368,18 +358,6 @@ class HybridBlock(Block):
 
         return self._cached_graph
 
-    def infer_shape(self, *args):
-        """Infers shape of Parameters from inputs."""
-        inputs, out = self._get_graph(*args)
-        args, _ = _flatten(args)
-        arg_shapes, _, aux_shapes = out.infer_shape(
-            **{i.name: j.shape for i, j in zip(inputs, args)})
-        sdict = {i: j for i, j in zip(out.list_arguments(), arg_shapes)}
-        sdict.update({name : shape for name, shape in \
-                      zip(out.list_auxiliary_states(), aux_shapes)})
-        for i in self.collect_params().values():
-            i.shape = sdict[i.name]
-
     def _build_cache(self, *args):
         inputs, out = self._get_graph(*args)
         self._cached_op = ndarray.CachedOp(out)
@@ -415,6 +393,67 @@ class HybridBlock(Block):
             out = [out]
         return _regroup(out, self._out_format)[0]
 
+    def _clear_cached_op(self):
+        self._cached_graph = ()
+        self._cached_op = None
+
+    def register_child(self, block):
+        if not isinstance(block, HybridBlock):
+            raise ValueError(
+                "Children of HybridBlock must also be HybridBlock, " \
+                "but %s has type %s. If you are using Sequential, " \
+                "please try HybridSequential instead"%(
+                    str(block), str(type(block))))
+        super(HybridBlock, self).register_child(block)
+        self._clear_cached_op()
+
+    def hybridize(self, active=True):
+        self._active = active
+        super(HybridBlock, self).hybridize(active)
+
+    def infer_shape(self, *args):
+        """Infers shape of Parameters from inputs."""
+        inputs, out = self._get_graph(*args)
+        args, _ = _flatten(args)
+        arg_shapes, _, aux_shapes = out.infer_shape(
+            **{i.name: j.shape for i, j in zip(inputs, args)})
+        sdict = {i: j for i, j in zip(out.list_arguments(), arg_shapes)}
+        sdict.update({name : shape for name, shape in \
+                      zip(out.list_auxiliary_states(), aux_shapes)})
+        for i in self.collect_params().values():
+            i.shape = sdict[i.name]
+
+    def export(self, path):
+        """Export HybridBlock to json format that can be loaded by `mxnet.mod.Module`
+        or the C++ interface.
+
+        .. note:: When there are only one input, it will have name `data`. When there
+                  Are more than one inputs, they will be named as `data0`, `data1`, etc.
+
+        Parameters
+        ----------
+        path : str
+            Path to save model. Two files `path-symbol.json` and `path-0000.params`
+            will be created.
+        """
+        if not self._cached_graph:
+            raise RuntimeError(
+                "Please first call block.hybridize() and then run forward with "
+                "this block at least once before calling export.")
+        sym = self._cached_graph[1]
+        sym.save('%s-symbol.json'%path)
+
+        arg_names = set(sym.list_arguments())
+        aux_names = set(sym.list_auxiliary_states())
+        arg_dict = {}
+        for name, param in self.collect_params().items():
+            if name in arg_names:
+                arg_dict['arg:%s'%name] = param._reduce()
+            else:
+                assert name in aux_names
+                arg_dict['aux:%s'%name] = param._reduce()
+        ndarray.save('%s-0000.params'%path, arg_dict)
+
     def forward(self, x, *args):
         """Defines the forward computation. Arguments can be either
         :py:class:`NDArray` or :py:class:`Symbol`."""
diff --git a/python/mxnet/gluon/parameter.py b/python/mxnet/gluon/parameter.py
index 79b5ca3..c73aee2 100644
--- a/python/mxnet/gluon/parameter.py
+++ b/python/mxnet/gluon/parameter.py
@@ -614,7 +614,9 @@ class ParameterDict(object):
                     "restore_prefix is %s but Parameters name %s does not start " \
                     "with %s"%(restore_prefix, name, restore_prefix)
         lprefix = len(restore_prefix)
-        arg_dict = {restore_prefix+k: v for k, v in ndarray.load(filename).items()}
+        loaded = [(k[4:] if k.startswith('arg:') or k.startswith('aux:') else k, v) \
+                  for k, v in ndarray.load(filename).items()]
+        arg_dict = {restore_prefix+k: v for k, v in loaded}
         if not allow_missing:
             for name in self.keys():
                 assert name in arg_dict, \
diff --git a/src/imperative/cached_op.cc b/src/imperative/cached_op.cc
index c365371..eb99aab 100644
--- a/src/imperative/cached_op.cc
+++ b/src/imperative/cached_op.cc
@@ -310,12 +310,25 @@ OpStatePtr Imperative::CachedOp::Forward(const std::vector<NDArray*>&
inputs,
                                          const std::vector<NDArray*>& outputs)
{
   using namespace nnvm;
   using namespace imperative;
+
   bool recording = Imperative::Get()->set_is_recording(false);
   // Initialize
   nnvm::Graph g = GetForwardGraph(recording, inputs);
   const auto& idx = g.indexed_graph();
   size_t num_inputs = idx.input_nodes().size();
 
+  CHECK_EQ(num_inputs, inputs.size())
+      << "CachedOp requires " << num_inputs << " but got " << inputs.size();
+
+  Context default_ctx = inputs[0]->ctx();
+  for (size_t i = 0; i < inputs.size(); ++i) {
+    CHECK_EQ(inputs[i]->ctx(), default_ctx)
+        << "CachedOp requires all inputs to live on the same context. But "
+        << idx[idx.input_nodes()[0]].source->attrs.name << " is on " <<
default_ctx
+        << " while " << idx[idx.input_nodes()[i]].source->attrs.name <<
" is on "
+        << inputs[i]->ctx();
+  }
+
   auto op_state_ptr = OpStatePtr::Create<CachedOpState>();
   auto& cached_op_state = op_state_ptr.get_state<CachedOpState>();
   auto& buff = cached_op_state.buff;
@@ -346,7 +359,6 @@ OpStatePtr Imperative::CachedOp::Forward(const std::vector<NDArray*>&
inputs,
     if (ref_count[i] == 0) array_reqs[i] = kNullOp;
   }
 
-  Context default_ctx = inputs[0]->ctx();
   const auto& mem_plan = g.GetAttr<MemoryPlanVector >(
       recording ? "full_mem_plan" : "forward_mem_plan");
   AllocateMemory(g, idx, default_ctx, 0, idx.num_node_entries(),
diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py
index 5432e17..60a0630 100644
--- a/tests/python/unittest/test_gluon.py
+++ b/tests/python/unittest/test_gluon.py
@@ -480,6 +480,43 @@ def test_embedding():
     assert (layer.weight.grad()[5:] == 0).asnumpy().all()
 
 
+def test_export():
+    ctx = mx.context.current_context()
+    model = gluon.model_zoo.vision.resnet18_v1(
+        prefix='resnet', ctx=ctx, pretrained=True)
+    model.hybridize()
+    data = mx.nd.random.normal(shape=(1, 3, 224, 224))
+    out = model(data)
+
+    model.export('gluon')
+
+    module = mx.mod.Module.load('gluon', 0, label_names=None, context=ctx)
+    module.bind(data_shapes=[('data', data.shape)])
+    module.forward(mx.io.DataBatch([data], None), is_train=False)
+    mod_out, = module.get_outputs()
+
+    assert_almost_equal(out.asnumpy(), mod_out.asnumpy())
+
+    model2 = gluon.model_zoo.vision.resnet18_v1(prefix='resnet', ctx=ctx)
+    model2.collect_params().load('gluon-0000.params', ctx)
+    out2 = model2(data)
+
+    assert_almost_equal(out.asnumpy(), out2.asnumpy())
+
+
+def test_hybrid_stale_cache():
+    net = mx.gluon.nn.HybridSequential()
+    with net.name_scope():
+        net.add(mx.gluon.nn.Dense(10, weight_initializer='zeros', bias_initializer='ones',
flatten=False))
+
+    net.hybridize()
+    net.initialize()
+    net(mx.nd.ones((2,3,5)))
+
+    net.add(mx.gluon.nn.Flatten())
+    assert net(mx.nd.ones((2,3,5))).shape == (2, 30)
+
+
 if __name__ == '__main__':
     import nose
     nose.runmodule()

-- 
To stop receiving notification emails like this one, please contact
['"commits@mxnet.apache.org" <commits@mxnet.apache.org>'].

Mime
View raw message