This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit ccd21a6ea437338371bf026517421a8f16402bf6
Author: Benjaminwei <33219531+tap...@users.noreply.github.com>
AuthorDate: Mon Mar 11 21:22:32 2024 +0800

    [Improve](InPredict) enhance in predict with array type (#31828)
---
 be/src/vec/exprs/vin_predicate.cpp                 |  5 +-
 ...on_struct_in.cpp => function_collection_in.cpp} |  8 +--
 ...nction_struct_in.h => function_collection_in.h} | 65 +++++++++++++---------
 be/src/vec/functions/simple_function_factory.h     |  4 +-
 .../nereids/trees/expressions/InPredicate.java     | 12 ++++
 .../doris/nereids/util/TypeCoercionUtils.java      | 10 +++-
 .../nereids_syntax_p0/inpredicate_with_list.out    | 19 +++++++
 .../data/query_p0/sql_functions/test_in_expr.out   |  3 +
 .../nereids_syntax_p0/inpredicate_with_list.groovy | 60 ++++++++++++++++++++
 .../query_p0/sql_functions/test_in_expr.groovy     |  5 +-
 10 files changed, 150 insertions(+), 41 deletions(-)

diff --git a/be/src/vec/exprs/vin_predicate.cpp 
b/be/src/vec/exprs/vin_predicate.cpp
index 6f57828ef2d..838e12dc9c2 100644
--- a/be/src/vec/exprs/vin_predicate.cpp
+++ b/be/src/vec/exprs/vin_predicate.cpp
@@ -66,8 +66,9 @@ Status VInPredicate::prepare(RuntimeState* state, const 
RowDescriptor& desc,
     // construct the proper function_name
     std::string head(_is_not_in ? "not_" : "");
     std::string real_function_name = head + std::string(function_name);
-    if (is_struct(remove_nullable(argument_template[0].type))) {
-        real_function_name = "struct_" + real_function_name;
+    auto arg_type = remove_nullable(argument_template[0].type);
+    if (is_struct(arg_type) || is_array(arg_type) || is_map(arg_type)) {
+        real_function_name = "collection_" + real_function_name;
     }
     _function = 
SimpleFunctionFactory::instance().get_function(real_function_name,
                                                                
argument_template, _data_type);
diff --git a/be/src/vec/functions/function_struct_in.cpp 
b/be/src/vec/functions/function_collection_in.cpp
similarity index 79%
rename from be/src/vec/functions/function_struct_in.cpp
rename to be/src/vec/functions/function_collection_in.cpp
index 943e3a80f47..e7d6b56d5c5 100644
--- a/be/src/vec/functions/function_struct_in.cpp
+++ b/be/src/vec/functions/function_collection_in.cpp
@@ -17,15 +17,15 @@
 // This file is copied from
 // and modified by Doris
 
-#include "vec/functions/function_struct_in.h"
+#include "vec/functions/function_collection_in.h"
 
 #include "vec/functions/simple_function_factory.h"
 
 namespace doris::vectorized {
 
-void register_function_struct_in(SimpleFunctionFactory& factory) {
-    factory.register_function<FunctionStructIn<false>>();
-    factory.register_function<FunctionStructIn<true>>();
+void register_function_collection_in(SimpleFunctionFactory& factory) {
+    factory.register_function<FunctionCollectionIn<false>>();
+    factory.register_function<FunctionCollectionIn<true>>();
 }
 
 } // namespace doris::vectorized
diff --git a/be/src/vec/functions/function_struct_in.h 
b/be/src/vec/functions/function_collection_in.h
similarity index 75%
rename from be/src/vec/functions/function_struct_in.h
rename to be/src/vec/functions/function_collection_in.h
index e3fb9dbe70e..1e86ce25b34 100644
--- a/be/src/vec/functions/function_struct_in.h
+++ b/be/src/vec/functions/function_collection_in.h
@@ -48,11 +48,11 @@ struct ColumnRowRef {
 
     // equals when call set insert, this operator will be used
     bool operator==(const ColumnRowRef& other) const {
-        return column->compare_at(row_idx, other.row_idx, *column, 0) == 0;
+        return column->compare_at(row_idx, other.row_idx, *other.column, 0) == 
0;
     }
     // compare
     bool operator<(const ColumnRowRef& other) const {
-        return column->compare_at(row_idx, other.row_idx, *column, 0) < 0;
+        return column->compare_at(row_idx, other.row_idx, *other.column, 0) < 
0;
     }
 
     // when call set find, will use hash to find
@@ -63,18 +63,18 @@ struct ColumnRowRef {
     }
 };
 
-struct StructInState {
-    ENABLE_FACTORY_CREATOR(StructInState)
+struct CollectionInState {
+    ENABLE_FACTORY_CREATOR(CollectionInState)
     std::unordered_set<ColumnRowRef, ColumnRowRef> args_set;
     bool null_in_set = false;
 };
 
 template <bool negative>
-class FunctionStructIn : public IFunction {
+class FunctionCollectionIn : public IFunction {
 public:
-    static constexpr auto name = negative ? "struct_not_in" : "struct_in";
+    static constexpr auto name = negative ? "collection_not_in" : 
"collection_in";
 
-    static FunctionPtr create() { return std::make_shared<FunctionStructIn>(); 
}
+    static FunctionPtr create() { return 
std::make_shared<FunctionCollectionIn>(); }
 
     String get_name() const override { return name; }
 
@@ -98,14 +98,17 @@ public:
         if (scope == FunctionContext::THREAD_LOCAL) {
             return Status::OK();
         }
-        std::shared_ptr<StructInState> state = 
std::make_shared<StructInState>();
+        int num_args = context->get_num_args();
+        DCHECK(num_args >= 1);
+
+        std::shared_ptr<CollectionInState> state = 
std::make_shared<CollectionInState>();
         context->set_function_state(scope, state);
-        DCHECK(context->get_num_args() >= 1);
+
         auto* col_desc = context->get_arg_type(0);
-        DataTypePtr args_type = 
DataTypeFactory::instance().create_data_type(*col_desc);
-        MutableColumnPtr column_struct_ptr_args = 
remove_nullable(args_type)->create_column();
-        NullMap null_map(context->get_num_args(), false);
-        for (int i = 1; i < context->get_num_args(); ++i) {
+        DataTypePtr args_type = 
DataTypeFactory::instance().create_data_type(*col_desc, false);
+        MutableColumnPtr args_column_ptr = args_type->create_column();
+
+        for (int i = 1; i < num_args; i++) {
             // FE should make element type consistent and
             // equalize the length of the elements in struct
             const auto& const_column_ptr = context->get_constant_col(i);
@@ -117,34 +120,33 @@ public:
                 auto* null_col = 
vectorized::check_and_get_column<vectorized::ColumnNullable>(col);
                 if (null_col->has_null()) {
                     state->null_in_set = true;
-                    null_map[i - 1] = true;
                 } else {
-                    
column_struct_ptr_args->insert_from(null_col->get_nested_column(), 0);
+                    
args_column_ptr->insert_from(null_col->get_nested_column(), 0);
                 }
             } else {
-                column_struct_ptr_args->insert_from(*col, 0);
+                args_column_ptr->insert_from(*col, 0);
             }
         }
-        ColumnPtr column_ptr = std::move(column_struct_ptr_args);
-        // make StructRef into set
-        for (size_t i = 1; i < context->get_num_args(); ++i) {
-            if (state->null_in_set && null_map[i - 1]) {
-                continue;
-            }
-            state->args_set.insert({column_ptr, i - 1});
+        ColumnPtr column_ptr = std::move(args_column_ptr);
+        // make collection ref into set
+        int col_size = column_ptr->size();
+        for (size_t i = 0; i < col_size; i++) {
+            state->args_set.insert({column_ptr, i});
         }
+
         return Status::OK();
     }
 
     Status execute_impl(FunctionContext* context, Block& block, const 
ColumnNumbers& arguments,
                         size_t result, size_t input_rows_count) const override 
{
-        auto in_state = reinterpret_cast<StructInState*>(
+        auto in_state = reinterpret_cast<CollectionInState*>(
                 context->get_function_state(FunctionContext::FRAGMENT_LOCAL));
         if (!in_state) {
-            return Status::RuntimeError("funciton context for function '{}' 
must have Set;",
+            return Status::RuntimeError("function context for function '{}' 
must have Set;",
                                         get_name());
         }
         const auto& args_set = in_state->args_set;
+        const bool null_in_set = in_state->null_in_set;
         auto res = ColumnUInt8::create();
         ColumnUInt8::Container& vec_res = res->get_data();
         vec_res.resize(input_rows_count);
@@ -155,15 +157,24 @@ public:
 
         const ColumnWithTypeAndName& left_arg = 
block.get_by_position(arguments[0]);
         const auto& [materialized_column, col_const] = 
unpack_if_const(left_arg.column);
+        auto materialized_column_not_null = materialized_column;
+        if (materialized_column_not_null->is_nullable()) {
+            materialized_column_not_null = assert_cast<ColumnPtr>(
+                    
vectorized::check_and_get_column<vectorized::ColumnNullable>(
+                            materialized_column_not_null)
+                            ->get_nested_column_ptr());
+        }
 
         for (size_t i = 0; i < input_rows_count; ++i) {
-            bool find = args_set.find({materialized_column, i}) != 
args_set.end();
+            bool find = args_set.find({materialized_column_not_null, i}) != 
args_set.end();
+
             if constexpr (negative) {
                 vec_res[i] = !find;
             } else {
                 vec_res[i] = find;
             }
-            if (in_state->null_in_set) {
+
+            if (null_in_set) {
                 vec_null_map_to[i] = negative == vec_res[i];
             } else {
                 vec_null_map_to[i] = false;
diff --git a/be/src/vec/functions/simple_function_factory.h 
b/be/src/vec/functions/simple_function_factory.h
index b1c1b394bff..a44861a2683 100644
--- a/be/src/vec/functions/simple_function_factory.h
+++ b/be/src/vec/functions/simple_function_factory.h
@@ -65,7 +65,7 @@ void 
register_function_running_difference(SimpleFunctionFactory& factory);
 void register_function_date_time_to_string(SimpleFunctionFactory& factory);
 void register_function_date_time_string_to_string(SimpleFunctionFactory& 
factory);
 void register_function_in(SimpleFunctionFactory& factory);
-void register_function_struct_in(SimpleFunctionFactory& factory);
+void register_function_collection_in(SimpleFunctionFactory& factory);
 void register_function_if(SimpleFunctionFactory& factory);
 void register_function_nullif(SimpleFunctionFactory& factory);
 void register_function_date_time_computation(SimpleFunctionFactory& factory);
@@ -246,7 +246,7 @@ public:
             register_function_time_of_function(instance);
             register_function_string(instance);
             register_function_in(instance);
-            register_function_struct_in(instance);
+            register_function_collection_in(instance);
             register_function_if(instance);
             register_function_nullif(instance);
             register_function_date_time_computation(instance);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
index f3ae9ce5b27..bcebdca4f5b 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/InPredicate.java
@@ -86,6 +86,18 @@ public class InPredicate extends Expression {
             }
             return;
         }
+
+        if (children().get(0).getDataType().isArrayType()) {
+            // we should check in value list is all list type
+            for (int i = 1; i < children().size(); i++) {
+                if (!children().get(i).getDataType().isArrayType() && 
!children().get(i).getDataType().isNullType()) {
+                    throw new AnalysisException("in predicate list should 
compare with struct type list, but got : "
+                            + children().get(i).getDataType().toSql());
+                }
+            }
+            return;
+        }
+
         children().forEach(c -> {
             if (c.getDataType().isObjectType()) {
                 throw new AnalysisException("in predicate could not contains 
object type: " + this.toSql());
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java
index 239972fc1d6..34d23aecf56 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java
@@ -954,7 +954,8 @@ public class TypeCoercionUtils {
         if (inPredicate.getOptions().stream().map(Expression::getDataType)
                 .allMatch(dt -> 
dt.equals(inPredicate.getCompareExpr().getDataType()))) {
             if (!supportCompare(inPredicate.getCompareExpr().getDataType())
-                    && 
!inPredicate.getCompareExpr().getDataType().isStructType()) {
+                    && 
!inPredicate.getCompareExpr().getDataType().isStructType() && 
!inPredicate.getCompareExpr()
+                    .getDataType().isArrayType()) {
                 throw new AnalysisException("data type " + 
inPredicate.getCompareExpr().getDataType()
                         + " could not used in InPredicate " + 
inPredicate.toSql());
             }
@@ -970,8 +971,13 @@ public class TypeCoercionUtils {
             throw new AnalysisException("data type " + optionalCommonType.get()
                     + " is not match " + 
inPredicate.getCompareExpr().getDataType() + " used in InPredicate");
         }
+        if (inPredicate.getCompareExpr().getDataType().isArrayType() && 
optionalCommonType.isPresent()
+                && !optionalCommonType.get().isArrayType()) {
+            throw new AnalysisException("data type " + optionalCommonType.get()
+                    + " is not match " + 
inPredicate.getCompareExpr().getDataType() + " used in InPredicate");
+        }
         if (optionalCommonType.isPresent() && 
!supportCompare(optionalCommonType.get())
-                && !optionalCommonType.get().isStructType()) {
+                && !optionalCommonType.get().isStructType() && 
!optionalCommonType.get().isArrayType()) {
             throw new AnalysisException("data type " + optionalCommonType.get()
                     + " could not used in InPredicate " + inPredicate.toSql());
         }
diff --git a/regression-test/data/nereids_syntax_p0/inpredicate_with_list.out 
b/regression-test/data/nereids_syntax_p0/inpredicate_with_list.out
new file mode 100644
index 00000000000..2aaf57f0cd5
--- /dev/null
+++ b/regression-test/data/nereids_syntax_p0/inpredicate_with_list.out
@@ -0,0 +1,19 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !select_default --
+1      [1, 2, 3]
+2      [4, 5, 6]
+
+-- !in_predicate_17 --
+
+-- !in_predicate_18 --
+
+-- !in_predicate_19 --
+2      [4, 5, 6]
+
+-- !in_predicate_20 --
+2      [4, 5, 6]
+
+-- !in_predicate_21 --
+
+-- !in_predicate_22 --
+
diff --git a/regression-test/data/query_p0/sql_functions/test_in_expr.out 
b/regression-test/data/query_p0/sql_functions/test_in_expr.out
index 6cb6bd195f6..43bb4b43bfb 100644
--- a/regression-test/data/query_p0/sql_functions/test_in_expr.out
+++ b/regression-test/data/query_p0/sql_functions/test_in_expr.out
@@ -73,3 +73,6 @@ d
 
 -- !select --
 
+-- !select --
+[1, 2, 3, 4, 5]        \N
+
diff --git 
a/regression-test/suites/nereids_syntax_p0/inpredicate_with_list.groovy 
b/regression-test/suites/nereids_syntax_p0/inpredicate_with_list.groovy
new file mode 100644
index 00000000000..58f98621529
--- /dev/null
+++ b/regression-test/suites/nereids_syntax_p0/inpredicate_with_list.groovy
@@ -0,0 +1,60 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("inpredicate_with_list") {
+    def tableName = "test_array"
+
+    sql """ DROP TABLE IF EXISTS ${tableName} """
+    sql """
+    CREATE TABLE IF NOT EXISTS ${tableName} (
+        `id` INT NULL,
+        `c_array` ARRAY<INT> NULL
+        ) DISTRIBUTED BY HASH(id) PROPERTIES("replication_num" = "1");
+    """
+
+    sql """ INSERT INTO ${tableName} (`id`,`c_array`) VALUES
+            (1, [1, 2, 3]),
+            (2, [4, 5, 6])
+        """
+    try {
+        qt_select_default """ SELECT * FROM ${tableName} t ORDER BY id; """
+
+        sql """ set enable_nereids_planner = true;"""
+        sql """ set enable_fallback_to_original_planner=false;"""
+        order_qt_in_predicate_17 """
+            SELECT * FROM ${tableName} where c_array in ([1,2], [1,3]);
+        """
+        order_qt_in_predicate_18 """
+            SELECT * FROM ${tableName} where c_array in (null, [1,3]);
+        """
+        order_qt_in_predicate_19 """
+            SELECT * FROM ${tableName} where c_array in ([1,2], [1,3], 
[4,5,6]);
+        """
+        order_qt_in_predicate_20 """
+            SELECT * FROM ${tableName} where c_array in ([1,2], null, [4,5,6]);
+        """
+        order_qt_in_predicate_21 """
+            SELECT * FROM ${tableName} where c_array in ([1,2], null, [4,5,3]);
+        """
+        order_qt_in_predicate_22 """
+            SELECT * FROM ${tableName} where c_array in ([1,2], [6,5,4], 
[4,5,3]);
+        """
+    } finally {
+        try_sql("DROP TABLE IF EXISTS ${tableName}")
+    }
+}
+
diff --git a/regression-test/suites/query_p0/sql_functions/test_in_expr.groovy 
b/regression-test/suites/query_p0/sql_functions/test_in_expr.groovy
index e3839a65c48..807fdd91513 100644
--- a/regression-test/suites/query_p0/sql_functions/test_in_expr.groovy
+++ b/regression-test/suites/query_p0/sql_functions/test_in_expr.groovy
@@ -145,10 +145,7 @@ suite("test_in_expr", "query,arrow_flight_sql") {
 
     sql """ INSERT INTO `array_in_test` VALUES (1, [1,2,3,4,5]); """
 
-    test {
-        sql """ select c_array, c_array in (null) from array_in_test; """
-        exception "errCode"
-    }
+    qt_select """ select c_array, c_array in (null) from array_in_test; """
 
     sql " drop table if exists `json_in_test` "
     sql """


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to