This is an automated email from the ASF dual-hosted git repository. sunlan pushed a commit to branch GROOVY-11905 in repository https://gitbox.apache.org/repos/asf/groovy.git
commit b9c0bce3e59539b8c6b9b8b63412f6c779a37574 Author: Daniel Sun <[email protected]> AuthorDate: Sun Apr 5 22:28:54 2026 +0900 GROOVY-11905: Optimize non-capturing lambdas --- .../classgen/asm/sc/StaticTypesLambdaWriter.java | 95 +++-- .../groovy/groovy/transform/stc/LambdaTest.groovy | 443 +++++++++++++++++++++ .../groovy/classgen/asm/TypeAnnotationsTest.groovy | 2 +- 3 files changed, 509 insertions(+), 31 deletions(-) diff --git a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java index b804702754..ae5f63ddc8 100644 --- a/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java +++ b/src/main/java/org/codehaus/groovy/classgen/asm/sc/StaticTypesLambdaWriter.java @@ -72,6 +72,7 @@ import static org.objectweb.asm.Opcodes.ACC_STATIC; import static org.objectweb.asm.Opcodes.ALOAD; import static org.objectweb.asm.Opcodes.CHECKCAST; import static org.objectweb.asm.Opcodes.DUP; +import static org.objectweb.asm.Opcodes.H_INVOKESTATIC; import static org.objectweb.asm.Opcodes.H_INVOKEVIRTUAL; import static org.objectweb.asm.Opcodes.ICONST_0; import static org.objectweb.asm.Opcodes.INVOKESPECIAL; @@ -110,36 +111,61 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun expression.setSerializable(true); } + boolean serializable = expression.isSerializable(); ClassNode lambdaClass = getOrAddLambdaClass(expression, abstractMethod); MethodNode lambdaMethod = lambdaClass.getMethods(DO_CALL).get(0); - boolean canDeserialize = controller.getClassNode().hasMethod(createDeserializeLambdaMethodName(lambdaClass), createDeserializeLambdaMethodParams()); - if (!canDeserialize) { - if (expression.isSerializable()) { - addDeserializeLambdaMethodForEachLambdaExpression(expression, lambdaClass); + Parameter[] lambdaSharedVariables = expression.getNodeMetaData(LAMBDA_SHARED_VARIABLES); + boolean accessingInstanceMembers = isAccessingInstanceMembersOfEnclosingClass(lambdaMethod); + // For non-capturing lambdas: make doCall static and use a capture-free invokedynamic, + // so LambdaMetafactory creates a singleton instance — just like Java non-capturing lambdas. + boolean nonCapturing = lambdaSharedVariables.length == 0 && !accessingInstanceMembers; + boolean deserializeLambdaMethodExists = hasDeserializeLambdaMethod(lambdaClass); + + if (nonCapturing) { + lambdaMethod.setModifiers(lambdaMethod.getModifiers() | ACC_STATIC); + } + + if (!deserializeLambdaMethodExists) { + if (serializable) { + addDeserializeLambdaMethodForLambdaExpression(expression, lambdaClass, lambdaMethod); addDeserializeLambdaMethod(); } - newGroovyLambdaWrapperAndLoad(lambdaClass, expression, isAccessingInstanceMembersOfEnclosingClass(lambdaMethod)); + if (!nonCapturing) { + newGroovyLambdaWrapperAndLoad(lambdaClass, expression, accessingInstanceMembers); + } } MethodVisitor mv = controller.getMethodVisitor(); mv.visitInvokeDynamicInsn( abstractMethod.getName(), - createAbstractMethodDesc(functionalType.redirect(), lambdaClass), - createBootstrapMethod(controller.getClassNode().isInterface(), expression.isSerializable()), - createBootstrapMethodArguments(createMethodDescriptor(abstractMethod), H_INVOKEVIRTUAL, lambdaClass, lambdaMethod, lambdaMethod.getParameters(), expression.isSerializable()) + nonCapturing + ? BytecodeHelper.getMethodDescriptor(functionalType.redirect(), Parameter.EMPTY_ARRAY) + : createAbstractMethodDesc(functionalType.redirect(), lambdaClass), + createBootstrapMethod(controller.getClassNode().isInterface(), serializable), + createBootstrapMethodArguments(createMethodDescriptor(abstractMethod), + nonCapturing ? H_INVOKESTATIC : H_INVOKEVIRTUAL, + lambdaClass, lambdaMethod, lambdaMethod.getParameters(), serializable) ); - if (expression.isSerializable()) { + if (serializable) { mv.visitTypeInsn(CHECKCAST, "java/io/Serializable"); } - controller.getOperandStack().replace(functionalType.redirect(), 1); + if (nonCapturing) { + controller.getOperandStack().push(functionalType.redirect()); + } else { + controller.getOperandStack().replace(functionalType.redirect(), 1); + } } private static Parameter[] createDeserializeLambdaMethodParams() { return new Parameter[]{new Parameter(SERIALIZEDLAMBDA_TYPE, "serializedLambda")}; } + private boolean hasDeserializeLambdaMethod(final ClassNode lambdaClass) { + return controller.getClassNode().hasMethod(createDeserializeLambdaMethodName(lambdaClass), createDeserializeLambdaMethodParams()); + } + private static boolean isAccessingInstanceMembersOfEnclosingClass(final MethodNode lambdaMethod) { boolean[] result = new boolean[1]; @@ -325,27 +351,36 @@ public class StaticTypesLambdaWriter extends LambdaWriter implements AbstractFun code); } - private void addDeserializeLambdaMethodForEachLambdaExpression(final LambdaExpression expression, final ClassNode lambdaClass) { + private static boolean requiresLambdaInstance(final MethodNode lambdaMethod) { + return 0 == (lambdaMethod.getModifiers() & ACC_STATIC); + } + + private void addDeserializeLambdaMethodForLambdaExpression(final LambdaExpression expression, final ClassNode lambdaClass, final MethodNode lambdaMethod) { ClassNode enclosingClass = controller.getClassNode(); - Statement code = block( - new BytecodeSequence(new BytecodeInstruction() { - @Override - public void visit(final MethodVisitor mv) { - mv.visitVarInsn(ALOAD, 0); - mv.visitInsn(ICONST_0); - mv.visitMethodInsn( - INVOKEVIRTUAL, - "java/lang/invoke/SerializedLambda", - "getCapturedArg", - "(I)Ljava/lang/Object;", - false); - mv.visitTypeInsn(CHECKCAST, BytecodeHelper.getClassInternalName(lambdaClass)); - OperandStack operandStack = controller.getOperandStack(); - operandStack.push(lambdaClass); - } - }), - returnS(expression) - ); + Statement code; + if (!requiresLambdaInstance(lambdaMethod)) { + code = block(returnS(expression)); + } else { + code = block( + new BytecodeSequence(new BytecodeInstruction() { + @Override + public void visit(final MethodVisitor mv) { + mv.visitVarInsn(ALOAD, 0); + mv.visitInsn(ICONST_0); + mv.visitMethodInsn( + INVOKEVIRTUAL, + "java/lang/invoke/SerializedLambda", + "getCapturedArg", + "(I)Ljava/lang/Object;", + false); + mv.visitTypeInsn(CHECKCAST, BytecodeHelper.getClassInternalName(lambdaClass)); + OperandStack operandStack = controller.getOperandStack(); + operandStack.push(lambdaClass); + } + }), + returnS(expression) + ); + } enclosingClass.addSyntheticMethod( createDeserializeLambdaMethodName(lambdaClass), diff --git a/src/test/groovy/groovy/transform/stc/LambdaTest.groovy b/src/test/groovy/groovy/transform/stc/LambdaTest.groovy index f69eb1a332..1554ca90db 100644 --- a/src/test/groovy/groovy/transform/stc/LambdaTest.groovy +++ b/src/test/groovy/groovy/transform/stc/LambdaTest.groovy @@ -18,6 +18,8 @@ */ package groovy.transform.stc +import org.codehaus.groovy.classgen.asm.AbstractBytecodeTestCase +import org.junit.jupiter.api.Nested import org.junit.jupiter.api.Test import static groovy.test.GroovyAssert.assertScript @@ -1888,4 +1890,445 @@ final class LambdaTest { assert this.class.classLoader.loadClass('Foo$_bar_lambda1').modifiers == 25 // public(1) + static(8) + final(16) ''' } + + // GROOVY-11905 + @Nested + class NonCapturingLambdaOptimizationTest extends AbstractBytecodeTestCase { + @Test + void testNonCapturingLambdaWithFunctionInStaticMethod() { + assertScript shell, ''' + class C { + static void test() { + assert [2, 3, 4] == [1, 2, 3].stream().map(e -> e + 1).toList() + } + } + C.test() + ''' + } + + @Test + void testNonCapturingLambdaWithFunctionInInstanceMethodWithoutThisAccess() { + assertScript shell, ''' + class C { + void test() { + assert [2, 3, 4] == [1, 2, 3].stream().map(e -> e + 1).toList() + } + } + new C().test() + ''' + } + + @Test + void testNonCapturingLambdaWithPredicate() { + assertScript shell, ''' + class C { + static void test() { + assert [2, 4] == [1, 2, 3, 4].stream().filter(e -> e % 2 == 0).toList() + } + } + C.test() + ''' + } + + @Test + void testNonCapturingLambdaWithSupplier() { + assertScript shell, ''' + class C { + static void test() { + Supplier<String> s = () -> 'constant' + assert s.get() == 'constant' + assert 'hello' == Optional.<String>empty().orElseGet(() -> 'hello') + } + } + C.test() + ''' + } + + @Test + void testNonCapturingLambdaWithBiFunction() { + assertScript shell, ''' + class C { + static void test() { + BiFunction<Integer, Integer, Integer> f = (a, b) -> a + b + assert f.apply(3, 4) == 7 + } + } + C.test() + ''' + } + + @Test + void testNonCapturingLambdaWithComparator() { + assertScript shell, ''' + class C { + static void test() { + assert [3, 2, 1] == [1, 2, 3].stream().sorted((a, b) -> b.compareTo(a)).toList() + } + } + C.test() + ''' + } + + @Test + void testNonCapturingLambdaWithPrimitiveParameterType() { + assertScript shell, ''' + class C { + static void test() { + IntUnaryOperator op = (int i) -> i * 2 + assert op.applyAsInt(5) == 10 + } + } + C.test() + ''' + } + + @Test + void testNonCapturingLambdaWithCustomFunctionalInterface() { + assertScript shell, ''' + interface Transformer<I, O> { + O transform(I input) + } + class C { + static void test() { + Transformer<String, Integer> t = (String s) -> s.length() + assert t.transform('hello') == 5 + } + } + C.test() + ''' + } + + @Test + void testNonCapturingLambdaCallingStaticMethodOnly() { + assertScript shell, ''' + class C { + static String prefix() { 'Hi ' } + static void test() { + assert ['Hi 1', 'Hi 2'] == [1, 2].stream().map(e -> C.prefix() + e).toList() + } + } + C.test() + ''' + } + + @Test + void testMultipleNonCapturingLambdasInSameMethod() { + assertScript shell, ''' + class C { + static void test() { + Function<Integer, Integer> f = (Integer x) -> x + 1 + Function<Integer, String> g = (Integer x) -> 'v' + x + Predicate<Integer> p = (Integer x) -> x > 2 + assert f.apply(1) == 2 + assert g.apply(1) == 'v1' + assert p.test(3) && !p.test(1) + } + } + C.test() + ''' + } + + @Test + void testNonCapturingLambdaInStaticInitializerBlock() { + assertScript shell, ''' + class C { + static List<Integer> result + static { result = [1, 2, 3].stream().map(e -> e * 2).toList() } + } + assert C.result == [2, 4, 6] + ''' + } + + @Test + void testNonCapturingLambdaInFieldInitializer() { + assertScript shell, ''' + class C { + IntUnaryOperator op = (int i) -> i + 1 + void test() { assert op.applyAsInt(5) == 6 } + } + new C().test() + ''' + } + + @Test + void testNonCapturingLambdaInInterfaceDefaultMethod() { + assertScript shell, ''' + interface Processor { + default List<Integer> process(List<Integer> input) { + input.stream().map(e -> e + 1).toList() + } + } + class C implements Processor {} + assert new C().process([1, 2, 3]) == [2, 3, 4] + ''' + } + + @Test + void testNonCapturingLambdaSingletonIdentity() { + assertScript shell, ''' + class C { + static void test() { + def identities = new HashSet() + for (int i = 0; i < 5; i++) { + Function<Integer, Integer> f = (Integer x) -> x + 1 + identities.add(System.identityHashCode(f)) + } + assert identities.size() == 1 : 'non-capturing lambda should be a singleton' + } + } + C.test() + ''' + } + + @Test + void testCapturingLambdaCreatesDistinctInstances() { + assertScript shell, ''' + class C { + static void test() { + def identities = new HashSet() + for (int i = 0; i < 3; i++) { + int captured = i + Function<Integer, Integer> f = (Integer x) -> x + captured + identities.add(System.identityHashCode(f)) + assert f.apply(10) == 10 + i + } + assert identities.size() == 3 : 'capturing lambda should create different instances' + } + } + C.test() + ''' + } + + @Test + void testCapturingLocalVariableStillWorks() { + assertScript shell, ''' + class C { + static void test() { + String x = '#' + assert ['#1', '#2'] == [1, 2].stream().map(e -> x + e).toList() + } + } + C.test() + ''' + } + + @Test + void testAccessingThisStillWorks() { + assertScript shell, ''' + class C { + String prefix = 'Hi ' + void test() { + assert ['Hi 1', 'Hi 2'] == [1, 2].stream().map(e -> this.prefix + e).toList() + } + } + new C().test() + ''' + } + + @Test + void testCallingInstanceMethodStillWorks() { + assertScript shell, ''' + class C { + String greet(int i) { "Hello $i" } + void test() { + assert ['Hello 1', 'Hello 2'] == [1, 2].stream().map(e -> greet(e)).toList() + } + } + new C().test() + ''' + } + + @Test + void testNonCapturingSerializableLambdaCanBeSerialized() { + assertScript shell, ''' + import java.io.* + interface SerFunc<I,O> extends Serializable, Function<I,O> {} + byte[] test() { + try (def out = new ByteArrayOutputStream()) { + out.withObjectOutputStream { + SerFunc<Integer, String> f = ((Integer i) -> 'a' + i) + it.writeObject(f) + } + out.toByteArray() + } + } + assert test().length > 0 + ''' + } + + @Test + void testNonCapturingSerializableLambdaRoundTrips() { + assertScript shell, ''' + package tests.lambda + class C { + static byte[] test() { + def out = new ByteArrayOutputStream() + out.withObjectOutputStream { it -> + SerFunc<Integer, String> f = (Integer i) -> 'a' + i + it.writeObject(f) + } + out.toByteArray() + } + static main(args) { + new ByteArrayInputStream(C.test()).withObjectInputStream(C.classLoader) { + SerFunc<Integer, String> f = (SerFunc<Integer, String>) it.readObject() + assert f.apply(1) == 'a1' + } + } + interface SerFunc<I,O> extends Serializable, Function<I,O> {} + } + ''' + } + + @Test + void testNonCapturingSerializableLambdaSingletonIdentity() { + assertScript shell, ''' + interface SerFunc<I,O> extends Serializable, Function<I,O> {} + class C { + static void test() { + def identities = new HashSet() + for (int i = 0; i < 5; i++) { + SerFunc<Integer, Integer> f = (Integer x) -> x + 1 + identities.add(System.identityHashCode(f)) + } + assert identities.size() == 1 : 'non-capturing serializable lambda should be a singleton' + } + } + C.test() + ''' + } + + @Test + void testCapturingSerializableLambdaStillRoundTrips() { + assertScript shell, ''' + package tests.lambda + class C { + byte[] test() { + def out = new ByteArrayOutputStream() + out.withObjectOutputStream { + String s = 'a' + SerFunc<Integer, String> f = (Integer i) -> s + i + it.writeObject(f) + } + out.toByteArray() + } + static main(args) { + new ByteArrayInputStream(C.newInstance().test()).withObjectInputStream(C.classLoader) { + SerFunc<Integer, String> f = (SerFunc<Integer, String>) it.readObject() + assert f.apply(1) == 'a1' + } + } + interface SerFunc<I,O> extends Serializable, Function<I,O> {} + } + ''' + } + + @Test + void testNonCapturingLambdaInStaticMethodUsesStaticDoCall() { + def bytecode = compileStaticBytecode(classNamePattern: 'C\\$_create_lambda1', method: 'doCall', ''' + @CompileStatic + class C { + static IntUnaryOperator create() { + (int i) -> i * 2 + } + } + ''') + assert bytecode.hasStrictSequence([ + 'public static doCall(I)I', + 'L0' + ]) + } + + @Test + void testNonCapturingLambdaInInstanceMethodWithoutThisAccessUsesCaptureFreeInvokeDynamic() { + def bytecode = compileStaticBytecode(classNamePattern: 'C', method: 'create', ''' + @CompileStatic + class C { + IntUnaryOperator create() { + (int i) -> i + 1 + } + } + ''') + assert bytecode.hasSequence([ + 'INVOKEDYNAMIC applyAsInt()Ljava/util/function/IntUnaryOperator;', + 'java/lang/invoke/LambdaMetafactory.metafactory', + 'C$_create_lambda1.doCall(I)I' + ]) + assert !bytecode.hasSequence(['NEW C$_create_lambda1']) + } + + @Test + void testCapturingLambdaRetainsInstanceDoCallAndCapturedReceiver() { + def lambdaBytecode = compileStaticBytecode(classNamePattern: 'C\\$_create_lambda1', method: 'doCall', ''' + @CompileStatic + class C { + static IntUnaryOperator create() { + int captured = 1 + IntUnaryOperator op = (int i) -> i + captured + op + } + } + ''') + assert lambdaBytecode.hasSequence(['public doCall(I)I']) + assert !lambdaBytecode.hasSequence(['public static doCall(I)I']) + + def outerBytecode = compileStaticBytecode(classNamePattern: 'C', method: 'create', ''' + @CompileStatic + class C { + static IntUnaryOperator create() { + int captured = 1 + IntUnaryOperator op = (int i) -> i + captured + op + } + } + ''') + assert outerBytecode.hasSequence([ + 'NEW C$_create_lambda1', + 'INVOKEDYNAMIC applyAsInt(LC$_create_lambda1;)Ljava/util/function/IntUnaryOperator;', + 'C$_create_lambda1.doCall(I)I' + ]) + } + + @Test + void testNonCapturingSerializableLambdaDeserializeHelperSkipsCapturedArgLookup() { + def bytecode = compileStaticBytecode(classNamePattern: 'C', method: '$deserializeLambda_C$_create_lambda1$', ''' + @CompileStatic + class C { + static SerFunc<Integer, String> create() { + (Integer i) -> 'a' + i + } + interface SerFunc<I,O> extends Serializable, Function<I,O> {} + } + ''') + assert !bytecode.hasSequence([SERIALIZED_LAMBDA_GET_CAPTURED_ARG]) + } + + @Test + void testCapturingSerializableLambdaDeserializeHelperReadsCapturedArg() { + def bytecode = compileStaticBytecode(classNamePattern: 'C', method: '$deserializeLambda_C$_create_lambda1$', ''' + @CompileStatic + class C { + static SerFunc<Integer, String> create() { + String prefix = 'a' + SerFunc<Integer, String> f = (Integer i) -> prefix + i + f + } + interface SerFunc<I,O> extends Serializable, Function<I,O> {} + } + ''') + assert bytecode.hasSequence([SERIALIZED_LAMBDA_GET_CAPTURED_ARG]) + } + + private compileStaticBytecode(final Map options = [:], final String script) { + compile(options, COMMON_IMPORTS + script) + } + + private static final String COMMON_IMPORTS = '''\ + import groovy.transform.CompileStatic + import java.io.Serializable + import java.util.function.Function + import java.util.function.IntUnaryOperator + '''.stripIndent() + private static final String SERIALIZED_LAMBDA_GET_CAPTURED_ARG = 'INVOKEVIRTUAL java/lang/invoke/SerializedLambda.getCapturedArg' + } } diff --git a/src/test/groovy/org/codehaus/groovy/classgen/asm/TypeAnnotationsTest.groovy b/src/test/groovy/org/codehaus/groovy/classgen/asm/TypeAnnotationsTest.groovy index c5270ddec5..cdcf626338 100644 --- a/src/test/groovy/org/codehaus/groovy/classgen/asm/TypeAnnotationsTest.groovy +++ b/src/test/groovy/org/codehaus/groovy/classgen/asm/TypeAnnotationsTest.groovy @@ -299,7 +299,7 @@ final class TypeAnnotationsTest extends AbstractBytecodeTestCase { } ''') assert bytecode.hasStrictSequence([ - 'public doCall(I)I', + 'public static doCall(I)I', '@LTypeAnno1;() : METHOD_FORMAL_PARAMETER 0, null', 'L0' ])
