tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tqc...@apache.org
Subject [incubator-tvm] 01/01: Revert "Conditions updated to cover better user scenarios (#4951)"
Date Tue, 10 Mar 2020 20:58:51 GMT
This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch revert-4951-relay-test
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git

commit 07bac243ef99abc2f3bce835765bf1898b8d0c70
Author: Tianqi Chen <tqchen@users.noreply.github.com>
AuthorDate: Tue Mar 10 13:58:37 2020 -0700

    Revert "Conditions updated to cover better user scenarios (#4951)"
    
    This reverts commit fe74b37ab578e6d3c540b0f6ac187a220ccc028a.
---
 src/relay/ir/alpha_equal.cc                 | 10 ++---
 tests/cpp/relay_pass_alpha_equal.cc         | 67 -----------------------------
 tests/python/relay/test_pass_alpha_equal.py | 32 --------------
 3 files changed, 5 insertions(+), 104 deletions(-)

diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc
index c622599..78688d7 100644
--- a/src/relay/ir/alpha_equal.cc
+++ b/src/relay/ir/alpha_equal.cc
@@ -50,14 +50,14 @@ class AlphaEqualHandler:
    * \return The comparison result.
    */
   bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
-    if (!lhs.defined() || !rhs.defined()) return false;
     if (lhs.same_as(rhs)) return true;
-    if (lhs->IsInstance<TypeNode>() || rhs->IsInstance<TypeNode>()) {
-      if (!rhs->IsInstance<TypeNode>() || !lhs->IsInstance<TypeNode>())
return false;
+    if (!lhs.defined() || !rhs.defined()) return false;
+    if (lhs->IsInstance<TypeNode>()) {
+      if (!rhs->IsInstance<TypeNode>()) return false;
       return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
     }
-    if (lhs->IsInstance<ExprNode>() || rhs->IsInstance<ExprNode>()) {
-      if (!rhs->IsInstance<ExprNode>() || !lhs->IsInstance<ExprNode>())
return false;
+    if (lhs->IsInstance<ExprNode>()) {
+      if (!rhs->IsInstance<ExprNode>()) return false;
       return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
     }
     if (const auto lhsm = lhs.as<IRModuleNode>()) {
diff --git a/tests/cpp/relay_pass_alpha_equal.cc b/tests/cpp/relay_pass_alpha_equal.cc
deleted file mode 100644
index 0207fca..0000000
--- a/tests/cpp/relay_pass_alpha_equal.cc
+++ /dev/null
@@ -1,67 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include <gtest/gtest.h>
-#include <tvm/te/operation.h>
-#include <tvm/relay/expr.h>
-#include <tvm/relay/type.h>
-#include <tvm/relay/analysis.h>
-#include <tvm/relay/transform.h>
-
-using namespace tvm;
-
-class TestAlphaEquals {
-  runtime::PackedFunc *_packed_func;
- public:
-  TestAlphaEquals(const char* func_name) {
-    _packed_func = new runtime::PackedFunc();
-    TVMFuncGetGlobal(func_name, reinterpret_cast<TVMFunctionHandle*>(&_packed_func));
-  }
-
-  void UpdatePackedFunc(const char* func_name) {
-    TVMFuncGetGlobal(func_name, reinterpret_cast<TVMFunctionHandle*>(&_packed_func));
-  }
-
-  bool operator()(ObjectRef input_1, ObjectRef input_2) {
-    TVMRetValue rv;
-    std::vector<TVMValue> values(2);
-    std::vector<int> codes(2);
-    runtime::TVMArgsSetter setter(values.data(), codes.data());
-    setter(0, input_1);
-    setter(1, input_2);
-    _packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv);
-    return bool(rv);
-  };
-
-};
-
-TEST(Relay, AlphaTestEmptyTypeNodes) {
-  auto x = TypeVar("x", kTypeData);
-  auto y = TypeVar();
-  EXPECT_FALSE(relay::AlphaEqual(x, y));
-
-  TestAlphaEquals test_equals("relay._make._alpha_equal");
-  EXPECT_FALSE(test_equals(x, y));
-}
-
-int main(int argc, char ** argv) {
-  testing::InitGoogleTest(&argc, argv);
-  testing::FLAGS_gtest_death_test_style = "threadsafe";
-  return RUN_ALL_TESTS();
-}
diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py
index ec026be..7e34f48 100644
--- a/tests/python/relay/test_pass_alpha_equal.py
+++ b/tests/python/relay/test_pass_alpha_equal.py
@@ -28,15 +28,6 @@ def alpha_equal(x, y):
     """
     return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y)
 
-def alpha_equal_commutative(x, y):
-    """
-    Check for commutative property of equality
-    """
-    xy = analysis.alpha_equal(x, y)
-    yx = analysis.alpha_equal(y, x)
-    assert xy == yx
-    return xy
-
 def test_tensor_type_alpha_equal():
     t1 = relay.TensorType((3, 4), "float32")
     t2 = relay.TensorType((3, 4), "float32")
@@ -228,26 +219,6 @@ def test_constant_alpha_equal():
     assert not alpha_equal(x, y)
     assert alpha_equal(x, relay.const(1))
 
-def test_type_node_alpha_equal():
-    v1 = relay.TypeVar('v1', 6)
-    v2 = relay.TypeVar('v2', 6)
-    assert not alpha_equal(v1, v2)
-
-    v1 = relay.TypeVar('v1', 0)
-    v2 = relay.TypeVar('v2', 6)
-    assert not alpha_equal(v1, v2)
-
-    assert alpha_equal_commutative(v1, v1)
-
-def test_type_node_incompatible_alpha_equal():
-    v1 = relay.TypeVar('v1', 6)
-    v2 = relay.Var("v2")
-    assert not alpha_equal_commutative(v1, v2)
-
-def test_expr_node_incompatible_alpha_equal():
-    v1 = relay.Var("v1")
-    v2 = relay.PatternVar(relay.Var("v2"))
-    assert not alpha_equal_commutative(v1, v2)
 
 def test_var_alpha_equal():
     v1 = relay.Var("v1")
@@ -705,9 +676,6 @@ if __name__ == "__main__":
     test_tensor_type_alpha_equal()
     test_incomplete_type_alpha_equal()
     test_constant_alpha_equal()
-    test_type_node_alpha_equal()
-    test_type_node_incompatible_alpha_equal()
-    test_expr_node_incompatible_alpha_equal()
     test_func_type_alpha_equal()
     test_tuple_type_alpha_equal()
     test_type_relation_alpha_equal()


Mime
View raw message