This is an automated email from the ASF dual-hosted git repository.

morningman pushed a commit to branch tpch500
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 2121c97465def5b567260e245882c5d5f523a14c
Author: TengJianPing <18241664+jackte...@users.noreply.github.com>
AuthorDate: Mon Dec 18 12:01:55 2023 +0800

    [fix](expr) fix performance problem caused by too many virtual function 
call (#28508)
---
 be/src/vec/columns/column.h                        |  2 +
 be/src/vec/columns/column_decimal.cpp              | 12 ++++++
 be/src/vec/columns/column_decimal.h                |  2 +
 be/src/vec/columns/column_nullable.cpp             | 10 -----
 be/src/vec/columns/column_nullable.h               |  2 -
 be/src/vec/columns/column_vector.cpp               | 12 ++++++
 be/src/vec/columns/column_vector.h                 |  2 +
 be/src/vec/functions/function.cpp                  | 27 +++++---------
 be/src/vec/functions/function.h                    | 18 +++++++++
 be/src/vec/functions/function_binary_arithmetic.h  | 10 +++++
 be/src/vec/functions/function_helpers.cpp          | 35 ++++++++++++------
 be/src/vec/functions/function_helpers.h            | 12 +++---
 .../data/datatype_p0/decimalv3/fix-overflow.out    | 12 ++++++
 .../datatype_p0/decimalv3/fix-overflow.groovy      | 43 ++++++++++++++++++++++
 14 files changed, 153 insertions(+), 46 deletions(-)

diff --git a/be/src/vec/columns/column.h b/be/src/vec/columns/column.h
index b2cf72b016f..25ad1145fc8 100644
--- a/be/src/vec/columns/column.h
+++ b/be/src/vec/columns/column.h
@@ -697,6 +697,8 @@ public:
     // only used in ColumnNullable replace_column_data
     virtual void replace_column_data_default(size_t self_row = 0) = 0;
 
+    virtual void replace_column_null_data(const uint8_t* __restrict null_map) 
{}
+
     virtual bool is_date_type() const { return is_date; }
     virtual bool is_datetime_type() const { return is_date_time; }
 
diff --git a/be/src/vec/columns/column_decimal.cpp 
b/be/src/vec/columns/column_decimal.cpp
index 07508f8c6a8..1158e5b0d63 100644
--- a/be/src/vec/columns/column_decimal.cpp
+++ b/be/src/vec/columns/column_decimal.cpp
@@ -515,6 +515,18 @@ ColumnPtr ColumnDecimal<T>::index(const IColumn& indexes, 
size_t limit) const {
     return select_index_impl(*this, indexes, limit);
 }
 
+template <typename T>
+void ColumnDecimal<T>::replace_column_null_data(const uint8_t* __restrict 
null_map) {
+    auto s = size();
+    size_t null_count = s - simd::count_zero_num((const int8_t*)null_map, s);
+    if (0 == null_count) {
+        return;
+    }
+    for (size_t i = 0; i < s; ++i) {
+        data[i] = null_map[i] ? T() : data[i];
+    }
+}
+
 template class ColumnDecimal<Decimal32>;
 template class ColumnDecimal<Decimal64>;
 template class ColumnDecimal<Decimal128>;
diff --git a/be/src/vec/columns/column_decimal.h 
b/be/src/vec/columns/column_decimal.h
index dfdfbb0d6b9..9b1ad69bbad 100644
--- a/be/src/vec/columns/column_decimal.h
+++ b/be/src/vec/columns/column_decimal.h
@@ -261,6 +261,8 @@ public:
         data[self_row] = T();
     }
 
+    void replace_column_null_data(const uint8_t* __restrict null_map) override;
+
     void sort_column(const ColumnSorter* sorter, EqualFlags& flags, 
IColumn::Permutation& perms,
                      EqualRange& range, bool last_column) const override;
 
diff --git a/be/src/vec/columns/column_nullable.cpp 
b/be/src/vec/columns/column_nullable.cpp
index 98ce01dd572..ecf330bead3 100644
--- a/be/src/vec/columns/column_nullable.cpp
+++ b/be/src/vec/columns/column_nullable.cpp
@@ -50,16 +50,6 @@ ColumnNullable::ColumnNullable(MutableColumnPtr&& 
nested_column_, MutableColumnP
     _need_update_has_null = true;
 }
 
