This is an automated email from the ASF dual-hosted git repository. lgoldstein pushed a commit to branch SSHD-966 in repository https://gitbox.apache.org/repos/asf/mina-sshd.git
commit 25e7e13c95ada68f3def91dd153b44fd0127971b Author: Lyor Goldstein <lgoldst...@apache.org> AuthorDate: Fri May 15 10:24:50 2020 +0300 [SSHD-966] Using same lock to synchronize session pending packets and ChannelOutputStream mutual exclusion --- .../sshd/common/channel/ChannelOutputStream.java | 253 ++++++++++++++------- .../common/session/helpers/AbstractSession.java | 98 ++++++-- 2 files changed, 245 insertions(+), 106 deletions(-) diff --git a/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelOutputStream.java b/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelOutputStream.java index d0d879b..14f7e91 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelOutputStream.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelOutputStream.java @@ -26,11 +26,14 @@ import java.util.Objects; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import org.apache.sshd.common.FactoryManager; +import org.apache.sshd.common.RuntimeSshException; import org.apache.sshd.common.SshConstants; import org.apache.sshd.common.SshException; import org.apache.sshd.common.channel.exception.SshChannelClosedException; import org.apache.sshd.common.io.PacketWriter; import org.apache.sshd.common.session.Session; +import org.apache.sshd.common.session.helpers.AbstractSession; import org.apache.sshd.common.util.ValidateUtils; import org.apache.sshd.common.util.buffer.Buffer; import org.slf4j.Logger; @@ -103,24 +106,62 @@ public class ChannelOutputStream extends OutputStream implements java.nio.channe } @Override - public synchronized void write(int w) throws IOException { - b[0] = (byte) w; - write(b, 0, 1); + public void write(int w) throws IOException { + try { + Channel channel = getChannel(); + Session session = channel.getSession(); + ((AbstractSession) session).executeUnderPendingPacketsLock( + getExtraPendingPacketLockWaitTime(1), () -> { + b[0] = (byte) w; + lockedWrite(session, channel, b, 0, 1); + return null; + }); + } catch (Exception e) { + log.error("write(" + this + ") value=0x" + Integer.toHexString(w) + " failed to write", e); + if (e instanceof IOException) { + throw (IOException) e; + } else if (e instanceof RuntimeException) { + throw (RuntimeException) e; + } else { + throw new RuntimeSshException(e); + } + } } @Override - public synchronized void write(byte[] buf, int s, int l) throws IOException { - Channel channel = getChannel(); + public void write(byte[] buf, int startOffset, int dataLen) throws IOException { + try { + Channel channel = getChannel(); + Session session = channel.getSession(); + ((AbstractSession) session).executeUnderPendingPacketsLock( + getExtraPendingPacketLockWaitTime(dataLen), () -> { + lockedWrite(session, channel, buf, startOffset, dataLen); + return null; + }); + } catch (Exception e) { + log.error("write(" + this + ") len=" + dataLen + " failed to write", e); + if (e instanceof IOException) { + throw (IOException) e; + } else if (e instanceof RuntimeException) { + throw (RuntimeException) e; + } else { + throw new RuntimeSshException(e); + } + } + } + + protected void lockedWrite( + Session session, Channel channel, byte[] buf, int startOffset, int dataLen) + throws Exception { if (!isOpen()) { throw new SshChannelClosedException( channel.getId(), - "write(" + this + ") len=" + l + " - channel already closed"); + "lockedWrite(" + this + ") len=" + dataLen + " - channel already closed"); } - Session session = channel.getSession(); boolean debugEnabled = log.isDebugEnabled(); boolean traceEnabled = log.isTraceEnabled(); - while (l > 0) { + while (dataLen > 0) { // The maximum amount we should admit without flushing again // is enough to make up one full packet within our allowed // window size. We give ourselves a credit equal to the last @@ -128,31 +169,31 @@ public class ChannelOutputStream extends OutputStream implements java.nio.channe // out the next packet before we block and wait for space to // become available again. long minReqLen = Math.min(remoteWindow.getSize() + lastSize, remoteWindow.getPacketSize()); - long l2 = Math.min(l, minReqLen - bufferLength); + long l2 = Math.min(dataLen, minReqLen - bufferLength); if (l2 <= 0) { if (bufferLength > 0) { - flush(); + lockedFlush(session, channel); } else { session.resetIdleTimeout(); try { long available = remoteWindow.waitForSpace(maxWaitTimeout); if (traceEnabled) { - log.trace("write({}) len={} - available={}", this, l, available); + log.trace("lockedWrite({}) len={} - available={}", this, dataLen, available); } } catch (IOException e) { - log.error("write({}) failed ({}) to wait for space of len={}: {}", - this, e.getClass().getSimpleName(), l, e.getMessage()); + log.error("lockedWrite({}) failed ({}) to wait for space of len={}: {}", + this, e.getClass().getSimpleName(), dataLen, e.getMessage()); if ((e instanceof WindowClosedException) && (!closedState.getAndSet(true))) { if (debugEnabled) { - log.debug("write({})[len={}] closing due to window closed", this, l); + log.debug("lockedWrite({})[len={}] closing due to window closed", this, dataLen); } } throw e; } catch (InterruptedException e) { throw (IOException) new InterruptedIOException( - "Interrupted while waiting for remote space on write len=" + l + " to " + this) + "Interrupted while waiting for remote space on write len=" + dataLen + " to " + this) .initCause(e); } } @@ -162,81 +203,30 @@ public class ChannelOutputStream extends OutputStream implements java.nio.channe ValidateUtils.checkTrue(l2 <= Integer.MAX_VALUE, "Accumulated bytes length exceeds int boundary: %d", l2); - buffer.putRawBytes(buf, s, (int) l2); + buffer.putRawBytes(buf, startOffset, (int) l2); bufferLength += l2; - s += l2; - l -= l2; + startOffset += l2; + dataLen -= l2; } if (isNoDelay()) { - flush(); + lockedFlush(session, channel); } else { session.resetIdleTimeout(); } } @Override - public synchronized void flush() throws IOException { - AbstractChannel channel = getChannel(); - if (!isOpen()) { - throw new SshChannelClosedException( - channel.getId(), - "flush(" + this + ") length=" + bufferLength + " - stream is already closed"); - } - + public void flush() throws IOException { + Channel channel = getChannel(); + Session session = channel.getSession(); try { - Session session = channel.getSession(); - boolean traceEnabled = log.isTraceEnabled(); - while (bufferLength > 0) { - session.resetIdleTimeout(); - - Buffer buf = buffer; - long total = bufferLength; - long available; - try { - available = remoteWindow.waitForSpace(maxWaitTimeout); - if (traceEnabled) { - log.trace("flush({}) len={}, available={}", this, total, available); - } - } catch (IOException e) { - log.error("flush({}) failed ({}) to wait for space of len={}: {}", - this, e.getClass().getSimpleName(), total, e.getMessage()); - if (log.isDebugEnabled()) { - log.error("flush(" + this + ") wait for space len=" + total + " exception details", e); - } - throw e; - } - - long lenToSend = Math.min(available, total); - long length = Math.min(lenToSend, remoteWindow.getPacketSize()); - if (length > Integer.MAX_VALUE) { - throw new StreamCorruptedException( - "Accumulated " + SshConstants.getCommandMessageName(cmd) - + " command bytes size (" + length + ") exceeds int boundaries"); - } - - int pos = buf.wpos(); - buf.wpos((cmd == SshConstants.SSH_MSG_CHANNEL_EXTENDED_DATA) ? 14 : 10); - buf.putInt(length); - buf.wpos(buf.wpos() + (int) length); - if (total == length) { - newBuffer((int) length); - } else { - long leftover = total - length; - newBuffer((int) Math.max(leftover, length)); - buffer.putRawBytes(buf.array(), pos - (int) leftover, (int) leftover); - bufferLength = (int) leftover; - } - lastSize = (int) length; - - session.resetIdleTimeout(); - remoteWindow.waitAndConsume(length, maxWaitTimeout); - if (traceEnabled) { - log.trace("flush({}) send {} len={}", - channel, SshConstants.getCommandMessageName(cmd), length); - } - packetWriter.writePacket(buf); - } + ((AbstractSession) session).executeUnderPendingPacketsLock( + getExtraPendingPacketLockWaitTime(bufferLength), + () -> { + lockedFlush(session, channel); + return null; + }); } catch (WindowClosedException e) { if (!closedState.getAndSet(true)) { if (log.isDebugEnabled()) { @@ -245,8 +235,11 @@ public class ChannelOutputStream extends OutputStream implements java.nio.channe } throw e; } catch (Exception e) { + log.error("flush(" + this + ") failed", e); if (e instanceof IOException) { throw (IOException) e; + } else if (e instanceof RuntimeException) { + throw (RuntimeException) e; } else if (e instanceof InterruptedException) { throw (IOException) new InterruptedIOException( "Interrupted while waiting for remote space flush len=" + bufferLength + " to " + this) @@ -257,21 +250,87 @@ public class ChannelOutputStream extends OutputStream implements java.nio.channe } } - @Override - public synchronized void close() throws IOException { + protected void lockedFlush(Session session, Channel channel) throws Exception { + boolean traceEnabled = log.isTraceEnabled(); if (!isOpen()) { + if (bufferLength > 0) { + throw new SshChannelClosedException( + channel.getId(), + "lockedFlush(" + this + ") length=" + bufferLength + " - stream is already closed"); + } + + if (traceEnabled) { + log.trace("lockedFlush({}) nothing to flush", this); + } return; } + while (bufferLength > 0) { + session.resetIdleTimeout(); + + Buffer buf = buffer; + long total = bufferLength; + long available; + try { + available = remoteWindow.waitForSpace(maxWaitTimeout); + if (traceEnabled) { + log.trace("lockedFlush({}) len={}, available={}", this, total, available); + } + } catch (IOException e) { + log.error("lockedFlush({}) failed ({}) to wait for space of len={}: {}", + this, e.getClass().getSimpleName(), total, e.getMessage()); + if (log.isDebugEnabled()) { + log.error("lockedFlush(" + this + ") wait for space len=" + total + " exception details", e); + } + throw e; + } + + long lenToSend = Math.min(available, total); + long length = Math.min(lenToSend, remoteWindow.getPacketSize()); + if (length > Integer.MAX_VALUE) { + throw new StreamCorruptedException( + "Accumulated " + SshConstants.getCommandMessageName(cmd) + + " command bytes size (" + length + + ") exceeds int boundaries"); + } + + int pos = buf.wpos(); + buf.wpos((cmd == SshConstants.SSH_MSG_CHANNEL_EXTENDED_DATA) ? 14 : 10); + buf.putInt(length); + buf.wpos(buf.wpos() + (int) length); + if (total == length) { + newBuffer((int) length); + } else { + long leftover = total - length; + newBuffer((int) Math.max(leftover, length)); + buffer.putRawBytes(buf.array(), pos - (int) leftover, (int) leftover); + bufferLength = (int) leftover; + } + lastSize = (int) length; + + session.resetIdleTimeout(); + remoteWindow.waitAndConsume(length, maxWaitTimeout); + if (traceEnabled) { + log.trace("lockedFlush({}) send len={}", this, length); + } + packetWriter.writePacket(buf); + } + } + + protected long getExtraPendingPacketLockWaitTime(int dataSize) { + // TODO see if can do anything better + return Math.min(dataSize, FactoryManager.DEFAULT_NIO2_MIN_WRITE_TIMEOUT) + maxWaitTimeout; + } + + protected void lockedClose(Session session, AbstractChannel channel) throws Exception { if (log.isTraceEnabled()) { - log.trace("close({}) closing", this); + log.trace("lockedClose({}) closing", this); } try { - flush(); + lockedFlush(session, channel); if (isEofOnClose()) { - AbstractChannel channel = getChannel(); channel.sendEof(); } } finally { @@ -285,6 +344,32 @@ public class ChannelOutputStream extends OutputStream implements java.nio.channe } } + @Override + public void close() throws IOException { + if (!isOpen()) { + return; + } + + AbstractChannel channel = getChannel(); + Session session = channel.getSession(); + try { + ((AbstractSession) session).executeUnderPendingPacketsLock( + getExtraPendingPacketLockWaitTime(bufferLength), + () -> { + lockedClose(session, channel); + return null; + }); + } catch (Exception e) { + if (e instanceof IOException) { + throw (IOException) e; + } else if (e instanceof RuntimeException) { + throw (RuntimeException) e; + } else { + throw new RuntimeSshException(e); + } + } + } + protected void newBuffer(int size) { Channel channel = getChannel(); Session session = channel.getSession(); diff --git a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java index 47a56b9..849a561 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java @@ -35,10 +35,15 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Queue; +import java.util.concurrent.Callable; import java.util.concurrent.CopyOnWriteArraySet; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Level; import org.apache.sshd.common.Closeable; @@ -187,6 +192,7 @@ public abstract class AbstractSession extends SessionHelper { protected long maxRekyPackets = FactoryManager.DEFAULT_REKEY_PACKETS_LIMIT; protected long maxRekeyBytes = FactoryManager.DEFAULT_REKEY_BYTES_LIMIT; protected long maxRekeyInterval = FactoryManager.DEFAULT_REKEY_TIME_LIMIT; + protected final Lock pendingPacketsLock = new ReentrantLock(); protected final Queue<PendingWriteFuture> pendingPackets = new LinkedList<>(); protected Service currentService; @@ -656,6 +662,38 @@ public abstract class AbstractSession extends SessionHelper { doKexNegotiation(); } + /** + * Attempts to lock the pending packets access and execute the relevant code. Max. wait time is derived from the + * current number of pending packets + * + * @param <V> The executed code return value + * @param extraWait An extra amount of time (msec.) that the caller is willing to wait beyond the time derived from + * the number of pending packets. <B>Note:</B> a hardcoded max. value of + * {@link FactoryManager#DEFAULT_AUTH_TIMEOUT} is imposed on the total calculated time. + * @param executor The code to execute under lock + * @return The executed code result + * @throws Exception If failed to lock or exception thrown by executor code + * @see <A HREF="https://issues.apache.org/jira/browse/SSHD-966">SSHD-966</A> + */ + public <V> V executeUnderPendingPacketsLock(long extraWait, Callable<? extends V> executor) throws Exception { + ValidateUtils.checkTrue(extraWait >= 0L, "Invalid extra wait time: %d", extraWait); + int numPending = pendingPackets.size(); + long maxWait = numPending * FactoryManager.DEFAULT_NIO2_MIN_WRITE_TIMEOUT + extraWait; + // in case zero + maxWait = Math.max(maxWait, FactoryManager.DEFAULT_NIO2_MIN_WRITE_TIMEOUT); + // in case lots of pending packets or large extra time + maxWait = Math.min(maxWait, FactoryManager.DEFAULT_AUTH_TIMEOUT); + if (!pendingPacketsLock.tryLock(maxWait, TimeUnit.MILLISECONDS)) { + throw new TimeoutException("Failed to acquire " + numPending + " pending packets lock"); + } + + try { + return executor.call(); + } finally { + pendingPacketsLock.unlock(); + } + } + protected void doKexNegotiation() throws Exception { if (kexState.compareAndSet(KexState.DONE, KexState.RUN)) { sendKexInit(); @@ -669,9 +707,10 @@ public abstract class AbstractSession extends SessionHelper { KeyExchangeFactory kexFactory = NamedResource.findByName( kexAlgorithm, String.CASE_INSENSITIVE_ORDER, kexFactories); ValidateUtils.checkNotNull(kexFactory, "Unknown negotiated KEX algorithm: %s", kexAlgorithm); - synchronized (pendingPackets) { + executeUnderPendingPacketsLock(0L, () -> { kex = kexFactory.createKeyExchange(this); - } + return kex; + }); byte[] v_s = serverVersion.getBytes(StandardCharsets.UTF_8); byte[] v_c = clientVersion.getBytes(StandardCharsets.UTF_8); @@ -707,12 +746,13 @@ public abstract class AbstractSession extends SessionHelper { signalSessionEvent(SessionListener.Event.KeyEstablished); - Collection<? extends Map.Entry<? extends SshFutureListener<IoWriteFuture>, IoWriteFuture>> pendingWrites; - synchronized (pendingPackets) { - pendingWrites = sendPendingPackets(pendingPackets); - kex = null; // discard and GC since KEX is completed - kexState.set(KexState.DONE); - } + Collection<? extends Map.Entry<? extends SshFutureListener<IoWriteFuture>, IoWriteFuture>> pendingWrites + = executeUnderPendingPacketsLock(FactoryManager.DEFAULT_NIO2_MIN_WRITE_TIMEOUT, () -> { + List<Map.Entry<PendingWriteFuture, IoWriteFuture>> result = sendPendingPackets(pendingPackets); + kex = null; // discard and GC since KEX is completed + kexState.set(KexState.DONE); + return result; + }); int pendingCount = pendingWrites.size(); if (pendingCount > 0) { @@ -734,7 +774,7 @@ public abstract class AbstractSession extends SessionHelper { } } - protected List<SimpleImmutableEntry<PendingWriteFuture, IoWriteFuture>> sendPendingPackets( + protected List<Map.Entry<PendingWriteFuture, IoWriteFuture>> sendPendingPackets( Queue<PendingWriteFuture> packetsQueue) throws IOException { if (GenericUtils.isEmpty(packetsQueue)) { @@ -742,7 +782,7 @@ public abstract class AbstractSession extends SessionHelper { } int numPending = packetsQueue.size(); - List<SimpleImmutableEntry<PendingWriteFuture, IoWriteFuture>> pendingWrites = new ArrayList<>(numPending); + List<Map.Entry<PendingWriteFuture, IoWriteFuture>> pendingWrites = new ArrayList<>(numPending); synchronized (encodeLock) { for (PendingWriteFuture future = packetsQueue.poll(); future != null; @@ -872,10 +912,11 @@ public abstract class AbstractSession extends SessionHelper { * Checks if key-exchange is done - if so, or the packet is related to the key-exchange protocol, then allows the * packet to go through, otherwise enqueues it to be sent when key-exchange completed * - * @param buffer The {@link Buffer} containing the packet to be sent - * @return A {@link PendingWriteFuture} if enqueued, {@code null} if packet can go through. + * @param buffer The {@link Buffer} containing the packet to be sent + * @return A {@link PendingWriteFuture} if enqueued, {@code null} if packet can go through. + * @throws IOException If failed to enqueue */ - protected PendingWriteFuture enqueuePendingPacket(Buffer buffer) { + protected PendingWriteFuture enqueuePendingPacket(Buffer buffer) throws IOException { if (KexState.DONE.equals(kexState.get())) { return null; } @@ -887,20 +928,33 @@ public abstract class AbstractSession extends SessionHelper { } String cmdName = SshConstants.getCommandMessageName(cmd); + AtomicInteger numPending = new AtomicInteger(); PendingWriteFuture future; - int numPending; - synchronized (pendingPackets) { - if (KexState.DONE.equals(kexState.get())) { - return null; - } + try { + future = executeUnderPendingPacketsLock(0L, () -> { + if (KexState.DONE.equals(kexState.get())) { + return null; + } + + PendingWriteFuture pending = new PendingWriteFuture(cmdName, buffer); + pendingPackets.add(pending); + numPending.set(pendingPackets.size()); + return pending; + }); + } catch (Exception e) { + log.error("enqueuePendingPacket(" + this + ")[" + cmdName + "] failed to generate future", e); - future = new PendingWriteFuture(cmdName, buffer); - pendingPackets.add(future); - numPending = pendingPackets.size(); + if (e instanceof IOException) { + throw (IOException) e; + } else if (e instanceof RuntimeException) { + throw (RuntimeException) e; + } else { + throw new RuntimeSshException(e); + } } if (log.isDebugEnabled()) { - if (numPending == 1) { + if (numPending.get() == 1) { log.debug("enqueuePendingPacket({})[{}] Start flagging packets as pending until key exchange is done", this, cmdName); } else {