This is an automated email from the ASF dual-hosted git repository.

lgoldstein pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/mina-sshd.git


The following commit(s) were added to refs/heads/master by this push:
     new e021a6f  [SSHD-954] Improve validation of DH public key values
e021a6f is described below

commit e021a6f4597cb7259162169851bb9e4cb85e1a39
Author: Lyor Goldstein <lgoldst...@apache.org>
AuthorDate: Fri Oct 9 09:17:01 2020 +0300

    [SSHD-954] Improve validation of DH public key values
---
 CHANGES.md                                         |  1 +
 .../org/apache/sshd/common/util/buffer/Buffer.java | 14 ++--
 .../sshd/common/util/buffer/BufferUtils.java       | 13 ++++
 .../java/org/apache/sshd/client/kex/DHGClient.java |  8 +-
 .../org/apache/sshd/client/kex/DHGEXClient.java    | 50 ++++++++++--
 .../org/apache/sshd/common/kex/KeyExchange.java    | 19 +++++
 .../sshd/common/kex/dh/AbstractDHKeyExchange.java  | 88 +++++++++++++++++++++-
 .../org/apache/sshd/server/kex/DHGEXServer.java    | 32 ++++++--
 .../java/org/apache/sshd/server/kex/DHGServer.java |  6 +-
 9 files changed, 202 insertions(+), 29 deletions(-)

