This is an automated email from the ASF dual-hosted git repository.

sunlan pushed a commit to branch GROOVY-9381_3
in repository https://gitbox.apache.org/repos/asf/groovy.git


The following commit(s) were added to refs/heads/GROOVY-9381_3 by this push:
     new e1abb83d30 GROOVY-9381: tweaks for CS
e1abb83d30 is described below

commit e1abb83d30cf672bef57b3f04d1d96fba3d5ed42
Author: Daniel Sun <[email protected]>
AuthorDate: Fri Mar 27 23:57:18 2026 +0900

    GROOVY-9381: tweaks for CS
---
 src/main/java/groovy/concurrent/AsyncScope.java    |   7 +-
 .../groovy/transform/AsyncTransformHelper.java     |   5 +-
 .../runtime/async/AsyncAwaitSyntaxTest.groovy      | 250 +++++++++++++++++++++
 3 files changed, 258 insertions(+), 4 deletions(-)

diff --git a/src/main/java/groovy/concurrent/AsyncScope.java 
b/src/main/java/groovy/concurrent/AsyncScope.java
index bf392ae112..48c5d45859 100644
--- a/src/main/java/groovy/concurrent/AsyncScope.java
+++ b/src/main/java/groovy/concurrent/AsyncScope.java
@@ -19,6 +19,8 @@
 package groovy.concurrent;
 
 import groovy.lang.Closure;
+import groovy.transform.stc.ClosureParams;
+import groovy.transform.stc.SimpleType;
 import org.apache.groovy.runtime.async.AsyncSupport;
 import org.apache.groovy.runtime.async.DefaultAsyncScope;
 
