This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 3a49156e30 [performance] (vectorization)optimize In Expr (#11826)
3a49156e30 is described below

commit 3a49156e30f3a025e5100d388961a082843387b7
Author: wangbo <wan...@apache.org>
AuthorDate: Wed Aug 17 10:46:37 2022 +0800

    [performance] (vectorization)optimize In Expr (#11826)
    
    
    
    Co-authored-by: Wang Bo <wangb...@meituan.com>
---
 be/src/exprs/aggregate_functions.cpp               |   2 +-
 be/src/exprs/create_predicate_function.h           |   2 +-
 be/src/exprs/hybrid_set.h                          |  94 +++++++++++++++++-
 be/src/vec/exec/join/vhash_join_node.cpp           |   9 +-
 be/src/vec/exec/join/vhash_join_node.h             |   2 +
 be/src/vec/functions/in.h                          |  85 +++++++++++++---
 .../data/query/sql_functions/test_in_expr.out      |  45 +++++++++
 .../suites/query/sql_functions/test_in_expr.groovy | 110 +++++++++++++++++++++
 8 files changed, 330 insertions(+), 19 deletions(-)

diff --git a/be/src/exprs/aggregate_functions.cpp 
b/be/src/exprs/aggregate_functions.cpp
index 992bd42b74..2ae13ffaef 100644
--- a/be/src/exprs/aggregate_functions.cpp
+++ b/be/src/exprs/aggregate_functions.cpp
@@ -1397,7 +1397,7 @@ public:
     static const int STRING_LENGTH_RECORD_LENGTH = 4;
 
 private:
-    StringValueSet _set;
+    StringSet _set;
     // _type is serialized into buffer by one byte
     FunctionContext::Type _type;
 };
diff --git a/be/src/exprs/create_predicate_function.h 
b/be/src/exprs/create_predicate_function.h
index 6795777cc6..5aa0e2347f 100644
--- a/be/src/exprs/create_predicate_function.h
+++ b/be/src/exprs/create_predicate_function.h
@@ -39,7 +39,7 @@ public:
     template <PrimitiveType type>
     static BasePtr get_function() {
         using CppType = typename PrimitiveTypeTraits<type>::CppType;
-        using Set = std::conditional_t<std::is_same_v<CppType, StringValue>, 
StringValueSet,
+        using Set = std::conditional_t<std::is_same_v<CppType, StringValue>, 
StringSet,
                                        HybridSet<type, is_vec>>;
         return new (std::nothrow) Set();
     };
diff --git a/be/src/exprs/hybrid_set.h b/be/src/exprs/hybrid_set.h
index e43b2c29ea..20c9721675 100644
--- a/be/src/exprs/hybrid_set.h
+++ b/be/src/exprs/hybrid_set.h
@@ -145,11 +145,11 @@ private:
     ObjectPool _pool;
 };
 
