This is an automated email from the ASF dual-hosted git repository.

markt pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tomcat.git


The following commit(s) were added to refs/heads/main by this push:
     new 00ef1b65d7 More improvements to exception handling during WebSocket 
msg processing
00ef1b65d7 is described below

commit 00ef1b65d71c3cd6f826461a54aacaa332c59196
Author: Mark Thomas <ma...@apache.org>
AuthorDate: Thu Aug 8 17:31:21 2024 +0100

    More improvements to exception handling during WebSocket msg processing
    
    This broadens the previous fix that just addressed issues with Encoders
    and reworks some unit tests that expected an exception during message
    processing to close the connection.
---
 .../tomcat/websocket/pojo/LocalStrings.properties  |  2 +-
 .../websocket/pojo/PojoMessageHandlerBase.java     | 26 ++++++------
 .../pojo/PojoMessageHandlerPartialBase.java        |  6 ++-
 .../pojo/PojoMessageHandlerWholeBase.java          |  6 ++-
 .../tomcat/websocket/TestWsRemoteEndpoint.java     | 28 +++++++------
 .../apache/tomcat/websocket/TesterEchoServer.java  | 46 +++++++++++++++++++---
 webapps/docs/changelog.xml                         |  8 ++--
 7 files changed, 85 insertions(+), 37 deletions(-)

diff --git a/java/org/apache/tomcat/websocket/pojo/LocalStrings.properties 
b/java/org/apache/tomcat/websocket/pojo/LocalStrings.properties
index c9ef340cba..a59548013b 100644
--- a/java/org/apache/tomcat/websocket/pojo/LocalStrings.properties
+++ b/java/org/apache/tomcat/websocket/pojo/LocalStrings.properties
@@ -19,7 +19,7 @@ pojoEndpointBase.onError=No error handling configured for 
[{0}] and the followin
 pojoEndpointBase.onErrorFail=Failed to call onError method of POJO end point 
for POJO of type [{0}]
 pojoEndpointBase.onOpenFail=Failed to call onOpen method of POJO end point for 
POJO of type [{0}]
 
-pojoMessageHandlerBase.encodeFail=Encoding failed for POJO of tyoe [{0}] in 
session [{1}]
+pojoMessageHandlerBase.onMessafeFail=Exception during onMessage call to POJO 
of type [{0}] in session [{1}]
 
 pojoMessageHandlerWhole.decodeIoFail=IO error while decoding message
 pojoMessageHandlerWhole.maxBufferSize=The maximum supported message size for 
this implementation is Integer.MAX_VALUE
diff --git a/java/org/apache/tomcat/websocket/pojo/PojoMessageHandlerBase.java 
b/java/org/apache/tomcat/websocket/pojo/PojoMessageHandlerBase.java
index b52fbbba20..4d96720306 100644
--- a/java/org/apache/tomcat/websocket/pojo/PojoMessageHandlerBase.java
+++ b/java/org/apache/tomcat/websocket/pojo/PojoMessageHandlerBase.java
@@ -17,6 +17,7 @@
 package org.apache.tomcat.websocket.pojo;
 
 import java.io.IOException;
+import java.lang.reflect.InvocationTargetException;
 import java.lang.reflect.Method;
 import java.nio.ByteBuffer;
 
@@ -114,19 +115,20 @@ public abstract class PojoMessageHandlerBase<T> 
implements WrappedMessageHandler
     }
 
 
