This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.1 by this push:
     new 361a59dec86 [feature](aes_encrypt) support GCM mode for aes_encrypt 
and aes_decrypt (#40004) (#40672)
361a59dec86 is described below

commit 361a59dec86c97b87fda6becdd0aa99dd192c536
Author: camby <camby...@tencent.com>
AuthorDate: Wed Sep 11 23:28:28 2024 +0800

    [feature](aes_encrypt) support GCM mode for aes_encrypt and aes_decrypt 
(#40004) (#40672)
    
    pick #40004 to branch-2.1
---
 be/src/util/encryption_util.cpp                    | 180 ++++++++++++++++-----
 be/src/util/encryption_util.h                      |  17 +-
 be/src/vec/functions/function_encryption.cpp       | 106 ++++++++----
 .../functions/scalar/AesCryptoFunction.java        |  25 ++-
 .../expressions/functions/scalar/AesDecrypt.java   |  22 ++-
 .../expressions/functions/scalar/AesEncrypt.java   |  22 ++-
 gensrc/script/doris_builtins_functions.py          |   4 +
 .../encryption_digest/test_encryption_function.out |  26 +++
 .../test_encryption_function.groovy                |  32 ++++
 9 files changed, 357 insertions(+), 77 deletions(-)

diff --git a/be/src/util/encryption_util.cpp b/be/src/util/encryption_util.cpp
index ab1ff3e857c..bf6777a5c88 100644
--- a/be/src/util/encryption_util.cpp
+++ b/be/src/util/encryption_util.cpp
@@ -25,6 +25,7 @@
 #include <algorithm>
 #include <cstring>
 #include <string>
+#include <unordered_map>
 
 namespace doris {
 
@@ -80,6 +81,12 @@ const EVP_CIPHER* get_evp_type(const EncryptionMode mode) {
         return EVP_aes_256_ctr();
     case EncryptionMode::AES_256_OFB:
         return EVP_aes_256_ofb();
+    case EncryptionMode::AES_128_GCM:
+        return EVP_aes_128_gcm();
+    case EncryptionMode::AES_192_GCM:
+        return EVP_aes_192_gcm();
+    case EncryptionMode::AES_256_GCM:
+        return EVP_aes_256_gcm();
     case EncryptionMode::SM4_128_CBC:
         return EVP_sm4_cbc();
     case EncryptionMode::SM4_128_ECB:
@@ -95,41 +102,29 @@ const EVP_CIPHER* get_evp_type(const EncryptionMode mode) {
     }
 }
 
-static uint mode_key_sizes[] = {
-        128 /* AES_128_ECB */,
-        192 /* AES_192_ECB */,
-        256 /* AES_256_ECB */,
-        128 /* AES_128_CBC */,
-        192 /* AES_192_CBC */,
-        256 /* AES_256_CBC */,
-        128 /* AES_128_CFB */,
-        192 /* AES_192_CFB */,
-        256 /* AES_256_CFB */,
-        128 /* AES_128_CFB1 */,
-        192 /* AES_192_CFB1 */,
-        256 /* AES_256_CFB1 */,
-        128 /* AES_128_CFB8 */,
-        192 /* AES_192_CFB8 */,
-        256 /* AES_256_CFB8 */,
-        128 /* AES_128_CFB128 */,
-        192 /* AES_192_CFB128 */,
-        256 /* AES_256_CFB128 */,
-        128 /* AES_128_CTR */,
-        192 /* AES_192_CTR */,
-        256 /* AES_256_CTR */,
-        128 /* AES_128_OFB */,
-        192 /* AES_192_OFB */,
-        256 /* AES_256_OFB */,
-        128 /* SM4_128_ECB */,
-        128 /* SM4_128_CBC */,
-        128 /* SM4_128_CFB128 */,
-        128 /* SM4_128_OFB */,
-        128 /* SM4_128_CTR */
-};
+static std::unordered_map<EncryptionMode, uint> mode_key_sizes = {
+        {EncryptionMode::AES_128_ECB, 128},    {EncryptionMode::AES_192_ECB, 
192},
+        {EncryptionMode::AES_256_ECB, 256},    {EncryptionMode::AES_128_CBC, 
128},
+        {EncryptionMode::AES_192_CBC, 192},    {EncryptionMode::AES_256_CBC, 
256},
+        {EncryptionMode::AES_128_CFB, 128},    {EncryptionMode::AES_192_CFB, 
192},
+        {EncryptionMode::AES_256_CFB, 256},    {EncryptionMode::AES_128_CFB1, 
128},
+        {EncryptionMode::AES_192_CFB1, 192},   {EncryptionMode::AES_256_CFB1, 
256},
+        {EncryptionMode::AES_128_CFB8, 128},   {EncryptionMode::AES_192_CFB8, 
192},
+        {EncryptionMode::AES_256_CFB8, 256},   
{EncryptionMode::AES_128_CFB128, 128},
+        {EncryptionMode::AES_192_CFB128, 192}, 
{EncryptionMode::AES_256_CFB128, 256},
+        {EncryptionMode::AES_128_CTR, 128},    {EncryptionMode::AES_192_CTR, 
192},
+        {EncryptionMode::AES_256_CTR, 256},    {EncryptionMode::AES_128_OFB, 
128},
+        {EncryptionMode::AES_192_OFB, 192},    {EncryptionMode::AES_256_OFB, 
256},
+        {EncryptionMode::AES_128_GCM, 128},    {EncryptionMode::AES_192_GCM, 
192},
+        {EncryptionMode::AES_256_GCM, 256},
+
+        {EncryptionMode::SM4_128_ECB, 128},    {EncryptionMode::SM4_128_CBC, 
128},
+        {EncryptionMode::SM4_128_CFB128, 128}, {EncryptionMode::SM4_128_OFB, 
128},
+        {EncryptionMode::SM4_128_CTR, 128}};
 
 static void create_key(const unsigned char* origin_key, uint32_t key_length, 
uint8_t* encrypt_key,
                        EncryptionMode mode) {
-    const uint key_size = mode_key_sizes[int(mode)] / 8;
+    const uint key_size = mode_key_sizes[mode] / 8;
     uint8_t* origin_key_end = ((uint8_t*)origin_key) + key_length; /* origin 
key boundary*/
 
     uint8_t* encrypt_key_end; /* encrypt key boundary */
@@ -172,10 +167,58 @@ static int do_encrypt(EVP_CIPHER_CTX* cipher_ctx, const 
EVP_CIPHER* cipher,
     return ret;
 }
 
+static int do_gcm_encrypt(EVP_CIPHER_CTX* cipher_ctx, const EVP_CIPHER* cipher,
+                          const unsigned char* source, uint32_t source_length,
+                          const unsigned char* encrypt_key, const unsigned 
char* iv, int iv_length,
+                          unsigned char* encrypt, int* length_ptr, const 
unsigned char* aad,
+                          uint32_t aad_length) {
+    int ret = EVP_EncryptInit_ex(cipher_ctx, cipher, nullptr, nullptr, 
nullptr);
+    if (ret != 1) {
+        return ret;
+    }
+    ret = EVP_CIPHER_CTX_ctrl(cipher_ctx, EVP_CTRL_GCM_SET_IVLEN, iv_length, 
nullptr);
+    if (ret != 1) {
+        return ret;
+    }
+    ret = EVP_EncryptInit_ex(cipher_ctx, nullptr, nullptr, encrypt_key, iv);
+    if (ret != 1) {
+        return ret;
+    }
+    if (aad) {
+        int tmp_len = 0;
+        ret = EVP_EncryptUpdate(cipher_ctx, nullptr, &tmp_len, aad, 
aad_length);
+        if (ret != 1) {
+            return ret;
+        }
+    }
+
+    std::memcpy(encrypt, iv, iv_length);
+    encrypt += iv_length;
+
+    int u_len = 0;
+    ret = EVP_EncryptUpdate(cipher_ctx, encrypt, &u_len, source, 
source_length);
+    if (ret != 1) {
+        return ret;
+    }
+    encrypt += u_len;
+
+    int f_len = 0;
+    ret = EVP_EncryptFinal_ex(cipher_ctx, encrypt, &f_len);
+    if (ret != 1) {
+        return ret;
+    }
+    encrypt += f_len;
+
+    ret = EVP_CIPHER_CTX_ctrl(cipher_ctx, EVP_CTRL_GCM_GET_TAG, 
EncryptionUtil::GCM_TAG_SIZE,
+                              encrypt);
+    *length_ptr = iv_length + u_len + f_len + EncryptionUtil::GCM_TAG_SIZE;
+    return ret;
+}
+
 int EncryptionUtil::encrypt(EncryptionMode mode, const unsigned char* source,
                             uint32_t source_length, const unsigned char* key, 
uint32_t key_length,
                             const char* iv_str, int iv_input_length, bool 
padding,
-                            unsigned char* encrypt) {
+                            unsigned char* encrypt, const unsigned char* aad, 
uint32_t aad_length) {
     const EVP_CIPHER* cipher = get_evp_type(mode);
     /* The encrypt key to be used for encryption */
     unsigned char encrypt_key[ENCRYPTION_MAX_KEY_LENGTH / 8];
@@ -196,8 +239,16 @@ int EncryptionUtil::encrypt(EncryptionMode mode, const 
unsigned char* source,
     EVP_CIPHER_CTX* cipher_ctx = EVP_CIPHER_CTX_new();
     EVP_CIPHER_CTX_reset(cipher_ctx);
     int length = 0;
-    int ret = do_encrypt(cipher_ctx, cipher, source, source_length, 
encrypt_key,
+    int ret = 0;
+    if (is_gcm_mode(mode)) {
+        ret = do_gcm_encrypt(cipher_ctx, cipher, source, source_length, 
encrypt_key,
+                             reinterpret_cast<unsigned char*>(init_vec), 
iv_length, encrypt,
+                             &length, aad, aad_length);
+    } else {
+        ret = do_encrypt(cipher_ctx, cipher, source, source_length, 
encrypt_key,
                          reinterpret_cast<unsigned char*>(init_vec), padding, 
encrypt, &length);
+    }
+
     EVP_CIPHER_CTX_free(cipher_ctx);
     if (ret == 0) {
         ERR_clear_error();
@@ -230,10 +281,61 @@ static int do_decrypt(EVP_CIPHER_CTX* cipher_ctx, const 
EVP_CIPHER* cipher,
     return ret;
 }
 
+static int do_gcm_decrypt(EVP_CIPHER_CTX* cipher_ctx, const EVP_CIPHER* cipher,
+                          const unsigned char* encrypt, uint32_t 
encrypt_length,
+                          const unsigned char* encrypt_key, int iv_length,
+                          unsigned char* decrypt_content, int* length_ptr, 
const unsigned char* aad,
+                          uint32_t aad_length) {
+    if (encrypt_length < iv_length + EncryptionUtil::GCM_TAG_SIZE) {
+        return -1;
+    }
+    int ret = EVP_DecryptInit_ex(cipher_ctx, cipher, nullptr, nullptr, 
nullptr);
+    if (ret != 1) {
+        return ret;
+    }
+    ret = EVP_CIPHER_CTX_ctrl(cipher_ctx, EVP_CTRL_GCM_SET_IVLEN, iv_length, 
nullptr);
+    if (ret != 1) {
+        return ret;
+    }
+    ret = EVP_DecryptInit_ex(cipher_ctx, nullptr, nullptr, encrypt_key, 
encrypt);
+    if (ret != 1) {
+        return ret;
+    }
+    encrypt += iv_length;
+    if (aad) {
+        int tmp_len = 0;
+        ret = EVP_DecryptUpdate(cipher_ctx, nullptr, &tmp_len, aad, 
aad_length);
+        if (ret != 1) {
+            return ret;
+        }
+    }
+
+    uint32_t real_encrypt_length = encrypt_length - iv_length - 
EncryptionUtil::GCM_TAG_SIZE;
+    int u_len = 0;
+    ret = EVP_DecryptUpdate(cipher_ctx, decrypt_content, &u_len, encrypt, 
real_encrypt_length);
+    if (ret != 1) {
+        return ret;
+    }
+    encrypt += real_encrypt_length;
+    decrypt_content += u_len;
+
+    void* tag = const_cast<void*>(reinterpret_cast<const void*>(encrypt));
+    ret = EVP_CIPHER_CTX_ctrl(cipher_ctx, EVP_CTRL_GCM_SET_TAG, 
EncryptionUtil::GCM_TAG_SIZE, tag);
+    if (ret != 1) {
+        return ret;
+    }
+
+    int f_len = 0;
+    ret = EVP_DecryptFinal_ex(cipher_ctx, decrypt_content, &f_len);
+    *length_ptr = u_len + f_len;
+    return ret;
+}
+
 int EncryptionUtil::decrypt(EncryptionMode mode, const unsigned char* encrypt,
                             uint32_t encrypt_length, const unsigned char* key, 
uint32_t key_length,
                             const char* iv_str, int iv_input_length, bool 
padding,
-                            unsigned char* decrypt_content) {
+                            unsigned char* decrypt_content, const unsigned 
char* aad,
+                            uint32_t aad_length) {
     const EVP_CIPHER* cipher = get_evp_type(mode);
 
     /* The encrypt key to be used for decryption */
@@ -255,9 +357,15 @@ int EncryptionUtil::decrypt(EncryptionMode mode, const 
unsigned char* encrypt,
     EVP_CIPHER_CTX* cipher_ctx = EVP_CIPHER_CTX_new();
     EVP_CIPHER_CTX_reset(cipher_ctx);
     int length = 0;
-    int ret = do_decrypt(cipher_ctx, cipher, encrypt, encrypt_length, 
encrypt_key,
+    int ret = 0;
+    if (is_gcm_mode(mode)) {
+        ret = do_gcm_decrypt(cipher_ctx, cipher, encrypt, encrypt_length, 
encrypt_key, iv_length,
+                             decrypt_content, &length, aad, aad_length);
+    } else {
+        ret = do_decrypt(cipher_ctx, cipher, encrypt, encrypt_length, 
encrypt_key,
                          reinterpret_cast<unsigned char*>(init_vec), padding, 
decrypt_content,
                          &length);
+    }
     EVP_CIPHER_CTX_free(cipher_ctx);
     if (ret > 0) {
         return length;
diff --git a/be/src/util/encryption_util.h b/be/src/util/encryption_util.h
index 8e61a119953..dfd288c3147 100644
--- a/be/src/util/encryption_util.h
+++ b/be/src/util/encryption_util.h
@@ -46,6 +46,9 @@ enum class EncryptionMode {
     AES_128_OFB,
     AES_192_OFB,
     AES_256_OFB,
+    AES_128_GCM,
+    AES_192_GCM,
+    AES_256_GCM,
     SM4_128_ECB,
     SM4_128_CBC,
     SM4_128_CFB128,
@@ -57,13 +60,23 @@ enum EncryptionState { AES_SUCCESS = 0, AES_BAD_DATA = -1 };
 
 class EncryptionUtil {
 public:
+    static bool is_gcm_mode(EncryptionMode mode) {
+        return mode == EncryptionMode::AES_128_GCM || mode == 
EncryptionMode::AES_192_GCM ||
+               mode == EncryptionMode::AES_256_GCM;
+    }
+
+    // https://tools.ietf.org/html/rfc5116#section-5.1
+    static const int GCM_TAG_SIZE = 16;
+
     static int encrypt(EncryptionMode mode, const unsigned char* source, 
uint32_t source_length,
                        const unsigned char* key, uint32_t key_length, const 
char* iv_str,
-                       int iv_input_length, bool padding, unsigned char* 
encrypt);
+                       int iv_input_length, bool padding, unsigned char* 
encrypt,
+                       const unsigned char* aad = nullptr, uint32_t aad_length 
= 0);
 
     static int decrypt(EncryptionMode mode, const unsigned char* encrypt, 
uint32_t encrypt_length,
                        const unsigned char* key, uint32_t key_length, const 
char* iv_str,
-                       int iv_input_length, bool padding, unsigned char* 
decrypt_content);
+                       int iv_input_length, bool padding, unsigned char* 
decrypt_content,
+                       const unsigned char* aad = nullptr, uint32_t aad_length 
= 0);
 };
 
 } // namespace doris
diff --git a/be/src/vec/functions/function_encryption.cpp 
b/be/src/vec/functions/function_encryption.cpp
index f63e9bca1b0..c90b6a1ff60 100644
--- a/be/src/vec/functions/function_encryption.cpp
+++ b/be/src/vec/functions/function_encryption.cpp
@@ -79,7 +79,10 @@ inline StringCaseUnorderedMap<EncryptionMode> aes_mode_map {
         {"AES_256_CTR", EncryptionMode::AES_256_CTR},
         {"AES_128_OFB", EncryptionMode::AES_128_OFB},
         {"AES_192_OFB", EncryptionMode::AES_192_OFB},
-        {"AES_256_OFB", EncryptionMode::AES_256_OFB}};
+        {"AES_256_OFB", EncryptionMode::AES_256_OFB},
+        {"AES_128_GCM", EncryptionMode::AES_128_GCM},
+        {"AES_192_GCM", EncryptionMode::AES_192_GCM},
+        {"AES_256_GCM", EncryptionMode::AES_256_GCM}};
 inline StringCaseUnorderedMap<EncryptionMode> sm4_mode_map {
         {"SM4_128_ECB", EncryptionMode::SM4_128_ECB},
         {"SM4_128_CBC", EncryptionMode::SM4_128_CBC},
@@ -120,7 +123,7 @@ void execute_result_vector(std::vector<const 
ColumnString::Offsets*>& offsets_li
                            std::vector<const ColumnString::Chars*>& 
chars_list, size_t i,
                            EncryptionMode& encryption_mode, const char* 
iv_raw, int iv_length,
                            ColumnString::Chars& result_data, 
ColumnString::Offsets& result_offset,
-                           NullMap& null_map) {
+                           NullMap& null_map, const char* aad, int aad_length) 
{
     int src_size = (*offsets_list[0])[i] - (*offsets_list[0])[i - 1];
     const auto* src_raw =
             reinterpret_cast<const 
char*>(&(*chars_list[0])[(*offsets_list[0])[i - 1]]);
@@ -128,7 +131,8 @@ void execute_result_vector(std::vector<const 
ColumnString::Offsets*>& offsets_li
     const auto* key_raw =
             reinterpret_cast<const 
char*>(&(*chars_list[1])[(*offsets_list[1])[i - 1]]);
     execute_result<Impl, is_encrypt>(src_raw, src_size, key_raw, key_size, i, 
encryption_mode,
-                                     iv_raw, iv_length, result_data, 
result_offset, null_map);
+                                     iv_raw, iv_length, result_data, 
result_offset, null_map, aad,
+                                     aad_length);
 }
 
 template <typename Impl, bool is_encrypt>
@@ -136,19 +140,19 @@ void execute_result_const(const ColumnString::Offsets* 
offsets_column,
                           const ColumnString::Chars* chars_column, StringRef 
key_arg, size_t i,
                           EncryptionMode& encryption_mode, const char* iv_raw, 
int iv_length,
                           ColumnString::Chars& result_data, 
ColumnString::Offsets& result_offset,
-                          NullMap& null_map) {
+                          NullMap& null_map, const char* aad, int aad_length) {
     int src_size = (*offsets_column)[i] - (*offsets_column)[i - 1];
     const auto* src_raw = reinterpret_cast<const 
char*>(&(*chars_column)[(*offsets_column)[i - 1]]);
     execute_result<Impl, is_encrypt>(src_raw, src_size, key_arg.data, 
key_arg.size, i,
                                      encryption_mode, iv_raw, iv_length, 
result_data, result_offset,
-                                     null_map);
+                                     null_map, aad, aad_length);
 }
 
 template <typename Impl, bool is_encrypt>
 void execute_result(const char* src_raw, int src_size, const char* key_raw, 
int key_size, size_t i,
                     EncryptionMode& encryption_mode, const char* iv_raw, int 
iv_length,
                     ColumnString::Chars& result_data, ColumnString::Offsets& 
result_offset,
-                    NullMap& null_map) {
+                    NullMap& null_map, const char* aad, int aad_length) {
     if (src_size == 0) {
         StringOP::push_null_string(i, result_data, result_offset, null_map);
         return;
@@ -156,6 +160,10 @@ void execute_result(const char* src_raw, int src_size, 
const char* key_raw, int
     int cipher_len = src_size;
     if constexpr (is_encrypt) {
         cipher_len += 16;
+        // for output AEAD tag
+        if (EncryptionUtil::is_gcm_mode(encryption_mode)) {
+            cipher_len += EncryptionUtil::GCM_TAG_SIZE;
+        }
     }
     std::unique_ptr<char[]> p;
     p.reset(new char[cipher_len]);
@@ -163,7 +171,7 @@ void execute_result(const char* src_raw, int src_size, 
const char* key_raw, int
 
     ret_code = Impl::execute_impl(encryption_mode, (unsigned char*)src_raw, 
src_size,
                                   (unsigned char*)key_raw, key_size, iv_raw, 
iv_length, true,
-                                  (unsigned char*)p.get());
+                                  (unsigned char*)p.get(), (unsigned 
char*)aad, aad_length);
 
     if (ret_code < 0) {
         StringOP::push_null_string(i, result_data, result_offset, null_map);
@@ -248,7 +256,7 @@ struct EncryptionAndDecryptTwoImpl {
             }
             execute_result_const<Impl, is_encrypt>(offsets_column, 
chars_column, key_arg, i,
                                                    encryption_mode, nullptr, 
0, result_data,
-                                                   result_offset, null_map);
+                                                   result_offset, null_map, 
nullptr, 0);
         }
     }
 
@@ -275,16 +283,22 @@ struct EncryptionAndDecryptTwoImpl {
             }
             execute_result_vector<Impl, is_encrypt>(offsets_list, chars_list, 
i, encryption_mode,
                                                     nullptr, 0, result_data, 
result_offset,
-                                                    null_map);
+                                                    null_map, nullptr, 0);
         }
     }
 };
 
-template <typename Impl, EncryptionMode mode, bool is_encrypt, bool is_sm_mode>
-struct EncryptionAndDecryptFourImpl {
+template <typename Impl, EncryptionMode mode, bool is_encrypt, bool 
is_sm_mode, int arg_num = 4>
+struct EncryptionAndDecryptMultiImpl {
     static DataTypes get_variadic_argument_types_impl() {
-        return {std::make_shared<DataTypeString>(), 
std::make_shared<DataTypeString>(),
-                std::make_shared<DataTypeString>(), 
std::make_shared<DataTypeString>()};
+        if constexpr (arg_num == 5) {
+            return {std::make_shared<DataTypeString>(), 
std::make_shared<DataTypeString>(),
+                    std::make_shared<DataTypeString>(), 
std::make_shared<DataTypeString>(),
+                    std::make_shared<DataTypeString>()};
+        } else {
+            return {std::make_shared<DataTypeString>(), 
std::make_shared<DataTypeString>(),
+                    std::make_shared<DataTypeString>(), 
std::make_shared<DataTypeString>()};
+        }
     }
 
     static Status execute_impl_inner(FunctionContext* context, Block& block,
@@ -292,8 +306,8 @@ struct EncryptionAndDecryptFourImpl {
                                      size_t input_rows_count) {
         auto result_column = ColumnString::create();
         auto result_null_map_column = ColumnUInt8::create(input_rows_count, 0);
-        DCHECK_EQ(4, arguments.size());
-        const size_t argument_size = 4;
+        DCHECK_EQ(arguments.size(), arg_num);
+        const size_t argument_size = arg_num;
         bool col_const[argument_size];
         ColumnPtr argument_columns[argument_size];
         for (int i = 0; i < argument_size; ++i) {
@@ -304,8 +318,13 @@ struct EncryptionAndDecryptFourImpl {
                                                      .convert_to_full_column()
                                            : 
block.get_by_position(arguments[0]).column;
 
-        default_preprocess_parameter_columns(argument_columns, col_const, {1, 
2, 3}, block,
-                                             arguments);
+        if constexpr (arg_num == 5) {
+            default_preprocess_parameter_columns(argument_columns, col_const, 
{1, 2, 3, 4}, block,
+                                                 arguments);
+        } else {
+            default_preprocess_parameter_columns(argument_columns, col_const, 
{1, 2, 3}, block,
+                                                 arguments);
+        }
 
         for (int i = 0; i < argument_size; i++) {
             check_set_nullable(argument_columns[i], result_null_map_column, 
col_const[i]);
@@ -314,11 +333,17 @@ struct EncryptionAndDecryptFourImpl {
         auto& result_offset = result_column->get_offsets();
         result_offset.resize(input_rows_count);
 
-        if (col_const[1] && col_const[2] && col_const[3]) {
+        if ((arg_num == 5) && col_const[1] && col_const[2] && col_const[3] && 
col_const[4]) {
+            vector_const(assert_cast<const 
ColumnString*>(argument_columns[0].get()),
+                         argument_columns[1]->get_data_at(0), 
argument_columns[2]->get_data_at(0),
+                         argument_columns[3]->get_data_at(0), 
input_rows_count, result_data,
+                         result_offset, result_null_map_column->get_data(),
+                         argument_columns[4]->get_data_at(0));
+        } else if ((arg_num == 4) && col_const[1] && col_const[2] && 
col_const[3]) {
             vector_const(assert_cast<const 
ColumnString*>(argument_columns[0].get()),
                          argument_columns[1]->get_data_at(0), 
argument_columns[2]->get_data_at(0),
                          argument_columns[3]->get_data_at(0), 
input_rows_count, result_data,
-                         result_offset, result_null_map_column->get_data());
+                         result_offset, result_null_map_column->get_data(), 
StringRef());
         } else {
             std::vector<const ColumnString::Offsets*> 
offsets_list(argument_size);
             std::vector<const ColumnString::Chars*> chars_list(argument_size);
@@ -338,7 +363,7 @@ struct EncryptionAndDecryptFourImpl {
     static void vector_const(const ColumnString* column, StringRef key_arg, 
StringRef iv_arg,
                              StringRef mode_arg, size_t input_rows_count,
                              ColumnString::Chars& result_data, 
ColumnString::Offsets& result_offset,
-                             NullMap& null_map) {
+                             NullMap& null_map, StringRef aad_arg) {
         EncryptionMode encryption_mode = mode;
         bool all_insert_null = false;
         if (mode_arg.size != 0) {
@@ -363,9 +388,9 @@ struct EncryptionAndDecryptFourImpl {
                 StringOP::push_null_string(i, result_data, result_offset, 
null_map);
                 continue;
             }
-            execute_result_const<Impl, is_encrypt>(offsets_column, 
chars_column, key_arg, i,
-                                                   encryption_mode, 
iv_arg.data, iv_arg.size,
-                                                   result_data, result_offset, 
null_map);
+            execute_result_const<Impl, is_encrypt>(
+                    offsets_column, chars_column, key_arg, i, encryption_mode, 
iv_arg.data,
+                    iv_arg.size, result_data, result_offset, null_map, 
aad_arg.data, aad_arg.size);
         }
     }
 
@@ -403,9 +428,16 @@ struct EncryptionAndDecryptFourImpl {
                 }
             }
 
+            int aad_size = 0;
+            const char* aad = nullptr;
+            if constexpr (arg_num == 5) {
+                aad_size = (*offsets_list[4])[i] - (*offsets_list[4])[i - 1];
+                aad = reinterpret_cast<const 
char*>(&(*chars_list[4])[(*offsets_list[4])[i - 1]]);
+            }
+
             execute_result_vector<Impl, is_encrypt>(offsets_list, chars_list, 
i, encryption_mode,
                                                     iv_raw, iv_size, 
result_data, result_offset,
-                                                    null_map);
+                                                    null_map, aad, aad_size);
         }
     }
 };
@@ -413,18 +445,20 @@ struct EncryptionAndDecryptFourImpl {
 struct EncryptImpl {
     static int execute_impl(EncryptionMode mode, const unsigned char* source,
                             uint32_t source_length, const unsigned char* key, 
uint32_t key_length,
-                            const char* iv, int iv_length, bool padding, 
unsigned char* encrypt) {
+                            const char* iv, int iv_length, bool padding, 
unsigned char* encrypt,
+                            const unsigned char* aad, int aad_length) {
         return EncryptionUtil::encrypt(mode, source, source_length, key, 
key_length, iv, iv_length,
-                                       true, encrypt);
+                                       true, encrypt, aad, aad_length);
     }
 };
 
 struct DecryptImpl {
     static int execute_impl(EncryptionMode mode, const unsigned char* source,
                             uint32_t source_length, const unsigned char* key, 
uint32_t key_length,
-                            const char* iv, int iv_length, bool padding, 
unsigned char* encrypt) {
+                            const char* iv, int iv_length, bool padding, 
unsigned char* encrypt,
+                            const unsigned char* aad, int aad_length) {
         return EncryptionUtil::decrypt(mode, source, source_length, key, 
key_length, iv, iv_length,
-                                       true, encrypt);
+                                       true, encrypt, aad, aad_length);
     }
 };
 
@@ -459,16 +493,24 @@ void register_function_encryption(SimpleFunctionFactory& 
factory) {
             AESDecryptName>>();
 
     factory.register_function<FunctionEncryptionAndDecrypt<
-            EncryptionAndDecryptFourImpl<EncryptImpl, 
EncryptionMode::SM4_128_ECB, true, true>,
+            EncryptionAndDecryptMultiImpl<EncryptImpl, 
EncryptionMode::SM4_128_ECB, true, true>,
             SM4EncryptName>>();
     factory.register_function<FunctionEncryptionAndDecrypt<
-            EncryptionAndDecryptFourImpl<DecryptImpl, 
EncryptionMode::SM4_128_ECB, false, true>,
+            EncryptionAndDecryptMultiImpl<DecryptImpl, 
EncryptionMode::SM4_128_ECB, false, true>,
             SM4DecryptName>>();
     factory.register_function<FunctionEncryptionAndDecrypt<
-            EncryptionAndDecryptFourImpl<EncryptImpl, 
EncryptionMode::AES_128_ECB, true, false>,
+            EncryptionAndDecryptMultiImpl<EncryptImpl, 
EncryptionMode::AES_128_ECB, true, false>,
+            AESEncryptName>>();
+    factory.register_function<FunctionEncryptionAndDecrypt<
+            EncryptionAndDecryptMultiImpl<DecryptImpl, 
EncryptionMode::AES_128_ECB, false, false>,
+            AESDecryptName>>();
+
+    factory.register_function<FunctionEncryptionAndDecrypt<
+            EncryptionAndDecryptMultiImpl<EncryptImpl, 
EncryptionMode::AES_128_GCM, true, false, 5>,
             AESEncryptName>>();
     factory.register_function<FunctionEncryptionAndDecrypt<
-            EncryptionAndDecryptFourImpl<DecryptImpl, 
EncryptionMode::AES_128_ECB, false, false>,
+            EncryptionAndDecryptMultiImpl<DecryptImpl, 
EncryptionMode::AES_128_GCM, false, false,
+                                          5>,
             AESDecryptName>>();
 }
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesCryptoFunction.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesCryptoFunction.java
index a72b84dab1c..3a98fb64ccf 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesCryptoFunction.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesCryptoFunction.java
@@ -19,6 +19,7 @@ package 
org.apache.doris.nereids.trees.expressions.functions.scalar;
 
 import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
 
 import com.google.common.collect.ImmutableSet;
@@ -52,7 +53,16 @@ public abstract class AesCryptoFunction extends 
CryptoFunction {
             "AES_256_CTR",
             "AES_128_OFB",
             "AES_192_OFB",
-            "AES_256_OFB"
+            "AES_256_OFB",
+            "AES_128_GCM",
+            "AES_192_GCM",
+            "AES_256_GCM"
+    );
+
+    public static final Set<String> AES_GCM_MODES = ImmutableSet.of(
+            "AES_128_GCM",
+            "AES_192_GCM",
+            "AES_256_GCM"
     );
 
     public AesCryptoFunction(String name, Expression... arguments) {
@@ -72,4 +82,17 @@ public abstract class AesCryptoFunction extends 
CryptoFunction {
         }
         return encryptionMode;
     }
+
+    @Override
+    public void checkLegalityAfterRewrite() {
+        if (arity() >= 4 && child(3) instanceof StringLikeLiteral) {
+            String mode = ((StringLikeLiteral) 
child(3)).getValue().toUpperCase();
+            if (!AES_MODES.contains(mode)) {
+                throw new AnalysisException("mode " + mode + " is not 
supported");
+            }
+            if (arity() == 5 && !AES_GCM_MODES.contains(mode)) {
+                throw new AnalysisException("only GCM mode support AAD(the 5th 
arg)");
+            }
+        }
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesDecrypt.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesDecrypt.java
index 7608cf4e40e..8967b0e5138 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesDecrypt.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesDecrypt.java
@@ -50,7 +50,16 @@ public class AesDecrypt extends AesCryptoFunction {
                             VarcharType.SYSTEM_DEFAULT,
                             VarcharType.SYSTEM_DEFAULT),
             FunctionSignature.ret(StringType.INSTANCE)
-                    .args(StringType.INSTANCE, StringType.INSTANCE, 
StringType.INSTANCE, StringType.INSTANCE)
+                    .args(StringType.INSTANCE, StringType.INSTANCE, 
StringType.INSTANCE, StringType.INSTANCE),
+            FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
+                    .args(VarcharType.SYSTEM_DEFAULT,
+                            VarcharType.SYSTEM_DEFAULT,
+                            VarcharType.SYSTEM_DEFAULT,
+                            VarcharType.SYSTEM_DEFAULT,
+                            VarcharType.SYSTEM_DEFAULT),
+            FunctionSignature.ret(StringType.INSTANCE)
+                    .args(StringType.INSTANCE, StringType.INSTANCE, 
StringType.INSTANCE, StringType.INSTANCE,
+                            StringType.INSTANCE)
     );
 
     /**
@@ -68,18 +77,25 @@ public class AesDecrypt extends AesCryptoFunction {
         super("aes_decrypt", arg0, arg1, arg2, arg3);
     }
 
+    public AesDecrypt(Expression arg0, Expression arg1, Expression arg2, 
Expression arg3, Expression arg4) {
+        super("aes_decrypt", arg0, arg1, arg2, arg3, arg4);
+    }
+
     /**
      * withChildren.
      */
     @Override
     public AesDecrypt withChildren(List<Expression> children) {
-        Preconditions.checkArgument(children.size() >= 2 && children.size() <= 
4);
+        Preconditions.checkArgument(children.size() >= 2 && children.size() <= 
5);
         if (children.size() == 2) {
             return new AesDecrypt(children.get(0), children.get(1));
         } else if (children().size() == 3) {
             return new AesDecrypt(children.get(0), children.get(1), 
children.get(2));
-        } else {
+        } else if (children().size() == 4) {
             return new AesDecrypt(children.get(0), children.get(1), 
children.get(2), children.get(3));
+        } else {
+            return new AesDecrypt(children.get(0), children.get(1), 
children.get(2), children.get(3),
+                    children.get(4));
         }
     }
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesEncrypt.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesEncrypt.java
index 455d6b0dbd5..a70a639785f 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesEncrypt.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesEncrypt.java
@@ -50,7 +50,16 @@ public class AesEncrypt extends AesCryptoFunction {
                             VarcharType.SYSTEM_DEFAULT,
                             VarcharType.SYSTEM_DEFAULT),
             FunctionSignature.ret(StringType.INSTANCE)
-                    .args(StringType.INSTANCE, StringType.INSTANCE, 
StringType.INSTANCE, StringType.INSTANCE)
+                    .args(StringType.INSTANCE, StringType.INSTANCE, 
StringType.INSTANCE, StringType.INSTANCE),
+            FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
+                    .args(VarcharType.SYSTEM_DEFAULT,
+                            VarcharType.SYSTEM_DEFAULT,
+                            VarcharType.SYSTEM_DEFAULT,
+                            VarcharType.SYSTEM_DEFAULT,
+                            VarcharType.SYSTEM_DEFAULT),
+            FunctionSignature.ret(StringType.INSTANCE)
+                    .args(StringType.INSTANCE, StringType.INSTANCE, 
StringType.INSTANCE, StringType.INSTANCE,
+                            StringType.INSTANCE)
     );
 
     /**
@@ -68,18 +77,25 @@ public class AesEncrypt extends AesCryptoFunction {
         super("aes_encrypt", arg0, arg1, arg2, arg3);
     }
 
+    public AesEncrypt(Expression arg0, Expression arg1, Expression arg2, 
Expression arg3, Expression arg4) {
+        super("aes_encrypt", arg0, arg1, arg2, arg3, arg4);
+    }
+
     /**
      * withChildren.
      */
     @Override
     public AesEncrypt withChildren(List<Expression> children) {
-        Preconditions.checkArgument(children.size() >= 2 && children.size() <= 
4);
+        Preconditions.checkArgument(children.size() >= 2 && children.size() <= 
5);
         if (children.size() == 2) {
             return new AesEncrypt(children.get(0), children.get(1));
         } else if (children().size() == 3) {
             return new AesEncrypt(children.get(0), children.get(1), 
children.get(2));
-        } else {
+        } else if (children().size() == 4) {
             return new AesEncrypt(children.get(0), children.get(1), 
children.get(2), children.get(3));
+        } else {
+            return new AesEncrypt(children.get(0), children.get(1), 
children.get(2), children.get(3),
+                    children.get(4));
         }
     }
 
diff --git a/gensrc/script/doris_builtins_functions.py 
b/gensrc/script/doris_builtins_functions.py
index 018d71385e0..6deede3784d 100644
--- a/gensrc/script/doris_builtins_functions.py
+++ b/gensrc/script/doris_builtins_functions.py
@@ -1919,6 +1919,8 @@ visible_functions = {
         [['aes_decrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR'], 
'ALWAYS_NULLABLE'],
         [['aes_encrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR', 
'VARCHAR'], 'ALWAYS_NULLABLE'],
         [['aes_decrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR', 
'VARCHAR'], 'ALWAYS_NULLABLE'],
+        [['aes_encrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR', 
'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
+        [['aes_decrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR', 
'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
         [['sm4_encrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR'], 
'ALWAYS_NULLABLE'],
         [['sm4_decrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR'], 
'ALWAYS_NULLABLE'],
         [['sm4_encrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR', 
'VARCHAR'], 'ALWAYS_NULLABLE'],
@@ -1928,6 +1930,8 @@ visible_functions = {
         [['aes_decrypt'], 'STRING', ['STRING', 'STRING', 'STRING'], 
'ALWAYS_NULLABLE'],
         [['aes_encrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING'], 
'ALWAYS_NULLABLE'],
         [['aes_decrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING'], 
'ALWAYS_NULLABLE'],
+        [['aes_encrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING', 
'STRING'], 'ALWAYS_NULLABLE'],
+        [['aes_decrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING', 
'STRING'], 'ALWAYS_NULLABLE'],
         [['sm4_encrypt'], 'STRING', ['STRING', 'STRING', 'STRING'], 
'ALWAYS_NULLABLE'],
         [['sm4_decrypt'], 'STRING', ['STRING', 'STRING', 'STRING'], 
'ALWAYS_NULLABLE'],
         [['sm4_encrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING'], 
'ALWAYS_NULLABLE'],
diff --git 
a/regression-test/data/nereids_p0/sql_functions/encryption_digest/test_encryption_function.out
 
b/regression-test/data/nereids_p0/sql_functions/encryption_digest/test_encryption_function.out
index 7a91c1dbf72..ec96e6df2b0 100644
--- 
a/regression-test/data/nereids_p0/sql_functions/encryption_digest/test_encryption_function.out
+++ 
b/regression-test/data/nereids_p0/sql_functions/encryption_digest/test_encryption_function.out
@@ -56,3 +56,29 @@ text
 -- !sql --
 82ec580fe6d36ae4f81cae3c73f4a5b3b5a09c943172dc9053c69fd8e18dca1e
 
+-- !sql_gcm_1 --
+MTIzNDU2Nzg5MDEyMdXvR41sJqwZ6hnTU8FRTTtXbL8yeChIZA==
+
+-- !sql_gcm_2 --
+Spark SQL
+
+-- !sql_gcm_3 --
+AAAAAAAAAAAAAAAAQiYi+sTLm7KD9UcZ2nlRdYDe/PX4
+
+-- !sql_gcm_4 --
+Spark
+
+-- !sql_gcm_5 --
+1      MTIzNDU2Nzg5MDEyMdXvR41sJqwZ6hnTU8FRTTtXbL8yeChIZA==
+2      AAAAAAAAAAAAAAAAQiYi+sTLm7KD9UcZ2nlRdYDe/PX4
+
+-- !sql_gcm_6 --
+1      Spark SQL
+2      Spark
+
+-- !sql_gcm_7 --
+1      MTIzNDU2Nzg5MDEyMdXvR41sJqwZ6hnTU8FRTTtXbL8yeChIZA==
+
+-- !sql_gcm_8 --
+Spark SQL
+
diff --git 
a/regression-test/suites/nereids_p0/sql_functions/encryption_digest/test_encryption_function.groovy
 
b/regression-test/suites/nereids_p0/sql_functions/encryption_digest/test_encryption_function.groovy
index 0a2d4261001..f96290faffb 100644
--- 
a/regression-test/suites/nereids_p0/sql_functions/encryption_digest/test_encryption_function.groovy
+++ 
b/regression-test/suites/nereids_p0/sql_functions/encryption_digest/test_encryption_function.groovy
@@ -57,4 +57,36 @@ suite("test_encryption_function") {
     qt_sql "SELECT SM3(\"abc\");"
     qt_sql "select sm3(\"abcd\");"
     qt_sql "select sm3sum(\"ab\",\"cd\");"
+
+    qt_sql_gcm_1 "SELECT TO_BASE64(AES_ENCRYPT('Spark SQL', 
'1234567890abcdef', '123456789012', 'aes_128_gcm', 'Some AAD'))"
+    qt_sql_gcm_2 "SELECT 
AES_DECRYPT(FROM_BASE64('MTIzNDU2Nzg5MDEyMdXvR41sJqwZ6hnTU8FRTTtXbL8yeChIZA=='),
 '1234567890abcdef', '', 'aes_128_gcm', 'Some AAD')"
+
+    qt_sql_gcm_3 "select 
to_base64(aes_encrypt('Spark','abcdefghijklmnop12345678ABCDEFGH',unhex('000000000000000000000000'),'aes_256_gcm',
 'This is an AAD mixed into the input'));"
+    qt_sql_gcm_4 "SELECT 
AES_DECRYPT(FROM_BASE64('AAAAAAAAAAAAAAAAQiYi+sTLm7KD9UcZ2nlRdYDe/PX4'), 
'abcdefghijklmnop12345678ABCDEFGH', '', 'aes_256_gcm', 'This is an AAD mixed 
into the input');"
+
+    sql "DROP TABLE IF EXISTS aes_encrypt_decrypt_tbl"
+    sql """
+        CREATE TABLE IF NOT EXISTS aes_encrypt_decrypt_tbl (
+          id int,
+          plain_txt varchar(255),
+          enc_txt varchar(255),
+          k varchar(255),
+          iv varchar(255),
+          mode varchar(255),
+          aad varchar(255)
+        ) DISTRIBUTED BY HASH(id) BUCKETS 1
+        PROPERTIES (
+          "replication_num" = "1"
+        )
+    """
+    sql """ insert into aes_encrypt_decrypt_tbl values(1,'Spark 
SQL','MTIzNDU2Nzg5MDEyMdXvR41sJqwZ6hnTU8FRTTtXbL8yeChIZA==','1234567890abcdef','123456789012','aes_128_gcm','Some
 AAD');"""
+    sql """ insert into aes_encrypt_decrypt_tbl 
values(2,'Spark','AAAAAAAAAAAAAAAAQiYi+sTLm7KD9UcZ2nlRdYDe/PX4','abcdefghijklmnop12345678ABCDEFGH',unhex('000000000000000000000000'),'aes_256_gcm','This
 is an AAD mixed into the input');"""
+    sql """ sync """
+
+    qt_sql_gcm_5 "SELECT id,TO_BASE64(AES_ENCRYPT(plain_txt,k,iv,mode,aad)) 
from aes_encrypt_decrypt_tbl order by id;"
+    qt_sql_gcm_6 "SELECT id,AES_DECRYPT(FROM_BASE64(enc_txt),k,'',mode,aad) 
from aes_encrypt_decrypt_tbl order by id;"
+
+    // test for const opt branch, only first column is not const
+    qt_sql_gcm_7 "SELECT id,TO_BASE64(AES_ENCRYPT(plain_txt, 
'1234567890abcdef', '123456789012', 'aes_128_gcm', 'Some AAD')) from 
aes_encrypt_decrypt_tbl where id=1"
+    qt_sql_gcm_8 "SELECT AES_DECRYPT(FROM_BASE64(enc_txt), '1234567890abcdef', 
'', 'aes_128_gcm', 'Some AAD') from aes_encrypt_decrypt_tbl where id=1"
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org


Reply via email to