singa-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-singa] chrishkchris commented on a change in pull request #468: Distributted module
Date Mon, 26 Aug 2019 14:14:06 GMT
chrishkchris commented on a change in pull request #468: Distributted module
URL: https://github.com/apache/incubator-singa/pull/468#discussion_r317527518
 
 

 ##########
 File path: python/singa/autograd.py
 ##########
 @@ -1286,25 +1287,26 @@ def set_params(self, **parameters):
 
 
 class _BatchNorm2d(Operation):
-    def __init__(self, handle, name=None):
+    def __init__(self, handle, running_mean, running_var, name=None):
         super(_BatchNorm2d, self).__init__(name)
         self.handle = handle
+        self.running_mean = running_mean.data
+        self.running_var = running_var.data
 
-    def forward(self, x, scale, bias, running_mean, running_var):
-        self.running_mean = running_mean
-        self.running_var = running_var
+    def forward(self, x, scale, bias):
         if training:
 
             if isinstance(self.handle, singa.CudnnBatchNormHandle):
                 y, mean, var = singa.GpuBatchNormForwardTraining(
-                    self.handle, x, scale, bias, running_mean, running_var
+                    self.handle, x, scale, bias, self.running_mean, self.running_var
 
 Review comment:
   Adding two batchnorm layers in the CNN for the MNIST, the cpu training is looking okay,
loss reduce from 608.7 to 91.5
   ```
   ubuntu@ip-172-31-36-94:~/incubator-singa/examples/autograd$ python3 mnist_dist_bn.py
   Starting Epoch 0:
   Training loss = 608.656677, training accuracy = 0.785035
   Evaluation accuracy = 0.900541, Elapsed Time = 130.748269s
   Starting Epoch 1:
   Training loss = 259.606445, training accuracy = 0.911720
   Evaluation accuracy = 0.951222, Elapsed Time = 129.687239s
   Starting Epoch 2:
   Training loss = 180.270645, training accuracy = 0.938917
   Evaluation accuracy = 0.965745, Elapsed Time = 129.715867s
   Starting Epoch 3:
   Training loss = 146.975281, training accuracy = 0.950607
   Evaluation accuracy = 0.961138, Elapsed Time = 129.695524s
   Starting Epoch 4:
   Training loss = 130.942749, training accuracy = 0.955576
   Evaluation accuracy = 0.966446, Elapsed Time = 129.814190s
   Starting Epoch 5:
   Training loss = 116.057938, training accuracy = 0.960846
   Evaluation accuracy = 0.964844, Elapsed Time = 129.697126s
   Starting Epoch 6:
   Training loss = 105.867195, training accuracy = 0.963914
   Evaluation accuracy = 0.973758, Elapsed Time = 129.782990s
   Starting Epoch 7:
   Training loss = 102.414818, training accuracy = 0.965498
   Evaluation accuracy = 0.973357, Elapsed Time = 129.847014s
   Starting Epoch 8:
   Training loss = 95.194695, training accuracy = 0.968433
   Evaluation accuracy = 0.973658, Elapsed Time = 129.762709s
   Starting Epoch 9:
   Training loss = 91.524719, training accuracy = 0.969717
   Evaluation accuracy = 0.975160, Elapsed Time = 129.581387s
   ```
   
   

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message