This is an automated email from the ASF dual-hosted git repository.

markt pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tomcat.git

commit dccc2644ce701e88b152563473a350ec33a29a81
Author: Mark Thomas <ma...@apache.org>
AuthorDate: Fri Mar 24 17:21:04 2023 +0000

    Further fix for BZ 66508
    
    https://bz.apache.org/bugzilla/show_bug.cgi?id=66508
    
    Avoid deadlock for close messages when
    WsRemoteEndpointImplServer.endMessage() for a previous message is
    processed on a container thread
---
 .../apache/tomcat/util/net/SocketWrapperBase.java  |  5 +-
 .../tomcat/websocket/WsRemoteEndpointImplBase.java | 26 +++++++-
 .../websocket/WsRemoteEndpointImplClient.java      |  9 +++
 java/org/apache/tomcat/websocket/WsSession.java    | 12 +++-
 .../websocket/server/WsHttpUpgradeHandler.java     |  3 +-
 .../server/WsRemoteEndpointImplServer.java         | 77 +++++++++++++++++++++-
 6 files changed, 121 insertions(+), 11 deletions(-)

diff --git a/java/org/apache/tomcat/util/net/SocketWrapperBase.java 
b/java/org/apache/tomcat/util/net/SocketWrapperBase.java
index ae6ee0e019..1216bd7a3a 100644
--- a/java/org/apache/tomcat/util/net/SocketWrapperBase.java
+++ b/java/org/apache/tomcat/util/net/SocketWrapperBase.java
@@ -31,7 +31,6 @@ import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
-import java.util.concurrent.locks.Lock;
 import java.util.concurrent.locks.ReentrantLock;
 
 import jakarta.servlet.ServletConnection;