-class StringValueSet : public HybridSetBase {
+class StringSet : public HybridSetBase {
 public:
-    StringValueSet() = default;
+    StringSet() = default;
 
-    ~StringValueSet() override = default;
+    ~StringSet() override = default;
 
     Status to_vexpr_list(doris::ObjectPool* pool,
                          std::vector<doris::vectorized::VExpr*>* vexpr_list, 
int precision,
@@ -179,7 +179,7 @@ public:
     }
 
     void insert(HybridSetBase* set) override {
-        StringValueSet* string_set = reinterpret_cast<StringValueSet*>(set);
+        StringSet* string_set = reinterpret_cast<StringSet*>(set);
         _set.insert(string_set->_set.begin(), string_set->_set.end());
     }
 
@@ -228,4 +228,90 @@ private:
     ObjectPool _pool;
 };
 
+// note: Two difference from StringSet
+// 1 StringValue has better comparison performance than std::string
+// 2 std::string keeps its own memory, bug StringValue just keeps ptr and len, 
so you the caller should manage memory of StringValue
+class StringValueSet : public HybridSetBase {
+public:
+    StringValueSet() = default;
+
+    ~StringValueSet() override = default;
+
+    Status to_vexpr_list(doris::ObjectPool* pool,
+                         std::vector<doris::vectorized::VExpr*>* vexpr_list, 
int precision,
+                         int scale) override {
+        HybridSetBase::IteratorBase* it = begin();
+        DCHECK(it != nullptr);
+        while (it->has_next()) {
+            TExprNode node;
+            const void* v = it->get_value();
+            create_texpr_literal_node<TYPE_STRING>(v, &node);
+            vexpr_list->push_back(pool->add(new 
doris::vectorized::VLiteral(node)));
+            it->next();
+        }
+        return Status::OK();
+    };
+
+    void insert(const void* data) override {
+        if (data == nullptr) return;
+
+        const auto* value = reinterpret_cast<const StringValue*>(data);
+        StringValue sv(value->ptr, value->len);
+        _set.insert(sv);
+    }
+    void insert(void* data, size_t size) override {
+        StringValue sv(reinterpret_cast<char*>(data), size);
+        _set.insert(sv);
+    }
+
+    void insert(HybridSetBase* set) override {
+        StringValueSet* string_set = reinterpret_cast<StringValueSet*>(set);
+        _set.insert(string_set->_set.begin(), string_set->_set.end());
+    }
+
+    int size() override { return _set.size(); }
+
+    bool find(void* data) override {
+        auto* value = reinterpret_cast<StringValue*>(data);
+        auto it = _set.find(*value);
+
+        return !(it == _set.end());
+    }
+
+    bool find(void* data, size_t size) override {
+        // std::string str_value(reinterpret_cast<char*>(data), size);
+        StringValue sv(reinterpret_cast<char*>(data), size);
+        auto it = _set.find(sv);
+        return !(it == _set.end());
+    }
+
+    class Iterator : public IteratorBase {
+    public:
+        Iterator(phmap::flat_hash_set<StringValue>::iterator begin,
+                 phmap::flat_hash_set<StringValue>::iterator end)
+                : _begin(begin), _end(end) {}
+        ~Iterator() override = default;
+        virtual bool has_next() const override { return !(_begin == _end); }
+        virtual const void* get_value() override {
+            _value.ptr = const_cast<char*>(_begin->ptr);
+            _value.len = _begin->len;
+            return &_value;
+        }
+        virtual void next() override { ++_begin; }
+
+    private:
+        typename phmap::flat_hash_set<StringValue>::iterator _begin;
+        typename phmap::flat_hash_set<StringValue>::iterator _end;
+        StringValue _value;
+    };
+
+    IteratorBase* begin() override {
+        return _pool.add(new (std::nothrow) Iterator(_set.begin(), 
_set.end()));
+    }
+
+private:
+    phmap::flat_hash_set<StringValue> _set;
+    ObjectPool _pool;
+};
+
 } // namespace doris
