This is an automated email from the ASF dual-hosted git repository.

lgoldstein pushed a commit to branch SSHD-966
in repository https://gitbox.apache.org/repos/asf/mina-sshd.git

commit b0159a4f22f7dcdab9a6b8824f6702547ab13058
Author: Lyor Goldstein <lgoldst...@apache.org>
AuthorDate: Fri May 15 10:24:50 2020 +0300

    [SSHD-966] Using same lock to synchronize session pending packets and 
ChannelOutputStream mutual exclusion
---
 .../sshd/common/channel/ChannelOutputStream.java   | 224 +++++++++++++--------
 .../common/session/helpers/AbstractSession.java    |  98 +++++++--
 2 files changed, 216 insertions(+), 106 deletions(-)

diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelOutputStream.java
 
b/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelOutputStream.java
index d0d879b..c2f48f2 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelOutputStream.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelOutputStream.java
@@ -26,11 +26,14 @@ import java.util.Objects;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicBoolean;
 
+import org.apache.sshd.common.FactoryManager;
+import org.apache.sshd.common.RuntimeSshException;
 import org.apache.sshd.common.SshConstants;
 import org.apache.sshd.common.SshException;
 import org.apache.sshd.common.channel.exception.SshChannelClosedException;
 import org.apache.sshd.common.io.PacketWriter;
 import org.apache.sshd.common.session.Session;
+import org.apache.sshd.common.session.helpers.AbstractSession;
 import org.apache.sshd.common.util.ValidateUtils;
 import org.apache.sshd.common.util.buffer.Buffer;
 import org.slf4j.Logger;
@@ -103,24 +106,63 @@ public class ChannelOutputStream extends OutputStream 
implements java.nio.channe
     }
 
     @Override
