This is an automated email from the ASF dual-hosted git repository.

dataroaring pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 0075a835ab5 branch-3.0: [fix](DECIMAL) error DECIMAL cat to BOOLEAN 
#44326 (#46276)
0075a835ab5 is described below

commit 0075a835ab54d7d16d84eb092ebd6d72d6ea5752
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Tue Jan 14 18:02:17 2025 +0800

    branch-3.0: [fix](DECIMAL) error DECIMAL cat to BOOLEAN #44326 (#46276)
    
    Cherry-picked from #44326
    
    Co-authored-by: Mryange <yanxuech...@selectdb.com>
---
 be/src/vec/data_types/data_type_decimal.h          | 19 +++++---
 be/src/vec/functions/function_cast.h               | 43 +++++++++++------
 .../correctness/test_cast_decimalv3_as_bool.out    | 17 +++++++
 .../correctness/test_cast_decimalv3_as_bool.groovy | 55 ++++++++++++++++++++++
 .../case_function/test_case_function_null.groovy   | 18 ++++---
 5 files changed, 124 insertions(+), 28 deletions(-)

diff --git a/be/src/vec/data_types/data_type_decimal.h 
b/be/src/vec/data_types/data_type_decimal.h
index c34d88035aa..43557084e52 100644
--- a/be/src/vec/data_types/data_type_decimal.h
+++ b/be/src/vec/data_types/data_type_decimal.h
@@ -461,15 +461,20 @@ void convert_from_decimals(RealTo* dst, const RealFrom* 
src, UInt32 precicion_fr
     MaxFieldType multiplier = 
DataTypeDecimal<MaxFieldType>::get_scale_multiplier(scale_from);
     FromDataType from_data_type(precicion_from, scale_from);
     for (size_t i = 0; i < size; i++) {
-        auto tmp = static_cast<MaxFieldType>(src[i]).value / multiplier.value;
-        if constexpr (narrow_integral) {
-            if (tmp < min_result.value || tmp > max_result.value) {
-                
THROW_DECIMAL_CONVERT_OVERFLOW_EXCEPTION(from_data_type.to_string(src[i]),
-                                                         
from_data_type.get_name(),
-                                                         OrigToDataType 
{}.get_name());
+        // uint8_t now use as boolean in doris
+        if constexpr (std::is_same_v<RealTo, UInt8>) {
+            dst[i] = static_cast<MaxFieldType>(src[i]).value != 0;
+        } else {
+            auto tmp = static_cast<MaxFieldType>(src[i]).value / 
multiplier.value;
+            if constexpr (narrow_integral) {
+                if (tmp < min_result.value || tmp > max_result.value) {
+                    
THROW_DECIMAL_CONVERT_OVERFLOW_EXCEPTION(from_data_type.to_string(src[i]),
+                                                             
from_data_type.get_name(),
+                                                             OrigToDataType 
{}.get_name());
+                }
             }
+            dst[i] = tmp;
         }
-        dst[i] = tmp;
     }
 }
 
diff --git a/be/src/vec/functions/function_cast.h 
b/be/src/vec/functions/function_cast.h
index acf63a66229..0cc2e9e2862 100644
--- a/be/src/vec/functions/function_cast.h
+++ b/be/src/vec/functions/function_cast.h
@@ -256,6 +256,21 @@ struct ConvertImpl {
     using FromFieldType = typename FromDataType::FieldType;
     using ToFieldType = typename ToDataType::FieldType;
 
+    // `static_cast_set` is introduced to wrap `static_cast` and handle 
special cases.
+    // Doris uses `uint8` to represent boolean values internally.
+    // Directly `static_cast` to `uint8` can result in non-0/1 values,
+    // To address this, `static_cast_set` performs an additional check:
+    //  For `uint8` types, it explicitly uses `static_cast<bool>` to ensure
+    //  the result is either 0 or 1.
+    static void static_cast_set(ToFieldType& to, const FromFieldType& from) {
+        // uint8_t now use as boolean in doris
+        if constexpr (std::is_same_v<uint8_t, ToFieldType>) {
+            to = static_cast<bool>(from);
+        } else {
+            to = static_cast<ToFieldType>(from);
+        }
+    }
+
     template <typename Additions = void*>
     static Status execute(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
                           size_t result, size_t input_rows_count,
@@ -375,8 +390,9 @@ struct ConvertImpl {
                     } else if constexpr (IsDateTimeV2Type<ToDataType>) {
                         DataTypeDateTimeV2::cast_from_date(vec_from[i], 
vec_to[i]);
                     } else {
-                        vec_to[i] =
-                                reinterpret_cast<const 
VecDateTimeValue&>(vec_from[i]).to_int64();
+                        static_cast_set(
+                                vec_to[i],
+                                reinterpret_cast<const 
VecDateTimeValue&>(vec_from[i]).to_int64());
                     }
                 }
             } else if constexpr (IsTimeV2Type<FromDataType>) {
@@ -407,13 +423,16 @@ struct ConvertImpl {
                         }
                     } else {
                         if constexpr (IsDateTimeV2Type<FromDataType>) {
-                            vec_to[i] = reinterpret_cast<const 
DateV2Value<DateTimeV2ValueType>&>(
-                                                vec_from[i])
-                                                .to_int64();
+                            static_cast_set(
+                                    vec_to[i],
+                                    reinterpret_cast<const 
DateV2Value<DateTimeV2ValueType>&>(
+                                            vec_from[i])
+                                            .to_int64());
                         } else {
-                            vec_to[i] = reinterpret_cast<const 
DateV2Value<DateV2ValueType>&>(
-                                                vec_from[i])
-                                                .to_int64();
+                            static_cast_set(vec_to[i],
+                                            reinterpret_cast<const 
DateV2Value<DateV2ValueType>&>(
+                                                    vec_from[i])
+                                                    .to_int64());
                         }
                     }
                 }
@@ -440,16 +459,10 @@ struct ConvertImpl {
                     }
                 } else {
                     for (size_t i = 0; i < size; ++i) {
-                        vec_to[i] = static_cast<ToFieldType>(vec_from[i]);
+                        static_cast_set(vec_to[i], vec_from[i]);
                     }
                 }
             }
