yiguolei commented on code in PR #10380:
URL: https://github.com/apache/doris/pull/10380#discussion_r906689417


##########
be/src/olap/comparison_predicate.h:
##########
@@ -17,46 +17,475 @@
 
 #pragma once
 
-#include <stdint.h>
-
 #include "olap/column_predicate.h"
+#include "runtime/vectorized_row_batch.h"
+#include "vec/columns/column_dictionary.h"
 
 namespace doris {
 
-class VectorizedRowBatch;
-
-#define COMPARISON_PRED_CLASS_DEFINE(CLASS, PT)                                
                    \
-    template <class T>                                                         
                    \
-    class CLASS : public ColumnPredicate {                                     
                    \
-    public:                                                                    
                    \
-        CLASS(uint32_t column_id, const T& value, bool opposite = false);      
                    \
-        PredicateType type() const override { return PredicateType::PT; }      
                    \
-        virtual void evaluate(VectorizedRowBatch* batch) const override;       
                    \
-        void evaluate(ColumnBlock* block, uint16_t* sel, uint16_t* size) const 
override;           \
-        void evaluate_or(ColumnBlock* block, uint16_t* sel, uint16_t size,     
                    \
-                         bool* flags) const override;                          
                    \
-        void evaluate_and(ColumnBlock* block, uint16_t* sel, uint16_t size,    
                    \
-                          bool* flags) const override;                         
                    \
-        virtual Status evaluate(const Schema& schema,                          
                    \
-                                const std::vector<BitmapIndexIterator*>& 
iterators,                \
-                                uint32_t num_rows, roaring::Roaring* roaring) 
const override;      \
-        uint16_t evaluate(vectorized::IColumn& column, uint16_t* sel,          
                    \
-                          uint16_t size) const override;                       
                    \
-        void evaluate_and(vectorized::IColumn& column, uint16_t* sel, uint16_t 
size,               \
-                          bool* flags) const override;                         
                    \
-        void evaluate_or(vectorized::IColumn& column, uint16_t* sel, uint16_t 
size,                \
-                         bool* flags) const override;                          
                    \
-        void evaluate_vec(vectorized::IColumn& column, uint16_t size, bool* 
flags) const override; \
-                                                                               
                    \
-    private:                                                                   
                    \
-        T _value;                                                              
                    \
-    };
-
-COMPARISON_PRED_CLASS_DEFINE(EqualPredicate, EQ)
-COMPARISON_PRED_CLASS_DEFINE(NotEqualPredicate, NE)
-COMPARISON_PRED_CLASS_DEFINE(LessPredicate, LT)
-COMPARISON_PRED_CLASS_DEFINE(LessEqualPredicate, LE)
-COMPARISON_PRED_CLASS_DEFINE(GreaterPredicate, GT)
-COMPARISON_PRED_CLASS_DEFINE(GreaterEqualPredicate, GE)
+template <class T, PredicateType PT>
+class ComparisonPredicateBase : public ColumnPredicate {
+public:
+    ComparisonPredicateBase(uint32_t column_id, const T& value, bool opposite 
= false)
+            : ColumnPredicate(column_id, opposite), _value(value) {}
+
+    PredicateType type() const override { return PT; }
+
+    void evaluate(VectorizedRowBatch* batch) const override {
+        uint16_t n = batch->size();
+        if (n == 0) {
+            return;
+        }
+        uint16_t* sel = batch->selected();
+        const T* col_vector = reinterpret_cast<const 
T*>(batch->column(_column_id)->col_data());
+        uint16_t new_size = 0;
+        if (batch->column(_column_id)->no_nulls()) {
+            if (batch->selected_in_use()) {
+                for (uint16_t j = 0; j != n; ++j) {
+                    uint16_t i = sel[j];
+                    sel[new_size] = i;
+                    new_size += _operator(col_vector[i], _value);
+                }
+                batch->set_size(new_size);
+            } else {
+                for (uint16_t i = 0; i != n; ++i) {
+                    sel[new_size] = i;
+                    new_size += _operator(col_vector[i], _value);
+                }
+                if (new_size < n) {
+                    batch->set_size(new_size);
+                    batch->set_selected_in_use(true);
+                }
+            }
+        } else {
+            bool* is_null = batch->column(_column_id)->is_null();
+            if (batch->selected_in_use()) {
+                for (uint16_t j = 0; j != n; ++j) {
+                    uint16_t i = sel[j];
+                    sel[new_size] = i;
+                    new_size += (!is_null[i] && _operator(col_vector[i], 
_value));
+                }
+                batch->set_size(new_size);
+            } else {
+                for (uint16_t i = 0; i != n; ++i) {
+                    sel[new_size] = i;
+                    new_size += (!is_null[i] && _operator(col_vector[i], 
_value));
+                }
+                if (new_size < n) {
+                    batch->set_size(new_size);
+                    batch->set_selected_in_use(true);
+                }
+            }
+        }
+    }
+
+    void evaluate(ColumnBlock* block, uint16_t* sel, uint16_t* size) const 
override {
+        uint16_t new_size = 0;
+        if (block->is_nullable()) {
+            for (uint16_t i = 0; i < *size; ++i) {
+                uint16_t idx = sel[i];
+                sel[new_size] = idx;
+                const T* cell_value = reinterpret_cast<const 
T*>(block->cell(idx).cell_ptr());
+                auto result = (!block->cell(idx).is_null() && 
_operator(*cell_value, _value));
+                new_size += _opposite ? !result : result;
+            }
+        } else {
+            for (uint16_t i = 0; i < *size; ++i) {
+                uint16_t idx = sel[i];
+                sel[new_size] = idx;
+                const T* cell_value = reinterpret_cast<const 
T*>(block->cell(idx).cell_ptr());
+                auto result = _operator(*cell_value, _value);
+                new_size += _opposite ? !result : result;
+            }
+        }
+        *size = new_size;
+    }
+
+    void evaluate_or(ColumnBlock* block, uint16_t* sel, uint16_t size, bool* 
flags) const override {
+        if (block->is_nullable()) {
+            for (uint16_t i = 0; i < size; ++i) {
+                if (flags[i]) {
+                    continue;
+                }
+                uint16_t idx = sel[i];
+                const T* cell_value = reinterpret_cast<const 
T*>(block->cell(idx).cell_ptr());
+                auto result = (!block->cell(idx).is_null() && 
_operator(*cell_value, _value));
+                flags[i] = flags[i] | (_opposite ? !result : result);
+            }
+        } else {
+            for (uint16_t i = 0; i < size; ++i) {
+                if (flags[i]) {
+                    continue;
+                }
+                uint16_t idx = sel[i];
+                const T* cell_value = reinterpret_cast<const 
T*>(block->cell(idx).cell_ptr());
+                auto result = _operator(*cell_value, _value);
+                flags[i] = flags[i] | (_opposite ? !result : result);
+            }
+        }
+    }
+
+    void evaluate_and(ColumnBlock* block, uint16_t* sel, uint16_t size,
+                      bool* flags) const override {
+        if (block->is_nullable()) {
+            for (uint16_t i = 0; i < size; ++i) {
+                if (!flags[i]) {
+                    continue;
+                }
+                uint16_t idx = sel[i];
+                const T* cell_value = reinterpret_cast<const 
T*>(block->cell(idx).cell_ptr());
+                auto result = (!block->cell(idx).is_null() && 
_operator(*cell_value, _value));
+                flags[i] = flags[i] & (_opposite ? !result : result);
+            }
+        } else {
+            for (uint16_t i = 0; i < size; ++i) {
+                if (flags[i]) {
+                    continue;
+                }
+                uint16_t idx = sel[i];
+                const T* cell_value = reinterpret_cast<const 
T*>(block->cell(idx).cell_ptr());
+                auto result = _operator(*cell_value, _value);
+                flags[i] = flags[i] & (_opposite ? !result : result);
+            }
+        }
+    }
+
+    Status evaluate(const Schema& schema, const 
std::vector<BitmapIndexIterator*>& iterators,
+                    uint32_t num_rows, roaring::Roaring* bitmap) const 
override {
+        BitmapIndexIterator* iterator = iterators[_column_id];
+        if (iterator == nullptr) {
+            return Status::OK();
+        }
+
+        rowid_t ordinal_limit = iterator->bitmap_nums();
+        if (iterator->has_null_bitmap()) {
+            ordinal_limit--;
+            roaring::Roaring null_bitmap;
+            RETURN_IF_ERROR(iterator->read_null_bitmap(&null_bitmap));
+            *bitmap -= null_bitmap;
+        }
+
+        roaring::Roaring roaring;
+        bool exact_match;
+        Status status = iterator->seek_dictionary(&_value, &exact_match);
+        rowid_t seeked_ordinal = iterator->current_ordinal();
+
+        return _bitmap_compare(status, exact_match, ordinal_limit, 
seeked_ordinal, iterator,
+                               bitmap);
+    }
+
+    uint16_t evaluate(const vectorized::IColumn& column, uint16_t* sel,
+                      uint16_t size) const override {
+        if (column.is_nullable()) {
+            auto* nullable_column_ptr =
+                    
vectorized::check_and_get_column<vectorized::ColumnNullable>(column);
+            auto& nested_column = nullable_column_ptr->get_nested_column();
+            auto& null_map = reinterpret_cast<const vectorized::ColumnUInt8&>(
+                                     
nullable_column_ptr->get_null_map_column())
+                                     .get_data();
+
+            return _base_evaluate<true>(&nested_column, &null_map, sel, size);
+        } else {
+            return _base_evaluate<false>(&column, nullptr, sel, size);
+        }
+    }
+
+    void evaluate_and(const vectorized::IColumn& column, const uint16_t* sel, 
uint16_t size,
+                      bool* flags) const override {
+        _evaluate<true>(column, sel, size, flags);
+    }
+
+    void evaluate_or(const vectorized::IColumn& column, const uint16_t* sel, 
uint16_t size,
+                     bool* flags) const override {
+        _evaluate<false>(column, sel, size, flags);
+    }
+
+    void evaluate_vec(const vectorized::IColumn& column, uint16_t size,
+                      bool* flags) const override {
+        using TReal = std::conditional_t<std::is_same_v<T, uint24_t>, 
uint32_t, T>;
+
+        TReal value_real;
+        if constexpr (std::is_same_v<T, uint24_t>) {
+            value_real = 0;
+            memory_copy(&value_real, _value.get_data(), sizeof(T));
+        } else {
+            value_real = _value;
+        }
+
+        if (column.is_nullable()) {
+            auto* nullable_column_ptr =
+                    
vectorized::check_and_get_column<vectorized::ColumnNullable>(column);
+            auto& nested_column = nullable_column_ptr->get_nested_column();
+            auto& null_map = reinterpret_cast<const vectorized::ColumnUInt8&>(
+                                     
nullable_column_ptr->get_null_map_column())
+                                     .get_data();
+
+            if (nested_column.is_column_dictionary()) {
+                if constexpr (std::is_same_v<T, StringValue>) {
+                    auto* dict_column_ptr =
+                            
vectorized::check_and_get_column<vectorized::ColumnDictI32>(
+                                    nested_column);
+                    auto dict_code = _is_range() ? 
dict_column_ptr->find_code_by_bound(
+                                                           _value, 
_is_greater(), _is_eq())
+                                                 : 
dict_column_ptr->find_code(_value);
+                    auto& data_array = dict_column_ptr->get_data();
+
+                    _base_loop<true>(size, flags, &null_map, data_array, 
dict_code);
+                } else {
+                    LOG(FATAL) << "column_dictionary must use StringValue 
predicate.";
+                }
+            } else {
+                auto& data_array = reinterpret_cast<const 
vectorized::PredicateColumnType<TReal>&>(
+                                           nested_column)
+                                           .get_data();
+
+                _base_loop<true>(size, flags, &null_map, data_array, 
value_real);
+                for (uint16_t i = 0; i < size; i++) {
+                }
+            }
+        } else {
+            if (column.is_column_dictionary()) {
+                if constexpr (std::is_same_v<T, StringValue>) {
+                    auto* dict_column_ptr =
+                            
vectorized::check_and_get_column<vectorized::ColumnDictI32>(column);
+                    auto dict_code = _is_range() ? 
dict_column_ptr->find_code_by_bound(
+                                                           _value, 
_is_greater(), _is_eq())
+                                                 : 
dict_column_ptr->find_code(_value);
+                    auto& data_array = dict_column_ptr->get_data();
+
+                    _base_loop<false>(size, flags, nullptr, data_array, 
dict_code);
+                } else {
+                    LOG(FATAL) << "column_dictionary must use StringValue 
predicate.";
+                }
+            } else {
+                auto& data_array =
+                        
vectorized::check_and_get_column<vectorized::PredicateColumnType<TReal>>(
+                                column)
+                                ->get_data();
+
+                _base_loop<false>(size, flags, nullptr, data_array, 
value_real);
+            }
+        }
+
+        if (_opposite) {
+            for (uint16_t i = 0; i < size; i++) {
+                flags[i] = !flags[i];
+            }
+        }
+    }
+
+private:
+    template <typename LeftT, typename RightT>
+    bool _operator(const LeftT& lhs, const RightT& rhs) const {
+        if constexpr (PT == PredicateType::EQ) {
+            return lhs == rhs;
+        } else if constexpr (PT == PredicateType::NE) {
+            return lhs != rhs;
+        } else if constexpr (PT == PredicateType::LT) {
+            return lhs < rhs;
+        } else if constexpr (PT == PredicateType::LE) {
+            return lhs <= rhs;
+        } else if constexpr (PT == PredicateType::GT) {
+            return lhs > rhs;
+        } else if constexpr (PT == PredicateType::GE) {
+            return lhs >= rhs;
+        }
+    }
+
+    constexpr bool _is_range() const {

Review Comment:
   Not very accurate.     PredicateType could be
       EQ = 1,
       NE = 2,
       LT = 3,
       LE = 4,
       GT = 5,
       GE = 6,
       IN_LIST = 7,
       NOT_IN_LIST = 8,
       IS_NULL = 9,
       IS_NOT_NULL = 10,
       BF = 11, // BloomFilter
   
   better use > < >= <= 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to