tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] siju-samuel commented on a change in pull request #5249: [PYTORCH]LayerNorm support added
Date Tue, 07 Apr 2020 03:22:18 GMT
siju-samuel commented on a change in pull request #5249: [PYTORCH]LayerNorm support added
URL: https://github.com/apache/incubator-tvm/pull/5249#discussion_r404512746
 
 

 ##########
 File path: python/tvm/relay/frontend/pytorch.py
 ##########
 @@ -503,6 +503,34 @@ def _impl(inputs, input_types):
                                     scale=scale)
     return _impl
 
+def _get_dims(data):
+    import torch
+    if isinstance(data, _expr.Expr):
+        dims = _infer_shape(data)
+    elif isinstance(data, list):
+        dims = data
+    elif isinstance(data, (torch.Tensor, np.ndarray)):
+        dims = data.shape
+    else:
+        msg = "Data type %s could not be parsed" % type(data)
+        raise AssertionError(msg)
+    return dims
+
+def _layer_norm():
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        ndims = len(_get_dims(inputs[1]))
+        assert ndims == 1, "Support only normalization over last one dimension."
+
+        return _op.nn.layer_norm(data,
+                                 gamma=inputs[1],
+                                 beta=inputs[2],
+                                 axis=-1,
 
 Review comment:
   @masahi Thanks a lot for pointing it out. Along with changing the inputs, need to update
the center and scale flags to True. My testcase i was not initializing weights and thus it
couldnt capure this mistake.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message