singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wang...@apache.org
Subject [07/18] incubator-singa git commit: SINGA-371 Implement functional operations in c++ for autograd
Date Thu, 05 Jul 2018 03:10:02 GMT
SINGA-371 Implement functional operations in c++ for autograd

- fix some bugs.
- the conv2d_gpu operation has pass tests.


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/78e1fc23
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/78e1fc23
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/78e1fc23

Branch: refs/heads/master
Commit: 78e1fc230a14510239728e585103e7c5791c3943
Parents: c57b87a
Author: xuewanqi <xue_wanqi@u.nus.edu>
Authored: Thu Jun 21 03:10:33 2018 +0000
Committer: xuewanqi <xue_wanqi@u.nus.edu>
Committed: Thu Jun 21 03:10:33 2018 +0000

----------------------------------------------------------------------
 python/singa/autograd.py | 20 +++++++++++---------
 1 file changed, 11 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/78e1fc23/python/singa/autograd.py
----------------------------------------------------------------------
diff --git a/python/singa/autograd.py b/python/singa/autograd.py
index c7e0adb..7ba68f5 100644
--- a/python/singa/autograd.py
+++ b/python/singa/autograd.py
@@ -616,7 +616,7 @@ class Conv2d_GPU(Operation):
 
         self.bias = bias
 
-        inner_params = {'cudnn_prefer': 'fastest', 'workspace_byte_limit': 1024}
+        inner_params = {'cudnn_prefer': 'fastest', 'workspace_MB_limit': 1024}
         # TODO valid value of inner_params check
 
         for kwarg in kwargs:
@@ -627,7 +627,8 @@ class Conv2d_GPU(Operation):
 
         self.convhandle = singa.SetupConv(self.kernel_size[0], self.kernel_size[1],
         			self.padding[0], self.padding[1], self.stride[0], self.stride[1],
-        			self.bias, inner_params['workspace_byte_limit']*1024*1024,
+        			self.in_channels, self.out_channels, self.bias, 
+                                inner_params['workspace_MB_limit']*1024*1024,
         			inner_params['cudnn_prefer'])
         
         w_shape = (self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1])
@@ -650,10 +651,13 @@ class Conv2d_GPU(Operation):
         assert 0 == 0, 'invalid padding.'
     	# TODO valid padding check.
 
-    	if not hasattr (self, cudnnconvhandle):
+    	if not hasattr (self, 'cudnnconvhandle'):
     	    self.cudnnconvhandle = singa.InitCudnn(x.data, self.convhandle)
     	elif x.shape[0] != self.cudnnconvhandle.batchsize:
     	    self.cudnnconvhandle = singa.InitCudnn(x.data, self.convhandle)
+        
+        if training:
+            self.x = x
 
     	self.dev = x.device
 
@@ -665,20 +669,18 @@ class Conv2d_GPU(Operation):
     	return self._do_forward(*xs)[0]
 
     def forward(self, *xs):
-        if training:
-    	    self.x = xs[0]
         return singa.CudnnConvForward(xs[0], xs[1], xs[2], self.convhandle, self.cudnnconvhandle)
 
     def backward(self, dy):
-        assert training is True and hasattr(self, x), 'Please set \'trainging\' as True before
do BP. '
+        assert training is True and hasattr(self, 'x'), 'Please set \'trainging\' as True
before do BP. '
 
         # todo check device?
         dy.ToDevice(self.dev)
 
-        dx = singa.CudnnConvBackwardx(dy, self.W, self.x, self.cch)
-        dW = singa.CudnnConvBackwardW(dy, self.x, self.W, self.cch)
+        dx = singa.CudnnConvBackwardx(dy, self.W.data, self.x.data, self.cudnnconvhandle)
+        dW = singa.CudnnConvBackwardW(dy, self.x.data, self.W.data, self.cudnnconvhandle)
         if self.bias:
-    	    db = singa.CudnnConvBackwardb(dy, self.b, self.cch)
+    	    db = singa.CudnnConvBackwardb(dy, self.b.data, self.cudnnconvhandle)
     	    return dx, dW, db
         else:
     	    return dx, dW


Mime
View raw message