handar423 opened a new pull request #5695:
URL: https://github.com/apache/incubatortvm/pull/5695
To whom it may concern,
Hello, I am learning tvm and trying to write some training model with relay when I
got an error in `dense_grad()` with the size of data(`5*4`) and weight(`3*4`). Furthermore,
I found that it may be caused by a small bug in dense_grad():
the present dense_grad() is:
```python
@register_gradient("nn.dense")
def dense_grad(orig, grad):
"""Returns [grad' @ weight, data @ grad']"""
data, weight = orig.args
return [collapse_sum_like(transpose(grad) * weight, data),
collapse_sum_like(data * transpose(grad), weight)]
```
in a common situation, when we calculate the gradient of `dense(A(i * j), weight(k * j))`,
we get grad matrix with size `i * k`, so in above `dense_grad()`, the first multiply operator
get parameters with size `k * i` and `k * j`, the second one get paramenters with size `i
* j` and `k * i`, so we can only avoid conflict when `i == j == k` or some of them are `1`.
To increase the robustness of the function, maybe we can modify it to:
```python
@register_gradient("nn.dense")
def dense_grad(orig, grad):
"""Returns [grad' @ weight, data @ grad']"""
data, weight = orig.args
return [collapse_sum_like(_nn.dense(grad, transpose(weight)), data),
collapse_sum_like(_nn.dense(transpose(grad), transpose(data)), weight)]
```
we change multiply(`*`) to `_nn.dense` so that it can handel matrix multiply as well. For
above assumption, the first `_nn.dense()` get parameters with size `i * k` and `j * k` and
give a result with size `i * j`, which is the same as `data`; the second one get parameters
with size `k * i` and `j * i` and give a result with size `k * j`, which is the same as `weight`.
We add an extra test case in test_dense_grad() to test its correctness.
I am just starting to learn about tvm, so I apologize if I miss some obvious things. Thank
you very much!

This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
users@infra.apache.org