diff --git a/be/src/vec/exec/join/vhash_join_node.cpp 
b/be/src/vec/exec/join/vhash_join_node.cpp
index e7ae1c3a91..4fc36fd5ba 100644
--- a/be/src/vec/exec/join/vhash_join_node.cpp
+++ b/be/src/vec/exec/join/vhash_join_node.cpp
@@ -879,6 +879,8 @@ Status HashJoinNode::prepare(RuntimeState* state) {
     _build_side_output_timer = ADD_TIMER(probe_phase_profile, 
"ProbeWhenBuildSideOutputTime");
     _probe_side_output_timer = ADD_TIMER(probe_phase_profile, 
"ProbeWhenProbeSideOutputTime");
 
+    _join_filter_timer = ADD_TIMER(runtime_profile(), "JoinFilterTimer");
+
     _push_down_timer = ADD_TIMER(runtime_profile(), "PushDownTime");
     _push_compute_timer = ADD_TIMER(runtime_profile(), "PushDownComputeTime");
     _build_buckets_counter = ADD_COUNTER(runtime_profile(), "BuildBuckets", 
TUnit::UNIT);
@@ -1034,8 +1036,11 @@ Status HashJoinNode::get_next(RuntimeState* state, 
Block* output_block, bool* eo
     }
 
     _add_tuple_is_null_column(&temp_block);
-    RETURN_IF_ERROR(
-            VExprContext::filter_block(_vconjunct_ctx_ptr, &temp_block, 
temp_block.columns()));
+    {
+        SCOPED_TIMER(_join_filter_timer);
+        RETURN_IF_ERROR(
+                VExprContext::filter_block(_vconjunct_ctx_ptr, &temp_block, 
temp_block.columns()));
+    }
     RETURN_IF_ERROR(_build_output_block(&temp_block, output_block));
     _reset_tuple_is_null_column();
     reached_limit(output_block, eos);
diff --git a/be/src/vec/exec/join/vhash_join_node.h 
b/be/src/vec/exec/join/vhash_join_node.h
index 36dec27fbb..48cb54e67a 100644
--- a/be/src/vec/exec/join/vhash_join_node.h
+++ b/be/src/vec/exec/join/vhash_join_node.h
@@ -204,6 +204,8 @@ private:
     RuntimeProfile::Counter* _build_side_output_timer;
     RuntimeProfile::Counter* _probe_side_output_timer;
 
+    RuntimeProfile::Counter* _join_filter_timer;
+
     int64_t _hash_table_rows;
     int64_t _mem_used;
 
diff --git a/be/src/vec/functions/in.h b/be/src/vec/functions/in.h
index d10661fa48..bdf21cda96 100644
--- a/be/src/vec/functions/in.h
+++ b/be/src/vec/functions/in.h
@@ -67,8 +67,15 @@ public:
         }
         auto* state = new InState();
         context->set_function_state(scope, state);