-    protected final void handlePojoMethodException(Throwable t) {
-        t = ExceptionUtils.unwrapInvocationTargetException(t);
+    protected final void 
handlePojoMethodInvocationTargetException(InvocationTargetException e) {
+        /*
+         * This is a failure during the execution of onMessage. This does not 
normally need to trigger the failure of
+         * the WebSocket connection.
+         */
+        Throwable t = ExceptionUtils.unwrapInvocationTargetException(e);
+        // Check for JVM wide issues
         ExceptionUtils.handleThrowable(t);
-        if (t instanceof EncodeException) {
-            if (log.isDebugEnabled()) {
-                log.debug(sm.getString("pojoMessageHandlerBase.encodeFail", 
pojo.getClass().getName(), session.getId()),
-                        t);
-            }
-            ((WsSession) session).getLocal().onError(session, t);
-        } else if (t instanceof RuntimeException) {
-            throw (RuntimeException) t;
-        } else {
-            throw new RuntimeException(t.getMessage(), t);
+        // Log at debug level since this is an application issue and the 
application should be handling this.
+        if (log.isDebugEnabled()) {
+            log.debug(sm.getString("pojoMessageHandlerBase.onMessafeFail", 
pojo.getClass().getName(), session.getId()),
+                    t);
         }
+        // Notify the application of the issue so it can handle it.
+        ((WsSession) session).getLocal().onError(session, t);
     }
 }
diff --git 
a/java/org/apache/tomcat/websocket/pojo/PojoMessageHandlerPartialBase.java 
b/java/org/apache/tomcat/websocket/pojo/PojoMessageHandlerPartialBase.java
index 1614fa9fe5..aae708f0a1 100644
--- a/java/org/apache/tomcat/websocket/pojo/PojoMessageHandlerPartialBase.java
+++ b/java/org/apache/tomcat/websocket/pojo/PojoMessageHandlerPartialBase.java
@@ -65,8 +65,10 @@ public abstract class PojoMessageHandlerPartialBase<T> 
extends PojoMessageHandle
         Object result = null;
         try {
             result = method.invoke(pojo, parameters);
-        } catch (IllegalAccessException | InvocationTargetException e) {
-            handlePojoMethodException(e);
+        } catch (InvocationTargetException e) {
+            handlePojoMethodInvocationTargetException(e);
+        } catch (IllegalAccessException e) {
+            throw new RuntimeException(e);
         }
         processResult(result);
     }
diff --git 
a/java/org/apache/tomcat/websocket/pojo/PojoMessageHandlerWholeBase.java 
b/java/org/apache/tomcat/websocket/pojo/PojoMessageHandlerWholeBase.java
index 4ede4632aa..80b1116d41 100644
--- a/java/org/apache/tomcat/websocket/pojo/PojoMessageHandlerWholeBase.java
+++ b/java/org/apache/tomcat/websocket/pojo/PojoMessageHandlerWholeBase.java
@@ -100,8 +100,10 @@ public abstract class PojoMessageHandlerWholeBase<T> 
extends PojoMessageHandlerB
         Object result = null;
         try {
             result = method.invoke(pojo, parameters);
-        } catch (IllegalAccessException | InvocationTargetException e) {
-            handlePojoMethodException(e);
+        } catch (InvocationTargetException e) {
+            handlePojoMethodInvocationTargetException(e);
+        } catch (IllegalAccessException e) {
+            throw new RuntimeException(e);
         }
         processResult(result);
     }
diff --git a/test/org/apache/tomcat/websocket/TestWsRemoteEndpoint.java 
b/test/org/apache/tomcat/websocket/TestWsRemoteEndpoint.java
index 67001bdcd6..db9692af7a 100644
--- a/test/org/apache/tomcat/websocket/TestWsRemoteEndpoint.java
+++ b/test/org/apache/tomcat/websocket/TestWsRemoteEndpoint.java
@@ -218,24 +218,30 @@ public class TestWsRemoteEndpoint extends 
WebSocketBaseTest {
             wsSession = wsContainer.connectToServer(clazz, uri);
         }
 
-        CountDownLatch latch = new CountDownLatch(1);
-        TesterEndpoint tep = (TesterEndpoint) 
wsSession.getUserProperties().get("endpoint");
-        tep.setLatch(latch);
-        AsyncHandler<?> handler;
-        handler = new AsyncText(latch);
-
+        AsyncHandler<?> handler = new AsyncText(null);
         wsSession.addMessageHandler(handler);
 
         // This should trigger the error
-        wsSession.getBasicRemote().sendText("Start");
+        
wsSession.getBasicRemote().sendText(TesterEchoServer.WriterError.MSG_ERROR);
 
-        boolean latchResult = handler.getLatch().await(10, TimeUnit.SECONDS);
-
-        Assert.assertTrue(latchResult);
+        // This should get a PASS/FAIL message
+        
wsSession.getBasicRemote().sendText(TesterEchoServer.WriterError.MSG_COUNT);
 
         @SuppressWarnings("unchecked")
         List<String> messages = (List<String>) handler.getMessages();
 
