This is an automated email from the ASF dual-hosted git repository. gnodet pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/mina-sshd.git
commit e85b67e0dc6c5f10a7ad77b365d33905c095bff9 Author: Guillaume Nodet <gno...@gmail.com> AuthorDate: Mon Apr 20 11:41:45 2020 +0200 [SSHD-979] Improve SFTP streaming Work sponsorded by Buddy https://buddy.works/ --- .../common/channel/ChannelAsyncOutputStream.java | 143 ++++++---- .../common/session/helpers/AbstractSession.java | 4 +- sshd-sftp/pom.xml | 22 ++ .../sshd/client/subsystem/sftp/RawSftpClient.java | 9 + .../sshd/client/subsystem/sftp/SftpClient.java | 56 +--- .../subsystem/sftp/SftpRemotePathChannel.java | 49 ++-- .../helpers/AbstractSftpClientExtension.java | 5 + .../client/subsystem/sftp/fs/SftpFileSystem.java | 15 + .../subsystem/sftp/fs/SftpFileSystemProvider.java | 42 ++- .../subsystem/sftp/impl/AbstractSftpClient.java | 45 ++- .../subsystem/sftp/impl/DefaultSftpClient.java | 187 +++++++----- .../subsystem/sftp/impl/SftpInputStreamAsync.java | 312 +++++++++++++++++++++ .../subsystem/sftp/impl/SftpOutputStreamAsync.java | 201 +++++++++++++ .../subsystem/sftp/SftpInputStreamWithChannel.java | 0 .../sftp/SftpOutputStreamWithChannel.java | 0 .../client/subsystem/sftp/SftpPerformanceTest.java | 243 ++++++++++++++++ .../sshd/client/subsystem/sftp/SftpTest.java | 21 +- .../client/subsystem/sftp/SftpTransferTest.java | 134 +++++++++ 18 files changed, 1276 insertions(+), 212 deletions(-) diff --git a/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelAsyncOutputStream.java b/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelAsyncOutputStream.java index 55af809..af80185 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelAsyncOutputStream.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/channel/ChannelAsyncOutputStream.java @@ -25,13 +25,13 @@ import java.util.concurrent.atomic.AtomicReference; import org.apache.sshd.common.SshConstants; import org.apache.sshd.common.future.CloseFuture; -import org.apache.sshd.common.future.SshFutureListener; import org.apache.sshd.common.io.IoOutputStream; import org.apache.sshd.common.io.IoWriteFuture; import org.apache.sshd.common.io.PacketWriter; import org.apache.sshd.common.io.WritePendingException; import org.apache.sshd.common.session.Session; import org.apache.sshd.common.util.buffer.Buffer; +import org.apache.sshd.common.util.buffer.ByteArrayBuffer; import org.apache.sshd.common.util.closeable.AbstractCloseable; public class ChannelAsyncOutputStream extends AbstractCloseable implements IoOutputStream, ChannelHolder { @@ -107,9 +107,39 @@ public class ChannelAsyncOutputStream extends AbstractCloseable implements IoOut if (total > 0) { Channel channel = getChannel(); Window remoteWindow = channel.getRemoteWindow(); - long length = Math.min(Math.min(remoteWindow.getSize(), total), remoteWindow.getPacketSize()); - if (log.isTraceEnabled()) { - log.trace("doWriteIfPossible({})[resume={}] attempting to write {} out of {}", this, resume, length, total); + long length; + if (remoteWindow.getSize() < total && total <= remoteWindow.getPacketSize()) { + // do not chunk when the window is smaller than the packet size + length = 0; + // do a defensive copy in case the user reuses the buffer + IoWriteFutureImpl f = new IoWriteFutureImpl(future.getId(), new ByteArrayBuffer(buffer.getCompactData())); + f.addListener(w -> future.setValue(w.getException() != null ? w.getException() : w.isWritten())); + pendingWrite.set(f); + if (log.isTraceEnabled()) { + log.trace("doWriteIfPossible({})[resume={}] waiting for window space {}", + this, resume, remoteWindow.getSize()); + } + } else if (total > remoteWindow.getPacketSize()) { + if (buffer.rpos() > 0) { + // do a defensive copy in case the user reuses the buffer + IoWriteFutureImpl f = new IoWriteFutureImpl(future.getId(), new ByteArrayBuffer(buffer.getCompactData())); + f.addListener(w -> future.setValue(w.getException() != null ? w.getException() : w.isWritten())); + pendingWrite.set(f); + length = remoteWindow.getPacketSize(); + if (log.isTraceEnabled()) { + log.trace("doWriteIfPossible({})[resume={}] attempting to write {} out of {}", + this, resume, length, total); + } + doWriteIfPossible(resume); + return; + } else { + length = remoteWindow.getPacketSize(); + } + } else { + length = total; + if (log.isTraceEnabled()) { + log.trace("doWriteIfPossible({})[resume={}] attempting to write {} bytes", this, resume, length); + } } if (length > 0) { @@ -125,66 +155,12 @@ public class ChannelAsyncOutputStream extends AbstractCloseable implements IoOut + ") exceeds int boundaries"); } - Session s = channel.getSession(); - Buffer buf = s.createBuffer(cmd, (int) length + 12); - buf.putInt(channel.getRecipient()); - if (cmd == SshConstants.SSH_MSG_CHANNEL_EXTENDED_DATA) { - buf.putInt(SshConstants.SSH_EXTENDED_DATA_STDERR); - } - buf.putInt(length); - buf.putRawBytes(buffer.array(), buffer.rpos(), (int) length); - buffer.rpos(buffer.rpos() + (int) length); + Buffer buf = createSendBuffer(buffer, channel, length); remoteWindow.consume(length); try { - ChannelAsyncOutputStream stream = this; IoWriteFuture writeFuture = packetWriter.writePacket(buf); - writeFuture.addListener(new SshFutureListener<IoWriteFuture>() { - @Override - public void operationComplete(IoWriteFuture f) { - if (f.isWritten()) { - handleOperationCompleted(); - } else { - handleOperationFailed(f.getException()); - } - } - - @SuppressWarnings("synthetic-access") - private void handleOperationCompleted() { - if (total > length) { - if (log.isTraceEnabled()) { - log.trace("doWriteIfPossible({}) completed write of {} out of {}", stream, length, total); - } - doWriteIfPossible(false); - } else { - boolean nullified = pendingWrite.compareAndSet(future, null); - if (log.isTraceEnabled()) { - log.trace("doWriteIfPossible({}) completed write len={}, more={}", - stream, total, !nullified); - } - future.setValue(Boolean.TRUE); - } - } - - @SuppressWarnings("synthetic-access") - private void handleOperationFailed(Throwable reason) { - if (log.isDebugEnabled()) { - log.debug("doWriteIfPossible({}) failed ({}) to complete write of {} out of {}: {}", - stream, reason.getClass().getSimpleName(), length, total, reason.getMessage()); - } - - if (log.isTraceEnabled()) { - log.trace("doWriteIfPossible(" + this + ") write failure details", reason); - } - - boolean nullified = pendingWrite.compareAndSet(future, null); - if (log.isTraceEnabled()) { - log.trace("doWriteIfPossible({}) failed write len={}, more={}", - stream, total, !nullified); - } - future.setValue(reason); - } - }); + writeFuture.addListener(f -> onWritten(future, total, length, f)); } catch (IOException e) { future.setValue(e); } @@ -202,6 +178,53 @@ public class ChannelAsyncOutputStream extends AbstractCloseable implements IoOut } } + protected void onWritten(IoWriteFutureImpl future, int total, long length, IoWriteFuture f) { + if (f.isWritten()) { + if (total > length) { + if (log.isTraceEnabled()) { + log.trace("onWritten({}) completed write of {} out of {}", + this, length, total); + } + doWriteIfPossible(false); + } else { + boolean nullified = pendingWrite.compareAndSet(future, null); + if (log.isTraceEnabled()) { + log.trace("onWritten({}) completed write len={}, more={}", + this, total, !nullified); + } + future.setValue(Boolean.TRUE); + } + } else { + Throwable reason = f.getException(); + if (log.isDebugEnabled()) { + log.debug("onWritten({}) failed ({}) to complete write of {} out of {}: {}", + this, reason.getClass().getSimpleName(), length, total, reason.getMessage()); + } + if (log.isTraceEnabled()) { + log.trace("onWritten(" + this + ") write failure details", reason); + } + boolean nullified = pendingWrite.compareAndSet(future, null); + if (log.isTraceEnabled()) { + log.trace("onWritten({}) failed write len={}, more={}", + this, total, !nullified); + } + future.setValue(reason); + } + } + + protected Buffer createSendBuffer(Buffer buffer, Channel channel, long length) { + Session s = channel.getSession(); + Buffer buf = s.createBuffer(cmd, (int) length + 12); + buf.putInt(channel.getRecipient()); + if (cmd == SshConstants.SSH_MSG_CHANNEL_EXTENDED_DATA) { + buf.putInt(SshConstants.SSH_EXTENDED_DATA_STDERR); + } + buf.putInt(length); + buf.putRawBytes(buffer.array(), buffer.rpos(), (int) length); + buffer.rpos(buffer.rpos() + (int) length); + return buf; + } + @Override public String toString() { return getClass().getSimpleName() + "[" + getChannel() + "] cmd=" + SshConstants.getCommandMessageName(cmd & 0xFF); diff --git a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java index e6d7c33..e8eda37 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/AbstractSession.java @@ -954,7 +954,9 @@ public abstract class AbstractSession extends SessionHelper { synchronized (encodeLock) { Buffer packet = resolveOutputPacket(buffer); IoSession networkSession = getIoSession(); - return networkSession.writePacket(packet); + IoWriteFuture future = networkSession.writePacket(packet); + buffer.rpos(buffer.wpos()); + return future; } } diff --git a/sshd-sftp/pom.xml b/sshd-sftp/pom.xml index 9116746..4cf78e3 100644 --- a/sshd-sftp/pom.xml +++ b/sshd-sftp/pom.xml @@ -85,8 +85,30 @@ <artifactId>jzlib</artifactId> <scope>test</scope> </dependency> + <dependency> + <groupId>org.testcontainers</groupId> + <artifactId>testcontainers</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.testcontainers</groupId> + <artifactId>toxiproxy</artifactId> + <scope>test</scope> + </dependency> </dependencies> + <dependencyManagement> + <dependencies> + <dependency> + <groupId>org.testcontainers</groupId> + <artifactId>testcontainers-bom</artifactId> + <type>pom</type> + <version>1.14.0</version> + <scope>import</scope> + </dependency> + </dependencies> + </dependencyManagement> + <build> <resources> <resource> diff --git a/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/RawSftpClient.java b/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/RawSftpClient.java index 0cd90af..560ce55 100644 --- a/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/RawSftpClient.java +++ b/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/RawSftpClient.java @@ -41,4 +41,13 @@ public interface RawSftpClient { * @throws IOException If connection closed or interrupted */ Buffer receive(int id) throws IOException; + + /** + * @param id The expected request id + * @param timeout The amount of time to wait for the response + * @return The received response {@link Buffer} containing the request id + * @throws IOException If connection closed or interrupted + */ + Buffer receive(int id, long timeout) throws IOException; + } diff --git a/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/SftpClient.java b/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/SftpClient.java index 593e996..e78fa00 100644 --- a/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/SftpClient.java +++ b/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/SftpClient.java @@ -80,12 +80,14 @@ public interface SftpClient extends SubsystemClient { /** * The {@link Set} of {@link OpenOption}-s supported by {@link #fromOpenOptions(Collection)} */ - public static final Set<OpenOption> SUPPORTED_OPTIONS = Collections.unmodifiableSet( - EnumSet.of( - StandardOpenOption.READ, StandardOpenOption.APPEND, - StandardOpenOption.CREATE, StandardOpenOption.TRUNCATE_EXISTING, - StandardOpenOption.WRITE, StandardOpenOption.CREATE_NEW, - StandardOpenOption.SPARSE)); + public static final Set<OpenOption> SUPPORTED_OPTIONS = Collections.unmodifiableSet(EnumSet.of( + StandardOpenOption.READ, + StandardOpenOption.APPEND, + StandardOpenOption.CREATE, + StandardOpenOption.TRUNCATE_EXISTING, + StandardOpenOption.WRITE, + StandardOpenOption.CREATE_NEW, + StandardOpenOption.SPARSE)); /** * Converts {@link StandardOpenOption}-s into {@link OpenMode}-s @@ -464,17 +466,9 @@ public interface SftpClient extends SubsystemClient { @Override public String toString() { - return "type=" + getType() - + ";size=" + getSize() - + ";uid=" + getUserId() - + ";gid=" + getGroupId() - + ";perms=0x" + Integer.toHexString(getPermissions()) - + ";flags=" + getFlags() - + ";owner=" + getOwner() - + ";group=" + getGroup() - + ";aTime=" + getAccessTime() - + ";cTime=" + getCreateTime() - + ";mTime=" + getModifyTime() + return "type=" + getType() + ";size=" + getSize() + ";uid=" + getUserId() + ";gid=" + getGroupId() + ";perms=0x" + + Integer.toHexString(getPermissions()) + ";flags=" + getFlags() + ";owner=" + getOwner() + ";group=" + + getGroup() + ";aTime=" + getAccessTime() + ";cTime=" + getCreateTime() + ";mTime=" + getModifyTime() + ";extensions=" + getExtensions().keySet(); } } @@ -541,7 +535,7 @@ public interface SftpClient extends SubsystemClient { DirEntry[] EMPTY_DIR_ENTRIES = new DirEntry[0]; // default values used if none specified - int MIN_BUFFER_SIZE = Byte.MAX_VALUE; + int MIN_BUFFER_SIZE = 256; int MIN_READ_BUFFER_SIZE = MIN_BUFFER_SIZE; int MIN_WRITE_BUFFER_SIZE = MIN_BUFFER_SIZE; int IO_BUFFER_SIZE = 32 * 1024; @@ -954,18 +948,7 @@ public interface SftpClient extends SubsystemClient { * @return An {@link InputStream} for reading the remote file data * @throws IOException If failed to execute */ - default InputStream read(String path, int bufferSize, Collection<OpenMode> mode) throws IOException { - if (bufferSize < MIN_READ_BUFFER_SIZE) { - throw new IllegalArgumentException( - "Insufficient read buffer size: " + bufferSize + ", min.=" + MIN_READ_BUFFER_SIZE); - } - - if (!isOpen()) { - throw new IOException("read(" + path + ")[" + mode + "] size=" + bufferSize + ": client is closed"); - } - - return new SftpInputStreamWithChannel(this, bufferSize, path, mode); - } + InputStream read(String path, int bufferSize, Collection<OpenMode> mode) throws IOException; default OutputStream write(String path) throws IOException { return write(path, DEFAULT_WRITE_BUFFER_SIZE); @@ -996,18 +979,7 @@ public interface SftpClient extends SubsystemClient { * @return An {@link OutputStream} for writing the data * @throws IOException If failed to execute */ - default OutputStream write(String path, int bufferSize, Collection<OpenMode> mode) throws IOException { - if (bufferSize < MIN_WRITE_BUFFER_SIZE) { - throw new IllegalArgumentException( - "Insufficient write buffer size: " + bufferSize + ", min.=" + MIN_WRITE_BUFFER_SIZE); - } - - if (!isOpen()) { - throw new IOException("write(" + path + ")[" + mode + "] size=" + bufferSize + ": client is closed"); - } - - return new SftpOutputStreamWithChannel(this, bufferSize, path, mode); - } + OutputStream write(String path, int bufferSize, Collection<OpenMode> mode) throws IOException; /** * @param <E> The generic extension type diff --git a/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/SftpRemotePathChannel.java b/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/SftpRemotePathChannel.java index 2789cd2..e9d5f5a 100644 --- a/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/SftpRemotePathChannel.java +++ b/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/SftpRemotePathChannel.java @@ -42,6 +42,9 @@ import java.util.concurrent.atomic.AtomicReference; import org.apache.sshd.client.session.ClientSession; import org.apache.sshd.client.subsystem.sftp.SftpClient.Attributes; +import org.apache.sshd.client.subsystem.sftp.impl.AbstractSftpClient; +import org.apache.sshd.client.subsystem.sftp.impl.SftpInputStreamAsync; +import org.apache.sshd.client.subsystem.sftp.impl.SftpOutputStreamAsync; import org.apache.sshd.common.subsystem.sftp.SftpConstants; import org.apache.sshd.common.subsystem.sftp.SftpException; import org.apache.sshd.common.util.GenericUtils; @@ -230,6 +233,18 @@ public class SftpRemotePathChannel extends FileChannel { return doWrite(buffers, -1L); } + static class Ack { + int id; + long offset; + int length; + + Ack(int id, long offset, int length) { + this.id = id; + this.offset = offset; + this.length = length; + } + } + protected long doWrite(Collection<? extends ByteBuffer> buffers, long position) throws IOException { ensureOpen(WRITE_MODES); @@ -346,30 +361,19 @@ public class SftpRemotePathChannel extends FileChannel { } boolean completed = false; - boolean eof = false; - long curPos = position; - int bufSize = (int) Math.min(count, copySize); - byte[] buffer = new byte[bufSize]; - long totalRead = 0L; + boolean eof; + long totalRead; synchronized (lock) { try { beginBlocking("transferTo"); - while (totalRead < count && !eof) { - int read = sftp.read(handle, curPos, buffer, 0, - (int) Math.min(count - totalRead, buffer.length)); - if (read > 0) { - ByteBuffer wrap = ByteBuffer.wrap(buffer, 0, read); - while (wrap.remaining() > 0) { - target.write(wrap); - } - curPos += read; - totalRead += read; - } else { - eof = read == -1; - } - } + SftpInputStreamAsync input = new SftpInputStreamAsync( + (AbstractSftpClient) sftp, + copySize, position, count, getRemotePath(), handle); + totalRead = input.transferTo(count, target); + // DO NOT CLOSE THE STREAM AS IT WOULD CLOSE THE HANDLE + eof = input.isEof(); completed = true; } finally { endBlocking("transferTo", completed); @@ -410,18 +414,23 @@ public class SftpRemotePathChannel extends FileChannel { try { beginBlocking("transferFrom"); + SftpOutputStreamAsync output = new SftpOutputStreamAsync( + (AbstractSftpClient) sftp, + copySize, getRemotePath(), handle); while (totalRead < count) { ByteBuffer wrap = ByteBuffer.wrap( buffer, 0, (int) Math.min(buffer.length, count - totalRead)); int read = src.read(wrap); if (read > 0) { - sftp.write(handle, curPos, buffer, 0, read); + output.write(buffer, 0, read); curPos += read; totalRead += read; } else { break; } } + output.flush(); + // DO NOT CLOSE THE OUTPUT STREAM AS IT WOULD CLOSE THE HANDLE completed = true; } finally { endBlocking("transferFrom", completed); diff --git a/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/extensions/helpers/AbstractSftpClientExtension.java b/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/extensions/helpers/AbstractSftpClientExtension.java index 4ad73c3..8ee2293 100644 --- a/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/extensions/helpers/AbstractSftpClientExtension.java +++ b/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/extensions/helpers/AbstractSftpClientExtension.java @@ -95,6 +95,11 @@ public abstract class AbstractSftpClientExtension extends AbstractLoggingBean im } @Override + public Buffer receive(int id, long timeout) throws IOException { + return raw.receive(id, timeout); + } + + @Override public final boolean isSupported() { return supported; } diff --git a/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/fs/SftpFileSystem.java b/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/fs/SftpFileSystem.java index c60ae8f..ad8d6dd 100644 --- a/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/fs/SftpFileSystem.java +++ b/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/fs/SftpFileSystem.java @@ -557,6 +557,21 @@ public class SftpFileSystem "receive(id=" + id + ") delegate is not a " + RawSftpClient.class.getSimpleName()); } } + + @Override + public Buffer receive(int id, long timeout) throws IOException { + if (!isOpen()) { + throw new IOException("receive(id=" + id + ", timeout=" + timeout + ") client is closed"); + } + + if (delegate instanceof RawSftpClient) { + return ((RawSftpClient) delegate).receive(id, timeout); + } else { + throw new StreamCorruptedException( + "receive(id=" + id + ", timeout=" + timeout + ") delegate is not a " + + RawSftpClient.class.getSimpleName()); + } + } } public static class DefaultUserPrincipalLookupService extends UserPrincipalLookupService { diff --git a/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/fs/SftpFileSystemProvider.java b/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/fs/SftpFileSystemProvider.java index 774e17d..666658e 100644 --- a/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/fs/SftpFileSystemProvider.java +++ b/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/fs/SftpFileSystemProvider.java @@ -74,8 +74,11 @@ import org.apache.sshd.client.SshClient; import org.apache.sshd.client.session.ClientSession; import org.apache.sshd.client.subsystem.sftp.SftpClient; import org.apache.sshd.client.subsystem.sftp.SftpClient.Attributes; +import org.apache.sshd.client.subsystem.sftp.SftpClient.OpenMode; import org.apache.sshd.client.subsystem.sftp.SftpClientFactory; +import org.apache.sshd.client.subsystem.sftp.SftpRemotePathChannel; import org.apache.sshd.client.subsystem.sftp.SftpVersionSelector; +import org.apache.sshd.client.subsystem.sftp.extensions.CopyFileExtension; import org.apache.sshd.common.PropertyResolver; import org.apache.sshd.common.PropertyResolverUtils; import org.apache.sshd.common.SshConstants; @@ -482,7 +485,33 @@ public class SftpFileSystemProvider extends FileSystemProvider { modes = EnumSet.of(SftpClient.OpenMode.Read, SftpClient.OpenMode.Write); } // TODO: process file attributes - return new SftpFileSystemChannel(toSftpPath(path), modes); + SftpPath p = toSftpPath(path); + return new SftpRemotePathChannel(p.toString(), p.getFileSystem().getClient(), true, modes); + } + + @Override + public InputStream newInputStream(Path path, OpenOption... options) throws IOException { + Collection<SftpClient.OpenMode> modes = SftpClient.OpenMode.fromOpenOptions(Arrays.asList(options)); + if (modes.isEmpty()) { + modes = EnumSet.of(SftpClient.OpenMode.Read); + } + SftpPath p = toSftpPath(path); + return p.getFileSystem().getClient().read(p.toString(), modes); + } + + @Override + public OutputStream newOutputStream(Path path, OpenOption... options) throws IOException { + Set<SftpClient.OpenMode> modes = SftpClient.OpenMode.fromOpenOptions(Arrays.asList(options)); + if (modes.contains(OpenMode.Read)) { + throw new IllegalArgumentException("READ not allowed"); + } + if (modes.isEmpty()) { + modes = EnumSet.of(OpenMode.Create, OpenMode.Truncate, OpenMode.Write); + } else { + modes.add(OpenMode.Write); + } + SftpPath p = toSftpPath(path); + return p.getFileSystem().getClient().write(p.toString(), modes); } @Override @@ -591,9 +620,14 @@ public class SftpFileSystemProvider extends FileSystemProvider { if (attrs.isDirectory()) { createDirectory(target); } else { - try (InputStream in = newInputStream(source); - OutputStream os = newOutputStream(target)) { - IoUtils.copy(in, os); + CopyFileExtension copyFile = src.getFileSystem().getClient().getExtension(CopyFileExtension.class); + if (copyFile.isSupported()) { + copyFile.copyFile(source.toString(), target.toString(), false); + } else { + try (InputStream in = newInputStream(source); + OutputStream os = newOutputStream(target)) { + IoUtils.copy(in, os); + } } } diff --git a/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/impl/AbstractSftpClient.java b/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/impl/AbstractSftpClient.java index 577ecdb..b96a24a 100644 --- a/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/impl/AbstractSftpClient.java +++ b/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/impl/AbstractSftpClient.java @@ -19,6 +19,8 @@ package org.apache.sshd.client.subsystem.sftp.impl; import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; import java.nio.charset.Charset; import java.nio.file.attribute.FileTime; import java.util.ArrayList; @@ -786,7 +788,6 @@ public abstract class AbstractSftpClient extends AbstractSubsystemClient impleme if (eofSignalled != null) { eofSignalled.set(null); } - if (!isOpen()) { throw new IOException("read(" + handle + "/" + fileOffset + ")[" + dstOffset + "/" + len + "] client is closed"); } @@ -1278,4 +1279,46 @@ public abstract class AbstractSftpClient extends AbstractSubsystemClient impleme buffer.putLong(length); checkCommandStatus(SftpConstants.SSH_FXP_UNBLOCK, buffer); } + + @Override + public InputStream read(String path, int bufferSize, Collection<OpenMode> mode) throws IOException { + if (bufferSize < MIN_WRITE_BUFFER_SIZE) { + throw new IllegalArgumentException( + "Insufficient read buffer size: " + bufferSize + ", min.=" + + MIN_READ_BUFFER_SIZE); + } + + if (!isOpen()) { + throw new IOException("write(" + path + ")[" + mode + "] size=" + bufferSize + ": client is closed"); + } + + return new SftpInputStreamAsync(this, bufferSize, path, mode); + } + + @Override + public InputStream read(String path, Collection<OpenMode> mode) throws IOException { + int packetSize = (int) getChannel().getRemoteWindow().getPacketSize(); + return read(path, packetSize, mode); + } + + @Override + public OutputStream write(String path, int bufferSize, Collection<OpenMode> mode) throws IOException { + if (bufferSize < MIN_WRITE_BUFFER_SIZE) { + throw new IllegalArgumentException( + "Insufficient write buffer size: " + bufferSize + ", min.=" + + MIN_WRITE_BUFFER_SIZE); + } + + if (!isOpen()) { + throw new IOException("write(" + path + ")[" + mode + "] size=" + bufferSize + ": client is closed"); + } + + return new SftpOutputStreamAsync(this, bufferSize, path, mode); + } + + @Override + public OutputStream write(String path, Collection<OpenMode> mode) throws IOException { + int packetSize = (int) getChannel().getRemoteWindow().getPacketSize(); + return write(path, packetSize, mode); + } } diff --git a/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/impl/DefaultSftpClient.java b/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/impl/DefaultSftpClient.java index 28e6b4f..5ef0e18 100644 --- a/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/impl/DefaultSftpClient.java +++ b/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/impl/DefaultSftpClient.java @@ -21,7 +21,6 @@ package org.apache.sshd.client.subsystem.sftp.impl; import java.io.ByteArrayOutputStream; import java.io.EOFException; import java.io.IOException; -import java.io.InputStream; import java.io.InterruptedIOException; import java.io.OutputStream; import java.io.StreamCorruptedException; @@ -48,6 +47,10 @@ import org.apache.sshd.common.FactoryManager; import org.apache.sshd.common.PropertyResolverUtils; import org.apache.sshd.common.SshConstants; import org.apache.sshd.common.SshException; +import org.apache.sshd.common.channel.Channel; +import org.apache.sshd.common.channel.ChannelAsyncOutputStream; +import org.apache.sshd.common.future.CloseFuture; +import org.apache.sshd.common.session.ConnectionService; import org.apache.sshd.common.session.Session; import org.apache.sshd.common.subsystem.sftp.SftpConstants; import org.apache.sshd.common.subsystem.sftp.extensions.ParserUtils; @@ -55,7 +58,6 @@ import org.apache.sshd.common.subsystem.sftp.extensions.VersionsParser.Versions; import org.apache.sshd.common.util.GenericUtils; import org.apache.sshd.common.util.ValidateUtils; import org.apache.sshd.common.util.buffer.Buffer; -import org.apache.sshd.common.util.buffer.BufferUtils; import org.apache.sshd.common.util.buffer.ByteArrayBuffer; /** @@ -67,7 +69,6 @@ public class DefaultSftpClient extends AbstractSftpClient { private final Map<Integer, Buffer> messages = new HashMap<>(); private final AtomicInteger cmdId = new AtomicInteger(100); private final Buffer receiveBuffer = new ByteArrayBuffer(); - private final byte[] workBuf = new byte[Integer.BYTES]; private final AtomicInteger versionHolder = new AtomicInteger(0); private final AtomicBoolean closing = new AtomicBoolean(false); private final NavigableMap<String, byte[]> extensions = new TreeMap<>(String.CASE_INSENSITIVE_ORDER); @@ -78,24 +79,8 @@ public class DefaultSftpClient extends AbstractSftpClient { this.nameDecodingCharset = PropertyResolverUtils.getCharset( clientSession, NAME_DECODING_CHARSET, DEFAULT_NAME_DECODING_CHARSET); this.clientSession = Objects.requireNonNull(clientSession, "No client session"); - this.channel = clientSession.createSubsystemChannel(SftpConstants.SFTP_SUBSYSTEM_NAME); - this.channel.setOut(new OutputStream() { - private final byte[] singleByte = new byte[1]; - - @Override - public void write(int b) throws IOException { - synchronized (singleByte) { - singleByte[0] = (byte) b; - write(singleByte); - } - } - - @Override - public void write(byte[] b, int off, int len) throws IOException { - data(b, off, len); - } - }); - this.channel.setErr(new ByteArrayOutputStream(Byte.MAX_VALUE)); + this.channel = new SftpChannelSubsystem(); + clientSession.getService(ConnectionService.class).registerChannel(channel); long initializationTimeout = clientSession.getLongProperty( SFTP_CHANNEL_OPEN_TIMEOUT, DEFAULT_CHANNEL_OPEN_TIMEOUT); @@ -274,12 +259,26 @@ public class DefaultSftpClient extends AbstractSftpClient { getClientChannel(), SftpConstants.getCommandMessageName(cmd), len, id); } - OutputStream dos = channel.getInvertedIn(); - BufferUtils.writeInt(dos, 1 /* cmd */ + Integer.BYTES /* id */ + len, workBuf); - dos.write(cmd & 0xFF); - BufferUtils.writeInt(dos, id, workBuf); - dos.write(buffer.array(), buffer.rpos(), len); - dos.flush(); + Buffer buf; + int hdr = Integer.BYTES /* length */ + 1 /* cmd */ + Integer.BYTES /* id */; + if (buffer.rpos() >= hdr) { + int wpos = buffer.wpos(); + int s = buffer.rpos() - hdr; + buffer.rpos(s); + buffer.wpos(s); + buffer.putInt(1 /* cmd */ + Integer.BYTES /* id */ + len); // length + buffer.putByte((byte) (cmd & 0xFF)); // cmd + buffer.putInt(id); // id + buffer.wpos(wpos); + buf = buffer; + } else { + buf = new ByteArrayBuffer(hdr + len); + buf.putInt(1 /* cmd */ + Integer.BYTES /* id */ + len); + buf.putByte((byte) (cmd & 0xFF)); + buf.putInt(id); + buf.putBuffer(buffer); + } + channel.getAsyncIn().writePacket(buf).verify(); return id; } @@ -292,66 +291,50 @@ public class DefaultSftpClient extends AbstractSftpClient { idleTimeout = FactoryManager.DEFAULT_IDLE_TIMEOUT; } - Integer reqId = id; boolean traceEnabled = log.isTraceEnabled(); for (int count = 1;; count++) { if (isClosing() || (!isOpen())) { throw new SshException("Channel is being closed"); } - synchronized (messages) { - Buffer buffer = messages.remove(reqId); - if (buffer != null) { - return buffer; - } - - try { - messages.wait(idleTimeout); - } catch (InterruptedException e) { - throw (IOException) new InterruptedIOException( - "Interrupted while waiting for messages at iteration #" + count).initCause(e); - } + Buffer buffer = receive(id, idleTimeout); + if (buffer != null) { + return buffer; } if (traceEnabled) { - log.trace("receive({}) check iteration #{} for id={}", this, count, reqId); + log.trace("receive({}) check iteration #{} for id={}", this, count, id); } } } - protected Buffer read() throws IOException { - InputStream dis = channel.getInvertedOut(); - int length = BufferUtils.readInt(dis, workBuf); - // must have at least command + length - if (length < (1 + Integer.BYTES)) { - throw new IllegalArgumentException("Bad length: " + length); - } - - Buffer buffer = new ByteArrayBuffer(length + Integer.BYTES, false); - buffer.putInt(length); - int nb = length; - while (nb > 0) { - int readLen = dis.read(buffer.array(), buffer.wpos(), nb); - if (readLen < 0) { - throw new IllegalArgumentException("Premature EOF while read " + length + " bytes - remaining=" + nb); + @Override + public Buffer receive(int id, long idleTimeout) throws IOException { + synchronized (messages) { + Buffer buffer = messages.remove(id); + if (buffer != null) { + return buffer; + } + if (idleTimeout > 0) { + try { + messages.wait(idleTimeout); + } catch (InterruptedException e) { + throw (IOException) new InterruptedIOException("Interrupted while waiting for messages").initCause(e); + } } - buffer.wpos(buffer.wpos() + readLen); - nb -= readLen; } - - return buffer; + return null; } protected void init(long initializationTimeout) throws IOException { ValidateUtils.checkTrue(initializationTimeout > 0L, "Invalid initialization timeout: %d", initializationTimeout); // Send init packet - OutputStream dos = channel.getInvertedIn(); - BufferUtils.writeInt(dos, 5 /* total length */, workBuf); - dos.write(SftpConstants.SSH_FXP_INIT); - // Ask for the highest we support and see what the server says - BufferUtils.writeInt(dos, SftpConstants.SFTP_V6, workBuf); - dos.flush(); + Buffer buf = new ByteArrayBuffer(9); + buf.putInt(5); + buf.putByte((byte) SftpConstants.SSH_FXP_INIT); + buf.putInt(SftpConstants.SFTP_V6); + channel.getAsyncIn().writePacket(buf).verify(); Buffer buffer; Integer reqId; @@ -419,7 +402,7 @@ public class DefaultSftpClient extends AbstractSftpClient { String name = buffer.getString(); byte[] data = buffer.getBytes(); if (traceEnabled) { - log.trace("init({}) added extension=", getClientChannel(), name); + log.trace("init({}) added extension={}", getClientChannel(), name); } extensions.put(name, data); } @@ -501,4 +484,74 @@ public class DefaultSftpClient extends AbstractSftpClient { versionHolder.set(selected); return selected; } + + private class SftpChannelSubsystem extends ChannelSubsystem { + + SftpChannelSubsystem() { + super(SftpConstants.SFTP_SUBSYSTEM_NAME); + } + + @Override + protected void doOpen() throws IOException { + String systemName = getSubsystem(); + Session session = getSession(); + boolean wantReply = this.getBooleanProperty( + REQUEST_SUBSYSTEM_REPLY, DEFAULT_REQUEST_SUBSYSTEM_REPLY); + Buffer buffer = session.createBuffer(SshConstants.SSH_MSG_CHANNEL_REQUEST, + Channel.CHANNEL_SUBSYSTEM.length() + systemName.length() + Integer.SIZE); + buffer.putInt(getRecipient()); + buffer.putString(Channel.CHANNEL_SUBSYSTEM); + buffer.putBoolean(wantReply); + buffer.putString(systemName); + addPendingRequest(Channel.CHANNEL_SUBSYSTEM, wantReply); + writePacket(buffer); + + asyncIn = new ChannelAsyncOutputStream(this, SshConstants.SSH_MSG_CHANNEL_DATA) { + @SuppressWarnings("synthetic-access") + @Override + protected CloseFuture doCloseGracefully() { + try { + sendEof(); + } catch (IOException e) { + Session session = getSession(); + session.exceptionCaught(e); + } + return super.doCloseGracefully(); + } + + @Override + protected Buffer createSendBuffer(Buffer buffer, Channel channel, long length) { + if (buffer.rpos() >= 9 && length == buffer.available()) { + int rpos = buffer.rpos(); + int wpos = buffer.wpos(); + buffer.rpos(rpos - 9); + buffer.wpos(rpos - 8); + buffer.putInt(channel.getRecipient()); + buffer.putInt(length); + buffer.wpos(wpos); + return buffer; + } else { + return super.createSendBuffer(buffer, channel, length); + } + } + }; + out = new OutputStream() { + private final byte[] singleByte = new byte[1]; + + @Override + public void write(int b) throws IOException { + synchronized (singleByte) { + singleByte[0] = (byte) b; + write(singleByte); + } + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + data(b, off, len); + } + }; + err = new ByteArrayOutputStream(); + } + } } diff --git a/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/impl/SftpInputStreamAsync.java b/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/impl/SftpInputStreamAsync.java new file mode 100644 index 0000000..ec3e593 --- /dev/null +++ b/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/impl/SftpInputStreamAsync.java @@ -0,0 +1,312 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sshd.client.subsystem.sftp.impl; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.Collection; +import java.util.Deque; +import java.util.LinkedList; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicReference; + +import org.apache.sshd.client.subsystem.sftp.SftpClient; +import org.apache.sshd.client.subsystem.sftp.SftpClient.CloseableHandle; +import org.apache.sshd.client.subsystem.sftp.SftpClient.OpenMode; +import org.apache.sshd.common.SshConstants; +import org.apache.sshd.common.subsystem.sftp.SftpConstants; +import org.apache.sshd.common.subsystem.sftp.SftpHelper; +import org.apache.sshd.common.util.buffer.Buffer; +import org.apache.sshd.common.util.buffer.ByteArrayBuffer; +import org.apache.sshd.common.util.io.InputStreamWithChannel; + +public class SftpInputStreamAsync extends InputStreamWithChannel { + + static class Ack { + int id; + long offset; + int length; + + Ack(int id, long offset, int length) { + this.id = id; + this.offset = offset; + this.length = length; + } + } + + private final AbstractSftpClient client; + private final String path; + private final byte[] bb = new byte[1]; + private final int bufferSize; + private final long fileSize; + private Buffer buffer; + private CloseableHandle handle; + private long requestOffset; + private long clientOffset; + private final Deque<Ack> pendingReads = new LinkedList<>(); + private boolean eofIndicator; + + public SftpInputStreamAsync(AbstractSftpClient client, int bufferSize, + String path, Collection<OpenMode> mode) throws IOException { + this.client = Objects.requireNonNull(client, "No SFTP client instance"); + this.path = path; + this.handle = client.open(path, mode); + this.bufferSize = bufferSize; + this.fileSize = client.stat(handle).getSize(); + } + + public SftpInputStreamAsync(AbstractSftpClient client, int bufferSize, long clientOffset, long fileSize, + String path, CloseableHandle handle) { + this.client = Objects.requireNonNull(client, "No SFTP client instance"); + this.path = path; + this.handle = handle; + this.bufferSize = bufferSize; + this.clientOffset = clientOffset; + this.fileSize = fileSize; + } + + /** + * The client instance + * + * @return {@link SftpClient} instance used to access the remote file + */ + public final AbstractSftpClient getClient() { + return client; + } + + /** + * The remotely accessed file path + * + * @return Remote file path + */ + public final String getPath() { + return path; + } + + /** + * Check if the stream is at EOF + * + * @return <code>true</code> if all the data has been consumer + */ + public boolean isEof() { + return eofIndicator && hasNoData(); + } + + @Override + public boolean isOpen() { + return (handle != null) && handle.isOpen(); + } + + @Override + public int read() throws IOException { + int read = read(bb, 0, 1); + if (read > 0) { + return bb[0] & 0xFF; + } + return read; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + if (!isOpen()) { + throw new IOException("read(" + getPath() + ") stream closed"); + } + int idx = off; + while (len > 0 && !eofIndicator) { + if (hasNoData()) { + fillData(); + if (eofIndicator && (hasNoData())) { + break; + } + sendRequests(); + } else { + int nb = Math.min(buffer.available(), len); + buffer.getRawBytes(b, off, nb); + idx += nb; + len -= nb; + clientOffset += nb; + } + } + int res = idx - off; + if (res == 0 && eofIndicator) { + res = -1; + } + return res; + } + + public long transferTo(long max, WritableByteChannel out) throws IOException { + if (!isOpen()) { + throw new IOException("transferTo(" + getPath() + ") stream closed"); + } + long orgOffset = clientOffset; + while (!eofIndicator && max > 0) { + if (hasNoData()) { + fillData(); + if (eofIndicator && hasNoData()) { + break; + } + sendRequests(); + } else { + int nb = buffer.available(); + int toRead = (int) Math.min(nb, max); + ByteBuffer bb = ByteBuffer.wrap(buffer.array(), buffer.rpos(), toRead); + while (bb.hasRemaining()) { + out.write(bb); + } + buffer.rpos(buffer.rpos() + toRead); + clientOffset += toRead; + max -= toRead; + } + } + return clientOffset - orgOffset; + } + + @SuppressWarnings("PMD.MissingOverride") + public long transferTo(OutputStream out) throws IOException { + if (!isOpen()) { + throw new IOException("transferTo(" + getPath() + ") stream closed"); + } + long orgOffset = clientOffset; + while (!eofIndicator) { + if (hasNoData()) { + fillData(); + if (eofIndicator && hasNoData()) { + break; + } + sendRequests(); + } else { + int nb = buffer.available(); + out.write(buffer.array(), buffer.rpos(), nb); + buffer.rpos(buffer.rpos() + nb); + clientOffset += nb; + } + } + return clientOffset - orgOffset; + } + + @Override + public long skip(long n) throws IOException { + if (!isOpen()) { + throw new IOException("skip(" + getPath() + ") stream closed"); + } + if (clientOffset == 0 && pendingReads.isEmpty()) { + clientOffset = n; + return n; + } + return super.skip(n); + } + + boolean hasNoData() { + return buffer == null || buffer.available() == 0; + } + + void sendRequests() throws IOException { + if (!eofIndicator) { + long windowSize = client.getChannel().getLocalWindow().getMaxSize(); + while (pendingReads.size() < (int) (windowSize / bufferSize) && requestOffset < fileSize + bufferSize + || pendingReads.isEmpty()) { + Buffer buf = client.getSession().createBuffer(SshConstants.SSH_MSG_CHANNEL_DATA, + 23 /* sftp packet */ + 16 + handle.getIdentifier().length); + buf.rpos(23); + buf.wpos(23); + buf.putBytes(handle.getIdentifier()); + buf.putLong(requestOffset); + buf.putInt(bufferSize); + int reqId = client.send(SftpConstants.SSH_FXP_READ, buf); + pendingReads.add(new Ack(reqId, requestOffset, bufferSize)); + requestOffset += bufferSize; + } + } + } + + void fillData() throws IOException { + Ack ack = pendingReads.pollFirst(); + if (ack != null) { + pollBuffer(ack); + if (!eofIndicator && clientOffset < ack.offset) { + // we are actually missing some data + // so request is synchronously + byte[] data = new byte[(int) (ack.offset - clientOffset + buffer.available())]; + int cur = 0; + int nb = (int) (ack.offset - clientOffset); + AtomicReference<Boolean> eof = new AtomicReference<>(); + while (cur < nb) { + int dlen = client.read(handle, clientOffset, data, cur, nb - cur, eof); + eofIndicator = dlen < 0 || eof.get() != null && eof.get(); + cur += dlen; + } + buffer.getRawBytes(data, nb, buffer.available()); + buffer = new ByteArrayBuffer(data); + } + } + } + + void pollBuffer(Ack ack) throws IOException { + Buffer buf = client.receive(ack.id); + int length = buf.getInt(); + int type = buf.getUByte(); + int id = buf.getInt(); + client.validateIncomingResponse(SshConstants.SSH_MSG_CHANNEL_DATA, id, type, length, buf); + if (type == SftpConstants.SSH_FXP_DATA) { + int dlen = buf.getInt(); + int rpos = buf.rpos(); + buf.rpos(rpos + dlen); + Boolean b = SftpHelper.getEndOfFileIndicatorValue(buf, client.getVersion()); + eofIndicator = b != null && b; + buf.rpos(rpos); + buf.wpos(rpos + dlen); + this.buffer = buf; + } else if (type == SftpConstants.SSH_FXP_STATUS) { + int substatus = buf.getInt(); + String msg = buf.getString(); + String lang = buf.getString(); + if (substatus == SftpConstants.SSH_FX_EOF) { + eofIndicator = true; + } else { + client.checkResponseStatus(SshConstants.SSH_MSG_CHANNEL_DATA, id, substatus, msg, lang); + } + } else { + IOException err = client.handleUnexpectedPacket(SshConstants.SSH_MSG_CHANNEL_DATA, + SftpConstants.SSH_FXP_STATUS, id, type, length, buf); + if (err != null) { + throw err; + } + } + } + + @Override + public void close() throws IOException { + if (isOpen()) { + try { + try { + while (!pendingReads.isEmpty()) { + Ack ack = pendingReads.removeFirst(); + pollBuffer(ack); + } + } finally { + handle.close(); + } + } finally { + handle = null; + } + } + } +} diff --git a/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/impl/SftpOutputStreamAsync.java b/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/impl/SftpOutputStreamAsync.java new file mode 100644 index 0000000..d8f1974 --- /dev/null +++ b/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/impl/SftpOutputStreamAsync.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sshd.client.subsystem.sftp.impl; + +import java.io.IOException; +import java.util.Collection; +import java.util.Deque; +import java.util.LinkedList; +import java.util.Objects; + +import org.apache.sshd.client.subsystem.sftp.SftpClient; +import org.apache.sshd.client.subsystem.sftp.SftpClient.CloseableHandle; +import org.apache.sshd.client.subsystem.sftp.SftpClient.OpenMode; +import org.apache.sshd.common.SshConstants; +import org.apache.sshd.common.subsystem.sftp.SftpConstants; +import org.apache.sshd.common.util.buffer.Buffer; +import org.apache.sshd.common.util.buffer.ByteArrayBuffer; +import org.apache.sshd.common.util.io.OutputStreamWithChannel; + +/** + * Implements an output stream for a given remote file + * + * @author <a href="mailto:d...@mina.apache.org">Apache MINA SSHD Project</a> + */ +public class SftpOutputStreamAsync extends OutputStreamWithChannel { + + static class Ack { + int id; + long offset; + int length; + + Ack(int id, long offset, int length) { + this.id = id; + this.offset = offset; + this.length = length; + } + } + + private final AbstractSftpClient client; + private final String path; + private final byte[] bb = new byte[1]; + private final int bufferSize; + private Buffer buffer; + private CloseableHandle handle; + private long offset; + private final Deque<Ack> pendingWrites = new LinkedList<>(); + + public SftpOutputStreamAsync(AbstractSftpClient client, int bufferSize, + String path, Collection<OpenMode> mode) throws IOException { + this.client = Objects.requireNonNull(client, "No SFTP client instance"); + this.path = path; + this.handle = client.open(path, mode); + this.bufferSize = bufferSize; + } + + public SftpOutputStreamAsync(AbstractSftpClient client, int bufferSize, + String path, CloseableHandle handle) throws IOException { + this.client = Objects.requireNonNull(client, "No SFTP client instance"); + this.path = path; + this.handle = handle; + this.bufferSize = bufferSize; + } + + /** + * The client instance + * + * @return {@link SftpClient} instance used to access the remote file + */ + public final AbstractSftpClient getClient() { + return client; + } + + public void setOffset(long offset) { + this.offset = offset; + } + + /** + * The remotely accessed file path + * + * @return Remote file path + */ + public final String getPath() { + return path; + } + + @Override + public boolean isOpen() { + return (handle != null) && handle.isOpen(); + } + + @Override + public void write(int b) throws IOException { + bb[0] = (byte) b; + write(bb, 0, 1); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + do { + if (buffer == null) { + buffer = client.getSession().createBuffer(SshConstants.SSH_MSG_CHANNEL_DATA, bufferSize); + int hdr = (9 + 16 + 8 + handle.getIdentifier().length) + buffer.wpos(); + buffer.rpos(hdr); + buffer.wpos(hdr); + } + int max = bufferSize - (9 + 16 + handle.getIdentifier().length + 72); + int nb = Math.min(len, max - (buffer.wpos() - buffer.rpos())); + buffer.putRawBytes(b, off, nb); + if (buffer.available() == max) { + flush(); + } + off += nb; + len -= nb; + } while (len > 0); + } + + @Override + public void flush() throws IOException { + if (!isOpen()) { + throw new IOException("flush(" + getPath() + ") stream is closed"); + } + + for (;;) { + Ack ack = pendingWrites.peek(); + if (ack != null) { + Buffer response = client.receive(ack.id, 0); + if (response != null) { + pendingWrites.removeFirst(); + client.checkResponseStatus(SftpConstants.SSH_FXP_WRITE, response); + } else { + break; + } + } else { + break; + } + } + + byte[] id = handle.getIdentifier(); + int avail = buffer.available(); + Buffer buf; + if (buffer.rpos() >= 16 + id.length) { + int wpos = buffer.wpos(); + buffer.rpos(buffer.rpos() - 16 - id.length); + buffer.wpos(buffer.rpos()); + buffer.putBytes(id); + buffer.putLong(offset); + buffer.putInt(avail); + buffer.wpos(wpos); + buf = buffer; + } else { + buf = new ByteArrayBuffer(id.length + avail + Long.SIZE /* some extra fields */, false); + buf.putBytes(id); + buf.putLong(offset); + buf.putBytes(buffer.array(), buffer.rpos(), avail); + } + + int reqId = client.send(SftpConstants.SSH_FXP_WRITE, buf); + pendingWrites.add(new Ack(reqId, offset, avail)); + + offset += avail; + buffer = null; + } + + @Override + public void close() throws IOException { + if (isOpen()) { + try { + try { + if (buffer != null && buffer.available() > 0) { + flush(); + } + while (!pendingWrites.isEmpty()) { + Ack ack = pendingWrites.removeFirst(); + Buffer response = client.receive(ack.id); + client.checkResponseStatus(SftpConstants.SSH_FXP_WRITE, response); + } + } finally { + handle.close(); + } + } finally { + handle = null; + } + } + } +} diff --git a/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/SftpInputStreamWithChannel.java b/sshd-sftp/src/test/java/org/apache/sshd/client/subsystem/sftp/SftpInputStreamWithChannel.java similarity index 100% rename from sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/SftpInputStreamWithChannel.java rename to sshd-sftp/src/test/java/org/apache/sshd/client/subsystem/sftp/SftpInputStreamWithChannel.java diff --git a/sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/SftpOutputStreamWithChannel.java b/sshd-sftp/src/test/java/org/apache/sshd/client/subsystem/sftp/SftpOutputStreamWithChannel.java similarity index 100% rename from sshd-sftp/src/main/java/org/apache/sshd/client/subsystem/sftp/SftpOutputStreamWithChannel.java rename to sshd-sftp/src/test/java/org/apache/sshd/client/subsystem/sftp/SftpOutputStreamWithChannel.java diff --git a/sshd-sftp/src/test/java/org/apache/sshd/client/subsystem/sftp/SftpPerformanceTest.java b/sshd-sftp/src/test/java/org/apache/sshd/client/subsystem/sftp/SftpPerformanceTest.java new file mode 100644 index 0000000..d24b7c8 --- /dev/null +++ b/sshd-sftp/src/test/java/org/apache/sshd/client/subsystem/sftp/SftpPerformanceTest.java @@ -0,0 +1,243 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sshd.client.subsystem.sftp; + +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.nio.file.StandardOpenOption; +import java.util.Arrays; + +import eu.rekawek.toxiproxy.model.ToxicDirection; +import eu.rekawek.toxiproxy.model.toxic.Latency; +import org.apache.sshd.client.SshClient; +import org.apache.sshd.client.config.hosts.HostConfigEntryResolver; +import org.apache.sshd.client.keyverifier.AcceptAllServerKeyVerifier; +import org.apache.sshd.client.session.ClientSession; +import org.apache.sshd.client.subsystem.sftp.SftpClient.OpenMode; +import org.apache.sshd.client.subsystem.sftp.fs.SftpFileSystem; +import org.apache.sshd.common.keyprovider.KeyIdentityProvider; +import org.jetbrains.annotations.NotNull; +import org.junit.Ignore; +import org.junit.Rule; +import org.junit.Test; +import org.testcontainers.containers.GenericContainer; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.ToxiproxyContainer; +import org.testcontainers.containers.ToxiproxyContainer.ContainerProxy; + +@Ignore("Special class used for development only - not really a test just useful to run as such") +public class SftpPerformanceTest { + + public static final String USERNAME = "foo"; + public static final String PASSWORD = "pass"; + + // Create a common docker network so that containers can communicate + @Rule + public Network network = Network.newNetwork(); + + // the target container - this could be anything + @Rule + public GenericContainer<?> sftp = new GenericContainer<>("atmoz/sftp") + .withEnv("SFTP_USERS", USERNAME + ":" + PASSWORD) + .withNetwork(network) + .withFileSystemBind("target", "/home/foo") + .withExposedPorts(22); + + // Toxiproxy container, which will be used as a TCP proxy + @Rule + public ToxiproxyContainer toxiproxy = new ToxiproxyContainer() + .withNetwork(network); + + public SftpPerformanceTest() { + super(); + } + + @Test + public void testUploadLatency() throws IOException { + final ContainerProxy proxy = toxiproxy.getProxy(sftp, 22); + for (int latency : Arrays.asList(0, 1, 5, 10, 50, 100, 500)) { + Latency toxic = proxy.toxics().latency("latency", ToxicDirection.DOWNSTREAM, latency); + for (int megabytes : Arrays.asList(1, 5, 10, 50, 100)) { + try (SshClient client = createSshClient()) { + long orgTime; + long newTime; + try (ClientSession session = createClientSession(client, proxy)) { + orgTime = uploadPrevious(session, megabytes); + } + try (ClientSession session = createClientSession(client, proxy)) { + newTime = uploadOptimized(session, megabytes); + } + System.out.println(String.format("%3d MB / %3d ms latency: %7d down to %5d ms, gain = %d%%", + megabytes, latency, orgTime, newTime, + (int) (100 * (orgTime - newTime) / orgTime))); + } + } + toxic.remove(); + } + } + + @Test + public void testDownloadLatency() throws IOException { + final ContainerProxy proxy = toxiproxy.getProxy(sftp, 22); + for (int latency : Arrays.asList(0, 1, 5, 10, 50, 100, 500)) { + Latency toxic = proxy.toxics().latency("latency", ToxicDirection.DOWNSTREAM, latency); + for (int megabytes : Arrays.asList(1, 5, 10, 50, 100)) { + try (SshClient client = createSshClient()) { + long orgTime; + long newTime; + try (ClientSession session = createClientSession(client, proxy)) { + newTime = downloadOptimized(session, megabytes); + } + try (ClientSession session = createClientSession(client, proxy)) { + orgTime = downloadPrevious(session, megabytes); + } + System.out.println(String.format("%3d MB / %3d ms latency: %7d down to %5d ms, gain = %d%%", + megabytes, latency, orgTime, newTime, + (int) (100 * (orgTime - newTime) / orgTime))); + } + } + toxic.remove(); + } + } + + public ClientSession createClientSession(SshClient client, ContainerProxy proxy) throws IOException { + final String ipAddressViaToxiproxy = proxy.getContainerIpAddress(); + final int portViaToxiproxy = proxy.getProxyPort(); + + ClientSession session = client.connect(USERNAME, ipAddressViaToxiproxy, portViaToxiproxy).verify().getClientSession(); + session.addPasswordIdentity(PASSWORD); + session.auth().verify(); + return session; + } + + @NotNull + public SshClient createSshClient() { + SshClient client = SshClient.setUpDefaultClient(); + client.setServerKeyVerifier(AcceptAllServerKeyVerifier.INSTANCE); + client.setHostConfigEntryResolver(HostConfigEntryResolver.EMPTY); + client.setKeyIdentityProvider(KeyIdentityProvider.EMPTY_KEYS_PROVIDER); + client.start(); + return client; + } + + public long uploadPrevious(ClientSession session, int mb) throws IOException { + long t0 = System.currentTimeMillis(); + try (SftpClient client = SftpClientFactory.instance().createSftpClient(session)) { + try (OutputStream os = new BufferedOutputStream( + new SftpOutputStreamWithChannel( + client, 32768, "out.txt", + Arrays.asList(OpenMode.Write, + OpenMode.Create, + OpenMode.Truncate)), + 32768)) { + byte[] bytes = "123456789abcdef\n".getBytes(); + for (int i = 0; i < 1024 * 1024 * mb / bytes.length; i++) { + os.write(bytes); + } + } + } + long t1 = System.currentTimeMillis(); + return t1 - t0; + } + + public long uploadOptimized(ClientSession session, int mb) throws IOException { + long t0 = System.currentTimeMillis(); + try (SftpFileSystem fs = SftpClientFactory.instance().createSftpFileSystem(session)) { + Path p = fs.getPath("out.txt"); + try (OutputStream os = new BufferedOutputStream( + Files.newOutputStream(p, StandardOpenOption.CREATE, + StandardOpenOption.TRUNCATE_EXISTING), + 32768)) { + byte[] bytes = "123456789abcdef\n".getBytes(); + for (int i = 0; i < 1024 * 1024 * mb / bytes.length; i++) { + os.write(bytes); + } + } + } + long t1 = System.currentTimeMillis(); + return t1 - t0; + } + + public long downloadPrevious(ClientSession session, int mb) throws IOException { + Path f = Paths.get("target/out.txt"); + byte[] bytes = "123456789abcdef\n".getBytes(); + try (BufferedOutputStream bos = new BufferedOutputStream( + Files.newOutputStream(f, StandardOpenOption.CREATE, + StandardOpenOption.TRUNCATE_EXISTING, + StandardOpenOption.WRITE))) { + for (int i = 0; i < 1024 * 1024 * mb / bytes.length; i++) { + bos.write(bytes); + } + } + long t0 = System.currentTimeMillis(); + try (SftpClient client = SftpClientFactory.instance().createSftpClient(session)) { + try (InputStream os = new BufferedInputStream( + new SftpInputStreamWithChannel( + client, 32768, "out.txt", + Arrays.asList(OpenMode.Read)), + 32768)) { + byte[] data = new byte[8192]; + for (int i = 0; i < 1024 * 1024 * mb / data.length; i++) { + int l = os.read(data); + if (l < 0) { + break; + } + } + } + } + long t1 = System.currentTimeMillis(); + return t1 - t0; + } + + public long downloadOptimized(ClientSession session, int mb) throws IOException { + Path f = Paths.get("target/out.txt"); + byte[] bytes = "123456789abcdef\n".getBytes(); + try (BufferedOutputStream bos = new BufferedOutputStream( + Files.newOutputStream(f, StandardOpenOption.CREATE, + StandardOpenOption.TRUNCATE_EXISTING, + StandardOpenOption.WRITE))) { + for (int i = 0; i < 1024 * 1024 * mb / bytes.length; i++) { + bos.write(bytes); + } + } + long t0 = System.currentTimeMillis(); + try (SftpFileSystem fs = SftpClientFactory.instance().createSftpFileSystem(session)) { + Path p = fs.getPath("out.txt"); + try (InputStream os = new BufferedInputStream( + Files.newInputStream(p, StandardOpenOption.READ), 32768)) { + byte[] data = new byte[8192]; + for (int i = 0; i < 1024 * 1024 * mb / data.length; i++) { + int l = os.read(data); + if (l < 0) { + break; + } + } + } + } + long t1 = System.currentTimeMillis(); + return t1 - t0; + } + +} diff --git a/sshd-sftp/src/test/java/org/apache/sshd/client/subsystem/sftp/SftpTest.java b/sshd-sftp/src/test/java/org/apache/sshd/client/subsystem/sftp/SftpTest.java index da0d96d..3078ec2 100644 --- a/sshd-sftp/src/test/java/org/apache/sshd/client/subsystem/sftp/SftpTest.java +++ b/sshd-sftp/src/test/java/org/apache/sshd/client/subsystem/sftp/SftpTest.java @@ -586,21 +586,14 @@ public class SftpTest extends AbstractSftpClientTestSupport { try (SftpClient sftp = createSftpClient(session); InputStream stream = sftp.read( - CommonTestSupportUtils.resolveRelativeRemotePath(parentPath, localFile), OpenMode.Read)) { - assertFalse("Stream reported mark supported", stream.markSupported()); - try { - stream.mark(data.length); - fail("Unexpected success to mark the read limit"); - } catch (UnsupportedOperationException e) { - // expected - ignored - } + CommonTestSupportUtils.resolveRelativeRemotePath(parentPath, localFile), + OpenMode.Read)) { byte[] expected = new byte[data.length / 4]; - int readLen = stream.read(expected); - assertEquals("Failed to read fully initial data", expected.length, readLen); + int readLen = expected.length; + System.arraycopy(data, 0, expected, 0, readLen); byte[] actual = new byte[readLen]; - stream.reset(); readLen = stream.read(actual); assertEquals("Failed to read fully reset data", actual.length, readLen); assertArrayEquals("Mismatched re-read data contents", expected, actual); @@ -616,12 +609,6 @@ public class SftpTest extends AbstractSftpClientTestSupport { System.arraycopy(data, expected.length + readLen, expected, 0, expected.length); assertArrayEquals("Mismatched skipped forward data contents", expected, actual); - - skipped = stream.skip(0 - readLen); - assertEquals("Mismatched backward skip size", readLen, skipped); - readLen = stream.read(actual); - assertEquals("Failed to read fully skipped backward data", actual.length, readLen); - assertArrayEquals("Mismatched skipped backward data contents", expected, actual); } } } diff --git a/sshd-sftp/src/test/java/org/apache/sshd/client/subsystem/sftp/SftpTransferTest.java b/sshd-sftp/src/test/java/org/apache/sshd/client/subsystem/sftp/SftpTransferTest.java new file mode 100644 index 0000000..9bba81f --- /dev/null +++ b/sshd-sftp/src/test/java/org/apache/sshd/client/subsystem/sftp/SftpTransferTest.java @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.sshd.client.subsystem.sftp; + +import java.io.BufferedWriter; +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Date; +import java.util.concurrent.TimeUnit; + +import org.apache.sshd.client.session.ClientSession; +import org.apache.sshd.client.subsystem.sftp.fs.SftpFileSystem; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; + +@FixMethodOrder(MethodSorters.NAME_ASCENDING) +public class SftpTransferTest extends AbstractSftpClientTestSupport { + + private static final int BUFFER_SIZE = 8192; + + public SftpTransferTest() throws IOException { + super(); + } + + @Test + public void testTransferIntegrity() throws IOException { + try (ClientSession session = createClientSession(); + SftpFileSystem fs = SftpClientFactory.instance().createSftpFileSystem(session)) { + + Path localRoot = detectTargetFolder().resolve("sftp"); + Path remoteRoot = fs.getDefaultDir().resolve("target/sftp"); + + Path local0 = localRoot.resolve("files-0.txt"); + Path remote0 = remoteRoot.resolve("files-1.txt"); + Path local1 = localRoot.resolve("files-2.txt"); + Path remote1 = remoteRoot.resolve("files-3.txt"); + Path local2 = localRoot.resolve("files-4.txt"); + Files.deleteIfExists(local0); + Files.deleteIfExists(remote0); + Files.deleteIfExists(local1); + Files.deleteIfExists(remote1); + Files.deleteIfExists(local2); + + String data = getClass().getName() + "#" + getCurrentTestName() + "(" + new Date() + ")\n"; + try (BufferedWriter bos = Files.newBufferedWriter(local0)) { + long count = 0; + while (count < 1024 * 1024 * 10) { // 10 MB + bos.append(data); + count += data.length(); + } + } + + Files.copy(local0, remote0); + Files.copy(remote0, local1); + Files.copy(local1, remote1); + Files.copy(remote1, local2); + + assertTrue("File integrity problem", sameContent(local0, local2)); + } + } + + private ClientSession createClientSession() throws IOException { + ClientSession session = client.connect(getCurrentTestName(), TEST_LOCALHOST, port) + .verify(7L, TimeUnit.SECONDS).getSession(); + try { + session.addPasswordIdentity(getCurrentTestName()); + session.auth().verify(5L, TimeUnit.SECONDS); + return session; + } catch (IOException e) { + session.close(); + throw e; + } + } + + private boolean sameContent(Path path, Path path2) throws IOException { + byte[] buffer1 = new byte[BUFFER_SIZE]; + byte[] buffer2 = new byte[BUFFER_SIZE]; + try (InputStream in1 = Files.newInputStream(path); + InputStream in2 = Files.newInputStream(path2)) { + while (true) { + int nRead1 = readNBytes(in1, buffer1); + int nRead2 = readNBytes(in2, buffer2); + if (nRead1 != nRead2) { + return false; + } else if (nRead1 == BUFFER_SIZE) { + if (!Arrays.equals(buffer1, buffer2)) { + return false; + } + } else { + for (int i = 0; i < nRead1; i++) { + if (buffer1[i] != buffer2[i]) { + return false; + } + } + return true; + } + } + } + } + + private int readNBytes(InputStream is, byte[] b) throws IOException { + int n = 0; + int len = b.length; + while (n < len) { + int count = is.read(b, n, len - n); + if (count < 0) { + break; + } + n += count; + } + return n; + } + +}