groovy-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From cchamp...@apache.org
Subject [26/37] incubator-groovy git commit: ASTMatcher: First bits of work for wildcard matching
Date Wed, 07 Oct 2015 19:26:46 GMT
ASTMatcher: First bits of work for wildcard matching


Project: http://git-wip-us.apache.org/repos/asf/incubator-groovy/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-groovy/commit/1cd27ea2
Tree: http://git-wip-us.apache.org/repos/asf/incubator-groovy/tree/1cd27ea2
Diff: http://git-wip-us.apache.org/repos/asf/incubator-groovy/diff/1cd27ea2

Branch: refs/heads/master
Commit: 1cd27ea261a9eb712234e996115dcc5d3f039f53
Parents: 4b7db4c
Author: Cedric Champeau <cedric.champeau@gmail.com>
Authored: Fri Oct 17 16:05:26 2014 +0200
Committer: Sergei Egorov <bsideup@gmail.com>
Committed: Mon Sep 28 14:33:11 2015 +0300

----------------------------------------------------------------------
 .../groovy/macro/matcher/ASTMatcher.groovy      | 240 ++++++++++---------
 .../groovy/macro/matcher/ASTMatcherTest.groovy  |  30 +++
 2 files changed, 157 insertions(+), 113 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-groovy/blob/1cd27ea2/subprojects/groovy-macro/src/main/groovy/org/codehaus/groovy/macro/matcher/ASTMatcher.groovy
