This is an automated email from the ASF dual-hosted git repository. lgoldstein pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/mina-sshd.git
commit 098760248c37e628723529c2c49c5a353606ff23 Author: =?UTF-8?q?Achim=20H=C3=BCgen?= <achim.hue...@deutschepost.de> AuthorDate: Sat Mar 20 09:07:58 2021 +0200 [SSHD-1123] Add option to chunk data in ChannelAsyncOutputStream if window size is smaller than packet size --- .../common/channel/ChannelAsyncOutputStream.java | 48 ++++++-- .../apache/sshd/server/channel/ChannelSession.java | 19 ++- .../channel/ChannelAsyncOutputStreamTest.java | 131 +++++++++++++++++++++ 3 files changed, 186 insertions(+), 12 deletions(-) diff --git a/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelAsyncOutputStream.java b/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelAsyncOutputStream.java index 8d1701f..685d79e 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelAsyncOutputStream.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelAsyncOutputStream.java @@ -40,9 +40,22 @@ public class ChannelAsyncOutputStream extends AbstractCloseable implements IoOut private final byte cmd; private final AtomicReference<IoWriteFutureImpl> pendingWrite = new AtomicReference<>(); private final Object packetWriteId; + private boolean sendChunkIfRemoteWindowIsSmallerThanPacketSize; public ChannelAsyncOutputStream(Channel channel, byte cmd) { + this(channel, cmd, false); + } + + /** + * @param sendChunkIfRemoteWindowIsSmallerThanPacketSize Determines the chunking behaviour, if the remote window + * size is smaller than the packet size. Can be use to + * establish compatibility with certain clients, that wait + * until the window size is 0 before adjusting it (see + * SSHD-1123). Default is false; + */ + public ChannelAsyncOutputStream(Channel channel, byte cmd, boolean sendChunkIfRemoteWindowIsSmallerThanPacketSize) { this.channelInstance = Objects.requireNonNull(channel, "No channel"); + this.sendChunkIfRemoteWindowIsSmallerThanPacketSize = sendChunkIfRemoteWindowIsSmallerThanPacketSize; this.packetWriter = channelInstance.resolveChannelStreamWriter(channel, cmd); this.cmd = cmd; this.packetWriteId = channel.toString() + "[" + SshConstants.getCommandMessageName(cmd) + "]"; @@ -113,15 +126,21 @@ public class ChannelAsyncOutputStream extends AbstractCloseable implements IoOut // send the first chunk as we have enough space in the window length = packetSize; } else { - // do not chunk when the window is smaller than the packet size - length = 0; - // do a defensive copy in case the user reuses the buffer - IoWriteFutureImpl f = new IoWriteFutureImpl(future.getId(), new ByteArrayBuffer(buffer.getCompactData())); - f.addListener(w -> future.setValue(w.getException() != null ? w.getException() : w.isWritten())); - pendingWrite.set(f); - if (log.isTraceEnabled()) { - log.trace("doWriteIfPossible({})[resume={}] waiting for window space {}", - this, resume, remoteWindowSize); + // Window size is even smaller than packet size. Determine how to handle this. + if (isSendChunkIfRemoteWindowIsSmallerThanPacketSize()) { + length = remoteWindowSize; + } else { + // do not chunk when the window is smaller than the packet size + length = 0L; + // do a defensive copy in case the user reuses the buffer + IoWriteFutureImpl f + = new IoWriteFutureImpl(future.getId(), new ByteArrayBuffer(buffer.getCompactData())); + f.addListener(w -> future.setValue(w.getException() != null ? w.getException() : w.isWritten())); + pendingWrite.set(f); + if (log.isTraceEnabled()) { + log.trace("doWriteIfPossible({})[resume={}] waiting for window space {}", + this, resume, remoteWindowSize); + } } } } else if (total > packetSize) { @@ -147,7 +166,7 @@ public class ChannelAsyncOutputStream extends AbstractCloseable implements IoOut } } - if (length > 0) { + if (length > 0L) { if (resume) { if (log.isDebugEnabled()) { log.debug("Resuming {} write due to more space ({}) available in the remote window", this, length); @@ -229,4 +248,13 @@ public class ChannelAsyncOutputStream extends AbstractCloseable implements IoOut public String toString() { return getClass().getSimpleName() + "[" + getChannel() + "] cmd=" + SshConstants.getCommandMessageName(cmd & 0xFF); } + + public boolean isSendChunkIfRemoteWindowIsSmallerThanPacketSize() { + return sendChunkIfRemoteWindowIsSmallerThanPacketSize; + } + + public void setSendChunkIfRemoteWindowIsSmallerThanPacketSize(boolean sendChunkIfRemoteWindowIsSmallerThanPacketSize) { + this.sendChunkIfRemoteWindowIsSmallerThanPacketSize = sendChunkIfRemoteWindowIsSmallerThanPacketSize; + } + } diff --git a/sshd-core/src/main/java/org/apache/sshd/server/channel/ChannelSession.java b/sshd-core/src/main/java/org/apache/sshd/server/channel/ChannelSession.java index adae173..93821eb 100644 --- a/sshd-core/src/main/java/org/apache/sshd/server/channel/ChannelSession.java +++ b/sshd-core/src/main/java/org/apache/sshd/server/channel/ChannelSession.java @@ -720,8 +720,12 @@ public class ChannelSession extends AbstractServerChannel { } // If the shell wants to use non-blocking io if (command instanceof AsyncCommandStreamsAware) { - asyncOut = new ChannelAsyncOutputStream(this, SshConstants.SSH_MSG_CHANNEL_DATA); - asyncErr = new ChannelAsyncOutputStream(this, SshConstants.SSH_MSG_CHANNEL_EXTENDED_DATA); + asyncOut = new ChannelAsyncOutputStream( + this, SshConstants.SSH_MSG_CHANNEL_DATA, + isSendChunkIfRemoteWindowIsSmallerThanPacketSize()); + asyncErr = new ChannelAsyncOutputStream( + this, SshConstants.SSH_MSG_CHANNEL_EXTENDED_DATA, + isSendChunkIfRemoteWindowIsSmallerThanPacketSize()); ((AsyncCommandStreamsAware) command).setIoOutputStream(asyncOut); ((AsyncCommandStreamsAware) command).setIoErrorStream(asyncErr); } else { @@ -914,4 +918,15 @@ public class ChannelSession extends AbstractServerChannel { commandExitFuture.setClosed(); } } + + /** + * Chance for specializations to vary chunking behaviour depending on the SFTP client version. + * + * @return {@code true} if chunk data sent via {@link ChannelAsyncOutputStream} when reported remote window size is + * less than its packet size + * @see ChannelAsyncOutputStream#ChannelAsyncOutputStream(Channel, byte, boolean) + */ + protected boolean isSendChunkIfRemoteWindowIsSmallerThanPacketSize() { + return false; + } } diff --git a/sshd-core/src/test/java/org/apache/sshd/common/channel/ChannelAsyncOutputStreamTest.java b/sshd-core/src/test/java/org/apache/sshd/common/channel/ChannelAsyncOutputStreamTest.java new file mode 100644 index 0000000..34228d8 --- /dev/null +++ b/sshd-core/src/test/java/org/apache/sshd/common/channel/ChannelAsyncOutputStreamTest.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sshd.common.channel; + +import java.io.IOException; +import java.util.Random; + +import org.apache.sshd.common.PropertyResolver; +import org.apache.sshd.common.channel.throttle.ChannelStreamWriter; +import org.apache.sshd.common.io.IoWriteFuture; +import org.apache.sshd.common.session.Session; +import org.apache.sshd.common.util.buffer.Buffer; +import org.apache.sshd.common.util.buffer.ByteArrayBuffer; +import org.apache.sshd.util.test.BaseTestSupport; +import org.apache.sshd.util.test.NoIoTestCase; +import org.junit.Before; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runners.MethodSorters; +import org.mockito.ArgumentMatchers; +import org.mockito.Mockito; + +/** + * Tests the behaviour of {@link ChannelAsyncOutputStream} regarding the chunking of the data to sent. + */ +@FixMethodOrder(MethodSorters.NAME_ASCENDING) +@Category({ NoIoTestCase.class }) +public class ChannelAsyncOutputStreamTest extends BaseTestSupport { + + private static final String CLIENT_WITH_COMPATIBILITY_ISSUE = "specialClient"; + private Window remoteWindow; + private ChannelStreamWriter channelStreamWriter; + private AbstractChannel channel; + private Session session; + private IoWriteFuture ioWriteFuture; + + public ChannelAsyncOutputStreamTest() { + super(); + } + + @Before + public void setUp() throws Exception { + channel = Mockito.mock(AbstractChannel.class); + channelStreamWriter = Mockito.mock(ChannelStreamWriter.class); + remoteWindow = new Window(channel, null, true, true); + ioWriteFuture = Mockito.mock(IoWriteFuture.class); + session = Mockito.mock(Session.class); + + Mockito.when(channel.getRemoteWindow()).thenReturn(remoteWindow); + Mockito.when(channel.getSession()).thenReturn(session); + + Mockito.when(channel.resolveChannelStreamWriter(ArgumentMatchers.any(Channel.class), ArgumentMatchers.anyByte())) + .thenReturn(channelStreamWriter); + Mockito.when(channelStreamWriter.writeData(ArgumentMatchers.any())).thenReturn(ioWriteFuture); + + Mockito.when(session.createBuffer(ArgumentMatchers.anyByte(), ArgumentMatchers.anyInt())) + .thenReturn(new ByteArrayBuffer()); + + Mockito.when(session.getClientVersion()).thenReturn(CLIENT_WITH_COMPATIBILITY_ISSUE); + + } + + @Test + public void testCompleteDataSentIfDataFitsIntoPacketAndPacketFitsInRemoteWindow() throws IOException { + ChannelAsyncOutputStream channelAsyncOutputStream = new ChannelAsyncOutputStream(channel, (byte) 0); + checkChangeOfRemoteWindowSizeOnBufferWrite(channelAsyncOutputStream, 40000, 32000, 30000, 40000 - 30000); + } + + /* + * Only partial Data of packet size should be sent if data is larger than packet size and packet size fits into + * remote window + */ + @Test + public void testChunkOfPacketSizeSentIfDataLargerThanPacketSizeAndPacketFitsInRemoteWindow() throws IOException { + ChannelAsyncOutputStream channelAsyncOutputStream = new ChannelAsyncOutputStream(channel, (byte) 0); + checkChangeOfRemoteWindowSizeOnBufferWrite(channelAsyncOutputStream, 40000, 32000, 35000, 40000 - 32000); + } + + @Test + public void testChunkOfPacketSizeSentIfDataLargerThanRemoteWindowAndPacketFitsInRemoteWindow() throws IOException { + ChannelAsyncOutputStream channelAsyncOutputStream = new ChannelAsyncOutputStream(channel, (byte) 0); + checkChangeOfRemoteWindowSizeOnBufferWrite(channelAsyncOutputStream, 40000, 32000, 50000, 40000 - 32000); + } + + @Test + public void testNoChunkingIfRemoteWindowSmallerThanPacketSize() throws IOException { + ChannelAsyncOutputStream channelAsyncOutputStream = new ChannelAsyncOutputStream(channel, (byte) 0); + checkChangeOfRemoteWindowSizeOnBufferWrite(channelAsyncOutputStream, 30000, 32000, 50000, 30000); + } + + @Test + public void testChunkingIfRemoteWindowSmallerThanPacketSize() throws IOException { + ChannelAsyncOutputStream channelAsyncOutputStream = new ChannelAsyncOutputStream(channel, (byte) 0, true); + checkChangeOfRemoteWindowSizeOnBufferWrite(channelAsyncOutputStream, 30000, 32000, 50000, 0); + } + + private void checkChangeOfRemoteWindowSizeOnBufferWrite( + ChannelAsyncOutputStream channelAsyncOutputStream, int initialWindowSize, int packetSize, int totalDataToSent, + int expectedWindowSize) + throws IOException { + + remoteWindow.init(initialWindowSize, packetSize, PropertyResolver.EMPTY); + Buffer buffer = createBuffer(totalDataToSent); + channelAsyncOutputStream.writeBuffer(buffer); + + assertEquals(expectedWindowSize, remoteWindow.getSize()); + } + + private ByteArrayBuffer createBuffer(int size) { + byte[] randomBytes = new byte[size]; + new Random().nextBytes(randomBytes); + return new ByteArrayBuffer(randomBytes); + } +}