This is an automated email from the ASF dual-hosted git repository.

technoboy pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/pulsar.git


The following commit(s) were added to refs/heads/branch-4.1 by this push:
     new e11ff901679 [fix][client] Fix thread-safety and refactor 
MessageCryptoBc key management (#25400)
e11ff901679 is described below

commit e11ff901679a72f8efc15b27cc11389086fc71e9
Author: Lari Hotari <[email protected]>
AuthorDate: Thu Mar 26 14:22:48 2026 +0200

    [fix][client] Fix thread-safety and refactor MessageCryptoBc key management 
(#25400)
---
 distribution/shell/src/assemble/LICENSE.bin.txt    |   1 +
 .../anonymizer/DefaultRoleAnonymizerType.java      |  50 ++--
 .../websocket/proxy/WssClientSideEncryptUtils.java |   2 +-
 pulsar-client-messagecrypto-bc/pom.xml             |   5 +
 .../pulsar/client/impl/crypto/MessageCryptoBc.java | 321 +++++++++++----------
 .../apache/pulsar/common/util/SecurityUtility.java |  23 +-
 6 files changed, 205 insertions(+), 197 deletions(-)

diff --git a/distribution/shell/src/assemble/LICENSE.bin.txt 
b/distribution/shell/src/assemble/LICENSE.bin.txt
index 83765a3c3bd..c4bfdf8a3de 100644
--- a/distribution/shell/src/assemble/LICENSE.bin.txt
+++ b/distribution/shell/src/assemble/LICENSE.bin.txt
@@ -324,6 +324,7 @@ The Apache Software License, Version 2.0
      - jackson-datatype-jdk8-2.18.6.jar
      - jackson-datatype-jsr310-2.18.6.jar
      - jackson-module-parameter-names-2.18.6.jar
+ * Caffeine -- caffeine-3.2.3.jar
  * Conscrypt -- conscrypt-openjdk-uber-2.5.2.jar
  * Gson
     - gson-2.13.2.jar
diff --git 
a/pulsar-broker-common/src/main/java/org/apache/pulsar/common/configuration/anonymizer/DefaultRoleAnonymizerType.java
 
b/pulsar-broker-common/src/main/java/org/apache/pulsar/common/configuration/anonymizer/DefaultRoleAnonymizerType.java
index 3333b69cf22..d30769aa2f1 100644
--- 
a/pulsar-broker-common/src/main/java/org/apache/pulsar/common/configuration/anonymizer/DefaultRoleAnonymizerType.java
+++ 
b/pulsar-broker-common/src/main/java/org/apache/pulsar/common/configuration/anonymizer/DefaultRoleAnonymizerType.java
@@ -18,8 +18,8 @@
  */
 package org.apache.pulsar.common.configuration.anonymizer;
 
+import io.netty.util.concurrent.FastThreadLocal;
 import java.security.MessageDigest;
-import java.security.NoSuchAlgorithmException;
 import java.util.Base64;
 
 public enum DefaultRoleAnonymizerType {
@@ -37,44 +37,44 @@ public enum DefaultRoleAnonymizerType {
    },
    SHA256 {
       private static final String PREFIX = "SHA-256:";
-      private final MessageDigest digest;
-
-      {
-         // Initializing the MessageDigest once for SHA-256
-         try {
-            digest = MessageDigest.getInstance("SHA-256");
-         } catch (NoSuchAlgorithmException e) {
-            throw new RuntimeException("SHA-256 algorithm not found", e);
+      private static final FastThreadLocal<MessageDigest> DIGEST = new 
FastThreadLocal<MessageDigest>() {
+         @Override
+         protected MessageDigest initialValue() throws Exception {
+            return MessageDigest.getInstance("SHA-256");
          }
-      }
+      };
 
       @Override
       public String anonymize(String role) {
-         byte[] hash = digest.digest(role.getBytes());
-         return PREFIX + Base64.getEncoder().encodeToString(hash);
+         try {
+            byte[] hash = DIGEST.get().digest(role.getBytes());
+            return PREFIX + Base64.getEncoder().encodeToString(hash);
+         } catch (Exception e) {
+            throw new RuntimeException("SHA-256 algorithm not found", e);
+         }
       }
    },
    MD5 {
       private static final String PREFIX = "MD5:";
-      private final MessageDigest digest;
-
-      {
-         // Initializing the MessageDigest once for MD5
-         try {
-            // codeql[java/weak-cryptographic-algorithm] - md5 is sufficient 
for this use case&
-            digest = MessageDigest.getInstance("MD5");
-         } catch (NoSuchAlgorithmException e) {
-            throw new RuntimeException("MD5 algorithm not found", e);
+      private static final FastThreadLocal<MessageDigest> DIGEST = new 
FastThreadLocal<MessageDigest>() {
+         @Override
+         protected MessageDigest initialValue() throws Exception {
+            // codeql[java/weak-cryptographic-algorithm] - md5 is sufficient 
for this use case
+            return MessageDigest.getInstance("MD5");
          }
-      }
+      };
 
       @Override
       public String anonymize(String role) {
-         byte[] hash = digest.digest(role.getBytes());
-         return PREFIX + Base64.getEncoder().encodeToString(hash);
+         try {
+            byte[] hash = DIGEST.get().digest(role.getBytes());
+            return PREFIX + Base64.getEncoder().encodeToString(hash);
+         } catch (Exception e) {
+            throw new RuntimeException("MD5 algorithm not found", e);
+         }
       }
    };
 
    private static final String REDACTED_VALUE = "[REDACTED]";
    public abstract String anonymize(String role);
-}
\ No newline at end of file
+}
diff --git 
a/pulsar-broker/src/test/java/org/apache/pulsar/websocket/proxy/WssClientSideEncryptUtils.java
 
b/pulsar-broker/src/test/java/org/apache/pulsar/websocket/proxy/WssClientSideEncryptUtils.java
index 1fb0fd2fd50..7e05eeb7d66 100644
--- 
a/pulsar-broker/src/test/java/org/apache/pulsar/websocket/proxy/WssClientSideEncryptUtils.java
+++ 
b/pulsar-broker/src/test/java/org/apache/pulsar/websocket/proxy/WssClientSideEncryptUtils.java
@@ -131,7 +131,7 @@ public class WssClientSideEncryptUtils {
         try {
             PublicKey pubKey = MessageCryptoBc.loadPublicKey(publicKeyData);
             Cipher dataKeyCipher = loadAndInitCipher(pubKey);
-            return dataKeyCipher.doFinal(msgCrypto.getDataKey().getEncoded());
+            return 
dataKeyCipher.doFinal(msgCrypto.getEncryptionKey().getEncoded());
         } catch (Exception e) {
             log.error("Failed to encrypt data key. {}", e.getMessage());
             throw new PulsarClientException.CryptoException(e.getMessage());
diff --git a/pulsar-client-messagecrypto-bc/pom.xml 
b/pulsar-client-messagecrypto-bc/pom.xml
index 691cebad686..1475103b4ea 100644
--- a/pulsar-client-messagecrypto-bc/pom.xml
+++ b/pulsar-client-messagecrypto-bc/pom.xml
@@ -47,6 +47,11 @@
       </exclusions>
     </dependency>
 
+    <dependency>
+      <groupId>com.github.ben-manes.caffeine</groupId>
+      <artifactId>caffeine</artifactId>
+    </dependency>
+
     <dependency>
       <groupId>${project.groupId}</groupId>
       <artifactId>bouncy-castle-bc</artifactId>
diff --git 
a/pulsar-client-messagecrypto-bc/src/main/java/org/apache/pulsar/client/impl/crypto/MessageCryptoBc.java
 
b/pulsar-client-messagecrypto-bc/src/main/java/org/apache/pulsar/client/impl/crypto/MessageCryptoBc.java
index 62cc00cd61e..ab071ae6558 100644
--- 
a/pulsar-client-messagecrypto-bc/src/main/java/org/apache/pulsar/client/impl/crypto/MessageCryptoBc.java
+++ 
b/pulsar-client-messagecrypto-bc/src/main/java/org/apache/pulsar/client/impl/crypto/MessageCryptoBc.java
@@ -18,9 +18,9 @@
  */
 package org.apache.pulsar.client.impl.crypto;
 
-import com.google.common.cache.CacheBuilder;
-import com.google.common.cache.CacheLoader;
-import com.google.common.cache.LoadingCache;
+import com.github.benmanes.caffeine.cache.Cache;
+import com.github.benmanes.caffeine.cache.Caffeine;
+import io.netty.util.concurrent.FastThreadLocal;
 import java.io.IOException;
 import java.io.Reader;
 import java.io.StringReader;
@@ -28,15 +28,16 @@ import java.nio.ByteBuffer;
 import java.security.InvalidAlgorithmParameterException;
 import java.security.InvalidKeyException;
 import java.security.KeyFactory;
-import java.security.MessageDigest;
 import java.security.NoSuchAlgorithmException;
 import java.security.NoSuchProviderException;
 import java.security.PrivateKey;
 import java.security.PublicKey;
 import java.security.SecureRandom;
 import java.security.Security;
+import java.security.interfaces.ECPrivateKey;
 import java.security.spec.AlgorithmParameterSpec;
 import java.security.spec.InvalidKeySpecException;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
@@ -69,8 +70,7 @@ import org.bouncycastle.asn1.x509.SubjectPublicKeyInfo;
 import org.bouncycastle.asn1.x9.ECNamedCurveTable;
 import org.bouncycastle.asn1.x9.X9ECParameters;
 import org.bouncycastle.cert.X509CertificateHolder;
-import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPrivateKey;
-import org.bouncycastle.jcajce.provider.asymmetric.ec.BCECPublicKey;
+import org.bouncycastle.jce.interfaces.ECPublicKey;
 import org.bouncycastle.jce.provider.BouncyCastleProvider;
 import org.bouncycastle.jce.spec.ECParameterSpec;
 import org.bouncycastle.jce.spec.ECPrivateKeySpec;
@@ -83,7 +83,6 @@ import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter;
 
 @Slf4j
 public class MessageCryptoBc implements MessageCrypto<MessageMetadata, 
MessageMetadata> {
-
     public static final String ECDSA = "ECDSA";
     public static final String RSA = "RSA";
     public static final String ECIES = "ECIES";
@@ -94,31 +93,26 @@ public class MessageCryptoBc implements 
MessageCrypto<MessageMetadata, MessageMe
     public static final String AESGCM = "AES/GCM/NoPadding";
     private static final String AESGCM_PROVIDER_NAME;
 
-    private static KeyGenerator keyGenerator;
     private static final int tagLen = 16 * 8;
-    private byte[] iv = new byte[IV_LEN];
-    private Cipher cipher;
-    MessageDigest digest;
-    private String logCtx;
+    private final String logCtx;
 
     // Data key which is used to encrypt message
     @Getter
-    private SecretKey dataKey;
-    private LoadingCache<ByteBuffer, SecretKey> dataKeyCache;
+    private volatile SecretKey encryptionKey;
+    private final Cache<SecretKeyCacheKey, SecretKeySpec> decryptionKeyCache;
+    private final Cache<String, SecretKey> lastDecryptionKeyCache;
 
     // Map of key name and encrypted gcm key, metadata pair which is sent with 
encrypted message
-    private ConcurrentHashMap<String, EncryptionKeyInfo> encryptedDataKeyMap;
-
+    private final ConcurrentHashMap<String, EncryptionKeyInfo> 
encryptedDataKeyMap;
 
     private static final SecureRandom secureRandom;
     static {
-        SecureRandom rand = null;
+        SecureRandom rand;
         try {
             rand = SecureRandom.getInstance("NativePRNGNonBlocking");
         } catch (NoSuchAlgorithmException nsa) {
             rand = new SecureRandom();
         }
-
         secureRandom = rand;
 
         // Initial seed
@@ -139,79 +133,101 @@ public class MessageCryptoBc implements 
MessageCrypto<MessageMetadata, MessageMe
         }
     }
 
-    public MessageCryptoBc(String logCtx, boolean keyGenNeeded) {
-
-        this.logCtx = logCtx;
-        encryptedDataKeyMap = new ConcurrentHashMap<String, 
EncryptionKeyInfo>();
-        dataKeyCache = CacheBuilder.newBuilder().expireAfterAccess(4, 
TimeUnit.HOURS)
-                .build(new CacheLoader<ByteBuffer, SecretKey>() {
-
-                    @Override
-                    public SecretKey load(ByteBuffer key) {
-                        return null;
-                    }
-
-                });
-
-        try {
-
-            cipher = Cipher.getInstance(AESGCM, AESGCM_PROVIDER_NAME);
-            // If keygen is not needed(e.g: consumer), data key will be 
decrypted from the message
-            if (!keyGenNeeded) {
-                // codeql[java/weak-cryptographic-algorithm] - md5 is 
sufficient for this use case
-                digest = MessageDigest.getInstance("MD5");
+    // Thread-local instances for non-thread-safe JCA classes
+    private static final FastThreadLocal<Cipher> THREAD_LOCAL_CIPHER = new 
FastThreadLocal<Cipher>() {
+        @Override
+        protected Cipher initialValue() throws Exception {
+            return Cipher.getInstance(AESGCM, AESGCM_PROVIDER_NAME);
+        }
+    };
 
-                dataKey = null;
-                return;
-            }
-            keyGenerator = KeyGenerator.getInstance("AES");
+    private static final FastThreadLocal<KeyGenerator> 
THREAD_LOCAL_KEY_GENERATOR =
+            new FastThreadLocal<KeyGenerator>() {
+        @Override
+        protected KeyGenerator initialValue() throws Exception {
+            KeyGenerator kg = KeyGenerator.getInstance("AES");
             int aesKeyLength = Cipher.getMaxAllowedKeyLength("AES");
             if (aesKeyLength <= 128) {
-                log.warn("{} AES Cryptographic strength is limited to {} bits. 
"
+                log.warn("AES Cryptographic strength is limited to {} bits. "
                         + "Consider installing JCE Unlimited Strength 
Jurisdiction Policy Files.",
-                        logCtx, aesKeyLength);
-                keyGenerator.init(aesKeyLength, secureRandom);
+                        aesKeyLength);
+                kg.init(aesKeyLength, secureRandom);
             } else {
-                keyGenerator.init(256, secureRandom);
+                kg.init(256, secureRandom);
             }
+            return kg;
+        }
+    };
 
-        } catch (NoSuchAlgorithmException | NoSuchProviderException | 
NoSuchPaddingException e) {
-
-            cipher = null;
-            log.error("{} MessageCrypto initialization Failed {}", logCtx, 
e.getMessage());
-
+    private static Cipher getAesGcmCipher() throws CryptoException {
+        try {
+            return THREAD_LOCAL_CIPHER.get();
+        } catch (Exception e) {
+            log.error("Failed to get AES-GCM cipher instance. {}", 
e.getMessage());
+            throw new PulsarClientException.CryptoException(e.getMessage());
         }
+    }
 
-        // Generate data key to encrypt messages
-        dataKey = keyGenerator.generateKey();
+    private static KeyGenerator getKeyGenerator() throws CryptoException {
+        try {
+            return THREAD_LOCAL_KEY_GENERATOR.get();
+        } catch (Exception e) {
+            log.error("Failed to get AES key generator instance. {}", 
e.getMessage());
+            throw new PulsarClientException.CryptoException(e.getMessage());
+        }
+    }
 
-        iv = new byte[IV_LEN];
+    private static SecretKey generateEncryptionKey() throws CryptoException {
+        return getKeyGenerator().generateKey();
+    }
 
+    public MessageCryptoBc(String logCtx, boolean keyGenNeeded) {
+        this.logCtx = logCtx;
+        encryptedDataKeyMap = new ConcurrentHashMap<String, 
EncryptionKeyInfo>();
+        decryptionKeyCache = Caffeine.newBuilder()
+                .expireAfterAccess(4, TimeUnit.HOURS)
+                .weigher((SecretKeyCacheKey key, SecretKeySpec value) -> 
key.encryptedKeyBytes.length
+                        + value.getEncoded().length)
+                .maximumWeight(10 * 1024 * 1024) // 10MB upperbound
+                .build();
+        lastDecryptionKeyCache = Caffeine.newBuilder()
+                .expireAfterAccess(4, TimeUnit.HOURS)
+                .maximumWeight(10 * 1024 * 1024) // 10MB upperbound
+                .weigher((String key, SecretKey value) -> key.length() + 
value.getEncoded().length)
+                .build();
+        if (keyGenNeeded) {
+            // Generate data key to encrypt messages
+            try {
+                encryptionKey = generateEncryptionKey();
+            } catch (CryptoException e) {
+                // retain same contract as before
+                if (e.getCause() instanceof RuntimeException) {
+                    throw (RuntimeException) e.getCause();
+                } else {
+                    throw new RuntimeException(e.getCause());
+                }
+            }
+        }
     }
 
     public static PublicKey loadPublicKey(byte[] keyBytes) throws Exception {
-
         Reader keyReader = new StringReader(new String(keyBytes));
-        PublicKey publicKey = null;
+        PublicKey publicKey;
         try (PEMParser pemReader = new PEMParser(keyReader)) {
             Object pemObj = pemReader.readObject();
             JcaPEMKeyConverter pemConverter = new JcaPEMKeyConverter();
-            SubjectPublicKeyInfo keyInfo = null;
+            SubjectPublicKeyInfo keyInfo;
             X9ECParameters ecParam = null;
 
             if (pemObj instanceof ASN1ObjectIdentifier) {
-
                 // make sure this is EC Parameter we're handling. In which case
                 // we'll store it and read the next object which should be our
                 // EC Public Key
-
                 ASN1ObjectIdentifier ecOID = (ASN1ObjectIdentifier) pemObj;
                 ecParam = ECNamedCurveTable.getByOID(ecOID);
                 if (ecParam == null) {
-                    throw new PEMException("Unable to find EC Parameter for 
the given curve oid: "
-                            + ((ASN1ObjectIdentifier) pemObj).getId());
+                    throw new PEMException("Unable to find EC Parameter for 
the given curve oid: " + ecOID.getId());
                 }
-
                 pemObj = pemReader.readObject();
             } else if (pemObj instanceof X9ECParameters) {
                 ecParam = (X9ECParameters) pemObj;
@@ -229,7 +245,7 @@ public class MessageCryptoBc implements 
MessageCrypto<MessageMetadata, MessageMe
                 ECParameterSpec ecSpec = new 
ECParameterSpec(ecParam.getCurve(), ecParam.getG(), ecParam.getN(),
                         ecParam.getH(), ecParam.getSeed());
                 KeyFactory keyFactory = KeyFactory.getInstance(ECDSA, 
BouncyCastleProvider.PROVIDER_NAME);
-                ECPublicKeySpec keySpec = new ECPublicKeySpec(((BCECPublicKey) 
publicKey).getQ(), ecSpec);
+                ECPublicKeySpec keySpec = new ECPublicKeySpec(((ECPublicKey) 
publicKey).getQ(), ecSpec);
                 publicKey = keyFactory.generatePublic(keySpec);
             }
         } catch (IOException | NoSuchAlgorithmException | 
NoSuchProviderException | InvalidKeySpecException e) {
@@ -238,8 +254,7 @@ public class MessageCryptoBc implements 
MessageCrypto<MessageMetadata, MessageMe
         return publicKey;
     }
 
-    private PrivateKey loadPrivateKey(byte[] keyBytes) throws Exception {
-
+    private static PrivateKey loadPrivateKey(byte[] keyBytes) throws Exception 
{
         Reader keyReader = new StringReader(new String(keyBytes));
         PrivateKey privateKey = null;
         try (PEMParser pemReader = new PEMParser(keyReader)) {
@@ -248,31 +263,24 @@ public class MessageCryptoBc implements 
MessageCrypto<MessageMetadata, MessageMe
             Object pemObj = pemReader.readObject();
 
             if (pemObj instanceof ASN1ObjectIdentifier) {
-
                 // make sure this is EC Parameter we're handling. In which case
                 // we'll store it and read the next object which should be our
                 // EC Private Key
-
                 ASN1ObjectIdentifier ecOID = (ASN1ObjectIdentifier) pemObj;
                 ecParam = ECNamedCurveTable.getByOID(ecOID);
                 if (ecParam == null) {
                     throw new PEMException("Unable to find EC Parameter for 
the given curve oid: " + ecOID.getId());
                 }
-
                 pemObj = pemReader.readObject();
-
             } else if (pemObj instanceof X9ECParameters) {
-
                 ecParam = (X9ECParameters) pemObj;
                 pemObj = pemReader.readObject();
             }
 
             if (pemObj instanceof PEMKeyPair) {
-
                 PrivateKeyInfo pKeyInfo = ((PEMKeyPair) 
pemObj).getPrivateKeyInfo();
                 JcaPEMKeyConverter pemConverter = new JcaPEMKeyConverter();
                 privateKey = pemConverter.getPrivateKey(pKeyInfo);
-
             } else if (pemObj instanceof PrivateKeyInfo) {
                 JcaPEMKeyConverter pemConverter = new JcaPEMKeyConverter();
                 privateKey = pemConverter.getPrivateKey((PrivateKeyInfo) 
pemObj);
@@ -280,12 +288,11 @@ public class MessageCryptoBc implements 
MessageCrypto<MessageMetadata, MessageMe
 
             // if our private key is EC type and we have parameters specified
             // then we need to set it accordingly
-
             if (ecParam != null && ECDSA.equals(privateKey.getAlgorithm())) {
                 ECParameterSpec ecSpec = new 
ECParameterSpec(ecParam.getCurve(), ecParam.getG(), ecParam.getN(),
                         ecParam.getH(), ecParam.getSeed());
                 KeyFactory keyFactory = KeyFactory.getInstance(ECDSA, 
BouncyCastleProvider.PROVIDER_NAME);
-                ECPrivateKeySpec keySpec = new 
ECPrivateKeySpec(((BCECPrivateKey) privateKey).getS(), ecSpec);
+                ECPrivateKeySpec keySpec = new 
ECPrivateKeySpec(((ECPrivateKey) privateKey).getS(), ecSpec);
                 privateKey = keyFactory.generatePrivate(keySpec);
             }
 
@@ -306,11 +313,9 @@ public class MessageCryptoBc implements 
MessageCrypto<MessageMetadata, MessageMe
      *
      */
     @Override
-    public synchronized void addPublicKeyCipher(Set<String> keyNames, 
CryptoKeyReader keyReader)
-            throws CryptoException {
-
-        // Generate data key
-        dataKey = keyGenerator.generateKey();
+    public void addPublicKeyCipher(Set<String> keyNames, CryptoKeyReader 
keyReader) throws CryptoException {
+        // Rotate the encryption key each time this method is called
+        encryptionKey = generateEncryptionKey();
 
         for (String key : keyNames) {
             addPublicKeyCipher(key, keyReader);
@@ -318,7 +323,6 @@ public class MessageCryptoBc implements 
MessageCrypto<MessageMetadata, MessageMe
     }
 
     private void addPublicKeyCipher(String keyName, CryptoKeyReader keyReader) 
throws CryptoException {
-
         if (keyName == null || keyReader == null) {
             throw new PulsarClientException.CryptoException("Keyname or 
KeyReader is null");
         }
@@ -336,9 +340,8 @@ public class MessageCryptoBc implements 
MessageCrypto<MessageMetadata, MessageMe
             throw new PulsarClientException.CryptoException(msg);
         }
 
-        Cipher dataKeyCipher = null;
+        Cipher dataKeyCipher;
         byte[] encryptedKey;
-
         try {
             AlgorithmParameterSpec params = null;
             // Encrypt data key using public key
@@ -357,8 +360,7 @@ public class MessageCryptoBc implements 
MessageCrypto<MessageMetadata, MessageMe
             } else {
                 dataKeyCipher.init(Cipher.ENCRYPT_MODE, pubKey);
             }
-            encryptedKey = dataKeyCipher.doFinal(dataKey.getEncoded());
-
+            encryptedKey = dataKeyCipher.doFinal(encryptionKey.getEncoded());
         } catch (IllegalBlockSizeException | BadPaddingException | 
NoSuchAlgorithmException | NoSuchProviderException
                  | NoSuchPaddingException | InvalidKeyException | 
InvalidAlgorithmParameterException e) {
             log.error("{} Failed to encrypt data key {}. {}", logCtx, keyName, 
e.getMessage());
@@ -384,7 +386,6 @@ public class MessageCryptoBc implements 
MessageCrypto<MessageMetadata, MessageMe
      */
     @Override
     public boolean removeKeyCipher(String keyName) {
-
         if (keyName == null) {
             return false;
         }
@@ -404,10 +405,9 @@ public class MessageCryptoBc implements 
MessageCrypto<MessageMetadata, MessageMe
      * @return encryptedData if success
      */
     @Override
-    public synchronized void encrypt(Set<String> encKeys, CryptoKeyReader 
keyReader,
-                                        Supplier<MessageMetadata> 
messageMetadataBuilderSupplier,
-                                     ByteBuffer payload, ByteBuffer outBuffer) 
throws PulsarClientException {
-
+    public void encrypt(Set<String> encKeys, CryptoKeyReader keyReader,
+                        Supplier<MessageMetadata> 
messageMetadataBuilderSupplier,
+                        ByteBuffer payload, ByteBuffer outBuffer) throws 
PulsarClientException {
         MessageMetadata msgMetadata = messageMetadataBuilderSupplier.get();
 
         if (encKeys.isEmpty()) {
@@ -444,11 +444,11 @@ public class MessageCryptoBc implements 
MessageCrypto<MessageMetadata, MessageMe
                 // We should never reach here.
                 log.error("{} Failed to find encrypted Data key for key {}.", 
logCtx, keyName);
             }
-
         }
 
         // Create gcm param
         // TODO: Replace random with counter and periodic refreshing based on 
timer/counter value
+        byte[] iv = new byte[IV_LEN];
         secureRandom.nextBytes(iv);
         GCMParameterSpec gcmParam = new GCMParameterSpec(tagLen, iv);
 
@@ -457,7 +457,8 @@ public class MessageCryptoBc implements 
MessageCrypto<MessageMetadata, MessageMe
 
         try {
             // Encrypt the data
-            cipher.init(Cipher.ENCRYPT_MODE, dataKey, gcmParam);
+            Cipher cipher = getAesGcmCipher();
+            cipher.init(Cipher.ENCRYPT_MODE, encryptionKey, gcmParam);
 
             int maxLength = cipher.getOutputSize(payload.remaining());
             if (outBuffer.remaining() < maxLength) {
@@ -474,9 +475,8 @@ public class MessageCryptoBc implements 
MessageCrypto<MessageMetadata, MessageMe
         }
     }
 
-    private boolean decryptDataKey(String keyName, byte[] encryptedDataKey, 
List<KeyValue> encKeyMeta,
+    private SecretKeySpec tryDecryptDataKey(String keyName, byte[] 
encryptedDataKey, List<KeyValue> encKeyMeta,
             CryptoKeyReader keyReader) {
-
         Map<String, String> keyMeta = new HashMap<String, String>();
         encKeyMeta.forEach(kv -> {
             keyMeta.put(kv.getKey(), kv.getValue());
@@ -491,20 +491,17 @@ public class MessageCryptoBc implements 
MessageCrypto<MessageMetadata, MessageMe
             privateKey = loadPrivateKey(keyInfo.getKey());
             if (privateKey == null) {
                 log.error("{} Failed to load private key {}.", logCtx, 
keyName);
-                return false;
+                return null;
             }
         } catch (Exception e) {
             log.error("{} Failed to decrypt data key {} to decrypt messages 
{}", logCtx, keyName, e.getMessage());
-            return false;
+            return null;
         }
 
         // Decrypt data key to decrypt messages
-        Cipher dataKeyCipher = null;
-        byte[] dataKeyValue = null;
-        byte[] keyDigest = null;
-
         try {
             AlgorithmParameterSpec params = null;
+            Cipher dataKeyCipher;
             // Decrypt data key using private key
             if (RSA.equals(privateKey.getAlgorithm())) {
                 dataKeyCipher = Cipher.getInstance(RSA_TRANS, 
BouncyCastleProvider.PROVIDER_NAME);
@@ -513,36 +510,35 @@ public class MessageCryptoBc implements 
MessageCrypto<MessageMetadata, MessageMe
                 params = createIESParameterSpec();
             } else {
                 log.error("Unsupported key type {} for key {}.", 
privateKey.getAlgorithm(), keyName);
-                return false;
+                return null;
             }
             if (params != null) {
                 dataKeyCipher.init(Cipher.DECRYPT_MODE, privateKey, params);
             } else {
                 dataKeyCipher.init(Cipher.DECRYPT_MODE, privateKey);
             }
-            dataKeyValue = dataKeyCipher.doFinal(encryptedDataKey);
-
-            keyDigest = digest.digest(encryptedDataKey);
+            byte[] dataKeyValue = dataKeyCipher.doFinal(encryptedDataKey);
+            return new SecretKeySpec(dataKeyValue, "AES");
 
         } catch (Exception e) {
             log.error("{} Failed to decrypt data key {} to decrypt messages 
{}", logCtx, keyName, e.getMessage());
-            return false;
+            return null;
         }
-        dataKey = new SecretKeySpec(dataKeyValue, "AES");
-        dataKeyCache.put(ByteBuffer.wrap(keyDigest), dataKey);
-        return true;
     }
 
     private boolean decryptData(SecretKey dataKeySecret, MessageMetadata 
msgMetadata,
                                 ByteBuffer payload, ByteBuffer targetBuffer) {
-
         // unpack iv and encrypted data
-        iv =  msgMetadata.getEncryptionParam();
+        byte[] iv = msgMetadata.getEncryptionParam();
 
         GCMParameterSpec gcmParams = new GCMParameterSpec(tagLen, iv);
         try {
-            cipher.init(Cipher.DECRYPT_MODE, dataKeySecret, gcmParams);
+            // mark the buffers to allow resetting them in case of decryption 
failure
+            payload.mark();
+            targetBuffer.mark();
 
+            Cipher cipher = getAesGcmCipher();
+            cipher.init(Cipher.DECRYPT_MODE, dataKeySecret, gcmParams);
             int maxLength = cipher.getOutputSize(payload.remaining());
             if (targetBuffer.remaining() < maxLength) {
                 throw new IllegalArgumentException("Target buffer size is too 
small");
@@ -551,8 +547,11 @@ public class MessageCryptoBc implements 
MessageCrypto<MessageMetadata, MessageMe
             targetBuffer.flip();
             targetBuffer.limit(decryptedSize);
             return true;
-
         } catch (Exception e) {
+            // reset the buffers so that decryption can be retried with the 
same buffers
+            payload.reset();
+            targetBuffer.reset();
+
             log.error("{} Failed to decrypt message {}", logCtx, 
e.getMessage());
             return false;
         }
@@ -563,34 +562,6 @@ public class MessageCryptoBc implements 
MessageCrypto<MessageMetadata, MessageMe
         return inputLen + Math.max(inputLen, 512);
     }
 
-    private boolean getKeyAndDecryptData(MessageMetadata msgMetadata, 
ByteBuffer payload, ByteBuffer targetBuffer) {
-        List<EncryptionKeys> encKeys = msgMetadata.getEncryptionKeysList();
-
-        // Go through all keys to retrieve data key from cache
-        for (int i = 0; i < encKeys.size(); i++) {
-
-            byte[] msgDataKey = encKeys.get(i).getValue();
-            byte[] keyDigest = digest.digest(msgDataKey);
-            SecretKey storedSecretKey = 
dataKeyCache.getIfPresent(ByteBuffer.wrap(keyDigest));
-            if (storedSecretKey != null) {
-
-                // Taking a small performance hit here if the hash collides. 
When it
-                // returns a different key, decryption fails. At this point, 
we would
-                // call decryptDataKey to refresh the cache and come here 
again to decrypt.
-                if (decryptData(storedSecretKey, msgMetadata, payload, 
targetBuffer)) {
-                    // If decryption succeeded, we can already return
-                    return true;
-                }
-            } else {
-                // First time, entry won't be present in cache
-                log.debug("{} Failed to decrypt data or data key is not in 
cache. Will attempt to refresh", logCtx);
-            }
-
-        }
-
-        return false;
-    }
-
     /*
      * Decrypt the payload using the data key. Keys used to encrypt data key 
can be retrieved from msgMetadata
      *
@@ -605,31 +576,73 @@ public class MessageCryptoBc implements 
MessageCrypto<MessageMetadata, MessageMe
     @Override
     public boolean decrypt(Supplier<MessageMetadata> messageMetadataSupplier,
                         ByteBuffer payload, ByteBuffer outBuffer, 
CryptoKeyReader keyReader) {
-
         MessageMetadata msgMetadata = messageMetadataSupplier.get();
-        // If dataKey is present, attempt to decrypt using the existing key
-        if (dataKey != null) {
-            if (getKeyAndDecryptData(msgMetadata, payload, outBuffer)) {
+        String producerName = msgMetadata.hasProducerName() ? 
msgMetadata.getProducerName() : "__not_set__";
+
+        // Pass 1: Try last used decryption key for this producer
+        SecretKey lastKey = lastDecryptionKeyCache.getIfPresent(producerName);
+        if (lastKey != null) {
+            if (decryptData(lastKey, msgMetadata, payload, outBuffer)) {
                 return true;
+            } else {
+                lastDecryptionKeyCache.invalidate(producerName);
             }
         }
 
-        // dataKey is null or decryption failed. Attempt to regenerate data key
         List<EncryptionKeys> encKeys = msgMetadata.getEncryptionKeysList();
-        EncryptionKeys encKeyInfo = encKeys.stream().filter(kbv -> {
+        // Pass 2: Try cached keys (fast path — no CryptoKeyReader calls)
+        for (EncryptionKeys encKey : encKeys) {
+            SecretKeyCacheKey cacheKey = new 
SecretKeyCacheKey(encKey.getValue());
+            SecretKey cachedKey = decryptionKeyCache.getIfPresent(cacheKey);
+            if (cachedKey != null) {
+                if (decryptData(cachedKey, msgMetadata, payload, outBuffer)) {
+                    lastDecryptionKeyCache.put(producerName, cachedKey);
+                    return true;
+                }
+            }
+        }
 
-            byte[] encDataKey = kbv.getValue();
-            List<KeyValue> encKeyMeta = kbv.getMetadatasList();
-            return decryptDataKey(kbv.getKey(), encDataKey, encKeyMeta, 
keyReader);
+        // Pass 3: Decrypt data keys via CryptoKeyReader (slow path)
+        for (EncryptionKeys encKey : encKeys) {
+            SecretKeySpec decryptedKey = tryDecryptDataKey(
+                    encKey.getKey(), encKey.getValue(), 
encKey.getMetadatasList(), keyReader);
+            if (decryptedKey != null) {
+                SecretKeyCacheKey cacheKey = new 
SecretKeyCacheKey(encKey.getValue());
+                decryptionKeyCache.put(cacheKey, decryptedKey);
+                if (decryptData(decryptedKey, msgMetadata, payload, 
outBuffer)) {
+                    lastDecryptionKeyCache.put(producerName, decryptedKey);
+                    return true;
+                }
+            }
+        }
+
+        return false;
+    }
 
-        }).findFirst().orElse(null);
+    // key to be used in the cache
+    private static final class SecretKeyCacheKey {
+        private final byte[] encryptedKeyBytes;
+        private final int hashCode;
 
-        if (encKeyInfo == null || dataKey == null) {
-            // Unable to decrypt data key
-            return false;
+        SecretKeyCacheKey(byte[] encryptedKeyBytes) {
+            this.encryptedKeyBytes = encryptedKeyBytes.clone();
+            this.hashCode = Arrays.hashCode(this.encryptedKeyBytes);
         }
 
-        return getKeyAndDecryptData(msgMetadata, payload, outBuffer);
+        @Override
+        public int hashCode() {
+            return hashCode;
+        }
 
+        @Override
+        public boolean equals(Object o) {
+            if (this == o) {
+                return true;
+            }
+            if (!(o instanceof SecretKeyCacheKey)) {
+                return false;
+            }
+            return Arrays.equals(encryptedKeyBytes, ((SecretKeyCacheKey) 
o).encryptedKeyBytes);
+        }
     }
 }
diff --git 
a/pulsar-common/src/main/java/org/apache/pulsar/common/util/SecurityUtility.java
 
b/pulsar-common/src/main/java/org/apache/pulsar/common/util/SecurityUtility.java
index 2b7b1a98463..fa72327dc47 100644
--- 
a/pulsar-common/src/main/java/org/apache/pulsar/common/util/SecurityUtility.java
+++ 
b/pulsar-common/src/main/java/org/apache/pulsar/common/util/SecurityUtility.java
@@ -52,7 +52,6 @@ import java.security.spec.InvalidKeySpecException;
 import java.security.spec.KeySpec;
 import java.security.spec.PKCS8EncodedKeySpec;
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Base64;
 import java.util.Collection;
 import java.util.List;
@@ -83,10 +82,7 @@ public class SecurityUtility {
     public static final String BC_NON_FIPS_PROVIDER_CLASS = 
"org.bouncycastle.jce.provider.BouncyCastleProvider";
     public static final String CONSCRYPT_PROVIDER_CLASS = 
"org.conscrypt.OpenSSLProvider";
     public static final Provider CONSCRYPT_PROVIDER = loadConscryptProvider();
-    private static final List<KeyFactory> KEY_FACTORIES = Arrays.asList(
-            createKeyFactory("RSA"),
-            createKeyFactory("EC")
-    );
+    private static final List<String> KEY_FACTORY_ALGORITHMS = List.of("RSA", 
"EC");
 
     // Security.getProvider("BC") / Security.getProvider("BCFIPS").
     // also used to get Factories. e.g. 
CertificateFactory.getInstance("X.509", "BCFIPS")
@@ -521,12 +517,12 @@ public class SecurityUtility {
                 sb.append(currentLine);
             }
             final KeySpec keySpec = new 
PKCS8EncodedKeySpec(Base64.getDecoder().decode(sb.toString()));
-            final List<String> failedAlgorithm = new 
ArrayList<>(KEY_FACTORIES.size());
-            for (KeyFactory kf : KEY_FACTORIES) {
+            final List<String> failedAlgorithm = new 
ArrayList<>(KEY_FACTORY_ALGORITHMS.size());
+            for (String algorithm : KEY_FACTORY_ALGORITHMS) {
                 try {
-                    return kf.generatePrivate(keySpec);
-                } catch (InvalidKeySpecException ex) {
-                    failedAlgorithm.add(kf.getAlgorithm());
+                    return 
KeyFactory.getInstance(algorithm).generatePrivate(keySpec);
+                } catch (InvalidKeySpecException | NoSuchAlgorithmException 
ex) {
+                    failedAlgorithm.add(algorithm);
                 }
             }
             throw new KeyManagementException("The private key algorithm is not 
supported. attempted: "
@@ -596,11 +592,4 @@ public class SecurityUtility {
         return provider;
     }
 
-    private static KeyFactory createKeyFactory(String algorithm) {
-        try {
-            return KeyFactory.getInstance(algorithm);
-        } catch (Exception e) {
-            throw new IllegalArgumentException(String.format("Illegal key 
factory algorithm " + algorithm), e);
-        }
-    }
 }


Reply via email to