This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch vector-index-dev in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/vector-index-dev by this push: new 1b671c9b06b [vector search] Step forward on stability and functionality (#51213) 1b671c9b06b is described below commit 1b671c9b06bf965fd01a152340202ffd85890975 Author: zhiqiang <hezhiqi...@selectdb.com> AuthorDate: Tue May 27 11:38:16 2025 +0800 [vector search] Step forward on stability and functionality (#51213) A huge step forward on stability and functionality. ### Functionality 1. Search parameters like `ef_search`, can be passed to index as session variables. This behavior is same with pg-vector and duckdb vector search plug-in. 2. Correct processing for order by desc. Fallback to brute force search when it is necessary. 3. Support using inner product as index metric and order by inner_product. 4. When metrics of sql dismatches with index, fallback to brute force. ### Stability 1. More unit test 2. Virtual column iterator. 3. According to custom script, result of range search, topn search & compound search is almost same with native faiss. The overlap rate of result is more than 90%. The 10% difference is introduced by batch insert mode of native faiss. --- be/src/olap/rowset/beta_rowset_reader.cpp | 7 + .../olap/rowset/segment_v2/ann_index_iterator.cpp | 2 +- be/src/olap/rowset/segment_v2/ann_index_iterator.h | 8 +- be/src/olap/rowset/segment_v2/ann_index_reader.cpp | 31 ++- be/src/olap/rowset/segment_v2/ann_index_reader.h | 9 +- be/src/olap/rowset/segment_v2/ann_index_writer.cpp | 32 +-- be/src/olap/rowset/segment_v2/ann_index_writer.h | 1 + be/src/olap/rowset/segment_v2/segment_iterator.cpp | 234 +++++++++++++++------ be/src/olap/rowset/segment_v2/segment_iterator.h | 9 +- .../rowset/segment_v2/virtual_column_iterator.cpp | 44 +++- .../rowset/segment_v2/virtual_column_iterator.h | 7 +- be/src/pipeline/exec/olap_scan_operator.cpp | 6 +- be/src/pipeline/exec/operator.cpp | 8 +- be/src/runtime/descriptors.cpp | 20 ++ be/src/runtime/runtime_state.h | 8 +- be/src/vec/core/block.cpp | 1 + be/src/vec/exec/scan/olap_scanner.cpp | 4 +- be/src/vec/exec/scan/olap_scanner.h | 3 + be/src/vec/exprs/ann_range_search_params.h | 21 +- be/src/vec/exprs/vann_topn_predicate.cpp | 17 +- be/src/vec/exprs/vann_topn_predicate.h | 11 +- be/src/vec/exprs/vectorized_fn_call.cpp | 42 +++- be/src/vec/exprs/vectorized_fn_call.h | 6 +- be/src/vec/exprs/vexpr.cpp | 4 +- be/src/vec/exprs/vexpr.h | 2 +- be/src/vec/exprs/vexpr_context.cpp | 4 +- be/src/vec/exprs/vexpr_context.h | 3 +- be/src/vec/exprs/virtual_slot_ref.cpp | 2 +- .../vec/functions/array/function_array_distance.h | 3 + be/src/vec/runtime/vector_search_user_params.cpp | 35 +++ be/src/vec/runtime/vector_search_user_params.h | 31 +++ be/src/vector/faiss_vector_index.cpp | 54 ++++- be/src/vector/faiss_vector_index.h | 24 ++- be/src/vector/vector_index.h | 19 +- .../olap/vector_search/ann_index_reader_test.cpp | 98 +++++++-- .../olap/vector_search/ann_range_search_test.cpp | 39 ++-- .../vector_search/ann_topn_descriptor_test.cpp | 8 +- .../olap/vector_search/faiss_vector_index_test.cpp | 26 +-- be/test/olap/vector_search/vector_search_utils.cpp | 11 + be/test/olap/vector_search/vector_search_utils.h | 37 +--- .../vector_search/virtual_column_iterator_test.cpp | 76 ++++++- .../PushDownVirtualColumnsIntoOlapScan.java | 19 +- .../trees/plans/commands/info/IndexDefinition.java | 21 +- .../java/org/apache/doris/qe/SessionVariable.java | 25 +++ gensrc/thrift/PaloInternalService.thrift | 4 + 45 files changed, 807 insertions(+), 269 deletions(-) diff --git a/be/src/olap/rowset/beta_rowset_reader.cpp b/be/src/olap/rowset/beta_rowset_reader.cpp index 66a44e7864e..e12c89d056f 100644 --- a/be/src/olap/rowset/beta_rowset_reader.cpp +++ b/be/src/olap/rowset/beta_rowset_reader.cpp @@ -146,6 +146,9 @@ Status BetaRowsetReader::get_segment_iterators(RowsetReaderContext* read_context _read_options.column_predicates.insert(_read_options.column_predicates.end(), _read_context->predicates->begin(), _read_context->predicates->end()); + LOG_INFO("Rowset reader, read options column predicates size: {}", + _read_options.column_predicates.size()); + for (auto pred : *(_read_context->predicates)) { if (_read_options.col_id_to_predicates.count(pred->column_id()) < 1) { _read_options.col_id_to_predicates.insert( @@ -185,6 +188,10 @@ Status BetaRowsetReader::get_segment_iterators(RowsetReaderContext* read_context _read_options.column_predicates.insert(_read_options.column_predicates.end(), _read_context->value_predicates->begin(), _read_context->value_predicates->end()); + LOG_INFO( + "Rowset reader, read options add value predicates, column predicates size now: " + "{}", + _read_options.column_predicates.size()); for (auto pred : *(_read_context->value_predicates)) { if (_read_options.col_id_to_predicates.count(pred->column_id()) < 1) { _read_options.col_id_to_predicates.insert( diff --git a/be/src/olap/rowset/segment_v2/ann_index_iterator.cpp b/be/src/olap/rowset/segment_v2/ann_index_iterator.cpp index 6a50032e2fb..3b37e3cabcb 100644 --- a/be/src/olap/rowset/segment_v2/ann_index_iterator.cpp +++ b/be/src/olap/rowset/segment_v2/ann_index_iterator.cpp @@ -37,7 +37,7 @@ Status AnnIndexIterator::read_from_index(const IndexParam& param) { } Status AnnIndexIterator::range_search(const RangeSearchParams& params, - const CustomSearchParams& custom_params, + const VectorSearchUserParams& custom_params, RangeSearchResult* result) { if (_ann_reader == nullptr) { return Status::Error<ErrorCode::INDEX_INVALID_PARAMETERS>("_ann_reader is null"); diff --git a/be/src/olap/rowset/segment_v2/ann_index_iterator.h b/be/src/olap/rowset/segment_v2/ann_index_iterator.h index 0972c69307e..82a4113cacb 100644 --- a/be/src/olap/rowset/segment_v2/ann_index_iterator.h +++ b/be/src/olap/rowset/segment_v2/ann_index_iterator.h @@ -23,6 +23,7 @@ #include "gutil/integral_types.h" #include "olap/rowset/segment_v2/ann_index_reader.h" #include "olap/rowset/segment_v2/index_iterator.h" +#include "runtime/runtime_state.h" namespace doris::segment_v2 { @@ -30,6 +31,7 @@ struct AnnIndexParam { const float* query_value; const size_t query_value_size; size_t limit; + doris::VectorSearchUserParams _user_params; roaring::Roaring* roaring; std::unique_ptr<std::vector<float>> distance = nullptr; std::unique_ptr<std::vector<uint64_t>> row_ids = nullptr; @@ -48,10 +50,6 @@ struct RangeSearchParams { virtual ~RangeSearchParams() = default; }; -struct CustomSearchParams { - int ef_search = 16; -}; - struct RangeSearchResult { std::shared_ptr<roaring::Roaring> roaring; std::unique_ptr<std::vector<uint64_t>> row_ids; @@ -80,7 +78,7 @@ public: bool has_null() override { return true; } MOCK_FUNCTION Status range_search(const RangeSearchParams& params, - const CustomSearchParams& custom_params, + const VectorSearchUserParams& custom_params, RangeSearchResult* result); private: diff --git a/be/src/olap/rowset/segment_v2/ann_index_reader.cpp b/be/src/olap/rowset/segment_v2/ann_index_reader.cpp index 0222597ea32..64637a72566 100644 --- a/be/src/olap/rowset/segment_v2/ann_index_reader.cpp +++ b/be/src/olap/rowset/segment_v2/ann_index_reader.cpp @@ -24,6 +24,7 @@ #include "common/config.h" #include "olap/rowset/segment_v2/index_file_reader.h" #include "olap/rowset/segment_v2/inverted_index_compound_reader.h" +#include "runtime/runtime_state.h" #include "vector/faiss_vector_index.h" #include "vector/vector_index.h" @@ -49,6 +50,9 @@ AnnIndexReader::AnnIndexReader(const TabletIndex* index_meta, auto it = index_properties.find("index_type"); DCHECK(it != index_properties.end()); _index_type = it->second; + it = index_properties.find("metric_type"); + DCHECK(it != index_properties.end()); + _metric_type = VectorIndex::string_to_metric(it->second); } Status AnnIndexReader::new_iterator(const io::IOContext& io_ctx, OlapReaderStatistics* stats, @@ -71,16 +75,27 @@ Status AnnIndexReader::load_index(io::IOContext* io_ctx) { } Status AnnIndexReader::query(io::IOContext* io_ctx, AnnIndexParam* param) { +#ifndef BE_TEST RETURN_IF_ERROR(_index_file_reader->init(config::inverted_index_read_buffer_size, io_ctx)); RETURN_IF_ERROR(load_index(io_ctx)); +#endif DCHECK(_vector_index != nullptr); const float* query_vec = param->query_value; const int limit = param->limit; - IndexSearchParameters index_search_params; IndexSearchResult index_search_result; - index_search_params.roaring = param->roaring; - RETURN_IF_ERROR(_vector_index->ann_topn_search(query_vec, limit, index_search_params, - index_search_result)); + if (_index_type == "hnsw") { + HNSWSearchParameters hnsw_search_params; + hnsw_search_params.roaring = param->roaring; + hnsw_search_params.ef_search = param->_user_params.hnsw_ef_search; + hnsw_search_params.check_relative_distance = + param->_user_params.hnsw_check_relative_distance; + hnsw_search_params.bounded_queue = param->_user_params.hnsw_bounded_queue; + RETURN_IF_ERROR(_vector_index->ann_topn_search(query_vec, limit, hnsw_search_params, + index_search_result)); + } else { + throw Status::NotSupported("Unsupported index type: {}", _index_type); + } + DCHECK(index_search_result.roaring != nullptr); DCHECK(index_search_result.distances != nullptr); DCHECK(index_search_result.row_ids != nullptr); @@ -92,17 +107,21 @@ Status AnnIndexReader::query(io::IOContext* io_ctx, AnnIndexParam* param) { } Status AnnIndexReader::range_search(const RangeSearchParams& params, - const CustomSearchParams& custom_params, + const VectorSearchUserParams& custom_params, RangeSearchResult* result, io::IOContext* io_ctx) { +#ifndef BE_TEST RETURN_IF_ERROR(_index_file_reader->init(config::inverted_index_read_buffer_size, io_ctx)); RETURN_IF_ERROR(load_index(io_ctx)); +#endif DCHECK(_vector_index != nullptr); IndexSearchResult search_result; std::unique_ptr<IndexSearchParameters> search_param = nullptr; if (_index_type == "hnsw") { auto hnsw_param = std::make_unique<HNSWSearchParameters>(); - hnsw_param->ef_search = custom_params.ef_search; + hnsw_param->ef_search = custom_params.hnsw_ef_search; + hnsw_param->check_relative_distance = custom_params.hnsw_check_relative_distance; + hnsw_param->bounded_queue = custom_params.hnsw_bounded_queue; search_param = std::move(hnsw_param); } else { throw Status::NotSupported("Unsupported index type: {}", _index_type); diff --git a/be/src/olap/rowset/segment_v2/ann_index_reader.h b/be/src/olap/rowset/segment_v2/ann_index_reader.h index 69ea9c91c0b..a12c0a508e4 100644 --- a/be/src/olap/rowset/segment_v2/ann_index_reader.h +++ b/be/src/olap/rowset/segment_v2/ann_index_reader.h @@ -19,13 +19,13 @@ #include "olap/rowset/segment_v2/index_reader.h" #include "olap/tablet_schema.h" +#include "runtime/runtime_state.h" #include "vector/vector_index.h" namespace doris::segment_v2 { struct AnnIndexParam; struct RangeSearchParams; -struct CustomSearchParams; struct RangeSearchResult; class IndexFileReader; @@ -44,14 +44,16 @@ public: Status query(io::IOContext* io_ctx, AnnIndexParam* param); - Status range_search(const RangeSearchParams& params, const CustomSearchParams& custom_params, - RangeSearchResult* result, io::IOContext* io_ctx = nullptr); + Status range_search(const RangeSearchParams& params, + const VectorSearchUserParams& custom_params, RangeSearchResult* result, + io::IOContext* io_ctx = nullptr); uint64_t get_index_id() const override { return _index_meta.index_id(); } Status new_iterator(const io::IOContext& io_ctx, OlapReaderStatistics* stats, RuntimeState* runtime_state, std::unique_ptr<IndexIterator>* iterator) override; + VectorIndex::Metric get_metric_type() const { return _metric_type; } private: TabletIndex _index_meta; @@ -59,6 +61,7 @@ private: std::unique_ptr<VectorIndex> _vector_index; // TODO: Use integer. std::string _index_type; + VectorIndex::Metric _metric_type; }; using AnnIndexReaderPtr = std::shared_ptr<AnnIndexReader>; diff --git a/be/src/olap/rowset/segment_v2/ann_index_writer.cpp b/be/src/olap/rowset/segment_v2/ann_index_writer.cpp index 452d9dfc9a2..0f8c6cd1d7c 100644 --- a/be/src/olap/rowset/segment_v2/ann_index_writer.cpp +++ b/be/src/olap/rowset/segment_v2/ann_index_writer.cpp @@ -19,6 +19,7 @@ #include <cstddef> #include <memory> +#include <string> #include "olap/rowset/segment_v2/inverted_index_fs_directory.h" @@ -57,21 +58,22 @@ Status AnnIndexColumnWriter::init() { _vector_index = nullptr; const auto& properties = _index_meta->properties(); - std::string index_type = get_or_default(properties, INDEX_TYPE, ""); - if (index_type == "hnsw") { - std::shared_ptr<FaissVectorIndex> faiss_index = std::make_shared<FaissVectorIndex>(); - FaissBuildParameter builderParameter; - builderParameter.index_type = FaissBuildParameter::string_to_index_type("hnsw"); - builderParameter.d = std::stoi(get_or_default(properties, DIM, "512")); - builderParameter.m = std::stoi(get_or_default(properties, MAX_DEGREE, "32")); - builderParameter.quantilizer = FaissBuildParameter::string_to_quantilizer( - get_or_default(properties, QUANTILIZER, "flat")); - faiss_index->set_build_params(builderParameter); - _vector_index = faiss_index; - } else { - return Status::NotSupported("Unsupported index type: " + index_type); - } - + const std::string index_type = get_or_default(properties, INDEX_TYPE, "hnsw"); + const std::string metric_type = get_or_default(properties, METRIC_TYPE, "l2"); + const std::string quantilizer = get_or_default(properties, QUANTILIZER, "flat"); + FaissBuildParameter builderParameter; + std::shared_ptr<FaissVectorIndex> faiss_index = std::make_shared<FaissVectorIndex>(); + builderParameter.index_type = FaissBuildParameter::string_to_index_type(index_type); + builderParameter.d = std::stoi(get_or_default(properties, DIM, "512")); + builderParameter.m = std::stoi(get_or_default(properties, MAX_DEGREE, "32")); + builderParameter.pq_m = std::stoi(get_or_default(properties, PQ_M, "-1")); // -1 means not set + + builderParameter.metric_type = FaissBuildParameter::string_to_metric_type(metric_type); + builderParameter.quantilizer = FaissBuildParameter::string_to_quantilizer(quantilizer); + + faiss_index->set_build_params(builderParameter); + + _vector_index = faiss_index; return Status::OK(); } diff --git a/be/src/olap/rowset/segment_v2/ann_index_writer.h b/be/src/olap/rowset/segment_v2/ann_index_writer.h index cb8f9316fc9..d674fb12648 100644 --- a/be/src/olap/rowset/segment_v2/ann_index_writer.h +++ b/be/src/olap/rowset/segment_v2/ann_index_writer.h @@ -50,6 +50,7 @@ public: static constexpr const char* INDEX_TYPE = "index_type"; static constexpr const char* METRIC_TYPE = "metric_type"; static constexpr const char* QUANTILIZER = "quantilizer"; + static constexpr const char* PQ_M = "pq_m"; static constexpr const char* DIM = "dim"; static constexpr const char* MAX_DEGREE = "max_degree"; diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.cpp b/be/src/olap/rowset/segment_v2/segment_iterator.cpp index 00a15be3e23..3075af6122c 100644 --- a/be/src/olap/rowset/segment_v2/segment_iterator.cpp +++ b/be/src/olap/rowset/segment_v2/segment_iterator.cpp @@ -24,6 +24,7 @@ #include <algorithm> #include <boost/iterator/iterator_facade.hpp> +#include <cstddef> #include <cstdint> #include <iterator> #include <memory> @@ -98,6 +99,7 @@ #include "vec/exprs/vslot_ref.h" #include "vec/functions/array/function_array_index.h" #include "vec/json/path_in_data.h" +#include "vector/vector_index.h" namespace doris { using namespace ErrorCode; @@ -299,6 +301,7 @@ Status SegmentIterator::_init_impl(const StorageReadOptions& opts) { } _col_predicates.emplace_back(predicate); } + LOG_INFO("Segment iterator init, column predicates size: {}", _col_predicates.size()); _tablet_id = opts.tablet_id; // Read options will not change, so that just resize here _block_rowids.resize(_opts.block_row_max); @@ -632,6 +635,32 @@ Status SegmentIterator::_apply_ann_topn_predicate() { !_col_predicates.empty()); return Status::OK(); } + + // Process asc & desc according to the type of metric + auto index_reader = ann_index_iterator->get_reader(); + auto ann_index_reader = dynamic_cast<AnnIndexReader*>(index_reader.get()); + DCHECK(ann_index_reader != nullptr); + if (ann_index_reader->get_metric_type() == VectorIndex::Metric::INNER_PRODUCT) { + if (_ann_topn_descriptor->is_asc()) { + LOG_INFO("asc topn for inner product can not be evaluated by ann index"); + return Status::OK(); + } + } else { + if (!_ann_topn_descriptor->is_asc()) { + LOG_INFO("desc topn for l2/cosine can not be evaluated by ann index"); + return Status::OK(); + } + } + + if (ann_index_reader->get_metric_type() != _ann_topn_descriptor->get_metric_type()) { + LOG_INFO( + "Ann topn metric type {} not match index metric type {}, can not be evaluated by " + "ann index", + VectorIndex::metric_to_string(_ann_topn_descriptor->get_metric_type()), + VectorIndex::metric_to_string(ann_index_reader->get_metric_type())); + return Status::OK(); + } + size_t pre_size = _row_bitmap.cardinality(); size_t dst_col_idx = _ann_topn_descriptor->get_dest_column_idx(); vectorized::IColumn::MutablePtr result_column; @@ -647,6 +676,8 @@ Status SegmentIterator::_apply_ann_topn_predicate() { DCHECK(column_iter != nullptr); VirtualColumnIterator* virtual_column_iter = dynamic_cast<VirtualColumnIterator*>(column_iter); DCHECK(virtual_column_iter != nullptr); + LOG_INFO("Virtual column iterator, column_idx {}, is materialized with {} rows", dst_col_idx, + result_row_ids->size()); virtual_column_iter->prepare_materialization(std::move(result_column), std::move(result_row_ids)); return Status::OK(); @@ -936,6 +967,7 @@ Status SegmentIterator::_apply_index_expr() { ++it; } } + // TODO:remove expr root from _remaining_conjunct_roots return Status::OK(); } @@ -1040,6 +1072,10 @@ bool SegmentIterator::_need_read_data(ColumnId cid) { _opts.enable_unique_key_merge_on_write)))) { return true; } + if (this->_vir_cid_to_idx_in_block.contains(cid)) { + return true; + } + // if there is delete predicate, we always need to read data if (_has_delete_predicate(cid)) { return true; @@ -1450,7 +1486,7 @@ Status SegmentIterator::_vec_init_lazy_materialization() { _is_pred_column.resize(_schema->columns().size(), false); // including short/vec/delete pred - std::set<ColumnId> pred_column_ids; + std::set<ColumnId> cols_read_by_column_predicate; _lazy_materialization_read = false; std::set<ColumnId> del_cond_id_set; @@ -1490,7 +1526,7 @@ Status SegmentIterator::_vec_init_lazy_materialization() { for (auto* predicate : _col_predicates) { auto cid = predicate->column_id(); _is_pred_column[cid] = true; - pred_column_ids.insert(cid); + cols_read_by_column_predicate.insert(cid); // check pred using short eval or vec eval if (_can_evaluated_by_vectorized(predicate)) { @@ -1508,7 +1544,7 @@ Status SegmentIterator::_vec_init_lazy_materialization() { // handle delete_condition if (!del_cond_id_set.empty()) { short_cir_pred_col_id_set.insert(del_cond_id_set.begin(), del_cond_id_set.end()); - pred_column_ids.insert(del_cond_id_set.begin(), del_cond_id_set.end()); + cols_read_by_column_predicate.insert(del_cond_id_set.begin(), del_cond_id_set.end()); for (auto cid : del_cond_id_set) { _is_pred_column[cid] = true; @@ -1566,7 +1602,7 @@ Status SegmentIterator::_vec_init_lazy_materialization() { // all columns are lazy materialization columns without non predicte column. // If common expr pushdown exists, and expr column is not contained in lazy materialization columns, // add to second read column, which will be read after lazy materialization - if (_schema->column_ids().size() > pred_column_ids.size()) { + if (_schema->column_ids().size() > cols_read_by_column_predicate.size()) { // pred_column_ids maybe empty, so that could not set _lazy_materialization_read = true here // has to check there is at least one predicate column for (auto cid : _schema->column_ids()) { @@ -1574,10 +1610,10 @@ Status SegmentIterator::_vec_init_lazy_materialization() { if (_is_need_vec_eval || _is_need_short_eval) { _lazy_materialization_read = true; } - if (!_is_common_expr_column[cid]) { - _non_predicate_columns.push_back(cid); + if (_is_common_expr_column[cid]) { + _cols_read_by_common_expr.push_back(cid); } else { - _non_predicate_column_ids.push_back(cid); + _cols_not_included_by_any_predicates.push_back(cid); } } } @@ -1586,16 +1622,20 @@ Status SegmentIterator::_vec_init_lazy_materialization() { // Step 4: fill first read columns if (_lazy_materialization_read) { // insert pred cid to first_read_columns - for (auto cid : pred_column_ids) { - _predicate_column_ids.push_back(cid); + for (auto cid : cols_read_by_column_predicate) { + _cols_read_by_column_predicate.push_back(cid); } - } else if (!_is_need_vec_eval && !_is_need_short_eval && - !_is_need_expr_eval) { // no pred exists, just read and output column + } else if (!_is_need_vec_eval && !_is_need_short_eval && !_is_need_expr_eval) { + // no pred exists, just read and output column + // 这代码也很迷惑啊,既然没有任何谓词列,那就不要改变流程啊,就按照正常的输出 non-predicates-columns 就好了啊 + // 为什么要强行把所有的列当作 predicate 列去处理呢 for (int i = 0; i < _schema->num_column_ids(); i++) { auto cid = _schema->column_id(i); - _predicate_column_ids.push_back(cid); + _cols_read_by_column_predicate.push_back(cid); } } else { + // 不延迟物化,但是有谓词 + // 说明除了 column_predicates 的列之外,还有其他列需要读 if (_is_need_vec_eval || _is_need_short_eval) { // TODO To refactor, because we suppose lazy materialization is better performance. // pred exits, but we can eliminate lazy materialization @@ -1605,12 +1645,12 @@ Status SegmentIterator::_vec_init_lazy_materialization() { _short_cir_pred_column_ids.end()); pred_id_set.insert(_vec_pred_column_ids.begin(), _vec_pred_column_ids.end()); - DCHECK(_non_predicate_column_ids.empty()); + DCHECK(_cols_read_by_common_expr.empty()); // _non_predicate_column_ids must be empty. Otherwise _lazy_materialization_read must not false. for (int i = 0; i < _schema->num_column_ids(); i++) { auto cid = _schema->column_id(i); if (pred_id_set.find(cid) != pred_id_set.end()) { - _predicate_column_ids.push_back(cid); + _cols_read_by_column_predicate.push_back(cid); } // In the past, if schema columns > pred columns, the _lazy_materialization_read maybe == false, but // we make sure using _lazy_materialization_read= true now, so these logic may never happens. I comment @@ -1624,21 +1664,22 @@ Status SegmentIterator::_vec_init_lazy_materialization() { } else if (_is_need_expr_eval) { DCHECK(!_is_need_vec_eval && !_is_need_short_eval); for (auto cid : _common_expr_columns) { - _predicate_column_ids.push_back(cid); + // 这代码太 track 了,很迷糊啊,完全概念混到一起了 + _cols_read_by_column_predicate.push_back(cid); } } } LOG_INFO( - "Laze materialization end. " + "Laze materialization init end. " "lazy_materialization_read: {}, " - "predicate_column_ids: [{}], " - "non_predicate_columns: [{}], " - "non_predicate_column_ids: [{}], " + "_cols_read_by_column_predicate: [{}], " + "_cols_not_included_by_any_predicates: [{}], " + "_cols_read_by_common_expr: [{}], " "columns_to_filter: [{}]", - _lazy_materialization_read, fmt::join(_predicate_column_ids, ","), - fmt::join(_non_predicate_columns, ","), fmt::join(_non_predicate_column_ids, ","), - fmt::join(_columns_to_filter, ",")); + _lazy_materialization_read, fmt::join(_cols_read_by_column_predicate, ","), + fmt::join(_cols_not_included_by_any_predicates, ","), + fmt::join(_cols_read_by_common_expr, ","), fmt::join(_columns_to_filter, ",")); return Status::OK(); } @@ -1806,7 +1847,7 @@ Status SegmentIterator::_init_return_columns(vectorized::Block* block, uint32_t void SegmentIterator::_output_non_pred_columns(vectorized::Block* block) { SCOPED_RAW_TIMER(&_opts.stats->output_col_ns); - for (auto cid : _non_predicate_columns) { + for (auto cid : _cols_not_included_by_any_predicates) { auto loc = _schema_block_id_map[cid]; // if loc > block->columns() means the column is delete column and should // not output by block, so just skip the column. @@ -1841,31 +1882,46 @@ Status SegmentIterator::_read_columns_by_index(uint32_t nrows_read_limit, uint32 SCOPED_RAW_TIMER(&_opts.stats->predicate_column_read_ns); nrows_read = _range_iter->read_batch_rowids(_block_rowids.data(), nrows_read_limit); + LOG_INFO("nrows_read from range iterator: {}", nrows_read); bool is_continuous = (nrows_read > 1) && (_block_rowids[nrows_read - 1] - _block_rowids[0] == nrows_read - 1); + std::vector<ColumnId> predicate_column_ids_and_virtual_columns; + predicate_column_ids_and_virtual_columns.reserve(_cols_read_by_column_predicate.size() + + _virtual_column_exprs.size()); + predicate_column_ids_and_virtual_columns.insert(predicate_column_ids_and_virtual_columns.end(), + _cols_read_by_column_predicate.begin(), + _cols_read_by_column_predicate.end()); - for (auto cid : _predicate_column_ids) { - auto& column = _current_return_columns[cid]; - if (_no_need_read_key_data(cid, column, nrows_read)) { - continue; - } - if (_prune_column(cid, column, true, nrows_read)) { - continue; - } + for (const auto& entry : _virtual_column_exprs) { + // virtual column id is not in _predicate_column_ids + predicate_column_ids_and_virtual_columns.push_back(entry.first); + } - DBUG_EXECUTE_IF("segment_iterator._read_columns_by_index", { - auto col_name = _opts.tablet_schema->column(cid).name(); - auto debug_col_name = DebugPoints::instance()->get_debug_param_or_default<std::string>( - "segment_iterator._read_columns_by_index", "column_name", ""); - if (debug_col_name.empty() && col_name != "__DORIS_DELETE_SIGN__") { - return Status::Error<ErrorCode::INTERNAL_ERROR>("does not need to read data, {}", - col_name); + for (auto cid : predicate_column_ids_and_virtual_columns) { + auto& column = _current_return_columns[cid]; + if (!_virtual_column_exprs.contains(cid)) { + if (_no_need_read_key_data(cid, column, nrows_read)) { + continue; } - if (debug_col_name.find(col_name) != std::string::npos) { - return Status::Error<ErrorCode::INTERNAL_ERROR>("does not need to read data, {}", - col_name); + if (_prune_column(cid, column, true, nrows_read)) { + continue; } - }) + + DBUG_EXECUTE_IF("segment_iterator._read_columns_by_index", { + auto col_name = _opts.tablet_schema->column(cid).name(); + auto debug_col_name = + DebugPoints::instance()->get_debug_param_or_default<std::string>( + "segment_iterator._read_columns_by_index", "column_name", ""); + if (debug_col_name.empty() && col_name != "__DORIS_DELETE_SIGN__") { + return Status::Error<ErrorCode::INTERNAL_ERROR>( + "does not need to read data, {}", col_name); + } + if (debug_col_name.find(col_name) != std::string::npos) { + return Status::Error<ErrorCode::INTERNAL_ERROR>( + "does not need to read data, {}", col_name); + } + }) + } if (is_continuous) { size_t rows_read = nrows_read; @@ -2321,8 +2377,26 @@ Status SegmentIterator::_next_batch_internal(vectorized::Block* block) { _current_batch_rows_read = 0; RETURN_IF_ERROR(_read_columns_by_index(nrows_read_limit, _current_batch_rows_read)); - if (std::find(_predicate_column_ids.begin(), _predicate_column_ids.end(), - _schema->version_col_idx()) != _predicate_column_ids.end()) { + + // 把从索引物化得到的虚拟列放到 block 中 + for (const auto pair : _vir_cid_to_idx_in_block) { + ColumnId cid = pair.first; + size_t position = pair.second; + block->replace_by_position(position, std::move(_current_return_columns[cid])); + bool is_nothing = false; + if (vectorized::check_and_get_column<vectorized::ColumnNothing>( + block->get_by_position(position).column.get())) { + is_nothing = true; + } + + LOG_INFO( + "SegmentIterator next block replace virtual column, cid {}, position {}, still " + "nothing {}", + cid, position, is_nothing); + } + + if (std::find(_cols_read_by_column_predicate.begin(), _cols_read_by_column_predicate.end(), + _schema->version_col_idx()) != _cols_read_by_column_predicate.end()) { _replace_version_col(_current_batch_rows_read); } @@ -2333,18 +2407,19 @@ Status SegmentIterator::_next_batch_internal(vectorized::Block* block) { // Convert all columns in _current_return_columns to schema column RETURN_IF_ERROR(_convert_to_expected_type(_schema->column_ids())); for (int i = 0; i < block->columns() - _vir_cid_to_idx_in_block.size(); i++) { - // TODO: 虚拟列是否需要处理 auto cid = _schema->column_id(i); // todo(wb) abstract make column where if (!_is_pred_column[cid]) { block->replace_by_position(i, std::move(_current_return_columns[cid])); } } + for (auto& pair : _vir_cid_to_idx_in_block) { auto cid = pair.first; auto loc = pair.second; block->replace_by_position(loc, std::move(_current_return_columns[cid])); } + block->clear_column_data(); // clear and release iterators memory footprint in advance _clear_iterators(); @@ -2352,11 +2427,15 @@ Status SegmentIterator::_next_batch_internal(vectorized::Block* block) { } if (!_is_need_vec_eval && !_is_need_short_eval && !_is_need_expr_eval) { - if (_non_predicate_columns.empty()) { + if (_cols_not_included_by_any_predicates.empty()) { return Status::InternalError("_non_predicate_columns is empty"); } - RETURN_IF_ERROR(_convert_to_expected_type(_predicate_column_ids)); - RETURN_IF_ERROR(_convert_to_expected_type(_non_predicate_columns)); + RETURN_IF_ERROR(_convert_to_expected_type(_cols_read_by_column_predicate)); + RETURN_IF_ERROR(_convert_to_expected_type(_cols_not_included_by_any_predicates)); + LOG_INFO( + "No need to evaluate any predicates or filter, output non-predicate columns, " + "block rows {}, selected size {}", + block->rows(), _current_batch_rows_read); _output_non_pred_columns(block); } else { uint16_t selected_size = _current_batch_rows_read; @@ -2379,33 +2458,41 @@ Status SegmentIterator::_next_batch_internal(vectorized::Block* block) { // when lazy materialization enables, _predicate_column_ids = distinct(_short_cir_pred_column_ids + _vec_pred_column_ids) // see _vec_init_lazy_materialization // todo(wb) need to tell input columnids from output columnids - RETURN_IF_ERROR(_output_column_by_sel_idx(block, _predicate_column_ids, + RETURN_IF_ERROR(_output_column_by_sel_idx(block, _cols_read_by_column_predicate, _sel_rowid_idx.data(), selected_size)); // step 3.2: read remaining expr column and evaluate it. if (_is_need_expr_eval) { // The predicate column contains the remaining expr column, no need second read. - if (!_non_predicate_column_ids.empty()) { + if (_cols_read_by_common_expr.size() > 0) { SCOPED_RAW_TIMER(&_opts.stats->non_predicate_read_ns); RETURN_IF_ERROR(_read_columns_by_rowids( - _non_predicate_column_ids, _block_rowids, _sel_rowid_idx.data(), + _cols_read_by_common_expr, _block_rowids, _sel_rowid_idx.data(), selected_size, &_current_return_columns)); - if (std::find(_non_predicate_column_ids.begin(), - _non_predicate_column_ids.end(), + if (std::find(_cols_read_by_common_expr.begin(), + _cols_read_by_common_expr.end(), _schema->version_col_idx()) != - _non_predicate_column_ids.end()) { + _cols_read_by_common_expr.end()) { _replace_version_col(selected_size); } - RETURN_IF_ERROR(_convert_to_expected_type(_non_predicate_column_ids)); - for (auto cid : _non_predicate_column_ids) { + RETURN_IF_ERROR(_convert_to_expected_type(_cols_read_by_common_expr)); + for (auto cid : _cols_read_by_common_expr) { auto loc = _schema_block_id_map[cid]; block->replace_by_position(loc, std::move(_current_return_columns[cid])); } + + for (const auto pair : _vir_cid_to_idx_in_block) { + auto cid = pair.first; + auto loc = pair.second; + block->replace_by_position(loc, + std::move(_current_return_columns[cid])); + } } DCHECK(block->columns() > _schema_block_id_map[*_common_expr_columns.begin()]); - // block->rows() takes the size of the first column by default. If the first column is no predicate column, + // block->rows() takes the size of the first column by default. + // If the first column is not predicate column, // it has not been read yet. add a const column that has been read to calculate rows(). if (block->rows() == 0) { vectorized::MutableColumnPtr col0 = @@ -2430,17 +2517,23 @@ Status SegmentIterator::_next_batch_internal(vectorized::Block* block) { } } } else if (_is_need_expr_eval) { - RETURN_IF_ERROR(_convert_to_expected_type(_non_predicate_column_ids)); - for (auto cid : _non_predicate_column_ids) { + RETURN_IF_ERROR(_convert_to_expected_type(_cols_read_by_common_expr)); + for (auto cid : _cols_read_by_common_expr) { auto loc = _schema_block_id_map[cid]; block->replace_by_position(loc, std::move(_current_return_columns[cid])); } + + for (const auto pair : _vir_cid_to_idx_in_block) { + auto cid = pair.first; + auto loc = pair.second; + block->replace_by_position(loc, std::move(_current_return_columns[cid])); + } } } else if (_is_need_expr_eval) { - DCHECK(!_predicate_column_ids.empty()); - RETURN_IF_ERROR(_convert_to_expected_type(_predicate_column_ids)); + DCHECK(!_cols_read_by_column_predicate.empty()); + RETURN_IF_ERROR(_convert_to_expected_type(_cols_read_by_column_predicate)); // first read all rows are insert block, initialize sel_rowid_idx to all rows. - for (auto cid : _predicate_column_ids) { + for (auto cid : _cols_read_by_column_predicate) { auto loc = _schema_block_id_map[cid]; block->replace_by_position(loc, std::move(_current_return_columns[cid])); } @@ -2482,7 +2575,7 @@ Status SegmentIterator::_next_batch_internal(vectorized::Block* block) { _selected_size = selected_size; } - if (_non_predicate_columns.empty()) { + if (_cols_not_included_by_any_predicates.empty()) { // shrink char_type suffix zero data block->shrink_char_type_column_suffix_zero(_char_type_idx); @@ -2490,16 +2583,17 @@ Status SegmentIterator::_next_batch_internal(vectorized::Block* block) { } // step4: read non_predicate column if (selected_size > 0) { - RETURN_IF_ERROR(_read_columns_by_rowids(_non_predicate_columns, _block_rowids, - _sel_rowid_idx.data(), selected_size, - &_current_return_columns)); - if (std::find(_non_predicate_columns.begin(), _non_predicate_columns.end(), - _schema->version_col_idx()) != _non_predicate_columns.end()) { + RETURN_IF_ERROR(_read_columns_by_rowids(_cols_not_included_by_any_predicates, + _block_rowids, _sel_rowid_idx.data(), + selected_size, &_current_return_columns)); + if (std::find(_cols_not_included_by_any_predicates.begin(), + _cols_not_included_by_any_predicates.end(), _schema->version_col_idx()) != + _cols_not_included_by_any_predicates.end()) { _replace_version_col(selected_size); } } - RETURN_IF_ERROR(_convert_to_expected_type(_non_predicate_columns)); + RETURN_IF_ERROR(_convert_to_expected_type(_cols_not_included_by_any_predicates)); // step5: output columns _output_non_pred_columns(block); } @@ -2836,6 +2930,8 @@ Status SegmentIterator::_materialization_of_virtual_column(vectorized::Block* bl if (vectorized::check_and_get_column<const vectorized::ColumnNothing>( block->get_by_position(idx_in_block).column.get())) { + LOG_INFO("Virtual column is doing materialization, cid {}, column_expr {}", cid, + column_expr->root()->debug_string()); int result_cid = -1; RETURN_IF_ERROR(column_expr->execute(block, &result_cid)); diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.h b/be/src/olap/rowset/segment_v2/segment_iterator.h index 378e6e90630..bb868d6278a 100644 --- a/be/src/olap/rowset/segment_v2/segment_iterator.h +++ b/be/src/olap/rowset/segment_v2/segment_iterator.h @@ -396,7 +396,7 @@ private: // whether lazy materialization read should be used. bool _lazy_materialization_read; // columns to read after predicate evaluation and remaining expr execute - std::vector<ColumnId> _non_predicate_columns; + std::vector<ColumnId> _cols_not_included_by_any_predicates; std::set<ColumnId> _common_expr_columns; // remember the rowids we've read for the current row block. // could be a local variable of next_batch(), kept here to reuse vector memory @@ -410,7 +410,7 @@ private: _vec_pred_column_ids; // keep columnId of columns for vectorized predicate evaluation std::vector<ColumnId> _short_cir_pred_column_ids; // keep columnId of columns for short circuit predicate evaluation - std::vector<bool> _is_pred_column; // columns hold _init segmentIter + std::map<uint32_t, bool> _need_read_data_indices; std::vector<bool> _is_common_expr_column; vectorized::MutableColumns _current_return_columns; @@ -422,8 +422,9 @@ private: // first, read predicate columns by various index // second, read non-predicate columns // so we need a field to stand for columns first time to read - std::vector<ColumnId> _predicate_column_ids; - std::vector<ColumnId> _non_predicate_column_ids; + std::vector<ColumnId> _cols_read_by_column_predicate; + std::vector<bool> _is_pred_column; + std::vector<ColumnId> _cols_read_by_common_expr; std::vector<ColumnId> _columns_to_filter; std::vector<ColumnId> _converted_column_ids; std::vector<int> _schema_block_id_map; // map from schema column id to column idx in Block diff --git a/be/src/olap/rowset/segment_v2/virtual_column_iterator.cpp b/be/src/olap/rowset/segment_v2/virtual_column_iterator.cpp index 767e03f89c3..82cb4631cd3 100644 --- a/be/src/olap/rowset/segment_v2/virtual_column_iterator.cpp +++ b/be/src/olap/rowset/segment_v2/virtual_column_iterator.cpp @@ -49,13 +49,52 @@ void VirtualColumnIterator::prepare_materialization(vectorized::IColumn::Ptr col _filter = doris::vectorized::IColumn::Filter(_size, 0); } +Status VirtualColumnIterator::seek_to_first() { + if (_size < 0) { + // _materialized_column is not set. do nothing. + return Status::OK(); + } + _current_ordinal = 0; + + return Status::OK(); +} + +Status VirtualColumnIterator::seek_to_ordinal(ordinal_t ord_idx) { + if (_size < 0 || + vectorized::check_and_get_column<vectorized::ColumnNothing>(*_materialized_column_ptr)) { + // _materialized_column is not set. do nothing. + return Status::OK(); + } + + if (ord_idx >= _size) { + return Status::InternalError("Seek to ordinal out of range: {} out of {}", ord_idx, _size); + } + + _current_ordinal = ord_idx; + + return Status::OK(); +} + // Next batch implementation Status VirtualColumnIterator::next_batch(size_t* n, vectorized::MutableColumnPtr& dst, bool* has_null) { if (vectorized::check_and_get_column<vectorized::ColumnNothing>(*_materialized_column_ptr)) { return Status::OK(); } + size_t rows_num_to_read = *n; + if (_row_id_to_idx.find(_current_ordinal) == _row_id_to_idx.end()) { + return Status::InternalError("Current ordinal {} not found in row_id_to_idx map", + _current_ordinal); + } + size_t start = _row_id_to_idx[_current_ordinal]; + // Update dst column + dst = _materialized_column_ptr->clone_empty(); + dst->insert_range_from(*_materialized_column_ptr, start, rows_num_to_read); + + LOG_INFO("Virtual column iterators, next_batch, rows reads: {}, dst size: {}", rows_num_to_read, + dst->size()); + _current_ordinal += rows_num_to_read; return Status::OK(); } @@ -75,9 +114,10 @@ Status VirtualColumnIterator::read_by_rowids(const rowid_t* rowids, const size_t // Apply filter to materialized column doris::vectorized::IColumn::Ptr res_col = _materialized_column_ptr->filter(_filter, 0); // Update dst column - dst->clear(); - dst->insert_range_from(*res_col, 0, res_col->size()); + dst = res_col->assume_mutable(); + LOG_INFO("Virtual column iterators, read_by_rowids, rowids size: {}, dst size: {}", count, + dst->size()); return Status::OK(); } diff --git a/be/src/olap/rowset/segment_v2/virtual_column_iterator.h b/be/src/olap/rowset/segment_v2/virtual_column_iterator.h index 17f60ec0e7a..f8c5f360716 100644 --- a/be/src/olap/rowset/segment_v2/virtual_column_iterator.h +++ b/be/src/olap/rowset/segment_v2/virtual_column_iterator.h @@ -38,9 +38,9 @@ public: Status init(const ColumnIteratorOptions& opts) override; - Status seek_to_first() override { return Status::OK(); } + Status seek_to_first() override; - Status seek_to_ordinal(ordinal_t ord_idx) override { return Status::OK(); } + Status seek_to_ordinal(ordinal_t ord_idx) override; Status next_batch(size_t* n, vectorized::MutableColumnPtr& dst, bool* has_null) override; @@ -57,9 +57,10 @@ private: vectorized::IColumn::Ptr _materialized_column_ptr; // segment rowid to index in column. std::map<uint64_t, uint64_t> _row_id_to_idx; - doris::vectorized::IColumn::Filter _filter; size_t _size = 0; + + ordinal_t _current_ordinal = 0; }; } // namespace doris::segment_v2 \ No newline at end of file diff --git a/be/src/pipeline/exec/olap_scan_operator.cpp b/be/src/pipeline/exec/olap_scan_operator.cpp index cfc05f36dd2..7861d95ec9d 100644 --- a/be/src/pipeline/exec/olap_scan_operator.cpp +++ b/be/src/pipeline/exec/olap_scan_operator.cpp @@ -569,11 +569,13 @@ Status OlapScanLocalState::init(RuntimeState* state, LocalStateInfo& info) { // order by 的表达式需要是一个 slot_ref,并且类型需要是虚拟列 DCHECK(ordering_expr.nodes[0].__isset.slot_ref); DCHECK(ordering_expr.nodes[0].slot_ref.is_virtual_slot); - size_t limit = olap_scan_node.ann_sort_limit; + DCHECK(olap_scan_node.ann_sort_info.is_asc_order.size() == 1); + const bool asc = olap_scan_node.ann_sort_info.is_asc_order[0]; + const size_t limit = olap_scan_node.ann_sort_limit; std::shared_ptr<vectorized::VExprContext> ordering_expr_ctx; RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(ordering_expr, ordering_expr_ctx)); _ann_topn_descriptor = - vectorized::AnnTopNDescriptor::create_shared(limit, ordering_expr_ctx); + vectorized::AnnTopNDescriptor::create_shared(asc, limit, ordering_expr_ctx); } return ScanLocalState<OlapScanLocalState>::init(state, info); diff --git a/be/src/pipeline/exec/operator.cpp b/be/src/pipeline/exec/operator.cpp index e42f581166c..4059de9c2a9 100644 --- a/be/src/pipeline/exec/operator.cpp +++ b/be/src/pipeline/exec/operator.cpp @@ -198,15 +198,15 @@ Status OperatorXBase::init(const TPlanNode& tnode, RuntimeState* /*state*/) { if (tnode.__isset.vconjunct) { vectorized::VExprContextSPtr context; RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(tnode.vconjunct, context)); - LOG_INFO("Conjunct of {} is\n{}", _op_name, - apache::thrift::ThriftDebugString(tnode.vconjunct)); + // LOG_INFO("Conjunct of {} is\n{}", _op_name, + // apache::thrift::ThriftDebugString(tnode.vconjunct)); _conjuncts.emplace_back(context); } else if (tnode.__isset.conjuncts) { for (auto& conjunct : tnode.conjuncts) { vectorized::VExprContextSPtr context; RETURN_IF_ERROR(vectorized::VExpr::create_expr_tree(conjunct, context)); - LOG_INFO("Conjunct of {} is\n{}", _op_name, - apache::thrift::ThriftDebugString(conjunct)); + // LOG_INFO("Conjunct of {} is\n{}", _op_name, + // apache::thrift::ThriftDebugString(conjunct)); // // Write the conjunct to a file for debugging // doris::vectorized::write_to_json( // "/mnt/disk4/hezhiqiang/workspace/doris/cmaster/RELEASE/be1", "conjunct.json", diff --git a/be/src/runtime/descriptors.cpp b/be/src/runtime/descriptors.cpp index 315829295ae..3d0d4c03727 100644 --- a/be/src/runtime/descriptors.cpp +++ b/be/src/runtime/descriptors.cpp @@ -66,6 +66,26 @@ SlotDescriptor::SlotDescriptor(const TSlotDescriptor& tdesc) _is_auto_increment(tdesc.__isset.is_auto_increment ? tdesc.is_auto_increment : false), _col_default_value(tdesc.__isset.col_default_value ? tdesc.col_default_value : "") { if (tdesc.__isset.virtual_column_expr) { + // Make sure virtual column is valid. + if (tdesc.virtual_column_expr.nodes.empty()) { + LOG_ERROR("Virtual column expr node is empty, col_name={}, col_unique_id={}", + tdesc.colName, tdesc.col_unique_id); + + throw doris::Exception(doris::ErrorCode::FATAL_ERROR, + "Virtual column expr node is empty, col_name: {}, " + "col_unique_id: {}", + tdesc.colName, tdesc.col_unique_id); + } + const auto& node = tdesc.virtual_column_expr.nodes[0]; + if (node.node_type == TExprNodeType::SLOT_REF) { + LOG_ERROR( + "Virtual column expr node is slot ref, col_name={}, col_unique_id={}, expr: {}", + tdesc.colName, tdesc.col_unique_id, apache::thrift::ThriftDebugString(tdesc)); + throw doris::Exception(doris::ErrorCode::FATAL_ERROR, + "Virtual column expr node is slot ref, col_name: {}, " + "col_unique_id: {}", + tdesc.colName, tdesc.col_unique_id); + } this->virtual_column_expr = std::make_shared<doris::TExpr>(tdesc.virtual_column_expr); } } diff --git a/be/src/runtime/runtime_state.h b/be/src/runtime/runtime_state.h index e4ecf59563c..dd9c25144f5 100644 --- a/be/src/runtime/runtime_state.h +++ b/be/src/runtime/runtime_state.h @@ -49,7 +49,7 @@ #include "runtime/workload_group/workload_group.h" #include "util/debug_util.h" #include "util/runtime_profile.h" -#include "vec/columns/columns_number.h" +#include "vec/runtime/vector_search_user_params.h" namespace doris { class RuntimeFilter; @@ -657,6 +657,12 @@ public: int profile_level() const { return _profile_level; } + VectorSearchUserParams get_vector_search_params() const { + return VectorSearchUserParams(_query_options.hnsw_ef_search, + _query_options.hnsw_check_relative_distance, + _query_options.hnsw_bounded_queue); + } + private: Status create_error_log_file(); diff --git a/be/src/vec/core/block.cpp b/be/src/vec/core/block.cpp index 4263a0659b7..fb7e218e062 100644 --- a/be/src/vec/core/block.cpp +++ b/be/src/vec/core/block.cpp @@ -777,6 +777,7 @@ void Block::update_hash(SipHash& hash) const { } } +// columns_to_filter 实际上是需要进行过滤的 col 的 position void Block::filter_block_internal(Block* block, const std::vector<uint32_t>& columns_to_filter, const IColumn::Filter& filter) { size_t count = filter.size() - simd::count_zero_num((int8_t*)filter.data(), filter.size()); diff --git a/be/src/vec/exec/scan/olap_scanner.cpp b/be/src/vec/exec/scan/olap_scanner.cpp index 63e5adff4a0..9c6197cce0e 100644 --- a/be/src/vec/exec/scan/olap_scanner.cpp +++ b/be/src/vec/exec/scan/olap_scanner.cpp @@ -104,6 +104,7 @@ OlapScanner::OlapScanner(pipeline::ScanLocalStateBase* parent, OlapScanner::Para }) { _tablet_reader_params.set_read_source(std::move(params.read_source)); _is_init = false; + _vector_search_params = params.state->get_vector_search_params(); } static std::string read_columns_to_string(TabletSchemaSPtr tablet_schema, @@ -143,11 +144,12 @@ Status OlapScanner::init() { auto* local_state = static_cast<pipeline::OlapScanLocalState*>(_local_state); auto& tablet = _tablet_reader_params.tablet; auto& tablet_schema = _tablet_reader_params.tablet_schema; + for (auto ctx : local_state->_common_expr_ctxs_push_down) { VExprContextSPtr context; RETURN_IF_ERROR(ctx->clone(_state, context)); _common_expr_ctxs_push_down.emplace_back(context); - RETURN_IF_ERROR(context->prepare_ann_range_search()); + RETURN_IF_ERROR(context->prepare_ann_range_search(_vector_search_params)); } for (auto pair : local_state->_slot_id_to_virtual_column_expr) { diff --git a/be/src/vec/exec/scan/olap_scanner.h b/be/src/vec/exec/scan/olap_scanner.h index 0fbeedb16a1..f6895662c88 100644 --- a/be/src/vec/exec/scan/olap_scanner.h +++ b/be/src/vec/exec/scan/olap_scanner.h @@ -36,6 +36,7 @@ #include "olap/tablet.h" #include "olap/tablet_reader.h" #include "olap/tablet_schema.h" +#include "runtime/runtime_state.h" #include "vec/data_types/data_type.h" #include "vec/exec/scan/scanner.h" @@ -118,6 +119,8 @@ public: std::map<size_t, vectorized::DataTypePtr> _vir_col_idx_to_type; std::shared_ptr<vectorized::AnnTopNDescriptor> _ann_topn_descriptor; + + VectorSearchUserParams _vector_search_params; }; } // namespace vectorized } // namespace doris diff --git a/be/src/vec/exprs/ann_range_search_params.h b/be/src/vec/exprs/ann_range_search_params.h index 5eedd7d334c..410c4dc14c4 100644 --- a/be/src/vec/exprs/ann_range_search_params.h +++ b/be/src/vec/exprs/ann_range_search_params.h @@ -22,18 +22,21 @@ #include <string> #include "olap/rowset/segment_v2/ann_index_iterator.h" +#include "runtime/runtime_state.h" +#include "vector/vector_index.h" namespace doris::vectorized { -struct AnnRangeSearchParams { +struct RangeSearchRuntimeInfo { bool is_ann_range_search = false; bool is_le_or_lt = true; size_t src_col_idx = 0; int64_t dst_col_idx = -1; double radius = 0.0; - int ef_search = 0; + segment_v2::VectorIndex::Metric metric_type; + doris::VectorSearchUserParams user_params; std::unique_ptr<float[]> query_value; - segment_v2::RangeSearchParams toRangeSearchParams() { + segment_v2::RangeSearchParams to_range_search_params() { segment_v2::RangeSearchParams params; params.query_value = query_value.get(); params.radius = static_cast<float>(radius); @@ -42,17 +45,13 @@ struct AnnRangeSearchParams { return params; } - segment_v2::CustomSearchParams toCustomSearchParams() { - segment_v2::CustomSearchParams params; - params.ef_search = ef_search; - return params; - } - std::string to_string() const { return fmt::format( "is_ann_range_search: {}, is_le_or_lt: {}, src_col_idx: {}, " - "dst_col_idx: {}, radius: {}, ef_search: {}", - is_ann_range_search, is_le_or_lt, src_col_idx, dst_col_idx, radius, ef_search); + "dst_col_idx: {}, metric_type {}, radius: {}, user params: {}", + is_ann_range_search, is_le_or_lt, src_col_idx, dst_col_idx, + segment_v2::VectorIndex::metric_to_string(metric_type), radius, + user_params.to_string()); } }; } // namespace doris::vectorized diff --git a/be/src/vec/exprs/vann_topn_predicate.cpp b/be/src/vec/exprs/vann_topn_predicate.cpp index d68086dcd3d..30352b09782 100644 --- a/be/src/vec/exprs/vann_topn_predicate.cpp +++ b/be/src/vec/exprs/vann_topn_predicate.cpp @@ -102,7 +102,16 @@ Status AnnTopNDescriptor::prepare(RuntimeState* state, const RowDescriptor& row_ distance_fn_call->children()[1]->debug_string()); } _query_array = array_literal->get_column_ptr(); + _user_params = state->get_vector_search_params(); + std::set<std::string> distance_func_names = {vectorized::L2Distance::name, + vectorized::InnerProduct::name}; + if (distance_func_names.contains(distance_fn_call->function_name()) == false) { + return Status::InternalError("Ann topn expr expect distance function, got {}", + distance_fn_call->function_name()); + } + + _metric_type = segment_v2::VectorIndex::string_to_metric(distance_fn_call->function_name()); VLOG_DEBUG << "AnnTopNDescriptor: {}" << this->debug_string(); return Status::OK(); } @@ -112,6 +121,9 @@ Status AnnTopNDescriptor::evaluate_vector_ann_search( vectorized::IColumn::MutablePtr& result_column, std::unique_ptr<std::vector<uint64_t>>& row_ids) { DCHECK(ann_index_iterator != nullptr); + segment_v2::AnnIndexIterator* ann_index_iterator_casted = + dynamic_cast<segment_v2::AnnIndexIterator*>(ann_index_iterator); + DCHECK(ann_index_iterator_casted != nullptr); DCHECK(_order_by_expr_ctx != nullptr); DCHECK(_order_by_expr_ctx->root() != nullptr); @@ -135,6 +147,7 @@ Status AnnTopNDescriptor::evaluate_vector_ann_search( .query_value = query_value_f32.get(), .query_value_size = query_value_size, .limit = _limit, + ._user_params = _user_params, .roaring = &roaring, .distance = nullptr, .row_ids = nullptr, @@ -159,7 +172,9 @@ Status AnnTopNDescriptor::evaluate_vector_ann_search( std::string AnnTopNDescriptor::debug_string() const { return "AnnTopNDescriptor: limit=" + std::to_string(_limit) + ", src_col_idx=" + std::to_string(_src_column_idx) + - ", dest_col_idx=" + std::to_string(_dest_column_idx) + + ", dest_col_idx=" + std::to_string(_dest_column_idx) + ", asc=" + std::to_string(_asc) + + ", user_params=" + _user_params.to_string() + + ", metric_type=" + segment_v2::VectorIndex::metric_to_string(_metric_type) + ", order_by_expr=" + _order_by_expr_ctx->root()->debug_string(); } } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/exprs/vann_topn_predicate.h b/be/src/vec/exprs/vann_topn_predicate.h index fb92dfa0a8b..842d054cfcd 100644 --- a/be/src/vec/exprs/vann_topn_predicate.h +++ b/be/src/vec/exprs/vann_topn_predicate.h @@ -17,6 +17,7 @@ #pragma once +#include "runtime/runtime_state.h" #include "vec/columns/column.h" #include "vec/exprs/varray_literal.h" #include "vec/exprs/vcast_expr.h" @@ -32,8 +33,8 @@ class AnnTopNDescriptor { ENABLE_FACTORY_CREATOR(AnnTopNDescriptor); public: - AnnTopNDescriptor(size_t limit, VExprContextSPtr order_by_expr_ctx) - : _limit(limit), _order_by_expr_ctx(order_by_expr_ctx) {}; + AnnTopNDescriptor(bool asc, size_t limit, VExprContextSPtr order_by_expr_ctx) + : _asc(asc), _limit(limit), _order_by_expr_ctx(order_by_expr_ctx) {}; Status prepare(RuntimeState* state, const RowDescriptor& row_desc); @@ -43,14 +44,16 @@ public: roaring::Roaring& row_bitmap, vectorized::IColumn::MutablePtr& result_column, std::unique_ptr<std::vector<uint64_t>>& row_ids); - + segment_v2::VectorIndex::Metric get_metric_type() const { return _metric_type; } std::string debug_string() const; size_t get_src_column_idx() const { return _src_column_idx; } size_t get_dest_column_idx() const { return _dest_column_idx; } + bool is_asc() const { return _asc; } private: + const bool _asc; // limit N const size_t _limit; // order by distance(xxx, [1,2]) @@ -59,7 +62,9 @@ private: std::string _name = "AnnTopNDescriptor"; size_t _src_column_idx = -1; size_t _dest_column_idx = -1; + segment_v2::VectorIndex::Metric _metric_type; IColumn::Ptr _query_array; + doris::VectorSearchUserParams _user_params; }; } // namespace doris::vectorized \ No newline at end of file diff --git a/be/src/vec/exprs/vectorized_fn_call.cpp b/be/src/vec/exprs/vectorized_fn_call.cpp index e3408aac960..9ed006195c3 100644 --- a/be/src/vec/exprs/vectorized_fn_call.cpp +++ b/be/src/vec/exprs/vectorized_fn_call.cpp @@ -57,6 +57,7 @@ #include "vec/functions/function_rpc.h" #include "vec/functions/simple_function_factory.h" #include "vec/utils/util.hpp" +#include "vector/vector_index.h" namespace doris { class RowDescriptor; @@ -258,6 +259,10 @@ const std::string& VectorizedFnCall::expr_name() const { return _expr_name; } +std::string VectorizedFnCall::function_name() const { + return _function_name; +} + std::string VectorizedFnCall::debug_string() const { std::stringstream out; out << "VectorizedFn["; @@ -327,7 +332,8 @@ bool VectorizedFnCall::equals(const VExpr& other) { SlotRef */ -Status VectorizedFnCall::prepare_ann_range_search() { +Status VectorizedFnCall::prepare_ann_range_search( + const doris::VectorSearchUserParams& user_params) { std::set<TExprOpcode::type> ops = {TExprOpcode::GE, TExprOpcode::LE, TExprOpcode::LE, TExprOpcode::GT, TExprOpcode::LT}; if (ops.find(this->op()) == ops.end()) { @@ -376,9 +382,13 @@ Status VectorizedFnCall::prepare_ann_range_search() { } // Now left child is a function call, we need to check if it is a distance function - if (function_call->_function_name != L2Distance::name) { + std::set<std::string> distance_functions = {L2Distance::name, InnerProduct::name}; + if (distance_functions.find(function_call->_function_name) == distance_functions.end()) { LOG_INFO("Left child is not a distance function. Got {}", function_call->_function_name); return Status::OK(); + } else { + _ann_range_search_params.metric_type = + segment_v2::VectorIndex::string_to_metric(function_call->_function_name); } if (function_call->get_num_children() != 2) { @@ -430,6 +440,7 @@ Status VectorizedFnCall::prepare_ann_range_search() { _ann_range_search_params.query_value[i] = static_cast<Float32>(cf64->get_data()[i]); } _ann_range_search_params.is_ann_range_search = true; + _ann_range_search_params.user_params = user_params; LOG_INFO("Ann range search params: {}", _ann_range_search_params.to_string()); return Status::OK(); } @@ -452,26 +463,37 @@ Status VectorizedFnCall::evaluate_ann_range_search( ColumnId src_col_cid = idx_to_cid[idx_in_block]; DCHECK(src_col_cid < cid_to_index_iterators.size()); - segment_v2::IndexIterator* index_iterators = cid_to_index_iterators[src_col_cid].get(); - if (index_iterators == nullptr) { + segment_v2::IndexIterator* index_iterator = cid_to_index_iterators[src_col_cid].get(); + if (index_iterator == nullptr) { LOG_INFO("No index iterator for column cid {}", src_col_cid); return Status::OK(); } - segment_v2::AnnIndexIterator* ann_index_iterators = - dynamic_cast<segment_v2::AnnIndexIterator*>(index_iterators); - if (ann_index_iterators == nullptr) { + segment_v2::AnnIndexIterator* ann_index_iterator = + dynamic_cast<segment_v2::AnnIndexIterator*>(index_iterator); + if (ann_index_iterator == nullptr) { LOG_INFO("No index iterator for column cid {}", src_col_cid); return Status::OK(); } + DCHECK(ann_index_iterator->get_reader() != nullptr) + << "Ann index iterator should have reader. Column cid: " << src_col_cid; + std::shared_ptr<AnnIndexReader> ann_index_reader = + std::dynamic_pointer_cast<AnnIndexReader>(ann_index_iterator->get_reader()); + DCHECK(ann_index_reader != nullptr) + << "Ann index reader should not be null. Column cid: " << src_col_cid; + // Check if metrics type is match. + if (ann_index_reader->get_metric_type() != _ann_range_search_params.metric_type) { + LOG_INFO("Metric type not match, can not execute range search by index."); + return Status::OK(); + } - RangeSearchParams params = _ann_range_search_params.toRangeSearchParams(); - CustomSearchParams custom_params = _ann_range_search_params.toCustomSearchParams(); + RangeSearchParams params = _ann_range_search_params.to_range_search_params(); params.roaring = &row_bitmap; DCHECK(params.roaring != nullptr); RangeSearchResult result; - RETURN_IF_ERROR(ann_index_iterators->range_search(params, custom_params, &result)); + RETURN_IF_ERROR(ann_index_iterator->range_search(params, _ann_range_search_params.user_params, + &result)); #ifndef NDEBUG if (this->_ann_range_search_params.is_le_or_lt == false) { diff --git a/be/src/vec/exprs/vectorized_fn_call.h b/be/src/vec/exprs/vectorized_fn_call.h index 4f3f76b5436..14d86964ac9 100644 --- a/be/src/vec/exprs/vectorized_fn_call.h +++ b/be/src/vec/exprs/vectorized_fn_call.h @@ -22,6 +22,7 @@ #include <vector> #include "common/status.h" +#include "runtime/runtime_state.h" #include "udf/udf.h" #include "vec/core/column_numbers.h" #include "vec/exprs/ann_range_search_params.h" @@ -60,6 +61,7 @@ public: FunctionContext::FunctionStateScope scope) override; void close(VExprContext* context, FunctionContext::FunctionStateScope scope) override; const std::string& expr_name() const override; + std::string function_name() const; std::string debug_string() const override; bool is_constant() const override { if (!_function->is_use_default_implementation_for_constants() || @@ -82,13 +84,13 @@ public: const std::vector<std::unique_ptr<segment_v2::ColumnIterator>>& column_iterators, roaring::Roaring& row_bitmap) override; - Status prepare_ann_range_search() override; + Status prepare_ann_range_search(const doris::VectorSearchUserParams& params) override; protected: FunctionBasePtr _function; std::string _expr_name; std::string _function_name; - AnnRangeSearchParams _ann_range_search_params; + RangeSearchRuntimeInfo _ann_range_search_params; private: Status _do_execute(doris::vectorized::VExprContext* context, doris::vectorized::Block* block, diff --git a/be/src/vec/exprs/vexpr.cpp b/be/src/vec/exprs/vexpr.cpp index 794bd66fac5..f3269dfca36 100644 --- a/be/src/vec/exprs/vexpr.cpp +++ b/be/src/vec/exprs/vexpr.cpp @@ -807,9 +807,9 @@ Status VExpr::evaluate_ann_range_search( return Status::OK(); } -Status VExpr::prepare_ann_range_search() { +Status VExpr::prepare_ann_range_search(const doris::VectorSearchUserParams& params) { for (auto& child : _children) { - RETURN_IF_ERROR(child->prepare_ann_range_search()); + RETURN_IF_ERROR(child->prepare_ann_range_search(params)); } return Status::OK(); } diff --git a/be/src/vec/exprs/vexpr.h b/be/src/vec/exprs/vexpr.h index c78cadb803a..09b91c35114 100644 --- a/be/src/vec/exprs/vexpr.h +++ b/be/src/vec/exprs/vexpr.h @@ -283,7 +283,7 @@ public: const std::vector<std::unique_ptr<segment_v2::ColumnIterator>>& column_iterators, roaring::Roaring& row_bitmap); - virtual Status prepare_ann_range_search(); + virtual Status prepare_ann_range_search(const doris::VectorSearchUserParams& params); bool has_been_executed(); diff --git a/be/src/vec/exprs/vexpr_context.cpp b/be/src/vec/exprs/vexpr_context.cpp index 66c5c1bfc58..886dea256c5 100644 --- a/be/src/vec/exprs/vexpr_context.cpp +++ b/be/src/vec/exprs/vexpr_context.cpp @@ -432,11 +432,11 @@ void VExprContext::_reset_memory_usage(const VExprContextSPtrs& contexts) { [](auto&& context) { context->_memory_usage = 0; }); } -Status VExprContext::prepare_ann_range_search() { +Status VExprContext::prepare_ann_range_search(const doris::VectorSearchUserParams& params) { if (_root == nullptr) { return Status::OK(); } - return _root->prepare_ann_range_search(); + return _root->prepare_ann_range_search(params); } #include "common/compile_check_end.h" diff --git a/be/src/vec/exprs/vexpr_context.h b/be/src/vec/exprs/vexpr_context.h index 60d3e1a31e0..d43012353e5 100644 --- a/be/src/vec/exprs/vexpr_context.h +++ b/be/src/vec/exprs/vexpr_context.h @@ -28,6 +28,7 @@ #include "common/factory_creator.h" #include "common/status.h" #include "olap/rowset/segment_v2/inverted_index_reader.h" +#include "runtime/runtime_state.h" #include "runtime/types.h" #include "udf/udf.h" #include "vec/core/block.h" @@ -279,7 +280,7 @@ public: [[nodiscard]] size_t get_memory_usage() const { return _memory_usage; } - Status prepare_ann_range_search(); + Status prepare_ann_range_search(const doris::VectorSearchUserParams& params); private: // Close method is called in vexpr context dector, not need call expicility diff --git a/be/src/vec/exprs/virtual_slot_ref.cpp b/be/src/vec/exprs/virtual_slot_ref.cpp index 2455ef40bfa..844da622f31 100644 --- a/be/src/vec/exprs/virtual_slot_ref.cpp +++ b/be/src/vec/exprs/virtual_slot_ref.cpp @@ -91,7 +91,7 @@ Status VirtualSlotRef::prepare(doris::RuntimeState* state, const doris::RowDescr state->desc_tbl().debug_string()); } const TExpr& expr = *slot_desc->get_virtual_column_expr(); - LOG_INFO("Virtual column expr is {}", apache::thrift::ThriftDebugString(expr)); + // LOG_INFO("Virtual column expr is {}", apache::thrift::ThriftDebugString(expr)); // Create a temp_ctx only for create_expr_tree. VExprContextSPtr temp_ctx; RETURN_IF_ERROR(VExpr::create_expr_tree(expr, temp_ctx)); diff --git a/be/src/vec/functions/array/function_array_distance.h b/be/src/vec/functions/array/function_array_distance.h index 28b0df28d7f..fcb7a067a07 100644 --- a/be/src/vec/functions/array/function_array_distance.h +++ b/be/src/vec/functions/array/function_array_distance.h @@ -96,6 +96,9 @@ public: Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, uint32_t result, size_t input_rows_count) const override { + LOG_INFO("Function {} is executed with {} rows, stack {}", get_name(), input_rows_count, + doris::get_stack_trace()); + const auto& arg1 = block.get_by_position(arguments[0]); const auto& arg2 = block.get_by_position(arguments[1]); if (!_check_input_type(arg1.type) || !_check_input_type(arg2.type)) { diff --git a/be/src/vec/runtime/vector_search_user_params.cpp b/be/src/vec/runtime/vector_search_user_params.cpp new file mode 100644 index 00000000000..04b8c8b91c4 --- /dev/null +++ b/be/src/vec/runtime/vector_search_user_params.cpp @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "vec/runtime/vector_search_user_params.h" + +#include <fmt/format.h> + +namespace doris { +bool VectorSearchUserParams::operator==(const VectorSearchUserParams& other) const { + return hnsw_ef_search == other.hnsw_ef_search && + hnsw_check_relative_distance == other.hnsw_check_relative_distance && + hnsw_bounded_queue == other.hnsw_bounded_queue; +} + +std::string VectorSearchUserParams::to_string() const { + return fmt::format( + "hnsw_ef_search: {}, hnsw_check_relative_distance: {}, " + "hnsw_bounded_queue: {}", + hnsw_ef_search, hnsw_check_relative_distance, hnsw_bounded_queue); +} +} // namespace doris \ No newline at end of file diff --git a/be/src/vec/runtime/vector_search_user_params.h b/be/src/vec/runtime/vector_search_user_params.h new file mode 100644 index 00000000000..5f886405e06 --- /dev/null +++ b/be/src/vec/runtime/vector_search_user_params.h @@ -0,0 +1,31 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <string> + +namespace doris { +// Constructed from session variables. +struct VectorSearchUserParams { + int hnsw_ef_search = 16; + bool hnsw_check_relative_distance = true; + bool hnsw_bounded_queue = true; + + bool operator==(const VectorSearchUserParams& other) const; + + std::string to_string() const; +}; +} // namespace doris \ No newline at end of file diff --git a/be/src/vector/faiss_vector_index.cpp b/be/src/vector/faiss_vector_index.cpp index e5916973ece..f48c71334c5 100644 --- a/be/src/vector/faiss_vector_index.cpp +++ b/be/src/vector/faiss_vector_index.cpp @@ -129,9 +129,48 @@ doris::Status FaissVectorIndex::add(int n, const float* vec) { void FaissVectorIndex::set_build_params(const FaissBuildParameter& params) { _dimension = params.d; if (params.index_type == FaissBuildParameter::IndexType::BruteForce) { - _index = std::make_unique<faiss::IndexFlatL2>(params.d); + if (params.metric_type == FaissBuildParameter::MetricType::L2) { + _index = std::make_unique<faiss::IndexFlatL2>(params.d); + } else if (params.metric_type == FaissBuildParameter::MetricType::IP) { + _index = std::make_unique<faiss::IndexFlatIP>(params.d); + } else { + throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT, + "Unsupported metric type: {}", + static_cast<int>(params.metric_type)); + } } else if (params.index_type == FaissBuildParameter::IndexType::HNSW) { - _index = std::make_unique<faiss::IndexHNSWFlat>(params.d, params.m); + if (params.quantilizer == FaissBuildParameter::Quantilizer::FLAT) { + if (params.metric_type == FaissBuildParameter::MetricType::L2) { + _index = std::make_unique<faiss::IndexHNSWFlat>(params.d, params.m); + } else if (params.metric_type == FaissBuildParameter::MetricType::IP) { + _index = std::make_unique<faiss::IndexHNSWFlat>(params.d, params.m, + faiss::METRIC_INNER_PRODUCT); + } else { + throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT, + "Unsupported metric type: {}", + static_cast<int>(params.metric_type)); + } + } else if (params.quantilizer == FaissBuildParameter::Quantilizer::PQ) { + if (params.pq_m <= 0) { + throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT, + "pq_m should be greater than 0 for PQ quantilizer"); + } + + if (params.metric_type == FaissBuildParameter::MetricType::L2) { + _index = std::make_unique<faiss::IndexHNSWPQ>(params.d, params.m, params.pq_m); + } else if (params.metric_type == FaissBuildParameter::MetricType::IP) { + _index = std::make_unique<faiss::IndexHNSWPQ>(params.d, params.m, params.pq_m, + faiss::METRIC_INNER_PRODUCT); + } else { + throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT, + "Unsupported metric type: {}", + static_cast<int>(params.metric_type)); + } + } else { + throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT, + "Unsupported quantilizer type: {}", + static_cast<int>(params.quantilizer)); + } } else { throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT, "Unsupported index type: {}", static_cast<int>(params.index_type)); @@ -158,7 +197,16 @@ doris::Status FaissVectorIndex::ann_topn_search(const float* query_vec, int k, std::unique_ptr<faiss::IDSelector> id_sel = nullptr; id_sel = roaring_to_faiss_selector(*params.roaring); faiss::SearchParametersHNSW param; + const HNSWSearchParameters* hnsw_params = + dynamic_cast<const HNSWSearchParameters*>(¶ms); + if (hnsw_params == nullptr) { + return doris::Status::InvalidArgument( + "HNSW search parameters should not be null for HNSW index"); + } param.sel = id_sel.get(); + param.efSearch = hnsw_params->ef_search; + param.check_relative_distance = hnsw_params->check_relative_distance; + param.bounded_queue = hnsw_params->bounded_queue; _index->search(1, query_vec, k, distances, labels, ¶m); } @@ -193,6 +241,8 @@ doris::Status FaissVectorIndex::range_search(const float* query_vec, const float if (hnsw_params != nullptr) { faiss::SearchParametersHNSW param; param.efSearch = hnsw_params->ef_search; + param.check_relative_distance = hnsw_params->check_relative_distance; + param.bounded_queue = hnsw_params->bounded_queue; param.sel = sel ? sel.get() : nullptr; _index->range_search(1, query_vec, radius * radius, &native_search_result, ¶m); } else { diff --git a/be/src/vector/faiss_vector_index.h b/be/src/vector/faiss_vector_index.h index 53637344bc1..129a6b26ccb 100644 --- a/be/src/vector/faiss_vector_index.h +++ b/be/src/vector/faiss_vector_index.h @@ -26,13 +26,20 @@ #include <string> #include "common/status.h" +#include "util/metrics.h" #include "vector_index.h" namespace doris::segment_v2 { struct FaissBuildParameter { enum class IndexType { BruteForce, IVF, HNSW }; - enum class Quantilizer { FLAT, SQ, PQ }; + enum class Quantilizer { FLAT, PQ }; + + enum class MetricType { + L2, // Euclidean distance + IP, // Inner product + COSINE // Cosine similarity + }; static IndexType string_to_index_type(const std::string& type) { if (type == "brute_force") { @@ -48,19 +55,30 @@ struct FaissBuildParameter { static Quantilizer string_to_quantilizer(const std::string& type) { if (type == "flat") { return Quantilizer::FLAT; - } else if (type == "sq") { - return Quantilizer::SQ; } else if (type == "pq") { return Quantilizer::PQ; } return Quantilizer::FLAT; // default } + static MetricType string_to_metric_type(const std::string& type) { + if (type == "l2") { + return MetricType::L2; + } else if (type == "ip") { + return MetricType::IP; + } else if (type == "cosine") { + return MetricType::COSINE; + } + return MetricType::L2; // default + } + // HNSW int d = 0; int m = 0; + int pq_m = -1; // Only used for PQ quantilizer IndexType index_type; Quantilizer quantilizer; + MetricType metric_type = MetricType::L2; }; class FaissVectorIndex : public VectorIndex { diff --git a/be/src/vector/vector_index.h b/be/src/vector/vector_index.h index 50b0d59b624..e344f1c06c3 100644 --- a/be/src/vector/vector_index.h +++ b/be/src/vector/vector_index.h @@ -21,7 +21,7 @@ #include <roaring/roaring.hh> #include "common/status.h" -#include "gutil/integral_types.h" +#include "vec/functions/array/function_array_distance.h" namespace lucene::store { class Directory; @@ -50,11 +50,13 @@ struct IndexSearchParameters { struct HNSWSearchParameters : public IndexSearchParameters { int ef_search = 16; + bool check_relative_distance = true; + bool bounded_queue = true; }; class VectorIndex { public: - enum class Metric { L2, COSINE, INNER_PRODUCT, UNKNOWN }; + enum class Metric { L2, INNER_PRODUCT, UNKNOWN }; /** Add n vectors of dimension d to the index. * @@ -87,21 +89,17 @@ public: static std::string metric_to_string(Metric metric) { switch (metric) { case Metric::L2: - return "L2"; - case Metric::COSINE: - return "COSINE"; + return vectorized::L2Distance::name; case Metric::INNER_PRODUCT: - return "INNER_PRODUCT"; + return vectorized::InnerProduct::name; default: return "UNKNOWN"; } } static Metric string_to_metric(const std::string& metric) { - if (metric == "l2") { + if (metric == vectorized::L2Distance::name) { return Metric::L2; - } else if (metric == "cosine") { - return Metric::COSINE; - } else if (metric == "inner_product") { + } else if (metric == vectorized::InnerProduct::name) { return Metric::INNER_PRODUCT; } else { return Metric::UNKNOWN; @@ -112,6 +110,7 @@ public: size_t get_dimension() const { return _dimension; } protected: + // When adding vectors to the index, use this variable to check the dimension of the vectors. size_t _dimension = 0; }; diff --git a/be/test/olap/vector_search/ann_index_reader_test.cpp b/be/test/olap/vector_search/ann_index_reader_test.cpp index e951a48c743..5b29f21e90c 100644 --- a/be/test/olap/vector_search/ann_index_reader_test.cpp +++ b/be/test/olap/vector_search/ann_index_reader_test.cpp @@ -15,33 +15,93 @@ // specific language governing permissions and limitations // under the License. -#include "olap/rowset/segment_v2/ann_index_reader.h" - +#include <gen_cpp/olap_file.pb.h> #include <gmock/gmock.h> #include <gtest/gtest.h> +#include <iostream> #include <memory> +#include <string> +#include "faiss_vector_index.h" +#include "olap/rowset/segment_v2/ann_index_iterator.h" +#include "olap/tablet_schema.h" +#include "runtime/runtime_state.h" #include "vector_search_utils.h" +using namespace doris::vector_search_utils; + +namespace doris::vectorized { + +TEST_F(VectorSearchTest, AnnIndexReaderRangeSearch) { + size_t iterato = 25; + for (size_t i = 0; i < iterato; ++i) { + std::map<std::string, std::string> index_properties; + index_properties["index_type"] = "hnsw"; + index_properties["metric_type"] = "l2"; + std::unique_ptr<doris::TabletIndex> index_meta = std::make_unique<doris::TabletIndex>(); + index_meta->_properties = index_properties; + auto mock_index_file_reader = std::make_shared<MockIndexFileReader>(); + auto ann_index_reader = std::make_unique<segment_v2::AnnIndexReader>( + index_meta.get(), mock_index_file_reader); + doris::vector_search_utils::IndexType index_type = + doris::vector_search_utils::IndexType::HNSW; + const size_t dim = 128; + const size_t m = 16; + auto doris_faiss_index = doris::vector_search_utils::create_doris_index(index_type, dim, m); + auto native_faiss_index = + doris::vector_search_utils::create_native_index(index_type, dim, m); + const size_t num_vectors = 1000; + auto vectors = doris::vector_search_utils::generate_test_vectors_matrix(num_vectors, dim); + doris::vector_search_utils::add_vectors_to_indexes_serial_mode( + doris_faiss_index.get(), native_faiss_index.get(), vectors); + std::ignore = doris_faiss_index->save(this->_ram_dir.get()); + std::vector<float> query_value = vectors[0]; + const float radius = doris::vector_search_utils::get_radius_from_matrix(query_value.data(), + dim, vectors, 0.3); + + // Make sure all rows are in the roaring + auto roaring = std::make_unique<roaring::Roaring>(); + for (size_t i = 0; i < num_vectors; ++i) { + roaring->add(i); + } + + doris::segment_v2::RangeSearchParams params; + params.radius = radius; + params.query_value = query_value.data(); + params.roaring = roaring.get(); + doris::VectorSearchUserParams custom_params; + custom_params.hnsw_ef_search = 16; + doris::segment_v2::RangeSearchResult result; + auto doris_faiss_vector_index = std::make_unique<doris::segment_v2::FaissVectorIndex>(); + std::ignore = doris_faiss_vector_index->load(this->_ram_dir.get()); + ann_index_reader->_vector_index = std::move(doris_faiss_vector_index); + std::ignore = ann_index_reader->range_search(params, custom_params, &result, nullptr); -namespace doris::segment_v2 { + ASSERT_TRUE(result.roaring != nullptr); + ASSERT_TRUE(result.distance != nullptr); + ASSERT_TRUE(result.row_ids != nullptr); + std::vector<std::pair<int, float>> doris_search_result_order_by_lables; + for (size_t i = 0; i < result.roaring->cardinality(); ++i) { + doris_search_result_order_by_lables.push_back( + {result.row_ids->at(i), result.distance[i]}); + } -using namespace vector_search_utils; -class AnnIndexReaderTest : public testing::Test {}; + std::sort(doris_search_result_order_by_lables.begin(), + doris_search_result_order_by_lables.end(), + [](const auto& a, const auto& b) { return a.first < b.first; }); -TEST_F(AnnIndexReaderTest, TestLoadIndex) { - MockTabletSchema tablet_schema; - std::shared_ptr<MockIndexFileReader> index_file_reader = - std::make_shared<MockIndexFileReader>(); - auto ann_index_reader = std::make_unique<AnnIndexReader>(&tablet_schema, index_file_reader); + std::vector<std::pair<int, float>> native_search_result_order_by_lables = + doris::vector_search_utils::perform_native_index_range_search( + native_faiss_index.get(), query_value.data(), radius); - EXPECT_TRUE(ann_index_reader->load_index(nullptr).ok()); -} + ASSERT_EQ(result.roaring->cardinality(), native_search_result_order_by_lables.size()); -TEST_F(AnnIndexReaderTest, TestQuery) { - MockTabletSchema tablet_schema; - std::shared_ptr<MockIndexFileReader> index_file_reader = - std::make_shared<MockIndexFileReader>(); - auto ann_index_reader = std::make_unique<AnnIndexReader>(&tablet_schema, index_file_reader); -} -} // namespace doris::segment_v2 \ No newline at end of file + for (size_t i = 0; i < native_search_result_order_by_lables.size(); ++i) { + ASSERT_EQ(doris_search_result_order_by_lables[i].first, + native_search_result_order_by_lables[i].first); + ASSERT_FLOAT_EQ(doris_search_result_order_by_lables[i].second, + native_search_result_order_by_lables[i].second); + } + } +}; +} // namespace doris::vectorized \ No newline at end of file diff --git a/be/test/olap/vector_search/ann_range_search_test.cpp b/be/test/olap/vector_search/ann_range_search_test.cpp index 8b9ad847f46..f5599526b38 100644 --- a/be/test/olap/vector_search/ann_range_search_test.cpp +++ b/be/test/olap/vector_search/ann_range_search_test.cpp @@ -28,7 +28,9 @@ #include "common/object_pool.h" #include "olap/rowset/segment_v2/ann_index_iterator.h" +#include "olap/rowset/segment_v2/ann_index_reader.h" #include "olap/rowset/segment_v2/column_reader.h" +#include "olap/rowset/segment_v2/index_file_reader.h" #include "olap/rowset/segment_v2/virtual_column_iterator.h" #include "olap/vector_search/vector_search_utils.h" #include "runtime/descriptors.h" @@ -826,10 +828,11 @@ TEST_F(VectorSearchTest, TestPrepareAnnRangeSearch) { state->set_desc_tbl(desc_tbl_ptr); VExprContextSPtr range_search_ctx; + doris::VectorSearchUserParams user_params; ASSERT_TRUE(vectorized::VExpr::create_expr_tree(texpr, range_search_ctx).ok()); ASSERT_TRUE(range_search_ctx->prepare(state.get(), row_desc).ok()); ASSERT_TRUE(range_search_ctx->open(state.get()).ok()); - ASSERT_TRUE(range_search_ctx->prepare_ann_range_search().ok()); + ASSERT_TRUE(range_search_ctx->prepare_ann_range_search(user_params).ok()); std::shared_ptr<VectorizedFnCall> fn_call = std::dynamic_pointer_cast<VectorizedFnCall>(range_search_ctx->root()); @@ -840,7 +843,7 @@ TEST_F(VectorSearchTest, TestPrepareAnnRangeSearch) { ASSERT_EQ(fn_call->_ann_range_search_params.radius, 10); doris::segment_v2::RangeSearchParams range_search_params = - fn_call->_ann_range_search_params.toRangeSearchParams(); + fn_call->_ann_range_search_params.to_range_search_params(); EXPECT_EQ(range_search_params.radius, 10.0f); std::vector<int> query_array_groud_truth = {1, 2, 3, 4, 5, 6, 7, 20}; std::vector<int> query_array_f32; @@ -867,11 +870,11 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch) { ASSERT_TRUE(vectorized::VExpr::create_expr_tree(texpr, range_search_ctx).ok()); ASSERT_TRUE(range_search_ctx->prepare(state.get(), row_desc).ok()); ASSERT_TRUE(range_search_ctx->open(state.get()).ok()); - ASSERT_TRUE(range_search_ctx->prepare_ann_range_search().ok()); - + doris::VectorSearchUserParams user_params; + ASSERT_TRUE(range_search_ctx->prepare_ann_range_search(user_params).ok()); std::shared_ptr<VectorizedFnCall> fn_call = std::dynamic_pointer_cast<VectorizedFnCall>(range_search_ctx->root()); - + ASSERT_EQ(fn_call->_ann_range_search_params.user_params, user_params); ASSERT_TRUE(fn_call->_ann_range_search_params.is_ann_range_search == true); ASSERT_EQ(fn_call->_ann_range_search_params.is_le_or_lt, false); ASSERT_EQ(fn_call->_ann_range_search_params.src_col_idx, 1); @@ -897,6 +900,12 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch) { dynamic_cast<doris::vector_search_utils::MockAnnIndexIterator*>( cid_to_index_iterators[1].get()); + std::map<std::string, std::string> properties; + properties["index_type"] = "hnsw"; + properties["metric_type"] = "l2_distance"; + auto pair = vector_search_utils::create_tmp_ann_index_reader(properties); + mock_ann_index_iter->_ann_reader = pair.second; + // Explain: // 1. predicate is dist >= 10, so it is not a within range search // 2. return 10 results @@ -906,18 +915,11 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch) { }), testing::_, testing::_)) .WillOnce(testing::Invoke([](const doris::segment_v2::RangeSearchParams& params, - const doris::segment_v2::CustomSearchParams& custom_params, + const doris::VectorSearchUserParams& custom_params, doris::segment_v2::RangeSearchResult* result) { - // size_t num_results = 10; result->roaring = std::make_shared<roaring::Roaring>(); result->row_ids = nullptr; result->distance = nullptr; - // result->row_ids = std::make_unique<std::vector<uint64_t>>(); - // for (size_t i = 0; i < num_results; ++i) { - // result->roaring->add(i * 10); - // result->row_ids->push_back(i * 10); - // } - // result->distance = std::make_unique<float[]>(10); return Status::OK(); })); @@ -960,7 +962,8 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch2) { ASSERT_TRUE(vectorized::VExpr::create_expr_tree(texpr, range_search_ctx).ok()); ASSERT_TRUE(range_search_ctx->prepare(state.get(), row_desc).ok()); ASSERT_TRUE(range_search_ctx->open(state.get()).ok()); - ASSERT_TRUE(range_search_ctx->prepare_ann_range_search().ok()); + doris::VectorSearchUserParams user_params; + ASSERT_TRUE(range_search_ctx->prepare_ann_range_search(user_params).ok()); std::shared_ptr<VectorizedFnCall> fn_call = std::dynamic_pointer_cast<VectorizedFnCall>(range_search_ctx->root()); @@ -984,11 +987,15 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch2) { std::vector<std::unique_ptr<segment_v2::ColumnIterator>> column_iterators; column_iterators.resize(4); column_iterators[3] = std::make_unique<doris::segment_v2::VirtualColumnIterator>(); - roaring::Roaring row_bitmap; doris::vector_search_utils::MockAnnIndexIterator* mock_ann_index_iter = dynamic_cast<doris::vector_search_utils::MockAnnIndexIterator*>( cid_to_index_iterators[1].get()); + std::map<std::string, std::string> properties; + properties["index_type"] = "hnsw"; + properties["metric_type"] = "l2_distance"; + auto pair = vector_search_utils::create_tmp_ann_index_reader(properties); + mock_ann_index_iter->_ann_reader = pair.second; // Explain: // 1. predicate is dist >= 10, so it is not a within range search @@ -999,7 +1006,7 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch2) { }), testing::_, testing::_)) .WillOnce(testing::Invoke([](const doris::segment_v2::RangeSearchParams& params, - const doris::segment_v2::CustomSearchParams& custom_params, + const doris::VectorSearchUserParams& custom_params, doris::segment_v2::RangeSearchResult* result) { size_t num_results = 10; result->roaring = std::make_shared<roaring::Roaring>(); diff --git a/be/test/olap/vector_search/ann_topn_descriptor_test.cpp b/be/test/olap/vector_search/ann_topn_descriptor_test.cpp index 7437bd80411..e6bc46c5d97 100644 --- a/be/test/olap/vector_search/ann_topn_descriptor_test.cpp +++ b/be/test/olap/vector_search/ann_topn_descriptor_test.cpp @@ -65,8 +65,8 @@ TEST_F(VectorSearchTest, AnnTopNDescriptorConstructor) { v->set_virtual_column_expr(distanc_calcu_fn_call_ctx->root()); std::shared_ptr<AnnTopNDescriptor> predicate; - predicate = AnnTopNDescriptor::create_shared(limit, virtual_slot_expr_ctx); - ASSERT_TRUE(predicate != nullptr) << "AnnTopNDescriptor::create_shared() failed"; + predicate = AnnTopNDescriptor::create_shared(true, limit, virtual_slot_expr_ctx); + ASSERT_TRUE(predicate != nullptr) << "AnnTopNDescriptor::create_shared(true,) failed"; } TEST_F(VectorSearchTest, AnnTopNDescriptorPrepare) { @@ -86,7 +86,7 @@ TEST_F(VectorSearchTest, AnnTopNDescriptorPrepare) { v->set_virtual_column_expr(distanc_calcu_fn_call_ctx->root()); std::shared_ptr<AnnTopNDescriptor> predicate; - predicate = AnnTopNDescriptor::create_shared(limit, virtual_slot_expr_ctx); + predicate = AnnTopNDescriptor::create_shared(true, limit, virtual_slot_expr_ctx); st = predicate->prepare(&_runtime_state, _row_desc); ASSERT_TRUE(st.ok()) << fmt::format("st: {}, expr {}", st.to_string(), predicate->get_order_by_expr_ctx()->root()->debug_string()); @@ -111,7 +111,7 @@ TEST_F(VectorSearchTest, AnnTopNDescriptorEvaluateTopN) { v->set_virtual_column_expr(distanc_calcu_fn_call_ctx->root()); std::shared_ptr<AnnTopNDescriptor> predicate; - predicate = AnnTopNDescriptor::create_shared(limit, virtual_slot_expr_ctx); + predicate = AnnTopNDescriptor::create_shared(true, limit, virtual_slot_expr_ctx); st = predicate->prepare(&_runtime_state, _row_desc); ASSERT_TRUE(st.ok()) << fmt::format("st: {}, expr {}", st.to_string(), predicate->get_order_by_expr_ctx()->root()->debug_string()); diff --git a/be/test/olap/vector_search/faiss_vector_index_test.cpp b/be/test/olap/vector_search/faiss_vector_index_test.cpp index bae49627ef2..684c3bf6cf4 100644 --- a/be/test/olap/vector_search/faiss_vector_index_test.cpp +++ b/be/test/olap/vector_search/faiss_vector_index_test.cpp @@ -300,13 +300,13 @@ TEST_F(VectorSearchTest, CompRangeSearch) { size_t random_n = std::uniform_int_distribution<>(500, 2000)(gen); // Random number of vectors // Step 1: Create and build index - auto index1 = std::make_unique<FaissVectorIndex>(); + auto doris_index = std::make_unique<FaissVectorIndex>(); FaissBuildParameter params; params.d = random_d; params.m = random_m; params.index_type = FaissBuildParameter::IndexType::HNSW; - index1->set_build_params(params); + doris_index->set_build_params(params); const int num_vectors = random_n; std::vector<std::vector<float>> vectors; @@ -316,31 +316,20 @@ TEST_F(VectorSearchTest, CompRangeSearch) { } std::unique_ptr<faiss::Index> native_index = std::make_unique<faiss::IndexHNSWFlat>(params.d, params.m); - doris::vector_search_utils::add_vectors_to_indexes_serial_mode(index1.get(), + doris::vector_search_utils::add_vectors_to_indexes_serial_mode(doris_index.get(), native_index.get(), vectors); std::vector<float> query_vec = vectors.front(); - - std::vector<std::pair<size_t, float>> distances(num_vectors); - for (int i = 0; i < num_vectors; i++) { - double sum = 0; - auto& vec = vectors[i]; - for (int j = 0; j < params.d; j++) { - accumulate(vec[j], query_vec[j], sum); - } - distances[i] = std::make_pair(i, finalize(sum)); - } - std::sort(distances.begin(), distances.end(), - [](const auto& a, const auto& b) { return a.second < b.second; }); - - float radius = distances[num_vectors / 4].second; + const float radius = doris::vector_search_utils::get_radius_from_matrix( + query_vec.data(), params.d, vectors, 0.4f); HNSWSearchParameters hnsw_params; hnsw_params.ef_search = 16; // Set efSearch for better accuracy hnsw_params.roaring = nullptr; // No selector for this test hnsw_params.is_le_or_lt = true; IndexSearchResult doris_result; - std::ignore = index1->range_search(query_vec.data(), radius, hnsw_params, doris_result); + std::ignore = + doris_index->range_search(query_vec.data(), radius, hnsw_params, doris_result); faiss::SearchParametersHNSW search_params_native; search_params_native.efSearch = hnsw_params.ef_search; @@ -575,6 +564,7 @@ TEST_F(VectorSearchTest, RangeSearchWithSelector1) { true); ASSERT_EQ(native_results.size(), doris_search_result.roaring->cardinality()); + ASSERT_EQ(doris_search_result.distances != nullptr, true); for (size_t i = 0; i < native_results.size(); i++) { const size_t rowid = native_results[i].first; diff --git a/be/test/olap/vector_search/vector_search_utils.cpp b/be/test/olap/vector_search/vector_search_utils.cpp index 6ee03666bbf..49c60296a8e 100644 --- a/be/test/olap/vector_search/vector_search_utils.cpp +++ b/be/test/olap/vector_search/vector_search_utils.cpp @@ -29,6 +29,7 @@ namespace doris::vector_search_utils { static void accumulate(double x, double y, double& sum) { sum += (x - y) * (x - y); } + static double finalize(double sum) { return sqrt(sum); } @@ -246,4 +247,14 @@ float get_radius_from_matrix(const float* vector, int dim, return radius; } + +std::pair<std::unique_ptr<MockTabletIndex>, std::shared_ptr<segment_v2::AnnIndexReader>> +create_tmp_ann_index_reader(std::map<std::string, std::string> properties) { + auto mock_tablet_index = std::make_unique<MockTabletIndex>(); + mock_tablet_index->_properties = properties; + auto mock_index_file_reader = std::make_shared<MockIndexFileReader>(); + auto ann_reader = std::make_shared<segment_v2::AnnIndexReader>(mock_tablet_index.get(), + mock_index_file_reader); + return std::make_pair(std::move(mock_tablet_index), ann_reader); +} } // namespace doris::vector_search_utils \ No newline at end of file diff --git a/be/test/olap/vector_search/vector_search_utils.h b/be/test/olap/vector_search/vector_search_utils.h index 8fd79997819..bd4b02ad0a7 100644 --- a/be/test/olap/vector_search/vector_search_utils.h +++ b/be/test/olap/vector_search/vector_search_utils.h @@ -27,8 +27,9 @@ #include <thrift/protocol/TDebugProtocol.h> #include <thrift/protocol/TJSONProtocol.h> -#include <iostream> #include <memory> +#include <string> +#include <utility> #include "common/object_pool.h" #include "olap/rowset/segment_v2/ann_index_iterator.h" @@ -150,42 +151,22 @@ public: MOCK_METHOD(Status, read_from_index, (const doris::segment_v2::IndexParam& param), (override)); MOCK_METHOD(Status, range_search, (const segment_v2::RangeSearchParams& params, - const segment_v2::CustomSearchParams& custom_params, + const VectorSearchUserParams& custom_params, segment_v2::RangeSearchResult* result), (override)); private: io::IOContext _io_ctx_mock; }; + +class MockAnnIndexReader : public doris::segment_v2::AnnIndexReader {}; + +std::pair<std::unique_ptr<MockTabletIndex>, std::shared_ptr<segment_v2::AnnIndexReader>> +create_tmp_ann_index_reader(std::map<std::string, std::string> properties); + } // namespace doris::vector_search_utils namespace doris::vectorized { -template <typename T> -T read_from_json(const std::string& json_str) { - auto memBufferIn = std::make_shared<apache::thrift::transport::TMemoryBuffer>( - reinterpret_cast<uint8_t*>(const_cast<char*>(json_str.data())), - static_cast<uint32_t>(json_str.size())); - auto jsonProtocolIn = std::make_shared<apache::thrift::protocol::TJSONProtocol>(memBufferIn); - T params; - params.read(jsonProtocolIn.get()); - return params; -} - -template <typename T> -void write_to_json(const std::string& path, std::string name, const T& expr) { - auto memBuffer = std::make_shared<apache::thrift::transport::TMemoryBuffer>(); - auto jsonProtocol = std::make_shared<apache::thrift::protocol::TJSONProtocol>(memBuffer); - - expr.write(jsonProtocol.get()); - uint8_t* buf; - uint32_t size; - memBuffer->getBuffer(&buf, &size); - std::string file_path = fmt::format("{}/{}.json", path, name); - std::ofstream ofs(file_path, std::ios::binary); - ofs.write(reinterpret_cast<const char*>(buf), size); - ofs.close(); - std::cout << fmt::format("Serialized JSON written to {}\n", file_path); -} class VectorSearchTest : public ::testing::Test { public: diff --git a/be/test/olap/vector_search/virtual_column_iterator_test.cpp b/be/test/olap/vector_search/virtual_column_iterator_test.cpp index 22f33db3bed..d860a7a7ad8 100644 --- a/be/test/olap/vector_search/virtual_column_iterator_test.cpp +++ b/be/test/olap/vector_search/virtual_column_iterator_test.cpp @@ -45,7 +45,7 @@ TEST_F(VectorSearchTest, TestDefaultConstructor) { } // Test with a materialized int32_t column -TEST_F(VectorSearchTest, TestWithint32_tColumn) { +TEST_F(VectorSearchTest, ReadByRowIdsint32_tColumn) { VirtualColumnIterator iterator; // Create a materialized int32_t column with values [10, 20, 30, 40, 50] @@ -78,7 +78,7 @@ TEST_F(VectorSearchTest, TestWithint32_tColumn) { } // Test with a String column -TEST_F(VectorSearchTest, TestWithStringColumn) { +TEST_F(VectorSearchTest, ReadByRowIdsStringColumn) { VirtualColumnIterator iterator; // Create a materialized String column @@ -114,7 +114,7 @@ TEST_F(VectorSearchTest, TestWithStringColumn) { } // Test with empty rowids array -TEST_F(VectorSearchTest, TestEmptyRowIds) { +TEST_F(VectorSearchTest, ReadByRowIdsEmptyRowIds) { VirtualColumnIterator iterator; // Create a materialized int32_t column with values [10, 20, 30, 40, 50] @@ -180,7 +180,7 @@ TEST_F(VectorSearchTest, TestLargeRowset) { } } -TEST_F(VectorSearchTest, TestNoContinueRowIds) { +TEST_F(VectorSearchTest, ReadByRowIdsNoContinueRowIds) { // Create a column with 1000 values (0-999) auto column = ColumnVector<int32_t>::create(); auto labels = std::make_unique<std::vector<uint64_t>>(); @@ -276,4 +276,72 @@ TEST_F(VectorSearchTest, TestNoContinueRowIds) { } } +TEST_F(VectorSearchTest, NextBatchTest1) { + VirtualColumnIterator iterator; + + // 构造一个有100行的int32列,值为0~99 + auto int_column = vectorized::ColumnVector<int32_t>::create(); + auto labels = std::make_unique<std::vector<uint64_t>>(); + for (int i = 0; i < 100; ++i) { + int_column->insert(i); + labels->push_back(i); + } + iterator.prepare_materialization(std::move(int_column), std::move(labels)); + + // 1. seek到第10行,next_batch读取10行 + { + vectorized::MutableColumnPtr dst = vectorized::ColumnVector<int32_t>::create(); + Status st = iterator.seek_to_ordinal(10); + ASSERT_TRUE(st.ok()); + size_t rows_read = 10; + bool has_null = false; + st = iterator.next_batch(&rows_read, dst, &has_null); + ASSERT_TRUE(st.ok()); + ASSERT_EQ(rows_read, 10); + ASSERT_EQ(dst->size(), 10); + for (int i = 0; i < 10; ++i) { + ASSERT_EQ(dst->get_int(i), 10 + i); + } + } + + // 2. seek到第85行,next_batch读取10行(只剩5行可读) + { + vectorized::MutableColumnPtr dst = vectorized::ColumnVector<int32_t>::create(); + Status st = iterator.seek_to_ordinal(85); + ASSERT_TRUE(st.ok()); + size_t rows_read = 10; + bool has_null = false; + st = iterator.next_batch(&rows_read, dst, &has_null); + ASSERT_TRUE(st.ok()); + ASSERT_EQ(rows_read, 10); + ASSERT_EQ(dst->size(), 10); + for (int i = 0; i < 10; ++i) { + EXPECT_EQ(dst->get_int(i), 85 + i); + } + } + + // 3. seek到第0行,next_batch读取全部100行 + { + vectorized::MutableColumnPtr dst = vectorized::ColumnVector<int32_t>::create(); + Status st = iterator.seek_to_ordinal(0); + ASSERT_TRUE(st.ok()); + size_t rows_read = 100; + bool has_null = false; + st = iterator.next_batch(&rows_read, dst, &has_null); + ASSERT_TRUE(st.ok()); + ASSERT_EQ(rows_read, 100); + ASSERT_EQ(dst->size(), 100); + for (int i = 0; i < 100; ++i) { + ASSERT_EQ(dst->get_int(i), i); + } + } + + // 4. seek到越界位置(如100),应该报错 + { + vectorized::MutableColumnPtr dst = vectorized::ColumnVector<int32_t>::create(); + Status st = iterator.seek_to_ordinal(100); + ASSERT_EQ(st.ok(), false); + } +} + } // namespace doris::vectorized \ No newline at end of file diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVirtualColumnsIntoOlapScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVirtualColumnsIntoOlapScan.java index 536f3e79f35..1bfd71e920a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVirtualColumnsIntoOlapScan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVirtualColumnsIntoOlapScan.java @@ -69,23 +69,24 @@ public class PushDownVirtualColumnsIntoOlapScan implements RewriteRuleFactory { // 3. replace filter // 4. replace project Map<Expression, Expression> replaceMap = Maps.newHashMap(); + ImmutableList.Builder<NamedExpression> virtualColumnsBuilder = ImmutableList.builder(); for (Expression conjunct : filter.getConjuncts()) { - Set<Expression> l2Distances = conjunct.collect(L2Distance.class::isInstance); - for (Expression l2Distance : l2Distances) { - if (replaceMap.containsKey(l2Distance)) { + // Set<Expression> l2Distances = conjunct.collect(L2Distance.class::isInstance); + // Set<Expression> innerProducts = conjunct.collect(InnerProduct.class::isInstance); + Set<Expression> distanceFunctions = conjunct.collect( + e -> e instanceof L2Distance || e instanceof InnerProduct); + for (Expression distanceFunction : distanceFunctions) { + if (replaceMap.containsKey(distanceFunction)) { continue; } - Alias alias = new Alias(l2Distance); - replaceMap.put(l2Distance, alias.toSlot()); + Alias alias = new Alias(distanceFunction); + replaceMap.put(distanceFunction, alias.toSlot()); + virtualColumnsBuilder.add(alias); } } if (replaceMap.isEmpty()) { return null; } - ImmutableList.Builder<NamedExpression> virtualColumnsBuilder = ImmutableList.builder(); - for (Expression expression : replaceMap.values()) { - virtualColumnsBuilder.add((NamedExpression) expression); - } logicalOlapScan = logicalOlapScan.withVirtualColumns(virtualColumnsBuilder.build()); Set<Expression> conjuncts = ExpressionUtils.replace(filter.getConjuncts(), replaceMap); Plan plan = filter.withConjunctsAndChild(conjuncts, logicalOlapScan); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/IndexDefinition.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/IndexDefinition.java index bc8e533e422..1aeb0770be2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/IndexDefinition.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/commands/info/IndexDefinition.java @@ -136,9 +136,22 @@ public class IndexDefinition { public void checkColumn(ColumnDefinition column, KeysType keysType, boolean enableUniqueKeyMergeOnWrite, TInvertedIndexFileStorageFormat invertedIndexFileStorageFormat) throws AnalysisException { + if (indexType == IndexType.ANN) { + String indexColName = column.getName(); + caseSensitivityCols.add(indexColName); + DataType colType = column.getType(); + if (!colType.isArrayType()) { + throw new AnalysisException("ANN index column must be array type, invalid index: " + name); + } + DataType itemType = ((ArrayType) colType).getItemType(); + if (!itemType.isFloatType()) { + throw new AnalysisException("ANN index column item type must be float type, invalid index: " + name); + } + return; + } + if (indexType == IndexType.BITMAP || indexType == IndexType.INVERTED - || indexType == IndexType.BLOOMFILTER || indexType == IndexType.NGRAM_BF - || indexType == IndexType.ANN) { + || indexType == IndexType.BLOOMFILTER || indexType == IndexType.NGRAM_BF) { String indexColName = column.getName(); caseSensitivityCols.add(indexColName); DataType colType = column.getType(); @@ -148,10 +161,6 @@ public class IndexDefinition { + " index. " + "invalid index: " + name); } - if (indexType == IndexType.ANN && !colType.isArrayType()) { - throw new AnalysisException("Ann index column must be array type, invalid index: " + name); - } - // In inverted index format v1, each subcolumn of a variant has its own index file, leading to high IOPS. // when the subcolumn type changes, it may result in missing files, causing link file failure. // There are two cases in which the inverted index format v1 is not supported: diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java index 69a453a3b16..ad1aee3a04c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java @@ -739,6 +739,10 @@ public class SessionVariable implements Serializable, Writable { public static final String SQL_CONVERTOR_CONFIG = "sql_convertor_config"; + public static final String HNSW_EF_SEARCH = "hnsw_ef_search"; + public static final String HNSW_CHECK_RELATIVE_DISTANCE = "hnsw_check_relative_distance"; + public static final String HNSW_BOUNDED_QUEUE = "hnsw_bounded_queue"; + /** * If set false, user couldn't submit analyze SQL and FE won't allocate any related resources. */ @@ -2611,6 +2615,22 @@ public class SessionVariable implements Serializable, Writable { return enableESParallelScroll; } + @VariableMgr.VarAttr(name = HNSW_EF_SEARCH, needForward = true, + description = {"HNSW索引的EF搜索参数,控制搜索的精度和速度", + "HNSW index EF search parameter, controls the precision and speed of the search"}) + public int hnswEFSearch = 16; + + @VariableMgr.VarAttr(name = HNSW_CHECK_RELATIVE_DISTANCE, needForward = true, + description = {"是否启用相对距离检查机制,以提升HNSW搜索的准确性", + "Enable relative distance checking to improve HNSW search accuracy"}) + public boolean hnswCheckRelativeDistance = true; + + @VariableMgr.VarAttr(name = HNSW_BOUNDED_QUEUE, needForward = true, + description = {"是否使用有界优先队列来优化HNSW的搜索性能", + "Whether to use a bounded priority queue to optimize HNSW search performance"}) + public boolean hnswBoundedQueue = true; + + // If this fe is in fuzzy mode, then will use initFuzzyModeVariables to generate some variables, // not the default value set in the code. @SuppressWarnings("checkstyle:Indentation") @@ -4218,6 +4238,11 @@ public class SessionVariable implements Serializable, Writable { tResult.setMinimumOperatorMemoryRequiredKb(minimumOperatorMemoryRequiredKB); tResult.setExchangeMultiBlocksByteSize(exchangeMultiBlocksByteSize); + + tResult.setHnswEfSearch(hnswEFSearch); + tResult.setHnswCheckRelativeDistance(hnswCheckRelativeDistance); + tResult.setHnswBoundedQueue(hnswBoundedQueue); + return tResult; } diff --git a/gensrc/thrift/PaloInternalService.thrift b/gensrc/thrift/PaloInternalService.thrift index 81e4d1f877c..73c268965e8 100644 --- a/gensrc/thrift/PaloInternalService.thrift +++ b/gensrc/thrift/PaloInternalService.thrift @@ -395,6 +395,10 @@ struct TQueryOptions { 164: optional bool check_orc_init_sargs_success = false 165: optional i32 exchange_multi_blocks_byte_size = 262144 + 166: optional i32 hnsw_ef_search = 16; + 167: optional bool hnsw_check_relative_distance = true; + 168: optional bool hnsw_bounded_queue = true; + // For cloud, to control if the content would be written into file cache // In write path, to control if the content would be written into file cache. // In read path, read from file cache or remote storage when execute query. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org