This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-doris.git


The following commit(s) were added to refs/heads/master by this push:
     new b26e7e3  [feature](function)(vec) support locate function (#7988)
b26e7e3 is described below

commit b26e7e3c284bdae98ac69d0e786b3e623e543e38
Author: Pxl <952130...@qq.com>
AuthorDate: Sat Feb 12 16:00:37 2022 +0800

    [feature](function)(vec) support locate function (#7988)
    
    * support function locate in vectorized engine
    
    * add ut and fix some bug
---
 be/src/vec/common/string_ref.h                |   7 +-
 be/src/vec/functions/function_string.cpp      |  32 ++++----
 be/src/vec/functions/function_string.h        | 110 +++++++++++++++++++++++++-
 be/src/vec/functions/function_totype.h        |   8 ++
 be/test/vec/function/function_string_test.cpp |  46 +++++++++--
 5 files changed, 174 insertions(+), 29 deletions(-)

diff --git a/be/src/vec/common/string_ref.h b/be/src/vec/common/string_ref.h
index bd81342..727996e 100644
--- a/be/src/vec/common/string_ref.h
+++ b/be/src/vec/common/string_ref.h
@@ -27,6 +27,7 @@
 
 #include "gutil/hash/city.h"
 #include "gutil/hash/hash128to64.h"
+#include "udf/udf.h"
 #include "vec/common/unaligned.h"
 #include "vec/core/types.h"
 
@@ -53,6 +54,10 @@ struct StringRef {
     std::string to_string() const { return std::string(data, size); }
 
     explicit operator std::string() const { return to_string(); }
+
+    StringVal to_string_val() const {
+        return StringVal(reinterpret_cast<uint8_t*>(const_cast<char*>(data)), 
size);
+    }
 };
 
 using StringRefs = std::vector<StringRef>;
@@ -291,7 +296,7 @@ struct StringRefHash : CRC32Hash {};
 
 struct CRC32Hash {
     size_t operator()(StringRef /* x */) const {
-        throw std::logic_error{"Not implemented CRC32Hash without SSE"};
+        throw std::logic_error {"Not implemented CRC32Hash without SSE"};
     }
 };
 
diff --git a/be/src/vec/functions/function_string.cpp 
b/be/src/vec/functions/function_string.cpp
index a89edf5..94ed9c9 100644
--- a/be/src/vec/functions/function_string.cpp
+++ b/be/src/vec/functions/function_string.cpp
@@ -95,14 +95,7 @@ struct StringUtf8LengthImpl {
         for (int i = 0; i < size; ++i) {
             const char* raw_str = reinterpret_cast<const 
char*>(&data[offsets[i - 1]]);
             int str_size = offsets[i] - offsets[i - 1] - 1;
-
-            size_t char_len = 0;
-            for (size_t i = 0, char_size = 0; i < str_size; i += char_size) {
-                char_size = get_utf8_byte_length((unsigned)(raw_str)[i]);
-                ++char_len;
-            }
-
-            res[i] = char_len;
+            res[i] = get_char_len(StringValue(const_cast<char*>(raw_str), 
str_size), str_size);
         }
         return Status::OK();
     }
@@ -201,17 +194,19 @@ struct InStrOP {
         // Hive returns positions starting from 1.
         int loc = search.search(&str_sv);
         if (loc > 0) {
-            size_t char_len = 0;
-            for (size_t i = 0, char_size = 0; i < loc; i += char_size) {
-                char_size = get_utf8_byte_length((unsigned)(strl.data())[i]);
-                ++char_len;
-            }
-            loc = char_len;
+            loc = get_char_len(str_sv, loc);
         }
 
         res = loc + 1;
     }
 };
+struct LocateOP {
+    using ResultDataType = DataTypeInt32;
+    using ResultPaddedPODArray = PaddedPODArray<Int32>;
+    static void execute(const std::string_view& strl, const std::string_view& 
strr, int32_t& res) {
+        InStrOP::execute(strr, strl, res);
+    }
+};
 
 // LeftDataType and RightDataType are DataTypeString
 template <typename LeftDataType, typename RightDataType, typename OP>
@@ -706,6 +701,9 @@ template <typename LeftDataType, typename RightDataType>
 using StringInstrImpl = StringFunctionImpl<LeftDataType, RightDataType, 
InStrOP>;
 
 template <typename LeftDataType, typename RightDataType>
+using StringLocateImpl = StringFunctionImpl<LeftDataType, RightDataType, 
LocateOP>;
+
+template <typename LeftDataType, typename RightDataType>
 using StringFindInSetImpl = StringFunctionImpl<LeftDataType, RightDataType, 
FindInSetOp>;
 
 // ready for regist function
@@ -720,7 +718,7 @@ using FunctionStringEndsWith =
 using FunctionStringInstr =
         FunctionBinaryToType<DataTypeString, DataTypeString, StringInstrImpl, 
NameInstr>;
 using FunctionStringLocate =
-        FunctionBinaryToType<DataTypeString, DataTypeString, StringInstrImpl, 
NameLocate>;
+        FunctionBinaryToType<DataTypeString, DataTypeString, StringLocateImpl, 
NameLocate>;
 using FunctionStringFindInSet =
         FunctionBinaryToType<DataTypeString, DataTypeString, 
StringFindInSetImpl, NameFindInSet>;
 
@@ -755,7 +753,6 @@ using FunctionStringLPad = FunctionStringPad<StringLPad>;
 using FunctionStringRPad = FunctionStringPad<StringRPad>;
 
 void register_function_string(SimpleFunctionFactory& factory) {
-    // factory.register_function<>();
     factory.register_function<FunctionStringASCII>();
     factory.register_function<FunctionStringLength>();
     factory.register_function<FunctionStringUTF8Length>();
@@ -764,7 +761,8 @@ void register_function_string(SimpleFunctionFactory& 
factory) {
     factory.register_function<FunctionStringEndsWith>();
     factory.register_function<FunctionStringInstr>();
     factory.register_function<FunctionStringFindInSet>();
-    //    factory.register_function<FunctionStringLocate>();
+    factory.register_function<FunctionStringLocate>();
+    factory.register_function<FunctionStringLocatePos>();
     factory.register_function<FunctionReverse>();
     factory.register_function<FunctionHexString>();
     factory.register_function<FunctionUnHex>();
diff --git a/be/src/vec/functions/function_string.h 
b/be/src/vec/functions/function_string.h
index af58062..934053e 100644
--- a/be/src/vec/functions/function_string.h
+++ b/be/src/vec/functions/function_string.h
@@ -21,18 +21,21 @@
 #include <fmt/format.h>
 #include <fmt/ranges.h>
 
+#include <cstdint>
 #include <string_view>
 
 #include "exprs/anyval_util.h"
 #include "exprs/math_functions.h"
 #include "exprs/string_functions.h"
 #include "runtime/string_value.hpp"
+#include "udf/udf.h"
 #include "util/md5.h"
 #include "util/url_parser.h"
 #include "vec/columns/column_decimal.h"
 #include "vec/columns/column_nullable.h"
 #include "vec/columns/column_string.h"
 #include "vec/columns/columns_number.h"
+#include "vec/common/string_ref.h"
 #include "vec/data_types/data_type_nullable.h"
 #include "vec/data_types/data_type_number.h"
 #include "vec/data_types/data_type_string.h"
@@ -70,6 +73,25 @@ inline size_t get_char_len(const std::string_view& str, 
std::vector<size_t>* str
     return char_len;
 }
 
+inline size_t get_char_len(const StringVal& str, std::vector<size_t>* 
str_index) {
+    size_t char_len = 0;
+    for (size_t i = 0, char_size = 0; i < str.len; i += char_size) {
+        char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]);
+        str_index->push_back(i);
+        ++char_len;
+    }
+    return char_len;
+}
+
+inline size_t get_char_len(const StringValue& str, size_t end_pos) {
+    size_t char_len = 0;
+    for (size_t i = 0, char_size = 0; i < std::min(str.len, end_pos); i += 
char_size) {
+        char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]);
+        ++char_len;
+    }
+    return char_len;
+}
+
 struct StringOP {
     static void push_empty_string(int index, ColumnString::Chars& chars,
                                   ColumnString::Offsets& offsets) {
@@ -1079,7 +1101,7 @@ struct MoneyFormatDoubleImpl {
     static DataTypes get_variadic_argument_types() { return 
{std::make_shared<DataTypeFloat64>()}; }
 
     static void execute(FunctionContext* context, ColumnString* result_column,
-                          const ColumnType* data_column, size_t 
input_rows_count) {
+                        const ColumnType* data_column, size_t 
input_rows_count) {
         for (size_t i = 0; i < input_rows_count; i++) {
             double value =
                     
MathFunctions::my_double_round(data_column->get_element(i), 2, false, false);
@@ -1095,7 +1117,7 @@ struct MoneyFormatInt64Impl {
     static DataTypes get_variadic_argument_types() { return 
{std::make_shared<DataTypeInt64>()}; }
 
     static void execute(FunctionContext* context, ColumnString* result_column,
-                          const ColumnType* data_column, size_t 
input_rows_count) {
+                        const ColumnType* data_column, size_t 
input_rows_count) {
         for (size_t i = 0; i < input_rows_count; i++) {
             Int64 value = data_column->get_element(i);
             StringVal str = StringFunctions::do_money_format<Int64, 
26>(context, value);
@@ -1110,7 +1132,7 @@ struct MoneyFormatInt128Impl {
     static DataTypes get_variadic_argument_types() { return 
{std::make_shared<DataTypeInt128>()}; }
 
     static void execute(FunctionContext* context, ColumnString* result_column,
-                          const ColumnType* data_column, size_t 
input_rows_count) {
+                        const ColumnType* data_column, size_t 
input_rows_count) {
         for (size_t i = 0; i < input_rows_count; i++) {
             Int128 value = data_column->get_element(i);
             StringVal str = StringFunctions::do_money_format<Int128, 
52>(context, value);
@@ -1127,7 +1149,7 @@ struct MoneyFormatDecimalImpl {
     }
 
     static void execute(FunctionContext* context, ColumnString* result_column,
-                          const ColumnType* data_column, size_t 
input_rows_count) {
+                        const ColumnType* data_column, size_t 
input_rows_count) {
         for (size_t i = 0; i < input_rows_count; i++) {
             DecimalV2Val value = DecimalV2Val(data_column->get_element(i));
 
@@ -1142,4 +1164,84 @@ struct MoneyFormatDecimalImpl {
     }
 };
 
+class FunctionStringLocatePos : public IFunction {
+public:
+    static constexpr auto name = "locate";
+    static FunctionPtr create() { return 
std::make_shared<FunctionStringLocatePos>(); }
+    String get_name() const override { return name; }
+    size_t get_number_of_arguments() const override { return 3; }
+
+    DataTypePtr get_return_type_impl(const DataTypes& arguments) const 
override {
+        return std::make_shared<DataTypeInt32>();
+    }
+
+    DataTypes get_variadic_argument_types_impl() const override {
+        return {std::make_shared<DataTypeString>(), 
std::make_shared<DataTypeString>(),
+                std::make_shared<DataTypeInt32>()};
+    }
+
+    bool is_variadic() const override { return true; }
+
+    bool use_default_implementation_for_constants() const override { return 
true; }
+
+    Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
+                        size_t result, size_t input_rows_count) override {
+        auto col_substr =
+                
block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
+        auto col_str =
+                
block.get_by_position(arguments[1]).column->convert_to_full_column_if_const();
+        auto col_pos =
+                
block.get_by_position(arguments[2]).column->convert_to_full_column_if_const();
+
+        ColumnInt32::MutablePtr col_res = ColumnInt32::create();
+
+        auto& vec_pos = reinterpret_cast<const 
ColumnInt32*>(col_pos.get())->get_data();
+        auto& vec_res = col_res->get_data();
+        vec_res.resize(input_rows_count);
+
+        for (int i = 0; i < input_rows_count; ++i) {
+            vec_res[i] = locate_pos(col_substr->get_data_at(i).to_string_val(),
+                                    col_str->get_data_at(i).to_string_val(), 
vec_pos[i]);
+        }
+
+        block.replace_by_position(result, std::move(col_res));
+        return Status::OK();
+    }
+
+private:
+    int locate_pos(StringVal substr, StringVal str, int start_pos) {
+        if (substr.len == 0) {
+            if (start_pos <= 0) {
+                return 0;
+            } else if (start_pos == 1) {
+                return 1;
+            } else if (start_pos > str.len) {
+                return 0;
+            } else {
+                return start_pos;
+            }
+        }
+        // Hive returns 0 for *start_pos <= 0,
+        // but throws an exception for *start_pos > str->len.
+        // Since returning 0 seems to be Hive's error condition, return 0.
+        std::vector<size_t> index;
+        size_t char_len = get_char_len(str, &index);
+        if (start_pos <= 0 || start_pos > str.len || start_pos > char_len) {
+            return 0;
+        }
+        StringValue substr_sv = StringValue::from_string_val(substr);
+        StringSearch search(&substr_sv);
+        // Input start_pos starts from 1.
+        StringValue adjusted_str(reinterpret_cast<char*>(str.ptr) + 
index[start_pos - 1],
+                                 str.len - index[start_pos - 1]);
+        int32_t match_pos = search.search(&adjusted_str);
+        if (match_pos >= 0) {
+            // Hive returns the position in the original string starting from 
1.
+            return start_pos + get_char_len(adjusted_str, match_pos);
+        } else {
+            return 0;
+        }
+    }
+};
+
 } // namespace doris::vectorized
\ No newline at end of file
diff --git a/be/src/vec/functions/function_totype.h 
b/be/src/vec/functions/function_totype.h
index 72e74b3..fef300e 100644
--- a/be/src/vec/functions/function_totype.h
+++ b/be/src/vec/functions/function_totype.h
@@ -195,9 +195,17 @@ public:
     static FunctionPtr create() { return 
std::make_shared<FunctionBinaryToType>(); }
     String get_name() const override { return name; }
     size_t get_number_of_arguments() const override { return 2; }
+
     DataTypePtr get_return_type_impl(const DataTypes& arguments) const 
override {
         return std::make_shared<ResultDataType>();
     }
+
+    DataTypes get_variadic_argument_types_impl() const override {
+        return {std::make_shared<DataTypeString>(), 
std::make_shared<DataTypeString>()};
+    }
+
+    bool is_variadic() const override { return true; }
+
     bool use_default_implementation_for_constants() const override { return 
true; }
 
     Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
diff --git a/be/test/vec/function/function_string_test.cpp 
b/be/test/vec/function/function_string_test.cpp
index 5617a36..60184fc 100644
--- a/be/test/vec/function/function_string_test.cpp
+++ b/be/test/vec/function/function_string_test.cpp
@@ -25,8 +25,10 @@
 #include "util/encryption_util.h"
 #include "util/url_coding.h"
 #include "vec/core/field.h"
+#include "vec/core/types.h"
 
 namespace doris::vectorized {
+using namespace ut_type;
 
 TEST(function_string_test, function_string_substr_test) {
     std::string func_name = "substr";
@@ -440,17 +442,47 @@ TEST(function_string_test, function_instr_test) {
 
     InputTypeSet input_types = {TypeIndex::String, TypeIndex::String};
 
-    DataSet data_set = {{{std::string("abcdefg"), std::string("efg")}, 5},
-                        {{std::string("aa"), std::string("a")}, 1},
-                        {{std::string("我是"), std::string("是")}, 2},
-                        {{std::string("abcd"), std::string("e")}, 0},
-                        {{std::string("abcdef"), std::string("")}, 1},
-                        {{std::string(""), std::string("")}, 1},
-                        {{std::string("aaaab"), std::string("bb")}, 0}};
+    DataSet data_set = {
+            {{STRING("abcdefg"), STRING("efg")}, INT(5)}, {{STRING("aa"), 
STRING("a")}, INT(1)},
+            {{STRING("我是"), STRING("是")}, INT(2)},     {{STRING("abcd"), 
STRING("e")}, INT(0)},
+            {{STRING("abcdef"), STRING("")}, INT(1)},     {{STRING(""), 
STRING("")}, INT(1)},
+            {{STRING("aaaab"), STRING("bb")}, INT(0)}};
 
     check_function<DataTypeInt32, true>(func_name, input_types, data_set);
 }
 
+TEST(function_string_test, function_locate_test) {
+    std::string func_name = "locate";
+
+    {
+        InputTypeSet input_types = {TypeIndex::String, TypeIndex::String};
+
+        DataSet data_set = {{{STRING("efg"), STRING("abcdefg")}, INT(5)},
+                            {{STRING("a"), STRING("aa")}, INT(1)},
+                            {{STRING("是"), STRING("我是")}, INT(2)},
+                            {{STRING("e"), STRING("abcd")}, INT(0)},
+                            {{STRING(""), STRING("abcdef")}, INT(1)},
+                            {{STRING(""), STRING("")}, INT(1)},
+                            {{STRING("bb"), STRING("aaaab")}, INT(0)}};
+
+        check_function<DataTypeInt32, true>(func_name, input_types, data_set);
+    }
+
+    {
+        InputTypeSet input_types = {TypeIndex::String, TypeIndex::String, 
TypeIndex::Int32};
+
+        DataSet data_set = {{{STRING("bar"), STRING("foobarbar"), INT(5)}, 
INT(7)},
+                            {{STRING("xbar"), STRING("foobar"), INT(1)}, 
INT(0)},
+                            {{STRING(""), STRING("foobar"), INT(2)}, INT(2)},
+                            {{STRING("A"), STRING("大A写的A"), INT(0)}, INT(0)},
+                            {{STRING("A"), STRING("大A写的A"), INT(1)}, INT(2)},
+                            {{STRING("A"), STRING("大A写的A"), INT(2)}, INT(2)},
+                            {{STRING("A"), STRING("大A写的A"), INT(3)}, INT(5)}};
+
+        check_function<DataTypeInt32, true>(func_name, input_types, data_set);
+    }
+}
+
 TEST(function_string_test, function_find_in_set_test) {
     std::string func_name = "find_in_set";
 

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to