----------------------------------------------------------------------
diff --git a/subprojects/groovy-macro/src/main/groovy/org/codehaus/groovy/macro/matcher/ASTMatcher.groovy
b/subprojects/groovy-macro/src/main/groovy/org/codehaus/groovy/macro/matcher/ASTMatcher.groovy
index 0020aeb..858924b 100644
--- a/subprojects/groovy-macro/src/main/groovy/org/codehaus/groovy/macro/matcher/ASTMatcher.groovy
+++ b/subprojects/groovy-macro/src/main/groovy/org/codehaus/groovy/macro/matcher/ASTMatcher.groovy
@@ -32,6 +32,8 @@ import org.codehaus.groovy.control.SourceUnit
 @CompileStatic
 class ASTMatcher extends ClassCodeVisitorSupport {
 
+    public static final String WILDCARD = "_";
+
     private Object current = null
     private boolean match = true
 
@@ -66,6 +68,14 @@ class ASTMatcher extends ClassCodeVisitorSupport {
         match = match && value
     }
 
+    private static boolean matchByName(String a, String b) {
+        return a.equals(b) || WILDCARD.equals(a) || WILDCARD.equals(b);
+    }
+
+    private static boolean isWildcardExpression(Object exp) {
+        return exp instanceof VariableExpression && WILDCARD.equals(exp.getName());
+    }
+
     /**
      * Locates all nodes in the given AST which match the pattern AST.
      * This operation can cost a lot, because it tries to match a sub-tree
@@ -81,10 +91,14 @@ class ASTMatcher extends ClassCodeVisitorSupport {
         finder.matches
     }
 
-    private void doWithNode(Class expectedClass, Object next, Closure cl) {
+    private void doWithNode(Object search, Object next, Closure cl) {
+        Class expectedClass = search?search.class:Object
         if (expectedClass == null) {
             expectedClass = Object
         }
+        if (isWildcardExpression(search)) {
+            return
+        }
         if (match && (next == null || expectedClass.isAssignableFrom(next.class)))
{
             Object old = current
             current = next
@@ -97,12 +111,12 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitClass(final ClassNode node) {
-        doWithNode(ClassNode, current) {
+        doWithNode(node, current) {
             visitAnnotations(node)
-            doWithNode(PackageNode, ((ClassNode) current).package) {
+            doWithNode(node.package, ((ClassNode) current).package) {
                 visitPackage(node.package)
             }
-            doWithNode(ModuleNode, ((ClassNode) current).module) {
+            doWithNode(node.module, ((ClassNode) current).module) {
                 visitImports(node.module)
             }
 
@@ -123,7 +137,7 @@ class ASTMatcher extends ClassCodeVisitorSupport {
                 def iter = curProps.iterator()
                 // now let's visit the contents of the class
                 for (PropertyNode pn : nodeProps) {
-                    doWithNode(pn.class, iter.next()) {
+                    doWithNode(pn, iter.next()) {
                         visitProperty(pn)
                     }
                 }
@@ -133,7 +147,7 @@ class ASTMatcher extends ClassCodeVisitorSupport {
                 if (nodeFields.size() == curFields.size()) {
                     iter = curFields.iterator()
                     for (FieldNode fn : nodeFields) {
-                        doWithNode(fn.class, iter.next()) {
+                        doWithNode(fn, iter.next()) {
                             visitField(fn)
                         }
                     }
@@ -143,7 +157,7 @@ class ASTMatcher extends ClassCodeVisitorSupport {
                     if (nodeConstructors.size() == curConstructors.size()) {
                         iter = curConstructors.iterator()
                         for (ConstructorNode cn : nodeConstructors) {
-                            doWithNode(cn.class, iter.next()) {
+                            doWithNode(cn, iter.next()) {
                                 visitConstructor(cn)
                             }
                         }
@@ -153,7 +167,7 @@ class ASTMatcher extends ClassCodeVisitorSupport {
                         if (nodeMethods.size() == curMethods.size()) {
                             iter = curMethods.iterator()
                             for (MethodNode mn : nodeMethods) {
-                                doWithNode(mn.class, iter.next()) {
+                                doWithNode(mn, iter.next()) {
                                     visitMethod(mn)
                                 }
                             }
@@ -162,18 +176,18 @@ class ASTMatcher extends ClassCodeVisitorSupport {
                     }
                 }
             }
-            failIfNot(cur.name == node.name)
+            failIfNot(matchByName(cur.name,node.name))
         }
     }
 
     @Override
     protected void visitObjectInitializerStatements(final ClassNode node) {
-        doWithNode(ClassNode, current) {
+        doWithNode(node, current) {
             def initializers = ((ClassNode) current).objectInitializerStatements
             if (initializers.size() == node.objectInitializerStatements.size()) {
                 def iterator = initializers.iterator()
                 for (Statement element : node.objectInitializerStatements) {
-                    doWithNode(element.class, iterator.next()) {
+                    doWithNode(element, iterator.next()) {
                         element.visit(this)
                     }
                 }
@@ -186,7 +200,7 @@ class ASTMatcher extends ClassCodeVisitorSupport {
     @Override
     public void visitPackage(final PackageNode node) {
         if (node) {
-            doWithNode(node.class, current) {
+            doWithNode(node, current) {
                 visitAnnotations(node)
                 node.visit(this)
             }
@@ -196,13 +210,13 @@ class ASTMatcher extends ClassCodeVisitorSupport {
     @Override
     public void visitImports(final ModuleNode node) {
         if (node) {
-            doWithNode(ModuleNode, current) {
+            doWithNode(node, current) {
                 ModuleNode module = (ModuleNode) current
                 def imports = module.imports
                 if (imports.size() == node.imports.size()) {
                     def iter = imports.iterator()
                     for (ImportNode importNode : node.imports) {
-                        doWithNode(importNode.class, iter.next()) {
+                        doWithNode(importNode, iter.next()) {
                             visitAnnotations(importNode)
                             importNode.visit(this)
                         }
@@ -215,7 +229,7 @@ class ASTMatcher extends ClassCodeVisitorSupport {
                 if (imports.size() == node.starImports.size()) {
                     def iter = imports.iterator()
                     for (ImportNode importNode : node.starImports) {
-                        doWithNode(importNode.class, iter.next()) {
+                        doWithNode(importNode, iter.next()) {
                             visitAnnotations(importNode)
                             importNode.visit(this)
                         }
@@ -228,7 +242,7 @@ class ASTMatcher extends ClassCodeVisitorSupport {
                 if (imports.size() == node.staticImports.size()) {
                     def iter = imports.values().iterator()
                     for (ImportNode importNode : node.staticImports.values()) {
-                        doWithNode(importNode.class, iter.next()) {
+                        doWithNode(importNode, iter.next()) {
                             visitAnnotations(importNode)
                             importNode.visit(this)
                         }
@@ -240,7 +254,7 @@ class ASTMatcher extends ClassCodeVisitorSupport {
                 if (imports.size() == node.staticStarImports.size()) {
                     def iter = imports.values().iterator()
                     for (ImportNode importNode : node.staticStarImports.values()) {
-                        doWithNode(importNode.class, iter.next()) {
+                        doWithNode(importNode, iter.next()) {
                             visitAnnotations(importNode)
                             importNode.visit(this)
                         }
@@ -254,7 +268,7 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitAnnotations(final AnnotatedNode node) {
-        doWithNode(AnnotatedNode, current) {
+        doWithNode(node, current) {
             List<AnnotationNode> refAnnotations = node.annotations
             AnnotatedNode cur = (AnnotatedNode) current
             List<AnnotationNode> curAnnotations = cur.annotations
@@ -282,7 +296,7 @@ class ASTMatcher extends ClassCodeVisitorSupport {
                     for (Map.Entry<String, Expression> member : refEntrySet) {
                         def next = entryIt.next()
                         if (next.key == member.key) {
-                            doWithNode(member.value.class, next.value) {
+                            doWithNode(member.value, next.value) {
                                 member.value.visit(this)
                             }
                         } else {
@@ -298,7 +312,7 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     protected void visitClassCodeContainer(final Statement code) {
-        doWithNode(Statement, current) {
+        doWithNode(code, current) {
             if (code) {
                 code.visit(this)
             }
@@ -308,24 +322,24 @@ class ASTMatcher extends ClassCodeVisitorSupport {
     @Override
     @CompileStatic(TypeCheckingMode.SKIP)
     public void visitDeclarationExpression(final DeclarationExpression expression) {
-        doWithNode(DeclarationExpression, current) {
+        doWithNode(expression, current) {
             super.visitDeclarationExpression(expression)
         }
     }
 
     @Override
     protected void visitConstructorOrMethod(final MethodNode node, final boolean isConstructor)
{
-        doWithNode(MethodNode, current) {
+        doWithNode(node, current) {
             visitAnnotations(node)
             def cur = (MethodNode) current
-            doWithNode(Statement, cur.code) {
+            doWithNode(node.code, cur.code) {
                 visitClassCodeContainer(node.code)
             }
             def params = node.parameters
             def curParams = cur.parameters
             if (params.length == curParams.length) {
                 params.eachWithIndex { Parameter entry, int i ->
-                    doWithNode(entry.class, curParams[i]) {
+                    doWithNode(entry, curParams[i]) {
                         visitAnnotations(entry)
                     }
                 }
@@ -337,10 +351,10 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitField(final FieldNode node) {
-        doWithNode(FieldNode, current) {
+        doWithNode(node, current) {
             visitAnnotations(node)
             def fieldNode = (FieldNode) current
-            failIfNot(fieldNode.name == node.name)
+            failIfNot(matchByName(fieldNode.name,node.name))
             failIfNot(fieldNode.originType == node.originType)
             failIfNot(fieldNode.modifiers == node.modifiers)
 
@@ -349,7 +363,7 @@ class ASTMatcher extends ClassCodeVisitorSupport {
             Expression curInit = fieldNode.initialExpression
             if (init) {
                 if (curInit) {
-                    doWithNode(init.class, curInit) {
+                    doWithNode(init, curInit) {
                         init.visit(this)
                     }
                 } else {
@@ -363,19 +377,19 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitProperty(final PropertyNode node) {
-        doWithNode(PropertyNode, current) {
+        doWithNode(node, current) {
             PropertyNode pNode = (PropertyNode) current
             visitAnnotations(node)
 
             Statement statement = node.getterBlock
             Statement curStatement = pNode.getterBlock
-            doWithNode(statement?.class, curStatement) {
+            doWithNode(statement, curStatement) {
                 visitClassCodeContainer(statement)
             }
 
             statement = node.setterBlock
             curStatement = pNode.setterBlock
-            doWithNode(statement?.class, curStatement) {
+            doWithNode(statement, curStatement) {
                 visitClassCodeContainer(statement)
             }
 
@@ -383,7 +397,7 @@ class ASTMatcher extends ClassCodeVisitorSupport {
             Expression curInit = pNode.initialExpression
             if (init) {
                 if (curInit) {
-                    doWithNode(init.class, curInit) {
+                    doWithNode(init, curInit) {
                         init.visit(this)
                     }
                 } else {
@@ -397,7 +411,7 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     void visitExpressionStatement(final ExpressionStatement statement) {
-        doWithNode(statement.expression.class, ((ExpressionStatement) current).expression)
{
+        doWithNode(statement.expression, ((ExpressionStatement) current).expression) {
             visitStatement(statement)
             statement.expression.visit(this)
         }
@@ -405,12 +419,12 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitBlockStatement(BlockStatement block) {
-        doWithNode(BlockStatement, current) {
+        doWithNode(block, current) {
             def statements = ((BlockStatement) current).statements
             if (statements.size() == block.statements.size()) {
                 def iter = statements.iterator()
                 for (Statement statement : block.statements) {
-                    doWithNode(statement.class, iter.next()) {
+                    doWithNode(statement, iter.next()) {
                         statement.visit(this)
                     }
                 }
@@ -422,18 +436,18 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitMethodCallExpression(final MethodCallExpression call) {
-        doWithNode(MethodCallExpression, current) {
+        doWithNode(call, current) {
             def mce = (MethodCallExpression) current
-            doWithNode(call.objectExpression.class, mce.objectExpression) {
+            doWithNode(call.objectExpression, mce.objectExpression) {
                 call.objectExpression.visit(this)
             }
-            doWithNode(call.method.class, mce.method) {
+            doWithNode(call.method, mce.method) {
                 call.method.visit(this)
             }
-            doWithNode(call.arguments.class, mce.arguments) {
+            doWithNode(call.arguments, mce.arguments) {
                 call.arguments.visit(this)
             }
-            failIfNot((call.methodAsString == mce.methodAsString) &&
+            failIfNot(matchByName(call.methodAsString, mce.methodAsString) &&
                     (call.safe == mce.safe) &&
                     (call.spreadSafe == mce.spreadSafe) &&
                     (call.implicitThis == mce.implicitThis))
@@ -448,7 +462,7 @@ class ASTMatcher extends ClassCodeVisitorSupport {
     @Override
     public void visitConstructorCallExpression(final ConstructorCallExpression call) {
         def cur = (ConstructorCallExpression) current
-        doWithNode(call.arguments.class, cur.arguments) {
+        doWithNode(call.arguments, cur.arguments) {
             call.arguments.visit(this)
             failIfNot(call.type == cur.type)
         }
@@ -456,14 +470,14 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitBinaryExpression(final BinaryExpression expression) {
-        doWithNode(BinaryExpression, current) {
+        doWithNode(expression, current) {
             def bin = (BinaryExpression) current
             def leftExpression = expression.getLeftExpression()
-            doWithNode(leftExpression.class, bin.leftExpression) {
+            doWithNode(leftExpression, bin.leftExpression) {
                 leftExpression.visit(this)
             }
             def rightExpression = expression.getRightExpression()
-            doWithNode(rightExpression.class, bin.rightExpression) {
+            doWithNode(rightExpression, bin.rightExpression) {
                 rightExpression.visit(this)
             }
             if (bin.operation.type != expression.operation.type) {
@@ -474,17 +488,17 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitTernaryExpression(final TernaryExpression expression) {
-        doWithNode(TernaryExpression, current) {
+        doWithNode(expression, current) {
             TernaryExpression te = (TernaryExpression) current
-            doWithNode(BooleanExpression, te.booleanExpression) {
+            doWithNode(expression.booleanExpression, te.booleanExpression) {
                 expression.booleanExpression.visit(this)
             }
             def trueExpression = expression.trueExpression
-            doWithNode(trueExpression.class, te.trueExpression) {
+            doWithNode(trueExpression, te.trueExpression) {
                 trueExpression.visit(this)
             }
             def falseExpression = expression.falseExpression
-            doWithNode(falseExpression.class, te.falseExpression) {
+            doWithNode(falseExpression, te.falseExpression) {
                 falseExpression.visit(this)
             }
         }
@@ -492,10 +506,10 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitPostfixExpression(final PostfixExpression expression) {
-        doWithNode(PostfixExpression, current) {
+        doWithNode(expression, current) {
             def origExpr = expression.expression
             def curExpr = (PostfixExpression) current
-            doWithNode(origExpr.class, curExpr.expression) {
+            doWithNode(origExpr, curExpr.expression) {
                 origExpr.visit(this)
                 failIfNot(expression.operation.type == curExpr.operation.type)
             }
@@ -504,10 +518,10 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitPrefixExpression(final PrefixExpression expression) {
-        doWithNode(PrefixExpression, current) {
+        doWithNode(expression, current) {
             def origExpr = expression.expression
             def curExpr = (PrefixExpression) current
-            doWithNode(origExpr.class, curExpr.expression) {
+            doWithNode(origExpr, curExpr.expression) {
                 origExpr.visit(this)
                 failIfNot(expression.operation.type == curExpr.operation.type)
             }
@@ -516,8 +530,8 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitBooleanExpression(final BooleanExpression expression) {
-        doWithNode(BooleanExpression, current) {
-            doWithNode(expression.expression.class, ((BooleanExpression) current).expression)
{
+        doWithNode(expression, current) {
+            doWithNode(expression.expression, ((BooleanExpression) current).expression) {
                 expression.expression.visit(this)
             }
         }
@@ -525,10 +539,10 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitNotExpression(final NotExpression expression) {
-        doWithNode(NotExpression, current) {
+        doWithNode(expression, current) {
             def expr = expression.expression
             def cur = ((NotExpression) current).expression
-            doWithNode(expr.class, cur) {
+            doWithNode(expr, cur) {
                 expr.visit(this)
             }
         }
@@ -536,10 +550,10 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitClosureExpression(final ClosureExpression expression) {
-        doWithNode(ClosureExpression, current) {
+        doWithNode(expression, current) {
             def code = expression.code
             def cl = (ClosureExpression) current
-            doWithNode(code.class, cl.code) {
+            doWithNode(code, cl.code) {
                 code.visit(this)
                 checkParameters(expression.parameters, cl.parameters)
             }
@@ -556,8 +570,8 @@ class ASTMatcher extends ClassCodeVisitorSupport {
                 for (int i = 0; i < nodeParams.length && match; i++) {
                     def n = nodeParams[i]
                     def c = curParams[i]
-                    doWithNode(n.class, c) {
-                        failIfNot((n.name == c.name) &&
+                    doWithNode(n, c) {
+                        failIfNot(matchByName(n.name,c.name) &&
                                 (n.originType == c.originType) &&
                                 (n.type == c.originType))
                     }
@@ -570,8 +584,8 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitTupleExpression(final TupleExpression expression) {
-        doWithNode(TupleExpression, current) {
-            doWithNode(List, ((TupleExpression) current).expressions) {
+        doWithNode(expression, current) {
+            doWithNode(expression.expressions, ((TupleExpression) current).expressions) {
                 visitListOfExpressions(expression.expressions)
             }
         }
@@ -579,9 +593,9 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitListExpression(final ListExpression expression) {
-        doWithNode(ListExpression, current) {
+        doWithNode(expression, current) {
             def exprs = expression.expressions
-            doWithNode(exprs.class, ((ListExpression) current).expressions) {
+            doWithNode(exprs, ((ListExpression) current).expressions) {
                 visitListOfExpressions(exprs)
             }
         }
@@ -589,16 +603,16 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitArrayExpression(final ArrayExpression expression) {
-        doWithNode(ArrayExpression, current) {
+        doWithNode(expression, current) {
             def expressions = expression.expressions
             def size = expression.sizeExpression
             def cur = (ArrayExpression) current
             def curExprs = cur.expressions
             def curSize = cur.sizeExpression
-            doWithNode(expressions.class, curExprs) {
+            doWithNode(expressions, curExprs) {
                 visitListOfExpressions(expressions)
             }
-            doWithNode(size.class, curSize) {
+            doWithNode(size, curSize) {
                 visitListOfExpressions(size)
             }
             failIfNot(expression.elementType == cur.elementType)
@@ -607,10 +621,10 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitMapExpression(final MapExpression expression) {
-        doWithNode(MapExpression, current) {
+        doWithNode(expression, current) {
             def entries = expression.mapEntryExpressions
             def curEntries = ((MapExpression) current).mapEntryExpressions
-            doWithNode(entries.class, curEntries) {
+            doWithNode(entries, curEntries) {
                 visitListOfExpressions(entries)
             }
         }
@@ -618,16 +632,16 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitMapEntryExpression(final MapEntryExpression expression) {
-        doWithNode(MapEntryExpression, current) {
+        doWithNode(expression, current) {
             def key = expression.keyExpression
             def value = expression.valueExpression
             def cur = (MapEntryExpression) current
             def curKey = cur.keyExpression
             def curValue = cur.valueExpression
-            doWithNode(key.class, curKey) {
+            doWithNode(key, curKey) {
                 key.visit(this)
             }
-            doWithNode(value.class, curValue) {
+            doWithNode(value, curValue) {
                 value.visit(this)
             }
         }
@@ -635,16 +649,16 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitRangeExpression(final RangeExpression expression) {
-        doWithNode(RangeExpression, current) {
+        doWithNode(expression, current) {
             def from = expression.from
             def to = expression.to
             def cur = (RangeExpression) current
             def curFrom = cur.from
             def curTo = cur.to
-            doWithNode(from.class, curFrom) {
+            doWithNode(from, curFrom) {
                 from.visit(this)
             }
-            doWithNode(to.class, curTo) {
+            doWithNode(to, curTo) {
                 to.visit(this)
             }
         }
@@ -652,9 +666,9 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitSpreadExpression(final SpreadExpression expression) {
-        doWithNode(SpreadExpression, current) {
+        doWithNode(expression, current) {
             def expr = expression.expression
-            doWithNode(expr.class, ((SpreadExpression) current).expression) {
+            doWithNode(expr, ((SpreadExpression) current).expression) {
                 expr.visit(this)
             }
         }
@@ -667,16 +681,16 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitMethodPointerExpression(final MethodPointerExpression expression) {
-        doWithNode(MethodPointerExpression, current) {
+        doWithNode(expression, current) {
             def cur = (MethodPointerExpression) current
             def expr = expression.expression
             def methodName = expression.methodName
             def curExpr = cur.expression
             def curName = cur.methodName
-            doWithNode(expr.class, curExpr) {
+            doWithNode(expr, curExpr) {
                 expr.visit(this)
             }
-            doWithNode(methodName.class, curName) {
+            doWithNode(methodName, curName) {
                 methodName.visit(this)
             }
         }
@@ -684,9 +698,9 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitUnaryMinusExpression(final UnaryMinusExpression expression) {
-        doWithNode(UnaryMinusExpression, current) {
+        doWithNode(expression, current) {
             def expr = expression.expression
-            doWithNode(expr.class, ((UnaryMinusExpression) current).expression) {
+            doWithNode(expr, ((UnaryMinusExpression) current).expression) {
                 expr.visit(this)
             }
         }
@@ -694,9 +708,9 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitUnaryPlusExpression(final UnaryPlusExpression expression) {
-        doWithNode(UnaryPlusExpression, current) {
+        doWithNode(expression, current) {
             def expr = expression.expression
-            doWithNode(expr.class, ((UnaryPlusExpression) current).expression) {
+            doWithNode(expr, ((UnaryPlusExpression) current).expression) {
                 expr.visit(this)
             }
         }
@@ -704,9 +718,9 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitBitwiseNegationExpression(final BitwiseNegationExpression expression)
{
-        doWithNode(BitwiseNegationExpression, current) {
+        doWithNode(expression, current) {
             def expr = expression.expression
-            doWithNode(expr.class, ((BitwiseNegationExpression) current).expression) {
+            doWithNode(expr, ((BitwiseNegationExpression) current).expression) {
                 expr.visit(this)
             }
         }
@@ -714,9 +728,9 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitCastExpression(final CastExpression expression) {
-        doWithNode(CastExpression, current) {
+        doWithNode(expression, current) {
             def expr = expression.expression
-            doWithNode(expr.class, ((CastExpression) current).expression) {
+            doWithNode(expr, ((CastExpression) current).expression) {
                 expr.visit(this)
             }
             failIfNot(expression.type == ((CastExpression) current).type)
@@ -726,7 +740,7 @@ class ASTMatcher extends ClassCodeVisitorSupport {
     @Override
     @CompileStatic(TypeCheckingMode.SKIP)
     public void visitConstantExpression(final ConstantExpression expression) {
-        doWithNode(ConstantExpression, current) {
+        doWithNode(expression, current) {
             def cur = (ConstantExpression) current
             super.visitConstantExpression(expression)
             failIfNot((expression.type == cur.type) &&
@@ -737,7 +751,7 @@ class ASTMatcher extends ClassCodeVisitorSupport {
     @Override
     @CompileStatic(TypeCheckingMode.SKIP)
     public void visitClassExpression(final ClassExpression expression) {
-        doWithNode(ClassExpression, current) {
+        doWithNode(expression, current) {
             super.visitClassExpression(expression)
             def cexp = (ClassExpression) current
             failIfNot(cexp.type == expression.type)
@@ -746,9 +760,9 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitVariableExpression(final VariableExpression expression) {
-        doWithNode(VariableExpression, current) {
+        doWithNode(expression, current) {
             def curVar = (VariableExpression) current
-            failIfNot((expression.name == curVar.name) &&
+            failIfNot(matchByName(expression.name,curVar.name) &&
                     (expression.type == curVar.type) &&
                     (expression.originType == curVar.originType))
         }
@@ -756,12 +770,12 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitPropertyExpression(final PropertyExpression expression) {
-        doWithNode(PropertyExpression, current) {
+        doWithNode(expression, current) {
             def currentPexp = (PropertyExpression) current
-            doWithNode(expression.objectExpression.class, currentPexp.objectExpression) {
+            doWithNode(expression.objectExpression, currentPexp.objectExpression) {
                 expression.objectExpression.visit(this)
             }
-            doWithNode(expression.property.class, currentPexp.property) {
+            doWithNode(expression.property, currentPexp.property) {
                 expression.property.visit(this)
             }
             failIfNot((expression.propertyAsString == currentPexp.propertyAsString) &&
@@ -773,12 +787,12 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitAttributeExpression(final AttributeExpression expression) {
-        doWithNode(AttributeExpression, current) {
+        doWithNode(expression, current) {
             def currentPexp = (AttributeExpression) current
-            doWithNode(expression.objectExpression.class, currentPexp.objectExpression) {
+            doWithNode(expression.objectExpression, currentPexp.objectExpression) {
                 expression.objectExpression.visit(this)
             }
-            doWithNode(expression.property.class, currentPexp.property) {
+            doWithNode(expression.property, currentPexp.property) {
                 expression.property.visit(this)
             }
             failIfNot((expression.propertyAsString == currentPexp.propertyAsString) &&
@@ -795,16 +809,16 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitGStringExpression(final GStringExpression expression) {
-        doWithNode(GStringExpression, current) {
+        doWithNode(expression, current) {
             def cur = (GStringExpression) current
             def strings = expression.strings
             def values = expression.values
             def curStrings = cur.strings
             def curValues = cur.values
-            doWithNode(strings.class, curStrings) {
+            doWithNode(strings, curStrings) {
                 visitListOfExpressions(strings)
             }
-            doWithNode(values.class, curValues) {
+            doWithNode(values, curValues) {
                 visitListOfExpressions(values)
             }
         }
@@ -821,7 +835,7 @@ class ASTMatcher extends ClassCodeVisitorSupport {
         def iter = currentExprs.iterator()
         for (Expression expression : list) {
             def next = iter.next()
-            doWithNode(expression.class, next) {
+            doWithNode(expression, next) {
                 expression.visit(this)
             }
         }
@@ -834,9 +848,9 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     public void visitClosureListExpression(final ClosureListExpression cle) {
-        doWithNode(ClosureListExpression, current) {
+        doWithNode(cle, current) {
             def exprs = cle.expressions
-            doWithNode(exprs.class, ((ClosureListExpression)current).expressions) {
+            doWithNode(exprs, ((ClosureListExpression)current).expressions) {
                 visitListOfExpressions(exprs)
             }
         }
@@ -849,20 +863,20 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     void visitIfElse(final IfStatement ifElse) {
-        doWithNode(IfStatement, current) {
+        doWithNode(ifElse, current) {
             visitStatement(ifElse)
             def cur = (IfStatement) current
             def bool = ifElse.booleanExpression
             def ifBlock = ifElse.ifBlock
             def elseBlock = ifElse.elseBlock
-            doWithNode(bool.class, cur.booleanExpression) {
+            doWithNode(bool, cur.booleanExpression) {
                 bool.visit(this)
             }
-            doWithNode(ifBlock.class, cur.ifBlock) {
+            doWithNode(ifBlock, cur.ifBlock) {
                 ifBlock.visit(this)
             }
             failIfNot(elseBlock && cur.elseBlock || !elseBlock && !cur.elseBlock)
-            doWithNode(elseBlock.class, cur.elseBlock) {
+            doWithNode(elseBlock, cur.elseBlock) {
                 elseBlock.visit(this)
             }
         }
@@ -870,15 +884,15 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     void visitForLoop(final ForStatement forLoop) {
-        doWithNode(ForStatement, current) {
+        doWithNode(forLoop, current) {
             visitStatement(forLoop)
             def cur = (ForStatement) current
             def col = forLoop.collectionExpression
             def block = forLoop.loopBlock
-            doWithNode(col.class, cur.collectionExpression) {
+            doWithNode(col, cur.collectionExpression) {
                 col.visit(this)
             }
-            doWithNode(block.class, cur.loopBlock) {
+            doWithNode(block, cur.loopBlock) {
                 block.visit(this)
             }
         }
@@ -886,15 +900,15 @@ class ASTMatcher extends ClassCodeVisitorSupport {
 
     @Override
     void visitWhileLoop(final WhileStatement loop) {
-        doWithNode(WhileStatement, current) {
+        doWithNode(loop, current) {
             visitStatement(loop)
             def cur = (WhileStatement) current
             def bool = loop.booleanExpression
             def block = loop.loopBlock
-            doWithNode(bool.class, cur.booleanExpression) {
+            doWithNode(bool, cur.booleanExpression) {
                 bool.visit(this)
             }
-            doWithNode(block.class, cur.loopBlock) {
+            doWithNode(block, cur.loopBlock) {
                 block.visit(this)
             }
         }

http://git-wip-us.apache.org/repos/asf/incubator-groovy/blob/1cd27ea2/subprojects/groovy-macro/src/test/groovy/org/codehaus/groovy/macro/matcher/ASTMatcherTest.groovy
----------------------------------------------------------------------
diff --git a/subprojects/groovy-macro/src/test/groovy/org/codehaus/groovy/macro/matcher/ASTMatcherTest.groovy
b/subprojects/groovy-macro/src/test/groovy/org/codehaus/groovy/macro/matcher/ASTMatcherTest.groovy
index 9c02e65..e1e88a3 100644
--- a/subprojects/groovy-macro/src/test/groovy/org/codehaus/groovy/macro/matcher/ASTMatcherTest.groovy
+++ b/subprojects/groovy-macro/src/test/groovy/org/codehaus/groovy/macro/matcher/ASTMatcherTest.groovy
@@ -574,4 +574,34 @@ class ASTMatcherTest extends GroovyTestCase {
         assert !ASTMatcher.matches(ast1, ast3)
         assert !ASTMatcher.matches(ast1, ast4)
     }
+
+    void testWildcardMatchVariable() {
+        def ast1 = macro { a }
+        def ast2 = macro { _ }
+        def ast3 = macro { b }
+        assert ASTMatcher.matches(ast1, ast2)
+        assert ASTMatcher.matches(ast2, ast3)
+    }
+
+    void testWildcardMatchVariableInBinaryExpression() {
+        def ast1 = macro { a+b }
+        def ast2 = macro { _+_ }
+        def ast3 = macro { _+c }
+        def ast4 = macro { c+_ }
+        def ast5 = macro { a+_ }
+        def ast6 = macro { _+b }
+        assert ASTMatcher.matches(ast1, ast2)
+        assert !ASTMatcher.matches(ast1, ast3)
+        assert !ASTMatcher.matches(ast1, ast4)
+        assert ASTMatcher.matches(ast1, ast5)
+        assert ASTMatcher.matches(ast1, ast6)
+    }
+
+    void testWildcardForSubExpression() {
+        def ast1 = macro { a+foo(b) }
+        def ast2 = macro { _+foo(b) }
+        def ast3 = macro { a+_ }
+        //assert ASTMatcher.matches(ast1, ast2)
+        assert ASTMatcher.matches(ast1, ast3)
+    }
 }



Mime
View raw message