singa-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From wang...@apache.org
Subject incubator-singa git commit: SINGA-290 Upgrade to Python 3 [Forced Update!]
Date Tue, 15 Aug 2017 15:16:58 GMT
Repository: incubator-singa
Updated Branches:
  refs/heads/master 69d3513a6 -> 8d03bd8d0 (forced update)


SINGA-290 Upgrade to Python 3

Fix the bugs for kernel/pad setting for 1D conv/pool
Fix the bugs for python unit test (uneven padding for testing conv layers)


Project: http://git-wip-us.apache.org/repos/asf/incubator-singa/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-singa/commit/8d03bd8d
Tree: http://git-wip-us.apache.org/repos/asf/incubator-singa/tree/8d03bd8d
Diff: http://git-wip-us.apache.org/repos/asf/incubator-singa/diff/8d03bd8d

Branch: refs/heads/master
Commit: 8d03bd8d0373962cb1da5aee5aa96165f6d46e6b
Parents: 0330af0
Author: Wei Wang <wangwei@comp.nus.edu.sg>
Authored: Tue Aug 15 23:02:46 2017 +0800
Committer: Wei Wang <wangwei@comp.nus.edu.sg>
Committed: Tue Aug 15 23:16:39 2017 +0800

----------------------------------------------------------------------
 python/singa/layer.py     | 12 ++++++++----
 test/python/test_layer.py |  6 +++---
 2 files changed, 11 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8d03bd8d/python/singa/layer.py
----------------------------------------------------------------------
diff --git a/python/singa/layer.py b/python/singa/layer.py
index 8a32279..348bedc 100644
--- a/python/singa/layer.py
+++ b/python/singa/layer.py
@@ -1214,10 +1214,14 @@ def _set_kernel_stride_pad(conf, kernel, stride, border_mode, pad,
in_shape):
     if pad is None:
         # TODO(wangwei) check the border mode
         if mode == 'same':
-            out_h = in_shape[1] / conf.stride_h
-            out_w = in_shape[2] / conf.stride_w
-            ph = max((out_h - 1) * conf.stride_h + conf.kernel_h - in_shape[1],
-                     0)
+            if conf.stride_h != 0:
+                out_h = in_shape[1] // conf.stride_h
+                ph = max(
+                    (out_h - 1) * conf.stride_h + conf.kernel_h - in_shape[1],
+                    0)
+            else:
+                ph = 0
+            out_w = in_shape[2] // conf.stride_w
             pw = max((out_w - 1) * conf.stride_w + conf.kernel_w - in_shape[2],
                      0)
             assert ph % 2 == 0 and pw % 2 == 0, 'ph=%d and pw=%d are not even' \

http://git-wip-us.apache.org/repos/asf/incubator-singa/blob/8d03bd8d/test/python/test_layer.py
----------------------------------------------------------------------
diff --git a/test/python/test_layer.py b/test/python/test_layer.py
index ec5becf..b98065f 100644
--- a/test/python/test_layer.py
+++ b/test/python/test_layer.py
@@ -132,7 +132,7 @@ class TestPythonLayer(unittest.TestCase):
         self.check_shape(out_sample_shape, (64, 224,))
 
     def test_max_pooling2D(self):
-        in_sample_shape = (64, 224, 224)
+        in_sample_shape = (64, 225, 225)
         pooling = layer.MaxPooling2D('pool', 3, 2,
                                      input_sample_shape=in_sample_shape)
         out_sample_shape = pooling.get_output_sample_shape()
@@ -146,7 +146,7 @@ class TestPythonLayer(unittest.TestCase):
         self.check_shape(out_sample_shape, (112,))
 
     def test_avg_pooling2D(self):
-        in_sample_shape = (64, 224, 224)
+        in_sample_shape = (64, 225, 225)
         pooling = layer.AvgPooling2D('pool', 3, 2,
                                      input_sample_shape=in_sample_shape)
         out_sample_shape = pooling.get_output_sample_shape()
@@ -234,7 +234,7 @@ class TestPythonLayer(unittest.TestCase):
         t2 = tensor.Tensor((3, 1))
         t1.set_value(1)
         t2.set_value(2)
-        grad,_ = lyr.backward(model_pb2.kTrain, [t1, t2])
+        grad, _ = lyr.backward(model_pb2.kTrain, [t1, t2])
         gnp = tensor.to_numpy(grad)
         self.assertEquals(np.sum(gnp), 12)
 


Mime
View raw message