-        state->hybrid_set.reset(
-                
vec_create_set(convert_type_to_primitive(context->get_arg_type(0)->type)));
+        if (context->get_arg_type(0)->type == FunctionContext::Type::TYPE_CHAR 
||
+            context->get_arg_type(0)->type == 
FunctionContext::Type::TYPE_VARCHAR ||
+            context->get_arg_type(0)->type == 
FunctionContext::Type::TYPE_STRING) {
+            // the StringValue's memory is held by FunctionContext, so we can 
use StringValueSet here directly
+            state->hybrid_set.reset(new StringValueSet());
+        } else {
+            state->hybrid_set.reset(
+                    
vec_create_set(convert_type_to_primitive(context->get_arg_type(0)->type)));
+        }
 
         DCHECK(context->get_num_args() >= 1);
         for (int i = 1; i < context->get_num_args(); ++i) {
@@ -109,18 +116,74 @@ public:
         auto materialized_column = 
left_arg.column->convert_to_full_column_if_const();
 
         if (in_state->use_set) {
-            for (size_t i = 0; i < input_rows_count; ++i) {
-                const auto& ref_data = materialized_column->get_data_at(i);
-                if (ref_data.data) {
-                    vec_res[i] = negative ^
-                                 
in_state->hybrid_set->find((void*)ref_data.data, ref_data.size);
-                    if (in_state->null_in_set) {
+            if (materialized_column->is_nullable()) {
+                auto* null_col_ptr = 
vectorized::check_and_get_column<vectorized::ColumnNullable>(
+                        materialized_column);
+                auto& null_bitmap = reinterpret_cast<const 
vectorized::ColumnUInt8&>(
+                                            
null_col_ptr->get_null_map_column())
+                                            .get_data();
+                auto* nested_col_ptr = 
null_col_ptr->get_nested_column_ptr().get();
+                auto search_hash_set = [&](auto* col_ptr) {
+                    for (size_t i = 0; i < input_rows_count; ++i) {
+                        const auto& ref_data = col_ptr->get_data_at(i);
+                        vec_res[i] =
+                                !null_bitmap[i] &&
+                                
in_state->hybrid_set->find((void*)ref_data.data, ref_data.size);
+                        if constexpr (negative) {
+                            vec_res[i] = !vec_res[i];
+                        }
+                    }
+                };
+
+                if (nested_col_ptr->is_column_string()) {
+                    const auto* column_string_ptr =
+                            reinterpret_cast<const 
vectorized::ColumnString*>(nested_col_ptr);
+                    search_hash_set(column_string_ptr);
+                } else {
+                    // todo support other column type
+                    search_hash_set(nested_col_ptr);
+                }
+
+                if (!in_state->null_in_set) {
+                    for (size_t i = 0; i < input_rows_count; ++i) {
+                        vec_null_map_to[i] = null_bitmap[i];
+                    }
+                } else {
+                    for (size_t i = 0; i < input_rows_count; ++i) {
+                        vec_null_map_to[i] = null_bitmap[i] || (negative == 
vec_res[i]);
+                    }
+                }
+
+            } else { // non-nullable
+
+                auto search_hash_set = [&](auto* col_ptr) {
+                    for (size_t i = 0; i < input_rows_count; ++i) {
+                        const auto& ref_data = col_ptr->get_data_at(i);
+                        vec_res[i] =
+                                
in_state->hybrid_set->find((void*)ref_data.data, ref_data.size);
+                        if constexpr (negative) {
+                            vec_res[i] = !vec_res[i];
+                        }
+                    }
+                };
+
+                if (materialized_column->is_column_string()) {
+                    const auto* column_string_ptr =
+                            reinterpret_cast<const vectorized::ColumnString*>(
+                                    materialized_column.get());
+                    search_hash_set(column_string_ptr);
+                } else {
+                    search_hash_set(materialized_column.get());
+                }
+
+                if (in_state->null_in_set) {
+                    for (size_t i = 0; i < input_rows_count; ++i) {
                         vec_null_map_to[i] = negative == vec_res[i];
-                    } else {
-                        vec_null_map_to[i] = false;
                     }
                 } else {
-                    vec_null_map_to[i] = true;
+                    for (size_t i = 0; i < input_rows_count; ++i) {
+                        vec_null_map_to[i] = false;
+                    }
                 }
             }
         } else {
diff --git a/regression-test/data/query/sql_functions/test_in_expr.out 
b/regression-test/data/query/sql_functions/test_in_expr.out
new file mode 100644
index 0000000000..5006d062dc
--- /dev/null
+++ b/regression-test/data/query/sql_functions/test_in_expr.out
@@ -0,0 +1,45 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !select --
+4
+
+-- !select --
+4
+
+-- !select --
+c
+
+-- !select --
+4
+
+-- !select --
+4
+
+-- !select --
+c
+
+-- !select --
+\N
+1
+2
+3
+
+-- !select --
+
+-- !select --
+\N
+a
+b
+d
+
+-- !select --
+1
+2
+3
+
+-- !select --
+
+-- !select --
+a
+b
+d
+
diff --git a/regression-test/suites/query/sql_functions/test_in_expr.groovy 
b/regression-test/suites/query/sql_functions/test_in_expr.groovy
new file mode 100644
index 0000000000..6efc094c3b
--- /dev/null
+++ b/regression-test/suites/query/sql_functions/test_in_expr.groovy
@@ -0,0 +1,110 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_in_expr", "query") {
+    def nullTableName = "in_expr_test_null"
+    def notNullTableName = "in_expr_test_not_null"
+
+    sql """DROP TABLE IF EXISTS ${nullTableName}"""
+    sql """
+            CREATE TABLE ${nullTableName} (
+              `cid` int(11) NULL,
+              `number` int(11) NULL,
+              `addr` varchar(256) NULL
+            ) ENGINE=OLAP
+            DUPLICATE KEY(`cid`)
+            COMMENT 'OLAP'
+            DISTRIBUTED BY HASH(`cid`) BUCKETS 1
+            PROPERTIES (
+            "replication_allocation" = "tag.location.default: 1",
+            "in_memory" = "false",
+            "storage_format" = "V2"
+            )
+        """
+    sql """ insert into ${nullTableName} 
values(100,1,'a'),(101,2,'b'),(102,3,'c'),(103,4,'d'),(104,null,'e'),(105,6, 
null) """
+
+
+    sql """DROP TABLE IF EXISTS ${notNullTableName}"""
+    sql """
+            CREATE TABLE ${notNullTableName} (
+              `cid` int(11) not NULL,
+              `number` int(11) not NULL,
+              `addr` varchar(256) not NULL
+            ) ENGINE=OLAP
+            DUPLICATE KEY(`cid`)
+            COMMENT 'OLAP'
+            DISTRIBUTED BY HASH(`cid`) BUCKETS 1
+            PROPERTIES (
+            "replication_allocation" = "tag.location.default: 1",
+            "in_memory" = "false",
+            "storage_format" = "V2"
+            )
+        """
+
+    sql """ insert into ${notNullTableName} 
values(100,1,'a'),(101,2,'b'),(102,3,'c'),(103,4,'d') """
+
+    sql """ set enable_vectorized_engine = true """
+
+    // 1 in expr
+    // 1.1 nullable
+    // 1.1.1 string + set_not_null
+    qt_select "select t1.number from ${nullTableName} t1 left join 
${nullTableName} t2 on t1.cid=t2.cid where t2.addr in ('d')"
+
+    // 1.1.2 string + null_in_set
+    qt_select "select t1.number from ${nullTableName} t1 left join 
${nullTableName} t2 on t1.cid=t2.cid where t2.addr in ('d', null)"
+
+    // 1.1.3 non-string
+    qt_select "select t1.addr from ${nullTableName} t1 left join 
${nullTableName} t2 on t1.cid=t2.cid where t2.number in (3)"
+
+    // 1.2 not null
+    // 1.2.1 string + set_not_null
+    qt_select "select t1.number from ${notNullTableName} t1 left join 
${notNullTableName} t2 on t1.cid=t2.cid where t2.addr in ('d')"
+
+    // 1.1.2 string + null_in_set
+    qt_select "select t1.number from ${notNullTableName} t1 left join 
${notNullTableName} t2 on t1.cid=t2.cid where t2.addr in ('d', null)"
+
+    // 1.1.3 non-string
+    qt_select "select t1.addr from ${notNullTableName} t1 left join 
${notNullTableName} t2 on t1.cid=t2.cid where t2.number in (3)"
+
+
+
+
+    // 2 not in expr
+    // 2.1 nullable
+    // 2.1.1 string + set_not_null
+    qt_select "select t1.number from ${nullTableName} t1 left join 
${nullTableName} t2 on t1.cid=t2.cid where t2.addr not in ('d') order by 
t1.number"
+
+    // 2.1.2 string + null_in_set
+    qt_select "select t1.number from ${nullTableName} t1 left join 
${nullTableName} t2 on t1.cid=t2.cid where t2.addr not in ('d', null) "
+
+    // 2.1.3 non-string
+    qt_select "select t1.addr from ${nullTableName} t1 left join 
${nullTableName} t2 on t1.cid=t2.cid where t2.number not in (3) order by 
t1.addr "
+
+    // 2.2 not null
+    // 2.2.1 string + set_not_null
+    qt_select "select t1.number from ${notNullTableName} t1 left join 
${notNullTableName} t2 on t1.cid=t2.cid where t2.addr not in ('d') order by 
t1.number "
+
+    // 2.1.2 string + null_in_set
+    qt_select "select t1.number from ${notNullTableName} t1 left join 
${notNullTableName} t2 on t1.cid=t2.cid where t2.addr not in ('d', null)"
+
+    // 2.1.3 non-string
+    qt_select "select t1.addr from ${notNullTableName} t1 left join 
${notNullTableName} t2 on t1.cid=t2.cid where t2.number not in (3) order by 
t1.addr "
+
+    sql """DROP TABLE IF EXISTS ${nullTableName}"""
+    sql """DROP TABLE IF EXISTS ${notNullTableName}"""
+
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to