@@ -61,7 +60,7 @@ public abstract class SocketWrapperBase<E> {
 
     private E socket;
     private final AbstractEndpoint<E,?> endpoint;
-    private final Lock lock = new ReentrantLock();
+    private final ReentrantLock lock = new ReentrantLock();
 
     protected final AtomicBoolean closed = new AtomicBoolean(false);
 
@@ -158,7 +157,7 @@ public abstract class SocketWrapperBase<E> {
         return endpoint;
     }
 
-    public Lock getLock() {
+    public ReentrantLock getLock() {
         return lock;
     }
 
diff --git a/java/org/apache/tomcat/websocket/WsRemoteEndpointImplBase.java 
b/java/org/apache/tomcat/websocket/WsRemoteEndpointImplBase.java
index eec3381a85..5dc9298b6e 100644
--- a/java/org/apache/tomcat/websocket/WsRemoteEndpointImplBase.java
+++ b/java/org/apache/tomcat/websocket/WsRemoteEndpointImplBase.java
@@ -33,6 +33,7 @@ import java.util.concurrent.Future;
 import java.util.concurrent.Semaphore;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.locks.ReentrantLock;
 
 import javax.naming.NamingException;
 
@@ -66,7 +67,7 @@ public abstract class WsRemoteEndpointImplBase implements 
RemoteEndpoint {
     private final IntermediateMessageHandler intermediateMessageHandler = new 
IntermediateMessageHandler(this);
 
     private Transformation transformation = null;
-    private final Semaphore messagePartInProgress = new Semaphore(1);
+    protected final Semaphore messagePartInProgress = new Semaphore(1);
     private final Queue<MessagePart> messagePartQueue = new ArrayDeque<>();
     private final Object messagePartLock = new Object();
 
@@ -288,9 +289,8 @@ public abstract class WsRemoteEndpointImplBase implements 
RemoteEndpoint {
             return;
         }
 
-        long timeout = timeoutExpiry - System.currentTimeMillis();
         try {
-            if (!messagePartInProgress.tryAcquire(timeout, 
TimeUnit.MILLISECONDS)) {
+            if (!acquireMessagePartInProgressSemaphore(opCode, timeoutExpiry)) 
{
                 String msg = sm.getString("wsRemoteEndpoint.acquireTimeout");
                 wsSession.doClose(new CloseReason(CloseCodes.GOING_AWAY, msg),
                         new CloseReason(CloseCodes.CLOSED_ABNORMALLY, msg), 
true);
@@ -334,6 +334,23 @@ public abstract class WsRemoteEndpointImplBase implements 
RemoteEndpoint {
     }
 
 
+    /**
+     * Acquire the semaphore that allows a message part to be written.
+     *
+     * @param opCode        The OPCODE for the message to be written
+     * @param timeoutExpiry The time when the attempt to acquire the semaphore 
should expire
+     *
+     * @return {@code true} if the semaphore is obtained, otherwise {@code 
false}.
+     *
+     * @throws InterruptedException If the wait for the semaphore is 
interrupted
+     */
+    protected boolean acquireMessagePartInProgressSemaphore(byte opCode, long 
timeoutExpiry)
+            throws InterruptedException {
+        long timeout = timeoutExpiry - System.currentTimeMillis();
+        return messagePartInProgress.tryAcquire(timeout, 
TimeUnit.MILLISECONDS);
+    }
+
+
     void startMessage(byte opCode, ByteBuffer payload, boolean last, 
SendHandler handler) {
 
         wsSession.updateLastActiveWrite();
@@ -735,6 +752,9 @@ public abstract class WsRemoteEndpointImplBase implements 
RemoteEndpoint {
 
     protected abstract void doClose();
 
+    protected abstract ReentrantLock getLock();
+
+
     private static void writeHeader(ByteBuffer headerBuffer, boolean fin, int 
rsv, byte opCode, boolean masked,
             ByteBuffer payload, byte[] mask, boolean first) {
 
diff --git a/java/org/apache/tomcat/websocket/WsRemoteEndpointImplClient.java 
b/java/org/apache/tomcat/websocket/WsRemoteEndpointImplClient.java
index d22ac2d3fd..5f0cc67fc3 100644
--- a/java/org/apache/tomcat/websocket/WsRemoteEndpointImplClient.java
+++ b/java/org/apache/tomcat/websocket/WsRemoteEndpointImplClient.java
@@ -21,6 +21,7 @@ import java.nio.ByteBuffer;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
+import java.util.concurrent.locks.ReentrantLock;
 
 import jakarta.websocket.SendHandler;
 import jakarta.websocket.SendResult;
@@ -28,6 +29,7 @@ import jakarta.websocket.SendResult;
 public class WsRemoteEndpointImplClient extends WsRemoteEndpointImplBase {
 
     private final AsyncChannelWrapper channel;
+    private final ReentrantLock lock = new ReentrantLock();
 
     public WsRemoteEndpointImplClient(AsyncChannelWrapper channel) {
         this.channel = channel;
@@ -67,8 +69,15 @@ public class WsRemoteEndpointImplClient extends 
WsRemoteEndpointImplBase {
         handler.onResult(SENDRESULT_OK);
     }
 
+
     @Override
     protected void doClose() {
         channel.close();
     }
+
+
+    @Override
+    protected ReentrantLock getLock() {
+        return lock;
+    }
 }
diff --git a/java/org/apache/tomcat/websocket/WsSession.java 
b/java/org/apache/tomcat/websocket/WsSession.java
index 1851e80001..5f947afde7 100644
--- a/java/org/apache/tomcat/websocket/WsSession.java
+++ b/java/org/apache/tomcat/websocket/WsSession.java
@@ -107,7 +107,6 @@ public class WsSession implements Session {
     private volatile MessageHandler binaryMessageHandler = null;
     private volatile MessageHandler.Whole<PongMessage> pongMessageHandler = 
null;
     private volatile State state = State.OPEN;
-    private final Object stateLock = new Object();
     private final Map<String, Object> userProperties = new 
ConcurrentHashMap<>();
     private volatile int maxBinaryMessageBufferSize = 
Constants.DEFAULT_BUFFER_SIZE;
     private volatile int maxTextMessageBufferSize = 
Constants.DEFAULT_BUFFER_SIZE;
@@ -564,7 +563,8 @@ public class WsSession implements Session {
             return;
         }
 
-        synchronized (stateLock) {
+        wsRemoteEndpoint.getLock().lock();
+        try {
             if (state != State.OPEN) {
                 return;
             }
@@ -594,6 +594,8 @@ public class WsSession implements Session {
                 }
                 fireEndpointOnClose(closeReasonLocal);
             }
+        } finally {
+            wsRemoteEndpoint.getLock().unlock();
         }
 
         IOException ioe = new 
IOException(sm.getString("wsSession.messageFailed"));
@@ -612,7 +614,8 @@ public class WsSession implements Session {
      */
     public void onClose(CloseReason closeReason) {
 
-        synchronized (stateLock) {
+        wsRemoteEndpoint.getLock().lock();
+        try {
             if (state != State.CLOSED) {
                 try {
                     wsRemoteEndpoint.setBatchingAllowed(false);
@@ -630,9 +633,12 @@ public class WsSession implements Session {
                 // Close the socket
                 wsRemoteEndpoint.close();
             }
+        } finally {
+            wsRemoteEndpoint.getLock().unlock();
         }
     }
 
+
     private void fireEndpointOnClose(CloseReason closeReason) {
 
         // Fire the onClose event
diff --git a/java/org/apache/tomcat/websocket/server/WsHttpUpgradeHandler.java 
b/java/org/apache/tomcat/websocket/server/WsHttpUpgradeHandler.java
index b180e52a74..7f37101e71 100644
--- a/java/org/apache/tomcat/websocket/server/WsHttpUpgradeHandler.java
+++ b/java/org/apache/tomcat/websocket/server/WsHttpUpgradeHandler.java
@@ -117,7 +117,8 @@ public class WsHttpUpgradeHandler implements 
InternalHttpUpgradeHandler {
         ClassLoader cl = t.getContextClassLoader();
         t.setContextClassLoader(applicationClassLoader);
         try {
-            wsRemoteEndpointServer = new 
WsRemoteEndpointImplServer(socketWrapper, upgradeInfo, webSocketContainer);
+            wsRemoteEndpointServer = new 
WsRemoteEndpointImplServer(socketWrapper, upgradeInfo, webSocketContainer,
+                    connection);
             wsSession = new WsSession(wsRemoteEndpointServer, 
webSocketContainer, handshakeRequest.getRequestURI(),
                     handshakeRequest.getParameterMap(), 
handshakeRequest.getQueryString(),
                     handshakeRequest.getUserPrincipal(), httpSessionId, 
negotiatedExtensions, subProtocol,
diff --git 
a/java/org/apache/tomcat/websocket/server/WsRemoteEndpointImplServer.java 
b/java/org/apache/tomcat/websocket/server/WsRemoteEndpointImplServer.java
index 8dd5974328..31779aa32a 100644
--- a/java/org/apache/tomcat/websocket/server/WsRemoteEndpointImplServer.java
+++ b/java/org/apache/tomcat/websocket/server/WsRemoteEndpointImplServer.java
@@ -23,7 +23,9 @@ import java.nio.ByteBuffer;
 import java.nio.channels.CompletionHandler;
 import java.util.concurrent.RejectedExecutionException;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.locks.ReentrantLock;
 
+import jakarta.servlet.http.WebConnection;
 import jakarta.websocket.SendHandler;
 import jakarta.websocket.SendResult;
 
@@ -33,6 +35,7 @@ import org.apache.juli.logging.LogFactory;
 import org.apache.tomcat.util.net.SocketWrapperBase;
 import org.apache.tomcat.util.net.SocketWrapperBase.BlockingMode;
 import org.apache.tomcat.util.res.StringManager;
+import org.apache.tomcat.websocket.Constants;
 import org.apache.tomcat.websocket.Transformation;
 import org.apache.tomcat.websocket.WsRemoteEndpointImplBase;
 
@@ -47,6 +50,7 @@ public class WsRemoteEndpointImplServer extends 
WsRemoteEndpointImplBase {
 
     private final SocketWrapperBase<?> socketWrapper;
     private final UpgradeInfo upgradeInfo;
+    private final WebConnection connection;
     private final WsWriteTimeout wsWriteTimeout;
     private volatile SendHandler handler = null;
     private volatile ByteBuffer[] buffers = null;
@@ -54,9 +58,10 @@ public class WsRemoteEndpointImplServer extends 
WsRemoteEndpointImplBase {
     private volatile long timeoutExpiry = -1;
 
     public WsRemoteEndpointImplServer(SocketWrapperBase<?> socketWrapper, 
UpgradeInfo upgradeInfo,
-            WsServerContainer serverContainer) {
+            WsServerContainer serverContainer, WebConnection connection) {
         this.socketWrapper = socketWrapper;
         this.upgradeInfo = upgradeInfo;
+        this.connection = connection;
         this.wsWriteTimeout = serverContainer.getTimeout();
     }
 
@@ -67,6 +72,70 @@ public class WsRemoteEndpointImplServer extends 
WsRemoteEndpointImplBase {
     }
 
 
+    /**
+     * {@inheritDoc}
+     * <p>
+     * The close message is a special case. It needs to be blocking else 
implementing the clean-up that follows the
+     * sending of the close message gets a lot more complicated. On the 
server, this creates additional complications as
+     * a dead-lock may occur in the following scenario:
+     * <ol>
+     * <li>Application thread writes message using non-blocking</li>
+     * <li>Write does not complete (write logic holds message pending 
lock)</li>
+     * <li>Socket is added to poller (or equivalent) for write
+     * <li>Client sends close message</li>
+     * <li>Container processes received close message and tries to send close 
message in response</li>
+     * <li>Container holds socket lock and is blocked waiting for message 
pending lock</li>
+     * <li>Poller fires write possible event for socket</li>
+     * <li>Container tries to process write possible event but is blocked 
waiting for socket lock</li>
+     * <li>Processing of the WebSocket connection is dead-locked until the 
original message write times out</li>
+     * </ol>
+     * The purpose of this method is to break the above dead-lock. It does 
this by returning control of the processor to
+     * the socket wrapper and releasing the socket lock while waiting for the 
pending message write to complete.
+     * Normally, that would be a terrible idea as it creates the possibility 
that the processor is returned to the pool
+     * more than once under various error conditions. In this instance it is 
safe because these are upgrade processors
+     * (isUpgrade() returns {@code true}) and upgrade processors are never 
pooled.
+     * <p>
+     * TODO: Despite the complications it creates, it would be worth exploring 
the possibility of processing a received
+     * close frame in a non-blocking manner.
+     */
+    @Override
+    protected boolean acquireMessagePartInProgressSemaphore(byte opCode, long 
timeoutExpiry)
+            throws InterruptedException {
+
+        // Only close requires special handling.
+        if (opCode != Constants.OPCODE_CLOSE) {
+            return super.acquireMessagePartInProgressSemaphore(opCode, 
timeoutExpiry);
+        }
+
+        int socketWrapperLockCount = socketWrapper.getLock().getHoldCount();
+        while (!messagePartInProgress.tryAcquire()) {
+            long timeout = timeoutExpiry - System.currentTimeMillis();
+            if (timeout < 0) {
+                return false;
+            }
+            try {
+                // Release control of the processor
+                socketWrapper.setCurrentProcessor(connection);
+                // Release the per socket lock(s)
+                for (int i = 0; i < socketWrapperLockCount; i++) {
+                    socketWrapper.getLock().unlock();
+                }
+                // Provide opportunity for another thread to obtain the 
socketWrapper lock
+                Thread.yield();
+            } finally {
+                // Re-obtain the per socket lock(s)
+                for (int i = 0; i < socketWrapperLockCount; i++) {
+                    socketWrapper.getLock().lock();
+                }
+                // Re-take control of the processor
+                socketWrapper.takeCurrentProcessor();
+            }
+        }
+
+        return true;
+    }
+
+
     @Override
     protected void doWrite(SendHandler handler, long 
blockingWriteTimeoutExpiry, ByteBuffer... buffers) {
         if (socketWrapper.hasAsyncIO()) {
@@ -296,6 +365,12 @@ public class WsRemoteEndpointImplServer extends 
WsRemoteEndpointImplBase {
     }
 
 
+    @Override
+    protected ReentrantLock getLock() {
+        return socketWrapper.getLock();
+    }
+
+
     private static class OnResultRunnable implements Runnable {
 
         private final SendHandler sh;


---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscr...@tomcat.apache.org
For additional commands, e-mail: dev-h...@tomcat.apache.org

Reply via email to