-            // TODO: support boolean cast more reasonable
-            if constexpr (std::is_same_v<uint8_t, ToFieldType>) {
-                for (int i = 0; i < size; ++i) {
-                    vec_to[i] = static_cast<bool>(vec_to[i]);
-                }
-            }
 
             block.replace_by_position(result, std::move(col_to));
         } else {
diff --git a/regression-test/data/correctness/test_cast_decimalv3_as_bool.out 
b/regression-test/data/correctness/test_cast_decimalv3_as_bool.out
new file mode 100644
index 00000000000..4f41130b00b
--- /dev/null
+++ b/regression-test/data/correctness/test_cast_decimalv3_as_bool.out
@@ -0,0 +1,17 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !select1 --
+0.000  13131.213132100 0E-16
+0.000  2131231.231000000       2.3323000E-9
+3.141  0E-9    123123.2131231231322130
+
+-- !select2 --
+false  true    false
+false  true    true
+true   false   true
+
+-- !select3 --
+true   1       true    false
+
+-- !select3 --
+true   1       true    false
+
diff --git 
a/regression-test/suites/correctness/test_cast_decimalv3_as_bool.groovy 
b/regression-test/suites/correctness/test_cast_decimalv3_as_bool.groovy
new file mode 100644
index 00000000000..768da493251
--- /dev/null
+++ b/regression-test/suites/correctness/test_cast_decimalv3_as_bool.groovy
@@ -0,0 +1,55 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_cast_decimalv3_as_bool") {
+     sql """ DROP TABLE IF EXISTS cast_decimalv3_as_bool """
+     sql """
+        CREATE TABLE IF NOT EXISTS cast_decimalv3_as_bool (
+            `id` int(11) ,
+            `k1` decimalv3(9,3) ,
+            `k2` decimalv3(18,9) ,
+            `k3` decimalv3(38,16) ,
+        )
+        UNIQUE KEY(`id`)
+        DISTRIBUTED BY HASH(`id`) BUCKETS 10
+        PROPERTIES (
+        "enable_unique_key_merge_on_write" = "true",
+        "replication_num" = "1"
+        );
+    """
+    sql """
+        set enable_nereids_planner=true,enable_fold_constant_by_be = false
+    """
+    sql """
+        INSERT INTO cast_decimalv3_as_bool VALUES
+        (1,0.00001,13131.2131321,0.000000000000000000),
+        (2,0.00000,2131231.231,0.0000000023323),
+        (3,3.141414,0.0000000000,123123.213123123132213);
+    """
+    qt_select1 """
+        select k1,k2,k3 from cast_decimalv3_as_bool order by id
+    """
+    qt_select2 """
+        select cast(k1 as boolean), cast(k2 as boolean) , cast(k3 as boolean) 
from cast_decimalv3_as_bool order by id
+    """ 
+    qt_select3"""
+        select cast(3.00001 as boolean),  cast(cast(3.00001 as  boolean) as 
int),cast(0.001 as boolean),cast(0.000 as boolean);
+    """
+        qt_select3"""
+        select cast(cast(3.00001 as double)as boolean),  
cast(cast(cast(3.00001 as double) as  boolean) as int),cast(cast(0.001 as 
double) as boolean),cast(cast(0.000 as double) as boolean);
+    """
+}
\ No newline at end of file
diff --git 
a/regression-test/suites/query_p0/sql_functions/case_function/test_case_function_null.groovy
 
