This is an automated email from the ASF dual-hosted git repository. morningman pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-doris.git
The following commit(s) were added to refs/heads/master by this push: new b26e7e3 [feature](function)(vec) support locate function (#7988) b26e7e3 is described below commit b26e7e3c284bdae98ac69d0e786b3e623e543e38 Author: Pxl <952130...@qq.com> AuthorDate: Sat Feb 12 16:00:37 2022 +0800 [feature](function)(vec) support locate function (#7988) * support function locate in vectorized engine * add ut and fix some bug --- be/src/vec/common/string_ref.h | 7 +- be/src/vec/functions/function_string.cpp | 32 ++++---- be/src/vec/functions/function_string.h | 110 +++++++++++++++++++++++++- be/src/vec/functions/function_totype.h | 8 ++ be/test/vec/function/function_string_test.cpp | 46 +++++++++-- 5 files changed, 174 insertions(+), 29 deletions(-) diff --git a/be/src/vec/common/string_ref.h b/be/src/vec/common/string_ref.h index bd81342..727996e 100644 --- a/be/src/vec/common/string_ref.h +++ b/be/src/vec/common/string_ref.h @@ -27,6 +27,7 @@ #include "gutil/hash/city.h" #include "gutil/hash/hash128to64.h" +#include "udf/udf.h" #include "vec/common/unaligned.h" #include "vec/core/types.h" @@ -53,6 +54,10 @@ struct StringRef { std::string to_string() const { return std::string(data, size); } explicit operator std::string() const { return to_string(); } + + StringVal to_string_val() const { + return StringVal(reinterpret_cast<uint8_t*>(const_cast<char*>(data)), size); + } }; using StringRefs = std::vector<StringRef>; @@ -291,7 +296,7 @@ struct StringRefHash : CRC32Hash {}; struct CRC32Hash { size_t operator()(StringRef /* x */) const { - throw std::logic_error{"Not implemented CRC32Hash without SSE"}; + throw std::logic_error {"Not implemented CRC32Hash without SSE"}; } }; diff --git a/be/src/vec/functions/function_string.cpp b/be/src/vec/functions/function_string.cpp index a89edf5..94ed9c9 100644 --- a/be/src/vec/functions/function_string.cpp +++ b/be/src/vec/functions/function_string.cpp @@ -95,14 +95,7 @@ struct StringUtf8LengthImpl { for (int i = 0; i < size; ++i) { const char* raw_str = reinterpret_cast<const char*>(&data[offsets[i - 1]]); int str_size = offsets[i] - offsets[i - 1] - 1; - - size_t char_len = 0; - for (size_t i = 0, char_size = 0; i < str_size; i += char_size) { - char_size = get_utf8_byte_length((unsigned)(raw_str)[i]); - ++char_len; - } - - res[i] = char_len; + res[i] = get_char_len(StringValue(const_cast<char*>(raw_str), str_size), str_size); } return Status::OK(); } @@ -201,17 +194,19 @@ struct InStrOP { // Hive returns positions starting from 1. int loc = search.search(&str_sv); if (loc > 0) { - size_t char_len = 0; - for (size_t i = 0, char_size = 0; i < loc; i += char_size) { - char_size = get_utf8_byte_length((unsigned)(strl.data())[i]); - ++char_len; - } - loc = char_len; + loc = get_char_len(str_sv, loc); } res = loc + 1; } }; +struct LocateOP { + using ResultDataType = DataTypeInt32; + using ResultPaddedPODArray = PaddedPODArray<Int32>; + static void execute(const std::string_view& strl, const std::string_view& strr, int32_t& res) { + InStrOP::execute(strr, strl, res); + } +}; // LeftDataType and RightDataType are DataTypeString template <typename LeftDataType, typename RightDataType, typename OP> @@ -706,6 +701,9 @@ template <typename LeftDataType, typename RightDataType> using StringInstrImpl = StringFunctionImpl<LeftDataType, RightDataType, InStrOP>; template <typename LeftDataType, typename RightDataType> +using StringLocateImpl = StringFunctionImpl<LeftDataType, RightDataType, LocateOP>; + +template <typename LeftDataType, typename RightDataType> using StringFindInSetImpl = StringFunctionImpl<LeftDataType, RightDataType, FindInSetOp>; // ready for regist function @@ -720,7 +718,7 @@ using FunctionStringEndsWith = using FunctionStringInstr = FunctionBinaryToType<DataTypeString, DataTypeString, StringInstrImpl, NameInstr>; using FunctionStringLocate = - FunctionBinaryToType<DataTypeString, DataTypeString, StringInstrImpl, NameLocate>; + FunctionBinaryToType<DataTypeString, DataTypeString, StringLocateImpl, NameLocate>; using FunctionStringFindInSet = FunctionBinaryToType<DataTypeString, DataTypeString, StringFindInSetImpl, NameFindInSet>; @@ -755,7 +753,6 @@ using FunctionStringLPad = FunctionStringPad<StringLPad>; using FunctionStringRPad = FunctionStringPad<StringRPad>; void register_function_string(SimpleFunctionFactory& factory) { - // factory.register_function<>(); factory.register_function<FunctionStringASCII>(); factory.register_function<FunctionStringLength>(); factory.register_function<FunctionStringUTF8Length>(); @@ -764,7 +761,8 @@ void register_function_string(SimpleFunctionFactory& factory) { factory.register_function<FunctionStringEndsWith>(); factory.register_function<FunctionStringInstr>(); factory.register_function<FunctionStringFindInSet>(); - // factory.register_function<FunctionStringLocate>(); + factory.register_function<FunctionStringLocate>(); + factory.register_function<FunctionStringLocatePos>(); factory.register_function<FunctionReverse>(); factory.register_function<FunctionHexString>(); factory.register_function<FunctionUnHex>(); diff --git a/be/src/vec/functions/function_string.h b/be/src/vec/functions/function_string.h index af58062..934053e 100644 --- a/be/src/vec/functions/function_string.h +++ b/be/src/vec/functions/function_string.h @@ -21,18 +21,21 @@ #include <fmt/format.h> #include <fmt/ranges.h> +#include <cstdint> #include <string_view> #include "exprs/anyval_util.h" #include "exprs/math_functions.h" #include "exprs/string_functions.h" #include "runtime/string_value.hpp" +#include "udf/udf.h" #include "util/md5.h" #include "util/url_parser.h" #include "vec/columns/column_decimal.h" #include "vec/columns/column_nullable.h" #include "vec/columns/column_string.h" #include "vec/columns/columns_number.h" +#include "vec/common/string_ref.h" #include "vec/data_types/data_type_nullable.h" #include "vec/data_types/data_type_number.h" #include "vec/data_types/data_type_string.h" @@ -70,6 +73,25 @@ inline size_t get_char_len(const std::string_view& str, std::vector<size_t>* str return char_len; } +inline size_t get_char_len(const StringVal& str, std::vector<size_t>* str_index) { + size_t char_len = 0; + for (size_t i = 0, char_size = 0; i < str.len; i += char_size) { + char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]); + str_index->push_back(i); + ++char_len; + } + return char_len; +} + +inline size_t get_char_len(const StringValue& str, size_t end_pos) { + size_t char_len = 0; + for (size_t i = 0, char_size = 0; i < std::min(str.len, end_pos); i += char_size) { + char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]); + ++char_len; + } + return char_len; +} + struct StringOP { static void push_empty_string(int index, ColumnString::Chars& chars, ColumnString::Offsets& offsets) { @@ -1079,7 +1101,7 @@ struct MoneyFormatDoubleImpl { static DataTypes get_variadic_argument_types() { return {std::make_shared<DataTypeFloat64>()}; } static void execute(FunctionContext* context, ColumnString* result_column, - const ColumnType* data_column, size_t input_rows_count) { + const ColumnType* data_column, size_t input_rows_count) { for (size_t i = 0; i < input_rows_count; i++) { double value = MathFunctions::my_double_round(data_column->get_element(i), 2, false, false); @@ -1095,7 +1117,7 @@ struct MoneyFormatInt64Impl { static DataTypes get_variadic_argument_types() { return {std::make_shared<DataTypeInt64>()}; } static void execute(FunctionContext* context, ColumnString* result_column, - const ColumnType* data_column, size_t input_rows_count) { + const ColumnType* data_column, size_t input_rows_count) { for (size_t i = 0; i < input_rows_count; i++) { Int64 value = data_column->get_element(i); StringVal str = StringFunctions::do_money_format<Int64, 26>(context, value); @@ -1110,7 +1132,7 @@ struct MoneyFormatInt128Impl { static DataTypes get_variadic_argument_types() { return {std::make_shared<DataTypeInt128>()}; } static void execute(FunctionContext* context, ColumnString* result_column, - const ColumnType* data_column, size_t input_rows_count) { + const ColumnType* data_column, size_t input_rows_count) { for (size_t i = 0; i < input_rows_count; i++) { Int128 value = data_column->get_element(i); StringVal str = StringFunctions::do_money_format<Int128, 52>(context, value); @@ -1127,7 +1149,7 @@ struct MoneyFormatDecimalImpl { } static void execute(FunctionContext* context, ColumnString* result_column, - const ColumnType* data_column, size_t input_rows_count) { + const ColumnType* data_column, size_t input_rows_count) { for (size_t i = 0; i < input_rows_count; i++) { DecimalV2Val value = DecimalV2Val(data_column->get_element(i)); @@ -1142,4 +1164,84 @@ struct MoneyFormatDecimalImpl { } }; +class FunctionStringLocatePos : public IFunction { +public: + static constexpr auto name = "locate"; + static FunctionPtr create() { return std::make_shared<FunctionStringLocatePos>(); } + String get_name() const override { return name; } + size_t get_number_of_arguments() const override { return 3; } + + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { + return std::make_shared<DataTypeInt32>(); + } + + DataTypes get_variadic_argument_types_impl() const override { + return {std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>(), + std::make_shared<DataTypeInt32>()}; + } + + bool is_variadic() const override { return true; } + + bool use_default_implementation_for_constants() const override { return true; } + + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + size_t result, size_t input_rows_count) override { + auto col_substr = + block.get_by_position(arguments[0]).column->convert_to_full_column_if_const(); + auto col_str = + block.get_by_position(arguments[1]).column->convert_to_full_column_if_const(); + auto col_pos = + block.get_by_position(arguments[2]).column->convert_to_full_column_if_const(); + + ColumnInt32::MutablePtr col_res = ColumnInt32::create(); + + auto& vec_pos = reinterpret_cast<const ColumnInt32*>(col_pos.get())->get_data(); + auto& vec_res = col_res->get_data(); + vec_res.resize(input_rows_count); + + for (int i = 0; i < input_rows_count; ++i) { + vec_res[i] = locate_pos(col_substr->get_data_at(i).to_string_val(), + col_str->get_data_at(i).to_string_val(), vec_pos[i]); + } + + block.replace_by_position(result, std::move(col_res)); + return Status::OK(); + } + +private: + int locate_pos(StringVal substr, StringVal str, int start_pos) { + if (substr.len == 0) { + if (start_pos <= 0) { + return 0; + } else if (start_pos == 1) { + return 1; + } else if (start_pos > str.len) { + return 0; + } else { + return start_pos; + } + } + // Hive returns 0 for *start_pos <= 0, + // but throws an exception for *start_pos > str->len. + // Since returning 0 seems to be Hive's error condition, return 0. + std::vector<size_t> index; + size_t char_len = get_char_len(str, &index); + if (start_pos <= 0 || start_pos > str.len || start_pos > char_len) { + return 0; + } + StringValue substr_sv = StringValue::from_string_val(substr); + StringSearch search(&substr_sv); + // Input start_pos starts from 1. + StringValue adjusted_str(reinterpret_cast<char*>(str.ptr) + index[start_pos - 1], + str.len - index[start_pos - 1]); + int32_t match_pos = search.search(&adjusted_str); + if (match_pos >= 0) { + // Hive returns the position in the original string starting from 1. + return start_pos + get_char_len(adjusted_str, match_pos); + } else { + return 0; + } + } +}; + } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/functions/function_totype.h b/be/src/vec/functions/function_totype.h index 72e74b3..fef300e 100644 --- a/be/src/vec/functions/function_totype.h +++ b/be/src/vec/functions/function_totype.h @@ -195,9 +195,17 @@ public: static FunctionPtr create() { return std::make_shared<FunctionBinaryToType>(); } String get_name() const override { return name; } size_t get_number_of_arguments() const override { return 2; } + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { return std::make_shared<ResultDataType>(); } + + DataTypes get_variadic_argument_types_impl() const override { + return {std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>()}; + } + + bool is_variadic() const override { return true; } + bool use_default_implementation_for_constants() const override { return true; } Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, diff --git a/be/test/vec/function/function_string_test.cpp b/be/test/vec/function/function_string_test.cpp index 5617a36..60184fc 100644 --- a/be/test/vec/function/function_string_test.cpp +++ b/be/test/vec/function/function_string_test.cpp @@ -25,8 +25,10 @@ #include "util/encryption_util.h" #include "util/url_coding.h" #include "vec/core/field.h" +#include "vec/core/types.h" namespace doris::vectorized { +using namespace ut_type; TEST(function_string_test, function_string_substr_test) { std::string func_name = "substr"; @@ -440,17 +442,47 @@ TEST(function_string_test, function_instr_test) { InputTypeSet input_types = {TypeIndex::String, TypeIndex::String}; - DataSet data_set = {{{std::string("abcdefg"), std::string("efg")}, 5}, - {{std::string("aa"), std::string("a")}, 1}, - {{std::string("我是"), std::string("是")}, 2}, - {{std::string("abcd"), std::string("e")}, 0}, - {{std::string("abcdef"), std::string("")}, 1}, - {{std::string(""), std::string("")}, 1}, - {{std::string("aaaab"), std::string("bb")}, 0}}; + DataSet data_set = { + {{STRING("abcdefg"), STRING("efg")}, INT(5)}, {{STRING("aa"), STRING("a")}, INT(1)}, + {{STRING("我是"), STRING("是")}, INT(2)}, {{STRING("abcd"), STRING("e")}, INT(0)}, + {{STRING("abcdef"), STRING("")}, INT(1)}, {{STRING(""), STRING("")}, INT(1)}, + {{STRING("aaaab"), STRING("bb")}, INT(0)}}; check_function<DataTypeInt32, true>(func_name, input_types, data_set); } +TEST(function_string_test, function_locate_test) { + std::string func_name = "locate"; + + { + InputTypeSet input_types = {TypeIndex::String, TypeIndex::String}; + + DataSet data_set = {{{STRING("efg"), STRING("abcdefg")}, INT(5)}, + {{STRING("a"), STRING("aa")}, INT(1)}, + {{STRING("是"), STRING("我是")}, INT(2)}, + {{STRING("e"), STRING("abcd")}, INT(0)}, + {{STRING(""), STRING("abcdef")}, INT(1)}, + {{STRING(""), STRING("")}, INT(1)}, + {{STRING("bb"), STRING("aaaab")}, INT(0)}}; + + check_function<DataTypeInt32, true>(func_name, input_types, data_set); + } + + { + InputTypeSet input_types = {TypeIndex::String, TypeIndex::String, TypeIndex::Int32}; + + DataSet data_set = {{{STRING("bar"), STRING("foobarbar"), INT(5)}, INT(7)}, + {{STRING("xbar"), STRING("foobar"), INT(1)}, INT(0)}, + {{STRING(""), STRING("foobar"), INT(2)}, INT(2)}, + {{STRING("A"), STRING("大A写的A"), INT(0)}, INT(0)}, + {{STRING("A"), STRING("大A写的A"), INT(1)}, INT(2)}, + {{STRING("A"), STRING("大A写的A"), INT(2)}, INT(2)}, + {{STRING("A"), STRING("大A写的A"), INT(3)}, INT(5)}}; + + check_function<DataTypeInt32, true>(func_name, input_types, data_set); + } +} + TEST(function_string_test, function_find_in_set_test) { std::string func_name = "find_in_set"; --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org