This is an automated email from the ASF dual-hosted git repository.

twolf pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/mina-sshd.git

commit d53fea9772f1be99ed944ec9e66afc3d9101bf4a
Author: Thomas Wolf <tw...@apache.org>
AuthorDate: Mon Jul 29 19:49:44 2024 +0200

    GH-524: More ChaCha20-Poly1305 optimizations
    
    In the ChaChaEngine, exploit peculiarities of its use in SSH: the
    counter is actually 32bit and never overflows, and the nonce is the
    SSH packet sequence number, also 32 bits, and wraps on overflow. As
    a result two of the ints of the engine state are always zero, and
    the handling of nonce and counter can be slightly simplified.
    
    In Poly1305Mac, inline the long multiplications. Avoid extensions
    from int to long for the precomputed values (r, s, k); store them as
    longs up front. For h, reduce the number of extensions from 25 to 5
    by doing it once before the multiplications.
    
    As a side effect this part of the code also is nicer to read.
---
 .../apache/sshd/common/cipher/ChaCha20Cipher.java  | 56 ++++++++++--------
 .../org/apache/sshd/common/mac/Poly1305Mac.java    | 68 +++++++++++-----------
 .../sshd/common/cipher/BuiltinCiphersTest.java     |  7 ++-
 3 files changed, 71 insertions(+), 60 deletions(-)

diff --git 
a/sshd-common/src/main/java/org/apache/sshd/common/cipher/ChaCha20Cipher.java 
b/sshd-common/src/main/java/org/apache/sshd/common/cipher/ChaCha20Cipher.java
index 71eec6e3f..61286a848 100644
--- 
a/sshd-common/src/main/java/org/apache/sshd/common/cipher/ChaCha20Cipher.java
+++ 
b/sshd-common/src/main/java/org/apache/sshd/common/cipher/ChaCha20Cipher.java
@@ -26,7 +26,6 @@ import javax.crypto.AEADBadTagException;
 
 import org.apache.sshd.common.mac.Mac;
 import org.apache.sshd.common.mac.Poly1305Mac;
-import org.apache.sshd.common.util.NumberUtils;
 import org.apache.sshd.common.util.ValidateUtils;
 import org.apache.sshd.common.util.buffer.BufferUtils;
 
