groovy-commits mailing list archives

Site index · List index
Message view « Date » · « Thread »
Top « Date » · « Thread »
From sun...@apache.org
Subject [groovy] branch GROOVY_3_0_X updated: Avoid unnecessary capturing the instance of enclosing class
Date Sat, 07 Dec 2019 18:07:33 GMT
This is an automated email from the ASF dual-hosted git repository.

sunlan pushed a commit to branch GROOVY_3_0_X
in repository https://gitbox.apache.org/repos/asf/groovy.git


The following commit(s) were added to refs/heads/GROOVY_3_0_X by this push:
     new 29c6696  Avoid unnecessary capturing the instance of enclosing class
29c6696 is described below

commit 29c6696fb9addae6b498087333f158b716cb3a76
Author: Daniel Sun <sunlan@apache.org>
AuthorDate: Sun Dec 8 02:03:03 2019 +0800

    Avoid unnecessary capturing the instance of enclosing class
    
    If the lambda expression does not access the instance of enclosing class, e.g. no instance
fields or instance methods accessed, the instance of enclosing class need not to be captured.
    
    (cherry picked from commit 23b53b4db1720fc365cc18eeb6bc2600a9ef368d)
---
 .../classgen/asm/sc/StaticTypesLambdaWriter.java   |  46 +++-
 .../groovy/control/StaticImportVisitor.java        |  15 +-
 src/test/groovy/transform/stc/LambdaTest.groovy    | 255 +++++++++++++++++++--
 3 files changed, 282 insertions(+), 34 deletions(-)

diff --git a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java
b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java
index 5005e33..4f7f5d9 100644
--- a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java
+++ b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java
@@ -19,6 +19,7 @@
 
 package org.codehaus.groovy.classgen.asm.sc;
 
+import org.apache.groovy.util.ObjectHolder;
 import org.codehaus.groovy.GroovyBugError;
 import org.codehaus.groovy.ast.ClassCodeVisitorSupport;
 import org.codehaus.groovy.ast.ClassHelper;
@@ -135,8 +136,8 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun
                 addDeserializeLambdaMethod();
             }
 
-            newGroovyLambdaWrapperAndLoad(lambdaWrapperClassNode, expression);
-            loadEnclosingClassInstance();
+            newGroovyLambdaWrapperAndLoad(lambdaWrapperClassNode, syntheticLambdaMethodNode,
expression);
+            loadEnclosingClassInstance(syntheticLambdaMethodNode);
         }
 
         MethodVisitor mv = controller.getMethodVisitor();
@@ -160,12 +161,12 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements
AbstractFun
         return new Parameter[]{new Parameter(ClassHelper.SERIALIZEDLAMBDA_TYPE, SERIALIZED_LAMBDA_PARAM_NAME)};
     }
 
