This is an automated email from the ASF dual-hosted git repository.

kxiao pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.0 by this push:
     new 3d4497b9a4 [FIX](decimalv3)Fix decimalv3 with precision #24241 (#24326)
3d4497b9a4 is described below

commit 3d4497b9a4a282f70cde74aec1c591c424a7fec8
Author: amory <wangqian...@selectdb.com>
AuthorDate: Wed Sep 13 22:19:55 2023 +0800

    [FIX](decimalv3)Fix decimalv3 with precision #24241 (#24326)
---
 be/src/util/string_parser.hpp                      | 14 ++++---
 be/src/vec/data_types/data_type_decimal.cpp        | 18 ++++----
 .../data_types/serde/data_type_decimal_serde.cpp   | 18 ++++----
 be/src/vec/functions/function_cast.h               | 48 ++++++++++++++++------
 be/src/vec/io/io_helper.h                          | 10 +++--
 .../suites/query_p0/cast/test_cast.groovy          | 31 ++++++++++++++
 6 files changed, 101 insertions(+), 38 deletions(-)

diff --git a/be/src/util/string_parser.hpp b/be/src/util/string_parser.hpp
index a59d6bf7de..e96ac4e3c2 100644
--- a/be/src/util/string_parser.hpp
+++ b/be/src/util/string_parser.hpp
@@ -722,13 +722,10 @@ T StringParser::string_to_decimal(const char* s, int len, 
int type_precision, in
                     return 0;
                 }
                 *result = StringParser::PARSE_SUCCESS;
-                if constexpr (std::is_same_v<T, vectorized::Int128I>) {
-                    value *= get_scale_multiplier<__int128>(type_scale - 
scale);
-                } else {
+
+                if (type_scale > scale) {
                     value *= get_scale_multiplier<T>(type_scale - scale);
                 }
-
-                return is_negative ? T(-value) : T(value);
             }
         }
     }
