singa-dev mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [singa] chrishkchris edited a comment on issue #576: The python test case in test_operation.py may need to be updated
Date Wed, 22 Jan 2020 06:20:44 GMT
chrishkchris edited a comment on issue #576: The python test case in test_operation.py may
need to be updated
URL: https://github.com/apache/singa/issues/576#issuecomment-577003957
 
 
   If I am correct, the reshape is due to the error in backward:
   
   ```
   class Reshape(Operation):
   
       def __init__(self,shape):
           super(Reshape, self).__init__()
           if isinstance(shape, tensor.Tensor):
               self.shape = np.asarray(tensor.to_numpy(shape).astype(np.int32)).tolist()
           else:
               self.shape = list(shape)
   
       def forward(self, x):
           _shape = x.shape()
           shape = self.shape
           # handle the shape with 0
           shape = [_shape[i] if i < len(_shape) and shape[i] == 0 else shape[i] for i
in range(len(shape))]
           # handle the shape with -1
           hidden_shape = int(np.prod(_shape) // np.abs(np.prod(shape)))
           self.cache=[s if s != -1 else hidden_shape for s in shape]
           return singa.Reshape(x, self.cache)
   
       def backward(self, dy):
           return singa.Reshape(dy, self.cache)
   ```
   
   I think the function should change to 
   ```
   class Reshape(Operation):
       def __init__(self,shape):
           super(Reshape, self).__init__()
           if isinstance(shape, tensor.Tensor):
               self.shape = np.asarray(tensor.to_numpy(shape).astype(np.int32)).tolist()
           else:
               self.shape = list(shape)
   
       def forward(self, x):
           self._shape = x.shape()
           shape = self.shape
           # handle the shape with 0
           shape = [self._shape[i] if i < len(self._shape) and shape[i] == 0 else shape[i]
for i in range(len(shape))]
           # handle the shape with -1
           hidden_shape = int(np.prod(self._shape) // np.abs(np.prod(shape)))
           self.cache=[s if s != -1 else hidden_shape for s in shape]
   
           return singa.Reshape(x, self.cache)
   
       def backward(self, dy):
           return singa.Reshape(dy, self._shape)
   
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message