tvm-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From GitBox <...@apache.org>
Subject [GitHub] [incubator-tvm] masahi commented on a change in pull request #5676: [DOC] Improve Pattern Language Docs
Date Wed, 27 May 2020 22:35:31 GMT

masahi commented on a change in pull request #5676:
URL: https://github.com/apache/incubator-tvm/pull/5676#discussion_r431481115



##########
File path: docs/langref/relay_pattern.rst
##########
@@ -139,3 +237,124 @@ Domination
 **********
 
 Match child pattern, find a match for the parent pattern, insuring that the child ultimately
dominates the parrent (i.e., no nodes outside the pattern use outputs of the parent), and
that ever node betwen the child and the pattern matches the path pattern.
+
+Applications
+============
+
+The pattern language provides not only the pattern matching but also pattern processing.
+Here we introduce two pattern processing approaches and provide some examples.
+
+Pattern Rewriting
+*****************
+
+If you would like to replace the matched pattern with another subgraph, you can leverage
+the ``rewrite`` transformation. Here is an example of rewriting a series of arithmetic operators
+with a single batch_norm op:
+
+.. code-block:: python
+
+    class BatchnormCallback(DFPatternCallback):
+        # A callback class to rewrite the matched pattern to a batch_norm op.
+        def __init__(self):
+            self.x = wildcard()
+            self.var = wildcard()
+            self.mean = wildcard()
+            self.beta = wildcard()
+            self.gamma = wildcard()
+            self.eps = wildcard()
+            
+            self.pattern = self.gamma * (self.x - self.mean)/is_op("sqrt")(self.var + self.eps)
+ self.beta
+
+        def callback(self, pre, post, node_map):
+            x = node_map[self.x][0]
+            var = node_map[self.var][0]
+            mean = node_map[self.mean][0]
+            beta = node_map[self.beta][0]
+            gamma = node_map[self.gamma][0]
+            eps = node_map[self.eps][0]
+            return relay.op.nn.batch_norm(x, gamma, beta, mean, var, epsilon = eps.data.asnumpy().item())[0]
+
+        # A graph of arithmetic operators that are functional equivalent to batch_norm.
+        x = relay.var('x')
+        var = relay.var('var')
+        mean = relay.var('mean')
+        beta = relay.var('beta')
+        gamma = relay.var('gamma')
+        BN = gamma * (x - mean)/relay.op.sqrt(var + relay.const(1e-5)) + beta
+
+        from tvm.relay.dataflow_pattern import rewrite
+        out = rewrite(BatchnormCallback(), BN)
+        assert tvm.ir.structural_equal(out, relay.op.nn.batch_norm(x, gamma, beta, mean,
var, epsilon = 1e-5)[0])
+
+The function ``def callback(self, pre, post, node_map)`` will be invoked when the rewriter
matches
+``self.pattern``. ``node_map`` is a dictionary mapping from pattern nodes to matched nodes
in the graph.
+
+Pattern Partitioning
+********************
+
+If you would like to perform a more complex processing for matched subgraphs and you are
not
+satisfy with ``rewrite``, you may consider partitioning the matched subgraphs to a separate

Review comment:
       satisfied




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



Mime
View raw message