This is an automated email from the ASF dual-hosted git repository. zhangstar333 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new f8aa1bef470 [improve](function) add error msg if exceeded maximum default value in repeat function (#32219) f8aa1bef470 is described below commit f8aa1bef4705444fb78ca901347d1b386a16ae17 Author: zhangstar333 <87313068+zhangstar...@users.noreply.github.com> AuthorDate: Wed Mar 20 21:10:33 2024 +0800 [improve](function) add error msg if exceeded maximum default value in repeat function (#32219) add some error msg from repeat function, so the user could know the count is greater than default value. --- be/src/vec/functions/function_string.h | 39 +++++++++++++++------- be/test/vec/function/function_string_test.cpp | 21 +++++++----- .../datatype_p0/string/test_string_basic.groovy | 5 ++- .../max_msg_size_of_result_receiver.groovy | 14 ++++---- 4 files changed, 51 insertions(+), 28 deletions(-) diff --git a/be/src/vec/functions/function_string.h b/be/src/vec/functions/function_string.h index 75d8c8d4997..9ae686f3398 100644 --- a/be/src/vec/functions/function_string.h +++ b/be/src/vec/functions/function_string.h @@ -32,6 +32,7 @@ #include <ostream> #include <random> #include <sstream> +#include <stdexcept> #include <tuple> #include <utility> #include <vector> @@ -1439,6 +1440,14 @@ public: static FunctionPtr create() { return std::make_shared<FunctionStringRepeat>(); } String get_name() const override { return name; } size_t get_number_of_arguments() const override { return 2; } + std::string error_msg(int default_value, int repeat_value) const { + auto error_msg = fmt::format( + "The second parameter of repeat function exceeded maximum default value, " + "default_value is {}, and now input is {} . you could try change default value " + "greater than value eg: set repeat_max_num = {}.", + default_value, repeat_value, repeat_value + 10); + return error_msg; + } DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { return make_nullable(std::make_shared<DataTypeString>()); @@ -1456,17 +1465,20 @@ public: if (auto* col1 = check_and_get_column<ColumnString>(*argument_ptr[0])) { if (auto* col2 = check_and_get_column<ColumnInt32>(*argument_ptr[1])) { - vector_vector(col1->get_chars(), col1->get_offsets(), col2->get_data(), - res->get_chars(), res->get_offsets(), null_map->get_data(), - context->state()->repeat_max_num()); + RETURN_IF_ERROR(vector_vector(col1->get_chars(), col1->get_offsets(), + col2->get_data(), res->get_chars(), + res->get_offsets(), null_map->get_data(), + context->state()->repeat_max_num())); block.replace_by_position( result, ColumnNullable::create(std::move(res), std::move(null_map))); return Status::OK(); } else if (auto* col2_const = check_and_get_column<ColumnConst>(*argument_ptr[1])) { DCHECK(check_and_get_column<ColumnInt32>(col2_const->get_data_column())); - int repeat = 0; - repeat = std::min<int>(col2_const->get_int(0), context->state()->repeat_max_num()); - + int repeat = col2_const->get_int(0); + if (repeat > context->state()->repeat_max_num()) { + return Status::InvalidArgument( + error_msg(context->state()->repeat_max_num(), repeat)); + } if (repeat <= 0) { null_map->get_data().resize_fill(input_rows_count, 0); res->insert_many_defaults(input_rows_count); @@ -1484,10 +1496,10 @@ public: argument_ptr[0]->get_name(), argument_ptr[1]->get_name()); } - void vector_vector(const ColumnString::Chars& data, const ColumnString::Offsets& offsets, - const ColumnInt32::Container& repeats, ColumnString::Chars& res_data, - ColumnString::Offsets& res_offsets, ColumnUInt8::Container& null_map, - const int repeat_max_num) const { + Status vector_vector(const ColumnString::Chars& data, const ColumnString::Offsets& offsets, + const ColumnInt32::Container& repeats, ColumnString::Chars& res_data, + ColumnString::Offsets& res_offsets, ColumnUInt8::Container& null_map, + const int repeat_max_num) const { size_t input_row_size = offsets.size(); fmt::memory_buffer buffer; @@ -1497,8 +1509,10 @@ public: buffer.clear(); const char* raw_str = reinterpret_cast<const char*>(&data[offsets[i - 1]]); size_t size = offsets[i] - offsets[i - 1]; - int repeat = 0; - repeat = std::min<int>(repeats[i], repeat_max_num); + int repeat = repeats[i]; + if (repeat > repeat_max_num) { + return Status::InvalidArgument(error_msg(repeat_max_num, repeat)); + } if (repeat <= 0) { StringOP::push_empty_string(i, res_data, res_offsets); @@ -1512,6 +1526,7 @@ public: res_data, res_offsets); } } + return Status::OK(); } // TODO: 1. use pmr::vector<char> replace fmt_buffer may speed up the code diff --git a/be/test/vec/function/function_string_test.cpp b/be/test/vec/function/function_string_test.cpp index 39a9dca901c..612a6fff0cc 100644 --- a/be/test/vec/function/function_string_test.cpp +++ b/be/test/vec/function/function_string_test.cpp @@ -174,15 +174,20 @@ TEST(function_string_test, function_string_repeat_test) { std::string func_name = "repeat"; InputTypeSet input_types = {TypeIndex::String, TypeIndex::Int32}; - DataSet data_set = { - {{std::string("a"), 3}, std::string("aaa")}, - {{std::string("hel lo"), 2}, std::string("hel lohel lo")}, - {{std::string("hello word"), -1}, std::string("")}, - {{std::string(""), 1}, std::string("")}, - {{std::string("a"), 1073741825}, std::string("aaaaaaaaaa")}, // ut repeat max num 10 - {{std::string("HELLO,!^%"), 2}, std::string("HELLO,!^%HELLO,!^%")}, - {{std::string("你"), 2}, std::string("你你")}}; + DataSet data_set = {{{std::string("a"), 3}, std::string("aaa")}, + {{std::string("hel lo"), 2}, std::string("hel lohel lo")}, + {{std::string("hello word"), -1}, std::string("")}, + {{std::string(""), 1}, std::string("")}, + {{std::string("HELLO,!^%"), 2}, std::string("HELLO,!^%HELLO,!^%")}, + {{std::string("你"), 2}, std::string("你你")}}; static_cast<void>(check_function<DataTypeString, true>(func_name, input_types, data_set)); + + { + DataSet data_set = {{{std::string("a"), 1073741825}, + std::string("aaaaaaaaaa")}}; // ut repeat max num 10 + Status st = check_function<DataTypeString, true>(func_name, input_types, data_set, true); + EXPECT_NE(Status::OK(), st); + } } TEST(function_string_test, function_string_reverse_test) { diff --git a/regression-test/suites/datatype_p0/string/test_string_basic.groovy b/regression-test/suites/datatype_p0/string/test_string_basic.groovy index 2aa9f9e86e4..36fbddede2d 100644 --- a/regression-test/suites/datatype_p0/string/test_string_basic.groovy +++ b/regression-test/suites/datatype_p0/string/test_string_basic.groovy @@ -129,7 +129,10 @@ suite("test_string_basic") { (2, repeat("test1111", 131072)) """ order_qt_select_str_tb "select k1, md5(v1), length(v1) from ${tbName}" - + test { + sql """SELECT repeat("test1111", 131073 + 100);""" + exception "repeat function exceeded maximum default value" + } sql """drop table if exists test_string_cmp;""" sql """ diff --git a/regression-test/suites/variable_p0/max_msg_size_of_result_receiver.groovy b/regression-test/suites/variable_p0/max_msg_size_of_result_receiver.groovy index e7fead33d90..f9afdd8eadb 100644 --- a/regression-test/suites/variable_p0/max_msg_size_of_result_receiver.groovy +++ b/regression-test/suites/variable_p0/max_msg_size_of_result_receiver.groovy @@ -27,13 +27,14 @@ suite("max_msg_size_of_result_receiver") { ENGINE=OLAP DISTRIBUTED BY HASH(id) PROPERTIES("replication_num"="1") """ - + sql """set repeat_max_num=100000;""" + sql """set max_msg_size_of_result_receiver=90000;""" // so the test of repeat("a", 80000) could pass, and repeat("a", 100000) will be failed sql """ - INSERT INTO ${table_name} VALUES (104, repeat("a", ${MESSAGE_SIZE_BASE * 104})) + INSERT INTO ${table_name} VALUES (104, repeat("a", 80000)) """ sql """ - INSERT INTO ${table_name} VALUES (105, repeat("a", ${MESSAGE_SIZE_BASE * 105})) + INSERT INTO ${table_name} VALUES (105, repeat("a", 100000)) """ def with_exception = false @@ -44,10 +45,9 @@ suite("max_msg_size_of_result_receiver") { } assertEquals(with_exception, false) - try { - sql "SELECT * FROM ${table_name} WHERE id = 105" - } catch (Exception e) { - assertTrue(e.getMessage().contains('MaxMessageSize reached, try increase max_msg_size_of_result_receiver')) + test { + sql """SELECT * FROM ${table_name} WHERE id = 105;""" + exception "MaxMessageSize reached, try increase max_msg_size_of_result_receiver" } try { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org