-void ColumnNullable::update_null_data() {
-    const auto& null_map_data = _get_null_map_data();
-    auto s = size();
-    for (size_t i = 0; i < s; ++i) {
-        if (null_map_data[i]) {
-            nested_column->replace_column_data_default(i);
-        }
-    }
-}
-
 MutableColumnPtr ColumnNullable::get_shrinked_column() {
     return 
ColumnNullable::create(get_nested_column_ptr()->get_shrinked_column(),
                                   get_null_map_column_ptr());
diff --git a/be/src/vec/columns/column_nullable.h 
b/be/src/vec/columns/column_nullable.h
index 8eb1f4eedb0..10b0951ab8b 100644
--- a/be/src/vec/columns/column_nullable.h
+++ b/be/src/vec/columns/column_nullable.h
@@ -83,8 +83,6 @@ public:
         return Base::create(std::forward<Args>(args)...);
     }
 
-    void update_null_data();
-
     MutableColumnPtr get_shrinked_column() override;
 
     const char* get_family_name() const override { return "Nullable"; }
diff --git a/be/src/vec/columns/column_vector.cpp 
b/be/src/vec/columns/column_vector.cpp
index 65b1b6308ee..45d9e8f70b0 100644
--- a/be/src/vec/columns/column_vector.cpp
+++ b/be/src/vec/columns/column_vector.cpp
@@ -557,6 +557,18 @@ ColumnPtr ColumnVector<T>::index(const IColumn& indexes, 
size_t limit) const {
     return select_index_impl(*this, indexes, limit);
 }
 
+template <typename T>
+void ColumnVector<T>::replace_column_null_data(const uint8_t* __restrict 
null_map) {
+    auto s = size();
+    size_t null_count = s - simd::count_zero_num((const int8_t*)null_map, s);
+    if (0 == null_count) {
+        return;
+    }
+    for (size_t i = 0; i < s; ++i) {
+        data[i] = null_map[i] ? T() : data[i];
+    }
+}
+
 /// Explicit template instantiations - to avoid code bloat in headers.
 template class ColumnVector<UInt8>;
 template class ColumnVector<UInt16>;
diff --git a/be/src/vec/columns/column_vector.h 
b/be/src/vec/columns/column_vector.h
index 772162bc879..384a8daa1c7 100644
--- a/be/src/vec/columns/column_vector.h
+++ b/be/src/vec/columns/column_vector.h
@@ -474,6 +474,8 @@ public:
         data[self_row] = T();
     }
 
+    void replace_column_null_data(const uint8_t* __restrict null_map) override;
+
     void sort_column(const ColumnSorter* sorter, EqualFlags& flags, 
IColumn::Permutation& perms,
                      EqualRange& range, bool last_column) const override;
 
diff --git a/be/src/vec/functions/function.cpp 
b/be/src/vec/functions/function.cpp
index b92f5453bfe..6e7f6572ab8 100644
--- a/be/src/vec/functions/function.cpp
+++ b/be/src/vec/functions/function.cpp
@@ -99,21 +99,8 @@ ColumnPtr wrap_in_nullable(const ColumnPtr& src, const 
Block& block, const Colum
         return ColumnNullable::create(src, 
ColumnUInt8::create(input_rows_count, 0));
     }
 
