This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-2.1 by this push: new 361a59dec86 [feature](aes_encrypt) support GCM mode for aes_encrypt and aes_decrypt (#40004) (#40672) 361a59dec86 is described below commit 361a59dec86c97b87fda6becdd0aa99dd192c536 Author: camby <camby...@tencent.com> AuthorDate: Wed Sep 11 23:28:28 2024 +0800 [feature](aes_encrypt) support GCM mode for aes_encrypt and aes_decrypt (#40004) (#40672) pick #40004 to branch-2.1 --- be/src/util/encryption_util.cpp | 180 ++++++++++++++++----- be/src/util/encryption_util.h | 17 +- be/src/vec/functions/function_encryption.cpp | 106 ++++++++---- .../functions/scalar/AesCryptoFunction.java | 25 ++- .../expressions/functions/scalar/AesDecrypt.java | 22 ++- .../expressions/functions/scalar/AesEncrypt.java | 22 ++- gensrc/script/doris_builtins_functions.py | 4 + .../encryption_digest/test_encryption_function.out | 26 +++ .../test_encryption_function.groovy | 32 ++++ 9 files changed, 357 insertions(+), 77 deletions(-) diff --git a/be/src/util/encryption_util.cpp b/be/src/util/encryption_util.cpp index ab1ff3e857c..bf6777a5c88 100644 --- a/be/src/util/encryption_util.cpp +++ b/be/src/util/encryption_util.cpp @@ -25,6 +25,7 @@ #include <algorithm> #include <cstring> #include <string> +#include <unordered_map> namespace doris { @@ -80,6 +81,12 @@ const EVP_CIPHER* get_evp_type(const EncryptionMode mode) { return EVP_aes_256_ctr(); case EncryptionMode::AES_256_OFB: return EVP_aes_256_ofb(); + case EncryptionMode::AES_128_GCM: + return EVP_aes_128_gcm(); + case EncryptionMode::AES_192_GCM: + return EVP_aes_192_gcm(); + case EncryptionMode::AES_256_GCM: + return EVP_aes_256_gcm(); case EncryptionMode::SM4_128_CBC: return EVP_sm4_cbc(); case EncryptionMode::SM4_128_ECB: @@ -95,41 +102,29 @@ const EVP_CIPHER* get_evp_type(const EncryptionMode mode) { } } -static uint mode_key_sizes[] = { - 128 /* AES_128_ECB */, - 192 /* AES_192_ECB */, - 256 /* AES_256_ECB */, - 128 /* AES_128_CBC */, - 192 /* AES_192_CBC */, - 256 /* AES_256_CBC */, - 128 /* AES_128_CFB */, - 192 /* AES_192_CFB */, - 256 /* AES_256_CFB */, - 128 /* AES_128_CFB1 */, - 192 /* AES_192_CFB1 */, - 256 /* AES_256_CFB1 */, - 128 /* AES_128_CFB8 */, - 192 /* AES_192_CFB8 */, - 256 /* AES_256_CFB8 */, - 128 /* AES_128_CFB128 */, - 192 /* AES_192_CFB128 */, - 256 /* AES_256_CFB128 */, - 128 /* AES_128_CTR */, - 192 /* AES_192_CTR */, - 256 /* AES_256_CTR */, - 128 /* AES_128_OFB */, - 192 /* AES_192_OFB */, - 256 /* AES_256_OFB */, - 128 /* SM4_128_ECB */, - 128 /* SM4_128_CBC */, - 128 /* SM4_128_CFB128 */, - 128 /* SM4_128_OFB */, - 128 /* SM4_128_CTR */ -}; +static std::unordered_map<EncryptionMode, uint> mode_key_sizes = { + {EncryptionMode::AES_128_ECB, 128}, {EncryptionMode::AES_192_ECB, 192}, + {EncryptionMode::AES_256_ECB, 256}, {EncryptionMode::AES_128_CBC, 128}, + {EncryptionMode::AES_192_CBC, 192}, {EncryptionMode::AES_256_CBC, 256}, + {EncryptionMode::AES_128_CFB, 128}, {EncryptionMode::AES_192_CFB, 192}, + {EncryptionMode::AES_256_CFB, 256}, {EncryptionMode::AES_128_CFB1, 128}, + {EncryptionMode::AES_192_CFB1, 192}, {EncryptionMode::AES_256_CFB1, 256}, + {EncryptionMode::AES_128_CFB8, 128}, {EncryptionMode::AES_192_CFB8, 192}, + {EncryptionMode::AES_256_CFB8, 256}, {EncryptionMode::AES_128_CFB128, 128}, + {EncryptionMode::AES_192_CFB128, 192}, {EncryptionMode::AES_256_CFB128, 256}, + {EncryptionMode::AES_128_CTR, 128}, {EncryptionMode::AES_192_CTR, 192}, + {EncryptionMode::AES_256_CTR, 256}, {EncryptionMode::AES_128_OFB, 128}, + {EncryptionMode::AES_192_OFB, 192}, {EncryptionMode::AES_256_OFB, 256}, + {EncryptionMode::AES_128_GCM, 128}, {EncryptionMode::AES_192_GCM, 192}, + {EncryptionMode::AES_256_GCM, 256}, + + {EncryptionMode::SM4_128_ECB, 128}, {EncryptionMode::SM4_128_CBC, 128}, + {EncryptionMode::SM4_128_CFB128, 128}, {EncryptionMode::SM4_128_OFB, 128}, + {EncryptionMode::SM4_128_CTR, 128}}; static void create_key(const unsigned char* origin_key, uint32_t key_length, uint8_t* encrypt_key, EncryptionMode mode) { - const uint key_size = mode_key_sizes[int(mode)] / 8; + const uint key_size = mode_key_sizes[mode] / 8; uint8_t* origin_key_end = ((uint8_t*)origin_key) + key_length; /* origin key boundary*/ uint8_t* encrypt_key_end; /* encrypt key boundary */ @@ -172,10 +167,58 @@ static int do_encrypt(EVP_CIPHER_CTX* cipher_ctx, const EVP_CIPHER* cipher, return ret; } +static int do_gcm_encrypt(EVP_CIPHER_CTX* cipher_ctx, const EVP_CIPHER* cipher, + const unsigned char* source, uint32_t source_length, + const unsigned char* encrypt_key, const unsigned char* iv, int iv_length, + unsigned char* encrypt, int* length_ptr, const unsigned char* aad, + uint32_t aad_length) { + int ret = EVP_EncryptInit_ex(cipher_ctx, cipher, nullptr, nullptr, nullptr); + if (ret != 1) { + return ret; + } + ret = EVP_CIPHER_CTX_ctrl(cipher_ctx, EVP_CTRL_GCM_SET_IVLEN, iv_length, nullptr); + if (ret != 1) { + return ret; + } + ret = EVP_EncryptInit_ex(cipher_ctx, nullptr, nullptr, encrypt_key, iv); + if (ret != 1) { + return ret; + } + if (aad) { + int tmp_len = 0; + ret = EVP_EncryptUpdate(cipher_ctx, nullptr, &tmp_len, aad, aad_length); + if (ret != 1) { + return ret; + } + } + + std::memcpy(encrypt, iv, iv_length); + encrypt += iv_length; + + int u_len = 0; + ret = EVP_EncryptUpdate(cipher_ctx, encrypt, &u_len, source, source_length); + if (ret != 1) { + return ret; + } + encrypt += u_len; + + int f_len = 0; + ret = EVP_EncryptFinal_ex(cipher_ctx, encrypt, &f_len); + if (ret != 1) { + return ret; + } + encrypt += f_len; + + ret = EVP_CIPHER_CTX_ctrl(cipher_ctx, EVP_CTRL_GCM_GET_TAG, EncryptionUtil::GCM_TAG_SIZE, + encrypt); + *length_ptr = iv_length + u_len + f_len + EncryptionUtil::GCM_TAG_SIZE; + return ret; +} + int EncryptionUtil::encrypt(EncryptionMode mode, const unsigned char* source, uint32_t source_length, const unsigned char* key, uint32_t key_length, const char* iv_str, int iv_input_length, bool padding, - unsigned char* encrypt) { + unsigned char* encrypt, const unsigned char* aad, uint32_t aad_length) { const EVP_CIPHER* cipher = get_evp_type(mode); /* The encrypt key to be used for encryption */ unsigned char encrypt_key[ENCRYPTION_MAX_KEY_LENGTH / 8]; @@ -196,8 +239,16 @@ int EncryptionUtil::encrypt(EncryptionMode mode, const unsigned char* source, EVP_CIPHER_CTX* cipher_ctx = EVP_CIPHER_CTX_new(); EVP_CIPHER_CTX_reset(cipher_ctx); int length = 0; - int ret = do_encrypt(cipher_ctx, cipher, source, source_length, encrypt_key, + int ret = 0; + if (is_gcm_mode(mode)) { + ret = do_gcm_encrypt(cipher_ctx, cipher, source, source_length, encrypt_key, + reinterpret_cast<unsigned char*>(init_vec), iv_length, encrypt, + &length, aad, aad_length); + } else { + ret = do_encrypt(cipher_ctx, cipher, source, source_length, encrypt_key, reinterpret_cast<unsigned char*>(init_vec), padding, encrypt, &length); + } + EVP_CIPHER_CTX_free(cipher_ctx); if (ret == 0) { ERR_clear_error(); @@ -230,10 +281,61 @@ static int do_decrypt(EVP_CIPHER_CTX* cipher_ctx, const EVP_CIPHER* cipher, return ret; } +static int do_gcm_decrypt(EVP_CIPHER_CTX* cipher_ctx, const EVP_CIPHER* cipher, + const unsigned char* encrypt, uint32_t encrypt_length, + const unsigned char* encrypt_key, int iv_length, + unsigned char* decrypt_content, int* length_ptr, const unsigned char* aad, + uint32_t aad_length) { + if (encrypt_length < iv_length + EncryptionUtil::GCM_TAG_SIZE) { + return -1; + } + int ret = EVP_DecryptInit_ex(cipher_ctx, cipher, nullptr, nullptr, nullptr); + if (ret != 1) { + return ret; + } + ret = EVP_CIPHER_CTX_ctrl(cipher_ctx, EVP_CTRL_GCM_SET_IVLEN, iv_length, nullptr); + if (ret != 1) { + return ret; + } + ret = EVP_DecryptInit_ex(cipher_ctx, nullptr, nullptr, encrypt_key, encrypt); + if (ret != 1) { + return ret; + } + encrypt += iv_length; + if (aad) { + int tmp_len = 0; + ret = EVP_DecryptUpdate(cipher_ctx, nullptr, &tmp_len, aad, aad_length); + if (ret != 1) { + return ret; + } + } + + uint32_t real_encrypt_length = encrypt_length - iv_length - EncryptionUtil::GCM_TAG_SIZE; + int u_len = 0; + ret = EVP_DecryptUpdate(cipher_ctx, decrypt_content, &u_len, encrypt, real_encrypt_length); + if (ret != 1) { + return ret; + } + encrypt += real_encrypt_length; + decrypt_content += u_len; + + void* tag = const_cast<void*>(reinterpret_cast<const void*>(encrypt)); + ret = EVP_CIPHER_CTX_ctrl(cipher_ctx, EVP_CTRL_GCM_SET_TAG, EncryptionUtil::GCM_TAG_SIZE, tag); + if (ret != 1) { + return ret; + } + + int f_len = 0; + ret = EVP_DecryptFinal_ex(cipher_ctx, decrypt_content, &f_len); + *length_ptr = u_len + f_len; + return ret; +} + int EncryptionUtil::decrypt(EncryptionMode mode, const unsigned char* encrypt, uint32_t encrypt_length, const unsigned char* key, uint32_t key_length, const char* iv_str, int iv_input_length, bool padding, - unsigned char* decrypt_content) { + unsigned char* decrypt_content, const unsigned char* aad, + uint32_t aad_length) { const EVP_CIPHER* cipher = get_evp_type(mode); /* The encrypt key to be used for decryption */ @@ -255,9 +357,15 @@ int EncryptionUtil::decrypt(EncryptionMode mode, const unsigned char* encrypt, EVP_CIPHER_CTX* cipher_ctx = EVP_CIPHER_CTX_new(); EVP_CIPHER_CTX_reset(cipher_ctx); int length = 0; - int ret = do_decrypt(cipher_ctx, cipher, encrypt, encrypt_length, encrypt_key, + int ret = 0; + if (is_gcm_mode(mode)) { + ret = do_gcm_decrypt(cipher_ctx, cipher, encrypt, encrypt_length, encrypt_key, iv_length, + decrypt_content, &length, aad, aad_length); + } else { + ret = do_decrypt(cipher_ctx, cipher, encrypt, encrypt_length, encrypt_key, reinterpret_cast<unsigned char*>(init_vec), padding, decrypt_content, &length); + } EVP_CIPHER_CTX_free(cipher_ctx); if (ret > 0) { return length; diff --git a/be/src/util/encryption_util.h b/be/src/util/encryption_util.h index 8e61a119953..dfd288c3147 100644 --- a/be/src/util/encryption_util.h +++ b/be/src/util/encryption_util.h @@ -46,6 +46,9 @@ enum class EncryptionMode { AES_128_OFB, AES_192_OFB, AES_256_OFB, + AES_128_GCM, + AES_192_GCM, + AES_256_GCM, SM4_128_ECB, SM4_128_CBC, SM4_128_CFB128, @@ -57,13 +60,23 @@ enum EncryptionState { AES_SUCCESS = 0, AES_BAD_DATA = -1 }; class EncryptionUtil { public: + static bool is_gcm_mode(EncryptionMode mode) { + return mode == EncryptionMode::AES_128_GCM || mode == EncryptionMode::AES_192_GCM || + mode == EncryptionMode::AES_256_GCM; + } + + // https://tools.ietf.org/html/rfc5116#section-5.1 + static const int GCM_TAG_SIZE = 16; + static int encrypt(EncryptionMode mode, const unsigned char* source, uint32_t source_length, const unsigned char* key, uint32_t key_length, const char* iv_str, - int iv_input_length, bool padding, unsigned char* encrypt); + int iv_input_length, bool padding, unsigned char* encrypt, + const unsigned char* aad = nullptr, uint32_t aad_length = 0); static int decrypt(EncryptionMode mode, const unsigned char* encrypt, uint32_t encrypt_length, const unsigned char* key, uint32_t key_length, const char* iv_str, - int iv_input_length, bool padding, unsigned char* decrypt_content); + int iv_input_length, bool padding, unsigned char* decrypt_content, + const unsigned char* aad = nullptr, uint32_t aad_length = 0); }; } // namespace doris diff --git a/be/src/vec/functions/function_encryption.cpp b/be/src/vec/functions/function_encryption.cpp index f63e9bca1b0..c90b6a1ff60 100644 --- a/be/src/vec/functions/function_encryption.cpp +++ b/be/src/vec/functions/function_encryption.cpp @@ -79,7 +79,10 @@ inline StringCaseUnorderedMap<EncryptionMode> aes_mode_map { {"AES_256_CTR", EncryptionMode::AES_256_CTR}, {"AES_128_OFB", EncryptionMode::AES_128_OFB}, {"AES_192_OFB", EncryptionMode::AES_192_OFB}, - {"AES_256_OFB", EncryptionMode::AES_256_OFB}}; + {"AES_256_OFB", EncryptionMode::AES_256_OFB}, + {"AES_128_GCM", EncryptionMode::AES_128_GCM}, + {"AES_192_GCM", EncryptionMode::AES_192_GCM}, + {"AES_256_GCM", EncryptionMode::AES_256_GCM}}; inline StringCaseUnorderedMap<EncryptionMode> sm4_mode_map { {"SM4_128_ECB", EncryptionMode::SM4_128_ECB}, {"SM4_128_CBC", EncryptionMode::SM4_128_CBC}, @@ -120,7 +123,7 @@ void execute_result_vector(std::vector<const ColumnString::Offsets*>& offsets_li std::vector<const ColumnString::Chars*>& chars_list, size_t i, EncryptionMode& encryption_mode, const char* iv_raw, int iv_length, ColumnString::Chars& result_data, ColumnString::Offsets& result_offset, - NullMap& null_map) { + NullMap& null_map, const char* aad, int aad_length) { int src_size = (*offsets_list[0])[i] - (*offsets_list[0])[i - 1]; const auto* src_raw = reinterpret_cast<const char*>(&(*chars_list[0])[(*offsets_list[0])[i - 1]]); @@ -128,7 +131,8 @@ void execute_result_vector(std::vector<const ColumnString::Offsets*>& offsets_li const auto* key_raw = reinterpret_cast<const char*>(&(*chars_list[1])[(*offsets_list[1])[i - 1]]); execute_result<Impl, is_encrypt>(src_raw, src_size, key_raw, key_size, i, encryption_mode, - iv_raw, iv_length, result_data, result_offset, null_map); + iv_raw, iv_length, result_data, result_offset, null_map, aad, + aad_length); } template <typename Impl, bool is_encrypt> @@ -136,19 +140,19 @@ void execute_result_const(const ColumnString::Offsets* offsets_column, const ColumnString::Chars* chars_column, StringRef key_arg, size_t i, EncryptionMode& encryption_mode, const char* iv_raw, int iv_length, ColumnString::Chars& result_data, ColumnString::Offsets& result_offset, - NullMap& null_map) { + NullMap& null_map, const char* aad, int aad_length) { int src_size = (*offsets_column)[i] - (*offsets_column)[i - 1]; const auto* src_raw = reinterpret_cast<const char*>(&(*chars_column)[(*offsets_column)[i - 1]]); execute_result<Impl, is_encrypt>(src_raw, src_size, key_arg.data, key_arg.size, i, encryption_mode, iv_raw, iv_length, result_data, result_offset, - null_map); + null_map, aad, aad_length); } template <typename Impl, bool is_encrypt> void execute_result(const char* src_raw, int src_size, const char* key_raw, int key_size, size_t i, EncryptionMode& encryption_mode, const char* iv_raw, int iv_length, ColumnString::Chars& result_data, ColumnString::Offsets& result_offset, - NullMap& null_map) { + NullMap& null_map, const char* aad, int aad_length) { if (src_size == 0) { StringOP::push_null_string(i, result_data, result_offset, null_map); return; @@ -156,6 +160,10 @@ void execute_result(const char* src_raw, int src_size, const char* key_raw, int int cipher_len = src_size; if constexpr (is_encrypt) { cipher_len += 16; + // for output AEAD tag + if (EncryptionUtil::is_gcm_mode(encryption_mode)) { + cipher_len += EncryptionUtil::GCM_TAG_SIZE; + } } std::unique_ptr<char[]> p; p.reset(new char[cipher_len]); @@ -163,7 +171,7 @@ void execute_result(const char* src_raw, int src_size, const char* key_raw, int ret_code = Impl::execute_impl(encryption_mode, (unsigned char*)src_raw, src_size, (unsigned char*)key_raw, key_size, iv_raw, iv_length, true, - (unsigned char*)p.get()); + (unsigned char*)p.get(), (unsigned char*)aad, aad_length); if (ret_code < 0) { StringOP::push_null_string(i, result_data, result_offset, null_map); @@ -248,7 +256,7 @@ struct EncryptionAndDecryptTwoImpl { } execute_result_const<Impl, is_encrypt>(offsets_column, chars_column, key_arg, i, encryption_mode, nullptr, 0, result_data, - result_offset, null_map); + result_offset, null_map, nullptr, 0); } } @@ -275,16 +283,22 @@ struct EncryptionAndDecryptTwoImpl { } execute_result_vector<Impl, is_encrypt>(offsets_list, chars_list, i, encryption_mode, nullptr, 0, result_data, result_offset, - null_map); + null_map, nullptr, 0); } } }; -template <typename Impl, EncryptionMode mode, bool is_encrypt, bool is_sm_mode> -struct EncryptionAndDecryptFourImpl { +template <typename Impl, EncryptionMode mode, bool is_encrypt, bool is_sm_mode, int arg_num = 4> +struct EncryptionAndDecryptMultiImpl { static DataTypes get_variadic_argument_types_impl() { - return {std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>(), - std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>()}; + if constexpr (arg_num == 5) { + return {std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>(), + std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>(), + std::make_shared<DataTypeString>()}; + } else { + return {std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>(), + std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>()}; + } } static Status execute_impl_inner(FunctionContext* context, Block& block, @@ -292,8 +306,8 @@ struct EncryptionAndDecryptFourImpl { size_t input_rows_count) { auto result_column = ColumnString::create(); auto result_null_map_column = ColumnUInt8::create(input_rows_count, 0); - DCHECK_EQ(4, arguments.size()); - const size_t argument_size = 4; + DCHECK_EQ(arguments.size(), arg_num); + const size_t argument_size = arg_num; bool col_const[argument_size]; ColumnPtr argument_columns[argument_size]; for (int i = 0; i < argument_size; ++i) { @@ -304,8 +318,13 @@ struct EncryptionAndDecryptFourImpl { .convert_to_full_column() : block.get_by_position(arguments[0]).column; - default_preprocess_parameter_columns(argument_columns, col_const, {1, 2, 3}, block, - arguments); + if constexpr (arg_num == 5) { + default_preprocess_parameter_columns(argument_columns, col_const, {1, 2, 3, 4}, block, + arguments); + } else { + default_preprocess_parameter_columns(argument_columns, col_const, {1, 2, 3}, block, + arguments); + } for (int i = 0; i < argument_size; i++) { check_set_nullable(argument_columns[i], result_null_map_column, col_const[i]); @@ -314,11 +333,17 @@ struct EncryptionAndDecryptFourImpl { auto& result_offset = result_column->get_offsets(); result_offset.resize(input_rows_count); - if (col_const[1] && col_const[2] && col_const[3]) { + if ((arg_num == 5) && col_const[1] && col_const[2] && col_const[3] && col_const[4]) { + vector_const(assert_cast<const ColumnString*>(argument_columns[0].get()), + argument_columns[1]->get_data_at(0), argument_columns[2]->get_data_at(0), + argument_columns[3]->get_data_at(0), input_rows_count, result_data, + result_offset, result_null_map_column->get_data(), + argument_columns[4]->get_data_at(0)); + } else if ((arg_num == 4) && col_const[1] && col_const[2] && col_const[3]) { vector_const(assert_cast<const ColumnString*>(argument_columns[0].get()), argument_columns[1]->get_data_at(0), argument_columns[2]->get_data_at(0), argument_columns[3]->get_data_at(0), input_rows_count, result_data, - result_offset, result_null_map_column->get_data()); + result_offset, result_null_map_column->get_data(), StringRef()); } else { std::vector<const ColumnString::Offsets*> offsets_list(argument_size); std::vector<const ColumnString::Chars*> chars_list(argument_size); @@ -338,7 +363,7 @@ struct EncryptionAndDecryptFourImpl { static void vector_const(const ColumnString* column, StringRef key_arg, StringRef iv_arg, StringRef mode_arg, size_t input_rows_count, ColumnString::Chars& result_data, ColumnString::Offsets& result_offset, - NullMap& null_map) { + NullMap& null_map, StringRef aad_arg) { EncryptionMode encryption_mode = mode; bool all_insert_null = false; if (mode_arg.size != 0) { @@ -363,9 +388,9 @@ struct EncryptionAndDecryptFourImpl { StringOP::push_null_string(i, result_data, result_offset, null_map); continue; } - execute_result_const<Impl, is_encrypt>(offsets_column, chars_column, key_arg, i, - encryption_mode, iv_arg.data, iv_arg.size, - result_data, result_offset, null_map); + execute_result_const<Impl, is_encrypt>( + offsets_column, chars_column, key_arg, i, encryption_mode, iv_arg.data, + iv_arg.size, result_data, result_offset, null_map, aad_arg.data, aad_arg.size); } } @@ -403,9 +428,16 @@ struct EncryptionAndDecryptFourImpl { } } + int aad_size = 0; + const char* aad = nullptr; + if constexpr (arg_num == 5) { + aad_size = (*offsets_list[4])[i] - (*offsets_list[4])[i - 1]; + aad = reinterpret_cast<const char*>(&(*chars_list[4])[(*offsets_list[4])[i - 1]]); + } + execute_result_vector<Impl, is_encrypt>(offsets_list, chars_list, i, encryption_mode, iv_raw, iv_size, result_data, result_offset, - null_map); + null_map, aad, aad_size); } } }; @@ -413,18 +445,20 @@ struct EncryptionAndDecryptFourImpl { struct EncryptImpl { static int execute_impl(EncryptionMode mode, const unsigned char* source, uint32_t source_length, const unsigned char* key, uint32_t key_length, - const char* iv, int iv_length, bool padding, unsigned char* encrypt) { + const char* iv, int iv_length, bool padding, unsigned char* encrypt, + const unsigned char* aad, int aad_length) { return EncryptionUtil::encrypt(mode, source, source_length, key, key_length, iv, iv_length, - true, encrypt); + true, encrypt, aad, aad_length); } }; struct DecryptImpl { static int execute_impl(EncryptionMode mode, const unsigned char* source, uint32_t source_length, const unsigned char* key, uint32_t key_length, - const char* iv, int iv_length, bool padding, unsigned char* encrypt) { + const char* iv, int iv_length, bool padding, unsigned char* encrypt, + const unsigned char* aad, int aad_length) { return EncryptionUtil::decrypt(mode, source, source_length, key, key_length, iv, iv_length, - true, encrypt); + true, encrypt, aad, aad_length); } }; @@ -459,16 +493,24 @@ void register_function_encryption(SimpleFunctionFactory& factory) { AESDecryptName>>(); factory.register_function<FunctionEncryptionAndDecrypt< - EncryptionAndDecryptFourImpl<EncryptImpl, EncryptionMode::SM4_128_ECB, true, true>, + EncryptionAndDecryptMultiImpl<EncryptImpl, EncryptionMode::SM4_128_ECB, true, true>, SM4EncryptName>>(); factory.register_function<FunctionEncryptionAndDecrypt< - EncryptionAndDecryptFourImpl<DecryptImpl, EncryptionMode::SM4_128_ECB, false, true>, + EncryptionAndDecryptMultiImpl<DecryptImpl, EncryptionMode::SM4_128_ECB, false, true>, SM4DecryptName>>(); factory.register_function<FunctionEncryptionAndDecrypt< - EncryptionAndDecryptFourImpl<EncryptImpl, EncryptionMode::AES_128_ECB, true, false>, + EncryptionAndDecryptMultiImpl<EncryptImpl, EncryptionMode::AES_128_ECB, true, false>, + AESEncryptName>>(); + factory.register_function<FunctionEncryptionAndDecrypt< + EncryptionAndDecryptMultiImpl<DecryptImpl, EncryptionMode::AES_128_ECB, false, false>, + AESDecryptName>>(); + + factory.register_function<FunctionEncryptionAndDecrypt< + EncryptionAndDecryptMultiImpl<EncryptImpl, EncryptionMode::AES_128_GCM, true, false, 5>, AESEncryptName>>(); factory.register_function<FunctionEncryptionAndDecrypt< - EncryptionAndDecryptFourImpl<DecryptImpl, EncryptionMode::AES_128_ECB, false, false>, + EncryptionAndDecryptMultiImpl<DecryptImpl, EncryptionMode::AES_128_GCM, false, false, + 5>, AESDecryptName>>(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesCryptoFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesCryptoFunction.java index a72b84dab1c..3a98fb64ccf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesCryptoFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesCryptoFunction.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral; import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; import com.google.common.collect.ImmutableSet; @@ -52,7 +53,16 @@ public abstract class AesCryptoFunction extends CryptoFunction { "AES_256_CTR", "AES_128_OFB", "AES_192_OFB", - "AES_256_OFB" + "AES_256_OFB", + "AES_128_GCM", + "AES_192_GCM", + "AES_256_GCM" + ); + + public static final Set<String> AES_GCM_MODES = ImmutableSet.of( + "AES_128_GCM", + "AES_192_GCM", + "AES_256_GCM" ); public AesCryptoFunction(String name, Expression... arguments) { @@ -72,4 +82,17 @@ public abstract class AesCryptoFunction extends CryptoFunction { } return encryptionMode; } + + @Override + public void checkLegalityAfterRewrite() { + if (arity() >= 4 && child(3) instanceof StringLikeLiteral) { + String mode = ((StringLikeLiteral) child(3)).getValue().toUpperCase(); + if (!AES_MODES.contains(mode)) { + throw new AnalysisException("mode " + mode + " is not supported"); + } + if (arity() == 5 && !AES_GCM_MODES.contains(mode)) { + throw new AnalysisException("only GCM mode support AAD(the 5th arg)"); + } + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesDecrypt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesDecrypt.java index 7608cf4e40e..8967b0e5138 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesDecrypt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesDecrypt.java @@ -50,7 +50,16 @@ public class AesDecrypt extends AesCryptoFunction { VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT), FunctionSignature.ret(StringType.INSTANCE) - .args(StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE) + .args(StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE), + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT) + .args(VarcharType.SYSTEM_DEFAULT, + VarcharType.SYSTEM_DEFAULT, + VarcharType.SYSTEM_DEFAULT, + VarcharType.SYSTEM_DEFAULT, + VarcharType.SYSTEM_DEFAULT), + FunctionSignature.ret(StringType.INSTANCE) + .args(StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, + StringType.INSTANCE) ); /** @@ -68,18 +77,25 @@ public class AesDecrypt extends AesCryptoFunction { super("aes_decrypt", arg0, arg1, arg2, arg3); } + public AesDecrypt(Expression arg0, Expression arg1, Expression arg2, Expression arg3, Expression arg4) { + super("aes_decrypt", arg0, arg1, arg2, arg3, arg4); + } + /** * withChildren. */ @Override public AesDecrypt withChildren(List<Expression> children) { - Preconditions.checkArgument(children.size() >= 2 && children.size() <= 4); + Preconditions.checkArgument(children.size() >= 2 && children.size() <= 5); if (children.size() == 2) { return new AesDecrypt(children.get(0), children.get(1)); } else if (children().size() == 3) { return new AesDecrypt(children.get(0), children.get(1), children.get(2)); - } else { + } else if (children().size() == 4) { return new AesDecrypt(children.get(0), children.get(1), children.get(2), children.get(3)); + } else { + return new AesDecrypt(children.get(0), children.get(1), children.get(2), children.get(3), + children.get(4)); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesEncrypt.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesEncrypt.java index 455d6b0dbd5..a70a639785f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesEncrypt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/AesEncrypt.java @@ -50,7 +50,16 @@ public class AesEncrypt extends AesCryptoFunction { VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT), FunctionSignature.ret(StringType.INSTANCE) - .args(StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE) + .args(StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE), + FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT) + .args(VarcharType.SYSTEM_DEFAULT, + VarcharType.SYSTEM_DEFAULT, + VarcharType.SYSTEM_DEFAULT, + VarcharType.SYSTEM_DEFAULT, + VarcharType.SYSTEM_DEFAULT), + FunctionSignature.ret(StringType.INSTANCE) + .args(StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, + StringType.INSTANCE) ); /** @@ -68,18 +77,25 @@ public class AesEncrypt extends AesCryptoFunction { super("aes_encrypt", arg0, arg1, arg2, arg3); } + public AesEncrypt(Expression arg0, Expression arg1, Expression arg2, Expression arg3, Expression arg4) { + super("aes_encrypt", arg0, arg1, arg2, arg3, arg4); + } + /** * withChildren. */ @Override public AesEncrypt withChildren(List<Expression> children) { - Preconditions.checkArgument(children.size() >= 2 && children.size() <= 4); + Preconditions.checkArgument(children.size() >= 2 && children.size() <= 5); if (children.size() == 2) { return new AesEncrypt(children.get(0), children.get(1)); } else if (children().size() == 3) { return new AesEncrypt(children.get(0), children.get(1), children.get(2)); - } else { + } else if (children().size() == 4) { return new AesEncrypt(children.get(0), children.get(1), children.get(2), children.get(3)); + } else { + return new AesEncrypt(children.get(0), children.get(1), children.get(2), children.get(3), + children.get(4)); } } diff --git a/gensrc/script/doris_builtins_functions.py b/gensrc/script/doris_builtins_functions.py index 018d71385e0..6deede3784d 100644 --- a/gensrc/script/doris_builtins_functions.py +++ b/gensrc/script/doris_builtins_functions.py @@ -1919,6 +1919,8 @@ visible_functions = { [['aes_decrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'], [['aes_encrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'], [['aes_decrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'], + [['aes_encrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'], + [['aes_decrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'], [['sm4_encrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'], [['sm4_decrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'], [['sm4_encrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'], @@ -1928,6 +1930,8 @@ visible_functions = { [['aes_decrypt'], 'STRING', ['STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'], [['aes_encrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'], [['aes_decrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'], + [['aes_encrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'], + [['aes_decrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'], [['sm4_encrypt'], 'STRING', ['STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'], [['sm4_decrypt'], 'STRING', ['STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'], [['sm4_encrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'], diff --git a/regression-test/data/nereids_p0/sql_functions/encryption_digest/test_encryption_function.out b/regression-test/data/nereids_p0/sql_functions/encryption_digest/test_encryption_function.out index 7a91c1dbf72..ec96e6df2b0 100644 --- a/regression-test/data/nereids_p0/sql_functions/encryption_digest/test_encryption_function.out +++ b/regression-test/data/nereids_p0/sql_functions/encryption_digest/test_encryption_function.out @@ -56,3 +56,29 @@ text -- !sql -- 82ec580fe6d36ae4f81cae3c73f4a5b3b5a09c943172dc9053c69fd8e18dca1e +-- !sql_gcm_1 -- +MTIzNDU2Nzg5MDEyMdXvR41sJqwZ6hnTU8FRTTtXbL8yeChIZA== + +-- !sql_gcm_2 -- +Spark SQL + +-- !sql_gcm_3 -- +AAAAAAAAAAAAAAAAQiYi+sTLm7KD9UcZ2nlRdYDe/PX4 + +-- !sql_gcm_4 -- +Spark + +-- !sql_gcm_5 -- +1 MTIzNDU2Nzg5MDEyMdXvR41sJqwZ6hnTU8FRTTtXbL8yeChIZA== +2 AAAAAAAAAAAAAAAAQiYi+sTLm7KD9UcZ2nlRdYDe/PX4 + +-- !sql_gcm_6 -- +1 Spark SQL +2 Spark + +-- !sql_gcm_7 -- +1 MTIzNDU2Nzg5MDEyMdXvR41sJqwZ6hnTU8FRTTtXbL8yeChIZA== + +-- !sql_gcm_8 -- +Spark SQL + diff --git a/regression-test/suites/nereids_p0/sql_functions/encryption_digest/test_encryption_function.groovy b/regression-test/suites/nereids_p0/sql_functions/encryption_digest/test_encryption_function.groovy index 0a2d4261001..f96290faffb 100644 --- a/regression-test/suites/nereids_p0/sql_functions/encryption_digest/test_encryption_function.groovy +++ b/regression-test/suites/nereids_p0/sql_functions/encryption_digest/test_encryption_function.groovy @@ -57,4 +57,36 @@ suite("test_encryption_function") { qt_sql "SELECT SM3(\"abc\");" qt_sql "select sm3(\"abcd\");" qt_sql "select sm3sum(\"ab\",\"cd\");" + + qt_sql_gcm_1 "SELECT TO_BASE64(AES_ENCRYPT('Spark SQL', '1234567890abcdef', '123456789012', 'aes_128_gcm', 'Some AAD'))" + qt_sql_gcm_2 "SELECT AES_DECRYPT(FROM_BASE64('MTIzNDU2Nzg5MDEyMdXvR41sJqwZ6hnTU8FRTTtXbL8yeChIZA=='), '1234567890abcdef', '', 'aes_128_gcm', 'Some AAD')" + + qt_sql_gcm_3 "select to_base64(aes_encrypt('Spark','abcdefghijklmnop12345678ABCDEFGH',unhex('000000000000000000000000'),'aes_256_gcm', 'This is an AAD mixed into the input'));" + qt_sql_gcm_4 "SELECT AES_DECRYPT(FROM_BASE64('AAAAAAAAAAAAAAAAQiYi+sTLm7KD9UcZ2nlRdYDe/PX4'), 'abcdefghijklmnop12345678ABCDEFGH', '', 'aes_256_gcm', 'This is an AAD mixed into the input');" + + sql "DROP TABLE IF EXISTS aes_encrypt_decrypt_tbl" + sql """ + CREATE TABLE IF NOT EXISTS aes_encrypt_decrypt_tbl ( + id int, + plain_txt varchar(255), + enc_txt varchar(255), + k varchar(255), + iv varchar(255), + mode varchar(255), + aad varchar(255) + ) DISTRIBUTED BY HASH(id) BUCKETS 1 + PROPERTIES ( + "replication_num" = "1" + ) + """ + sql """ insert into aes_encrypt_decrypt_tbl values(1,'Spark SQL','MTIzNDU2Nzg5MDEyMdXvR41sJqwZ6hnTU8FRTTtXbL8yeChIZA==','1234567890abcdef','123456789012','aes_128_gcm','Some AAD');""" + sql """ insert into aes_encrypt_decrypt_tbl values(2,'Spark','AAAAAAAAAAAAAAAAQiYi+sTLm7KD9UcZ2nlRdYDe/PX4','abcdefghijklmnop12345678ABCDEFGH',unhex('000000000000000000000000'),'aes_256_gcm','This is an AAD mixed into the input');""" + sql """ sync """ + + qt_sql_gcm_5 "SELECT id,TO_BASE64(AES_ENCRYPT(plain_txt,k,iv,mode,aad)) from aes_encrypt_decrypt_tbl order by id;" + qt_sql_gcm_6 "SELECT id,AES_DECRYPT(FROM_BASE64(enc_txt),k,'',mode,aad) from aes_encrypt_decrypt_tbl order by id;" + + // test for const opt branch, only first column is not const + qt_sql_gcm_7 "SELECT id,TO_BASE64(AES_ENCRYPT(plain_txt, '1234567890abcdef', '123456789012', 'aes_128_gcm', 'Some AAD')) from aes_encrypt_decrypt_tbl where id=1" + qt_sql_gcm_8 "SELECT AES_DECRYPT(FROM_BASE64(enc_txt), '1234567890abcdef', '', 'aes_128_gcm', 'Some AAD') from aes_encrypt_decrypt_tbl where id=1" } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org