singa-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-singa] chrishkchris commented on a change in pull request #468: Distributted module
Date Mon, 26 Aug 2019 13:29:26 GMT
chrishkchris commented on a change in pull request #468: Distributted module
URL: https://github.com/apache/incubator-singa/pull/468#discussion_r317527518
 
 

 ##########
 File path: python/singa/autograd.py
 ##########
 @@ -1286,25 +1287,26 @@ def set_params(self, **parameters):
 
 
 class _BatchNorm2d(Operation):
-    def __init__(self, handle, name=None):
+    def __init__(self, handle, running_mean, running_var, name=None):
         super(_BatchNorm2d, self).__init__(name)
         self.handle = handle
+        self.running_mean = running_mean.data
+        self.running_var = running_var.data
 
-    def forward(self, x, scale, bias, running_mean, running_var):
-        self.running_mean = running_mean
-        self.running_var = running_var
+    def forward(self, x, scale, bias):
         if training:
 
             if isinstance(self.handle, singa.CudnnBatchNormHandle):
                 y, mean, var = singa.GpuBatchNormForwardTraining(
-                    self.handle, x, scale, bias, running_mean, running_var
+                    self.handle, x, scale, bias, self.running_mean, self.running_var
 
 Review comment:
   Still testing with real dataset using cpu
   
   Adding a batchnorm layer in the CNN for the MNIST, the cpu training is looking okay,
   ```
   ubuntu@ip-172-31-31-187:~/incubator-singa/examples/autograd$ python3 mnist_dist_bn.py
   Starting Epoch 0:
   Training loss = 1115.447021, training accuracy = 0.592116
   Evaluation accuracy = 0.649239, Elapsed Time = 91.413475s
   Starting Epoch 1:
   Training loss = 615.564209, training accuracy = 0.783968
   Evaluation accuracy = 0.878105, Elapsed Time = 91.053461s
   Starting Epoch 2:
   Training loss = 444.550018, training accuracy = 0.848286
   Evaluation accuracy = 0.900240, Elapsed Time = 91.064094s
   Starting Epoch 3:
   Training loss = 333.629150, training accuracy = 0.886690
   Evaluation accuracy = 0.857772, Elapsed Time = 90.935190s
   Starting Epoch 4:
   Training loss = 289.389832, training accuracy = 0.902648
   Evaluation accuracy = 0.913462, Elapsed Time = 91.152710s
   Starting Epoch 5:
   Training loss = 263.009583, training accuracy = 0.910836
   Evaluation accuracy = 0.922877, Elapsed Time = 91.171680s
   Starting Epoch 6:
   Training loss = 238.859818, training accuracy = 0.918957
   Evaluation accuracy = 0.933794, Elapsed Time = 91.016456s
   Starting Epoch 7:
   Training loss = 215.822647, training accuracy = 0.927428
   Evaluation accuracy = 0.946615, Elapsed Time = 90.870825s
   Starting Epoch 8:
   Training loss = 202.828430, training accuracy = 0.932080
   Evaluation accuracy = 0.948017, Elapsed Time = 91.014656s
   Starting Epoch 9:
   Training loss = 190.810226, training accuracy = 0.935899
   Evaluation accuracy = 0.949820, Elapsed Time = 91.270044s
   ```
   
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message