This is an automated email from the ASF dual-hosted git repository.

markt pushed a commit to branch 10.1.x
in repository https://gitbox.apache.org/repos/asf/tomcat.git

commit 5554cadc88b9ba72f05675ade0bdb11c2a25fd03
Author: Mark Thomas <ma...@apache.org>
AuthorDate: Fri Aug 4 14:53:56 2023 +0100

    Fix BZ 66841 - ensure AsyncListener.onError is called after an error
    
    https://bz.apache.org/bugzilla/show_bug.cgi?id=66841
---
 java/org/apache/coyote/AbstractProcessor.java    |   4 +-
 java/org/apache/coyote/AsyncStateMachine.java    |  38 ++++-
 java/org/apache/coyote/LocalStrings.properties   |   2 +
 test/org/apache/coyote/http2/TestAsyncError.java | 170 +++++++++++++++++++++++
 webapps/docs/changelog.xml                       |   5 +
 5 files changed, 215 insertions(+), 4 deletions(-)

diff --git a/java/org/apache/coyote/AbstractProcessor.java 
b/java/org/apache/coyote/AbstractProcessor.java
index ece5a6abda..b63eeae02e 100644
--- a/java/org/apache/coyote/AbstractProcessor.java
+++ b/java/org/apache/coyote/AbstractProcessor.java
@@ -105,7 +105,7 @@ public abstract class AbstractProcessor extends 
AbstractProcessorLight implement
         }
         // Use the return value to avoid processing more than one async error
         // in a single async cycle.
-        boolean setError = response.setError();
+        response.setError();
         boolean blockIo = this.errorState.isIoAllowed() && 
!errorState.isIoAllowed();
         this.errorState = this.errorState.getMostSevere(errorState);
         // Don't change the status code for IOException since that is almost
@@ -117,7 +117,7 @@ public abstract class AbstractProcessor extends 
AbstractProcessorLight implement
         if (t != null) {
             request.setAttribute(RequestDispatcher.ERROR_EXCEPTION, t);
         }