-    private void loadEnclosingClassInstance() {
+    private void loadEnclosingClassInstance(MethodNode syntheticLambdaMethodNode) {
         MethodVisitor mv = controller.getMethodVisitor();
         OperandStack operandStack = controller.getOperandStack();
         CompileStack compileStack = controller.getCompileStack();
 
-        if (controller.isStaticMethod() || compileStack.isInSpecialConstructorCall()) {
+        if (controller.isStaticMethod() || compileStack.isInSpecialConstructorCall() || !isAccessingInstanceMembers(syntheticLambdaMethodNode))
{
             operandStack.pushConstant(ConstantExpression.NULL);
         } else {
             mv.visitVarInsn(ALOAD, 0);
@@ -173,13 +174,46 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements
AbstractFun
         }
     }
 
-    private void newGroovyLambdaWrapperAndLoad(ClassNode lambdaWrapperClassNode, LambdaExpression
expression) {
+    private boolean isAccessingInstanceMembers(MethodNode syntheticLambdaMethodNode) {
+        ObjectHolder<Boolean> objectHolder = new ObjectHolder<>(false);
+        ClassCodeVisitorSupport classCodeVisitorSupport = new ClassCodeVisitorSupport() {
+            @Override
+            public void visitVariableExpression(VariableExpression expression) {
+                if (expression.isThisExpression()) {
+                    objectHolder.setObject(true);
+                }
+            }
+
+            @Override
+            public void visitMethodCallExpression(MethodCallExpression call) {
+                if (!call.getMethodTarget().isStatic()) {
+                    Expression objectExpression = call.getObjectExpression();
+                    if (objectExpression instanceof VariableExpression && ENCLOSING_THIS.equals(((VariableExpression)
objectExpression).getName())) {
+                        objectHolder.setObject(true);
+                    }
+                }
+
+                super.visitMethodCallExpression(call);
+            }
+
+            @Override
+            protected SourceUnit getSourceUnit() {
+                return null;
+            }
+        };
+
+        classCodeVisitorSupport.visitMethod(syntheticLambdaMethodNode);
+
+        return objectHolder.getObject();
+    }
+
+    private void newGroovyLambdaWrapperAndLoad(ClassNode lambdaWrapperClassNode, MethodNode
syntheticLambdaMethodNode, LambdaExpression expression) {
         MethodVisitor mv = controller.getMethodVisitor();
         String lambdaWrapperClassInternalName = BytecodeHelper.getClassInternalName(lambdaWrapperClassNode);
         mv.visitTypeInsn(NEW, lambdaWrapperClassInternalName);
         mv.visitInsn(DUP);
 
-        loadEnclosingClassInstance();
+        loadEnclosingClassInstance(syntheticLambdaMethodNode);
         controller.getOperandStack().dup();
 
         loadSharedVariables(expression);
diff --git a/src/main/java/org/codehaus/groovy/control/StaticImportVisitor.java b/src/main/java/org/codehaus/groovy/control/StaticImportVisitor.java
index 7d2db8d..e8557a7 100644
--- a/src/main/java/org/codehaus/groovy/control/StaticImportVisitor.java
+++ b/src/main/java/org/codehaus/groovy/control/StaticImportVisitor.java
@@ -100,25 +100,26 @@ public class StaticImportVisitor extends ClassCodeExpressionTransformer
{
 
     public Expression transform(Expression exp) {
         if (exp == null) return null;
-        if (exp.getClass() == VariableExpression.class) {
+        Class<? extends Expression> clazz = exp.getClass();
+        if (clazz == VariableExpression.class) {
             return transformVariableExpression((VariableExpression) exp);
         }
-        if (exp.getClass() == BinaryExpression.class) {
+        if (clazz == BinaryExpression.class) {
             return transformBinaryExpression((BinaryExpression) exp);
         }
-        if (exp.getClass() == PropertyExpression.class) {
+        if (clazz == PropertyExpression.class) {
             return transformPropertyExpression((PropertyExpression) exp);
         }
-        if (exp.getClass() == MethodCallExpression.class) {
+        if (clazz == MethodCallExpression.class) {
             return transformMethodCallExpression((MethodCallExpression) exp);
         }
-        if (exp.getClass() == ClosureExpression.class) {
+        if (exp instanceof ClosureExpression) {
             return transformClosureExpression((ClosureExpression) exp);
         }
-        if (exp.getClass() == ConstructorCallExpression.class) {
+        if (clazz == ConstructorCallExpression.class) {
             return transformConstructorCallExpression((ConstructorCallExpression) exp);
         }
-        if (exp.getClass() == ArgumentListExpression.class) {
+        if (clazz == ArgumentListExpression.class) {
             Expression result = exp.transformExpression(this);
             if (inPropertyExpression) {
                 foundArgs = result;
diff --git a/src/test/groovy/transform/stc/LambdaTest.groovy b/src/test/groovy/transform/stc/LambdaTest.groovy
index 4693e29..3cdb6a1 100644
--- a/src/test/groovy/transform/stc/LambdaTest.groovy
+++ b/src/test/groovy/transform/stc/LambdaTest.groovy
@@ -909,6 +909,26 @@ class LambdaTest extends GroovyTestCase {
         '''
     }
 
+    void testInitializeBlocks() {
+        assertScript '''
+            import java.util.stream.Collectors
+            
+            @groovy.transform.CompileStatic
+            class Test1 {
+                static sl
+                def il
+                static { sl = [1, 2, 3].stream().map(e -> e + 1).toList() }
+                 
+                {
+                    il = [1, 2, 3].stream().map(e -> e + 2).toList()
+                }
+            }
+            
+            assert [2, 3, 4] == Test1.sl
+            assert [3, 4, 5] == new Test1().il
+        '''
+    }
+
     void testSerialize() {
         assertScript '''
         import java.util.function.Function
@@ -916,8 +936,7 @@ class LambdaTest extends GroovyTestCase {
         interface SerializableFunction<T, R> extends Function<T, R>, Serializable
{}
         
         @groovy.transform.CompileStatic
-        class Test1 implements Serializable {
-            private static final long serialVersionUID = -1L;
+        class Test1 {
             def p() {
                     def out = new ByteArrayOutputStream()
                     out.withObjectOutputStream {
@@ -934,12 +953,11 @@ class LambdaTest extends GroovyTestCase {
     }
 
     void testSerializeFailed() {
-        shouldFail(NotSerializableException, '''
+        def errMsg = shouldFail(NotSerializableException, '''
         import java.util.function.Function
         
         @groovy.transform.CompileStatic
-        class Test1 implements Serializable {
-            private static final long serialVersionUID = -1L;
+        class Test1 {
             def p() {
                     def out = new ByteArrayOutputStream()
                     out.withObjectOutputStream {
@@ -953,6 +971,8 @@ class LambdaTest extends GroovyTestCase {
 
         new Test1().p()
         ''')
+
+        assert errMsg.contains('$Lambda$')
     }
 
     void testDeserialize() {
@@ -961,8 +981,7 @@ class LambdaTest extends GroovyTestCase {
         import java.util.function.Function
         
         @groovy.transform.CompileStatic
-        class Test1 implements Serializable {
-            private static final long serialVersionUID = -1L;
+        class Test1 {
             byte[] p() {
                     def out = new ByteArrayOutputStream()
                     out.withObjectOutputStream {
@@ -985,14 +1004,83 @@ class LambdaTest extends GroovyTestCase {
         '''
     }
 
+    void testDeserializeLambdaInInitializeBlock() {
+        assertScript '''
+            package tests.lambda
+            import java.util.function.Function
+            
+            @groovy.transform.CompileStatic
+            class Test1 implements Serializable {
+                private static final long serialVersionUID = -1L;
+                String a = 'a'
+                SerializableFunction<Integer, String> f
+                 
+                {
+                    f = ((Integer e) -> a + e)
+                }
+                
+                byte[] p() {
+                    def out = new ByteArrayOutputStream()
+                    out.withObjectOutputStream {
+                        it.writeObject(f)
+                    }
+                    
+                    return out.toByteArray()
+                }
+                
+                static void main(String[] args) {
+                    new ByteArrayInputStream(new Test1().p()).withObjectInputStream(Test1.class.classLoader)
{
+                        SerializableFunction<Integer, String> f = (SerializableFunction<Integer,
String>) it.readObject()
+                        assert 'a1' == f.apply(1)
+                    }
+                }
+                
+                interface SerializableFunction<T, R> extends Function<T, R>,
Serializable {}
+            }
+        '''
+    }
+
+    void testDeserializeLambdaInInitializeBlockShouldFail() {
+        def errMsg = shouldFail(NotSerializableException, '''
+            package tests.lambda
+            import java.util.function.Function
+            
+            @groovy.transform.CompileStatic
+            class Test1 {
+                String a = 'a'
+                SerializableFunction<Integer, String> f
+                 
+                {
+                    f = ((Integer e) -> a + e)
+                }
+                
+                byte[] p() {
+                    def out = new ByteArrayOutputStream()
+                    out.withObjectOutputStream {
+                        it.writeObject(f)
+                    }
+                    
+                    return out.toByteArray()
+                }
+                
+                static void main(String[] args) {
+                    new Test1().p()
+                }
+                
+                interface SerializableFunction<T, R> extends Function<T, R>,
Serializable {}
+            }
+        ''')
+
+        assert errMsg.contains('tests.lambda.Test1')
+    }
+
 
     void testDeserialize2() {
         assertScript '''
         import java.util.function.Function
         
         @groovy.transform.CompileStatic
-        class Test1 implements Serializable {
-            private static final long serialVersionUID = -1L;
+        class Test1 {
             static byte[] p() {
                     def out = new ByteArrayOutputStream()
                     out.withObjectOutputStream {
@@ -1021,8 +1109,7 @@ class LambdaTest extends GroovyTestCase {
         import java.util.function.Function
         
         @groovy.transform.CompileStatic
-        class Test1 implements Serializable {
-            private static final long serialVersionUID = -1L;
+        class Test1 {
             byte[] p() {
                     def out = new ByteArrayOutputStream()
                     out.withObjectOutputStream {
@@ -1052,8 +1139,7 @@ class LambdaTest extends GroovyTestCase {
         import java.util.function.Function
         
         @groovy.transform.CompileStatic
-        class Test1 implements Serializable {
-            private static final long serialVersionUID = -1L;
+        class Test1 {
             byte[] p() {
                     def out = new ByteArrayOutputStream()
                     String c = 'a'
@@ -1107,7 +1193,7 @@ class LambdaTest extends GroovyTestCase {
         '''
     }
 
-    void testDeserialize6() {
+    void testDeserialize6InstanceFields() {
         assertScript '''
         package tests.lambda
         import java.util.function.Function
@@ -1139,7 +1225,105 @@ class LambdaTest extends GroovyTestCase {
         '''
     }
 
-    void testDeserialize7() {
+    void testDeserialize6InstanceFieldsShouldFail() {
+        def errMsg = shouldFail(NotSerializableException, '''
+        package tests.lambda
+        import java.util.function.Function
+        
+        @groovy.transform.CompileStatic
+        class Test1 {
+            private String c = 'a'
+            
+            byte[] p() {
+                    def out = new ByteArrayOutputStream()
+                    SerializableFunction<Integer, String> f = (Integer e) -> c +
e
+                    out.withObjectOutputStream {
+                        it.writeObject(f)
+                    }
+                    
+                    return out.toByteArray()
+            }
+            
+            static void main(String[] args) {
+                new ByteArrayInputStream(new Test1().p()).withObjectInputStream(Test1.class.classLoader)
{
+                    SerializableFunction<Integer, String> f = (SerializableFunction<Integer,
String>) it.readObject()
+                    assert 'a1' == f.apply(1)
+                }
+            }
+            
+            interface SerializableFunction<T, R> extends Function<T, R>, Serializable
{}
+        }
+        ''')
+
+        assert errMsg.contains('tests.lambda.Test1')
+    }
+
+    void testDeserialize6InstanceMethods() {
+        assertScript '''
+        package tests.lambda
+        import java.util.function.Function
+        
+        @groovy.transform.CompileStatic
+        class Test1 implements Serializable {
+            private static final long serialVersionUID = -1L;
+            private String c() { 'a' }
+            
+            byte[] p() {
+                    def out = new ByteArrayOutputStream()
+                    SerializableFunction<Integer, String> f = (Integer e) -> c()
+ e
+                    out.withObjectOutputStream {
+                        it.writeObject(f)
+                    }
+                    
+                    return out.toByteArray()
+            }
+            
+            static void main(String[] args) {
+                new ByteArrayInputStream(new Test1().p()).withObjectInputStream(Test1.class.classLoader)
{
+                    SerializableFunction<Integer, String> f = (SerializableFunction<Integer,
String>) it.readObject()
+                    assert 'a1' == f.apply(1)
+                }
+            }
+            
+            interface SerializableFunction<T, R> extends Function<T, R>, Serializable
{}
+        }
+        '''
+    }
+
+    void testDeserialize6InstanceMethodsShouldFail() {
+        def errMsg = shouldFail(NotSerializableException, '''
+        package tests.lambda
+        import java.util.function.Function
+        
+        @groovy.transform.CompileStatic
+        class Test1 {
+            private String c() { 'a' }
+            
+            byte[] p() {
+                    def out = new ByteArrayOutputStream()
+                    SerializableFunction<Integer, String> f = (Integer e) -> c()
+ e
+                    out.withObjectOutputStream {
+                        it.writeObject(f)
+                    }
+                    
+                    return out.toByteArray()
+            }
+            
+            static void main(String[] args) {
+                new ByteArrayInputStream(new Test1().p()).withObjectInputStream(Test1.class.classLoader)
{
+                    SerializableFunction<Integer, String> f = (SerializableFunction<Integer,
String>) it.readObject()
+                    assert 'a1' == f.apply(1)
+                }
+            }
+            
+            interface SerializableFunction<T, R> extends Function<T, R>, Serializable
{}
+        }
+        ''')
+
+        assert errMsg.contains('tests.lambda.Test1')
+    }
+
+    void testDeserialize7StaticFields() {
         assertScript '''
         package tests.lambda
         import java.util.function.Function
@@ -1169,6 +1353,37 @@ class LambdaTest extends GroovyTestCase {
         '''
     }
 
+
+    void testDeserialize7StaticMethods() {
+        assertScript '''
+        package tests.lambda
+        import java.util.function.Function
+        
+        @groovy.transform.CompileStatic
+        class Test1 {
+            private static String c() { 'a' }
+            static byte[] p() {
+                    def out = new ByteArrayOutputStream()
+                    SerializableFunction<Integer, String> f = (Integer e) -> c()
+ e
+                    out.withObjectOutputStream {
+                        it.writeObject(f)
+                    }
+                    
+                    return out.toByteArray()
+            }
+            
+            static void main(String[] args) {
+                new ByteArrayInputStream(Test1.p()).withObjectInputStream(Test1.class.classLoader)
{
+                    SerializableFunction<Integer, String> f = (SerializableFunction<Integer,
String>) it.readObject()
+                    assert 'a1' == f.apply(1)
+                }
+            }
+            
+            interface SerializableFunction<T, R> extends Function<T, R>, Serializable
{}
+        }
+        '''
+    }
+
     void testDeserializeNestedLambda() {
         assertScript '''
         import java.util.function.Function
@@ -1176,8 +1391,7 @@ class LambdaTest extends GroovyTestCase {
         interface SerializableFunction<T, R> extends Function<T, R>, Serializable
{}
         
         @groovy.transform.CompileStatic
-        class Test1 implements Serializable {
-            private static final long serialVersionUID = -1L;
+        class Test1 {
             def p() {
                     def out1 = new ByteArrayOutputStream()
                     SerializableFunction<Integer, String> f1 = (Integer e) -> 'a'
+ e
@@ -1233,8 +1447,7 @@ class LambdaTest extends GroovyTestCase {
         interface SerializableFunction<T, R> extends Function<T, R>, Serializable
{}
         
         @groovy.transform.CompileStatic
-        class Test1 implements Serializable {
-            private static final long serialVersionUID = -1L;
+        class Test1 {
             def p() {
                     def out1 = new ByteArrayOutputStream()
                     out1.withObjectOutputStream {
@@ -1290,8 +1503,7 @@ class LambdaTest extends GroovyTestCase {
         interface SerializableFunction<T, R> extends Function<T, R>, Serializable
{}
         
         @groovy.transform.CompileStatic
-        class Test1 implements Serializable {
-            private static final long serialVersionUID = -1L;
+        class Test1 {
             static p() {
                     def out1 = new ByteArrayOutputStream()
                     out1.withObjectOutputStream {
@@ -1395,4 +1607,5 @@ class LambdaTest extends GroovyTestCase {
         }
         '''
     }
+
 }


Mime
View raw message