This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit ccd21a6ea437338371bf026517421a8f16402bf6 Author: Benjaminwei <33219531+tap...@users.noreply.github.com> AuthorDate: Mon Mar 11 21:22:32 2024 +0800 [Improve](InPredict) enhance in predict with array type (#31828) --- be/src/vec/exprs/vin_predicate.cpp | 5 +- ...on_struct_in.cpp => function_collection_in.cpp} | 8 +-- ...nction_struct_in.h => function_collection_in.h} | 65 +++++++++++++--------- be/src/vec/functions/simple_function_factory.h | 4 +- .../nereids/trees/expressions/InPredicate.java | 12 ++++ .../doris/nereids/util/TypeCoercionUtils.java | 10 +++- .../nereids_syntax_p0/inpredicate_with_list.out | 19 +++++++ .../data/query_p0/sql_functions/test_in_expr.out | 3 + .../nereids_syntax_p0/inpredicate_with_list.groovy | 60 ++++++++++++++++++++ .../query_p0/sql_functions/test_in_expr.groovy | 5 +- 10 files changed, 150 insertions(+), 41 deletions(-) diff --git a/be/src/vec/exprs/vin_predicate.cpp b/be/src/vec/exprs/vin_predicate.cpp index 6f57828ef2d..838e12dc9c2 100644 --- a/be/src/vec/exprs/vin_predicate.cpp +++ b/be/src/vec/exprs/vin_predicate.cpp @@ -66,8 +66,9 @@ Status VInPredicate::prepare(RuntimeState* state, const RowDescriptor& desc, // construct the proper function_name std::string head(_is_not_in ? "not_" : ""); std::string real_function_name = head + std::string(function_name); - if (is_struct(remove_nullable(argument_template[0].type))) { - real_function_name = "struct_" + real_function_name; + auto arg_type = remove_nullable(argument_template[0].type); + if (is_struct(arg_type) || is_array(arg_type) || is_map(arg_type)) { + real_function_name = "collection_" + real_function_name; } _function = SimpleFunctionFactory::instance().get_function(real_function_name, argument_template, _data_type); diff --git a/be/src/vec/functions/function_struct_in.cpp b/be/src/vec/functions/function_collection_in.cpp similarity index 79% rename from be/src/vec/functions/function_struct_in.cpp rename to be/src/vec/functions/function_collection_in.cpp index 943e3a80f47..e7d6b56d5c5 100644 --- a/be/src/vec/functions/function_struct_in.cpp +++ b/be/src/vec/functions/function_collection_in.cpp @@ -17,15 +17,15 @@ // This file is copied from // and modified by Doris -#include "vec/functions/function_struct_in.h" +#include "vec/functions/function_collection_in.h" #include "vec/functions/simple_function_factory.h" namespace doris::vectorized { -void register_function_struct_in(SimpleFunctionFactory& factory) { - factory.register_function<FunctionStructIn<false>>(); - factory.register_function<FunctionStructIn<true>>(); +void register_function_collection_in(SimpleFunctionFactory& factory) { + factory.register_function<FunctionCollectionIn<false>>(); + factory.register_function<FunctionCollectionIn<true>>(); } } // namespace doris::vectorized diff --git a/be/src/vec/functions/function_struct_in.h b/be/src/vec/functions/function_collection_in.h similarity index 75% rename from be/src/vec/functions/function_struct_in.h rename to be/src/vec/functions/function_collection_in.h index e3fb9dbe70e..1e86ce25b34 100644 --- a/be/src/vec/functions/function_struct_in.h +++ b/be/src/vec/functions/function_collection_in.h @@ -48,11 +48,11 @@ struct ColumnRowRef { // equals when call set insert, this operator will be used bool operator==(const ColumnRowRef& other) const { - return column->compare_at(row_idx, other.row_idx, *column, 0) == 0; + return column->compare_at(row_idx, other.row_idx, *other.column, 0) == 0; } // compare bool operator<(const ColumnRowRef& other) const { - return column->compare_at(row_idx, other.row_idx, *column, 0) < 0; + return column->compare_at(row_idx, other.row_idx, *other.column, 0) < 0; } // when call set find, will use hash to find @@ -63,18 +63,18 @@ struct ColumnRowRef { } }; -struct StructInState { - ENABLE_FACTORY_CREATOR(StructInState) +struct CollectionInState { + ENABLE_FACTORY_CREATOR(CollectionInState) std::unordered_set<ColumnRowRef, ColumnRowRef> args_set; bool null_in_set = false; }; template <bool negative> -class FunctionStructIn : public IFunction { +class FunctionCollectionIn : public IFunction { public: - static constexpr auto name = negative ? "struct_not_in" : "struct_in"; + static constexpr auto name = negative ? "collection_not_in" : "collection_in"; - static FunctionPtr create() { return std::make_shared<FunctionStructIn>(); } + static FunctionPtr create() { return std::make_shared<FunctionCollectionIn>(); } String get_name() const override { return name; } @@ -98,14 +98,17 @@ public: if (scope == FunctionContext::THREAD_LOCAL) { return Status::OK(); } - std::shared_ptr<StructInState> state = std::make_shared<StructInState>(); + int num_args = context->get_num_args(); + DCHECK(num_args >= 1); + + std::shared_ptr<CollectionInState> state = std::make_shared<CollectionInState>(); context->set_function_state(scope, state); - DCHECK(context->get_num_args() >= 1); + auto* col_desc = context->get_arg_type(0); - DataTypePtr args_type = DataTypeFactory::instance().create_data_type(*col_desc); - MutableColumnPtr column_struct_ptr_args = remove_nullable(args_type)->create_column(); - NullMap null_map(context->get_num_args(), false); - for (int i = 1; i < context->get_num_args(); ++i) { + DataTypePtr args_type = DataTypeFactory::instance().create_data_type(*col_desc, false); + MutableColumnPtr args_column_ptr = args_type->create_column(); + + for (int i = 1; i < num_args; i++) { // FE should make element type consistent and // equalize the length of the elements in struct const auto& const_column_ptr = context->get_constant_col(i); @@ -117,34 +120,33 @@ public: auto* null_col = vectorized::check_and_get_column<vectorized::ColumnNullable>(col); if (null_col->has_null()) { state->null_in_set = true; - null_map[i - 1] = true; } else { - column_struct_ptr_args->insert_from(null_col->get_nested_column(), 0); + args_column_ptr->insert_from(null_col->get_nested_column(), 0); } } else { - column_struct_ptr_args->insert_from(*col, 0); + args_column_ptr->insert_from(*col, 0); } } - ColumnPtr column_ptr = std::move(column_struct_ptr_args); - // make StructRef into set - for (size_t i = 1; i < context->get_num_args(); ++i) { - if (state->null_in_set && null_map[i - 1]) { - continue; - } - state->args_set.insert({column_ptr, i - 1}); + ColumnPtr column_ptr = std::move(args_column_ptr); + // make collection ref into set + int col_size = column_ptr->size(); + for (size_t i = 0; i < col_size; i++) { + state->args_set.insert({column_ptr, i}); } + return Status::OK(); } Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) const override { - auto in_state = reinterpret_cast<StructInState*>( + auto in_state = reinterpret_cast<CollectionInState*>( context->get_function_state(FunctionContext::FRAGMENT_LOCAL)); if (!in_state) { - return Status::RuntimeError("funciton context for function '{}' must have Set;", + return Status::RuntimeError("function context for function '{}' must have Set;", get_name()); } const auto& args_set = in_state->args_set; + const bool null_in_set = in_state->null_in_set; auto res = ColumnUInt8::create(); ColumnUInt8::Container& vec_res = res->get_data(); vec_res.resize(input_rows_count); @@ -155,15 +157,24 @@ public: const ColumnWithTypeAndName& left_arg = block.get_by_position(arguments[0]); const auto& [materialized_column, col_const] = unpack_if_const(left_arg.column); + auto materialized_column_not_null = materialized_column; + if (materialized_column_not_null->is_nullable()) { + materialized_column_not_null = assert_cast<ColumnPtr>( + vectorized::check_and_get_column<vectorized::ColumnNullable>( + materialized_column_not_null) + ->get_nested_column_ptr()); + } for (size_t i = 0; i < input_rows_count; ++i) { - bool find = args_set.find({materialized_column, i}) != args_set.end(); + bool find = args_set.find({materialized_column_not_null, i}) != args_set.end(); + if constexpr (negative) { vec_res[i] = !find; } else { vec_res[i] = find; } - if (in_state->null_in_set) { + + if (null_in_set) { vec_null_map_to[i] = negative == vec_res[i]; } else { vec_null_map_to[i] = false; diff --git a/be/src/vec/functions/simple_function_factory.h b/be/src/vec/functions/simple_function_factory.h index b1c1b394bff..a44861a2683 100644 --- a/be/src/vec/functions/simple_function_factory.h +++ b/be/src/vec/functions/simple_function_factory.h @@ -65,7 +65,7 @@ void register_function_running_difference(SimpleFunctionFactory& factory); void register_function_date_time_to_string(SimpleFunctionFactory& factory); void register_function_date_time_string_to_string(SimpleFunctionFactory& factory); void register_function_in(SimpleFunctionFactory& factory); -void register_function_struct_in(SimpleFunctionFactory& factory); +void register_function_collection_in(SimpleFunctionFactory& factory); void register_function_if(SimpleFunctionFactory& factory); void register_function_nullif(SimpleFunctionFactory& factory); void register_function_date_time_computation(SimpleFunctionFactory& factory); @@ -246,7 +246,7 @@ public: register_function_time_of_function(instance); register_function_string(instance); register_function_in(instance); - register_function_struct_in(instance); + register_function_collection_in(instance); register_function_if(instance); register_function_nullif(instance); register_function_date_time_computation(instance); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java index f3ae9ce5b27..bcebdca4f5b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java @@ -86,6 +86,18 @@ public class InPredicate extends Expression { } return; } + + if (children().get(0).getDataType().isArrayType()) { + // we should check in value list is all list type + for (int i = 1; i < children().size(); i++) { + if (!children().get(i).getDataType().isArrayType() && !children().get(i).getDataType().isNullType()) { + throw new AnalysisException("in predicate list should compare with struct type list, but got : " + + children().get(i).getDataType().toSql()); + } + } + return; + } + children().forEach(c -> { if (c.getDataType().isObjectType()) { throw new AnalysisException("in predicate could not contains object type: " + this.toSql()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java index 239972fc1d6..34d23aecf56 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java @@ -954,7 +954,8 @@ public class TypeCoercionUtils { if (inPredicate.getOptions().stream().map(Expression::getDataType) .allMatch(dt -> dt.equals(inPredicate.getCompareExpr().getDataType()))) { if (!supportCompare(inPredicate.getCompareExpr().getDataType()) - && !inPredicate.getCompareExpr().getDataType().isStructType()) { + && !inPredicate.getCompareExpr().getDataType().isStructType() && !inPredicate.getCompareExpr() + .getDataType().isArrayType()) { throw new AnalysisException("data type " + inPredicate.getCompareExpr().getDataType() + " could not used in InPredicate " + inPredicate.toSql()); } @@ -970,8 +971,13 @@ public class TypeCoercionUtils { throw new AnalysisException("data type " + optionalCommonType.get() + " is not match " + inPredicate.getCompareExpr().getDataType() + " used in InPredicate"); } + if (inPredicate.getCompareExpr().getDataType().isArrayType() && optionalCommonType.isPresent() + && !optionalCommonType.get().isArrayType()) { + throw new AnalysisException("data type " + optionalCommonType.get() + + " is not match " + inPredicate.getCompareExpr().getDataType() + " used in InPredicate"); + } if (optionalCommonType.isPresent() && !supportCompare(optionalCommonType.get()) - && !optionalCommonType.get().isStructType()) { + && !optionalCommonType.get().isStructType() && !optionalCommonType.get().isArrayType()) { throw new AnalysisException("data type " + optionalCommonType.get() + " could not used in InPredicate " + inPredicate.toSql()); } diff --git a/regression-test/data/nereids_syntax_p0/inpredicate_with_list.out b/regression-test/data/nereids_syntax_p0/inpredicate_with_list.out new file mode 100644 index 00000000000..2aaf57f0cd5 --- /dev/null +++ b/regression-test/data/nereids_syntax_p0/inpredicate_with_list.out @@ -0,0 +1,19 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !select_default -- +1 [1, 2, 3] +2 [4, 5, 6] + +-- !in_predicate_17 -- + +-- !in_predicate_18 -- + +-- !in_predicate_19 -- +2 [4, 5, 6] + +-- !in_predicate_20 -- +2 [4, 5, 6] + +-- !in_predicate_21 -- + +-- !in_predicate_22 -- + diff --git a/regression-test/data/query_p0/sql_functions/test_in_expr.out b/regression-test/data/query_p0/sql_functions/test_in_expr.out index 6cb6bd195f6..43bb4b43bfb 100644 --- a/regression-test/data/query_p0/sql_functions/test_in_expr.out +++ b/regression-test/data/query_p0/sql_functions/test_in_expr.out @@ -73,3 +73,6 @@ d -- !select -- +-- !select -- +[1, 2, 3, 4, 5] \N + diff --git a/regression-test/suites/nereids_syntax_p0/inpredicate_with_list.groovy b/regression-test/suites/nereids_syntax_p0/inpredicate_with_list.groovy new file mode 100644 index 00000000000..58f98621529 --- /dev/null +++ b/regression-test/suites/nereids_syntax_p0/inpredicate_with_list.groovy @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("inpredicate_with_list") { + def tableName = "test_array" + + sql """ DROP TABLE IF EXISTS ${tableName} """ + sql """ + CREATE TABLE IF NOT EXISTS ${tableName} ( + `id` INT NULL, + `c_array` ARRAY<INT> NULL + ) DISTRIBUTED BY HASH(id) PROPERTIES("replication_num" = "1"); + """ + + sql """ INSERT INTO ${tableName} (`id`,`c_array`) VALUES + (1, [1, 2, 3]), + (2, [4, 5, 6]) + """ + try { + qt_select_default """ SELECT * FROM ${tableName} t ORDER BY id; """ + + sql """ set enable_nereids_planner = true;""" + sql """ set enable_fallback_to_original_planner=false;""" + order_qt_in_predicate_17 """ + SELECT * FROM ${tableName} where c_array in ([1,2], [1,3]); + """ + order_qt_in_predicate_18 """ + SELECT * FROM ${tableName} where c_array in (null, [1,3]); + """ + order_qt_in_predicate_19 """ + SELECT * FROM ${tableName} where c_array in ([1,2], [1,3], [4,5,6]); + """ + order_qt_in_predicate_20 """ + SELECT * FROM ${tableName} where c_array in ([1,2], null, [4,5,6]); + """ + order_qt_in_predicate_21 """ + SELECT * FROM ${tableName} where c_array in ([1,2], null, [4,5,3]); + """ + order_qt_in_predicate_22 """ + SELECT * FROM ${tableName} where c_array in ([1,2], [6,5,4], [4,5,3]); + """ + } finally { + try_sql("DROP TABLE IF EXISTS ${tableName}") + } +} + diff --git a/regression-test/suites/query_p0/sql_functions/test_in_expr.groovy b/regression-test/suites/query_p0/sql_functions/test_in_expr.groovy index e3839a65c48..807fdd91513 100644 --- a/regression-test/suites/query_p0/sql_functions/test_in_expr.groovy +++ b/regression-test/suites/query_p0/sql_functions/test_in_expr.groovy @@ -145,10 +145,7 @@ suite("test_in_expr", "query,arrow_flight_sql") { sql """ INSERT INTO `array_in_test` VALUES (1, [1,2,3,4,5]); """ - test { - sql """ select c_array, c_array in (null) from array_in_test; """ - exception "errCode" - } + qt_select """ select c_array, c_array in (null) from array_in_test; """ sql " drop table if exists `json_in_test` " sql """ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org