tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From tqc...@apache.org
Subject [incubator-tvm] branch master updated: [TIR][PASS] Remove legacy HoistIfThenElse (#5944)
Date Sat, 27 Jun 2020 21:56:26 GMT
This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new e99e116  [TIR][PASS] Remove legacy HoistIfThenElse (#5944)
e99e116 is described below

commit e99e11657191fd230f7d109a77ec48c1643a9f25
Author: Tianqi Chen <tqchen@users.noreply.github.com>
AuthorDate: Sat Jun 27 14:56:13 2020 -0700

    [TIR][PASS] Remove legacy HoistIfThenElse (#5944)
    
    This pass has not been migrated to the new transform API,
    and contains potential bugs per https://github.com/apache/incubator-tvm/issues/5559.
    Given that it is not being actively used, this PR remove this pass
    from the collection.
    
    Followup PRs are more than welcomed to land a better version that
    conforms with the new transform API.
---
 src/tir/pass/hoist_if_then_else.cc              | 404 ------------------------
 tests/python/unittest/test_tir_pass_hoist_if.py | 186 -----------
 2 files changed, 590 deletions(-)

diff --git a/src/tir/pass/hoist_if_then_else.cc b/src/tir/pass/hoist_if_then_else.cc
deleted file mode 100644
index d1e24b9..0000000
--- a/src/tir/pass/hoist_if_then_else.cc
+++ /dev/null
@@ -1,404 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file hoist_if_then_else.cc
- */
-#include <tvm/arith/analyzer.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/tir/expr.h>
-#include <tvm/tir/stmt_functor.h>
-
-#include <queue>
-#include <unordered_map>
-#include <unordered_set>
-
-#include "../../arith/interval_set.h"
-#include "../../runtime/thread_storage_scope.h"
-
-namespace tvm {
-namespace tir {
-
-using HoistMap = std::unordered_map<const Object*, std::vector<Stmt>>;
-using VarMap = std::unordered_map<const Object*, std::unordered_set<const Object*>>;
-
-/*
- * This pass tries to hoist IfThenElse stmt out of For loop if condition is loop invariant.
- * For example, given the following block:
- * for (i = 0; i < 3; i++)
- *    for (j = 0; j < 4; j++)
- *        for (k = 0; k < 5; k++)
- *            if (likely(i*2 < 4))
- *                A[3*i+2j+k] = B[7*i+3j+k]
- *
- * We first detect all IfThenElse stmt and find the corresponding loop invariant For stmt.
- * Then we hoist IfThenElse stmt by one For stmt each step:
- *
- * Step 1:
- * for (i = 0; i < 3; i++)
- *     for (j = 0; j < 4; j++)
- *         if (likely(i*2 < 4))
- *             for (k = 0; k < 5; k++)
- *                 A[3*i+2j+k] = B[7*i+3j+k]
- *
- * Step 2:
- * for (i = 0; i < 3; i++)
- *     if (likely(i*2 < 4))
- *         for (j = 0; j < 4; j++)
- *             for (k = 0; k < 5; k++)
- *                 A[3*i+2j+k] = B[7*i+3j+k]
- *
- * In this pass, we only continue detecting possible hoisting chance when visiting For,
- * IfThenElse or AttrStmt Node. For example, for the following block:
- * for (i = 0; i < 3; i++)
- *    for (j = 0; j < 4; j++)
- *        A[i + j] = A[i + j] - 1
- *        for (k = 0; k < 5; k++)
- *            if (likely(i*2 < 4))
- *                A[3*i+2j+k] = B[7*i+3j+k]
- *
- * Only the For with k variable will be considered and the resulting stmt would be:
- * for (i = 0; i < 3; i++)
- *    for (j = 0; j < 4; j++)
- *        A[i + j] = A[i + j] - 1
- *        if (likely(i*2 < 4))
- *            for (k = 0; k < 5; k++)
- *                A[3*i+2j+k] = B[7*i+3j+k]
- *
- * This pass doesn't do hoisting for consecutive IfThenElse stmt. The following
- * block won't be optimized:
- * for (i = 0; i < 3; i++)
- *    for (j = 0; j < 4; j++)
- *        for (k = 0; k < 5; k++)
- *            if (likely(i*2 < 4))
- *                A[3*i+2j+k] = B[7*i+3j+k]
- *            if (likely(j > 2))
- *                A[i+j+k] = B[i+j+k]
- *
- */
-class IfThenElseHoist {
- public:
-  Stmt VisitAndMutate(const Stmt& stmt) {
-    SelectCandidates(stmt);
-    LocateTopFor();
-    return PostOrderMutate(stmt);
-  }
-
- private:
-  void SelectCandidates(const Stmt& stmt);
-  void LocateTopFor();
-  Stmt PostOrderMutate(const Stmt& stmt);
-  size_t GetUpdatedFor(const Stmt& for_stmt, const Stmt& if_stmt);
-  Stmt HoistIf(const Stmt& if_stmt);
-
-  // Map of all For nodes to all child IfThenElse nodes.
-  HoistMap for2if_map_;
-  // Map of all IfThenElse nodes to all For nodes which are loop invariant.
-  HoistMap if2for_map_;
-  // Map of highest loop invariant For to child IfThenElse.
-  HoistMap top_for_var_map_;
-  // Map of original For to list of update For nodes.
-  HoistMap for_tracking_map_;
-  // Map of all IfThenElse nodes to condition variable nodes.
-  VarMap cond_var_map_;
-  // List of For nodes added in post order DFS visiting.
-  std::vector<Stmt> ordered_for_list_;
-};
-
-// Check whether a given IfThenElse stmt is the first one appearing
-// in a For stmt.
-bool is_first_if(const Stmt& for_stmt, const Stmt& if_stmt) {
-  std::vector<const Object*> if_node_list;
-  const ForNode* for_node = for_stmt.as<ForNode>();
-  CHECK(for_node);
-  CHECK(if_stmt.as<IfThenElseNode>());
-
-  PostOrderVisit(for_node->body, [&](const ObjectRef& node) {
-    if (node.as<IfThenElseNode>()) {
-      if_node_list.push_back(node.get());
-    }
-  });
-  return if_node_list.empty() ? false : if_stmt.get() == if_node_list.back();
-}
-
-// Update upper level For node when current For node is modified.
-// With this function we only need to visit and mutate top level For node
-// in the main VisitAndMutate function.
-Stmt update_for(const Stmt& parent_for_stmt, const Stmt& new_if_stmt) {
-  const Object* top_for_node;
-  const ForNode* parent_for_node = parent_for_stmt.as<ForNode>();
-  CHECK(parent_for_node);
-  CHECK(new_if_stmt.as<IfThenElseNode>());
-
-  PostOrderVisit(parent_for_node->body, [&](const ObjectRef& node) {
-    if (node.as<ForNode>()) {
-      top_for_node = node.get();
-    }
-  });
-
-  PackedFunc replace_target_for = PackedFunc([&](TVMArgs args, TVMRetValue* ret) {
-    const ObjectRef& current_for = args[0];
-    if (current_for.get() == top_for_node) {
-      *ret = new_if_stmt;
-    }
-  });
-
-  return IRTransform(parent_for_stmt, nullptr, replace_target_for, Array<String>{"tir.For"});
-}
-
-// Remove IfThenElse node from a For node.
-// A pair of For nodes will be generated.
-std::pair<Stmt, Stmt> RemoveIf(const Stmt& for_stmt, const Stmt& if_stmt) {
-  Stmt then_for;
-  Stmt else_for;
-  CHECK(if_stmt.as<IfThenElseNode>());
-
-  PackedFunc replace_then_case = PackedFunc([&](TVMArgs args, TVMRetValue* ret) {
-    const ObjectRef& node = args[0];
-    if (node == if_stmt) {
-      *ret = node.as<IfThenElseNode>()->then_case;
-    }
-  });
-
-  PackedFunc replace_else_case = PackedFunc([&](TVMArgs args, TVMRetValue* ret) {
-    const ObjectRef& node = args[0];
-    if (node == if_stmt) {
-      *ret = node.as<IfThenElseNode>()->else_case;
-    }
-  });
-
-  then_for = IRTransform(for_stmt, nullptr, replace_then_case, Array<String>{"tir.IfThenElse"});
-  if (if_stmt.as<IfThenElseNode>()->else_case.defined()) {
-    else_for = IRTransform(for_stmt, nullptr, replace_else_case, Array<String>{"tir.IfThenElse"});
-  }
-
-  return std::make_pair(then_for, else_for);
-}
-
-// Locate all For nodes and capture child IfThenElse nodes.
-void IfThenElseHoist::SelectCandidates(const Stmt& stmt) {
-  PostOrderVisit(stmt, [&](const ObjectRef& node) {
-    const ForNode* for_node = node.as<ForNode>();
-    if (!for_node) return;
-
-    std::queue<Stmt> tracker;
-    tracker.push(for_node->body);
-    Stmt for_stmt = Downcast<Stmt, ObjectRef>(node);
-    for2if_map_.insert({for_stmt.get(), std::vector<Stmt>()});
-    while (!tracker.empty()) {
-      Stmt head = tracker.front();
-      tracker.pop();
-      if (head->IsInstance<ForNode>()) {
-        for (const auto& if_stmt : for2if_map_.at(head.get())) {
-          for2if_map_[for_stmt.get()].push_back(if_stmt);
-        }
-      } else if (head->IsInstance<AttrStmtNode>()) {
-        const AttrStmtNode* attr_node = head.as<AttrStmtNode>();
-        tracker.push(attr_node->body);
-      } else if (head->IsInstance<IfThenElseNode>()) {
-        for2if_map_[for_stmt.get()].push_back(head);
-        const IfThenElseNode* if_node = head.as<IfThenElseNode>();
-        tracker.push(if_node->then_case);
-        if (if_node->else_case.defined()) {
-          tracker.push(if_node->else_case);
-        }
-
-        // Record condition variables.
-        if (!cond_var_map_.count(head.get())) {
-          std::unordered_set<const Object*> new_var_set;
-          cond_var_map_.insert({head.get(), new_var_set});
-          PostOrderVisit(if_node->condition, [&](const ObjectRef& cond_node) {
-            if (cond_node.as<VarNode>()) {
-              cond_var_map_[head.get()].insert(cond_node.get());
-            }
-          });
-        }
-      } else {
-        continue;
-      }
-    }
-    ordered_for_list_.emplace_back(Downcast<Stmt, ObjectRef>(node));
-  });
-}
-
-// For each IfThenElse node, find the highest For node which
-// meets loop invariant condition.
-void IfThenElseHoist::LocateTopFor() {
-  std::unordered_map<const Object*, Stmt> if_position_map;
-  std::unordered_set<const Object*> top_for_var_set;
-
-  // Create IfThenElse -> For map.
-  for (const Stmt& for_stmt : ordered_for_list_) {
-    std::vector<Stmt> if_list = for2if_map_[for_stmt.get()];
-    const ForNode* for_node = for_stmt.as<ForNode>();
-    CHECK(for_node);
-    top_for_var_map_.insert({for_node->loop_var.get(), if_list});
-    for (const Stmt& if_stmt : if_list) {
-      const Object* if_node = if_stmt.get();
-      if2for_map_[if_node].push_back(for_stmt);
-    }
-  }
-
-  // Locate the highest For node which is loop invariant.
-  for (const auto& item : if2for_map_) {
-    Stmt top_for;
-    const Object* if_stmt = item.first;
-    std::vector<Stmt> for_list = item.second;
-    for (size_t i = 0; i < for_list.size(); ++i) {
-      const Stmt& for_stmt = for_list.at(i);
-      const ForNode* for_node = for_stmt.as<ForNode>();
-      CHECK(for_node);
-      std::vector<Stmt> new_for_list{for_stmt};
-      for_tracking_map_.insert({for_stmt.get(), new_for_list});
-      if (cond_var_map_[if_stmt].count(for_node->loop_var.get())) {
-        std::vector<Stmt> updated_for_list(for_list.begin(), for_list.begin() + i);
-        if2for_map_[if_stmt] = updated_for_list;
-        break;
-      } else {
-        top_for = for_stmt;
-      }
-    }
-    if (top_for.as<ForNode>()) {
-      if_position_map.insert({if_stmt, top_for});
-    }
-  }
-
-  for (const auto& item : if_position_map) {
-    top_for_var_set.insert(item.second.as<ForNode>()->loop_var.get());
-  }
-
-  std::vector<const Object*> removed_for_var_list;
-  for (const auto& item : top_for_var_map_) {
-    const Object* top_for_var = item.first;
-    std::vector<Stmt> if_list = item.second;
-    if (!top_for_var_set.count(top_for_var)) {
-      removed_for_var_list.push_back(top_for_var);
-    } else {
-      std::vector<Stmt> actual_if_list;
-      for (const Stmt& if_stmt : if_list) {
-        if (if_position_map.count(if_stmt.get())) {
-          actual_if_list.push_back(if_stmt);
-        }
-      }
-      top_for_var_map_[top_for_var] = actual_if_list;
-    }
-  }
-  for (const Object* top_for_var : removed_for_var_list) {
-    top_for_var_map_.erase(top_for_var);
-  }
-}
-
-// When we try to mutate a For node, some child For nodes can have already
-// been mutated. This function is to get the updated For node and further
-// hoisting can be done based on this new node.
-// We keep all For nodes tracing in for_tracking_map_. When we get a
-// hoisted IfThenElse, we match it with tracing For nodes to pick
-// the updated one.
-size_t IfThenElseHoist::GetUpdatedFor(const Stmt& for_stmt, const Stmt& if_stmt)
{
-  std::vector<Stmt> tracked_for_list = for_tracking_map_[for_stmt.get()];
-  size_t updated_for_idx = 0;
-  for (size_t i = 0; i < tracked_for_list.size(); ++i) {
-    const Stmt& current_for = tracked_for_list.at(tracked_for_list.size() - 1 - i);
-    if (is_first_if(current_for, if_stmt)) {
-      updated_for_idx = tracked_for_list.size() - 1 - i;
-      break;
-    }
-  }
-  return updated_for_idx;
-}
-
-// Hoist an IfThenElse node as high as possible.
-// This function iterates on all candidate For nodes. For each For node,
-// it first removes IfThenElse nodes. Then it generates a new IfThenElse
-// node using mutated For nodes.
-Stmt IfThenElseHoist::HoistIf(const Stmt& if_stmt) {
-  Stmt new_if = if_stmt;
-
-  for (size_t i = 0; i < if2for_map_[if_stmt.get()].size(); ++i) {
-    const Stmt& for_stmt = if2for_map_[if_stmt.get()].at(i);
-    size_t updated_for_idx = GetUpdatedFor(for_stmt, new_if);
-    const Stmt& updated_for_node = for_tracking_map_[for_stmt.get()].at(updated_for_idx);
-    auto generated_for_pair = RemoveIf(updated_for_node, new_if);
-    const Stmt& then_for = generated_for_pair.first;
-    const Stmt& else_for = generated_for_pair.second;
-
-    for_tracking_map_[for_stmt.get()].at(updated_for_idx) = then_for;
-
-    if (else_for.get()) {
-      for_tracking_map_[for_stmt.get()].push_back(else_for);
-    }
-
-    const IfThenElseNode* new_if_node = new_if.as<IfThenElseNode>();
-    CHECK(new_if_node);
-    new_if = IfThenElse(new_if_node->condition, then_for, else_for);
-    if (i < if2for_map_[if_stmt.get()].size() - 1) {
-      const Stmt& original_next_for = if2for_map_[if_stmt.get()].at(i + 1);
-      const Stmt& actual_next_for = for_tracking_map_[original_next_for.get()].at(updated_for_idx);
-      Stmt update_for_stmt = update_for(actual_next_for, new_if);
-
-      for_tracking_map_[original_next_for.get()].at(updated_for_idx) = update_for_stmt;
-    }
-  }
-  return new_if;
-}
-
-// Mutate For nodes in post order DFS manner.
-Stmt IfThenElseHoist::PostOrderMutate(const Stmt& stmt) {
-  PackedFunc replace_top_for = PackedFunc([&](TVMArgs args, TVMRetValue* ret) {
-    const ObjectRef& current_for = args[0];
-    const ForNode* for_node = current_for.as<ForNode>();
-    if (!for_node) return;
-
-    if (top_for_var_map_.count(for_node->loop_var.get())) {
-      std::vector<Stmt> new_if_list;
-      for (const Stmt& if_stmt : top_for_var_map_[for_node->loop_var.get()]) {
-        new_if_list.emplace_back(HoistIf(if_stmt));
-      }
-
-      const IfThenElseNode* next_if_node;
-      const IfThenElseNode* current_if_node = new_if_list.back().as<IfThenElseNode>();
-      Stmt new_for = Stmt();
-      for (size_t i = new_if_list.size() - 1; i > 0; --i) {
-        CHECK(current_if_node);
-        const Stmt current_if_stmt = IfThenElse(
-            current_if_node->condition, current_if_node->then_case, current_if_node->else_case);
-        next_if_node = new_if_list[i - 1].as<IfThenElseNode>();
-        CHECK(next_if_node);
-        new_for = IfThenElse(next_if_node->condition, current_if_stmt, next_if_node->else_case);
-        current_if_node = new_for.as<IfThenElseNode>();
-      }
-
-      if (!new_for.get()) {
-        const IfThenElseNode* first_if_node = new_if_list[0].as<IfThenElseNode>();
-        CHECK(first_if_node);
-        new_for = IfThenElse(first_if_node->condition, first_if_node->then_case,
-                             first_if_node->else_case);
-      }
-      *ret = new_for;
-    }
-  });
-  return IRTransform(stmt, nullptr, replace_top_for, Array<String>{"tir.For"});
-}
-
-Stmt HoistIfThenElse(Stmt stmt) { return IfThenElseHoist().VisitAndMutate(stmt); }
-
-TVM_REGISTER_GLOBAL("testing.HoistIfThenElse").set_body_typed(HoistIfThenElse);
-
-}  // namespace tir
-}  // namespace tvm
diff --git a/tests/python/unittest/test_tir_pass_hoist_if.py b/tests/python/unittest/test_tir_pass_hoist_if.py
deleted file mode 100644
index 80e93a7..0000000
--- a/tests/python/unittest/test_tir_pass_hoist_if.py
+++ /dev/null
@@ -1,186 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-import tvm
-from tvm import te
-
-
-var_list = []
-
-def verify_structure(stmt, expected_struct):
-    node_dict = {}
-    struct = {}
-    def _extract_vars(op):
-        global var_list
-        if isinstance(op, tvm.tir.Var):
-            var_list.append(op.name)
-
-    def _visit(op):
-        key = op
-        if isinstance(op, tvm.tir.IfThenElse):
-            global var_list
-            tvm.tir.stmt_functor.post_order_visit(op.condition, _extract_vars)
-            val = [(op.then_case, op.else_case), ("tir.IfThenElse", tuple(var_list))]
-            var_list.clear()
-        elif isinstance(op, tvm.tir.For):
-            val = [(op.body,), ("tir.For", op.loop_var.name)]
-        elif isinstance(op, tvm.tir.AttrStmt):
-            val = [(op.body,), ("tir.AttrStmt", op.attr_key, int(op.value))]
-        else:
-            return
-        node_dict[key] = val
-
-    tvm.tir.stmt_functor.post_order_visit(stmt, _visit)
-    for key, val in node_dict.items():
-        struct[val[1]] = tuple(node_dict[child][1] if child in node_dict
-                               else None for child in val[0])
-
-    assert struct == expected_struct, "Structure mismatch: expect %s but got %s" \
-                                      % (expected_struct, struct)
-    var_list.clear()
-
-def test_basic():
-    ib = tvm.tir.ir_builder.create()
-    l = te.var('l')
-    m = te.var('m')
-    n = te.var('n')
-
-    with ib.for_range(0, l, "i") as i:
-        with ib.for_range(0, m, "j") as j:
-            with ib.for_range(0, n, "k") as k:
-                with ib.if_scope(ib.likely(i < 2)):
-                    ib.emit(tvm.tir.Evaluate(m))
-                with ib.else_scope():
-                    ib.emit(tvm.tir.Evaluate(n))
-
-    stmt = ib.get()
-    new_stmt = tvm.testing.HoistIfThenElse(stmt)
-    expected_struct = {('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'k'),),
-                       ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), ('tir.For', 'j')),
-                       ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
-    verify_structure(new_stmt, expected_struct)
-
-def test_no_else():
-    ib = tvm.tir.ir_builder.create()
-    l = te.var('l')
-    m = te.var('m')
-    n = te.var('n')
-
-    with ib.for_range(0, l, "i") as i:
-        with ib.for_range(0, m, "j") as j:
-            with ib.for_range(0, n, "k") as k:
-                with ib.if_scope(ib.likely(i < 2)):
-                    ib.emit(tvm.tir.Evaluate(m))
-
-    stmt = ib.get()
-    new_stmt = tvm.testing.HoistIfThenElse(stmt)
-    expected_struct = {('tir.For', 'k'): (None,), ('tir.For', 'j'): (('tir.For', 'k'),),
-                       ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None),
-                       ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
-    verify_structure(new_stmt, expected_struct)
-
-def test_attr_stmt():
-    ib = tvm.tir.ir_builder.create()
-    dshape = (32, 64)
-    data = ib.pointer("float32", name="data")
-    l = te.var('l')
-    m = te.var('m')
-    n = te.var('n')
-
-    tx = te.thread_axis("threadIdx.x")
-    bx = te.thread_axis("blockIdx.x")
-    ib.scope_attr(tx, "thread_extent", dshape[0])
-    ib.scope_attr(bx, "thread_extent", dshape[1])
-    with ib.for_range(0, l, "i") as i:
-        with ib.for_range(0, m, "j") as j:
-            with ib.for_range(0, n, "k") as k:
-                with ib.if_scope(tvm.tir.any(i < 4, j >= 8)):
-                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 0.5
-                with ib.else_scope():
-                    data[bx * j + tx * j * k] = data[bx * j + tx * j * k]  + 1.0
-
-    stmt = ib.get()
-    new_stmt = tvm.testing.HoistIfThenElse(stmt)
-    expected_struct = {('tir.For', 'k'): (None,), ('tir.IfThenElse', ('i', 'j')): (('tir.For',
'k'), ('tir.For', 'k')),
-                       ('tir.For', 'j'): (('tir.IfThenElse', ('i', 'j')),), ('tir.For', 'i'):
(('tir.For', 'j'),),
-                       ('tir.AttrStmt', 'thread_extent', 64): (('tir.For', 'i'),),
-                       ('tir.AttrStmt', 'thread_extent', 32): (('tir.AttrStmt', 'thread_extent',
64),)}
-    verify_structure(new_stmt, expected_struct)
-
-def test_nested_for():
-    ib = tvm.tir.ir_builder.create()
-    data = ib.pointer("float32", name="data")
-
-
-    with ib.for_range(0, 5, "i") as i:
-        with ib.for_range(0, 10, "j") as j:
-            with ib.if_scope(i >= 3):
-                data[i * 3 + j] = data[i * 3 + j] + 0.5
-                with ib.for_range(0, 15, "k") as k:
-                    with ib.for_range(0, 20, "l") as l:
-                        with ib.if_scope(tvm.tir.any(i < 4, j >= 8)):
-                            data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2
-                        with ib.else_scope():
-                            data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5
-
-    stmt = ib.get()
-    new_stmt = tvm.testing.HoistIfThenElse(stmt)
-    expected_struct = {('tir.IfThenElse', ('i', 'j')): (None, None), ('tir.For', 'l'): (('tir.IfThenElse',
('i', 'j')),),
-                       ('tir.For', 'k'): (('tir.For', 'l'),), ('tir.For', 'j'): (None,),
('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None),
-                       ('tir.For', 'i'): (('tir.IfThenElse', ('i',)),)}
-    verify_structure(new_stmt, expected_struct)
-
-def test_if_block():
-    ib = tvm.tir.ir_builder.create()
-    data = ib.pointer("float32", name="data")
-    n = te.var("n")
-
-
-    with ib.for_range(0, 5, "i") as i:
-        with ib.for_range(0, 10, "j") as j:
-            with ib.if_scope(i >= 3):
-                data[i * 3 + j] = data[i * 3 + j] + 0.5
-                with ib.for_range(0, 15, "k") as k:
-                    with ib.for_range(0, 20, "l") as l:
-                        with ib.if_scope(tvm.tir.any(i < 4, j >= 8)):
-                            data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 2
-                        with ib.else_scope():
-                            data[i * 3 + j + k + l] = data[i * 3 + j + k + l] * 1.5
-                        with ib.if_scope(j <5):
-                            data[i * 3 + j + k + l] = data[i * 3 + j + k + l] - 1
-
-
-    with ib.for_range(0, 5, "i") as i:
-        with ib.for_range(0, 10, "j") as j:
-                with ib.for_range(0, 15, "k") as k:
-                    with ib.if_scope(n >= 3):
-                        data[i * 3 + j + k] = data[i * 3 + j + k] + 0.6
-
-    stmt = ib.get()
-    new_stmt = tvm.testing.HoistIfThenElse(stmt)
-    expected_struct = {('tir.IfThenElse', ('i', 'j')): (None, None), ('tir.IfThenElse', ('j',)):
(None, None),
-                       ('tir.For', 'l'): (None,), ('tir.For', 'k'): (None,), ('tir.For',
'j'): (('tir.For', 'j'),),
-                       ('tir.IfThenElse', ('i',)): (('tir.For', 'j'), None), ('tir.For',
'i'): (('tir.IfThenElse', ('i',)),),
-                       ('tir.IfThenElse', ('n',)): (('tir.For', 'j'), None)}
-    verify_structure(new_stmt, expected_struct)
-
-
-if __name__ == "__main__":
-    test_basic()
-    test_no_else()
-    test_attr_stmt()
-    test_nested_for()
-    test_if_block()


Mime
View raw message