tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] yzhliu commented on a change in pull request #5121: [TE] reverse-mode autodiff without any optimization
Date Tue, 24 Mar 2020 06:59:23 GMT
yzhliu commented on a change in pull request #5121: [TE] reverse-mode autodiff without any
optimization
URL: https://github.com/apache/incubator-tvm/pull/5121#discussion_r396935671
 
 

 ##########
 File path: src/te/autodiff/jacobian.cc
 ##########
 @@ -0,0 +1,381 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file jacobian.cc
+ * \brief Calculate Jacobian of two tensors dY/dX.
+ *        X must be direct input tensor of Y.
+ *        The result Jacobian shape will be (Y.shape, X.shape)
+ *        The algorithm was initially implemented by Sergei Grechanik (sgrechanik-h)
+ *        in [Automatic differentiation for tensor expressions](#2498)
+ *        and [Zero elimination](#2634)
+ */
+#include <tvm/te/autodiff.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/tir/stmt_functor.h>
+#include <topi/transform.h>
+#include <memory>
+#include "ad_util.h"
+
+namespace tvm {
+namespace te {
+
+#define NOT_IMPLEMENTED \
+  { LOG(FATAL) << "Derivative of this expr is not implemented: " << GetRef<PrimExpr>(op);
throw; }
+
+/*! \brief Differentiate an expression wrt a variable or a tensor element */
+class JacobianMutator : public ExprMutator {
+ public:
+  /*!
+   * \brief Differentiate wrt `input(indices)`.
+   * \param input The input tensor.
+   * \param indices The indices of the element with respect to which to differentiate.
+   */
+  explicit JacobianMutator(Tensor input, Array<PrimExpr> indices)
+    : input_(input), indices_(indices) {}
+  /*!
+   * \brief Differentiate wrt the input variable.
+   * \param input The input variable.
+   */
+  explicit JacobianMutator(Var input) : input_var_(input) {}
+
+  PrimExpr Mutate(PrimExpr e) {
+    if (e.dtype().is_int() || e.dtype().is_uint()) {
+      LOG(WARNING) << "For now we assume that the derivative of any integer expression
is always 0."
+                   << " e = " << e;
+      return make_zero(e.dtype());
+    } else {
+      return ExprMutator::VisitExpr(e);
+    }
+  }
+
+  PrimExpr VisitExpr_(const VarNode* op) {
+    if (input_var_.get() && input_var_.get() == op && op->dtype.is_float())
{
+      return FloatImm(op->dtype, 1.0);
+    } else {
+      return make_zero(op->dtype);
+    }
+  }
+
+  PrimExpr VisitExpr_(const LoadNode* op) NOT_IMPLEMENTED
+  PrimExpr VisitExpr_(const LetNode* op) NOT_IMPLEMENTED
+
+  PrimExpr VisitExpr_(const CallNode* op) {
+    PrimExpr expr = GetRef<PrimExpr>(op);
+    if (op->call_type == CallNode::CallType::Halide) {
+      if (input_.get() && op->func.same_as(input_->op) &&
+          op->value_index == input_->value_index) {
+        // Tensor(indices)
+        CHECK_EQ(indices_.size(), op->args.size());
+        PrimExpr condition = const_true();
+        for (size_t i = 0; i < input_.ndim(); ++i) {
+          condition = AndNode::make(condition, EQNode::make(indices_[i], op->args[i]));
+        }
+        return CastNode::make(op->dtype, condition);
+      } else {
+        return make_zero(op->dtype);
+      }
+    } else if (op->call_type == CallNode::CallType::PureIntrinsic) {
+      static std::unordered_set<std::string> piecewise_const = {"floor", "ceil", "trunc",
"round"};
+      if (op->name == "exp") {
+        return MulNode::make(Mutate(op->args[0]), expr);
+      } else if (op->name == "log") {
+        return DivNode::make(Mutate(op->args[0]), op->args[0]);
+      } else if (op->name == "sigmoid") {
+        return MulNode::make(Mutate(op->args[0]),
+                             MulNode::make(expr, SubNode::make(FloatImm(expr.dtype(), 1.0),
expr)));
+      } else if (op->name == "sqrt") {
+        return DivNode::make(Mutate(op->args[0]),
+                             MulNode::make(expr, FloatImm(expr.dtype(), 2.0)));
+      } else if (op->name == "tanh") {
+        return MulNode::make(Mutate(op->args[0]),
+                             SubNode::make(FloatImm(expr.dtype(), 1.0), MulNode::make(expr,
expr)));
+      } else if (op->name == "pow") {
+        auto x = op->args[0], y = op->args[1];
+        return expr * (Mutate(y)*log(x) + Mutate(x)*y/x);
+      } else if (op->name == "fabs") {
+        auto type = op->args[0].dtype();
+        return MulNode::make(Mutate(op->args[0]),
+                             SelectNode::make(GENode::make(op->args[0], make_zero(type)),
+                                              FloatImm(type, 1.0), FloatImm(type, -1.0)));
+      } else if (op->name == intrinsic::tvm_if_then_else) {
+        Array<PrimExpr> new_args = {op->args[0],
+                                    Mutate(op->args[1]),
+                                    Mutate(op->args[2])};
+        return CallNode::make(op->dtype, op->name, new_args,
+                              op->call_type, op->func, op->value_index);
+      } else if (piecewise_const.count(op->name)) {
+        return FloatImm(expr.dtype(), 0.0);
+      } else {
+        throw dmlc::Error("Derivative of this intrinsic is not implemented: " + op->name);
+      }
+    }
+    NOT_IMPLEMENTED
+  }
+
+  PrimExpr VisitExpr_(const AddNode* op) {
+    return AddNode::make(Mutate(op->a), Mutate(op->b));
+  }
+
+  PrimExpr VisitExpr_(const SubNode* op) {
+    return SubNode::make(Mutate(op->a), Mutate(op->b));
+  }
+
+  PrimExpr VisitExpr_(const MulNode* op) {
+    return AddNode::make(
+        MulNode::make(Mutate(op->a), op->b),
+        MulNode::make(op->a, Mutate(op->b)));
+  }
+
+  PrimExpr VisitExpr_(const DivNode* op) {
+    return DivNode::make(
+        SubNode::make(
+            MulNode::make(Mutate(op->a), op->b),
+            MulNode::make(op->a, Mutate(op->b))),
+        MulNode::make(op->b, op->b));
+  }
+
+  PrimExpr VisitExpr_(const ModNode* op) NOT_IMPLEMENTED
+
+  PrimExpr VisitExpr_(const FloorDivNode* op) {
+    return FloorDivNode::make(
+        SubNode::make(
+            MulNode::make(Mutate(op->a), op->b),
+            MulNode::make(op->a, Mutate(op->b))),
+        MulNode::make(op->b, op->b));
+  }
+
+  PrimExpr VisitExpr_(const FloorModNode* op) NOT_IMPLEMENTED
+
+  PrimExpr VisitExpr_(const MinNode* op) {
+    return SelectNode::make(LENode::make(op->a, op->b),
+        Mutate(op->a), Mutate(op->b));
+  }
+
+  PrimExpr VisitExpr_(const MaxNode* op) {
+    return SelectNode::make(GENode::make(op->a, op->b),
+        Mutate(op->a), Mutate(op->b));
+  }
+
+  PrimExpr VisitExpr_(const EQNode* op) NOT_IMPLEMENTED
+  PrimExpr VisitExpr_(const NENode* op) NOT_IMPLEMENTED
+  PrimExpr VisitExpr_(const LTNode* op) NOT_IMPLEMENTED
+  PrimExpr VisitExpr_(const LENode* op) NOT_IMPLEMENTED
+  PrimExpr VisitExpr_(const GTNode* op) NOT_IMPLEMENTED
+  PrimExpr VisitExpr_(const GENode* op) NOT_IMPLEMENTED
+  PrimExpr VisitExpr_(const AndNode* op) NOT_IMPLEMENTED
+  PrimExpr VisitExpr_(const OrNode* op) NOT_IMPLEMENTED
+
+  PrimExpr VisitExpr_(const ReduceNode* op) {
+    // This case is relatively difficult because a reduction expression
+    // may use an arbitrary combiner.
+    // The resulting reduction expression will return a tuple containing
+    // both derivatives and the original results (in exactly this order).
 
 Review comment:
   Looking into a bit more, the order actually makes difference. When original init value
is different from its derivative init value, and they depends on each other during calculation.
we must calculate derivative first (using origin's init value), switch the order in tvm makes
the origin value be replaced, produces incorrect results.
   
   One example is in the test case,
   ```python
    def fcombine(x, y):
           return x*y
   
       def fidentity(t0):
           return tvm.tir.const(1, t0)
   
       prod = te.comm_reducer(fcombine, fidentity, name='prod')
       B = te.compute((10, 10), lambda i, j: prod(A0[i, k] + A0[k, i], axis=k), name='B')
       check_grad(B, A0)
   ```
   
   Correct result (derivative first):
   ```
   produce B.jacobian {
     for (i, 0, 10) {
       for (j, 0, 10) {
         for (jac_i0, 0, 10) {
           for (jac_i1, 0, 10) {
             B.jacobian.v0[((((i*1000) + (j*100)) + (jac_i0*10)) + jac_i1)] = 0f
             B.jacobian.v1[((((i*1000) + (j*100)) + (jac_i0*10)) + jac_i1)] = 1f
             for (k, 0, 10) {
               B.jacobian.v0[((((i*1000) + (j*100)) + (jac_i0*10)) + jac_i1)] = ((B.jacobian.v0[((((i*1000)
+ (j*100)) + (jac_i0*10)) + jac_i1)]*(A0[((i*10) + k)] + A0[((k*10) + i)])) + ((float32(((jac_i0
== i) && (jac_i1 == k))) + float32(((jac_i0 == k) && (jac_i1 == i))))*B.jacobian.v1[((((i*1000)
+ (j*100)) + (jac_i0*10)) + jac_i1)]))
               B.jacobian.v1[((((i*1000) + (j*100)) + (jac_i0*10)) + jac_i1)] = (B.jacobian.v1[((((i*1000)
+ (j*100)) + (jac_i0*10)) + jac_i1)]*(A0[((i*10) + k)] + A0[((k*10) + i)]))
             }
           }
         }
       }
     }
   }
   Output B.jacobian.v0
   ```
   
   Incorrect result (origin first):
   ```
   produce B.jacobian {
     for (i, 0, 10) {
       for (j, 0, 10) {
         for (jac_i0, 0, 10) {
           for (jac_i1, 0, 10) {
             B.jacobian.v0[((((i*1000) + (j*100)) + (jac_i0*10)) + jac_i1)] = 1f
             B.jacobian.v1[((((i*1000) + (j*100)) + (jac_i0*10)) + jac_i1)] = 0f
             for (k, 0, 10) {
               B.jacobian.v0[((((i*1000) + (j*100)) + (jac_i0*10)) + jac_i1)] = (B.jacobian.v0[((((i*1000)
+ (j*100)) + (jac_i0*10)) + jac_i1)]*(A0[((i*10) + k)] + A0[((k*10) + i)]))
               B.jacobian.v1[((((i*1000) + (j*100)) + (jac_i0*10)) + jac_i1)] = ((B.jacobian.v1[((((i*1000)
+ (j*100)) + (jac_i0*10)) + jac_i1)]*(A0[((i*10) + k)] + A0[((k*10) + i)])) + ((float32(((jac_i0
== i) && (jac_i1 == k))) + float32(((jac_i0 == k) && (jac_i1 == i))))*B.jacobian.v0[((((i*1000)
+ (j*100)) + (jac_i0*10)) + jac_i1)]))
             }
           }
         }
       }
     }
   }
   Output B.jacobian.v1
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
users@infra.apache.org


With regards,
Apache Git Services

Mime
View raw message