-    bool update_null_data = false;
-    auto full_column = src_not_nullable->convert_to_full_column_if_const();
-    if (const auto* nullable = check_and_get_column<const 
ColumnNullable>(full_column.get())) {
-        const auto& nested_column = nullable->get_nested_column();
-        update_null_data = nested_column.is_numeric() || 
nested_column.is_column_decimal();
-    } else {
-        update_null_data = full_column->is_numeric() || 
full_column->is_column_decimal();
-    }
-    auto result_column = ColumnNullable::create(full_column, 
result_null_map_column);
-    if (update_null_data) {
-        auto* res_nullable_column =
-                
assert_cast<ColumnNullable*>(std::move(*result_column).mutate().get());
-        res_nullable_column->update_null_data();
-    }
-    return result_column;
+    return 
ColumnNullable::create(src_not_nullable->convert_to_full_column_if_const(),
+                                  result_null_map_column);
 }
 
 NullPresence get_null_presence(const Block& block, const ColumnNumbers& args) {
@@ -247,8 +234,14 @@ Status 
PreparedFunctionImpl::default_implementation_for_nulls(
     }
 
     if (null_presence.has_nullable) {
-        auto [temporary_block, new_args, new_result] =
-                create_block_with_nested_columns(block, args, result);
+        bool check_overflow_for_decimal = false;
+        if (context) {
+            check_overflow_for_decimal = context->check_overflow_for_decimal();
+        }
+        auto [temporary_block, new_args, new_result] = 
create_block_with_nested_columns(
+                block, args, result,
+                check_overflow_for_decimal && 
need_replace_null_data_to_default());
+
         RETURN_IF_ERROR(execute_without_low_cardinality_columns(
                 context, temporary_block, new_args, new_result, 
temporary_block.rows(), dry_run));
         block.get_by_position(result).column =
diff --git a/be/src/vec/functions/function.h b/be/src/vec/functions/function.h
index 63cf78c417c..cb8ff34cdbb 100644
--- a/be/src/vec/functions/function.h
+++ b/be/src/vec/functions/function.h
@@ -102,6 +102,12 @@ public:
       */
     virtual bool use_default_implementation_for_constants() const { return 
true; }
 
+    /** If use_default_implementation_for_nulls() is true, after execute the 
function,
+      * whether need to replace the nested data of null data to the default 
value.
+      * E.g. for binary arithmetic exprs, need return true to avoid false 
overflow.
+      */
+    virtual bool need_replace_null_data_to_default() const { return false; }
+
 protected:
     virtual Status execute_impl_dry_run(FunctionContext* context, Block& block,
                                         const ColumnNumbers& arguments, size_t 
result,
@@ -393,6 +399,8 @@ protected:
       */
     virtual bool use_default_implementation_for_nulls() const { return true; }
 
+    virtual bool need_replace_null_data_to_default() const { return false; }
+
     /** If use_default_implementation_for_nulls() is true, than change 
arguments for get_return_type() and build_impl().
       * If function arguments has low cardinality types, convert them to 
ordinary types.
       * get_return_type returns ColumnLowCardinality if at least one argument 
type is ColumnLowCardinality.
@@ -434,6 +442,9 @@ public:
 
     /// Override this functions to change default implementation behavior. See 
details in IMyFunction.
     bool use_default_implementation_for_nulls() const override { return true; }
+
+    bool need_replace_null_data_to_default() const override { return false; }
+
     bool use_default_implementation_for_low_cardinality_columns() const 
override { return true; }
 
     /// all constancy check should use this function to do automatically
@@ -506,6 +517,9 @@ protected:
     bool use_default_implementation_for_nulls() const final {
         return function->use_default_implementation_for_nulls();
     }
+    bool need_replace_null_data_to_default() const final {
+        return function->need_replace_null_data_to_default();
+    }
     bool use_default_implementation_for_constants() const final {
         return function->use_default_implementation_for_constants();
     }
@@ -629,6 +643,10 @@ protected:
     bool use_default_implementation_for_nulls() const override {
         return function->use_default_implementation_for_nulls();
     }
+
+    bool need_replace_null_data_to_default() const override {
+        return function->need_replace_null_data_to_default();
+    }
     bool use_default_implementation_for_low_cardinality_columns() const 
override {
         return 
function->use_default_implementation_for_low_cardinality_columns();
     }
diff --git a/be/src/vec/functions/function_binary_arithmetic.h 
b/be/src/vec/functions/function_binary_arithmetic.h
index 498cea1c1a1..9f653d5c4c8 100644
--- a/be/src/vec/functions/function_binary_arithmetic.h
+++ b/be/src/vec/functions/function_binary_arithmetic.h
@@ -870,6 +870,8 @@ template <template <typename, typename> class Operation, 
typename Name, bool is_
 class FunctionBinaryArithmetic : public IFunction {
     using OpTraits = OperationTraits<Operation>;
 
+    mutable bool need_replace_null_data_to_default_ = false;
+
     template <typename F>
     static bool cast_type(const IDataType* type, F&& f) {
         return cast_type_to_either<DataTypeUInt8, DataTypeInt8, DataTypeInt16, 
DataTypeInt32,
@@ -905,6 +907,10 @@ public:
 
     String get_name() const override { return name; }
 
+    bool need_replace_null_data_to_default() const override {
+        return need_replace_null_data_to_default_;
+    }
+
     size_t get_number_of_arguments() const override { return 2; }
 
     DataTypes get_variadic_argument_types_impl() const override {
@@ -924,6 +930,10 @@ public:
                             typename BinaryOperationTraits<Operation, 
LeftDataType,
                                                            
RightDataType>::ResultDataType;
                     if constexpr (!std::is_same_v<ResultDataType, 
InvalidType>) {
+                        need_replace_null_data_to_default_ =
+                                IsDataTypeDecimal<ResultDataType> ||
+                                (name == "pow" &&
+                                 std::is_floating_point_v<typename 
ResultDataType::FieldType>);
                         if constexpr (IsDataTypeDecimal<LeftDataType> &&
                                       IsDataTypeDecimal<RightDataType>) {
                             type_res = decimal_result_type(left, right, 
OpTraits::is_multiply,
diff --git a/be/src/vec/functions/function_helpers.cpp 
b/be/src/vec/functions/function_helpers.cpp
index 9cb371fb086..c6202e2b088 100644
--- a/be/src/vec/functions/function_helpers.cpp
+++ b/be/src/vec/functions/function_helpers.cpp
@@ -39,9 +39,9 @@
 
 namespace doris::vectorized {
 
-std::tuple<Block, ColumnNumbers> create_block_with_nested_columns(const Block& 
block,
-                                                                  const 
ColumnNumbers& args,
-                                                                  const bool 
need_check_same) {
+std::tuple<Block, ColumnNumbers> create_block_with_nested_columns(
+        const Block& block, const ColumnNumbers& args, const bool 
need_check_same,
+        bool need_replace_null_data_to_default) {
     Block res;
     ColumnNumbers res_args(args.size());
     res.reserve(args.size() + 1);
@@ -70,10 +70,22 @@ std::tuple<Block, ColumnNumbers> 
create_block_with_nested_columns(const Block& b
 
                 if (!col.column) {
                     res.insert({nullptr, nested_type, col.name});
-                } else if (auto* nullable = 
check_and_get_column<ColumnNullable>(*col.column)) {
-                    const auto& nested_col = nullable->get_nested_column_ptr();
-                    res.insert({nested_col, nested_type, col.name});
-                } else if (auto* const_column = 
check_and_get_column<ColumnConst>(*col.column)) {
+                } else if (const auto* nullable =
+                                   
check_and_get_column<ColumnNullable>(*col.column)) {
+                    if (need_replace_null_data_to_default) {
+                        const auto& null_map = nullable->get_null_map_data();
+                        const auto nested_col = 
nullable->get_nested_column_ptr();
+                        // only need to mutate nested column, avoid to copy 
nullmap
+                        auto mutable_nested_col = 
(*std::move(nested_col)).mutate();
+                        
mutable_nested_col->replace_column_null_data(null_map.data());
+
+                        res.insert({std::move(mutable_nested_col), 
nested_type, col.name});
+                    } else {
+                        const auto& nested_col = 
nullable->get_nested_column_ptr();
+                        res.insert({nested_col, nested_type, col.name});
+                    }
+                } else if (const auto* const_column =
+                                   
check_and_get_column<ColumnConst>(*col.column)) {
                     const auto& nested_col =
                             
check_and_get_column<ColumnNullable>(const_column->get_data_column())
                                     ->get_nested_column_ptr();
@@ -104,10 +116,11 @@ std::tuple<Block, ColumnNumbers> 
create_block_with_nested_columns(const Block& b
     return {std::move(res), std::move(res_args)};
 }
 
-std::tuple<Block, ColumnNumbers, size_t> 
create_block_with_nested_columns(const Block& block,
-                                                                          
const ColumnNumbers& args,
-                                                                          
size_t result) {
-    auto [res, res_args] = create_block_with_nested_columns(block, args, true);
+std::tuple<Block, ColumnNumbers, size_t> create_block_with_nested_columns(
+        const Block& block, const ColumnNumbers& args, size_t result,
+        bool need_replace_null_data_to_default) {
+    auto [res, res_args] =
+            create_block_with_nested_columns(block, args, true, 
need_replace_null_data_to_default);
     // insert result column in temp block
     res.insert(block.get_by_position(result));
     return {std::move(res), std::move(res_args), res.columns() - 1};
diff --git a/be/src/vec/functions/function_helpers.h 
b/be/src/vec/functions/function_helpers.h
index dce507f6568..f5d343f3678 100644
--- a/be/src/vec/functions/function_helpers.h
+++ b/be/src/vec/functions/function_helpers.h
@@ -97,14 +97,14 @@ Columns convert_const_tuple_to_constant_elements(const 
ColumnConst& column);
 /// Returns the copy of a tmp block and temp args order same as args
 /// in which only args column each column specified in the "arguments"
 /// parameter is replaced with its respective nested column if it is nullable.
-std::tuple<Block, ColumnNumbers> create_block_with_nested_columns(const Block& 
block,
-                                                                  const 
ColumnNumbers& args,
-                                                                  const bool 
need_check_same);
+std::tuple<Block, ColumnNumbers> create_block_with_nested_columns(
+        const Block& block, const ColumnNumbers& args, const bool 
need_check_same,
+        bool need_replace_null_data_to_default = false);
 
 // Same as above and return the new_res loc in tuple
-std::tuple<Block, ColumnNumbers, size_t> 
create_block_with_nested_columns(const Block& block,
-                                                                          
const ColumnNumbers& args,
-                                                                          
size_t result);
+std::tuple<Block, ColumnNumbers, size_t> create_block_with_nested_columns(
+        const Block& block, const ColumnNumbers& args, size_t result,
+        bool need_replace_null_data_to_default = false);
 
 /// Checks argument type at specified index with predicate.
 /// throws if there is no argument at specified index or if predicate returns 
false.
diff --git a/regression-test/data/datatype_p0/decimalv3/fix-overflow.out 
b/regression-test/data/datatype_p0/decimalv3/fix-overflow.out
index ba2250c9b72..26eff80019d 100644
--- a/regression-test/data/datatype_p0/decimalv3/fix-overflow.out
+++ b/regression-test/data/datatype_p0/decimalv3/fix-overflow.out
@@ -6,3 +6,15 @@
 a      \N
 b      0.00
 
+-- !select_fix_overflow_float_null1 --
+\N
+
+-- !select_fix_overflow_int_null1 --
+\N
+
+-- !select_fix_overflow_int_null2 --
+\N
+
+-- !select_fix_overflow_bool_null1 --
+\N
+
diff --git a/regression-test/suites/datatype_p0/decimalv3/fix-overflow.groovy 
b/regression-test/suites/datatype_p0/decimalv3/fix-overflow.groovy
index 0a285189cbc..4fd294e37d5 100644
--- a/regression-test/suites/datatype_p0/decimalv3/fix-overflow.groovy
+++ b/regression-test/suites/datatype_p0/decimalv3/fix-overflow.groovy
@@ -104,4 +104,47 @@ suite("fix-overflow") {
     qt_select_insert """
         select * from fix_overflow_null2 order by 1,2;
     """
+
+    sql """
+        drop table if exists fix_overflow_null3;
+    """
+    sql """
+        create table fix_overflow_null3(k1 decimalv3(38, 6), k2 double, k3 
double) distributed by hash(k1) properties("replication_num"="1");
+    """
+    sql """
+        insert into fix_overflow_null3 values (9.9, -1, null);
+    """
+    qt_select_fix_overflow_float_null1 """
+        select cast(pow(k2+k3, 0.2) as decimalv3(38,6)) from 
fix_overflow_null3;
+    """
+
+    sql """
+        drop table if exists fix_overflow_null4
+    """
+    sql """
+        create table fix_overflow_null4(k1 int, k2 int, k3 decimalv3(38,6)) 
distributed by hash(k1) properties("replication_num"="1");
+    """
+    sql """
+        insert into fix_overflow_null4 values (1, null, 
99999999999999999999999999999999.999999);
+    """
+    qt_select_fix_overflow_int_null1 """
+        select k1 + k2 + k3 from fix_overflow_null4;
+    """
+    qt_select_fix_overflow_int_null2 """
+        select cast( (k1 + k2) as decimalv3(3, 0) ) from fix_overflow_null4;
+    """
+
+    sql """
+        drop table if exists fix_overflow_null5
+    """
+    sql """
+        create table fix_overflow_null5(k1 int, k2 int, k3 decimalv3(38,6))
+            distributed by hash(k1) properties("replication_num"="1");
+    """
+    sql """
+        insert into fix_overflow_null5 values (-1, null, 
99999999999999999999999999999999.999999);
+    """
+    qt_select_fix_overflow_bool_null1 """
+        select (k1 < k2) + k3 from fix_overflow_null5;
+    """
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to