This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 94eedd8ea4d1891baaa570257956d98c7c03d9c8
Author: koarz <66543806+ko...@users.noreply.github.com>
AuthorDate: Fri Feb 2 10:20:29 2024 +0800

    [Enhancement](function)make SUBSTRING_INDEX function DEPEND_ON_ARGUMENT 
(#30392)
---
 be/src/agent/be_exec_version_manager.h             |   2 +-
 be/src/vec/functions/function_string.cpp           |   1 +
 be/src/vec/functions/function_string.h             | 182 ++++++++++++++++++---
 .../functions/scalar/SubstringIndex.java           |   4 +-
 gensrc/script/doris_builtins_functions.py          |   4 +-
 5 files changed, 166 insertions(+), 27 deletions(-)

diff --git a/be/src/agent/be_exec_version_manager.h 
b/be/src/agent/be_exec_version_manager.h
index 963437b3d39..1cabc38eba2 100644
--- a/be/src/agent/be_exec_version_manager.h
+++ b/be/src/agent/be_exec_version_manager.h
@@ -64,7 +64,7 @@ private:
  *    c. cleared old version of Version 2.
  *    d. unix_timestamp function support timestamp with float for datetimev2, 
and change nullable mode.
  *    e. change shuffle serialize/deserialize way 
- *    f. the right function outputs NULL when the function contains NULL, 
substr function returns empty if start > str.length.
+ *    f. the right function outputs NULL when the function contains NULL, 
substr function returns empty if start > str.length, and change some function 
nullable mode.
 */
 constexpr inline int BeExecVersionManager::max_be_exec_version = 3;
 constexpr inline int BeExecVersionManager::min_be_exec_version = 0;
diff --git a/be/src/vec/functions/function_string.cpp 
b/be/src/vec/functions/function_string.cpp
index 723c68c3010..6965139a1c8 100644
--- a/be/src/vec/functions/function_string.cpp
+++ b/be/src/vec/functions/function_string.cpp
@@ -1020,6 +1020,7 @@ void register_function_string(SimpleFunctionFactory& 
factory) {
     
factory.register_alternative_function<FunctionSubstringOld<Substr2ImplOld>>();
     factory.register_alternative_function<FunctionLeftOld>();
     factory.register_alternative_function<FunctionRightOld>();
+    factory.register_alternative_function<FunctionSubstringIndexOld>();
 
     factory.register_alias(FunctionLeft::name, "strleft");
     factory.register_alias(FunctionRight::name, "strright");
diff --git a/be/src/vec/functions/function_string.h 
b/be/src/vec/functions/function_string.h
index bb1fbfce767..6fc84074ddb 100644
--- a/be/src/vec/functions/function_string.h
+++ b/be/src/vec/functions/function_string.h
@@ -471,8 +471,6 @@ public:
         return get_variadic_argument_types_impl().size();
     }
 
-    bool use_default_implementation_for_nulls() const override { return true; }
-
     Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
                         size_t result, size_t input_rows_count) const override 
{
         return Impl::execute_impl(context, block, arguments, result, 
input_rows_count);
@@ -622,8 +620,6 @@ public:
 
     bool is_variadic() const override { return true; }
 
-    bool use_default_implementation_for_nulls() const override { return true; }
-
     Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
                         size_t result, size_t input_rows_count) const override 
{
         DCHECK_GE(arguments.size(), 1);
@@ -636,7 +632,7 @@ public:
                 assert_cast<const 
ColumnString&>(*block.get_by_position(arguments[0]).column);
 
         if (arguments.size() > 1) {
-            auto& col = *block.get_by_position(arguments[1]).column;
+            const auto& col = *block.get_by_position(arguments[1]).column;
             auto string_ref = col.get_data_at(0);
             if (string_ref.size > 0) {
                 upper = *string_ref.data;
@@ -644,7 +640,7 @@ public:
         }
 
         if (arguments.size() > 2) {
-            auto& col = *block.get_by_position(arguments[2]).column;
+            const auto& col = *block.get_by_position(arguments[2]).column;
             auto string_ref = col.get_data_at(0);
             if (string_ref.size > 0) {
                 lower = *string_ref.data;
@@ -652,7 +648,7 @@ public:
         }
 
         if (arguments.size() > 3) {
-            auto& col = *block.get_by_position(arguments[3]).column;
+            const auto& col = *block.get_by_position(arguments[3]).column;
             auto string_ref = col.get_data_at(0);
             if (string_ref.size > 0) {
                 number = *string_ref.data;
@@ -721,8 +717,6 @@ public:
 
     bool is_variadic() const override { return true; }
 
-    bool use_default_implementation_for_nulls() const override { return true; }
-
     Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
                         size_t result, size_t input_rows_count) const override 
{
         DCHECK_GE(arguments.size(), 1);
@@ -732,10 +726,10 @@ public:
 
         auto res = ColumnString::create();
         auto col = 
block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
-        const ColumnString& source_column = assert_cast<const 
ColumnString&>(*col);
+        const auto& source_column = assert_cast<const ColumnString&>(*col);
 
         if (arguments.size() == 2) {
-            auto& col = *block.get_by_position(arguments[1]).column;
+            const auto& col = *block.get_by_position(arguments[1]).column;
             n = col.get_int(0);
         } else if (arguments.size() > 2) {
             return Status::InvalidArgument(
@@ -758,8 +752,8 @@ public:
 private:
     static void vector(const ColumnString& src, int n, ColumnString& result) {
         const auto num_rows = src.size();
-        auto* chars = src.get_chars().data();
-        auto* offsets = src.get_offsets().data();
+        const auto* chars = src.get_chars().data();
+        const auto* offsets = src.get_offsets().data();
         result.get_chars().resize(src.get_chars().size());
         result.get_offsets().resize(src.get_offsets().size());
         memcpy_small_allow_read_write_overflow15(
@@ -800,8 +794,6 @@ public:
         return std::make_shared<DataTypeString>();
     }
 
-    bool use_default_implementation_for_nulls() const override { return true; }
-
     Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
                         size_t result, size_t input_rows_count) const override 
{
         auto int_type = std::make_shared<DataTypeInt32>();
@@ -855,8 +847,6 @@ public:
         return std::make_shared<DataTypeString>();
     }
 
-    bool use_default_implementation_for_nulls() const override { return true; }
-
     Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
                         size_t result, size_t input_rows_count) const override 
{
         auto int_type = std::make_shared<DataTypeInt32>();
@@ -1867,10 +1857,159 @@ public:
     size_t get_number_of_arguments() const override { return 3; }
 
     DataTypePtr get_return_type_impl(const DataTypes& arguments) const 
override {
-        return make_nullable(std::make_shared<DataTypeString>());
+        return std::make_shared<DataTypeString>();
     }
 
-    bool use_default_implementation_for_nulls() const override { return true; }
+    Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
+                        size_t result, size_t input_rows_count) const override 
{
+        DCHECK_EQ(arguments.size(), 3);
+
+        // Create a zero column to simply implement
+        auto res = ColumnString::create();
+
+        auto& res_offsets = res->get_offsets();
+        auto& res_chars = res->get_chars();
+        res_offsets.resize(input_rows_count);
+        ColumnPtr content_column;
+        bool content_const = false;
+        std::tie(content_column, content_const) =
+                unpack_if_const(block.get_by_position(arguments[0]).column);
+
+        const auto* str_col = assert_cast<const 
ColumnString*>(content_column.get());
+
+        [[maybe_unused]] const auto& [delimiter_col, delimiter_const] =
+                unpack_if_const(block.get_by_position(arguments[1]).column);
+        auto delimiter = delimiter_col->get_data_at(0);
+        int32_t delimiter_size = delimiter.size;
+
+        [[maybe_unused]] const auto& [part_num_col, part_const] =
+                unpack_if_const(block.get_by_position(arguments[2]).column);
+        auto part_number = *((int*)part_num_col->get_data_at(0).data);
+
+        if (part_number == 0 || delimiter_size == 0) {
+            for (size_t i = 0; i < input_rows_count; ++i) {
+                StringOP::push_empty_string(i, res_chars, res_offsets);
+            }
+        } else if (part_number > 0) {
+            if (delimiter_size == 1) {
+                // If delimiter is a char, use memchr to split
+                for (size_t i = 0; i < input_rows_count; ++i) {
+                    auto str = str_col->get_data_at(i);
+                    int32_t offset = -1;
+                    int32_t num = 0;
+                    while (num < part_number) {
+                        size_t n = str.size - offset - 1;
+                        const char* pos = reinterpret_cast<const char*>(
+                                memchr(str.data + offset + 1, 
delimiter.data[0], n));
+                        if (pos != nullptr) {
+                            offset = pos - str.data;
+                            num++;
+                        } else {
+                            offset = str.size;
+                            num = (num == 0) ? 0 : num + 1;
+                            break;
+                        }
+                    }
+
+                    if (num == part_number) {
+                        StringOP::push_value_string(
+                                std::string_view {reinterpret_cast<const 
char*>(str.data),
+                                                  (size_t)offset},
+                                i, res_chars, res_offsets);
+                    } else {
+                        StringOP::push_value_string(std::string_view(str.data, 
str.size), i,
+                                                    res_chars, res_offsets);
+                    }
+                }
+            } else {
+                StringRef delimiter_ref(delimiter);
+                StringSearch search(&delimiter_ref);
+                for (size_t i = 0; i < input_rows_count; ++i) {
+                    auto str = str_col->get_data_at(i);
+                    int32_t offset = -delimiter_size;
+                    int32_t num = 0;
+                    while (num < part_number) {
+                        size_t n = str.size - offset - delimiter_size;
+                        // search first match delimter_ref index from src 
string among str_offset to end
+                        const char* pos = search.search(str.data + offset + 
delimiter_size, n);
+                        if (pos < str.data + str.size) {
+                            offset = pos - str.data;
+                            num++;
+                        } else {
+                            offset = str.size;
+                            num = (num == 0) ? 0 : num + 1;
+                            break;
+                        }
+                    }
+
+                    if (num == part_number) {
+                        StringOP::push_value_string(
+                                std::string_view {reinterpret_cast<const 
char*>(str.data),
+                                                  (size_t)offset},
+                                i, res_chars, res_offsets);
+                    } else {
+                        StringOP::push_value_string(std::string_view(str.data, 
str.size), i,
+                                                    res_chars, res_offsets);
+                    }
+                }
+            }
+        } else {
+            // if part_number is negative
+            part_number = -part_number;
+            for (size_t i = 0; i < input_rows_count; ++i) {
+                auto str = str_col->get_data_at(i);
+                auto str_str = str.to_string();
+                int32_t offset = str.size;
+                int32_t pre_offset = offset;
+                int32_t num = 0;
+                auto substr = str_str;
+                while (num <= part_number && offset >= 0) {
+                    offset = (int)substr.rfind(delimiter, offset);
+                    if (offset != -1) {
+                        if (++num == part_number) {
+                            break;
+                        }
+                        pre_offset = offset;
+                        offset = offset - 1;
+                        substr = str_str.substr(0, pre_offset);
+                    } else {
+                        break;
+                    }
+                }
+                num = (offset == -1 && num != 0) ? num + 1 : num;
+
+                if (num == part_number) {
+                    if (offset == -1) {
+                        StringOP::push_value_string(std::string_view(str.data, 
str.size), i,
+                                                    res_chars, res_offsets);
+                    } else {
+                        StringOP::push_value_string(
+                                std::string_view {str.data + offset + 
delimiter_size,
+                                                  str.size - offset - 
delimiter_size},
+                                i, res_chars, res_offsets);
+                    }
+                } else {
+                    StringOP::push_value_string(std::string_view(str.data, 
str.size), i, res_chars,
+                                                res_offsets);
+                }
+            }
+        }
+
+        block.get_by_position(result).column = std::move(res);
+        return Status::OK();
+    }
+};
+
+class FunctionSubstringIndexOld : public IFunction {
+public:
+    static constexpr auto name = "substring_index";
+    static FunctionPtr create() { return 
std::make_shared<FunctionSubstringIndex>(); }
+    String get_name() const override { return name; }
+    size_t get_number_of_arguments() const override { return 3; }
+
+    DataTypePtr get_return_type_impl(const DataTypes& arguments) const 
override {
+        return make_nullable(std::make_shared<DataTypeString>());
+    }
 
     Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
                         size_t result, size_t input_rows_count) const override 
{
@@ -1878,7 +2017,6 @@ public:
 
         auto null_map = ColumnUInt8::create(input_rows_count, 0);
         // Create a zero column to simply implement
-        auto const_null_map = ColumnUInt8::create(input_rows_count, 0);
         auto res = ColumnString::create();
 
         auto& res_offsets = res->get_offsets();
@@ -1889,7 +2027,7 @@ public:
         std::tie(content_column, content_const) =
                 unpack_if_const(block.get_by_position(arguments[0]).column);
 
-        if (auto* nullable = check_and_get_column<const 
ColumnNullable>(*content_column)) {
+        if (const auto* nullable = check_and_get_column<const 
ColumnNullable>(*content_column)) {
             // Danger: Here must dispose the null map data first! Because
             // argument_columns[0]=nullable->get_nested_column_ptr(); will 
release the mem
             // of column nullable mem of null map
@@ -1897,7 +2035,7 @@ public:
             content_column = nullable->get_nested_column_ptr();
         }
 
-        auto str_col = assert_cast<const ColumnString*>(content_column.get());
+        const auto* str_col = assert_cast<const 
ColumnString*>(content_column.get());
 
         [[maybe_unused]] const auto& [delimiter_col, delimiter_const] =
                 unpack_if_const(block.get_by_position(arguments[1]).column);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/SubstringIndex.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/SubstringIndex.java
index 5374950135c..bb9e2b749c4 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/SubstringIndex.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/SubstringIndex.java
@@ -20,8 +20,8 @@ package 
org.apache.doris.nereids.trees.expressions.functions.scalar;
 import org.apache.doris.catalog.FunctionSignature;
 import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
 import 
org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
+import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
 import org.apache.doris.nereids.trees.expressions.shape.TernaryExpression;
 import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 import org.apache.doris.nereids.types.IntegerType;
@@ -37,7 +37,7 @@ import java.util.List;
  * ScalarFunction 'substring_index'. This class is generated by 
GenerateFunction.
  */
 public class SubstringIndex extends ScalarFunction
-        implements TernaryExpression, ExplicitlyCastableSignature, 
AlwaysNullable {
+        implements TernaryExpression, ExplicitlyCastableSignature, 
PropagateNullable {
 
     public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
             FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
diff --git a/gensrc/script/doris_builtins_functions.py 
b/gensrc/script/doris_builtins_functions.py
index b31fab21b20..9cce3c3824e 100644
--- a/gensrc/script/doris_builtins_functions.py
+++ b/gensrc/script/doris_builtins_functions.py
@@ -1606,7 +1606,7 @@ visible_functions = {
         [['money_format'], 'VARCHAR', ['DECIMAL128'], ''],
         [['split_by_string'],'ARRAY_VARCHAR',['STRING','STRING'], ''],
         [['split_part'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'INT'], 
'ALWAYS_NULLABLE'],
-        [['substring_index'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'INT'], 
'ALWAYS_NULLABLE'],
+        [['substring_index'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'INT'], 
'DEPEND_ON_ARGUMENT'],
         [['extract_url_parameter'], 'VARCHAR', ['VARCHAR', 'VARCHAR'], ''],
 
         [['sub_replace'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'INT'], 
'ALWAYS_NULLABLE'],
@@ -1661,7 +1661,7 @@ visible_functions = {
         [['money_format'], 'STRING', ['DECIMAL64'], ''],
         [['money_format'], 'STRING', ['DECIMAL128'], ''],
         [['split_part'], 'STRING', ['STRING', 'STRING', 'INT'], 
'ALWAYS_NULLABLE'],
-        [['substring_index'], 'STRING', ['STRING', 'STRING', 'INT'], 
'ALWAYS_NULLABLE']
+        [['substring_index'], 'STRING', ['STRING', 'STRING', 'INT'], 
'DEPEND_ON_ARGUMENT']
     ],
 
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to