-        if (blockIo && isAsync() && setError) {
+        if (blockIo && isAsync()) {
             if (asyncStateMachine.asyncError()) {
                 processSocketEvent(SocketEvent.ERROR, true);
             }
diff --git a/java/org/apache/coyote/AsyncStateMachine.java 
b/java/org/apache/coyote/AsyncStateMachine.java
index d2bcae1b0d..c81f93735c 100644
--- a/java/org/apache/coyote/AsyncStateMachine.java
+++ b/java/org/apache/coyote/AsyncStateMachine.java
@@ -185,6 +185,16 @@ class AsyncStateMachine {
      * ends badly: e.g. CVE-2018-8037.
      */
     private final AtomicLong generation = new AtomicLong(0);
+    /*
+     * Error processing should only be triggered once per async generation. 
These fields track the last generation of
+     * async processing for which error processing was triggered and are used 
to ensure that the second and subsequent
+     * attempts to trigger async error processing for a given generation are 
NO-OPs.
+     *
+     * Guarded by this
+     */
+    private long lastErrorGeneration = -1;
+    private long lastErrorGenerationMust = -1;
+
     // Need this to fire listener on complete
     private AsyncContextCallback asyncCtxt = null;
     private final AbstractProcessor processor;
@@ -409,6 +419,30 @@ class AsyncStateMachine {
 
 
     synchronized boolean asyncError() {
+        Request request = processor.getRequest();
+        boolean containerThread = (request != null && 
request.isRequestThread());
+
+        // Ensure the error processing is only started once per generation
+        if (lastErrorGeneration == getCurrentGeneration()) {
+            if (state == AsyncState.MUST_ERROR && containerThread && 
lastErrorGenerationMust != getCurrentGeneration()) {
+                // This is the first container thread call after state was set 
to MUST_ERROR so don't skip
+                lastErrorGenerationMust = getCurrentGeneration();
+            } else {
+                // Duplicate call. Skip.
+                if (log.isDebugEnabled()) {
+                    
log.debug(sm.getString("asyncStateMachine.asyncError.skip"));
+                }
+                return false;
+            }
+        } else {
+            // First call for this generation, don't skip.
+            lastErrorGeneration = getCurrentGeneration();
+        }
+
+        if (log.isDebugEnabled()) {
+            log.debug(sm.getString("asyncStateMachine.asyncError.start"));
+        }
+
         clearNonBlockingListeners();
         if (state == AsyncState.STARTING) {
             updateState(AsyncState.MUST_ERROR);
@@ -422,8 +456,8 @@ class AsyncStateMachine {
             updateState(AsyncState.ERROR);
         }
 
-        Request request = processor.getRequest();
-        return request == null || !request.isRequestThread();
+        // Return true for non-container threads to trigger a dispatch
+        return !containerThread;
     }
 
 
diff --git a/java/org/apache/coyote/LocalStrings.properties 
b/java/org/apache/coyote/LocalStrings.properties
index b708730f45..74adeeaceb 100644
--- a/java/org/apache/coyote/LocalStrings.properties
+++ b/java/org/apache/coyote/LocalStrings.properties
@@ -51,6 +51,8 @@ abstractProtocolHandler.setAttribute=Set attribute [{0}] with 
value [{1}]
 abstractProtocolHandler.start=Starting ProtocolHandler [{0}]
 abstractProtocolHandler.stop=Stopping ProtocolHandler [{0}]
 
+asyncStateMachine.asyncError.skip=Ignoring call to asyncError() as it has 
already been called since async processing started
+asyncStateMachine.asyncError.start=Starting to process call to asyncError()
 asyncStateMachine.invalidAsyncState=Calling [{0}] is not valid for a request 
with Async state [{1}]
 asyncStateMachine.stateChange=Changing async state from [{0}] to [{1}]
 
diff --git a/test/org/apache/coyote/http2/TestAsyncError.java 
b/test/org/apache/coyote/http2/TestAsyncError.java
new file mode 100644
index 0000000000..176e410ded
--- /dev/null
+++ b/test/org/apache/coyote/http2/TestAsyncError.java
@@ -0,0 +1,170 @@
+/*
+ *  Licensed to the Apache Software Foundation (ASF) under one or more
+ *  contributor license agreements.  See the NOTICE file distributed with
+ *  this work for additional information regarding copyright ownership.
+ *  The ASF licenses this file to You under the Apache License, Version 2.0
+ *  (the "License"); you may not use this file except in compliance with
+ *  the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ *  Unless required by applicable law or agreed to in writing, software
+ *  distributed under the License is distributed on an "AS IS" BASIS,
+ *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ *  See the License for the specific language governing permissions and
+ *  limitations under the License.
+ */
+package org.apache.coyote.http2;
+
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.nio.ByteBuffer;
+import java.nio.charset.StandardCharsets;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import jakarta.servlet.AsyncContext;
+import jakarta.servlet.AsyncEvent;
+import jakarta.servlet.AsyncListener;
+import jakarta.servlet.ServletException;
+import jakarta.servlet.http.HttpServlet;
+import jakarta.servlet.http.HttpServletRequest;
+import jakarta.servlet.http.HttpServletResponse;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import org.apache.catalina.Context;
+import org.apache.catalina.Wrapper;
+import org.apache.catalina.startup.Tomcat;
+
+/*
+ * Based on
+ * https://bz.apache.org/bugzilla/show_bug.cgi?id=66841
+ */
+public class TestAsyncError extends Http2TestBase {
+
+    @Test
+    public void testError() throws Exception {
+
+        enableHttp2();
+
+        Tomcat tomcat = getTomcatInstance();
+
+        Context ctxt = tomcat.addContext("", null);
+        Tomcat.addServlet(ctxt, "simple", new SimpleServlet());
+        ctxt.addServletMappingDecoded("/simple", "simple");
+        Wrapper w = Tomcat.addServlet(ctxt, "async", new AsyncErrorServlet());
+        w.setAsyncSupported(true);
+        ctxt.addServletMappingDecoded("/async", "async");
+        tomcat.start();
+
+        openClientConnection();
+        doHttpUpgrade();
+        sendClientPreface();
+        validateHttp2InitialResponse();
+
+        // Send request
+        byte[] frameHeader = new byte[9];
+        ByteBuffer headersPayload = ByteBuffer.allocate(128);
+        buildGetRequest(frameHeader, headersPayload, null, 3, "/async");
+        writeFrame(frameHeader, headersPayload);
+
+        // Read response
+        // Headers
+        parser.readFrame();
+
+        // Read 3 'events'
+        parser.readFrame();
+        parser.readFrame();
+        parser.readFrame();
+
+        // Reset the stream
+        sendRst(3, Http2Error.CANCEL.getCode());
+
+        int count = 0;
+        while (count < 50 && TestListener.getErrorCount() == 0) {
+            count++;
+            Thread.sleep(100);
+        }
+
+        Assert.assertTrue(TestListener.getErrorCount() > 0);
+    }
+
+
+    private static final class AsyncErrorServlet extends HttpServlet {
+
+        private static final long serialVersionUID = 1L;
+
+        @Override
+        protected void doGet(HttpServletRequest req, HttpServletResponse resp) 
throws ServletException, IOException {
+
+            final AsyncContext asyncContext = req.startAsync();
+            TestListener testListener = new TestListener();
+            asyncContext.addListener(testListener);
+
+            MessageGenerator msgGenerator = new MessageGenerator(resp);
+            asyncContext.start(msgGenerator);
+        }
+    }
+
+
+    private static final class MessageGenerator implements Runnable {
+
+        private final HttpServletResponse resp;
+
+        MessageGenerator(HttpServletResponse resp) {
+            this.resp = resp;
+        }
+
+        @Override
+        public void run() {
+            try {
+                resp.setContentType("text/plain");
+                resp.setCharacterEncoding(StandardCharsets.UTF_8);
+                PrintWriter pw = resp.getWriter();
+
+                while (true) {
+                    pw.println("OK");
+                    pw.flush();
+                    if (pw.checkError()) {
+                        throw new IOException();
+                    }
+                    Thread.sleep(1000);
+                }
+            } catch (IOException | InterruptedException e) {
+                // Expect async error handler to handle clean-up
+            }
+        }
+    }
+
+
+    private static final class TestListener implements AsyncListener {
+
+        private static final AtomicInteger errorCount = new AtomicInteger(0);
+
+        public static int getErrorCount() {
+            return errorCount.get();
+        }
+
+        @Override
+        public void onComplete(AsyncEvent event) throws IOException {
+            // NO-OP
+        }
+
+        @Override
+        public void onTimeout(AsyncEvent event) throws IOException {
+            // NO-OP
+        }
+
+        @Override
+        public void onError(AsyncEvent event) throws IOException {
+            errorCount.incrementAndGet();
+            event.getAsyncContext().complete();
+        }
+
+        @Override
+        public void onStartAsync(AsyncEvent event) throws IOException {
+            // NO-OP
+        }
+    }
+}
diff --git a/webapps/docs/changelog.xml b/webapps/docs/changelog.xml
index fd9df27fe5..44c53200cf 100644
--- a/webapps/docs/changelog.xml
+++ b/webapps/docs/changelog.xml
@@ -139,6 +139,11 @@
         <code>PROFILE=SYSTEM</code> instead of producing an error trying to
         parse it. (remm)
       </fix>
+      <fix>
+        <bug>66841</bug>: Ensure that <code>AsyncListener.onError()</code> is
+        called after an error during asynchronous processing with HTTP/2.
+        (markt)
+      </fix>
     </changelog>
   </subsection>
   <subsection name="WebSocket">


---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscr...@tomcat.apache.org
For additional commands, e-mail: dev-h...@tomcat.apache.org

Reply via email to