-        Assert.assertEquals(0, messages.size());
+        // There should be a response - allow up to 15s
+        int count = 0;
+        while (count < 300 && messages.size() == 0) {
+            // 200 * 50 == 15,000ms == 15s
+            try {
+                Thread.sleep(50);
+            } catch (InterruptedException e) {
+                // Ignore
+            }
+            count++;
+        }
+        Assert.assertEquals(1, messages.size());
+        Assert.assertEquals(TesterEchoServer.WriterError.RESULT_PASS, 
messages.get(0));
     }
 }
diff --git a/test/org/apache/tomcat/websocket/TesterEchoServer.java 
b/test/org/apache/tomcat/websocket/TesterEchoServer.java
index f3007aff5a..0c90d8b015 100644
--- a/test/org/apache/tomcat/websocket/TesterEchoServer.java
+++ b/test/org/apache/tomcat/websocket/TesterEchoServer.java
@@ -17,10 +17,13 @@
 package org.apache.tomcat.websocket;
 
 import java.io.IOException;
+import java.io.Writer;
 import java.nio.ByteBuffer;
+import java.util.concurrent.atomic.AtomicInteger;
 
 import jakarta.servlet.ServletContextEvent;
 import jakarta.websocket.DeploymentException;
+import jakarta.websocket.OnError;
 import jakarta.websocket.OnMessage;
 import jakarta.websocket.Session;
 import jakarta.websocket.server.ServerContainer;
@@ -192,12 +195,39 @@ public class TesterEchoServer {
     @ServerEndpoint("/echoWriterError")
     public static class WriterError {
 
+        public static final String MSG_ERROR = "error";
+        public static final String MSG_COUNT = "count";
+        public static final String RESULT_PASS = "PASS";
+        public static final String RESULT_FAIL = "FAIL";
+
+        private AtomicInteger errorCount = new AtomicInteger(0);
+
         @OnMessage
-        public void echoTextMessage(Session session, 
@SuppressWarnings("unused") String msg) {
-            try {
-                session.getBasicRemote().getSendWriter();
-                // Simulate an error
-                throw new RuntimeException();
+        public void echoTextMessage(Session session, String msg) {
+            try (Writer w = session.getBasicRemote().getSendWriter()) {
+                if (MSG_ERROR.equals(msg)) {
+                    // Simulate an error
+                    throw new RuntimeException();
+                } else if (MSG_COUNT.equals(msg)) {
+                    int count = 0;
+                    while (count < 200 && errorCount.get() == 0) {
+                        // 200 * 50 == 10,000ms == 10s
+                        try {
+                            Thread.sleep(50);
+                        } catch (InterruptedException e) {
+                            // Ignore
+                        }
+                        count++;
+                    }
+                    if (errorCount.get() == 1) {
+                        w.write(RESULT_PASS);
+                    } else {
+                        w.write(RESULT_FAIL);
+                    }
+                } else {
+                    // Default is echo
+                   w.write(msg);
+                }
             } catch (IOException e) {
                 // Should not happen
                 try {
@@ -207,6 +237,12 @@ public class TesterEchoServer {
                 }
             }
         }
+
+
+        @OnError
+        public void onError(@SuppressWarnings("unused") Throwable t) {
+            errorCount.incrementAndGet();
+        }
     }
 
     @ServerEndpoint("/")
diff --git a/webapps/docs/changelog.xml b/webapps/docs/changelog.xml
index 2f14ccb16a..109ea3eafb 100644
--- a/webapps/docs/changelog.xml
+++ b/webapps/docs/changelog.xml
@@ -120,10 +120,10 @@
         again before throwing the exception. (markt)
       </fix>
       <fix>
-        An EncodeException being thrown during a message write should not
-        automatically cause the connection to close. The application should
-        handle the exception and make the decision whether or not to close the
-        connection. (markt)
+        An Exception being thrown during message processing (e.g. in a method
+        annotated with <code>@onMessage</code>) should not automatically cause
+        the connection to close. The application should handle the exception 
and
+        make the decision whether or not to close the connection. (markt)
       </fix>
     </changelog>
   </subsection>


---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscr...@tomcat.apache.org
For additional commands, e-mail: dev-h...@tomcat.apache.org

Reply via email to