b/regression-test/suites/query_p0/sql_functions/case_function/test_case_function_null.groovy
index 5138db6e73b..a91c86b5f48 100644
--- 
a/regression-test/suites/query_p0/sql_functions/case_function/test_case_function_null.groovy
+++ 
b/regression-test/suites/query_p0/sql_functions/case_function/test_case_function_null.groovy
@@ -185,10 +185,11 @@ suite("test_case_function_null", 
"query,p0,arrow_flight_sql") {
             c2,
             c1;
     """
-
+    // There is a behavior change. The 0.4cast boolean used to be 0 in the 
past, but now it has changed to 1.
+    // Therefore, we need to update the case accordingly.
     qt_sql_case1 """
         SELECT SUM(
-            CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.4716 AS BOOLEAN)))
+            CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.0 AS BOOLEAN)))
             WHEN ((CAST('-1530390546' AS VARCHAR)) LIKE ('-1678299490'))
             THEN (- (+ case_null2.c0))
             WHEN CASE (NULL IN (NULL))
@@ -197,9 +198,10 @@ suite("test_case_function_null", 
"query,p0,arrow_flight_sql") {
             END)
         FROM case_null2;
     """
-
+    // There is a behavior change. The 0.4cast boolean used to be 0 in the 
past, but now it has changed to 1.
+    // Therefore, we need to update the case accordingly.
     qt_sql_case2 """
-        SELECT SUM(CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.4716 AS 
BOOLEAN)))
+        SELECT SUM(CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.0 AS 
BOOLEAN)))
             WHEN ((CAST('-1530390546' AS VARCHAR)) LIKE ('-1678299490'))
             THEN (- (+ case_null2.c0))
             END)
@@ -209,9 +211,11 @@ suite("test_case_function_null", 
"query,p0,arrow_flight_sql") {
     sql "SET experimental_enable_nereids_planner=true"
     sql "SET enable_fallback_to_original_planner=false"
 
+    // There is a behavior change. The 0.4cast boolean used to be 0 in the 
past, but now it has changed to 1.
+    // Therefore, we need to update the case accordingly.
     qt_sql_case1 """
         SELECT SUM(
-            CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.4716 AS BOOLEAN)))
+            CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.0 AS BOOLEAN)))
             WHEN ((CAST('-1530390546' AS VARCHAR)) LIKE ('-1678299490'))
             THEN (- (+ case_null2.c0))
             WHEN CASE (NULL IN (NULL))
@@ -221,8 +225,10 @@ suite("test_case_function_null", 
"query,p0,arrow_flight_sql") {
         FROM case_null2;
     """
 
+    // There is a behavior change. The 0.4cast boolean used to be 0 in the 
past, but now it has changed to 1.
+    // Therefore, we need to update the case accordingly.
     qt_sql_case2 """
-        SELECT SUM(CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.4716 AS 
BOOLEAN)))
+        SELECT SUM(CASE (((NULL BETWEEN NULL AND NULL)) and (CAST(0.0 AS 
BOOLEAN)))
             WHEN ((CAST('-1530390546' AS VARCHAR)) LIKE ('-1678299490'))
             THEN (- (+ case_null2.c0))
             END)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to