This is an automated email from the ASF dual-hosted git repository.

markt pushed a commit to branch 8.5.x
in repository https://gitbox.apache.org/repos/asf/tomcat.git

commit 6688878e8a5b94af6bcd48b48c95bfa221643dc1
Author: Mark Thomas <ma...@apache.org>
AuthorDate: Wed Apr 5 17:18:36 2023 +0100

    Further fix for BZ 66508
    
    https://bz.apache.org/bugzilla/show_bug.cgi?id=66508
    
    Avoid deadlock for close messages when
    WsRemoteEndpointImplServer.endMessage() for a previous message is
    processed on a container thread
---
 .../tomcat/websocket/WsRemoteEndpointImplBase.java | 28 ++++++--
 .../websocket/WsRemoteEndpointImplClient.java      |  9 +++
 java/org/apache/tomcat/websocket/WsSession.java    | 12 +++-
 .../websocket/server/WsHttpUpgradeHandler.java     |  3 +-
 .../server/WsRemoteEndpointImplServer.java         | 83 +++++++++++++++++++++-
 5 files changed, 126 insertions(+), 9 deletions(-)

diff --git a/java/org/apache/tomcat/websocket/WsRemoteEndpointImplBase.java 
b/java/org/apache/tomcat/websocket/WsRemoteEndpointImplBase.java
index 07a558c247..16068e6dae 100644
--- a/java/org/apache/tomcat/websocket/WsRemoteEndpointImplBase.java
+++ b/java/org/apache/tomcat/websocket/WsRemoteEndpointImplBase.java
@@ -33,6 +33,7 @@ import java.util.concurrent.Future;
 import java.util.concurrent.Semaphore;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.locks.Lock;
 
 import javax.naming.NamingException;
 import javax.websocket.CloseReason;