@@ -762,6 +759,13 @@ T StringParser::string_to_decimal(const char* s, int len, 
int type_precision, in
     *result = StringParser::PARSE_SUCCESS;
     if (UNLIKELY(precision - scale > type_precision - type_scale)) {
         *result = StringParser::PARSE_OVERFLOW;
+        if constexpr (TYPE_DECIMALV2 != P) {
+            // decimalv3 overflow will return max min value for type precision
+            value = is_negative
+                            ? 
vectorized::min_decimal_value<vectorized::Decimal<T>>(type_precision)
+                            : 
vectorized::max_decimal_value<vectorized::Decimal<T>>(type_precision);
+            return value;
+        }
     } else if (UNLIKELY(scale > type_scale)) {
         *result = StringParser::PARSE_UNDERFLOW;
         int shift = scale - type_scale;
diff --git a/be/src/vec/data_types/data_type_decimal.cpp 
b/be/src/vec/data_types/data_type_decimal.cpp
index deb559f75b..5aea3bea2d 100644
--- a/be/src/vec/data_types/data_type_decimal.cpp
+++ b/be/src/vec/data_types/data_type_decimal.cpp
@@ -87,15 +87,17 @@ void DataTypeDecimal<T>::to_string(const IColumn& column, 
size_t row_num,
 template <typename T>
 Status DataTypeDecimal<T>::from_string(ReadBuffer& rb, IColumn* column) const {
     auto& column_data = static_cast<ColumnType&>(*column).get_data();
-    T val = 0;
-    if (!read_decimal_text_impl<DataTypeDecimalSerDe<T>::get_primitive_type(), 
T>(
-                val, rb, precision, scale)) {
-        return Status::InvalidArgument("parse decimal fail, string: '{}', 
primitive type: '{}'",
-                                       std::string(rb.position(), 
rb.count()).c_str(),
-                                       
DataTypeDecimalSerDe<T>::get_primitive_type());
+    T val {};
+    StringParser::ParseResult res =
+            
read_decimal_text_impl<DataTypeDecimalSerDe<T>::get_primitive_type(), T>(
+                    val, rb, precision, scale);
+    if (res == StringParser::PARSE_SUCCESS || res == 
StringParser::PARSE_UNDERFLOW) {
+        column_data.emplace_back(val);
+        return Status::OK();
     }
-    column_data.emplace_back(val);
-    return Status::OK();
+    return Status::InvalidArgument("parse decimal fail, string: '{}', 
primitive type: '{}'",
+                                   std::string(rb.position(), 
rb.count()).c_str(),
+                                   
DataTypeDecimalSerDe<T>::get_primitive_type());
 }
 
 // binary: row_num | value1 | value2 | ...
diff --git a/be/src/vec/data_types/serde/data_type_decimal_serde.cpp 
b/be/src/vec/data_types/serde/data_type_decimal_serde.cpp
index fdf9d41128..9ffe461f68 100644
--- a/be/src/vec/data_types/serde/data_type_decimal_serde.cpp
+++ b/be/src/vec/data_types/serde/data_type_decimal_serde.cpp
@@ -70,15 +70,17 @@ template <typename T>
 Status DataTypeDecimalSerDe<T>::deserialize_one_cell_from_json(IColumn& 
column, Slice& slice,
                                                                const 
FormatOptions& options) const {
     auto& column_data = assert_cast<ColumnDecimal<T>&>(column).get_data();
-    T val = 0;
-    if (ReadBuffer rb(slice.data, slice.size);
-        !read_decimal_text_impl<get_primitive_type(), T>(val, rb, precision, 
scale)) {
-        return Status::InvalidArgument("parse decimal fail, string: '{}', 
primitive type: '{}'",
-                                       std::string(rb.position(), 
rb.count()).c_str(),
-                                       get_primitive_type());
+    T val = {};
+    ReadBuffer rb(slice.data, slice.size);
+    StringParser::ParseResult res =
+            read_decimal_text_impl<get_primitive_type(), T>(val, rb, 
precision, scale);
+    if (res == StringParser::PARSE_SUCCESS || res == 
StringParser::PARSE_UNDERFLOW) {
+        column_data.emplace_back(val);
+        return Status::OK();
     }
-    column_data.emplace_back(val);
-    return Status::OK();
+    return Status::InvalidArgument("parse decimal fail, string: '{}', 
primitive type: '{}'",
+                                   std::string(rb.position(), 
rb.count()).c_str(),
+                                   get_primitive_type());
 }
 
 template <typename T>
diff --git a/be/src/vec/functions/function_cast.h 
b/be/src/vec/functions/function_cast.h
index 6a6a348e15..5d4a1716b1 100644
--- a/be/src/vec/functions/function_cast.h
+++ b/be/src/vec/functions/function_cast.h
@@ -117,6 +117,11 @@ inline UInt32 extract_to_decimal_scale(const 
ColumnWithTypeAndName& named_column
     named_column.column->get(0, field);
     return field.get<UInt32>();
 }
+
+struct PrecisionScaleArg {
+    UInt32 precision;
+    UInt32 scale;
+};
 /** Cast from string or number to Time.
   * In Doris, the underlying storage type of the Time class is Float64.
   */
@@ -270,7 +275,7 @@ struct ConvertImpl {
                     check_and_get_column<ColVecFrom>(named_from.column.get())) 
{
             typename ColVecTo::MutablePtr col_to = nullptr;
             if constexpr (IsDataTypeDecimal<ToDataType>) {
-                UInt32 scale = additions;
+                UInt32 scale = ((PrecisionScaleArg)additions).scale;
                 ToDataType::check_type_scale(scale);
                 col_to = ColVecTo::create(0, scale);
             } else {
@@ -897,25 +902,37 @@ bool try_parse_impl(typename DataType::FieldType& x, 
ReadBuffer& rb,
     if constexpr (std::is_integral_v<typename DataType::FieldType>) {
         return try_read_int_text(x, rb);
     }
+}
 
+template <typename DataType, typename Additions = void*>
+StringParser::ParseResult try_parse_decimal_impl(typename DataType::FieldType& 
x, ReadBuffer& rb,
+                                                 const cctz::time_zone& 
local_time_zone,
+                                                 ZoneList& time_zone_cache,
+                                                 std::shared_mutex& cache_lock,
+                                                 Additions additions
+                                                 [[maybe_unused]] = 
Additions()) {
     if constexpr (IsDataTypeDecimalV2<DataType>) {
-        UInt32 scale = additions;
-        return try_read_decimal_text<TYPE_DECIMALV2>(x, rb, 
DataType::max_precision(), scale);
+        UInt32 scale = ((PrecisionScaleArg)additions).scale;
+        UInt32 precision = ((PrecisionScaleArg)additions).precision;
+        return try_read_decimal_text<TYPE_DECIMALV2>(x, rb, precision, scale);
     }
 
     if constexpr (std::is_same_v<DataTypeDecimal<Decimal32>, DataType>) {
-        UInt32 scale = additions;
-        return try_read_decimal_text<TYPE_DECIMAL32>(x, rb, 
DataType::max_precision(), scale);
+        UInt32 scale = ((PrecisionScaleArg)additions).scale;
+        UInt32 precision = ((PrecisionScaleArg)additions).precision;
+        return try_read_decimal_text<TYPE_DECIMAL32>(x, rb, precision, scale);
     }
 
     if constexpr (std::is_same_v<DataTypeDecimal<Decimal64>, DataType>) {
-        UInt32 scale = additions;
-        return try_read_decimal_text<TYPE_DECIMAL64>(x, rb, 
DataType::max_precision(), scale);
+        UInt32 scale = ((PrecisionScaleArg)additions).scale;
+        UInt32 precision = ((PrecisionScaleArg)additions).precision;
+        return try_read_decimal_text<TYPE_DECIMAL64>(x, rb, precision, scale);
     }
 
     if constexpr (IsDataTypeDecimal128I<DataType>) {
-        UInt32 scale = additions;
-        return try_read_decimal_text<TYPE_DECIMAL128I>(x, rb, 
DataType::max_precision(), scale);
+        UInt32 scale = ((PrecisionScaleArg)additions).scale;
+        UInt32 precision = ((PrecisionScaleArg)additions).precision;
+        return try_read_decimal_text<TYPE_DECIMAL128I>(x, rb, precision, 
scale);
     }
 }
 
@@ -1346,7 +1363,7 @@ struct ConvertThroughParsing {
         typename ColVecTo::MutablePtr col_to = nullptr;
 
         if constexpr (IsDataTypeDecimal<ToDataType>) {
-            UInt32 scale = additions;
+            UInt32 scale = ((PrecisionScaleArg)additions).scale;
             ToDataType::check_type_scale(scale);
             col_to = ColVecTo::create(size, scale);
         } else {
@@ -1382,9 +1399,13 @@ struct ConvertThroughParsing {
 
             bool parsed;
             if constexpr (IsDataTypeDecimal<ToDataType>) {
-                parsed = try_parse_impl<ToDataType>(
+                
ToDataType::check_type_precision((PrecisionScaleArg(additions).precision));
+                StringParser::ParseResult res = 
try_parse_decimal_impl<ToDataType>(
                         vec_to[i], read_buffer, 
context->state()->timezone_obj(), time_zone_cache,
-                        cache_lock, vec_to.get_scale());
+                        cache_lock, PrecisionScaleArg(additions));
+                parsed = (res == StringParser::PARSE_SUCCESS ||
+                          res == StringParser::PARSE_OVERFLOW ||
+                          res == StringParser::PARSE_UNDERFLOW);
             } else if constexpr (IsDataTypeDateTimeV2<ToDataType>) {
                 auto type = check_and_get_data_type<DataTypeDateTimeV2>(
                         block.get_by_position(result).type.get());
@@ -1625,7 +1646,8 @@ private:
 
                         auto state = ConvertImpl<LeftDataType, RightDataType, 
NameCast>::execute(
                                 context, block, arguments, result, 
input_rows_count,
-                                context->check_overflow_for_decimal(), scale);
+                                context->check_overflow_for_decimal(),
+                                PrecisionScaleArg {precision, scale});
                         if (!state) {
                             throw Exception(state.code(), state.to_string());
                         }
diff --git a/be/src/vec/io/io_helper.h b/be/src/vec/io/io_helper.h
index c8bc6d7c7c..42918ca4c0 100644
--- a/be/src/vec/io/io_helper.h
+++ b/be/src/vec/io/io_helper.h
@@ -380,7 +380,8 @@ bool read_datetime_v2_text_impl(T& x, ReadBuffer& buf, 
const cctz::time_zone& lo
 }
 
 template <PrimitiveType P, typename T>
-bool read_decimal_text_impl(T& x, ReadBuffer& buf, UInt32 precision, UInt32 
scale) {
+StringParser::ParseResult read_decimal_text_impl(T& x, ReadBuffer& buf, UInt32 
precision,
+                                                 UInt32 scale) {
     static_assert(IsDecimalNumber<T>);
     if constexpr (!std::is_same_v<Decimal128, T>) {
         StringParser::ParseResult result = StringParser::PARSE_SUCCESS;
@@ -389,7 +390,7 @@ bool read_decimal_text_impl(T& x, ReadBuffer& buf, UInt32 
precision, UInt32 scal
                 (const char*)buf.position(), buf.count(), precision, scale, 
&result);
         // only to match the is_all_read() check to prevent return null
         buf.position() = buf.end();
-        return result == StringParser::PARSE_SUCCESS || result == 
StringParser::PARSE_UNDERFLOW;
+        return result;
     } else {
         StringParser::ParseResult result = StringParser::PARSE_SUCCESS;
 
@@ -400,7 +401,7 @@ bool read_decimal_text_impl(T& x, ReadBuffer& buf, UInt32 
precision, UInt32 scal
         // only to match the is_all_read() check to prevent return null
         buf.position() = buf.end();
 
-        return result == StringParser::PARSE_SUCCESS || result == 
StringParser::PARSE_UNDERFLOW;
+        return result;
     }
 }
 
@@ -450,7 +451,8 @@ bool try_read_float_text(T& x, ReadBuffer& in) {
 }
 
 template <PrimitiveType P, typename T>
-bool try_read_decimal_text(T& x, ReadBuffer& in, UInt32 precision, UInt32 
scale) {
+StringParser::ParseResult try_read_decimal_text(T& x, ReadBuffer& in, UInt32 
precision,
+                                                UInt32 scale) {
     return read_decimal_text_impl<P, T>(x, in, precision, scale);
 }
 
diff --git a/regression-test/suites/query_p0/cast/test_cast.groovy 
b/regression-test/suites/query_p0/cast/test_cast.groovy
index 59d86eb80e..1063845775 100644
--- a/regression-test/suites/query_p0/cast/test_cast.groovy
+++ b/regression-test/suites/query_p0/cast/test_cast.groovy
@@ -32,6 +32,37 @@ suite('test_cast') {
         result([[869930357, 20200101123445l, ((float) 20200101123445l), 
((double) 20200101123445l)]])
     }
 
+    test {
+        sql " select cast('9999e-1' as DECIMALV3(2, 1)) "
+        result([[9.9]])
+    }
+
+    test {
+        sql " select cast('100000' as DECIMALV3(2, 1)) "
+        result([[9.9]])
+    }
+
+    test {
+        sql " select cast('-9999e-1' as DECIMALV3(2, 1)) "
+        result([[-9.9]])
+    }
+
+
+    test {
+        sql " select cast('100000' as DECIMALV3(2, 1)) "
+        result([[9.9]])
+    }
+
+    test {
+        sql "select cast('0.2147483648e3' as DECIMALV3(2, 1))"
+        result([[9.9]])
+    }
+
+    test {
+        sql "select cast('0.2147483648e-3' as DECIMALV3(2, 1))"
+        result([[0.0]])
+    }
+
     def tbl = "test_cast"
 
     sql """ DROP TABLE IF EXISTS ${tbl}"""


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to