diff --git a/CHANGES.md b/CHANGES.md
index c9d2864..001fc9e 100644
--- a/CHANGES.md
+++ b/CHANGES.md
@@ -37,6 +37,7 @@ or `-key-file` command line option.
 ## Behavioral changes and enhancements
 
 * [SSHD-506](https://issues.apache.org/jira/browse/SSHD-506) Added support for 
AES-GCM ciphers.
+* [SSHD-954](https://issues.apache.org/jira/browse/SSHD-954) Improve 
validation of DH public key values.
 * [SSHD-1004](https://issues.apache.org/jira/browse/SSHD-1004) Deprecate DES, 
RC4 and Blowfish ciphers from default setup.
 * [SSHD-1004](https://issues.apache.org/jira/browse/SSHD-1004) Deprecate SHA-1 
based key exchanges and signatures from default setup.
 * [SSHD-1004](https://issues.apache.org/jira/browse/SSHD-1004) Deprecate 
MD5-based and truncated HMAC algorithms from default setup.
diff --git 
a/sshd-common/src/main/java/org/apache/sshd/common/util/buffer/Buffer.java 
b/sshd-common/src/main/java/org/apache/sshd/common/util/buffer/Buffer.java
index 8022df3..4a1bfeb 100644
--- a/sshd-common/src/main/java/org/apache/sshd/common/util/buffer/Buffer.java
+++ b/sshd-common/src/main/java/org/apache/sshd/common/util/buffer/Buffer.java
@@ -853,18 +853,18 @@ public abstract class Buffer implements Readable {
         }
     }
 
-    public void putMPInt(BigInteger bi) {
-        putMPInt(bi.toByteArray());
+    public void putMPInt(BigInteger bigint) {
+        putMPInt(bigint.toByteArray());
     }
 
-    public void putMPInt(byte[] foo) {
-        if ((foo[0] & 0x80) != 0) {
-            putInt(foo.length + 1 /* padding */);
+    public void putMPInt(byte[] mpInt) {
+        if ((mpInt[0] & 0x80) != 0) {
+            putInt(mpInt.length + 1 /* padding */);
             putByte((byte) 0);
         } else {
-            putInt(foo.length);
+            putInt(mpInt.length);
         }
-        putRawBytes(foo);
+        putRawBytes(mpInt);
     }
 
     public void putRawBytes(byte[] d) {
diff --git 
a/sshd-common/src/main/java/org/apache/sshd/common/util/buffer/BufferUtils.java 
b/sshd-common/src/main/java/org/apache/sshd/common/util/buffer/BufferUtils.java
index be67be0..ad67a5f 100644
--- 
a/sshd-common/src/main/java/org/apache/sshd/common/util/buffer/BufferUtils.java
+++ 
b/sshd-common/src/main/java/org/apache/sshd/common/util/buffer/BufferUtils.java
@@ -22,6 +22,7 @@ import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
 import java.io.StreamCorruptedException;
+import java.math.BigInteger;
 import java.util.function.IntUnaryOperator;
 import java.util.logging.Level;
 
@@ -402,6 +403,18 @@ public final class BufferUtils {
         return l;
     }
 
+    public static BigInteger fromMPIntBytes(byte[] mpInt) {
+        if (NumberUtils.isEmpty(mpInt)) {
+            return null;
+        }
+
+        if ((mpInt[0] & 0x80) != 0) {
+            return new BigInteger(0, mpInt);
+        } else {
+            return new BigInteger(mpInt);
+        }
+    }
+
     /**
      * Writes a 32-bit value in network order (i.e., MSB 1st)
      *
diff --git a/sshd-core/src/main/java/org/apache/sshd/client/kex/DHGClient.java 
b/sshd-core/src/main/java/org/apache/sshd/client/kex/DHGClient.java
index 559206b..ec65ba1 100644
--- a/sshd-core/src/main/java/org/apache/sshd/client/kex/DHGClient.java
+++ b/sshd-core/src/main/java/org/apache/sshd/client/kex/DHGClient.java
@@ -95,7 +95,8 @@ public class DHGClient extends AbstractDHClientKeyExchange {
         dh = getDH();
         hash = dh.getHash();
         hash.init();
-        e = dh.getE();
+
+        byte[] e = updateE(dh.getE());
 
         Session s = getSession();
         if (log.isDebugEnabled()) {
@@ -127,8 +128,9 @@ public class DHGClient extends AbstractDHClientKeyExchange {
         }
 
         byte[] k_s = buffer.getBytes();
-        f = buffer.getMPIntAsBytes();
+        byte[] f = updateF(buffer);
         byte[] sig = buffer.getBytes();
+
         dh.setF(f);
         k = dh.getK();
 
@@ -166,7 +168,7 @@ public class DHGClient extends AbstractDHClientKeyExchange {
         buffer.putBytes(i_c);
         buffer.putBytes(i_s);
         buffer.putBytes(k_s);
-        buffer.putMPInt(e);
+        buffer.putMPInt(getE());
         buffer.putMPInt(f);
         buffer.putMPInt(k);
         hash.update(buffer.array(), 0, buffer.available());
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/client/kex/DHGEXClient.java 
b/sshd-core/src/main/java/org/apache/sshd/client/kex/DHGEXClient.java
index d02d7e6..6c7923c 100644
--- a/sshd-core/src/main/java/org/apache/sshd/client/kex/DHGEXClient.java
+++ b/sshd-core/src/main/java/org/apache/sshd/client/kex/DHGEXClient.java
@@ -37,6 +37,7 @@ import org.apache.sshd.common.signature.Signature;
 import org.apache.sshd.common.util.GenericUtils;
 import org.apache.sshd.common.util.ValidateUtils;
 import org.apache.sshd.common.util.buffer.Buffer;
+import org.apache.sshd.common.util.buffer.BufferUtils;
 import org.apache.sshd.common.util.buffer.ByteArrayBuffer;
 import org.apache.sshd.common.util.security.SecurityUtils;
 import org.apache.sshd.core.CoreModuleProperties;
@@ -52,9 +53,11 @@ public class DHGEXClient extends AbstractDHClientKeyExchange 
{
     protected int prf;
     protected int max;
     protected AbstractDH dh;
-    protected byte[] p;
     protected byte[] g;
 
+    private byte[] p;
+    private BigInteger pValue;
+
     protected DHGEXClient(DHFactory factory, Session session) {
         super(session);
         this.factory = Objects.requireNonNull(factory, "No factory");
@@ -73,6 +76,34 @@ public class DHGEXClient extends AbstractDHClientKeyExchange 
{
         return factory.getName();
     }
 
+    protected byte[] getP() {
+        return p;
+    }
+
+    protected BigInteger getPValue() {
+        if (pValue == null) {
+            pValue = BufferUtils.fromMPIntBytes(getP());
+        }
+
+        return pValue;
+    }
+
+    protected void setP(byte[] p) {
+        this.p = p;
+
+        if (pValue != null) {
+            pValue = null;  // force lazy re-initialization
+        }
+    }
+
+    protected void validateEValue() throws Exception {
+        validateEValue(getPValue());
+    }
+
+    protected void validateFValue() throws Exception {
+        validateFValue(getPValue());
+    }
+
     public static KeyExchangeFactory newFactory(DHFactory delegate) {
         return new KeyExchangeFactory() {
             @Override
@@ -137,13 +168,15 @@ public class DHGEXClient extends 
AbstractDHClientKeyExchange {
         }
 
         if (cmd == SshConstants.SSH_MSG_KEX_DH_GEX_GROUP) {
-            p = buffer.getMPIntAsBytes();
+            setP(buffer.getMPIntAsBytes());
             g = buffer.getMPIntAsBytes();
 
-            dh = getDH(new BigInteger(p), new BigInteger(g));
+            dh = getDH(getPValue(), new BigInteger(g));
             hash = dh.getHash();
             hash.init();
-            e = dh.getE();
+
+            byte[] e = updateE(dh.getE());
+            validateEValue();
 
             if (debugEnabled) {
                 log.debug("next({})[{}] Send SSH_MSG_KEX_DH_GEX_INIT", this, 
session);
@@ -164,8 +197,11 @@ public class DHGEXClient extends 
AbstractDHClientKeyExchange {
             }
 
             byte[] k_s = buffer.getBytes();
-            f = buffer.getMPIntAsBytes();
+            byte[] f = updateF(buffer);
             byte[] sig = buffer.getBytes();
+
+            validateFValue();
+
             dh.setF(f);
             k = dh.getK();
 
@@ -188,9 +224,9 @@ public class DHGEXClient extends 
AbstractDHClientKeyExchange {
             buffer.putInt(min);
             buffer.putInt(prf);
             buffer.putInt(max);
-            buffer.putMPInt(p);
+            buffer.putMPInt(getP());
             buffer.putMPInt(g);
-            buffer.putMPInt(e);
+            buffer.putMPInt(getE());
             buffer.putMPInt(f);
             buffer.putMPInt(k);
             hash.update(buffer.array(), 0, buffer.available());
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/kex/KeyExchange.java 
b/sshd-core/src/main/java/org/apache/sshd/common/kex/KeyExchange.java
index a0ad22e..f50cd36 100644
--- a/sshd-core/src/main/java/org/apache/sshd/common/kex/KeyExchange.java
+++ b/sshd-core/src/main/java/org/apache/sshd/common/kex/KeyExchange.java
@@ -18,6 +18,7 @@
  */
 package org.apache.sshd.common.kex;
 
+import java.math.BigInteger;
 import java.util.Collections;
 import java.util.NavigableMap;
 
@@ -101,4 +102,22 @@ public interface KeyExchange extends NamedResource, 
SessionHolder<Session> {
             return name;
         }
     }
+
+    // see https://tools.ietf.org/html/rfc8268#section-4
+    static boolean isValidDHValue(BigInteger value, BigInteger p) {
+        if ((value == null) || (p == null)) {
+            return false;
+        }
+
+        // 1 < value < p-1
+        if (value.compareTo(BigInteger.ONE) <= 0) {
+            return false;
+        }
+
+        if (value.compareTo(p.subtract(BigInteger.ONE)) >= 0) {
+            return false;
+        }
+
+        return true;
+    }
 }
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/common/kex/dh/AbstractDHKeyExchange.java
 
b/sshd-core/src/main/java/org/apache/sshd/common/kex/dh/AbstractDHKeyExchange.java
index 7f29782..1e77c38 100644
--- 
a/sshd-core/src/main/java/org/apache/sshd/common/kex/dh/AbstractDHKeyExchange.java
+++ 
b/sshd-core/src/main/java/org/apache/sshd/common/kex/dh/AbstractDHKeyExchange.java
@@ -19,12 +19,17 @@
 
 package org.apache.sshd.common.kex.dh;
 
+import java.math.BigInteger;
 import java.util.Objects;
 
+import org.apache.sshd.common.SshConstants;
+import org.apache.sshd.common.SshException;
 import org.apache.sshd.common.digest.Digest;
 import org.apache.sshd.common.kex.KeyExchange;
 import org.apache.sshd.common.session.Session;
 import org.apache.sshd.common.util.ValidateUtils;
+import org.apache.sshd.common.util.buffer.Buffer;
+import org.apache.sshd.common.util.buffer.BufferUtils;
 import org.apache.sshd.common.util.logging.AbstractLoggingBean;
 
 /**
@@ -36,11 +41,14 @@ public abstract class AbstractDHKeyExchange extends 
AbstractLoggingBean implemen
     protected byte[] i_s;
     protected byte[] i_c;
     protected Digest hash;
-    protected byte[] e;
-    protected byte[] f;
     protected byte[] k;
     protected byte[] h;
 
+    private byte[] e;
+    private BigInteger eValue;
+    private byte[] f;
+    private BigInteger fValue;
+
     private final Session session;
 
     protected AbstractDHKeyExchange(Session session) {
@@ -75,6 +83,82 @@ public abstract class AbstractDHKeyExchange extends 
AbstractLoggingBean implemen
         return k;
     }
 
+    protected byte[] getE() {
+        return e;
+    }
+
+    protected BigInteger getEValue() {
+        if (eValue == null) {
+            eValue = BufferUtils.fromMPIntBytes(getE());
+        }
+
+        return eValue;
+    }
+
+    protected byte[] updateE(Buffer buffer) {
+        return updateE(buffer.getMPIntAsBytes());
+    }
+
+    protected byte[] updateE(byte[] mpInt) {
+        setE(mpInt);
+        return mpInt;
+    }
+
+    protected void setE(byte[] e) {
+        this.e = e;
+
+        if (eValue != null) {
+            eValue = null;  // force lazy re-initialization
+        }
+    }
+
+    protected void validateEValue(BigInteger pValue) throws Exception {
+        BigInteger value = Objects.requireNonNull(getEValue(), "No DH 'e' 
value set");
+        if (!KeyExchange.isValidDHValue(value, pValue)) {
+            throw new SshException(
+                    SshConstants.SSH2_DISCONNECT_KEY_EXCHANGE_FAILED,
+                    "Protocol error: invalid DH 'e' value");
+        }
+    }
+
+    protected byte[] getF() {
+        return f;
+    }
+
+    protected BigInteger getFValue() {
+        if (fValue == null) {
+            fValue = BufferUtils.fromMPIntBytes(getF());
+        }
+
+        return fValue;
+    }
+
+    protected byte[] updateF(Buffer buffer) {
+        return updateF(buffer.getMPIntAsBytes());
+    }
+
+    protected byte[] updateF(byte[] mpInt) {
+        setF(mpInt);
+        return mpInt;
+    }
+
+    protected void setF(byte[] f) {
+        this.f = f;
+
+        if (fValue != null) {
+            fValue = null;  // force lazy re-initialization
+        }
+    }
+
+    protected void validateFValue(BigInteger pValue) throws Exception {
+        BigInteger value = Objects.requireNonNull(getFValue(), "No DH 'f' 
value set");
+        if (!KeyExchange.isValidDHValue(value, pValue)) {
+            throw new SshException(
+                    SshConstants.SSH2_DISCONNECT_KEY_EXCHANGE_FAILED,
+                    "Protocol error: invalid DH 'f' value");
+        }
+    }
+
     @Override
     public String toString() {
         return getClass().getSimpleName() + "[" + getName() + "]";
diff --git 
a/sshd-core/src/main/java/org/apache/sshd/server/kex/DHGEXServer.java 
b/sshd-core/src/main/java/org/apache/sshd/server/kex/DHGEXServer.java
index f6f2c18..66fbf4a 100644
--- a/sshd-core/src/main/java/org/apache/sshd/server/kex/DHGEXServer.java
+++ b/sshd-core/src/main/java/org/apache/sshd/server/kex/DHGEXServer.java
@@ -126,7 +126,12 @@ public class DHGEXServer extends 
AbstractDHServerKeyExchange {
             }
 
             dh = chooseDH(min, prf, max);
-            f = dh.getE();
+
+            setF(dh.getE());
+
+            BigInteger pValue = dh.getP();
+            validateFValue(pValue);
+
             hash = dh.getHash();
             hash.init();
 
@@ -136,7 +141,7 @@ public class DHGEXServer extends 
AbstractDHServerKeyExchange {
             }
 
             buffer = 
session.createBuffer(SshConstants.SSH_MSG_KEX_DH_GEX_GROUP);
-            buffer.putMPInt(dh.getP());
+            buffer.putMPInt(pValue);
             buffer.putMPInt(dh.getG());
             session.writePacket(buffer);
 
@@ -157,7 +162,12 @@ public class DHGEXServer extends 
AbstractDHServerKeyExchange {
             }
 
             dh = chooseDH(min, prf, max);
-            f = dh.getE();
+
+            setF(dh.getE());
+
+            BigInteger pValue = dh.getP();
+            validateFValue(pValue);
+
             hash = dh.getHash();
             hash.init();
 
@@ -166,7 +176,7 @@ public class DHGEXServer extends 
AbstractDHServerKeyExchange {
                         this, session, min, prf, max);
             }
             buffer = 
session.createBuffer(SshConstants.SSH_MSG_KEX_DH_GEX_GROUP);
-            buffer.putMPInt(dh.getP());
+            buffer.putMPInt(pValue);
             buffer.putMPInt(dh.getG());
             session.writePacket(buffer);
 
@@ -182,11 +192,14 @@ public class DHGEXServer extends 
AbstractDHServerKeyExchange {
         }
 
         if (cmd == SshConstants.SSH_MSG_KEX_DH_GEX_INIT) {
-            e = buffer.getMPIntAsBytes();
+            byte[] e = updateE(buffer.getMPIntAsBytes());
+            BigInteger pValue = dh.getP();
+            validateEValue(pValue);
+
             dh.setF(e);
+
             k = dh.getK();
 
-            byte[] k_s;
             KeyPair kp = Objects.requireNonNull(session.getHostKey(), "No 
server key pair available");
             String algo = 
session.getNegotiatedKexParameter(KexProposalOption.SERVERKEYS);
             Signature sig = ValidateUtils.checkNotNull(
@@ -196,7 +209,8 @@ public class DHGEXServer extends 
AbstractDHServerKeyExchange {
 
             buffer = new ByteArrayBuffer();
             buffer.putRawPublicKey(kp.getPublic());
-            k_s = buffer.getCompactData();
+
+            byte[] k_s = buffer.getCompactData();
 
             buffer.clear();
             buffer.putBytes(v_c);
@@ -213,11 +227,13 @@ public class DHGEXServer extends 
AbstractDHServerKeyExchange {
                 buffer.putInt(max);
             }
 
-            buffer.putMPInt(dh.getP());
+            buffer.putMPInt(pValue);
             buffer.putMPInt(dh.getG());
             buffer.putMPInt(e);
+            byte[] f = getF();
             buffer.putMPInt(f);
             buffer.putMPInt(k);
+
             hash.update(buffer.array(), 0, buffer.available());
             h = hash.digest();
             sig.update(session, h);
diff --git a/sshd-core/src/main/java/org/apache/sshd/server/kex/DHGServer.java 
b/sshd-core/src/main/java/org/apache/sshd/server/kex/DHGServer.java
index d911f2f..6ac9337 100644
--- a/sshd-core/src/main/java/org/apache/sshd/server/kex/DHGServer.java
+++ b/sshd-core/src/main/java/org/apache/sshd/server/kex/DHGServer.java
@@ -82,7 +82,7 @@ public class DHGServer extends AbstractDHServerKeyExchange {
         dh = factory.create();
         hash = dh.getHash();
         hash.init();
-        f = dh.getE();
+        setF(dh.getE());
     }
 
     @Override
@@ -99,7 +99,7 @@ public class DHGServer extends AbstractDHServerKeyExchange {
                     "Protocol error: expected packet SSH_MSG_KEXDH_INIT, got " 
+ KeyExchange.getSimpleKexOpcodeName(cmd));
         }
 
-        e = buffer.getMPIntAsBytes();
+        byte[] e = updateE(buffer);
         dh.setF(e);
         k = dh.getK();
 
@@ -122,8 +122,10 @@ public class DHGServer extends AbstractDHServerKeyExchange 
{
         buffer.putBytes(i_s);
         buffer.putBytes(k_s);
         buffer.putMPInt(e);
+        byte[] f = getF();
         buffer.putMPInt(f);
         buffer.putMPInt(k);
+
         hash.update(buffer.array(), 0, buffer.available());
         h = hash.digest();
         sig.update(session, h);

Reply via email to