This is an automated email from the ASF dual-hosted git repository. lgoldstein pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/mina-sshd.git
The following commit(s) were added to refs/heads/master by this push: new 1498b76 [SSHD-1066] Allow multiple binding to local port tunnel on different addresses 1498b76 is described below commit 1498b762d95c8118f903f89f92dbb79e9617af68 Author: Lyor Goldstein <lgoldst...@apache.org> AuthorDate: Fri Oct 23 18:22:59 2020 +0300 [SSHD-1066] Allow multiple binding to local port tunnel on different addresses --- CHANGES.md | 1 + .../sshd/common/util/net/SshdSocketAddress.java | 95 ++++++++++- .../sshd/client/session/AbstractClientSession.java | 2 +- .../apache/sshd/client/session/ClientSession.java | 16 ++ .../sshd/common/forward/DefaultForwarder.java | 139 ++++++++++------- .../sshd/common/forward/LocalForwardingEntry.java | 173 ++++++++++++++++----- .../forward/PortForwardingInformationProvider.java | 19 +-- .../sshd/common/forward/PortForwardingManager.java | 12 ++ .../sshd/common/forward/TcpipClientChannel.java | 2 +- .../sshd/common/session/helpers/SessionHelper.java | 10 +- ...calForwardingEntryCombinedBoundAddressTest.java | 121 ++++++++++++++ .../common/forward/LocalForwardingEntryTest.java | 97 +++++++++++- .../sshd/common/forward/PortForwardingTest.java | 98 +++++++++++- .../org/apache/sshd/util/test/BaseTestSupport.java | 24 +++ .../sftp/client/AbstractSftpClientTestSupport.java | 20 +-- 15 files changed, 680 insertions(+), 149 deletions(-) diff --git a/CHANGES.md b/CHANGES.md index 001fc9e..e47f90d 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -57,5 +57,6 @@ or `-key-file` command line option. * [SSHD-1058](https://issues.apache.org/jira/browse/SSHD-1058) Improve exception logging strategy. * [SSHD-1059](https://issues.apache.org/jira/browse/SSHD-1059) Do not send heartbeat if KEX state not DONE * [SSHD-1063](https://issues.apache.org/jira/browse/SSHD-1063) Fixed known-hosts file server key verifier matching of same host with different ports +* [SSHD-1066](https://issues.apache.org/jira/browse/SSHD-1066) Allow multiple binding to local port tunnel on different addresses * [SSHD-1070](https://issues.apache.org/jira/browse/SSHD-1070) OutOfMemoryError when use async port forwarding diff --git a/sshd-common/src/main/java/org/apache/sshd/common/util/net/SshdSocketAddress.java b/sshd-common/src/main/java/org/apache/sshd/common/util/net/SshdSocketAddress.java index e19f4b7..701e761 100644 --- a/sshd-common/src/main/java/org/apache/sshd/common/util/net/SshdSocketAddress.java +++ b/sshd-common/src/main/java/org/apache/sshd/common/util/net/SshdSocketAddress.java @@ -31,6 +31,7 @@ import java.util.Comparator; import java.util.Enumeration; import java.util.LinkedHashSet; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Set; @@ -191,7 +192,7 @@ public class SshdSocketAddress extends SocketAddress { return true; } else { return (this.getPort() == that.getPort()) - && (GenericUtils.safeCompare(this.getHostName(), that.getHostName(), false) == 0); + && isEquivalentHostName(this.getHostName(), that.getHostName(), false); } } @@ -208,12 +209,12 @@ public class SshdSocketAddress extends SocketAddress { @Override public int hashCode() { - return GenericUtils.hashCode(getHostName(), Boolean.FALSE) + getPort(); + return GenericUtils.hashCode(getHostName(), Boolean.FALSE) + 31 * Integer.hashCode(getPort()); } /** * Returns the first external network address assigned to this machine or null if one is not found. - * + * * @return Inet4Address associated with an external interface DevNote: We actually return InetAddress here, as * Inet4Addresses are final and cannot be mocked. */ @@ -292,7 +293,6 @@ public class SshdSocketAddress extends SocketAddress { } return !isLoopback(addr); - } /** @@ -325,11 +325,22 @@ public class SshdSocketAddress extends SocketAddress { return false; } - if (LOCALHOST_NAME.equals(ip) || LOCALHOST_IPV4.equals(ip)) { + if (LOCALHOST_NAME.equals(ip)) { return true; } - // TODO add support for IPv6 - see SSHD-746 + return isIPv4LoopbackAddress(ip) || isIPv6LoopbackAddress(ip); + } + + public static boolean isIPv4LoopbackAddress(String ip) { + if (GenericUtils.isEmpty(ip)) { + return false; + } + + if (LOCALHOST_IPV4.equals(ip)) { + return true; // most used + } + String[] values = GenericUtils.split(ip, '.'); if (GenericUtils.length(values) != 4) { return false; @@ -352,6 +363,34 @@ public class SshdSocketAddress extends SocketAddress { return true; } + public static boolean isIPv6LoopbackAddress(String ip) { + // TODO add more patterns - e.g., https://tools.ietf.org/id/draft-smith-v6ops-larger-ipv6-loopback-prefix-04.html + return IPV6_LONG_LOCALHOST.equals(ip) || IPV6_SHORT_LOCALHOST.equals(ip); + } + + public static boolean isEquivalentHostName(String h1, String h2, boolean allowWildcard) { + if (GenericUtils.safeCompare(h1, h2, false) == 0) { + return true; + } + + if (allowWildcard) { + return isWildcardAddress(h1) || isWildcardAddress(h2); + } + + return false; + } + + public static boolean isLoopbackAlias(String h1, String h2) { + return (LOCALHOST_NAME.equals(h1) && isLoopback(h2)) + || (LOCALHOST_NAME.equals(h2) && isLoopback(h1)); + } + + public static boolean isWildcardAddress(String addr) { + return IPV4_ANYADDR.equalsIgnoreCase(addr) + || IPV6_LONG_ANY_ADDRESS.equalsIgnoreCase(addr) + || IPV6_SHORT_ANY_ADDRESS.equalsIgnoreCase(addr); + } + public static SshdSocketAddress toSshdSocketAddress(SocketAddress addr) { if (addr == null) { return null; @@ -457,7 +496,7 @@ public class SshdSocketAddress extends SocketAddress { /** * Checks if the address is one of the allocated private blocks - * + * * @param addr The address string * @return {@code true} if this is one of the allocated private blocks. <B>Note:</B> it assumes that the * address string is indeed an IPv4 address @@ -533,7 +572,7 @@ public class SshdSocketAddress extends SocketAddress { * <LI>Has at most 3 <U>digits</U></LI> * <LI>Its value is ≤ 255</LI> * </UL> - * + * * @param c The {@link CharSequence} to be validate * @return {@code true} if valid IPv4 address component */ @@ -652,4 +691,44 @@ public class SshdSocketAddress extends SocketAddress { } return true; } + + public static <V> V findByOptionalWildcardAddress(Map<SshdSocketAddress, ? extends V> map, SshdSocketAddress address) { + Map.Entry<SshdSocketAddress, ? extends V> entry = findMatchingOptionalWildcardEntry(map, address); + return (entry == null) ? null : entry.getValue(); + } + + public static <V> V removeByOptionalWildcardAddress(Map<SshdSocketAddress, ? extends V> map, SshdSocketAddress address) { + Map.Entry<SshdSocketAddress, ? extends V> entry = findMatchingOptionalWildcardEntry(map, address); + return (entry == null) ? null : map.remove(entry.getKey()); + } + + public static <V> Map.Entry<SshdSocketAddress, ? extends V> findMatchingOptionalWildcardEntry( + Map<SshdSocketAddress, ? extends V> map, SshdSocketAddress address) { + if (GenericUtils.isEmpty(map) || (address == null)) { + return null; + } + + String hostName = address.getHostName(); + Map.Entry<SshdSocketAddress, ? extends V> candidate = null; + for (Map.Entry<SshdSocketAddress, ? extends V> e : map.entrySet()) { + SshdSocketAddress a = e.getKey(); + if (a.getPort() != address.getPort()) { + continue; + } + + String candidateName = a.getHostName(); + if (hostName.equalsIgnoreCase(candidateName)) { + return e; // If found exact match then use it + } + + if (isEquivalentHostName(hostName, candidateName, true)) { + if (candidate != null) { + throw new IllegalStateException("Multiple candidate matches for " + address + ": " + candidate + ", " + e); + } + candidate = e; + } + } + + return candidate; + } } diff --git a/sshd-core/src/main/java/org/apache/sshd/client/session/AbstractClientSession.java b/sshd-core/src/main/java/org/apache/sshd/client/session/AbstractClientSession.java index a1748f4..c7c8d5e 100644 --- a/sshd-core/src/main/java/org/apache/sshd/client/session/AbstractClientSession.java +++ b/sshd-core/src/main/java/org/apache/sshd/client/session/AbstractClientSession.java @@ -328,7 +328,7 @@ public abstract class AbstractClientSession extends AbstractSession implements C } else if (Channel.CHANNEL_SUBSYSTEM.equals(type)) { return createSubsystemChannel(subType); } else { - throw new IllegalArgumentException("Unsupported channel type " + type); + throw new IllegalArgumentException("Unsupported channel type requested: " + type); } } diff --git a/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSession.java b/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSession.java index 075d026..2de7695 100644 --- a/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSession.java +++ b/sshd-core/src/main/java/org/apache/sshd/client/session/ClientSession.java @@ -312,6 +312,22 @@ public interface ClientSession * Starts a local port forwarding and returns a tracker that stops the forwarding when the {@code close()} method is * called. This tracker can be used in a {@code try-with-resource} block to ensure cleanup of the set up forwarding. * + * @param localPort The local port - if zero one is allocated + * @param remote The remote address + * @return The tracker instance + * @throws IOException If failed to set up the requested forwarding + * @see #startLocalPortForwarding(SshdSocketAddress, SshdSocketAddress) + */ + default ExplicitPortForwardingTracker createLocalPortForwardingTracker( + int localPort, SshdSocketAddress remote) + throws IOException { + return createLocalPortForwardingTracker(new SshdSocketAddress(localPort), remote); + } + + /** + * Starts a local port forwarding and returns a tracker that stops the forwarding when the {@code close()} method is + * called. This tracker can be used in a {@code try-with-resource} block to ensure cleanup of the set up forwarding. + * * @param local The local address * @param remote The remote address * @return The tracker instance diff --git a/sshd-core/src/main/java/org/apache/sshd/common/forward/DefaultForwarder.java b/sshd-core/src/main/java/org/apache/sshd/common/forward/DefaultForwarder.java index 26a14c7..c1d24f1 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/forward/DefaultForwarder.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/forward/DefaultForwarder.java @@ -19,6 +19,7 @@ package org.apache.sshd.common.forward; import java.io.IOException; +import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.SocketAddress; import java.time.Duration; @@ -26,15 +27,14 @@ import java.util.AbstractMap.SimpleImmutableEntry; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.Comparator; import java.util.EnumSet; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.NavigableSet; import java.util.Objects; import java.util.Set; -import java.util.TreeMap; import java.util.concurrent.CopyOnWriteArraySet; import java.util.concurrent.atomic.AtomicLong; import java.util.function.Consumer; @@ -86,13 +86,13 @@ public class DefaultForwarder private final Session sessionInstance; private final Object localLock = new Object(); - private final Map<Integer, SshdSocketAddress> localToRemote = new TreeMap<>(Comparator.naturalOrder()); - private final Map<Integer, InetSocketAddress> boundLocals = new TreeMap<>(Comparator.naturalOrder()); + private final Map<SshdSocketAddress, SshdSocketAddress> localToRemote = new HashMap<>(); + private final Map<SshdSocketAddress, InetSocketAddress> boundLocals = new HashMap<>(); private final Object dynamicLock = new Object(); - private final Map<Integer, SshdSocketAddress> remoteToLocal = new TreeMap<>(Comparator.naturalOrder()); - private final Map<Integer, SocksProxy> dynamicLocal = new TreeMap<>(Comparator.naturalOrder()); - private final Map<Integer, InetSocketAddress> boundDynamic = new TreeMap<>(Comparator.naturalOrder()); + private final Map<Integer, SshdSocketAddress> remoteToLocal = new HashMap<>(); + private final Map<Integer, SocksProxy> dynamicLocal = new HashMap<>(); + private final Map<Integer, InetSocketAddress> boundDynamic = new HashMap<>(); private final Set<LocalForwardingEntry> localForwards = new HashSet<>(); private final IoHandlerFactory staticIoHandlerFactory = StaticIoHandler::new; @@ -186,29 +186,32 @@ public class DefaultForwarder throw new IllegalStateException("TcpipForwarder is closed or closing: " + state); } - InetSocketAddress bound = null; - int port; signalEstablishingExplicitTunnel(local, remote, true); + + InetSocketAddress bound = null; + SshdSocketAddress result; try { bound = doBind(local, getLocalIoAcceptor()); - port = bound.getPort(); + int port = bound.getPort(); + result = new SshdSocketAddress(bound.getHostString(), port); + synchronized (localLock) { - SshdSocketAddress prevRemote = localToRemote.get(port); + SshdSocketAddress prevRemote = SshdSocketAddress.findByOptionalWildcardAddress(localToRemote, result); if (prevRemote != null) { throw new IOException( - "Multiple local port forwarding addressing on port=" + port + "Multiple local port forwarding addressing on port=" + result + ": current=" + remote + ", previous=" + prevRemote); } - InetSocketAddress prevBound = boundLocals.get(port); + InetSocketAddress prevBound = SshdSocketAddress.findByOptionalWildcardAddress(boundLocals, result); if (prevBound != null) { throw new IOException( - "Multiple local port forwarding bindings on port=" + port + "Multiple local port forwarding bindings on port=" + result + ": current=" + bound + ", previous=" + prevBound); } - localToRemote.put(port, remote); - boundLocals.put(port, bound); + localToRemote.put(result, remote); + boundLocals.put(result, bound); } } catch (IOException | RuntimeException e) { try { @@ -221,7 +224,6 @@ public class DefaultForwarder } try { - SshdSocketAddress result = new SshdSocketAddress(bound.getHostString(), port); if (log.isDebugEnabled()) { log.debug("startLocalPortForwarding(" + local + " -> " + remote + "): " + result); } @@ -239,10 +241,9 @@ public class DefaultForwarder SshdSocketAddress remote; InetSocketAddress bound; - int port = local.getPort(); synchronized (localLock) { - remote = localToRemote.remove(port); - bound = boundLocals.remove(port); + remote = SshdSocketAddress.removeByOptionalWildcardAddress(localToRemote, local); + bound = SshdSocketAddress.removeByOptionalWildcardAddress(boundLocals, local); } unbindLocalForwarding(local, remote, bound); @@ -723,29 +724,29 @@ public class DefaultForwarder } signalEstablishingExplicitTunnel(local, null, true); + SshdSocketAddress result; try { InetSocketAddress bound = doBind(local, getLocalIoAcceptor()); - result = new SshdSocketAddress(bound.getHostString(), bound.getPort()); + result = new SshdSocketAddress(bound); if (log.isDebugEnabled()) { log.debug("localPortForwardingRequested(" + local + "): " + result); } boolean added; + LocalForwardingEntry localEntry = new LocalForwardingEntry(local, result); synchronized (localForwards) { - // NOTE !!! it is crucial to use the bound address host name first - added = localForwards - .add(new LocalForwardingEntry(result.getHostName(), local.getHostName(), result.getPort())); + added = localForwards.add(localEntry); } if (!added) { throw new IOException("Failed to add local port forwarding entry for " + local + " -> " + result); } - } catch (IOException | RuntimeException e) { + } catch (IOException | RuntimeException | Error e) { try { localPortForwardingCancelled(local); - } catch (IOException | RuntimeException err) { - e.addSuppressed(e); + } catch (IOException | RuntimeException | Error err) { + e.addSuppressed(err); } signalEstablishedExplicitTunnel(local, null, true, null, e); throw e; @@ -763,7 +764,8 @@ public class DefaultForwarder public synchronized void localPortForwardingCancelled(SshdSocketAddress local) throws IOException { LocalForwardingEntry entry; synchronized (localForwards) { - entry = LocalForwardingEntry.findMatchingEntry(local.getHostName(), local.getPort(), localForwards); + entry = LocalForwardingEntry.findMatchingEntry( + local.getHostName(), local.getPort(), localForwards); if (entry != null) { localForwards.remove(entry); } @@ -774,15 +776,18 @@ public class DefaultForwarder log.debug("localPortForwardingCancelled(" + local + ") unbind " + entry); } - signalTearingDownExplicitTunnel(entry, true, null); + SshdSocketAddress reportedBoundAddress = entry.getCombinedBoundAddress(); + signalTearingDownExplicitTunnel(reportedBoundAddress, true, null); + + SshdSocketAddress boundAddress = entry.getBoundAddress(); try { - localAcceptor.unbind(entry.toInetSocketAddress()); - } catch (RuntimeException e) { - signalTornDownExplicitTunnel(entry, true, null, e); + localAcceptor.unbind(boundAddress.toInetSocketAddress()); + } catch (RuntimeException | Error e) { + signalTornDownExplicitTunnel(reportedBoundAddress, true, null, e); throw e; } - signalTornDownExplicitTunnel(entry, true, null, null); + signalTornDownExplicitTunnel(reportedBoundAddress, true, null, null); } else { if (log.isDebugEnabled()) { log.debug("localPortForwardingCancelled(" + local + ") no match/acceptor: " + entry); @@ -993,12 +998,12 @@ public class DefaultForwarder protected InetSocketAddress doBind(SshdSocketAddress address, IoAcceptor acceptor) throws IOException { // TODO find a better way to determine the resulting bind address - what if multi-threaded calls... - Set<SocketAddress> before = acceptor.getBoundAddresses(); + Collection<SocketAddress> before = acceptor.getBoundAddresses(); try { InetSocketAddress bindAddress = address.toInetSocketAddress(); acceptor.bind(bindAddress); - Set<SocketAddress> after = acceptor.getBoundAddresses(); + Collection<SocketAddress> after = acceptor.getBoundAddresses(); if (GenericUtils.size(after) > 0) { after.removeAll(before); } @@ -1009,7 +1014,9 @@ public class DefaultForwarder if (after.size() > 1) { throw new IOException("Multiple local addresses have been bound for " + address + "[" + bindAddress + "]"); } - return (InetSocketAddress) GenericUtils.head(after); + + InetSocketAddress boundAddress = (InetSocketAddress) GenericUtils.head(after); + return boundAddress; } catch (IOException bindErr) { Collection<SocketAddress> after = acceptor.getBoundAddresses(); if (GenericUtils.isEmpty(after)) { @@ -1034,9 +1041,13 @@ public class DefaultForwarder @Override public void sessionCreated(IoSession session) throws Exception { - InetSocketAddress local = (InetSocketAddress) session.getLocalAddress(); - int localPort = local.getPort(); - SshdSocketAddress remote = localToRemote.get(localPort); + InetSocketAddress localAddress = (InetSocketAddress) session.getLocalAddress(); + SshdSocketAddress local = new SshdSocketAddress(localAddress); + SshdSocketAddress remote; + synchronized (localLock) { + remote = SshdSocketAddress.findByOptionalWildcardAddress(localToRemote, local); + } + TcpipClientChannel.Type channelType = (remote == null) ? TcpipClientChannel.Type.Forwarded : TcpipClientChannel.Type.Direct; @@ -1048,9 +1059,12 @@ public class DefaultForwarder SocketAddress accepted = session.getAcceptanceAddress(); LocalForwardingEntry localEntry = null; if (accepted instanceof InetSocketAddress) { + InetSocketAddress inetSocketAddress = (InetSocketAddress) accepted; + InetAddress inetAddress = inetSocketAddress.getAddress(); synchronized (localForwards) { localEntry = LocalForwardingEntry.findMatchingEntry( - ((InetSocketAddress) accepted).getHostString(), localPort, localForwards); + inetSocketAddress.getHostString(), inetAddress.isAnyLocalAddress(), local.getPort(), + localForwards); } } @@ -1162,18 +1176,33 @@ public class DefaultForwarder } @Override - public SshdSocketAddress getBoundLocalPortForward(int port) { - ValidateUtils.checkTrue(port > 0, "Invalid local port: %d", port); + public List<SshdSocketAddress> getBoundLocalPortForwards(int port) { + synchronized (localLock) { + return localToRemote.isEmpty() + ? Collections.emptyList() + : localToRemote.keySet() + .stream() + .filter(k -> k.getPort() == port) + .collect(Collectors.toList()); + } + } - Integer portKey = Integer.valueOf(port); - synchronized (localToRemote) { - return localToRemote.get(portKey); + @Override + public boolean isLocalPortForwardingStartedForPort(int port) { + synchronized (localLock) { + return localToRemote.isEmpty() + ? false + : localToRemote.keySet() + .stream() + .filter(e -> e.getPort() == port) + .findAny() + .isPresent(); } } @Override - public List<Map.Entry<Integer, SshdSocketAddress>> getLocalForwardsBindings() { - synchronized (localToRemote) { + public List<Map.Entry<SshdSocketAddress, SshdSocketAddress>> getLocalForwardsBindings() { + synchronized (localLock) { return localToRemote.isEmpty() ? Collections.emptyList() : localToRemote.entrySet() @@ -1184,13 +1213,9 @@ public class DefaultForwarder } @Override - public NavigableSet<Integer> getStartedLocalPortForwards() { - synchronized (localToRemote) { - if (localToRemote.isEmpty()) { - return Collections.emptyNavigableSet(); - } - - return GenericUtils.asSortedSet(localToRemote.keySet()); + public List<SshdSocketAddress> getStartedLocalPortForwards() { + synchronized (localLock) { + return localToRemote.isEmpty() ? Collections.emptyList() : new ArrayList<>(localToRemote.keySet()); } } @@ -1219,11 +1244,7 @@ public class DefaultForwarder @Override public NavigableSet<Integer> getStartedRemotePortForwards() { synchronized (remoteToLocal) { - if (remoteToLocal.isEmpty()) { - return Collections.emptyNavigableSet(); - } - - return GenericUtils.asSortedSet(remoteToLocal.keySet()); + return remoteToLocal.isEmpty() ? Collections.emptyNavigableSet() : GenericUtils.asSortedSet(remoteToLocal.keySet()); } } } diff --git a/sshd-core/src/main/java/org/apache/sshd/common/forward/LocalForwardingEntry.java b/sshd-core/src/main/java/org/apache/sshd/common/forward/LocalForwardingEntry.java index 13ed473..033d86a 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/forward/LocalForwardingEntry.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/forward/LocalForwardingEntry.java @@ -21,82 +21,185 @@ package org.apache.sshd.common.forward; import java.net.InetSocketAddress; import java.util.Collection; +import java.util.Objects; import org.apache.sshd.common.util.GenericUtils; -import org.apache.sshd.common.util.ValidateUtils; import org.apache.sshd.common.util.net.SshdSocketAddress; /** * @author <a href="mailto:d...@mina.apache.org">Apache MINA SSHD Project</a> */ -public class LocalForwardingEntry extends SshdSocketAddress { - private static final long serialVersionUID = 423661570180889621L; - private final String alias; +public class LocalForwardingEntry { + private final SshdSocketAddress local; + private final SshdSocketAddress bound; + private final SshdSocketAddress combined; - // NOTE !!! it is crucial to use the bound address host name first public LocalForwardingEntry(SshdSocketAddress local, InetSocketAddress bound) { - this(local, new SshdSocketAddress(bound.getHostString(), bound.getPort())); + this(local, new SshdSocketAddress(bound)); } - // NOTE !!! it is crucial to use the bound address host name first public LocalForwardingEntry(SshdSocketAddress local, SshdSocketAddress bound) { - this(bound.getHostName(), local.getHostName(), bound.getPort()); + this.local = Objects.requireNonNull(local, "No local address provided"); + this.bound = Objects.requireNonNull(bound, "No bound address provided"); + this.combined = resolveCombinedBoundAddress(local, bound); } - public LocalForwardingEntry(String hostName, String alias, int port) { - super(hostName, port); - this.alias = ValidateUtils.checkNotNullAndNotEmpty(alias, "No host alias"); + /** + * @return The original requested local address for binding + */ + public SshdSocketAddress getLocalAddress() { + return local; } - public String getAlias() { - return alias; + /** + * @return The actual bound address + */ + public SshdSocketAddress getBoundAddress() { + return bound; } - @Override - protected boolean isEquivalent(SshdSocketAddress that) { - if (super.isEquivalent(that) && (that instanceof LocalForwardingEntry)) { - LocalForwardingEntry entry = (LocalForwardingEntry) that; - if (GenericUtils.safeCompare(this.getAlias(), entry.getAlias(), false) == 0) { - return true; - } - } - - return false; + /** + * A combined address using the following logic: + * <UL> + * <LI>If original requested local binding has a specific port and non-wildcard address then use the local binding + * as-is</LI> + * + * <LI>If original requested local binding has a specific address but no specific port, then combine its address + * with the actual auto-allocated port at binding.</LI> + * + * <LI>If original requested local binding has neither a specific address nor a specific port then use the effective + * bound address.</LI> + * <UL> + * + * @return Combined result + */ + public SshdSocketAddress getCombinedBoundAddress() { + return combined; } @Override public boolean equals(Object o) { - return super.equals(o); + if (o == null) { + return false; + } + if (o == this) { + return true; + } + if (getClass() != o.getClass()) { + return false; + } + + LocalForwardingEntry other = (LocalForwardingEntry) o; + return Objects.equals(getCombinedBoundAddress(), other.getCombinedBoundAddress()); } @Override public int hashCode() { - return super.hashCode() + GenericUtils.hashCode(getAlias(), Boolean.FALSE); + return Objects.hashCode(getCombinedBoundAddress()); } @Override public String toString() { - return super.toString() + " - " + getAlias(); + return getClass().getSimpleName() + + "[local=" + getLocalAddress() + + ", bound=" + getBoundAddress() + + ", combined=" + getCombinedBoundAddress() + "]"; + } + + public static SshdSocketAddress resolveCombinedBoundAddress(SshdSocketAddress local, SshdSocketAddress bound) { + int localPort = local.getPort(); + int boundPort = bound.getPort(); + if ((localPort > 0) && (localPort != boundPort)) { + throw new IllegalArgumentException("Mismatched ports for local (" + local + ") vs. bound (" + bound + ") entry"); + } + + if (Objects.equals(local, bound)) { + return local; + } + + String localName = local.getHostName(); + boolean wildcardLocal = SshdSocketAddress.isWildcardAddress(localName); + if (wildcardLocal) { + return bound; + } + + if (localPort > 0) { + return local; // have a specific local address + } + + // Missing the port from local address + return new SshdSocketAddress(localName, boundPort); + } + + public static LocalForwardingEntry findMatchingEntry( + String host, int port, Collection<? extends LocalForwardingEntry> entries) { + return findMatchingEntry(host, SshdSocketAddress.isWildcardAddress(host), port, entries); } /** - * @param host The host - ignored if {@code null}/empty - i.e., no match reported - * @param port The port - ignored if non-positive - i.e., no match reported - * @param entries The {@link Collection} of {@link LocalForwardingEntry} to check - ignored if {@code null}/empty - - * i.e., no match reported - * @return The <U>first</U> entry whose host or alias matches the host name - case <U>insensitive</U> - * <B>and</B> has a matching port - {@code null} if no match found + * @param host The host - ignored if {@code null}/empty and not wildcard address match - i.e., no match + * reported + * @param anyLocalAddress Is host the wildcard address - in which case, we try an exact match first for the host, + * and if that fails then only the port is matched + * @param port The port - ignored if non-positive - i.e., no match reported + * @param entries The {@link Collection} of {@link LocalForwardingEntry} to check - ignored if + * {@code null}/empty - i.e., no match reported + * @return The <U>first</U> entry whose local or bound address matches the host name - case + * <U>insensitive</U> <B>and</B> has a matching bound port - {@code null} if no match found */ public static LocalForwardingEntry findMatchingEntry( - String host, int port, Collection<? extends LocalForwardingEntry> entries) { - if (GenericUtils.isEmpty(host) || (port <= 0) || (GenericUtils.isEmpty(entries))) { + String host, boolean anyLocalAddress, int port, Collection<? extends LocalForwardingEntry> entries) { + if ((port <= 0) || (GenericUtils.isEmpty(entries))) { return null; } + if (GenericUtils.isEmpty(host) && (!anyLocalAddress)) { + return null; + } + + LocalForwardingEntry candidate = null; for (LocalForwardingEntry e : entries) { - if ((port == e.getPort()) && (host.equalsIgnoreCase(e.getHostName()) || host.equalsIgnoreCase(e.getAlias()))) { + SshdSocketAddress bound = e.getBoundAddress(); + /* + * Note we don't check the local port since it could be zero. + * If it isn't then it must be equal to the bound port (enforced in constructor) + */ + if (port != bound.getPort()) { + continue; + } + + /* + * We first try an exact match - if not found, declare this + * a candidate and return it if host is any local address + */ + + String boundName = bound.getHostName(); + if (SshdSocketAddress.isEquivalentHostName(host, boundName, false)) { + return e; + } + + SshdSocketAddress local = e.getLocalAddress(); + String localName = local.getHostName(); + if (SshdSocketAddress.isEquivalentHostName(host, localName, false)) { return e; } + + if (SshdSocketAddress.isLoopbackAlias(host, boundName) + || SshdSocketAddress.isLoopbackAlias(host, localName)) { + return e; + } + + if (anyLocalAddress) { + if (candidate != null) { + throw new IllegalStateException( + "Multiple candidate matches for " + host + "@" + port + ": " + candidate + ", " + e); + } + candidate = e; + } + } + + if (anyLocalAddress) { + return candidate; } return null; // no match found diff --git a/sshd-core/src/main/java/org/apache/sshd/common/forward/PortForwardingInformationProvider.java b/sshd-core/src/main/java/org/apache/sshd/common/forward/PortForwardingInformationProvider.java index 0d39182..c9c8217 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/forward/PortForwardingInformationProvider.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/forward/PortForwardingInformationProvider.java @@ -23,6 +23,7 @@ import java.util.List; import java.util.Map; import java.util.NavigableSet; +import org.apache.sshd.common.util.GenericUtils; import org.apache.sshd.common.util.net.SshdSocketAddress; /** @@ -30,33 +31,33 @@ import org.apache.sshd.common.util.net.SshdSocketAddress; */ public interface PortForwardingInformationProvider { /** - * @return A {@link NavigableSet} <u>snapshot</u> of the currently started local port forwards + * @return A {@link List} <u>snapshot</u> of the currently started local port forward bindings */ - NavigableSet<Integer> getStartedLocalPortForwards(); + List<SshdSocketAddress> getStartedLocalPortForwards(); /** * @param port The port number - * @return The local bound {@link SshdSocketAddress} for the port - {@code null} if none bound + * @return The local bound {@link SshdSocketAddress}-es for the port * @see #isLocalPortForwardingStartedForPort(int) isLocalPortForwardingStartedForPort * @see #getStartedLocalPortForwards() */ - SshdSocketAddress getBoundLocalPortForward(int port); + List<SshdSocketAddress> getBoundLocalPortForwards(int port); /** - * @return A <u>snapshot</u> of the currently bound forwarded local ports as "pairs" of port + bound - * {@link SshdSocketAddress} + * @return A <u>snapshot</u> of the currently bound forwarded local ports as "pairs" of local/remote + * {@link SshdSocketAddress}-es */ - List<Map.Entry<Integer, SshdSocketAddress>> getLocalForwardsBindings(); + List<Map.Entry<SshdSocketAddress, SshdSocketAddress>> getLocalForwardsBindings(); /** * Test if local port forwarding is started * * @param port The local port * @return {@code true} if local port forwarding is started - * @see #getBoundLocalPortForward(int) getBoundLocalPortForward + * @see #getBoundLocalPortForwards(int) getBoundLocalPortForwards */ default boolean isLocalPortForwardingStartedForPort(int port) { - return getBoundLocalPortForward(port) != null; + return GenericUtils.isNotEmpty(getBoundLocalPortForwards(port)); } /** diff --git a/sshd-core/src/main/java/org/apache/sshd/common/forward/PortForwardingManager.java b/sshd-core/src/main/java/org/apache/sshd/common/forward/PortForwardingManager.java index 8cc647d..e71eabd 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/forward/PortForwardingManager.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/forward/PortForwardingManager.java @@ -28,6 +28,18 @@ import org.apache.sshd.common.util.net.SshdSocketAddress; */ public interface PortForwardingManager extends PortForwardingInformationProvider { /** + * Start forwarding the given local port on the client to the given address on the server. + * + * @param localPort The local port - if zero then one will be allocated + * @param remote The remote address + * @return The bound {@link SshdSocketAddress} + * @throws IOException If failed to create the requested binding + */ + default SshdSocketAddress startLocalPortForwarding(int localPort, SshdSocketAddress remote) throws IOException { + return startLocalPortForwarding(new SshdSocketAddress(localPort), remote); + } + + /** * Start forwarding the given local address on the client to the given address on the server. * * @param local The local address diff --git a/sshd-core/src/main/java/org/apache/sshd/common/forward/TcpipClientChannel.java b/sshd-core/src/main/java/org/apache/sshd/common/forward/TcpipClientChannel.java index 2282b9a..853581b 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/forward/TcpipClientChannel.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/forward/TcpipClientChannel.java @@ -105,7 +105,7 @@ public class TcpipClientChannel extends AbstractClientChannel implements Forward public void updateLocalForwardingEntry(LocalForwardingEntry entry) { Objects.requireNonNull(entry, "No local forwarding entry provided"); - localEntry = new SshdSocketAddress(entry.getAlias(), entry.getPort()); + localEntry = entry.getBoundAddress(); } @Override diff --git a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/SessionHelper.java b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/SessionHelper.java index 6883451..c62719a 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/SessionHelper.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/session/helpers/SessionHelper.java @@ -1199,7 +1199,7 @@ public abstract class SessionHelper extends AbstractKexFactoryManager implements } @Override - public List<Map.Entry<Integer, SshdSocketAddress>> getLocalForwardsBindings() { + public List<Map.Entry<SshdSocketAddress, SshdSocketAddress>> getLocalForwardsBindings() { Forwarder forwarder = getForwarder(); return (forwarder == null) ? Collections.emptyList() : forwarder.getLocalForwardsBindings(); } @@ -1211,15 +1211,15 @@ public abstract class SessionHelper extends AbstractKexFactoryManager implements } @Override - public NavigableSet<Integer> getStartedLocalPortForwards() { + public List<SshdSocketAddress> getStartedLocalPortForwards() { Forwarder forwarder = getForwarder(); - return (forwarder == null) ? Collections.emptyNavigableSet() : forwarder.getStartedLocalPortForwards(); + return (forwarder == null) ? Collections.emptyList() : forwarder.getStartedLocalPortForwards(); } @Override - public SshdSocketAddress getBoundLocalPortForward(int port) { + public List<SshdSocketAddress> getBoundLocalPortForwards(int port) { Forwarder forwarder = getForwarder(); - return (forwarder == null) ? null : forwarder.getBoundLocalPortForward(port); + return (forwarder == null) ? Collections.emptyList() : forwarder.getBoundLocalPortForwards(port); } @Override diff --git a/sshd-core/src/test/java/org/apache/sshd/common/forward/LocalForwardingEntryCombinedBoundAddressTest.java b/sshd-core/src/test/java/org/apache/sshd/common/forward/LocalForwardingEntryCombinedBoundAddressTest.java new file mode 100644 index 0000000..cc5e431 --- /dev/null +++ b/sshd-core/src/test/java/org/apache/sshd/common/forward/LocalForwardingEntryCombinedBoundAddressTest.java @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sshd.common.forward; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; + +import org.apache.sshd.common.util.net.SshdSocketAddress; +import org.apache.sshd.util.test.JUnit4ClassRunnerWithParametersFactory; +import org.apache.sshd.util.test.JUnitTestSupport; +import org.apache.sshd.util.test.NoIoTestCase; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.MethodSorters; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; +import org.junit.runners.Parameterized.UseParametersRunnerFactory; + +/** + * @author <a href="mailto:d...@mina.apache.org">Apache MINA SSHD Project</a> + */ +@FixMethodOrder(MethodSorters.NAME_ASCENDING) +@RunWith(Parameterized.class) // see https://github.com/junit-team/junit/wiki/Parameterized-tests +@UseParametersRunnerFactory(JUnit4ClassRunnerWithParametersFactory.class) +@Category({ NoIoTestCase.class }) +public class LocalForwardingEntryCombinedBoundAddressTest extends JUnitTestSupport { + private final LocalForwardingEntry entry; + private final SshdSocketAddress expected; + + public LocalForwardingEntryCombinedBoundAddressTest( + SshdSocketAddress local, SshdSocketAddress bound, + SshdSocketAddress expected) { + this.entry = new LocalForwardingEntry(local, bound); + this.expected = expected; + } + + @Parameters(name = "local={0}, bound={1}, expected={2}") + public static List<Object[]> parameters() { + return new ArrayList<Object[]>() { + // Not serializing it + private static final long serialVersionUID = 1L; + + { + SshdSocketAddress bound = new SshdSocketAddress("10.10.10.10", 7365); + addTestCase(bound, bound, bound); + + SshdSocketAddress specificLocal = new SshdSocketAddress("specificLocal", bound.getPort()); + addTestCase(specificLocal, bound, specificLocal); + + SshdSocketAddress noLocalPort = new SshdSocketAddress(specificLocal.getHostName(), 0); + addTestCase(noLocalPort, bound, new SshdSocketAddress(specificLocal.getHostName(), bound.getPort())); + + for (String address : new String[] { + "", SshdSocketAddress.IPV4_ANYADDR, + SshdSocketAddress.IPV6_LONG_ANY_ADDRESS, + SshdSocketAddress.IPV6_SHORT_ANY_ADDRESS + }) { + SshdSocketAddress wildcard = new SshdSocketAddress(address, bound.getPort()); + addTestCase(wildcard, bound, bound); + } + } + + private void addTestCase( + SshdSocketAddress local, SshdSocketAddress bound, SshdSocketAddress expected) { + add(new Object[] { local, bound, expected }); + } + }; + } + + @Test + public void testResolvedValue() { + assertEquals(expected, entry.getCombinedBoundAddress()); + } + + @Test + public void testHashCode() { + assertEquals(expected.hashCode(), entry.hashCode()); + } + + @Test + public void testSameInstanceReuse() { + SshdSocketAddress combined = entry.getCombinedBoundAddress(); + SshdSocketAddress local = entry.getLocalAddress(); + SshdSocketAddress bound = entry.getBoundAddress(); + boolean eqLocal = Objects.equals(combined, local); + boolean eqBound = Objects.equals(combined, bound); + if (eqLocal) { + assertSame("Not same local reference", combined, local); + } else if (eqBound) { + assertSame("Not same bound reference", combined, bound); + } else { + assertNotSame("Unexpected same local reference", combined, local); + assertNotSame("Unexpected same bound reference", combined, bound); + } + } + + @Override + public String toString() { + return getClass().getSimpleName() + "[entry=" + entry + ", expected=" + expected + "]"; + } +} diff --git a/sshd-core/src/test/java/org/apache/sshd/common/forward/LocalForwardingEntryTest.java b/sshd-core/src/test/java/org/apache/sshd/common/forward/LocalForwardingEntryTest.java index d88ff82..a61441c 100644 --- a/sshd-core/src/test/java/org/apache/sshd/common/forward/LocalForwardingEntryTest.java +++ b/sshd-core/src/test/java/org/apache/sshd/common/forward/LocalForwardingEntryTest.java @@ -25,9 +25,12 @@ import java.util.List; import java.util.stream.Collectors; import java.util.stream.IntStream; +import org.apache.sshd.common.util.net.SshdSocketAddress; import org.apache.sshd.util.test.BaseTestSupport; +import org.apache.sshd.util.test.NoIoTestCase; import org.junit.FixMethodOrder; import org.junit.Test; +import org.junit.experimental.categories.Category; import org.junit.runners.MethodSorters; /** @@ -36,6 +39,7 @@ import org.junit.runners.MethodSorters; * @author <a href="mailto:d...@mina.apache.org">Apache MINA SSHD Project</a> */ @FixMethodOrder(MethodSorters.NAME_ASCENDING) +@Category({ NoIoTestCase.class }) public class LocalForwardingEntryTest extends BaseTestSupport { public LocalForwardingEntryTest() { super(); @@ -43,12 +47,16 @@ public class LocalForwardingEntryTest extends BaseTestSupport { @Test // NOTE: this also checks indirectly SshSocketAddress host comparison case-insensitive public void testCaseInsensitiveMatching() { - LocalForwardingEntry expected = new LocalForwardingEntry(getClass().getSimpleName(), getCurrentTestName(), 7365); - String hostname = expected.getHostName(); - String alias = expected.getAlias(); - int port = expected.getPort(); + SshdSocketAddress local = new SshdSocketAddress(getClass().getSimpleName(), 0); + SshdSocketAddress bound = new SshdSocketAddress(getCurrentTestName(), 7365); + LocalForwardingEntry expected = new LocalForwardingEntry(local, bound); + String hostname = local.getHostName(); + String alias = bound.getHostName(); + int port = bound.getPort(); List<LocalForwardingEntry> entries = IntStream.rangeClosed(1, 4) - .mapToObj(seed -> new LocalForwardingEntry(hostname + "-" + seed, alias + "-" + seed, port + seed)) + .mapToObj(seed -> new LocalForwardingEntry( + new SshdSocketAddress(hostname + "-" + seed, 0), + new SshdSocketAddress(alias + "-" + seed, port + seed))) .collect(Collectors.toCollection(ArrayList::new)); entries.add(expected); @@ -63,4 +71,83 @@ public class LocalForwardingEntryTest extends BaseTestSupport { } } } + + @Test + public void testSingleWildcardMatching() { + SshdSocketAddress address = new SshdSocketAddress(getCurrentTestName(), 7365); + LocalForwardingEntry expected = new LocalForwardingEntry(address, address); + int port = address.getPort(); + List<LocalForwardingEntry> entries = IntStream.rangeClosed(1, 4) + .mapToObj(seed -> { + String hostname = address.getHostName(); + SshdSocketAddress other = new SshdSocketAddress(hostname + "-" + seed, port + seed); + return new LocalForwardingEntry(other, other); + }).collect(Collectors.toCollection(ArrayList::new)); + entries.add(expected); + + for (String host : new String[] { + SshdSocketAddress.IPV4_ANYADDR, + SshdSocketAddress.IPV6_LONG_ANY_ADDRESS, + SshdSocketAddress.IPV6_SHORT_ANY_ADDRESS + }) { + LocalForwardingEntry actual = LocalForwardingEntry.findMatchingEntry(host, port, entries); + assertSame("Host=" + host, expected, actual); + } + } + + @Test + public void testLoopbackMatching() { + int port = 7365; + List<LocalForwardingEntry> entries = IntStream.rangeClosed(1, 4) + .mapToObj(seed -> { + String hostname = getCurrentTestName(); + SshdSocketAddress other = new SshdSocketAddress(hostname + "-" + seed, port + seed); + return new LocalForwardingEntry(other, other); + }).collect(Collectors.toCollection(ArrayList::new)); + int numEntries = entries.size(); + for (String host : new String[] { + SshdSocketAddress.LOCALHOST_IPV4, + SshdSocketAddress.IPV6_LONG_LOCALHOST, + SshdSocketAddress.IPV6_SHORT_LOCALHOST + }) { + SshdSocketAddress bound = new SshdSocketAddress(host, port); + LocalForwardingEntry expected = new LocalForwardingEntry(bound, bound); + entries.add(expected); + + LocalForwardingEntry actual + = LocalForwardingEntry.findMatchingEntry(SshdSocketAddress.LOCALHOST_NAME, port, entries); + entries.remove(numEntries); + assertSame("Host=" + host, expected, actual); + } + } + + @Test + public void testMultipleWildcardCandidates() { + int port = 7365; + List<LocalForwardingEntry> entries = IntStream.rangeClosed(1, 4) + .mapToObj(seed -> { + String hostname = getCurrentTestName(); + SshdSocketAddress other = new SshdSocketAddress(hostname + "-" + seed, port + seed); + return new LocalForwardingEntry(other, other); + }).collect(Collectors.toCollection(ArrayList::new)); + for (int index = 0; index < 4; index++) { + SshdSocketAddress duplicate = new SshdSocketAddress(getClass().getSimpleName() + "-" + index, port); + entries.add(new LocalForwardingEntry(duplicate, duplicate)); + } + + for (String host : new String[] { + SshdSocketAddress.IPV4_ANYADDR, + SshdSocketAddress.IPV6_LONG_ANY_ADDRESS, + SshdSocketAddress.IPV6_SHORT_ANY_ADDRESS + }) { + try { + LocalForwardingEntry actual = LocalForwardingEntry.findMatchingEntry(host, port, entries); + fail("Unexpected success for host=" + host + ": " + actual); + } catch (IllegalStateException e) { + String msg = e.getMessage(); + assertTrue("Bad exception message: " + msg, + msg.startsWith("Multiple candidate matches for " + host + "@" + port + ":")); + } + } + } } diff --git a/sshd-core/src/test/java/org/apache/sshd/common/forward/PortForwardingTest.java b/sshd-core/src/test/java/org/apache/sshd/common/forward/PortForwardingTest.java index 51bcb80..08c7adf 100644 --- a/sshd-core/src/test/java/org/apache/sshd/common/forward/PortForwardingTest.java +++ b/sshd-core/src/test/java/org/apache/sshd/common/forward/PortForwardingTest.java @@ -18,18 +18,30 @@ */ package org.apache.sshd.common.forward; +import java.io.BufferedReader; import java.io.IOException; import java.io.InputStream; +import java.io.InputStreamReader; import java.io.OutputStream; import java.lang.reflect.Field; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; +import java.net.HttpURLConnection; +import java.net.Inet4Address; +import java.net.InetAddress; import java.net.InetSocketAddress; +import java.net.NetworkInterface; +import java.net.Proxy; import java.net.Socket; +import java.net.SocketException; +import java.net.URL; import java.nio.charset.StandardCharsets; import java.time.Duration; +import java.util.ArrayList; import java.util.Collection; +import java.util.Enumeration; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -38,6 +50,7 @@ import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; import com.jcraft.jsch.JSch; import com.jcraft.jsch.JSchException; @@ -55,6 +68,7 @@ import org.apache.sshd.common.session.ConnectionService; import org.apache.sshd.common.util.GenericUtils; import org.apache.sshd.common.util.MapEntryUtils.NavigableMapBuilder; import org.apache.sshd.common.util.ProxyUtils; +import org.apache.sshd.common.util.io.IoUtils; import org.apache.sshd.common.util.net.SshdSocketAddress; import org.apache.sshd.core.CoreModuleProperties; import org.apache.sshd.server.SshServer; @@ -66,6 +80,7 @@ import org.apache.sshd.util.test.CoreTestSupportUtils; import org.apache.sshd.util.test.JSchLogger; import org.apache.sshd.util.test.SimpleUserInfo; import org.junit.AfterClass; +import org.junit.Before; import org.junit.BeforeClass; import org.junit.FixMethodOrder; import org.junit.Test; @@ -77,6 +92,7 @@ import org.slf4j.LoggerFactory; * Port forwarding tests */ @FixMethodOrder(MethodSorters.NAME_ASCENDING) +@SuppressWarnings("checkstyle:MethodCount") public class PortForwardingTest extends BaseTestSupport { public static final int SO_TIMEOUT = (int) TimeUnit.SECONDS.toMillis(13L); @@ -198,7 +214,13 @@ public class PortForwardingTest extends BaseTestSupport { @SuppressWarnings("synthetic-access") @Override public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { - Object result = method.invoke(forwarder, args); + Object result; + try { + result = method.invoke(forwarder, args); + } catch (Throwable t) { + throw ProxyUtils.unwrapInvocationThrowable(t); + } + String name = method.getName(); String request = method2req.get(name); if (GenericUtils.length(request) > 0) { @@ -246,7 +268,14 @@ public class PortForwardingTest extends BaseTestSupport { } } - private void waitForForwardingRequest(String expected, Duration timeout) throws InterruptedException { + @Before + public void setUp() { + if (!REQUESTS_QUEUE.isEmpty()) { + REQUESTS_QUEUE.clear(); + } + } + + private static void waitForForwardingRequest(String expected, Duration timeout) throws InterruptedException { for (long remaining = timeout.toMillis(); remaining > 0L;) { long waitStart = System.currentTimeMillis(); String actual = REQUESTS_QUEUE.poll(remaining, TimeUnit.MILLISECONDS); @@ -622,6 +651,8 @@ public class PortForwardingTest extends BaseTestSupport { byte[] buf = new byte[bytes.length + Long.SIZE]; int n = input.read(buf); + assertTrue("No data read from tunnel", n > 0); + String res = new String(buf, 0, n, StandardCharsets.UTF_8); assertEquals("Mismatched data", expected, res); } finally { @@ -672,6 +703,8 @@ public class PortForwardingTest extends BaseTestSupport { output.flush(); int n = input.read(buf); + assertTrue("No data read from tunnel", n > 0); + String res = new String(buf, 0, n, StandardCharsets.UTF_8); assertEquals("Mismatched data at iteration #" + i, expected, res); } @@ -754,6 +787,61 @@ public class PortForwardingTest extends BaseTestSupport { } } + @Test // see SSHD-1066 + public void testLocalBindingOnDifferentInterfaces() throws Exception { + InetSocketAddress addr = (InetSocketAddress) GenericUtils.head(sshd.getBoundAddresses()); + log.info("{} - using bound address={}", getCurrentTestName(), addr); + + List<String> allAddresses = getHostAddresses(); + log.info("{} - test on addresses={}", getCurrentTestName(), allAddresses); + + try (ClientSession session = createNativeSession(null)) { + List<ExplicitPortForwardingTracker> trackers = new ArrayList<>(); + try { + for (String host : allAddresses) { + ExplicitPortForwardingTracker tracker = session.createLocalPortForwardingTracker( + new SshdSocketAddress(host, 8080), + new SshdSocketAddress("test.javastack.org", 80)); + SshdSocketAddress boundAddress = tracker.getBoundAddress(); + log.info("{} - test for binding={}", getCurrentTestName(), boundAddress); + testRemoteURL(new Proxy(Proxy.Type.HTTP, boundAddress.toInetSocketAddress()), + "http://test.javastack.org/"); + trackers.add(tracker); + } + } finally { + IoUtils.closeQuietly(trackers); + } + } + } + + private static List<String> getHostAddresses() throws SocketException { + List<String> addresses = new ArrayList<>(); + Enumeration<NetworkInterface> eni = NetworkInterface.getNetworkInterfaces(); + while (eni.hasMoreElements()) { + NetworkInterface networkInterface = eni.nextElement(); + Enumeration<InetAddress> eia = networkInterface.getInetAddresses(); + while (eia.hasMoreElements()) { + InetAddress ia = eia.nextElement(); + if (ia instanceof Inet4Address) { + addresses.add(ia.getHostAddress()); + } + } + } + return addresses; + } + + private static void testRemoteURL(Proxy proxy, String url) throws IOException { + HttpURLConnection connection = (HttpURLConnection) new URL(url).openConnection(proxy); + connection.setConnectTimeout((int) DEFAULT_TIMEOUT.toMillis()); + connection.setReadTimeout((int) DEFAULT_TIMEOUT.toMillis()); + String result; + try (InputStream inputStream = connection.getInputStream(); + BufferedReader in = new BufferedReader(new InputStreamReader(inputStream))) { + result = in.lines().collect(Collectors.joining(System.lineSeparator())); + } + assertEquals("Unexpected server response", "OK", result); + } + /** * Close the socket inside this JSCH session. Use reflection to find it and just close it. * @@ -848,10 +936,6 @@ public class PortForwardingTest extends BaseTestSupport { client.addPortForwardingEventListener(listener); } - ClientSession session - = client.connect(getCurrentTestName(), TEST_LOCALHOST, sshPort).verify(CONNECT_TIMEOUT).getSession(); - session.addPasswordIdentity(getCurrentTestName()); - session.auth().verify(AUTH_TIMEOUT); - return session; + return createAuthenticatedClientSession(client, sshPort); } } diff --git a/sshd-core/src/test/java/org/apache/sshd/util/test/BaseTestSupport.java b/sshd-core/src/test/java/org/apache/sshd/util/test/BaseTestSupport.java index 94765ef..5bc19cc 100644 --- a/sshd-core/src/test/java/org/apache/sshd/util/test/BaseTestSupport.java +++ b/sshd-core/src/test/java/org/apache/sshd/util/test/BaseTestSupport.java @@ -18,10 +18,12 @@ */ package org.apache.sshd.util.test; +import java.io.IOException; import java.time.Duration; import java.util.Collection; import org.apache.sshd.client.SshClient; +import org.apache.sshd.client.session.ClientSession; import org.apache.sshd.common.helpers.AbstractFactoryManager; import org.apache.sshd.common.io.BuiltinIoServiceFactoryFactories; import org.apache.sshd.common.io.DefaultIoServiceFactoryFactory; @@ -119,6 +121,28 @@ public abstract class BaseTestSupport extends JUnitTestSupport { assumeNotIoServiceProvider(getCurrentTestName(), excluded); } + protected ClientSession createClientSession(SshClient client, int port) throws IOException { + return client.connect(getCurrentTestName(), TEST_LOCALHOST, port) + .verify(CONNECT_TIMEOUT) + .getSession(); + } + + protected ClientSession createAuthenticatedClientSession(SshClient client, int port) throws IOException { + ClientSession session = createClientSession(client, port); + try { + session.addPasswordIdentity(getCurrentTestName()); + session.auth().verify(AUTH_TIMEOUT); + + ClientSession authSession = session; + session = null; // avoid auto-close at finally clause + return authSession; + } finally { + if (session != null) { + session.close(); + } + } + } + public static IoServiceFactoryFactory getIoServiceProvider() { DefaultIoServiceFactoryFactory factory = DefaultIoServiceFactoryFactory.getDefaultIoServiceFactoryFactoryInstance(); return factory.getIoServiceProvider(); diff --git a/sshd-sftp/src/test/java/org/apache/sshd/sftp/client/AbstractSftpClientTestSupport.java b/sshd-sftp/src/test/java/org/apache/sshd/sftp/client/AbstractSftpClientTestSupport.java index 4d7b14e..ed3dbad 100644 --- a/sshd-sftp/src/test/java/org/apache/sshd/sftp/client/AbstractSftpClientTestSupport.java +++ b/sshd-sftp/src/test/java/org/apache/sshd/sftp/client/AbstractSftpClientTestSupport.java @@ -87,26 +87,8 @@ public abstract class AbstractSftpClientTestSupport extends BaseTestSupport { sshd.setFileSystemFactory(fileSystemFactory); } - protected ClientSession createClientSession() throws IOException { - return client.connect(getCurrentTestName(), TEST_LOCALHOST, port) - .verify(CONNECT_TIMEOUT) - .getSession(); - } - protected ClientSession createAuthenticatedClientSession() throws IOException { - ClientSession session = createClientSession(); - try { - session.addPasswordIdentity(getCurrentTestName()); - session.auth().verify(AUTH_TIMEOUT); - - ClientSession authSession = session; - session = null; // avoid auto-close at finally clause - return authSession; - } finally { - if (session != null) { - session.close(); - } - } + return createAuthenticatedClientSession(client, port); } protected SftpClient createSingleSessionClient() throws IOException {