This is an automated email from the ASF dual-hosted git repository. lgoldstein pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/mina-sshd.git
The following commit(s) were added to refs/heads/master by this push: new 5e89f97 [SSHD-982] Fix race condition when loading known hosts file 5e89f97 is described below commit 5e89f970f73531c4712c94b76e898b932d046c1e Author: FliegenKLATSCH <ch...@koras.de> AuthorDate: Thu Apr 23 17:16:08 2020 +0300 [SSHD-982] Fix race condition when loading known hosts file --- .../org/apache/sshd/common/util/GenericUtils.java | 24 +++++++++++++++ .../keyverifier/KnownHostsServerKeyVerifier.java | 32 +++++++++++++------- .../KnownHostsServerKeyVerifierTest.java | 34 ++++++++++++++++++++++ 3 files changed, 79 insertions(+), 11 deletions(-) diff --git a/sshd-common/src/main/java/org/apache/sshd/common/util/GenericUtils.java b/sshd-common/src/main/java/org/apache/sshd/common/util/GenericUtils.java index 62a2910..e5da0d5 100644 --- a/sshd-common/src/main/java/org/apache/sshd/common/util/GenericUtils.java +++ b/sshd-common/src/main/java/org/apache/sshd/common/util/GenericUtils.java @@ -42,6 +42,7 @@ import java.util.SortedSet; import java.util.TreeMap; import java.util.TreeSet; import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.BinaryOperator; import java.util.function.Consumer; import java.util.function.Function; @@ -1010,4 +1011,27 @@ public final class GenericUtils { Iterable<? extends Supplier<? extends Iterable<? extends T>>> providers) { return () -> stream(providers).<T> flatMap(s -> stream(s.get())).map(Function.identity()).iterator(); } + + /** + * The delegate Suppliers get() method is called exactly once and the result is cached. + * + * @param delegate The actual Supplier + * @return The memoized Supplier + */ + public static <T> Supplier<T> memoizeLock(Supplier<T> delegate) { + AtomicReference<T> value = new AtomicReference<>(); + return () -> { + T val = value.get(); + if (val == null) { + synchronized (value) { + val = value.get(); + if (val == null) { + val = Objects.requireNonNull(delegate.get()); + value.set(val); + } + } + } + return val; + }; + } } diff --git a/sshd-core/src/main/java/org/apache/sshd/client/keyverifier/KnownHostsServerKeyVerifier.java b/sshd-core/src/main/java/org/apache/sshd/client/keyverifier/KnownHostsServerKeyVerifier.java index 8cdecd7..b0d1f35 100644 --- a/sshd-core/src/main/java/org/apache/sshd/client/keyverifier/KnownHostsServerKeyVerifier.java +++ b/sshd-core/src/main/java/org/apache/sshd/client/keyverifier/KnownHostsServerKeyVerifier.java @@ -38,6 +38,7 @@ import java.util.List; import java.util.Objects; import java.util.TreeSet; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; import org.apache.sshd.client.config.hosts.KnownHostEntry; import org.apache.sshd.client.config.hosts.KnownHostHashValue; @@ -79,7 +80,7 @@ public class KnownHostsServerKeyVerifier /** * Represents an entry in the internal verifier's cache - * + * * @author <a href="mailto:d...@mina.apache.org">Apache MINA SSHD Project</a> */ public static class HostEntryPair { @@ -119,7 +120,8 @@ public class KnownHostsServerKeyVerifier protected final Object updateLock = new Object(); private final ServerKeyVerifier delegate; - private final AtomicReference<Collection<HostEntryPair>> keysHolder = new AtomicReference<>(Collections.emptyList()); + private final AtomicReference<Supplier<? extends Collection<HostEntryPair>>> keysSupplier + = new AtomicReference<>(getKnownHostSupplier(null, getPath())); private ModifiedServerKeyAcceptor modKeyAcceptor; public KnownHostsServerKeyVerifier(ServerKeyVerifier delegate, Path file) { @@ -153,35 +155,43 @@ public class KnownHostsServerKeyVerifier @Override public boolean verifyServerKey(ClientSession clientSession, SocketAddress remoteAddress, PublicKey serverKey) { - Collection<HostEntryPair> knownHosts = getLoadedHostsEntries(); try { if (checkReloadRequired()) { Path file = getPath(); if (exists()) { - knownHosts = reloadKnownHosts(clientSession, file); + updateReloadAttributes(); + keysSupplier.set(GenericUtils.memoizeLock(getKnownHostSupplier(clientSession, file))); } else { if (log.isDebugEnabled()) { log.debug("verifyServerKey({})[{}] missing known hosts file {}", clientSession, remoteAddress, file); } - knownHosts = Collections.emptyList(); + keysSupplier.set(GenericUtils.memoizeLock(Collections::emptyList)); } - - setLoadedHostsEntries(knownHosts); } } catch (Throwable t) { return acceptIncompleteHostKeys(clientSession, remoteAddress, serverKey, t); } + Collection<HostEntryPair> knownHosts = keysSupplier.get().get(); + return acceptKnownHostEntries(clientSession, remoteAddress, serverKey, knownHosts); } - protected Collection<HostEntryPair> getLoadedHostsEntries() { - return keysHolder.get(); + protected Supplier<Collection<HostEntryPair>> getKnownHostSupplier(ClientSession clientSession, Path file) { + return () -> { + try { + return reloadKnownHosts(clientSession, file); + } catch (Exception e) { + log.warn("verifyServerKey({}) Could not reload known hosts file {}", + clientSession, file, e); + return Collections.emptyList(); + } + }; } protected void setLoadedHostsEntries(Collection<HostEntryPair> keys) { - keysHolder.set(keys); + keysSupplier.set(() -> keys); } /** @@ -579,7 +589,7 @@ public class KnownHostsServerKeyVerifier if (delegate.verifyServerKey(clientSession, remoteAddress, serverKey)) { Path file = getPath(); - Collection<HostEntryPair> keys = getLoadedHostsEntries(); + Collection<HostEntryPair> keys = keysSupplier.get().get(); try { updateKnownHostsFile(clientSession, remoteAddress, serverKey, file, keys); } catch (Throwable t) { diff --git a/sshd-core/src/test/java/org/apache/sshd/client/keyverifier/KnownHostsServerKeyVerifierTest.java b/sshd-core/src/test/java/org/apache/sshd/client/keyverifier/KnownHostsServerKeyVerifierTest.java index 9552eee..9482ffd 100644 --- a/sshd-core/src/test/java/org/apache/sshd/client/keyverifier/KnownHostsServerKeyVerifierTest.java +++ b/sshd-core/src/test/java/org/apache/sshd/client/keyverifier/KnownHostsServerKeyVerifierTest.java @@ -94,6 +94,40 @@ public class KnownHostsServerKeyVerifierTest extends BaseTestSupport { } @Test + public void testParallelLoading() { + KnownHostsServerKeyVerifier verifier + = new KnownHostsServerKeyVerifier(AcceptAllServerKeyVerifier.INSTANCE, entriesFile) { + @Override + public ModifiedServerKeyAcceptor getModifiedServerKeyAcceptor() { + return (clientSession, remoteAddress, entry, expected, actual) -> true; // don't care here + } + + @Override + protected boolean acceptKnownHostEntries( + ClientSession clientSession, SocketAddress remoteAddress, PublicKey serverKey, + Collection<HostEntryPair> knownHosts) { + if (GenericUtils.isEmpty(knownHosts)) { + fail("Loaded known_hosts collection is empty!"); + } + return super.acceptKnownHostEntries(clientSession, remoteAddress, serverKey, knownHosts); + } + }; + + ClientFactoryManager manager = Mockito.mock(ClientFactoryManager.class); + Mockito.when(manager.getRandomFactory()).thenReturn(JceRandomFactory.INSTANCE); + + HOST_KEYS.entrySet().parallelStream().forEach(line -> { + KnownHostEntry entry = hostsEntries.get(line.getKey()); + + ClientSession session = Mockito.mock(ClientSession.class); + Mockito.when(session.getFactoryManager()).thenReturn(manager); + + Mockito.when(session.getConnectAddress()).thenReturn(line.getKey()); + assertTrue("Failed to validate server=" + entry, verifier.verifyServerKey(session, line.getKey(), line.getValue())); + }); + } + + @Test public void testNoUpdatesNoNewHostsAuthentication() throws Exception { AtomicInteger delegateCount = new AtomicInteger(0); ServerKeyVerifier delegate = (clientSession, remoteAddress, serverKey) -> {