mxnet-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-mxnet] wkcn commented on issue #18333: How to implement cross entropy for binary segmentation using symbol only?
Date Wed, 20 May 2020 17:13:40 GMT

wkcn commented on issue #18333:
URL: https://github.com/apache/incubator-mxnet/issues/18333#issuecomment-631608741


   Hi @John1231983 , using `mx.nd.pick` will simplify the code.
   ```python
   import mxnet as mx
   
   B, C, H, W = 2, 2, 3, 4
   
   x = mx.random.uniform(-1, 1, shape=(B, C, H, W)) 
   target = mx.random.randint(0, C, shape=(B, 1, H, W)) 
   
   f = mx.nd.softmax(x)
   # target size of Bx1xHxW
   target_squeeze = mx.nd.squeeze(target, axis=1) #size of BxHxW
   target_squeeze = mx.nd.one_hot(target_squeeze, depth = 2, on_value = -1.0, off_value =
0.0) 
   # Transpose from BxHxWx2 to Bx2xHxW
   target_squeeze = mx.nd.transpose(target_squeeze, axes=(0,3,1,2))
   # Get log of feature f
   f_log  = mx.nd.log(f)
   batch_size =32
   f_sum = mx.nd.sum(target_squeeze * f_log)/batch_size
   print(f_sum)
   
   lscore = -mx.nd.log_softmax(x)
   target_squeeze = mx.nd.squeeze(target, axis=1) #size of BxHxW
   t_sum = mx.nd.pick(lscore, target_squeeze, axis=1).sum() / batch_size
   print(t_sum) # t_sum == f_sum
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



Mime
View raw message