This is an automated email from the ASF dual-hosted git repository. panxiaolei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 4eb2604789 [Bug](function) fix function define of Retention inconsist and change some static_cast to assert cast (#19455) 4eb2604789 is described below commit 4eb2604789983e640183c5b87bf6bdbbabfb33fb Author: Pxl <pxl...@qq.com> AuthorDate: Mon May 15 11:50:02 2023 +0800 [Bug](function) fix function define of Retention inconsist and change some static_cast to assert cast (#19455) 1. fix function define of `Retention` inconsist, this function return tinyint on `FE` and return uint8 on `BE` 2. make assert_cast support cast to derived 3. change some static cast to assert cast 4. support sum(bool)/avg(bool) --- be/src/olap/rowset/segment_v2/segment_iterator.cpp | 2 +- .../vec/aggregate_functions/aggregate_function.h | 46 +++++++++++----------- .../aggregate_function_approx_count_distinct.h | 6 +-- .../aggregate_functions/aggregate_function_avg.h | 8 ++-- .../aggregate_functions/aggregate_function_bit.h | 4 +- .../aggregate_function_bitmap.h | 16 ++++---- .../aggregate_function_collect.h | 4 +- .../aggregate_functions/aggregate_function_count.h | 2 +- .../aggregate_function_group_concat.h | 8 ++-- .../aggregate_function_histogram.h | 10 ++--- .../aggregate_function_hll_union_agg.h | 2 +- .../aggregate_function_min_max.h | 10 ++--- .../aggregate_function_orthogonal_bitmap.h | 26 ++++++------ .../aggregate_function_product.h | 12 +++--- .../aggregate_function_quantile_state.h | 6 +-- .../aggregate_function_retention.h | 4 +- .../aggregate_function_sequence_match.h | 4 +- .../aggregate_function_stddev.h | 6 +-- .../aggregate_functions/aggregate_function_sum.h | 4 +- .../aggregate_functions/aggregate_function_topn.h | 40 +++++++++---------- .../aggregate_function_window_funnel.h | 6 +-- be/src/vec/common/assert_cast.h | 5 +++ be/src/vec/core/field.h | 2 +- be/src/vec/exec/vaggregation_node.cpp | 4 +- .../vec/exprs/table_function/vexplode_bitmap.cpp | 10 ++--- .../vec/exprs/table_function/vexplode_numbers.cpp | 12 +++--- .../main/java/org/apache/doris/catalog/Type.java | 3 +- .../apache/doris/analysis/FunctionCallExpr.java | 3 +- .../java/org/apache/doris/catalog/FunctionSet.java | 14 ++++++- .../trees/expressions/functions/agg/Retention.java | 3 +- .../apache/doris/analysis/InsertArrayStmtTest.java | 2 +- 31 files changed, 151 insertions(+), 133 deletions(-) diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.cpp b/be/src/olap/rowset/segment_v2/segment_iterator.cpp index 76e443e78e..83a2698f33 100644 --- a/be/src/olap/rowset/segment_v2/segment_iterator.cpp +++ b/be/src/olap/rowset/segment_v2/segment_iterator.cpp @@ -1993,7 +1993,7 @@ void SegmentIterator::_output_index_result_column(uint16_t* sel_rowid_idx, uint1 if (!iter.second.first) { // predicate not in compound query block->get_by_name(iter.first).column = - vectorized::DataTypeUInt8().create_column_const(block->rows(), 1u); + vectorized::DataTypeUInt8().create_column_const(block->rows(), (uint8_t)1); continue; } _build_index_result_column(sel_rowid_idx, select_size, block, iter.first, diff --git a/be/src/vec/aggregate_functions/aggregate_function.h b/be/src/vec/aggregate_functions/aggregate_function.h index 377a84d804..73f3473d43 100644 --- a/be/src/vec/aggregate_functions/aggregate_function.h +++ b/be/src/vec/aggregate_functions/aggregate_function.h @@ -221,7 +221,7 @@ public: const size_t num_rows) const noexcept override { const size_t size_of_data_ = size_of_data(); for (size_t i = 0; i != num_rows; ++i) { - static_cast<const Derived*>(this)->destroy(place + size_of_data_ * i); + assert_cast<const Derived*>(this)->destroy(place + size_of_data_ * i); } } @@ -245,7 +245,7 @@ public: } auto iter = place_rows.begin(); while (iter != place_rows.end()) { - static_cast<const Derived*>(this)->add_many(iter->first, columns, iter->second, + assert_cast<const Derived*>(this)->add_many(iter->first, columns, iter->second, arena); iter++; } @@ -254,7 +254,7 @@ public: } for (size_t i = 0; i < batch_size; ++i) { - static_cast<const Derived*>(this)->add(places[i] + place_offset, columns, i, arena); + assert_cast<const Derived*>(this)->add(places[i] + place_offset, columns, i, arena); } } @@ -262,7 +262,7 @@ public: const IColumn** columns, Arena* arena) const override { for (size_t i = 0; i < batch_size; ++i) { if (places[i]) { - static_cast<const Derived*>(this)->add(places[i] + place_offset, columns, i, arena); + assert_cast<const Derived*>(this)->add(places[i] + place_offset, columns, i, arena); } } } @@ -270,7 +270,7 @@ public: void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns, Arena* arena) const override { for (size_t i = 0; i < batch_size; ++i) { - static_cast<const Derived*>(this)->add(place, columns, i, arena); + assert_cast<const Derived*>(this)->add(place, columns, i, arena); } } //now this is use for sum/count/avg/min/max win function, other win function should override this function in class @@ -281,28 +281,28 @@ public: frame_start = std::max<int64_t>(frame_start, partition_start); frame_end = std::min<int64_t>(frame_end, partition_end); for (int64_t i = frame_start; i < frame_end; ++i) { - static_cast<const Derived*>(this)->add(place, columns, i, arena); + assert_cast<const Derived*>(this)->add(place, columns, i, arena); } } void add_batch_range(size_t batch_begin, size_t batch_end, AggregateDataPtr place, const IColumn** columns, Arena* arena, bool has_null) override { for (size_t i = batch_begin; i <= batch_end; ++i) { - static_cast<const Derived*>(this)->add(place, columns, i, arena); + assert_cast<const Derived*>(this)->add(place, columns, i, arena); } } void insert_result_into_vec(const std::vector<AggregateDataPtr>& places, const size_t offset, IColumn& to, const size_t num_rows) const override { for (size_t i = 0; i != num_rows; ++i) { - static_cast<const Derived*>(this)->insert_result_into(places[i] + offset, to); + assert_cast<const Derived*>(this)->insert_result_into(places[i] + offset, to); } } void serialize_vec(const std::vector<AggregateDataPtr>& places, size_t offset, BufferWritable& buf, const size_t num_rows) const override { for (size_t i = 0; i != num_rows; ++i) { - static_cast<const Derived*>(this)->serialize(places[i] + offset, buf); + assert_cast<const Derived*>(this)->serialize(places[i] + offset, buf); buf.commit(); } } @@ -317,10 +317,10 @@ public: const size_t num_rows, Arena* arena) const override { char place[size_of_data()]; for (size_t i = 0; i != num_rows; ++i) { - static_cast<const Derived*>(this)->create(place); - DEFER({ static_cast<const Derived*>(this)->destroy(place); }); - static_cast<const Derived*>(this)->add(place, columns, i, arena); - static_cast<const Derived*>(this)->serialize(place, buf); + assert_cast<const Derived*>(this)->create(place); + DEFER({ assert_cast<const Derived*>(this)->destroy(place); }); + assert_cast<const Derived*>(this)->add(place, columns, i, arena); + assert_cast<const Derived*>(this)->serialize(place, buf); buf.commit(); } } @@ -334,23 +334,23 @@ public: void serialize_without_key_to_column(ConstAggregateDataPtr __restrict place, MutableColumnPtr& dst) const override { VectorBufferWriter writter(assert_cast<ColumnString&>(*dst)); - static_cast<const Derived*>(this)->serialize(place, writter); + assert_cast<const Derived*>(this)->serialize(place, writter); writter.commit(); } void deserialize_vec(AggregateDataPtr places, const ColumnString* column, Arena* arena, size_t num_rows) const override { - const auto size_of_data = static_cast<const Derived*>(this)->size_of_data(); + const auto size_of_data = assert_cast<const Derived*>(this)->size_of_data(); for (size_t i = 0; i != num_rows; ++i) { try { auto place = places + size_of_data * i; VectorBufferReader buffer_reader(column->get_data_at(i)); - static_cast<const Derived*>(this)->create(place); - static_cast<const Derived*>(this)->deserialize(place, buffer_reader, arena); + assert_cast<const Derived*>(this)->create(place); + assert_cast<const Derived*>(this)->deserialize(place, buffer_reader, arena); } catch (...) { for (int j = 0; j < i; ++j) { auto place = places + size_of_data * j; - static_cast<const Derived*>(this)->destroy(place); + assert_cast<const Derived*>(this)->destroy(place); } throw; } @@ -364,9 +364,9 @@ public: void merge_vec(const AggregateDataPtr* places, size_t offset, ConstAggregateDataPtr rhs, Arena* arena, const size_t num_rows) const override { - const auto size_of_data = static_cast<const Derived*>(this)->size_of_data(); + const auto size_of_data = assert_cast<const Derived*>(this)->size_of_data(); for (size_t i = 0; i != num_rows; ++i) { - static_cast<const Derived*>(this)->merge(places[i] + offset, rhs + size_of_data * i, + assert_cast<const Derived*>(this)->merge(places[i] + offset, rhs + size_of_data * i, arena); } } @@ -374,10 +374,10 @@ public: void merge_vec_selected(const AggregateDataPtr* places, size_t offset, ConstAggregateDataPtr rhs, Arena* arena, const size_t num_rows) const override { - const auto size_of_data = static_cast<const Derived*>(this)->size_of_data(); + const auto size_of_data = assert_cast<const Derived*>(this)->size_of_data(); for (size_t i = 0; i != num_rows; ++i) { if (places[i]) { - static_cast<const Derived*>(this)->merge(places[i] + offset, rhs + size_of_data * i, + assert_cast<const Derived*>(this)->merge(places[i] + offset, rhs + size_of_data * i, arena); } } @@ -420,7 +420,7 @@ public: char deserialized_data[size_of_data()]; AggregateDataPtr deserialized_place = (AggregateDataPtr)deserialized_data; - auto derived = static_cast<const Derived*>(this); + auto derived = assert_cast<const Derived*>(this); derived->create(deserialized_place); DEFER({ derived->destroy(deserialized_place); }); derived->deserialize(deserialized_place, buf, arena); diff --git a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h index 19c0efb13a..03e1cc3df1 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h +++ b/be/src/vec/aggregate_functions/aggregate_function_approx_count_distinct.h @@ -98,12 +98,12 @@ public: void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { if constexpr (IsFixLenColumnType<ColumnDataType>::value) { - auto column = static_cast<const ColumnDataType*>(columns[0]); + auto column = assert_cast<const ColumnDataType*>(columns[0]); auto value = column->get_element(row_num); this->data(place).add( HashUtil::murmur_hash64A((char*)&value, sizeof(value), HashUtil::MURMUR_SEED)); } else { - auto value = static_cast<const ColumnDataType*>(columns[0])->get_data_at(row_num); + auto value = assert_cast<const ColumnDataType*>(columns[0])->get_data_at(row_num); uint64_t hash_value = HashUtil::murmur_hash64A(value.data, value.size, HashUtil::MURMUR_SEED); this->data(place).add(hash_value); @@ -127,7 +127,7 @@ public: } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { - auto& column = static_cast<ColumnInt64&>(to); + auto& column = assert_cast<ColumnInt64&>(to); column.get_data().push_back(this->data(place).get()); } }; diff --git a/be/src/vec/aggregate_functions/aggregate_function_avg.h b/be/src/vec/aggregate_functions/aggregate_function_avg.h index 1528881203..bf9c71b90d 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_avg.h +++ b/be/src/vec/aggregate_functions/aggregate_function_avg.h @@ -82,7 +82,7 @@ struct AggregateFunctionAvgData { // to keep the same result with row vesion; see AggregateFunctions::decimalv2_avg_get_value if constexpr (IsDecimalV2<T> && IsDecimalV2<ResultT>) { DecimalV2Value decimal_val_count(count, 0); - DecimalV2Value decimal_val_sum(static_cast<Int128>(sum)); + DecimalV2Value decimal_val_sum(sum); DecimalV2Value cal_ret = decimal_val_sum / decimal_val_count; Decimal128 ret(cal_ret.value()); return ret; @@ -135,7 +135,7 @@ public: void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { - const auto& column = static_cast<const ColVecType&>(*columns[0]); + const auto& column = assert_cast<const ColVecType&>(*columns[0]); if constexpr (IsDecimalNumber<T>) { this->data(place).sum += column.get_data()[row_num].value; } else { @@ -169,7 +169,7 @@ public: } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { - auto& column = static_cast<ColVecResult&>(to); + auto& column = assert_cast<ColVecResult&>(to); column.get_data().push_back(this->data(place).template result<ResultType>()); } @@ -196,7 +196,7 @@ public: void streaming_agg_serialize_to_column(const IColumn** columns, MutableColumnPtr& dst, const size_t num_rows, Arena* arena) const override { auto* src_data = assert_cast<const ColVecType&>(*columns[0]).get_data().data(); - auto& dst_col = static_cast<ColumnFixedLengthObject&>(*dst); + auto& dst_col = assert_cast<ColumnFixedLengthObject&>(*dst); dst_col.set_item_size(sizeof(Data)); dst_col.resize(num_rows); auto* data = dst_col.get_data().data(); diff --git a/be/src/vec/aggregate_functions/aggregate_function_bit.h b/be/src/vec/aggregate_functions/aggregate_function_bit.h index 9538633721..6d2e67b14e 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bit.h +++ b/be/src/vec/aggregate_functions/aggregate_function_bit.h @@ -114,7 +114,7 @@ public: void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { - const auto& column = static_cast<const ColumnVector<T>&>(*columns[0]); + const auto& column = assert_cast<const ColumnVector<T>&>(*columns[0]); this->data(place).add(column.get_data()[row_num]); } @@ -135,7 +135,7 @@ public: } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { - auto& column = static_cast<ColumnVector<T>&>(to); + auto& column = assert_cast<ColumnVector<T>&>(to); column.get_data().push_back(this->data(place).get()); } }; diff --git a/be/src/vec/aggregate_functions/aggregate_function_bitmap.h b/be/src/vec/aggregate_functions/aggregate_function_bitmap.h index 98d8bb69f4..00d3517fa0 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_bitmap.h +++ b/be/src/vec/aggregate_functions/aggregate_function_bitmap.h @@ -165,14 +165,14 @@ public: void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { - const auto& column = static_cast<const ColVecType&>(*columns[0]); + const auto& column = assert_cast<const ColVecType&>(*columns[0]); this->data(place).add(column.get_data()[row_num]); } void add_many(AggregateDataPtr __restrict place, const IColumn** columns, std::vector<int>& rows, Arena*) const override { if constexpr (std::is_same_v<Op, AggregateFunctionBitmapUnionOp>) { - const auto& column = static_cast<const ColVecType&>(*columns[0]); + const auto& column = assert_cast<const ColVecType&>(*columns[0]); std::vector<const BitmapValue*> values; for (int i = 0; i < rows.size(); ++i) { values.push_back(&(column.get_data()[rows[i]])); @@ -197,7 +197,7 @@ public: } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { - auto& column = static_cast<ColVecResult&>(to); + auto& column = assert_cast<ColVecResult&>(to); column.get_data().push_back( const_cast<AggregateFunctionBitmapData<Op>&>(this->data(place)).get()); } @@ -229,11 +229,11 @@ public: auto& nullable_column = assert_cast<const ColumnNullable&>(*columns[0]); if (!nullable_column.is_null_at(row_num)) { const auto& column = - static_cast<const ColVecType&>(nullable_column.get_nested_column()); + assert_cast<const ColVecType&>(nullable_column.get_nested_column()); this->data(place).add(column.get_data()[row_num]); } } else { - const auto& column = static_cast<const ColVecType&>(*columns[0]); + const auto& column = assert_cast<const ColVecType&>(*columns[0]); this->data(place).add(column.get_data()[row_num]); } } @@ -243,7 +243,7 @@ public: if constexpr (arg_is_nullable && std::is_same_v<ColVecType, ColumnBitmap>) { auto& nullable_column = assert_cast<const ColumnNullable&>(*columns[0]); const auto& column = - static_cast<const ColVecType&>(nullable_column.get_nested_column()); + assert_cast<const ColVecType&>(nullable_column.get_nested_column()); std::vector<const BitmapValue*> values; for (int i = 0; i < rows.size(); ++i) { if (!nullable_column.is_null_at(rows[i])) { @@ -252,7 +252,7 @@ public: } this->data(place).add_batch(values); } else if constexpr (std::is_same_v<ColVecType, ColumnBitmap>) { - const auto& column = static_cast<const ColVecType&>(*columns[0]); + const auto& column = assert_cast<const ColVecType&>(*columns[0]); std::vector<const BitmapValue*> values; for (int i = 0; i < rows.size(); ++i) { values.push_back(&(column.get_data()[rows[i]])); @@ -277,7 +277,7 @@ public: void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { auto& value_data = const_cast<AggFunctionData&>(this->data(place)).get(); - auto& column = static_cast<ColVecResult&>(to); + auto& column = assert_cast<ColVecResult&>(to); column.get_data().push_back(value_data.cardinality()); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_collect.h b/be/src/vec/aggregate_functions/aggregate_function_collect.h index 206109e8b8..4e774b887e 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_collect.h +++ b/be/src/vec/aggregate_functions/aggregate_function_collect.h @@ -248,7 +248,7 @@ struct AggregateFunctionCollectListData<StringRef, HasLimit> { max_size = rhs.max_size; data->insert_range_from(*rhs.data, 0, - std::min(static_cast<size_t>(max_size - size()), rhs.size())); + std::min(assert_cast<size_t>(max_size - size()), rhs.size())); } else { data->insert_range_from(*rhs.data, 0, rhs.size()); } @@ -324,7 +324,7 @@ public: if constexpr (HasLimit::value) { if (data.max_size == -1) { data.max_size = - (UInt64) static_cast<const ColumnInt32*>(columns[1])->get_element(row_num); + (UInt64)assert_cast<const ColumnInt32*>(columns[1])->get_element(row_num); } if (data.size() >= data.max_size) { return; diff --git a/be/src/vec/aggregate_functions/aggregate_function_count.h b/be/src/vec/aggregate_functions/aggregate_function_count.h index 389e0f1df4..0b6ec8f2ba 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_count.h +++ b/be/src/vec/aggregate_functions/aggregate_function_count.h @@ -112,7 +112,7 @@ public: const size_t num_rows, Arena* arena) const override { auto& col = assert_cast<ColumnUInt64&>(*dst); col.resize(num_rows); - col.get_data().assign(num_rows, static_cast<UInt64>(1UL)); + col.get_data().assign(num_rows, assert_cast<UInt64>(1UL)); } void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column, diff --git a/be/src/vec/aggregate_functions/aggregate_function_group_concat.h b/be/src/vec/aggregate_functions/aggregate_function_group_concat.h index 6ddc600452..0e147de464 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_group_concat.h +++ b/be/src/vec/aggregate_functions/aggregate_function_group_concat.h @@ -98,7 +98,7 @@ struct AggregateFunctionGroupConcatImplStr { static const std::string separator; static void add(AggregateFunctionGroupConcatData& __restrict place, const IColumn** columns, size_t row_num) { - place.add(static_cast<const ColumnString&>(*columns[0]).get_data_at(row_num), + place.add(assert_cast<const ColumnString&>(*columns[0]).get_data_at(row_num), StringRef(separator.data(), separator.length())); } }; @@ -106,8 +106,8 @@ struct AggregateFunctionGroupConcatImplStr { struct AggregateFunctionGroupConcatImplStrStr { static void add(AggregateFunctionGroupConcatData& __restrict place, const IColumn** columns, size_t row_num) { - place.add(static_cast<const ColumnString&>(*columns[0]).get_data_at(row_num), - static_cast<const ColumnString&>(*columns[1]).get_data_at(row_num)); + place.add(assert_cast<const ColumnString&>(*columns[0]).get_data_at(row_num), + assert_cast<const ColumnString&>(*columns[1]).get_data_at(row_num)); } }; @@ -147,7 +147,7 @@ public: void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { std::string result = this->data(place).get(); - static_cast<ColumnString&>(to).insert_data(result.c_str(), result.length()); + assert_cast<ColumnString&>(to).insert_data(result.c_str(), result.length()); } }; diff --git a/be/src/vec/aggregate_functions/aggregate_function_histogram.h b/be/src/vec/aggregate_functions/aggregate_function_histogram.h index 4c3f6548f1..a1a8a346e9 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_histogram.h +++ b/be/src/vec/aggregate_functions/aggregate_function_histogram.h @@ -137,10 +137,10 @@ struct AggregateFunctionHistogramData { for (auto i = 0; i < pair_vector.size(); i++) { const auto& element = pair_vector[i]; if constexpr (std::is_same_v<T, std::string>) { - static_cast<ColumnString&>(to).insert_data(element.second.c_str(), + assert_cast<ColumnString&>(to).insert_data(element.second.c_str(), element.second.length()); } else { - static_cast<ColVecType&>(to).get_data().push_back(element.second); + assert_cast<ColVecType&>(to).get_data().push_back(element.second); } } } @@ -192,14 +192,14 @@ public: if (has_input_param) { this->data(place).set_parameters( - static_cast<const ColumnInt32*>(columns[1])->get_element(row_num)); + assert_cast<const ColumnInt32*>(columns[1])->get_element(row_num)); } if constexpr (std::is_same_v<T, std::string>) { this->data(place).add( - static_cast<const ColumnString&>(*columns[0]).get_data_at(row_num)); + assert_cast<const ColumnString&>(*columns[0]).get_data_at(row_num)); } else { - this->data(place).add(static_cast<const ColVecType&>(*columns[0]).get_data()[row_num]); + this->data(place).add(assert_cast<const ColVecType&>(*columns[0]).get_data()[row_num]); } } diff --git a/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.h b/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.h index a45d228882..d4a94a190a 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.h +++ b/be/src/vec/aggregate_functions/aggregate_function_hll_union_agg.h @@ -76,7 +76,7 @@ struct AggregateFunctionHLLData { void reset() { dst_hll.clear(); } void add(const IColumn* column, size_t row_num) { - const auto& sources = static_cast<const ColumnHLL&>(*column); + const auto& sources = assert_cast<const ColumnHLL&>(*column); dst_hll.merge(sources.get_element(row_num)); } }; diff --git a/be/src/vec/aggregate_functions/aggregate_function_min_max.h b/be/src/vec/aggregate_functions/aggregate_function_min_max.h index 2d1f7b7c00..c17aee4348 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_min_max.h +++ b/be/src/vec/aggregate_functions/aggregate_function_min_max.h @@ -355,7 +355,7 @@ public: } } else { if (capacity < rhs_size) { - capacity = static_cast<UInt32>(round_up_to_power_of_two_or_zero(rhs_size)); + capacity = round_up_to_power_of_two_or_zero(rhs_size); large_data = arena->alloc(capacity); } @@ -555,7 +555,7 @@ public: void deserialize_from_column(AggregateDataPtr places, const IColumn& column, Arena* arena, size_t num_rows) const override { if constexpr (Data::IsFixedLength) { - const auto& col = static_cast<const ColumnFixedLengthObject&>(column); + const auto& col = assert_cast<const ColumnFixedLengthObject&>(column); auto* column_data = reinterpret_cast<const Data*>(col.get_data().data()); Data* data = reinterpret_cast<Data*>(places); for (size_t i = 0; i != num_rows; ++i) { @@ -569,7 +569,7 @@ public: void serialize_to_column(const std::vector<AggregateDataPtr>& places, size_t offset, MutableColumnPtr& dst, const size_t num_rows) const override { if constexpr (Data::IsFixedLength) { - auto& dst_column = static_cast<ColumnFixedLengthObject&>(*dst); + auto& dst_column = assert_cast<ColumnFixedLengthObject&>(*dst); dst_column.resize(num_rows); auto* dst_data = reinterpret_cast<Data*>(dst_column.get_data().data()); for (size_t i = 0; i != num_rows; ++i) { @@ -583,7 +583,7 @@ public: void streaming_agg_serialize_to_column(const IColumn** columns, MutableColumnPtr& dst, const size_t num_rows, Arena* arena) const override { if constexpr (Data::IsFixedLength) { - auto& dst_column = static_cast<ColumnFixedLengthObject&>(*dst); + auto& dst_column = assert_cast<ColumnFixedLengthObject&>(*dst); dst_column.resize(num_rows); auto* dst_data = reinterpret_cast<Data*>(dst_column.get_data().data()); for (size_t i = 0; i != num_rows; ++i) { @@ -597,7 +597,7 @@ public: void deserialize_and_merge_from_column(AggregateDataPtr __restrict place, const IColumn& column, Arena* arena) const override { if constexpr (Data::IsFixedLength) { - const auto& col = static_cast<const ColumnFixedLengthObject&>(column); + const auto& col = assert_cast<const ColumnFixedLengthObject&>(column); auto* column_data = reinterpret_cast<const Data*>(col.get_data().data()); const size_t num_rows = column.size(); for (size_t i = 0; i != num_rows; ++i) { diff --git a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.h b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.h index 737088de60..0177f1c415 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.h +++ b/be/src/vec/aggregate_functions/aggregate_function_orthogonal_bitmap.h @@ -57,8 +57,8 @@ public: using ColVecData = std::conditional_t<IsNumber<T>, ColumnVector<T>, ColumnString>; void add(const IColumn** columns, size_t row_num) { - const auto& bitmap_col = static_cast<const ColumnBitmap&>(*columns[0]); - const auto& data_col = static_cast<const ColVecData&>(*columns[1]); + const auto& bitmap_col = assert_cast<const ColumnBitmap&>(*columns[0]); + const auto& data_col = assert_cast<const ColVecData&>(*columns[1]); const auto& bitmap_value = bitmap_col.get_element(row_num); if constexpr (IsNumber<T>) { @@ -75,7 +75,7 @@ public: if (first_init) { DCHECK(argument_size > 1); for (int idx = 2; idx < argument_size; ++idx) { - const auto& col = static_cast<const ColVecData&>(*columns[idx]); + const auto& col = assert_cast<const ColVecData&>(*columns[idx]); if constexpr (IsNumber<T>) { bitmap.add_key(col.get_element(row_num)); } @@ -119,7 +119,7 @@ public: } void get(IColumn& to) const { - auto& column = static_cast<ColumnBitmap&>(to); + auto& column = assert_cast<ColumnBitmap&>(to); column.get_data().emplace_back(result); } @@ -157,7 +157,7 @@ public: } void get(IColumn& to) const { - auto& column = static_cast<ColumnVector<Int64>&>(to); + auto& column = assert_cast<ColumnVector<Int64>&>(to); column.get_data().emplace_back(AggOrthBitmapBaseData<T>::bitmap.intersect_count()); } }; @@ -188,7 +188,7 @@ public: } void get(IColumn& to) const { - auto& column = static_cast<ColumnVector<Int64>&>(to); + auto& column = assert_cast<ColumnVector<Int64>&>(to); column.get_data().emplace_back(result ? result : AggOrthBitmapBaseData<T>::bitmap.intersect_count()); } @@ -203,8 +203,8 @@ public: using ColVecData = std::conditional_t<IsNumber<T>, ColumnVector<T>, ColumnString>; void add(const IColumn** columns, size_t row_num) { - const auto& bitmap_col = static_cast<const ColumnBitmap&>(*columns[0]); - const auto& data_col = static_cast<const ColVecData&>(*columns[1]); + const auto& bitmap_col = assert_cast<const ColumnBitmap&>(*columns[0]); + const auto& data_col = assert_cast<const ColVecData&>(*columns[1]); const auto& bitmap_value = bitmap_col.get_element(row_num); std::string update_key = data_col.get_data_at(row_num).to_string(); bitmap_expr_cal.update(update_key, bitmap_value); @@ -213,7 +213,7 @@ public: void init_add_key(const IColumn** columns, size_t row_num, int argument_size) { if (first_init) { DCHECK(argument_size > 1); - const auto& col = static_cast<const ColVecData&>(*columns[2]); + const auto& col = assert_cast<const ColVecData&>(*columns[2]); std::string expr = col.get_data_at(row_num).to_string(); bitmap_expr_cal.bitmap_calculation_init(expr); first_init = false; @@ -251,7 +251,7 @@ public: } void get(IColumn& to) const { - auto& column = static_cast<ColumnBitmap&>(to); + auto& column = assert_cast<ColumnBitmap&>(to); column.get_data().emplace_back(result); } @@ -285,7 +285,7 @@ public: } void get(IColumn& to) const { - auto& column = static_cast<ColumnVector<Int64>&>(to); + auto& column = assert_cast<ColumnVector<Int64>&>(to); column.get_data().emplace_back(result); } @@ -302,7 +302,7 @@ struct OrthBitmapUnionCountData { void init_add_key(const IColumn** columns, size_t row_num, int argument_size) {} void add(const IColumn** columns, size_t row_num) { - const auto& column = static_cast<const ColumnBitmap&>(*columns[0]); + const auto& column = assert_cast<const ColumnBitmap&>(*columns[0]); value |= column.get_data()[row_num]; } void merge(const OrthBitmapUnionCountData& rhs) { result += rhs.result; } @@ -315,7 +315,7 @@ struct OrthBitmapUnionCountData { void read(BufferReadable& buf) { read_binary(result, buf); } void get(IColumn& to) const { - auto& column = static_cast<ColumnVector<Int64>&>(to); + auto& column = assert_cast<ColumnVector<Int64>&>(to); column.get_data().emplace_back(result ? result : value.cardinality()); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_product.h b/be/src/vec/aggregate_functions/aggregate_function_product.h index 572dc19000..ba3f74d626 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_product.h +++ b/be/src/vec/aggregate_functions/aggregate_function_product.h @@ -55,15 +55,15 @@ struct AggregateFunctionProductData<Decimal128> { Decimal128 product {}; void add(Decimal128 value, Decimal128) { - DecimalV2Value decimal_product(static_cast<Int128>(product)); - DecimalV2Value decimal_value(static_cast<Int128>(value)); + DecimalV2Value decimal_product(product); + DecimalV2Value decimal_value(value); DecimalV2Value ret = decimal_product * decimal_value; memcpy(&product, &ret, sizeof(Decimal128)); } void merge(const AggregateFunctionProductData& other, Decimal128) { - DecimalV2Value decimal_product(static_cast<Int128>(product)); - DecimalV2Value decimal_value(static_cast<Int128>(other.product)); + DecimalV2Value decimal_product(product); + DecimalV2Value decimal_value(other.product); DecimalV2Value ret = decimal_product * decimal_value; memcpy(&product, &ret, sizeof(Decimal128)); } @@ -133,7 +133,7 @@ public: void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { - const auto& column = static_cast<const ColVecType&>(*columns[0]); + const auto& column = assert_cast<const ColVecType&>(*columns[0]); this->data(place).add(column.get_data()[row_num], multiplier); } @@ -160,7 +160,7 @@ public: } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { - auto& column = static_cast<ColVecResult&>(to); + auto& column = assert_cast<ColVecResult&>(to); column.get_data().push_back(this->data(place).get()); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_quantile_state.h b/be/src/vec/aggregate_functions/aggregate_function_quantile_state.h index 00f93abb41..e368c58238 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_quantile_state.h +++ b/be/src/vec/aggregate_functions/aggregate_function_quantile_state.h @@ -126,11 +126,11 @@ public: auto& nullable_column = assert_cast<const ColumnNullable&>(*columns[0]); if (!nullable_column.is_null_at(row_num)) { const auto& column = - static_cast<const ColVecType&>(nullable_column.get_nested_column()); + assert_cast<const ColVecType&>(nullable_column.get_nested_column()); this->data(place).add(column.get_data()[row_num]); } } else { - const auto& column = static_cast<const ColVecType&>(*columns[0]); + const auto& column = assert_cast<const ColVecType&>(*columns[0]); this->data(place).add(column.get_data()[row_num]); } } @@ -152,7 +152,7 @@ public: } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { - auto& column = static_cast<ColVecResult&>(to); + auto& column = assert_cast<ColVecResult&>(to); column.get_data().push_back( const_cast<AggregateFunctionQuantileStateData<Op, InternalType>&>(this->data(place)) .get()); diff --git a/be/src/vec/aggregate_functions/aggregate_function_retention.h b/be/src/vec/aggregate_functions/aggregate_function_retention.h index 8eb8db218a..f595a1ad72 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_retention.h +++ b/be/src/vec/aggregate_functions/aggregate_function_retention.h @@ -55,7 +55,7 @@ struct RetentionState { static constexpr size_t MAX_EVENTS = 32; uint8_t events[MAX_EVENTS] = {0}; - RetentionState() {} + RetentionState() = default; void reset() { for (int64_t i = 0; i < MAX_EVENTS; i++) { @@ -94,7 +94,7 @@ struct RetentionState { } void insert_result_into(IColumn& to, size_t events_size, const uint8_t* events) const { - auto& data_to = static_cast<ColumnUInt8&>(to).get_data(); + auto& data_to = assert_cast<ColumnUInt8&>(to).get_data(); ColumnArray::Offset64 current_offset = data_to.size(); data_to.resize(current_offset + events_size); diff --git a/be/src/vec/aggregate_functions/aggregate_function_sequence_match.h b/be/src/vec/aggregate_functions/aggregate_function_sequence_match.h index 3e7e0b65e9..064c1e9979 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_sequence_match.h +++ b/be/src/vec/aggregate_functions/aggregate_function_sequence_match.h @@ -201,7 +201,7 @@ private: using PatternActions = PODArrayWithStackMemory<PatternAction, 64>; - Derived& derived() { return static_cast<Derived&>(*this); } + Derived& derived() { return assert_cast<Derived&>(*this); } void parse_pattern() { actions.clear(); @@ -606,7 +606,7 @@ public: this->data(place).init(pattern, arg_count); const auto& timestamp = - static_cast<const ColumnVector<NativeType>&>(*columns[1]).get_data()[row_num]; + assert_cast<const ColumnVector<NativeType>&>(*columns[1]).get_data()[row_num]; typename AggregateFunctionSequenceMatchData<DateValueType, NativeType, Derived>::Events events; diff --git a/be/src/vec/aggregate_functions/aggregate_function_stddev.h b/be/src/vec/aggregate_functions/aggregate_function_stddev.h index 2b38badd7b..9578b37fbc 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_stddev.h +++ b/be/src/vec/aggregate_functions/aggregate_function_stddev.h @@ -108,7 +108,7 @@ struct BaseData { } void add(const IColumn* column, size_t row_num) { - const auto& sources = static_cast<const ColumnVector<T>&>(*column); + const auto& sources = assert_cast<const ColumnVector<T>&>(*column); double source_data = sources.get_data()[row_num]; double delta = source_data - mean; @@ -188,7 +188,7 @@ struct BaseDatadecimal { } void add(const IColumn* column, size_t row_num) { - const auto& sources = static_cast<const ColumnDecimal<T>&>(*column); + const auto& sources = assert_cast<const ColumnDecimal<T>&>(*column); Field field = sources[row_num]; auto decimal_field = field.template get<DecimalField<T>>(); int128_t value; @@ -265,7 +265,7 @@ struct SampData : Data { if (this->count == 1 || this->count == 0) { nullable_column.insert_default(); } else { - auto& col = static_cast<ColVecResult&>(nullable_column.get_nested_column()); + auto& col = assert_cast<ColVecResult&>(nullable_column.get_nested_column()); if constexpr (IsDecimalNumber<T>) { col.get_data().push_back(this->get_samp_result().value()); } else { diff --git a/be/src/vec/aggregate_functions/aggregate_function_sum.h b/be/src/vec/aggregate_functions/aggregate_function_sum.h index 4c4dd9de26..b56dca4c97 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_sum.h +++ b/be/src/vec/aggregate_functions/aggregate_function_sum.h @@ -94,7 +94,7 @@ public: void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { - const auto& column = static_cast<const ColVecType&>(*columns[0]); + const auto& column = assert_cast<const ColVecType&>(*columns[0]); this->data(place).add(column.get_data()[row_num]); } @@ -115,7 +115,7 @@ public: } void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { - auto& column = static_cast<ColVecResult&>(to); + auto& column = assert_cast<ColVecResult&>(to); column.get_data().push_back(this->data(place).get()); } diff --git a/be/src/vec/aggregate_functions/aggregate_function_topn.h b/be/src/vec/aggregate_functions/aggregate_function_topn.h index b9ecedab19..bf02cc1817 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_topn.h +++ b/be/src/vec/aggregate_functions/aggregate_function_topn.h @@ -196,10 +196,10 @@ struct AggregateFunctionTopNData { for (int i = 0; i < std::min((int)counter_vector.size(), top_num); i++) { const auto& element = counter_vector[i]; if constexpr (std::is_same_v<T, std::string>) { - static_cast<ColumnString&>(to).insert_data(element.second.c_str(), + assert_cast<ColumnString&>(to).insert_data(element.second.c_str(), element.second.length()); } else { - static_cast<ColVecType&>(to).get_data().push_back(element.second); + assert_cast<ColVecType&>(to).get_data().push_back(element.second); } } } @@ -214,17 +214,17 @@ struct AggregateFunctionTopNData { struct AggregateFunctionTopNImplInt { static void add(AggregateFunctionTopNData<std::string>& __restrict place, const IColumn** columns, size_t row_num) { - place.set_paramenters(static_cast<const ColumnInt32*>(columns[1])->get_element(row_num)); - place.add(static_cast<const ColumnString&>(*columns[0]).get_data_at(row_num)); + place.set_paramenters(assert_cast<const ColumnInt32*>(columns[1])->get_element(row_num)); + place.add(assert_cast<const ColumnString&>(*columns[0]).get_data_at(row_num)); } }; struct AggregateFunctionTopNImplIntInt { static void add(AggregateFunctionTopNData<std::string>& __restrict place, const IColumn** columns, size_t row_num) { - place.set_paramenters(static_cast<const ColumnInt32*>(columns[1])->get_element(row_num), - static_cast<const ColumnInt32*>(columns[2])->get_element(row_num)); - place.add(static_cast<const ColumnString&>(*columns[0]).get_data_at(row_num)); + place.set_paramenters(assert_cast<const ColumnInt32*>(columns[1])->get_element(row_num), + assert_cast<const ColumnInt32*>(columns[2])->get_element(row_num)); + place.add(assert_cast<const ColumnString&>(*columns[0]).get_data_at(row_num)); } }; @@ -237,17 +237,17 @@ struct AggregateFunctionTopNImplArray { size_t row_num) { if constexpr (has_default_param) { place.set_paramenters( - static_cast<const ColumnInt32*>(columns[1])->get_element(row_num), - static_cast<const ColumnInt32*>(columns[2])->get_element(row_num)); + assert_cast<const ColumnInt32*>(columns[1])->get_element(row_num), + assert_cast<const ColumnInt32*>(columns[2])->get_element(row_num)); } else { place.set_paramenters( - static_cast<const ColumnInt32*>(columns[1])->get_element(row_num)); + assert_cast<const ColumnInt32*>(columns[1])->get_element(row_num)); } if constexpr (std::is_same_v<T, std::string>) { - place.add(static_cast<const ColumnString&>(*columns[0]).get_data_at(row_num)); + place.add(assert_cast<const ColumnString&>(*columns[0]).get_data_at(row_num)); } else { - T val = static_cast<const ColVecType&>(*columns[0]).get_data()[row_num]; + T val = assert_cast<const ColVecType&>(*columns[0]).get_data()[row_num]; place.add(val); } } @@ -262,19 +262,19 @@ struct AggregateFunctionTopNImplWeight { size_t row_num) { if constexpr (has_default_param) { place.set_paramenters( - static_cast<const ColumnInt32*>(columns[2])->get_element(row_num), - static_cast<const ColumnInt32*>(columns[3])->get_element(row_num)); + assert_cast<const ColumnInt32*>(columns[2])->get_element(row_num), + assert_cast<const ColumnInt32*>(columns[3])->get_element(row_num)); } else { place.set_paramenters( - static_cast<const ColumnInt32*>(columns[2])->get_element(row_num)); + assert_cast<const ColumnInt32*>(columns[2])->get_element(row_num)); } if constexpr (std::is_same_v<T, std::string>) { - auto weight = static_cast<const ColumnVector<Int64>&>(*columns[1]).get_data()[row_num]; - place.add(static_cast<const ColumnString&>(*columns[0]).get_data_at(row_num), weight); + auto weight = assert_cast<const ColumnVector<Int64>&>(*columns[1]).get_data()[row_num]; + place.add(assert_cast<const ColumnString&>(*columns[0]).get_data_at(row_num), weight); } else { - T val = static_cast<const ColVecType&>(*columns[0]).get_data()[row_num]; - auto weight = static_cast<const ColumnVector<Int64>&>(*columns[1]).get_data()[row_num]; + T val = assert_cast<const ColVecType&>(*columns[0]).get_data()[row_num]; + auto weight = assert_cast<const ColumnVector<Int64>&>(*columns[1]).get_data()[row_num]; place.add(val, weight); } } @@ -325,7 +325,7 @@ public: void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { std::string result = this->data(place).get(); - static_cast<ColumnString&>(to).insert_data(result.c_str(), result.length()); + assert_cast<ColumnString&>(to).insert_data(result.c_str(), result.length()); } }; diff --git a/be/src/vec/aggregate_functions/aggregate_function_window_funnel.h b/be/src/vec/aggregate_functions/aggregate_function_window_funnel.h index e44f0204a5..54ebfe7b28 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_window_funnel.h +++ b/be/src/vec/aggregate_functions/aggregate_function_window_funnel.h @@ -203,15 +203,15 @@ public: void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { const auto& window = - static_cast<const ColumnVector<Int64>&>(*columns[0]).get_data()[row_num]; + assert_cast<const ColumnVector<Int64>&>(*columns[0]).get_data()[row_num]; // TODO: handle mode in the future. // be/src/olap/row_block2.cpp copy_data_to_column const auto& timestamp = - static_cast<const ColumnVector<NativeType>&>(*columns[2]).get_data()[row_num]; + assert_cast<const ColumnVector<NativeType>&>(*columns[2]).get_data()[row_num]; const int NON_EVENT_NUM = 3; for (int i = NON_EVENT_NUM; i < IAggregateFunction::get_argument_types().size(); i++) { const auto& is_set = - static_cast<const ColumnVector<UInt8>&>(*columns[i]).get_data()[row_num]; + assert_cast<const ColumnVector<UInt8>&>(*columns[i]).get_data()[row_num]; if (is_set) { this->data(place).add( binary_cast<NativeType, DateValueType>(timestamp), i - NON_EVENT_NUM, diff --git a/be/src/vec/common/assert_cast.h b/be/src/vec/common/assert_cast.h index 2ce18f2481..399996f8eb 100644 --- a/be/src/vec/common/assert_cast.h +++ b/be/src/vec/common/assert_cast.h @@ -39,6 +39,11 @@ To assert_cast(From&& from) { try { if constexpr (std::is_pointer_v<To>) { if (typeid(*from) == typeid(std::remove_pointer_t<To>)) return static_cast<To>(from); + if constexpr (std::is_pointer_v<From>) { + if (auto ptr = dynamic_cast<To>(from); ptr != nullptr) { + return ptr; + } + } } else { if (typeid(from) == typeid(To)) return static_cast<To>(from); } diff --git a/be/src/vec/core/field.h b/be/src/vec/core/field.h index 4c4896ce9a..9cc07669ec 100644 --- a/be/src/vec/core/field.h +++ b/be/src/vec/core/field.h @@ -1051,7 +1051,7 @@ struct NearestFieldTypeImpl<signed char> { }; template <> struct NearestFieldTypeImpl<unsigned char> { - using Type = UInt64; + using Type = Int64; }; template <> diff --git a/be/src/vec/exec/vaggregation_node.cpp b/be/src/vec/exec/vaggregation_node.cpp index 18ae7f42f2..20bf48fe47 100644 --- a/be/src/vec/exec/vaggregation_node.cpp +++ b/be/src/vec/exec/vaggregation_node.cpp @@ -678,7 +678,9 @@ Status AggregationNode::_get_without_key_result(RuntimeState* state, Block* bloc if (!column_type->equals(*data_types[i])) { if (!is_array(remove_nullable(column_type))) { DCHECK(column_type->is_nullable()); - DCHECK(!data_types[i]->is_nullable()); + DCHECK(!data_types[i]->is_nullable()) + << " column type: " << column_type->get_name() + << ", data type: " << data_types[i]->get_name(); DCHECK(remove_nullable(column_type)->equals(*data_types[i])) << " column type: " << remove_nullable(column_type)->get_name() << ", data type: " << data_types[i]->get_name(); diff --git a/be/src/vec/exprs/table_function/vexplode_bitmap.cpp b/be/src/vec/exprs/table_function/vexplode_bitmap.cpp index 7ffbb346f8..2f717b7413 100644 --- a/be/src/vec/exprs/table_function/vexplode_bitmap.cpp +++ b/be/src/vec/exprs/table_function/vexplode_bitmap.cpp @@ -76,14 +76,14 @@ void VExplodeBitmapTableFunction::get_value(MutableColumnPtr& column) { column->insert_default(); } else { if (_is_nullable) { - static_cast<ColumnInt64*>( - static_cast<ColumnNullable*>(column.get())->get_nested_column_ptr().get()) + assert_cast<ColumnInt64*>( + assert_cast<ColumnNullable*>(column.get())->get_nested_column_ptr().get()) ->insert_value(**_cur_iter); - static_cast<ColumnUInt8*>( - static_cast<ColumnNullable*>(column.get())->get_null_map_column_ptr().get()) + assert_cast<ColumnUInt8*>( + assert_cast<ColumnNullable*>(column.get())->get_null_map_column_ptr().get()) ->insert_default(); } else { - static_cast<ColumnInt64*>(column.get())->insert_value(**_cur_iter); + assert_cast<ColumnInt64*>(column.get())->insert_value(**_cur_iter); } } } diff --git a/be/src/vec/exprs/table_function/vexplode_numbers.cpp b/be/src/vec/exprs/table_function/vexplode_numbers.cpp index 50bc6cca7d..022c0f13a5 100644 --- a/be/src/vec/exprs/table_function/vexplode_numbers.cpp +++ b/be/src/vec/exprs/table_function/vexplode_numbers.cpp @@ -54,7 +54,7 @@ Status VExplodeNumbersTableFunction::process_init(Block* block) { auto& column_nested = assert_cast<const ColumnConst&>(*_value_column).get_data_column_ptr(); if (column_nested->is_nullable()) { if (!column_nested->is_null_at(0)) { - _cur_size = static_cast<const ColumnNullable*>(column_nested.get()) + _cur_size = assert_cast<const ColumnNullable*>(column_nested.get()) ->get_nested_column() .get_int(0); } @@ -95,14 +95,14 @@ void VExplodeNumbersTableFunction::get_value(MutableColumnPtr& column) { column->insert_default(); } else { if (_is_nullable) { - static_cast<ColumnInt32*>( - static_cast<ColumnNullable*>(column.get())->get_nested_column_ptr().get()) + assert_cast<ColumnInt32*>( + assert_cast<ColumnNullable*>(column.get())->get_nested_column_ptr().get()) ->insert_value(_cur_offset); - static_cast<ColumnUInt8*>( - static_cast<ColumnNullable*>(column.get())->get_null_map_column_ptr().get()) + assert_cast<ColumnUInt8*>( + assert_cast<ColumnNullable*>(column.get())->get_null_map_column_ptr().get()) ->insert_default(); } else { - static_cast<ColumnInt32*>(column.get())->insert_value(_cur_offset); + assert_cast<ColumnInt32*>(column.get())->insert_value(_cur_offset); } } } diff --git a/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java b/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java index 3479230e17..8ab42d8fe7 100644 --- a/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java +++ b/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java @@ -431,7 +431,8 @@ public abstract class Type { } public boolean isFixedPointType() { - return isScalarType(PrimitiveType.TINYINT) || isScalarType(PrimitiveType.SMALLINT) + return isScalarType(PrimitiveType.TINYINT) + || isScalarType(PrimitiveType.SMALLINT) || isScalarType(PrimitiveType.INT) || isScalarType(PrimitiveType.BIGINT) || isScalarType(PrimitiveType.LARGEINT); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index fa68a3085b..9ec4358bec 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -819,7 +819,8 @@ public class FunctionCallExpr extends Expr { // SUM and AVG cannot be applied to non-numeric types if ((fnName.getFunction().equalsIgnoreCase("sum") || fnName.getFunction().equalsIgnoreCase("avg")) - && ((!arg.type.isNumericType() && !arg.type.isNull()) || arg.type.isOnlyMetricType())) { + && ((!arg.type.isNumericType() && !arg.type.isNull() && !arg.type.isBoolean()) + || arg.type.isOnlyMetricType())) { throw new AnalysisException(fnName.getFunction() + " requires a numeric parameter: " + this.toSql()); } // DecimalV3 scale lower than DEFAULT_MIN_AVG_DECIMAL128_SCALE should do cast diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index b633e4ec7c..3b8818bc95 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -1663,8 +1663,7 @@ public class FunctionSet<T> { // retention vectorization addBuiltin(AggregateFunction.createBuiltin(FunctionSet.RETENTION, Lists.newArrayList(Type.BOOLEAN), - // Type.BOOLEAN will return non-numeric results so we use Type.TINYINT - new ArrayType(Type.TINYINT), + new ArrayType(Type.BOOLEAN), Type.VARCHAR, true, "", @@ -2458,6 +2457,13 @@ public class FunctionSet<T> { null, false, true, false)); // vectorized + addBuiltin(AggregateFunction.createBuiltin(name, + Lists.<Type>newArrayList(Type.BOOLEAN), Type.BIGINT, Type.BIGINT, initNull, + "", + "", + null, null, + "", + null, false, true, false, true)); addBuiltin(AggregateFunction.createBuiltin(name, Lists.<Type>newArrayList(Type.TINYINT), Type.BIGINT, Type.BIGINT, initNull, prefix + "3sumIN9doris_udf9BigIntValES3_EEvPNS2_15FunctionContextERKT_PT0_", @@ -2913,6 +2919,10 @@ public class FunctionSet<T> { false, true, false)); // vectorized avg + addBuiltin(AggregateFunction.createBuiltin("avg", + Lists.<Type>newArrayList(Type.BOOLEAN), Type.DOUBLE, Type.TINYINT, + "", "", "", "", "", "", "", + false, true, false, true)); addBuiltin(AggregateFunction.createBuiltin("avg", Lists.<Type>newArrayList(Type.TINYINT), Type.DOUBLE, Type.TINYINT, "", "", "", "", "", "", "", diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Retention.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Retention.java index 1ae3d08e68..ec3d4b9594 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Retention.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Retention.java @@ -25,7 +25,6 @@ import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSi import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.ArrayType; import org.apache.doris.nereids.types.BooleanType; -import org.apache.doris.nereids.types.TinyIntType; import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.base.Preconditions; @@ -40,7 +39,7 @@ public class Retention extends AggregateFunction implements ExplicitlyCastableSignature, AlwaysNotNullable { public static final List<FunctionSignature> SIGNATURES = ImmutableList.of( - FunctionSignature.ret(ArrayType.of(TinyIntType.INSTANCE)).varArgs(BooleanType.INSTANCE) + FunctionSignature.ret(ArrayType.of(BooleanType.INSTANCE)).varArgs(BooleanType.INSTANCE) ); /** diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/InsertArrayStmtTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/InsertArrayStmtTest.java index 58a7f7a013..f74fcd3432 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/analysis/InsertArrayStmtTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/InsertArrayStmtTest.java @@ -146,7 +146,7 @@ public class InsertArrayStmtTest { stmtExecutor = new StmtExecutor(connectContext, insertStmt); stmtExecutor.execute(); QueryState state = connectContext.getState(); - Assert.assertEquals(MysqlStateType.OK, state.getStateType()); + Assert.assertEquals(state.getErrorMessage(), MysqlStateType.OK, state.getStateType()); } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org