-    public synchronized void write(int w) throws IOException {
-        b[0] = (byte) w;
-        write(b, 0, 1);
+    public void write(int w) throws IOException {
+        try {
+            Channel channel = getChannel();
+            Session session = channel.getSession();
+            ((AbstractSession) session).executeUnderPendingPacketsLock(
+                    getExtraPendingPacketLockWaitTime(1), () -> {
+                        b[0] = (byte) w;
+                        lockedWrite(session, channel, b, 0, 1);
+                        return null;
+                    });
+        } catch (Exception e) {
+            log.error("write(" + this + ") value=0x" + Integer.toHexString(w) 
+ " failed to write", e);
+            if (e instanceof IOException) {
+                throw (IOException) e;
+            } else if (e instanceof RuntimeException) {
+                throw (RuntimeException) e;
+            } else {
+                throw new RuntimeSshException(e);
+            }
+        }
     }
 
     @Override
-    public synchronized void write(byte[] buf, int s, int l) throws 
IOException {
-        Channel channel = getChannel();
+    public void write(byte[] buf, int startOffset, int dataLen) throws 
IOException {
+        try {
+            Channel channel = getChannel();
+            Session session = channel.getSession();
+            ((AbstractSession) session).executeUnderPendingPacketsLock(
+                    getExtraPendingPacketLockWaitTime(dataLen), () -> {
+                        lockedWrite(session, channel, buf, startOffset, 
dataLen);
+                        return null;
+                    });
+        } catch (Exception e) {
+            log.error("write(" + this + ") len=" + dataLen + " failed to 
write", e);
+            if (e instanceof IOException) {
+                throw (IOException) e;
+            } else if (e instanceof RuntimeException) {
+                throw (RuntimeException) e;
+            } else {
+                throw new RuntimeSshException(e);
+            }
+        }
+    }
+
+    protected void lockedWrite(
+            Session session, Channel channel, byte[] buf, int startOffset, int 
dataLen)
+                throws Exception {
         if (!isOpen()) {
             throw new SshChannelClosedException(
                     channel.getId(),
-                    "write(" + this + ") len=" + l + " - channel already 
closed");
+                    "write(" + this + ") len=" + dataLen + " - channel already 
closed");
         }
 
-        Session session = channel.getSession();
+
         boolean debugEnabled = log.isDebugEnabled();
         boolean traceEnabled = log.isTraceEnabled();
-        while (l > 0) {
+        while (dataLen > 0) {
             // The maximum amount we should admit without flushing again
             // is enough to make up one full packet within our allowed
             // window size. We give ourselves a credit equal to the last
@@ -128,7 +170,7 @@ public class ChannelOutputStream extends OutputStream 
implements java.nio.channe
             // out the next packet before we block and wait for space to
             // become available again.
             long minReqLen = Math.min(remoteWindow.getSize() + lastSize, 
remoteWindow.getPacketSize());
-            long l2 = Math.min(l, minReqLen - bufferLength);
+            long l2 = Math.min(dataLen, minReqLen - bufferLength);
             if (l2 <= 0) {
                 if (bufferLength > 0) {
                     flush();
@@ -137,22 +179,22 @@ public class ChannelOutputStream extends OutputStream 
implements java.nio.channe
                     try {
                         long available = 
remoteWindow.waitForSpace(maxWaitTimeout);
                         if (traceEnabled) {
-                            log.trace("write({}) len={} - available={}", this, 
l, available);
+                            log.trace("write({}) len={} - available={}", this, 
dataLen, available);
                         }
                     } catch (IOException e) {
                         log.error("write({}) failed ({}) to wait for space of 
len={}: {}",
-                                this, e.getClass().getSimpleName(), l, 
e.getMessage());
+                                this, e.getClass().getSimpleName(), dataLen, 
e.getMessage());
 
                         if ((e instanceof WindowClosedException) && 
(!closedState.getAndSet(true))) {
                             if (debugEnabled) {
-                                log.debug("write({})[len={}] closing due to 
window closed", this, l);
+                                log.debug("write({})[len={}] closing due to 
window closed", this, dataLen);
                             }
                         }
 
                         throw e;
                     } catch (InterruptedException e) {
                         throw (IOException) new InterruptedIOException(
-                                "Interrupted while waiting for remote space on 
write len=" + l + " to " + this)
+                                "Interrupted while waiting for remote space on 
write len=" + dataLen + " to " + this)
                                         .initCause(e);
                     }
                 }
@@ -162,10 +204,10 @@ public class ChannelOutputStream extends OutputStream 
implements java.nio.channe
 
             ValidateUtils.checkTrue(l2 <= Integer.MAX_VALUE,
                     "Accumulated bytes length exceeds int boundary: %d", l2);
-            buffer.putRawBytes(buf, s, (int) l2);
+            buffer.putRawBytes(buf, startOffset, (int) l2);
             bufferLength += l2;
-            s += l2;
-            l -= l2;
+            startOffset += l2;
+            dataLen -= l2;
         }
 
         if (isNoDelay()) {
@@ -176,67 +218,74 @@ public class ChannelOutputStream extends OutputStream 
implements java.nio.channe
     }
 
     @Override
-    public synchronized void flush() throws IOException {
-        AbstractChannel channel = getChannel();
-        if (!isOpen()) {
-            throw new SshChannelClosedException(
-                    channel.getId(),
-                    "flush(" + this + ") length=" + bufferLength + " - stream 
is already closed");
-        }
-
+    public void flush() throws IOException {
+        Channel channel = getChannel();
+        Session session = channel.getSession();
         try {
-            Session session = channel.getSession();
-            boolean traceEnabled = log.isTraceEnabled();
-            while (bufferLength > 0) {
-                session.resetIdleTimeout();
+            ((AbstractSession) session).executeUnderPendingPacketsLock(
+                    getExtraPendingPacketLockWaitTime(bufferLength),
+                    () -> {
+                        if (!isOpen()) {
+                            throw new SshChannelClosedException(
+                                    channel.getId(),
+                                    "flush(" + this + ") length=" + 
bufferLength + " - stream is already closed");
+                        }
 
-                Buffer buf = buffer;
-                long total = bufferLength;
-                long available;
-                try {
-                    available = remoteWindow.waitForSpace(maxWaitTimeout);
-                    if (traceEnabled) {
-                        log.trace("flush({}) len={}, available={}", this, 
total, available);
-                    }
-                } catch (IOException e) {
-                    log.error("flush({}) failed ({}) to wait for space of 
len={}: {}",
-                            this, e.getClass().getSimpleName(), total, 
e.getMessage());
-                    if (log.isDebugEnabled()) {
-                        log.error("flush(" + this + ") wait for space len=" + 
total + " exception details", e);
-                    }
-                    throw e;
-                }
+                        boolean traceEnabled = log.isTraceEnabled();
+                        while (bufferLength > 0) {
+                            session.resetIdleTimeout();
 
-                long lenToSend = Math.min(available, total);
-                long length = Math.min(lenToSend, 
remoteWindow.getPacketSize());
-                if (length > Integer.MAX_VALUE) {
-                    throw new StreamCorruptedException(
-                            "Accumulated " + 
SshConstants.getCommandMessageName(cmd)
-                                                       + " command bytes size 
(" + length + ") exceeds int boundaries");
-                }
+                            Buffer buf = buffer;
+                            long total = bufferLength;
+                            long available;
+                            try {
+                                available = 
remoteWindow.waitForSpace(maxWaitTimeout);
+                                if (traceEnabled) {
+                                    log.trace("flush({}) len={}, 
available={}", this, total, available);
+                                }
+                            } catch (IOException e) {
+                                log.error("flush({}) failed ({}) to wait for 
space of len={}: {}",
+                                        this, e.getClass().getSimpleName(), 
total, e.getMessage());
+                                if (log.isDebugEnabled()) {
+                                    log.error("flush(" + this + ") wait for 
space len=" + total + " exception details", e);
+                                }
+                                throw e;
+                            }
 
-                int pos = buf.wpos();
-                buf.wpos((cmd == SshConstants.SSH_MSG_CHANNEL_EXTENDED_DATA) ? 
14 : 10);
-                buf.putInt(length);
-                buf.wpos(buf.wpos() + (int) length);
-                if (total == length) {
-                    newBuffer((int) length);
-                } else {
-                    long leftover = total - length;
-                    newBuffer((int) Math.max(leftover, length));
-                    buffer.putRawBytes(buf.array(), pos - (int) leftover, 
(int) leftover);
-                    bufferLength = (int) leftover;
-                }
-                lastSize = (int) length;
+                            long lenToSend = Math.min(available, total);
+                            long length = Math.min(lenToSend, 
remoteWindow.getPacketSize());
+                            if (length > Integer.MAX_VALUE) {
+                                throw new StreamCorruptedException(
+                                        "Accumulated " + 
SshConstants.getCommandMessageName(cmd)
+                                                                   + " command 
bytes size (" + length
+                                                                   + ") 
exceeds int boundaries");
+                            }
 
-                session.resetIdleTimeout();
-                remoteWindow.waitAndConsume(length, maxWaitTimeout);
-                if (traceEnabled) {
-                    log.trace("flush({}) send {} len={}",
-                            channel, SshConstants.getCommandMessageName(cmd), 
length);
-                }
-                packetWriter.writePacket(buf);
-            }
+                            int pos = buf.wpos();
+                            buf.wpos((cmd == 
SshConstants.SSH_MSG_CHANNEL_EXTENDED_DATA) ? 14 : 10);
+                            buf.putInt(length);
+                            buf.wpos(buf.wpos() + (int) length);
+                            if (total == length) {
+                                newBuffer((int) length);
+                            } else {
+                                long leftover = total - length;
+                                newBuffer((int) Math.max(leftover, length));
+                                buffer.putRawBytes(buf.array(), pos - (int) 
leftover, (int) leftover);
+                                bufferLength = (int) leftover;
+                            }
+                            lastSize = (int) length;
+
+                            session.resetIdleTimeout();
+                            remoteWindow.waitAndConsume(length, 
maxWaitTimeout);
+                            if (traceEnabled) {
+                                log.trace("flush({}) send {} len={}",
+                                        channel, 
SshConstants.getCommandMessageName(cmd), length);
+                            }
+                            packetWriter.writePacket(buf);
+                        }
+
+                        return null;
+                    });
         } catch (WindowClosedException e) {
             if (!closedState.getAndSet(true)) {
                 if (log.isDebugEnabled()) {
@@ -257,8 +306,13 @@ public class ChannelOutputStream extends OutputStream 
implements java.nio.channe
         }
     }
 
+    protected long getExtraPendingPacketLockWaitTime(int dataSize) {
+        // TODO see if can do anything better
+        return Math.min(dataSize, 
FactoryManager.DEFAULT_NIO2_MIN_WRITE_TIMEOUT);
+    }
+
     @Override
-    public synchronized void close() throws IOException {
+    public void close() throws IOException {
         if (!isOpen()) {
             return;
         }
@@ -267,20 +321,22 @@ public class ChannelOutputStream extends OutputStream 
implements java.nio.channe
             log.trace("close({}) closing", this);
         }
 
-        try {
-            flush();
-
-            if (isEofOnClose()) {
-                AbstractChannel channel = getChannel();
-                channel.sendEof();
-            }
-        } finally {
+        synchronized (closedState) {
             try {
-                if (!(packetWriter instanceof Channel)) {
-                    packetWriter.close();
+                flush();
+
+                if (isEofOnClose()) {
+                    AbstractChannel channel = getChannel();
+                    channel.sendEof();
                 }
             } finally {
-                closedState.set(true);
+                try {
+                    if (!(packetWriter instanceof Channel)) {
+                        packetWriter.close();
+                    }
+                } finally {
+                    closedState.set(true);
+                }
             }
         }
     }
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
 
b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
index 47a56b9..849a561 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java
@@ -35,10 +35,15 @@ import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Queue;
+import java.util.concurrent.Callable;
 import java.util.concurrent.CopyOnWriteArraySet;
 import java.util.concurrent.TimeUnit;
+import java.util.concurrent.TimeoutException;
+import java.util.concurrent.atomic.AtomicInteger;
 import java.util.concurrent.atomic.AtomicLong;
 import java.util.concurrent.atomic.AtomicReference;
+import java.util.concurrent.locks.Lock;
+import java.util.concurrent.locks.ReentrantLock;
 import java.util.logging.Level;
 
 import org.apache.sshd.common.Closeable;
@@ -187,6 +192,7 @@ public abstract class AbstractSession extends SessionHelper 
{
     protected long maxRekyPackets = FactoryManager.DEFAULT_REKEY_PACKETS_LIMIT;
     protected long maxRekeyBytes = FactoryManager.DEFAULT_REKEY_BYTES_LIMIT;
     protected long maxRekeyInterval = FactoryManager.DEFAULT_REKEY_TIME_LIMIT;
+    protected final Lock pendingPacketsLock = new ReentrantLock();
     protected final Queue<PendingWriteFuture> pendingPackets = new 
LinkedList<>();
 
     protected Service currentService;
@@ -656,6 +662,38 @@ public abstract class AbstractSession extends 
SessionHelper {
         doKexNegotiation();
     }
 
+    /**
+     * Attempts to lock the pending packets access and execute the relevant 
code. Max. wait time is derived from the
+     * current number of pending packets
+     *
+     * @param  <V>       The executed code return value
+     * @param  extraWait An extra amount of time (msec.) that the caller is 
willing to wait beyond the time derived from
+     *                   the number of pending packets. <B>Note:</B> a 
hardcoded max. value of
+     *                   {@link FactoryManager#DEFAULT_AUTH_TIMEOUT} is 
imposed on the total calculated time.
+     * @param  executor  The code to execute under lock
+     * @return           The executed code result
+     * @throws Exception If failed to lock or exception thrown by executor code
+     * @see              <A 
HREF="https://issues.apache.org/jira/browse/SSHD-966";>SSHD-966</A>
+     */
+    public <V> V executeUnderPendingPacketsLock(long extraWait, Callable<? 
extends V> executor) throws Exception {
+        ValidateUtils.checkTrue(extraWait >= 0L, "Invalid extra wait time: 
%d", extraWait);
+        int numPending = pendingPackets.size();
+        long maxWait = numPending * 
FactoryManager.DEFAULT_NIO2_MIN_WRITE_TIMEOUT + extraWait;
+        // in case zero
+        maxWait = Math.max(maxWait, 
FactoryManager.DEFAULT_NIO2_MIN_WRITE_TIMEOUT);
+        // in case lots of pending packets or large extra time
+        maxWait = Math.min(maxWait, FactoryManager.DEFAULT_AUTH_TIMEOUT);
+        if (!pendingPacketsLock.tryLock(maxWait, TimeUnit.MILLISECONDS)) {
+            throw new TimeoutException("Failed to acquire " + numPending + " 
pending packets lock");
+        }
+
+        try {
+            return executor.call();
+        } finally {
+            pendingPacketsLock.unlock();
+        }
+    }
+
     protected void doKexNegotiation() throws Exception {
         if (kexState.compareAndSet(KexState.DONE, KexState.RUN)) {
             sendKexInit();
@@ -669,9 +707,10 @@ public abstract class AbstractSession extends 
SessionHelper {
         KeyExchangeFactory kexFactory = NamedResource.findByName(
                 kexAlgorithm, String.CASE_INSENSITIVE_ORDER, kexFactories);
         ValidateUtils.checkNotNull(kexFactory, "Unknown negotiated KEX 
algorithm: %s", kexAlgorithm);
-        synchronized (pendingPackets) {
+        executeUnderPendingPacketsLock(0L, () -> {
             kex = kexFactory.createKeyExchange(this);
-        }
+            return kex;
+        });
 
         byte[] v_s = serverVersion.getBytes(StandardCharsets.UTF_8);
         byte[] v_c = clientVersion.getBytes(StandardCharsets.UTF_8);
@@ -707,12 +746,13 @@ public abstract class AbstractSession extends 
SessionHelper {
 
         signalSessionEvent(SessionListener.Event.KeyEstablished);
 
-        Collection<? extends Map.Entry<? extends 
SshFutureListener<IoWriteFuture>, IoWriteFuture>> pendingWrites;
-        synchronized (pendingPackets) {
-            pendingWrites = sendPendingPackets(pendingPackets);
-            kex = null; // discard and GC since KEX is completed
-            kexState.set(KexState.DONE);
-        }
+        Collection<? extends Map.Entry<? extends 
SshFutureListener<IoWriteFuture>, IoWriteFuture>> pendingWrites
+                = 
executeUnderPendingPacketsLock(FactoryManager.DEFAULT_NIO2_MIN_WRITE_TIMEOUT, 
() -> {
+                    List<Map.Entry<PendingWriteFuture, IoWriteFuture>> result 
= sendPendingPackets(pendingPackets);
+                    kex = null; // discard and GC since KEX is completed
+                    kexState.set(KexState.DONE);
+                    return result;
+                });
 
         int pendingCount = pendingWrites.size();
         if (pendingCount > 0) {
@@ -734,7 +774,7 @@ public abstract class AbstractSession extends SessionHelper 
{
         }
     }
 
-    protected List<SimpleImmutableEntry<PendingWriteFuture, IoWriteFuture>> 
sendPendingPackets(
+    protected List<Map.Entry<PendingWriteFuture, IoWriteFuture>> 
sendPendingPackets(
             Queue<PendingWriteFuture> packetsQueue)
             throws IOException {
         if (GenericUtils.isEmpty(packetsQueue)) {
@@ -742,7 +782,7 @@ public abstract class AbstractSession extends SessionHelper 
{
         }
 
         int numPending = packetsQueue.size();
-        List<SimpleImmutableEntry<PendingWriteFuture, IoWriteFuture>> 
pendingWrites = new ArrayList<>(numPending);
+        List<Map.Entry<PendingWriteFuture, IoWriteFuture>> pendingWrites = new 
ArrayList<>(numPending);
         synchronized (encodeLock) {
             for (PendingWriteFuture future = packetsQueue.poll();
                  future != null;
@@ -872,10 +912,11 @@ public abstract class AbstractSession extends 
SessionHelper {
      * Checks if key-exchange is done - if so, or the packet is related to the 
key-exchange protocol, then allows the
      * packet to go through, otherwise enqueues it to be sent when 
key-exchange completed
      *
-     * @param  buffer The {@link Buffer} containing the packet to be sent
-     * @return        A {@link PendingWriteFuture} if enqueued, {@code null} 
if packet can go through.
+     * @param  buffer      The {@link Buffer} containing the packet to be sent
+     * @return             A {@link PendingWriteFuture} if enqueued, {@code 
null} if packet can go through.
+     * @throws IOException If failed to enqueue
      */
-    protected PendingWriteFuture enqueuePendingPacket(Buffer buffer) {
+    protected PendingWriteFuture enqueuePendingPacket(Buffer buffer) throws 
IOException {
         if (KexState.DONE.equals(kexState.get())) {
             return null;
         }
@@ -887,20 +928,33 @@ public abstract class AbstractSession extends 
SessionHelper {
         }
 
         String cmdName = SshConstants.getCommandMessageName(cmd);
+        AtomicInteger numPending = new AtomicInteger();
         PendingWriteFuture future;
-        int numPending;
-        synchronized (pendingPackets) {
-            if (KexState.DONE.equals(kexState.get())) {
-                return null;
-            }
+        try {
+            future = executeUnderPendingPacketsLock(0L, () -> {
+                if (KexState.DONE.equals(kexState.get())) {
+                    return null;
+                }
+
+                PendingWriteFuture pending = new PendingWriteFuture(cmdName, 
buffer);
+                pendingPackets.add(pending);
+                numPending.set(pendingPackets.size());
+                return pending;
+            });
+        } catch (Exception e) {
+            log.error("enqueuePendingPacket(" + this + ")[" + cmdName + "] 
failed to generate future", e);
 
-            future = new PendingWriteFuture(cmdName, buffer);
-            pendingPackets.add(future);
-            numPending = pendingPackets.size();
+            if (e instanceof IOException) {
+                throw (IOException) e;
+            } else if (e instanceof RuntimeException) {
+                throw (RuntimeException) e;
+            } else {
+                throw new RuntimeSshException(e);
+            }
         }
 
         if (log.isDebugEnabled()) {
-            if (numPending == 1) {
+            if (numPending.get() == 1) {
                 log.debug("enqueuePendingPacket({})[{}] Start flagging packets 
as pending until key exchange is done", this,
                         cmdName);
             } else {

Reply via email to