@@ -214,7 +216,7 @@ public interface AsyncScope extends AutoCloseable {
      * @throws NullPointerException if {@code body} is {@code null}
      */
     @SuppressWarnings("unchecked")
-    static <T> T withScope(Closure<T> body) {
+    static <T> T withScope(@ClosureParams(value = SimpleType.class, options = 
"groovy.concurrent.AsyncScope") Closure<T> body) {
         return withScope(AsyncSupport.getExecutor(), body);
     }
 
@@ -232,7 +234,8 @@ public interface AsyncScope extends AutoCloseable {
      *                              is {@code null}
      */
     @SuppressWarnings("unchecked")
-    static <T> T withScope(Executor executor, Closure<T> body) {
+    static <T> T withScope(Executor executor,
+            @ClosureParams(value = SimpleType.class, options = 
"groovy.concurrent.AsyncScope") Closure<T> body) {
         Objects.requireNonNull(body, "body must not be null");
         try (AsyncScope scope = AsyncScope.create(executor)) {
             return withCurrent(scope, () -> body.call(scope));
diff --git 
a/src/main/java/org/codehaus/groovy/transform/AsyncTransformHelper.java 
b/src/main/java/org/codehaus/groovy/transform/AsyncTransformHelper.java
index 8f0d137c16..2b93619f5d 100644
--- a/src/main/java/org/codehaus/groovy/transform/AsyncTransformHelper.java
+++ b/src/main/java/org/codehaus/groovy/transform/AsyncTransformHelper.java
@@ -18,6 +18,7 @@
  */
 package org.codehaus.groovy.transform;
 
+import org.apache.groovy.runtime.async.AsyncSupport;
 import org.codehaus.groovy.ast.ClassHelper;
 import org.codehaus.groovy.ast.ClassNode;
 import org.codehaus.groovy.ast.CodeVisitorSupport;
@@ -83,8 +84,8 @@ public final class AsyncTransformHelper {
 
     // ---- internal constants (all private) --------------------------------
 
-    private static final String ASYNC_SUPPORT_CLASS = 
"org.apache.groovy.runtime.async.AsyncSupport";
-    private static final ClassNode ASYNC_SUPPORT_TYPE = 
ClassHelper.make(ASYNC_SUPPORT_CLASS);
+    private static final String ASYNC_SUPPORT_CLASS = 
AsyncSupport.class.getName();
+    private static final ClassNode ASYNC_SUPPORT_TYPE = 
ClassHelper.makeWithoutCaching(AsyncSupport.class, false);
     private static final String ASYNC_GEN_PARAM_NAME = "$__asyncGen__";
     private static final String DEFER_SCOPE_VAR = "$__deferScope__";
 
diff --git 
a/src/test/groovy/org/apache/groovy/runtime/async/AsyncAwaitSyntaxTest.groovy 
b/src/test/groovy/org/apache/groovy/runtime/async/AsyncAwaitSyntaxTest.groovy
index 211a08418d..92cee0b076 100644
--- 
a/src/test/groovy/org/apache/groovy/runtime/async/AsyncAwaitSyntaxTest.groovy
+++ 
b/src/test/groovy/org/apache/groovy/runtime/async/AsyncAwaitSyntaxTest.groovy
@@ -2670,4 +2670,254 @@ class AsyncAwaitSyntaxTest {
         '''
     }
 
+    // 
=========================================================================
+    // CompileStatic coverage
+    // 
=========================================================================
+
+    @Test
+    void testCompileStaticScriptAwaitableCombinators() {
+        assertScript '''
+            import groovy.concurrent.AwaitResult
+            import groovy.concurrent.Awaitable
+            import groovy.transform.CompileStatic
+            import java.io.IOException
+            import java.util.List
+
+            @CompileStatic
+            List<Object> gather() {
+                return await Awaitable.all(
+                    Awaitable.of('hero'),
+                    Awaitable.of(42),
+                    Awaitable.of(true)
+                )
+            }
+
+            @CompileStatic
+            String firstWinner() {
+                return await Awaitable.any(
+                    Awaitable.of('fast'),
+                    Awaitable.delay(50).then { 'slow' }
+                )
+            }
+
+            @CompileStatic
+            List<AwaitResult<Object>> settle() {
+                return await Awaitable.allSettled(
+                    Awaitable.of('ok'),
+                    Awaitable.failed(new IOException('boom'))
+                )
+            }
+
+            assert gather() == ['hero', 42, true]
+            assert firstWinner() == 'fast'
+
+            List<AwaitResult<Object>> settled = settle()
+            assert settled.size() == 2
+            assert settled[0].isSuccess()
+            assert settled[0].getValue() == 'ok'
+            assert settled[1].isFailure()
+            assert settled[1].getError().message == 'boom'
+        '''
+    }
+
+    @Test
+    void testCompileStaticAsyncMethodsAndAwait() {
+        assertScript '''
+            import groovy.concurrent.Awaitable
+            import groovy.transform.CompileStatic
+
+            @CompileStatic
+            class QuestMath {
+                async int square(int n) {
+                    return n * n
+                }
+
+                async String describe(int n) {
+                    int value = await square(n)
+                    return "square=${value}"
+                }
+            }
+
+            QuestMath math = new QuestMath()
+            Awaitable<Integer> squared = math.square(6)
+            assert squared.get() == 36
+            assert math.describe(5).get() == 'square=25'
+        '''
+    }
+
+    @Test
+    void testCompileStaticAsyncClosureAndLambda() {
+        assertScript '''
+            import groovy.transform.CompileStatic
+
+            @CompileStatic
+            class QuestClosures {
+                void verify() {
+                    def asyncSquare = async { int n -> n * n }
+                    def asyncDouble = async (int n) -> { n * 2 }
+
+                    assert await(asyncSquare(5)) == 25
+                    assert await(asyncDouble(21)) == 42
+                }
+            }
+
+            new QuestClosures().verify()
+        '''
+    }
+
+    @Test
+    void testCompileStaticAsyncGeneratorAndForAwait() {
+        assertScript '''
+            import groovy.concurrent.AsyncStream
+            import groovy.transform.CompileStatic
+            import java.util.ArrayList
+            import java.util.List
+
+            @CompileStatic
+            class QuestGenerator {
+                async numbers(int max) {
+                    for (int i = 1; i <= max; i += 1) {
+                        yield return i
+                    }
+                }
+
+                async List<Integer> collectOdds(int max) {
+                    List<Integer> values = new ArrayList<>()
+                    for await (Integer value in numbers(max)) {
+                        if ((value % 2) == 1) {
+                            values.add(value)
+                        }
+                    }
+                    return values
+                }
+            }
+
+            QuestGenerator generator = new QuestGenerator()
+            assert generator.numbers(3) instanceof AsyncStream
+            assert generator.collectOdds(5).get() == [1, 3, 5]
+        '''
+    }
+
+    @Test
+    void testCompileStaticForAwaitOverAsyncChannel() {
+        assertScript '''
+            import groovy.concurrent.AsyncChannel
+            import groovy.concurrent.Awaitable
+            import groovy.transform.CompileStatic
+            import java.util.ArrayList
+            import java.util.List
+
+            @CompileStatic
+            class QuestChannel {
+                async void produce(AsyncChannel<String> channel) {
+                    await channel.send('a')
+                    await channel.send('b')
+                    await channel.send('c')
+                    channel.close()
+                }
+
+                async List<String> collect() {
+                    AsyncChannel<String> channel = AsyncChannel.create(3)
+                    Awaitable<Void> producer = produce(channel)
+                    List<String> values = new ArrayList<>()
+                    for await (String item in channel) {
+                        values.add(item.toUpperCase())
+                    }
+                    await producer
+                    return values
+                }
+            }
+
+            assert new QuestChannel().collect().get() == ['A', 'B', 'C']
+        '''
+    }
+
+    @Test
+    void testCompileStaticDeferWithAsyncCleanup() {
+        assertScript '''
+            import groovy.concurrent.Awaitable
+            import groovy.transform.CompileStatic
+            import java.util.ArrayList
+            import java.util.List
+
+            @CompileStatic
+            class QuestCleanup {
+                final List<String> log = new ArrayList<>()
+
+                Awaitable<Void> closeAsync(String name) {
+                    return Awaitable.go {
+                        await Awaitable.delay(5)
+                        log.add('close:' + name)
+                        return null
+                    }
+                }
+
+                async List<String> run() {
+                    defer { closeAsync('outer') }
+                    defer { closeAsync('inner') }
+                    log.add('body')
+                    return log
+                }
+            }
+
+            assert await(new QuestCleanup().run()) == ['body', 'close:inner', 
'close:outer']
+        '''
+    }
+
+    @Test
+    void testCompileStaticAsyncScopeWithAwait() {
+        assertScript '''
+            import groovy.concurrent.AsyncScope
+            import groovy.concurrent.Awaitable
+            import groovy.transform.CompileStatic
+            import java.util.Map
+
+            @CompileStatic
+            class QuestDashboard {
+                async String fetchName(int id) {
+                    await Awaitable.delay(5)
+                    return "hero-${id}"
+                }
+
+                async Integer fetchLevel(int id) {
+                    return await Awaitable.of(id + 40)
+                }
+
+                async Map<String, Object> load(int id) {
+                    return AsyncScope.withScope { scope ->
+                        Awaitable<String> nameTask = scope.async { await 
fetchName(id) }
+                        Awaitable<Integer> levelTask = scope.async { await 
fetchLevel(id) }
+                        [name: await nameTask, level: await levelTask]
+                    }
+                }
+            }
+
+            assert new QuestDashboard().load(2).get() == [name: 'hero-2', 
level: 42]
+        '''
+    }
+
+    @Test
+    void testCompileStaticAsyncContextPropagation() {
+        assertScript '''
+            import groovy.concurrent.AsyncContext
+            import groovy.concurrent.Awaitable
+            import groovy.transform.CompileStatic
+
+            @CompileStatic
+            class QuestTrace {
+                async String nested() {
+                    await Awaitable.delay(5)
+                    return (String) AsyncContext.current().get('traceId')
+                }
+
+                async String traced() {
+                    AsyncContext.current().put('traceId', 'trace-42')
+                    return await nested()
+                }
+            }
+
+            assert new QuestTrace().traced().get() == 'trace-42'
+        '''
+    }
+
 }

Reply via email to