tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From zha...@apache.org
Subject [incubator-tvm] branch master updated: [KERAS]Embedding layer (#5444)
Date Sun, 26 Apr 2020 02:58:10 GMT
This is an automated email from the ASF dual-hosted git repository.

zhaowu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new 1014fef  [KERAS]Embedding layer (#5444)
1014fef is described below

commit 1014fefa54b5f0a359501b6d19ea3b5a52d6dca6
Author: Samuel <siju.samuel@huawei.com>
AuthorDate: Sun Apr 26 08:28:02 2020 +0530

    [KERAS]Embedding layer (#5444)
---
 python/tvm/relay/frontend/keras.py          | 10 +++++++++-
 tests/python/frontend/keras/test_forward.py | 20 +++++++++++++++++++-
 2 files changed, 28 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relay/frontend/keras.py b/python/tvm/relay/frontend/keras.py
index bf91bc1..43065be 100644
--- a/python/tvm/relay/frontend/keras.py
+++ b/python/tvm/relay/frontend/keras.py
@@ -207,6 +207,14 @@ def _convert_permute(inexpr, keras_layer, _):
     return _op.transpose(inexpr, axes=(0,) + keras_layer.dims)
 
 
+def _convert_embedding(inexpr, keras_layer, etab):
+    indices = inexpr
+    weightList = keras_layer.get_weights()
+    weight = etab.new_const(weightList[0])
+    out = _op.take(weight, indices.astype('int32'), axis=0)
+
+    return out
+
 def _convert_dense(inexpr, keras_layer, etab):
     weightList = keras_layer.get_weights()
     weight = etab.new_const(weightList[0].transpose([1, 0]))
@@ -893,7 +901,7 @@ _convert_map = {
     'Maximum'                  : _convert_merge,
     'Dot'                      : _convert_merge,
     'Permute'                  : _convert_permute,
-    # 'Embedding'              : _convert_embedding,
+    'Embedding'                : _convert_embedding,
     # 'RepeatVector'           : _convert_repeat_vector,
 
     'InputLayer'               : _default_skip,
diff --git a/tests/python/frontend/keras/test_forward.py b/tests/python/frontend/keras/test_forward.py
index b764137..b4a1816 100644
--- a/tests/python/frontend/keras/test_forward.py
+++ b/tests/python/frontend/keras/test_forward.py
@@ -466,6 +466,24 @@ class TestKeras:
             keras_model = keras.models.Model(data, x)
             verify_keras_frontend(keras_model, layout='NDHWC')
 
+
+    def test_forward_embedding(self, keras):
+        data = keras.layers.Input(shape=(2, 4), dtype="int32")
+        x = keras.layers.Embedding(10, 3)(data)
+        keras_model = keras.models.Model(data, x)
+        verify_keras_frontend(keras_model, need_transpose=False)
+
+        data = keras.layers.Input(shape=(2, 3, 4), dtype="int32")
+        x = keras.layers.Embedding(4, 5)(data)
+        keras_model = keras.models.Model(data, x)
+        verify_keras_frontend(keras_model, need_transpose=False)
+
+        data = keras.layers.Input(shape=(6, 2, 3, 4), dtype="int32")
+        x = keras.layers.Embedding(4, 5)(data)
+        keras_model = keras.models.Model(data, x)
+        verify_keras_frontend(keras_model, need_transpose=False)
+
+
 if __name__ == '__main__':
     for k in [keras, tf_keras]:
         sut = TestKeras()
@@ -497,4 +515,4 @@ if __name__ == '__main__':
         sut.test_forward_pool3d(keras=k)
         sut.test_forward_upsample3d(keras=k)
         sut.test_forward_zero_padding3d(keras=k)
-
+        sut.test_forward_embedding(keras=k)


Mime
View raw message