@@ -65,7 +66,7 @@ public abstract class WsRemoteEndpointImplBase implements 
RemoteEndpoint {
     private final IntermediateMessageHandler intermediateMessageHandler = new 
IntermediateMessageHandler(this);
 
     private Transformation transformation = null;
-    private final Semaphore messagePartInProgress = new Semaphore(1);
+    protected final Semaphore messagePartInProgress = new Semaphore(1);
     private final Queue<MessagePart> messagePartQueue = new ArrayDeque<>();
     private final Object messagePartLock = new Object();
 
@@ -287,9 +288,8 @@ public abstract class WsRemoteEndpointImplBase implements 
RemoteEndpoint {
             return;
         }
 
-        long timeout = timeoutExpiry - System.currentTimeMillis();
         try {
-            if (!messagePartInProgress.tryAcquire(timeout, 
TimeUnit.MILLISECONDS)) {
+            if (!acquireMessagePartInProgressSemaphore(opCode, timeoutExpiry)) 
{
                 String msg = sm.getString("wsRemoteEndpoint.acquireTimeout");
                 wsSession.doClose(new CloseReason(CloseCodes.GOING_AWAY, msg),
                         new CloseReason(CloseCodes.CLOSED_ABNORMALLY, msg), 
true);
@@ -333,6 +333,23 @@ public abstract class WsRemoteEndpointImplBase implements 
RemoteEndpoint {
     }
 
 
+    /**
+     * Acquire the semaphore that allows a message part to be written.
+     *
+     * @param opCode        The OPCODE for the message to be written
+     * @param timeoutExpiry The time when the attempt to acquire the semaphore 
should expire
+     *
+     * @return {@code true} if the semaphore is obtained, otherwise {@code 
false}.
+     *
+     * @throws InterruptedException If the wait for the semaphore is 
interrupted
+     */
+    protected boolean acquireMessagePartInProgressSemaphore(byte opCode, long 
timeoutExpiry)
+            throws InterruptedException {
+        long timeout = timeoutExpiry - System.currentTimeMillis();
+        return messagePartInProgress.tryAcquire(timeout, 
TimeUnit.MILLISECONDS);
+    }
+
+
     void startMessage(byte opCode, ByteBuffer payload, boolean last, 
SendHandler handler) {
 
         wsSession.updateLastActiveWrite();
@@ -391,7 +408,7 @@ public abstract class WsRemoteEndpointImplBase implements 
RemoteEndpoint {
     }
 
 
-    void endMessage(SendHandler handler, SendResult result) {
+    protected void endMessage(SendHandler handler, SendResult result) {
         boolean doWrite = false;
         MessagePart mpNext = null;
         synchronized (messagePartLock) {
@@ -734,6 +751,9 @@ public abstract class WsRemoteEndpointImplBase implements 
RemoteEndpoint {
 
     protected abstract void doClose();
 
+    protected abstract Lock getLock();
+
+
     private static void writeHeader(ByteBuffer headerBuffer, boolean fin, int 
rsv, byte opCode, boolean masked,
             ByteBuffer payload, byte[] mask, boolean first) {
 
diff --git a/java/org/apache/tomcat/websocket/WsRemoteEndpointImplClient.java 
b/java/org/apache/tomcat/websocket/WsRemoteEndpointImplClient.java
index 2f87e60b52..48a44aef09 100644
--- a/java/org/apache/tomcat/websocket/WsRemoteEndpointImplClient.java
+++ b/java/org/apache/tomcat/websocket/WsRemoteEndpointImplClient.java
@@ -21,6 +21,7 @@ import java.nio.ByteBuffer;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.TimeoutException;
+import java.util.concurrent.locks.ReentrantLock;
 
 import javax.websocket.SendHandler;
 import javax.websocket.SendResult;
@@ -28,6 +29,7 @@ import javax.websocket.SendResult;
 public class WsRemoteEndpointImplClient extends WsRemoteEndpointImplBase {
 
     private final AsyncChannelWrapper channel;
+    private final ReentrantLock lock = new ReentrantLock();
 
     public WsRemoteEndpointImplClient(AsyncChannelWrapper channel) {
         this.channel = channel;
@@ -67,8 +69,15 @@ public class WsRemoteEndpointImplClient extends 
WsRemoteEndpointImplBase {
         handler.onResult(SENDRESULT_OK);
     }
 
+
     @Override
     protected void doClose() {
         channel.close();
     }
+
+
+    @Override
+    protected ReentrantLock getLock() {
+        return lock;
+    }
 }
diff --git a/java/org/apache/tomcat/websocket/WsSession.java 
b/java/org/apache/tomcat/websocket/WsSession.java
index e3167ffa4d..845330a4e5 100644
--- a/java/org/apache/tomcat/websocket/WsSession.java
+++ b/java/org/apache/tomcat/websocket/WsSession.java
@@ -106,7 +106,6 @@ public class WsSession implements Session {
     private volatile MessageHandler binaryMessageHandler = null;
     private volatile MessageHandler.Whole<PongMessage> pongMessageHandler = 
null;
     private volatile State state = State.OPEN;
-    private final Object stateLock = new Object();
     private final Map<String, Object> userProperties = new 
ConcurrentHashMap<>();
     private volatile int maxBinaryMessageBufferSize = 
Constants.DEFAULT_BUFFER_SIZE;
     private volatile int maxTextMessageBufferSize = 
Constants.DEFAULT_BUFFER_SIZE;
@@ -647,7 +646,8 @@ public class WsSession implements Session {
             return;
         }
 
-        synchronized (stateLock) {
+        wsRemoteEndpoint.getLock().lock();
+        try {
             if (state != State.OPEN) {
                 return;
             }
@@ -677,6 +677,8 @@ public class WsSession implements Session {
                 }
                 fireEndpointOnClose(closeReasonLocal);
             }
+        } finally {
+            wsRemoteEndpoint.getLock().unlock();
         }
 
         IOException ioe = new 
IOException(sm.getString("wsSession.messageFailed"));
@@ -695,7 +697,8 @@ public class WsSession implements Session {
      */
     public void onClose(CloseReason closeReason) {
 
-        synchronized (stateLock) {
+        wsRemoteEndpoint.getLock().lock();
+        try {
             if (state != State.CLOSED) {
                 try {
                     wsRemoteEndpoint.setBatchingAllowed(false);
@@ -713,9 +716,12 @@ public class WsSession implements Session {
                 // Close the socket
                 wsRemoteEndpoint.close();
             }
+        } finally {
+            wsRemoteEndpoint.getLock().unlock();
         }
     }
 
+
     private void fireEndpointOnClose(CloseReason closeReason) {
 
         // Fire the onClose event
diff --git a/java/org/apache/tomcat/websocket/server/WsHttpUpgradeHandler.java 
b/java/org/apache/tomcat/websocket/server/WsHttpUpgradeHandler.java
index 3cae80d730..29bccf163a 100644
--- a/java/org/apache/tomcat/websocket/server/WsHttpUpgradeHandler.java
+++ b/java/org/apache/tomcat/websocket/server/WsHttpUpgradeHandler.java
@@ -117,7 +117,8 @@ public class WsHttpUpgradeHandler implements 
InternalHttpUpgradeHandler {
         ClassLoader cl = t.getContextClassLoader();
         t.setContextClassLoader(applicationClassLoader);
         try {
-            wsRemoteEndpointServer = new 
WsRemoteEndpointImplServer(socketWrapper, upgradeInfo, webSocketContainer);
+            wsRemoteEndpointServer = new 
WsRemoteEndpointImplServer(socketWrapper, upgradeInfo, webSocketContainer,
+                    connection);
             wsSession = new WsSession(wsRemoteEndpointServer, 
webSocketContainer, handshakeRequest.getRequestURI(),
                     handshakeRequest.getParameterMap(), 
handshakeRequest.getQueryString(),
                     handshakeRequest.getUserPrincipal(), httpSessionId, 
negotiatedExtensions, subProtocol,
diff --git 
a/java/org/apache/tomcat/websocket/server/WsRemoteEndpointImplServer.java 
b/java/org/apache/tomcat/websocket/server/WsRemoteEndpointImplServer.java
index db270e6b14..bf54fe31c8 100644
--- a/java/org/apache/tomcat/websocket/server/WsRemoteEndpointImplServer.java
+++ b/java/org/apache/tomcat/websocket/server/WsRemoteEndpointImplServer.java
@@ -23,7 +23,10 @@ import java.nio.ByteBuffer;
 import java.nio.channels.CompletionHandler;
 import java.util.concurrent.RejectedExecutionException;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
 
+import javax.servlet.http.WebConnection;
 import javax.websocket.SendHandler;
 import javax.websocket.SendResult;
 
@@ -33,6 +36,7 @@ import org.apache.juli.logging.LogFactory;
 import org.apache.tomcat.util.net.SocketWrapperBase;
 import org.apache.tomcat.util.net.SocketWrapperBase.BlockingMode;
 import org.apache.tomcat.util.res.StringManager;
+import org.apache.tomcat.websocket.Constants;
 import org.apache.tomcat.websocket.Transformation;
 import org.apache.tomcat.websocket.WsRemoteEndpointImplBase;
 
@@ -47,6 +51,7 @@ public class WsRemoteEndpointImplServer extends 
WsRemoteEndpointImplBase {
 
     private final SocketWrapperBase<?> socketWrapper;
     private final UpgradeInfo upgradeInfo;
+    private final WebConnection connection;
     private final WsWriteTimeout wsWriteTimeout;
     private volatile SendHandler handler = null;
     private volatile ByteBuffer[] buffers = null;
@@ -54,9 +59,10 @@ public class WsRemoteEndpointImplServer extends 
WsRemoteEndpointImplBase {
     private volatile long timeoutExpiry = -1;
 
     public WsRemoteEndpointImplServer(SocketWrapperBase<?> socketWrapper, 
UpgradeInfo upgradeInfo,
-            WsServerContainer serverContainer) {
+            WsServerContainer serverContainer, WebConnection connection) {
         this.socketWrapper = socketWrapper;
         this.upgradeInfo = upgradeInfo;
+        this.connection = connection;
         this.wsWriteTimeout = serverContainer.getTimeout();
     }
 
@@ -67,6 +73,75 @@ public class WsRemoteEndpointImplServer extends 
WsRemoteEndpointImplBase {
     }
 
 
+    /**
+     * {@inheritDoc}
+     * <p>
+     * The close message is a special case. It needs to be blocking else 
implementing the clean-up that follows the
+     * sending of the close message gets a lot more complicated. On the 
server, this creates additional complications as
+     * a dead-lock may occur in the following scenario:
+     * <ol>
+     * <li>Application thread writes message using non-blocking</li>
+     * <li>Write does not complete (write logic holds message pending 
lock)</li>
+     * <li>Socket is added to poller (or equivalent) for write
+     * <li>Client sends close message</li>
+     * <li>Container processes received close message and tries to send close 
message in response</li>
+     * <li>Container holds socket lock and is blocked waiting for message 
pending lock</li>
+     * <li>Poller fires write possible event for socket</li>
+     * <li>Container tries to process write possible event but is blocked 
waiting for socket lock</li>
+     * <li>Processing of the WebSocket connection is dead-locked until the 
original message write times out</li>
+     * </ol>
+     * The purpose of this method is to break the above dead-lock. It does 
this by returning control of the processor to
+     * the socket wrapper and releasing the socket lock while waiting for the 
pending message write to complete.
+     * Normally, that would be a terrible idea as it creates the possibility 
that the processor is returned to the pool
+     * more than once under various error conditions. In this instance it is 
safe because these are upgrade processors
+     * (isUpgrade() returns {@code true}) and upgrade processors are never 
pooled.
+     * <p>
+     * TODO: Despite the complications it creates, it would be worth exploring 
the possibility of processing a received
+     * close frame in a non-blocking manner.
+     */
+    @Override
+    protected boolean acquireMessagePartInProgressSemaphore(byte opCode, long 
timeoutExpiry)
+            throws InterruptedException {
+
+        // Only close requires special handling.
+        if (opCode != Constants.OPCODE_CLOSE) {
+            return super.acquireMessagePartInProgressSemaphore(opCode, 
timeoutExpiry);
+        }
+
+        int socketWrapperLockCount;
+        if (socketWrapper.getLock() instanceof ReentrantLock) {
+            socketWrapperLockCount = ((ReentrantLock) 
socketWrapper.getLock()).getHoldCount();
+        } else {
+            socketWrapperLockCount = 1;
+        }
+        while (!messagePartInProgress.tryAcquire()) {
+            long timeout = timeoutExpiry - System.currentTimeMillis();
+            if (timeout < 0) {
+                return false;
+            }
+            try {
+                // Release control of the processor
+                socketWrapper.setCurrentProcessor(connection);
+                // Release the per socket lock(s)
+                for (int i = 0; i < socketWrapperLockCount; i++) {
+                    socketWrapper.getLock().unlock();
+                }
+                // Provide opportunity for another thread to obtain the 
socketWrapper lock
+                Thread.yield();
+            } finally {
+                // Re-obtain the per socket lock(s)
+                for (int i = 0; i < socketWrapperLockCount; i++) {
+                    socketWrapper.getLock().lock();
+                }
+                // Re-take control of the processor
+                socketWrapper.takeCurrentProcessor();
+            }
+        }
+
+        return true;
+    }
+
+
     @Override
     protected void doWrite(final SendHandler handler, final long 
blockingWriteTimeoutExpiry, ByteBuffer... buffers) {
         if (socketWrapper.hasAsyncIO()) {
@@ -296,6 +371,12 @@ public class WsRemoteEndpointImplServer extends 
WsRemoteEndpointImplBase {
     }
 
 
+    @Override
+    protected Lock getLock() {
+        return socketWrapper.getLock();
+    }
+
+
     private static class OnResultRunnable implements Runnable {
 
         private final SendHandler sh;


---------------------------------------------------------------------
To unsubscribe, e-mail: dev-unsubscr...@tomcat.apache.org
For additional commands, e-mail: dev-h...@tomcat.apache.org

Reply via email to