@@ -38,11 +37,11 @@ import org.apache.sshd.common.util.buffer.BufferUtils;
 public class ChaCha20Cipher implements Cipher {
     protected final ChaChaEngine headerEngine = new ChaChaEngine();
     protected final ChaChaEngine bodyEngine = new ChaChaEngine();
-    protected final Mac mac = new Poly1305Mac();
+    protected final Mac mac;
     protected Mode mode;
 
     public ChaCha20Cipher() {
-        // empty
+        this.mac = new Poly1305Mac();
     }
 
     @Override
@@ -131,7 +130,7 @@ public class ChaCha20Cipher implements Cipher {
 
     @Override
     public int getKeySize() {
-        return 256;
+        return 512;
     }
 
     protected static class ChaChaEngine {
@@ -142,15 +141,24 @@ public class ChaCha20Cipher implements Cipher {
         private static final int KEY_INTS = KEY_BYTES / Integer.BYTES;
         private static final int COUNTER_OFFSET = 12;
         private static final int NONCE_OFFSET = 14;
-        private static final int NONCE_BYTES = 8;
-        private static final int NONCE_INTS = NONCE_BYTES / Integer.BYTES;
         private static final int[] ENGINE_STATE_HEADER = unpackSigmaString(
                 "expand 32-byte k".getBytes(StandardCharsets.US_ASCII));
 
         protected final int[] engineState = new int[BLOCK_INTS];
         protected final byte[] keyStream = new byte[BLOCK_BYTES];
-        protected final byte[] nonce = new byte[NONCE_BYTES];
+        protected final byte[] nonce = new byte[Integer.BYTES];
         protected long initialNonce;
+        protected long nonceVal;
+
+        // Elements 12 to 15 in the engineState are the counter and the nonce. 
The counter is a 64bit little-
+        // endian value; the nonce is a 64bit big-endian value.
+        //
+        // The counter always starts at zero, is incremented with each full 
block (64 bytes), and in SSH never
+        // overflows 32bits because it counts only inside a single SSH packet. 
The nonce in SSH is the packet
+        // sequence number, which is a 32bit unsigned int that wraps around on 
overflow.
+        //
+        // Therefore, engineState[13] and engineState[14] are always zero. 
engineState[12] is the counter, and
+        // engineState[15] is the packet sequence number in inverse byte order.
 
         protected ChaChaEngine() {
             System.arraycopy(ENGINE_STATE_HEADER, 0, engineState, 0, 4);
@@ -161,21 +169,25 @@ public class ChaCha20Cipher implements Cipher {
         }
 
         protected void initNonce(byte[] nonce) {
-            initialNonce = BufferUtils.getLong(nonce, 0, 
NumberUtils.length(nonce));
-            unpackIntsLE(nonce, 0, NONCE_INTS, engineState, NONCE_OFFSET);
-            System.arraycopy(nonce, 0, this.nonce, 0, NONCE_BYTES);
+            long hiBits = BufferUtils.getUInt(nonce, 0, Integer.BYTES);
+            ValidateUtils.checkState(hiBits == 0, "ChaCha20 nonce is not a 
valid SSH packet sequence number");
+            initialNonce = BufferUtils.getUInt(nonce, Integer.BYTES, 
Integer.BYTES);
+            nonceVal = initialNonce;
+            engineState[NONCE_OFFSET] = 0;
+            engineState[NONCE_OFFSET + 1] = Poly1305Mac.unpackIntLE(nonce, 
Integer.BYTES);
         }
 
         protected void advanceNonce() {
-            long counter = BufferUtils.getLong(nonce, 0, NONCE_BYTES) + 1;
-            ValidateUtils.checkState(counter != initialNonce, "Packet sequence 
number cannot be reused with the same key");
-            BufferUtils.putLong(counter, nonce, 0, NONCE_BYTES);
-            unpackIntsLE(nonce, 0, NONCE_INTS, engineState, NONCE_OFFSET);
+            // SSH packet sequence number wraps around on uint32 overflow.
+            nonceVal = (nonceVal + 1) & 0xFFFF_FFFFL;
+            ValidateUtils.checkState(nonceVal != initialNonce, "Packet 
sequence number cannot be reused with the same key");
+            BufferUtils.putUInt(nonceVal, nonce, 0, Integer.BYTES);
+            engineState[NONCE_OFFSET + 1] = Poly1305Mac.unpackIntLE(nonce, 0);
         }
 
         protected void initCounter(long counter) {
             engineState[COUNTER_OFFSET] = (int) counter;
-            engineState[COUNTER_OFFSET + 1] = (int) (counter >>> Integer.SIZE);
+            engineState[COUNTER_OFFSET + 1] = 0; // Always zero; and counter 
never overflows in SSH.
         }
 
         // one-shot usage
@@ -187,11 +199,7 @@ public class ChaCha20Cipher implements Cipher {
                     out[outOffset++] = (byte) (in[offset++] ^ keyStream[i]);
                 }
                 length -= want;
-                int lo = ++engineState[COUNTER_OFFSET];
-                if (lo == 0) {
-                    // overflow
-                    ++engineState[COUNTER_OFFSET + 1];
-                }
+                ++engineState[COUNTER_OFFSET]; // Never overflows in SSH
             }
         }
 
@@ -216,10 +224,10 @@ public class ChaCha20Cipher implements Cipher {
             int x9 = engine[9];
             int x10 = engine[10];
             int x11 = engine[11];
-            int x12 = engine[12];
-            int x13 = engine[13];
-            int x14 = engine[14];
-            int x15 = engine[15];
+            int x12 = engine[12]; // counter
+            int x13 = engine[13]; // 0
+            int x14 = engine[14]; // 0
+            int x15 = engine[15]; // nonce
 
             for (int i = 0; i < 10; i++) {
                 // Columns
diff --git 
a/sshd-common/src/main/java/org/apache/sshd/common/mac/Poly1305Mac.java 
b/sshd-common/src/main/java/org/apache/sshd/common/mac/Poly1305Mac.java
index 4fa809b2f..554d619e3 100644
--- a/sshd-common/src/main/java/org/apache/sshd/common/mac/Poly1305Mac.java
+++ b/sshd-common/src/main/java/org/apache/sshd/common/mac/Poly1305Mac.java
@@ -36,19 +36,19 @@ public class Poly1305Mac implements Mac {
     public static final int KEY_BYTES = 32;
     private static final int BLOCK_SIZE = 16;
 
-    private int r0;
-    private int r1;
-    private int r2;
-    private int r3;
-    private int r4;
-    private int s1;
-    private int s2;
-    private int s3;
-    private int s4;
-    private int k0;
-    private int k1;
-    private int k2;
-    private int k3;
+    private long r0;
+    private long r1;
+    private long r2;
+    private long r3;
+    private long r4;
+    private long s1;
+    private long s2;
+    private long s3;
+    private long s4;
+    private long k0;
+    private long k1;
+    private long k2;
+    private long k3;
 
     private int h0;
     private int h1;
@@ -91,10 +91,10 @@ public class Poly1305Mac implements Mac {
         s3 = r3 * 5;
         s4 = r4 * 5;
 
-        k0 = unpackIntLE(key, 16);
-        k1 = unpackIntLE(key, 20);
-        k2 = unpackIntLE(key, 24);
-        k3 = unpackIntLE(key, 28);
+        k0 = unpackIntLE(key, 16) & 0xFFFF_FFFFL;
+        k1 = unpackIntLE(key, 20) & 0xFFFF_FFFFL;
+        k2 = unpackIntLE(key, 24) & 0xFFFF_FFFFL;
+        k3 = unpackIntLE(key, 28) & 0xFFFF_FFFFL;
 
         currentBlockOffset = 0;
     }
@@ -186,10 +186,10 @@ public class Poly1305Mac implements Mac {
         h3 = h3 & nb | g3 & b;
         h4 = h4 & nb | g4 & b;
 
-        long f0 = ((h0 | h1 << 26) & 0xFFFF_FFFFL) + (k0 & 0xFFFF_FFFFL);
-        long f1 = ((h1 >>> 6 | h2 << 20) & 0xFFFF_FFFFL) + (k1 & 0xFFFF_FFFFL);
-        long f2 = ((h2 >>> 12 | h3 << 14) & 0xFFFF_FFFFL) + (k2 & 
0xFFFF_FFFFL);
-        long f3 = ((h3 >>> 18 | h4 << 8) & 0xFFFF_FFFFL) + (k3 & 0xFFFF_FFFFL);
+        long f0 = ((h0 | h1 << 26) & 0xFFFF_FFFFL) + k0;
+        long f1 = ((h1 >>> 6 | h2 << 20) & 0xFFFF_FFFFL) + k1;
+        long f2 = ((h2 >>> 12 | h3 << 14) & 0xFFFF_FFFFL) + k2;
+        long f3 = ((h3 >>> 18 | h4 << 8) & 0xFFFF_FFFFL) + k3;
 
         packIntLE((int) f0, out, offset);
         f1 += f0 >>> 32;
@@ -219,16 +219,18 @@ public class Poly1305Mac implements Mac {
             h4 += 1 << 24;
         }
 
-        long tp0 = unsignedProduct(h0, r0) + unsignedProduct(h1, s4) + 
unsignedProduct(h2, s3) + unsignedProduct(h3, s2)
-                   + unsignedProduct(h4, s1);
-        long tp1 = unsignedProduct(h0, r1) + unsignedProduct(h1, r0) + 
unsignedProduct(h2, s4) + unsignedProduct(h3, s3)
-                   + unsignedProduct(h4, s2);
-        long tp2 = unsignedProduct(h0, r2) + unsignedProduct(h1, r1) + 
unsignedProduct(h2, r0) + unsignedProduct(h3, s4)
-                   + unsignedProduct(h4, s3);
-        long tp3 = unsignedProduct(h0, r3) + unsignedProduct(h1, r2) + 
unsignedProduct(h2, r1) + unsignedProduct(h3, r0)
-                   + unsignedProduct(h4, s4);
-        long tp4 = unsignedProduct(h0, r4) + unsignedProduct(h1, r3) + 
unsignedProduct(h2, r2) + unsignedProduct(h3, r1)
-                   + unsignedProduct(h4, r0);
+        // The high bits of h0 to h4 are guaranteed to be zero, so we can just 
let the compiler extend the ints.
+        // No need to do a & 0xFFFF_FFFFL.
+        long l0 = h0;
+        long l1 = h1;
+        long l2 = h2;
+        long l3 = h3;
+        long l4 = h4;
+        long tp0 = l0 * r0 + l1 * s4 + l2 * s3 + l3 * s2 + l4 * s1;
+        long tp1 = l0 * r1 + l1 * r0 + l2 * s4 + l3 * s3 + l4 * s2;
+        long tp2 = l0 * r2 + l1 * r1 + l2 * r0 + l3 * s4 + l4 * s3;
+        long tp3 = l0 * r3 + l1 * r2 + l2 * r1 + l3 * r0 + l4 * s4;
+        long tp4 = l0 * r4 + l1 * r3 + l2 * r2 + l3 * r1 + l4 * r0;
 
         h0 = (int) tp0 & 0x3ffffff;
         tp1 += tp0 >>> 26;
@@ -278,8 +280,4 @@ public class Poly1305Mac implements Mac {
         dst[off++] = (byte) (value >> 16);
         dst[off] = (byte) (value >> 24);
     }
-
-    private static long unsignedProduct(int i1, int i2) {
-        return (i1 & 0xFFFF_FFFFL) * i2;
-    }
 }
diff --git 
a/sshd-core/src/test/java/org/apache/sshd/common/cipher/BuiltinCiphersTest.java 
b/sshd-core/src/test/java/org/apache/sshd/common/cipher/BuiltinCiphersTest.java
index b23213f4e..e4448c057 100644
--- 
a/sshd-core/src/test/java/org/apache/sshd/common/cipher/BuiltinCiphersTest.java
+++ 
b/sshd-core/src/test/java/org/apache/sshd/common/cipher/BuiltinCiphersTest.java
@@ -163,7 +163,12 @@ public class BuiltinCiphersTest extends BaseTestSupport {
         byte[] key = new byte[cipher.getKdfSize()];
         rnd.nextBytes(key);
         byte[] iv = new byte[cipher.getIVSize()];
-        rnd.nextBytes(iv);
+        // ChaCha20 has an SSH packet sequence number as IV! Do not use random 
IVs with ChaCha20!
+        if (cipher.getAlgorithm().startsWith("ChaCha20")) {
+            iv[iv.length - 1] = 42;
+        } else {
+            rnd.nextBytes(iv);
+        }
         cipher.init(Cipher.Mode.Encrypt, key, iv);
 
         byte[] data = new byte[cipher.getCipherBlockSize() + 
cipher.getAuthenticationTagSize()];

Reply via email to