This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch vector-index-dev in repository https://gitbox.apache.org/repos/asf/doris.git
commit b012878e0ab7719393469de518df054dcb1248a3 Author: hezhiqiang <hezhiqi...@selectdb.com> AuthorDate: Thu Jun 12 15:10:30 2025 +0800 Fix rebase master compile --- be/src/olap/rowset/segment_v2/segment_iterator.cpp | 2 +- be/src/pipeline/exec/operator.h | 7 +- be/src/pipeline/pipeline_fragment_context.cpp | 3 +- be/src/runtime/exec_env.h | 2 +- be/src/vec/exprs/virtual_slot_ref.cpp | 4 +- be/src/vector/faiss_vector_index.cpp | 59 +++++-- be/src/vector/metric.cpp | 4 +- be/src/vector/metric.h | 2 +- be/src/vector/vector_index.h | 1 + .../olap/vector_search/ann_index_reader_test.cpp | 7 +- .../olap/vector_search/ann_index_smoke_test.cpp | 5 +- .../olap/vector_search/ann_range_search_test.cpp | 69 ++++---- .../vector_search/ann_topn_descriptor_test.cpp | 3 +- .../olap/vector_search/faiss_vector_index_test.cpp | 177 +++++++++++++-------- be/test/olap/vector_search/vector_search_utils.cpp | 41 +++-- be/test/olap/vector_search/vector_search_utils.h | 12 +- .../vector_search/virtual_column_iterator_test.cpp | 74 ++++----- .../org/apache/doris/analysis/DescriptorTable.java | 5 + .../doris/nereids/rules/analysis/BindRelation.java | 10 +- .../rewrite/PushDownVectorTopNIntoOlapScan.java | 6 +- .../trees/plans/logical/LogicalFileScan.java | 50 +++--- .../trees/plans/logical/LogicalHudiScan.java | 42 +++-- .../physical/PhysicalLazyMaterializeOlapScan.java | 6 +- 23 files changed, 342 insertions(+), 249 deletions(-) diff --git a/be/src/olap/rowset/segment_v2/segment_iterator.cpp b/be/src/olap/rowset/segment_v2/segment_iterator.cpp index 36866f9704d..dc9187eda33 100644 --- a/be/src/olap/rowset/segment_v2/segment_iterator.cpp +++ b/be/src/olap/rowset/segment_v2/segment_iterator.cpp @@ -642,7 +642,7 @@ Status SegmentIterator::_apply_ann_topn_predicate() { auto index_reader = ann_index_iterator->get_reader(); auto ann_index_reader = dynamic_cast<AnnIndexReader*>(index_reader.get()); DCHECK(ann_index_reader != nullptr); - if (ann_index_reader->get_metric_type() == Metric::INNER_PRODUCT) { + if (ann_index_reader->get_metric_type() == Metric::IP) { if (_ann_topn_runtime->is_asc()) { LOG_INFO("Asc topn for inner product can not be evaluated by ann index"); return Status::OK(); diff --git a/be/src/pipeline/exec/operator.h b/be/src/pipeline/exec/operator.h index 3de1be0564d..559226d21a6 100644 --- a/be/src/pipeline/exec/operator.h +++ b/be/src/pipeline/exec/operator.h @@ -792,11 +792,16 @@ public: _tuple_ids(tnode.row_tuples), _row_descriptor(descs, tnode.row_tuples, tnode.nullable_tuples), _resource_profile(tnode.resource_profile), - _limit(tnode.limit) { + _limit(tnode.limit) { if (tnode.__isset.output_tuple_id) { + LOG_INFO("Operator {}, node_id {}, output_tuple_id {}", + this->_op_name, tnode.node_id, tnode.output_tuple_id); _output_row_descriptor.reset(new RowDescriptor(descs, {tnode.output_tuple_id}, {true})); } if (!tnode.intermediate_output_tuple_id_list.empty()) { + LOG_INFO("Operator {}, node_id {}, intermediate_output_tuple_id_list: [{}]", + this->_op_name, tnode.node_id, + fmt::join(tnode.intermediate_output_tuple_id_list, ",")); // common subexpression elimination _intermediate_output_row_descriptor.reserve( tnode.intermediate_output_tuple_id_list.size()); diff --git a/be/src/pipeline/pipeline_fragment_context.cpp b/be/src/pipeline/pipeline_fragment_context.cpp index a86bf58ba06..c31f5ddde78 100644 --- a/be/src/pipeline/pipeline_fragment_context.cpp +++ b/be/src/pipeline/pipeline_fragment_context.cpp @@ -107,7 +107,6 @@ #include "runtime/fragment_mgr.h" #include "runtime/runtime_state.h" #include "runtime/stream_load/new_load_stream_mgr.h" -#include "runtime/stream_load/stream_load_context.h" #include "runtime/thread_context.h" #include "runtime_filter/runtime_filter_mgr.h" #include "service/backend_options.h" @@ -118,6 +117,7 @@ #include "vec/common/sort/topn_sorter.h" #include "vec/runtime/vdata_stream_mgr.h" #include "vec/spill/spill_stream.h" +#include "thrift/protocol/TDebugProtocol.h" namespace doris::pipeline { #include "common/compile_check_begin.h" @@ -300,6 +300,7 @@ Status PipelineFragmentContext::prepare(const doris::TPipelineFragmentParams& re DCHECK(request.__isset.desc_tbl); RETURN_IF_ERROR(DescriptorTbl::create(_runtime_state->obj_pool(), request.desc_tbl, &_desc_tbl)); + LOG_INFO("Input desc_tbl: {}", apache::thrift::ThriftDebugString(request.desc_tbl)); } _runtime_state->set_desc_tbl(_desc_tbl); _runtime_state->set_num_per_fragment_instances(request.num_senders); diff --git a/be/src/runtime/exec_env.h b/be/src/runtime/exec_env.h index a6c0924d652..2949e508820 100644 --- a/be/src/runtime/exec_env.h +++ b/be/src/runtime/exec_env.h @@ -31,6 +31,7 @@ #include "olap/memtable_memory_limiter.h" #include "olap/options.h" #include "olap/rowset/segment_v2/index_writer.h" +#include "olap/rowset/segment_v2/tmp_file_dirs.h" #include "olap/tablet_fwd.h" #include "pipeline/pipeline_tracing.h" #include "runtime/cluster_info.h" @@ -65,7 +66,6 @@ class HdfsMgr; namespace segment_v2 { class InvertedIndexSearcherCache; class InvertedIndexQueryCache; -class TmpFileDirs; } // namespace segment_v2 namespace kerberos { diff --git a/be/src/vec/exprs/virtual_slot_ref.cpp b/be/src/vec/exprs/virtual_slot_ref.cpp index 6d320b2b96f..1fb6b23df4f 100644 --- a/be/src/vec/exprs/virtual_slot_ref.cpp +++ b/be/src/vec/exprs/virtual_slot_ref.cpp @@ -86,8 +86,8 @@ Status VirtualSlotRef::prepare(doris::RuntimeState* state, const doris::RowDescr _column_id = desc.get_column_id(_slot_id, context->force_materialize_slot()); if (_column_id < 0) { return Status::Error<ErrorCode::INTERNAL_ERROR>( - "VirtualSlotRef {} has invalid slot id: {}, desc: {}, slot_desc: {}, desc_tbl: {}", - *_column_name, _slot_id, desc.debug_string(), slot_desc->debug_string(), + "VirtualSlotRef {} has invalid slot id: {}.\nslot_desc:\n{},\ndesc:\n{},\ndesc_tbl:\n{}", + *_column_name, _slot_id, slot_desc->debug_string(), desc.debug_string(), state->desc_tbl().debug_string()); } const TExpr& expr = *slot_desc->get_virtual_column_expr(); diff --git a/be/src/vector/faiss_vector_index.cpp b/be/src/vector/faiss_vector_index.cpp index d9eec3d78d1..7cb8b49faef 100644 --- a/be/src/vector/faiss_vector_index.cpp +++ b/be/src/vector/faiss_vector_index.cpp @@ -129,6 +129,18 @@ doris::Status FaissVectorIndex::add(int n, const float* vec) { void FaissVectorIndex::set_build_params(const FaissBuildParameter& params) { _dimension = params.d; + switch (params.metric_type) { + case FaissBuildParameter::MetricType::L2: + _metric = Metric::L2; + break; + case FaissBuildParameter::MetricType::IP: + _metric = Metric::IP; + break; + default: + throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT, "Unsupported metric type: {}", + static_cast<int>(params.metric_type)); + break; + } if (params.index_type == FaissBuildParameter::IndexType::BruteForce) { if (params.metric_type == FaissBuildParameter::MetricType::L2) { _index = std::make_unique<faiss::IndexFlatL2>(params.d); @@ -217,10 +229,25 @@ doris::Status FaissVectorIndex::ann_topn_search(const float* query_vec, int k, size_t roaring_cardinality = result.roaring->cardinality(); result.distances = std::make_unique<float[]>(roaring_cardinality); result.row_ids = std::make_unique<std::vector<uint64_t>>(); - - for (size_t i = 0; i < roaring_cardinality; ++i) { - result.row_ids->push_back(labels[i]); - result.distances[i] = std::sqrt(distances[i]); // Convert squared distance to actual distance + + if (_metric == Metric::L2) { + // For inner product, we need to convert the distance to the actual distance. + // The distance returned by Faiss is actually the squared distance. + // So we need to take the square root of the squared distance. + for (size_t i = 0; i < roaring_cardinality; ++i) { + result.row_ids->push_back(labels[i]); + result.distances[i] = distances[i]; // Convert squared distance to actual distance + } + } else if (_metric == Metric::IP) { + // For L2, we can use the distance directly. + for (size_t i = 0; i < roaring_cardinality; ++i) { + result.row_ids->push_back(labels[i]); + result.distances[i] = + std::sqrt(distances[i]); // Convert squared distance to actual distance + } + } else { + throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT, "Unsupported metric type: {}", + static_cast<int>(_metric)); } DCHECK(result.row_ids->size() == result.roaring->cardinality()) @@ -263,11 +290,25 @@ doris::Status FaissVectorIndex::range_search(const float* query_vec, const float std::unique_ptr<float[]> distances_ptr = std::make_unique<float[]>(end - begin); float* distances = distances_ptr.get(); auto roaring = std::make_shared<roaring::Roaring>(); - for (size_t i = begin; i < end; ++i) { - (*row_ids)[i] = native_search_result.labels[i]; - roaring->add(native_search_result.labels[i]); - // TODO: l2_distance and inner_product is different. - distances[i] = sqrt(native_search_result.distances[i]); + if (_metric == Metric::L2) { + // For inner product, we need to convert the distance to the actual distance. + // The distance returned by Faiss is actually the squared distance. + // So we need to take the square root of the squared distance. + for (size_t i = begin; i < end; ++i) { + (*row_ids)[i] = native_search_result.labels[i]; + roaring->add(native_search_result.labels[i]); + distances[i - begin] = sqrt(native_search_result.distances[i]); + } + } else if (_metric == Metric::IP) { + // For L2, we can use the distance directly. + for (size_t i = begin; i < end; ++i) { + (*row_ids)[i] = native_search_result.labels[i]; + roaring->add(native_search_result.labels[i]); + distances[i - begin] = native_search_result.distances[i]; + } + } else { + throw doris::Exception(doris::ErrorCode::INVALID_ARGUMENT, + "Unsupported metric type: {}", static_cast<int>(_metric)); } result.distances = std::move(distances_ptr); diff --git a/be/src/vector/metric.cpp b/be/src/vector/metric.cpp index 9d6a415111e..6f54f13db0c 100644 --- a/be/src/vector/metric.cpp +++ b/be/src/vector/metric.cpp @@ -27,7 +27,7 @@ std::string metric_to_string(Metric metric) { switch (metric) { case Metric::L2: return vectorized::L2Distance::name; - case Metric::INNER_PRODUCT: + case Metric::IP: return vectorized::InnerProduct::name; default: return "UNKNOWN"; @@ -38,7 +38,7 @@ Metric string_to_metric(const std::string& metric) { if (metric == vectorized::L2Distance::name) { return Metric::L2; } else if (metric == vectorized::InnerProduct::name) { - return Metric::INNER_PRODUCT; + return Metric::IP; } else { return Metric::UNKNOWN; } diff --git a/be/src/vector/metric.h b/be/src/vector/metric.h index b6c95bed675..71998aa1ecc 100644 --- a/be/src/vector/metric.h +++ b/be/src/vector/metric.h @@ -20,7 +20,7 @@ #include <string> namespace doris::segment_v2 { -enum class Metric { L2, INNER_PRODUCT, UNKNOWN }; +enum class Metric { L2, IP, UNKNOWN }; std::string metric_to_string(Metric metric); diff --git a/be/src/vector/vector_index.h b/be/src/vector/vector_index.h index 00e0904fdeb..63f9efa97ad 100644 --- a/be/src/vector/vector_index.h +++ b/be/src/vector/vector_index.h @@ -79,6 +79,7 @@ public: protected: // When adding vectors to the index, use this variable to check the dimension of the vectors. size_t _dimension = 0; + Metric _metric = Metric::L2; // Default metric is L2 distance }; } // namespace doris::segment_v2 \ No newline at end of file diff --git a/be/test/olap/vector_search/ann_index_reader_test.cpp b/be/test/olap/vector_search/ann_index_reader_test.cpp index 5b29f21e90c..e0fe926cf65 100644 --- a/be/test/olap/vector_search/ann_index_reader_test.cpp +++ b/be/test/olap/vector_search/ann_index_reader_test.cpp @@ -24,9 +24,8 @@ #include <string> #include "faiss_vector_index.h" -#include "olap/rowset/segment_v2/ann_index_iterator.h" +#include "olap/rowset/segment_v2/ann_index/ann_search_params.h" #include "olap/tablet_schema.h" -#include "runtime/runtime_state.h" #include "vector_search_utils.h" using namespace doris::vector_search_utils; @@ -65,13 +64,13 @@ TEST_F(VectorSearchTest, AnnIndexReaderRangeSearch) { roaring->add(i); } - doris::segment_v2::RangeSearchParams params; + doris::vectorized::RangeSearchParams params; params.radius = radius; params.query_value = query_value.data(); params.roaring = roaring.get(); doris::VectorSearchUserParams custom_params; custom_params.hnsw_ef_search = 16; - doris::segment_v2::RangeSearchResult result; + doris::vectorized::RangeSearchResult result; auto doris_faiss_vector_index = std::make_unique<doris::segment_v2::FaissVectorIndex>(); std::ignore = doris_faiss_vector_index->load(this->_ram_dir.get()); ann_index_reader->_vector_index = std::move(doris_faiss_vector_index); diff --git a/be/test/olap/vector_search/ann_index_smoke_test.cpp b/be/test/olap/vector_search/ann_index_smoke_test.cpp index b106a9bc15e..9f2593956f6 100644 --- a/be/test/olap/vector_search/ann_index_smoke_test.cpp +++ b/be/test/olap/vector_search/ann_index_smoke_test.cpp @@ -26,6 +26,7 @@ #include "faiss_vector_index.h" #include "olap/olap_common.h" +#include "olap/rowset/segment_v2/ann_index/ann_search_params.h" #include "olap/rowset/segment_v2/ann_index_writer.h" #include "olap/rowset/segment_v2/index_file_writer.h" #include "vector_index.h" @@ -136,8 +137,8 @@ TEST_F(AnnIndexTest, SmokeTest) { query_vec[i] = static_cast<float>(i); } - segment_v2::IndexSearchParameters params; - segment_v2::IndexSearchResult result; + vectorized::IndexSearchParameters params; + vectorized::IndexSearchResult result; ASSERT_TRUE(index2->ann_topn_search(query_vec.get(), 1, params, result)); EXPECT_TRUE(result.roaring->cardinality() == 1); EXPECT_TRUE(result.roaring->contains(0)); diff --git a/be/test/olap/vector_search/ann_range_search_test.cpp b/be/test/olap/vector_search/ann_range_search_test.cpp index f73d1df2606..da3b992cba8 100644 --- a/be/test/olap/vector_search/ann_range_search_test.cpp +++ b/be/test/olap/vector_search/ann_range_search_test.cpp @@ -27,18 +27,16 @@ #include <vector> #include "common/object_pool.h" +#include "olap/rowset/segment_v2/ann_index/ann_search_params.h" #include "olap/rowset/segment_v2/ann_index_iterator.h" #include "olap/rowset/segment_v2/ann_index_reader.h" #include "olap/rowset/segment_v2/column_reader.h" -#include "olap/rowset/segment_v2/index_file_reader.h" #include "olap/rowset/segment_v2/virtual_column_iterator.h" #include "olap/vector_search/vector_search_utils.h" #include "runtime/descriptors.h" #include "runtime/runtime_state.h" #include "vec/columns/column.h" #include "vec/columns/column_nothing.h" -#include "vec/columns/columns_number.h" -#include "vec/exprs/vectorized_fn_call.h" #include "vec/exprs/vexpr_fwd.h" #include "vec/functions/functions_comparison.h" @@ -833,22 +831,19 @@ TEST_F(VectorSearchTest, TestPrepareAnnRangeSearch) { ASSERT_TRUE(range_search_ctx->prepare(state.get(), row_desc).ok()); ASSERT_TRUE(range_search_ctx->open(state.get()).ok()); ASSERT_TRUE(range_search_ctx->prepare_ann_range_search(user_params).ok()); - std::shared_ptr<VectorizedFnCall> fn_call = - std::dynamic_pointer_cast<VectorizedFnCall>(range_search_ctx->root()); + ASSERT_TRUE(range_search_ctx->_ann_range_search_runtime.is_ann_range_search == true); + ASSERT_EQ(range_search_ctx->_ann_range_search_runtime.is_le_or_lt, false); + ASSERT_EQ(range_search_ctx->_ann_range_search_runtime.dst_col_idx, 3); + ASSERT_EQ(range_search_ctx->_ann_range_search_runtime.src_col_idx, 1); + ASSERT_EQ(range_search_ctx->_ann_range_search_runtime.radius, 10); - ASSERT_TRUE(fn_call->_ann_range_search_params.is_ann_range_search == true); - ASSERT_EQ(fn_call->_ann_range_search_params.is_le_or_lt, false); - ASSERT_EQ(fn_call->_ann_range_search_params.dst_col_idx, 3); - ASSERT_EQ(fn_call->_ann_range_search_params.src_col_idx, 1); - ASSERT_EQ(fn_call->_ann_range_search_params.radius, 10); - - doris::segment_v2::RangeSearchParams range_search_params = - fn_call->_ann_range_search_params.to_range_search_params(); - EXPECT_EQ(range_search_params.radius, 10.0f); + doris::vectorized::RangeSearchParams ann_range_search_runtime = + range_search_ctx->_ann_range_search_runtime.to_range_search_params(); + EXPECT_EQ(ann_range_search_runtime.radius, 10.0f); std::vector<int> query_array_groud_truth = {1, 2, 3, 4, 5, 6, 7, 20}; std::vector<int> query_array_f32; for (int i = 0; i < query_array_groud_truth.size(); ++i) { - query_array_f32.push_back(static_cast<int>(range_search_params.query_value[i])); + query_array_f32.push_back(static_cast<int>(ann_range_search_runtime.query_value[i])); } for (int i = 0; i < query_array_f32.size(); ++i) { EXPECT_EQ(query_array_f32[i], query_array_groud_truth[i]); @@ -872,14 +867,12 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch) { ASSERT_TRUE(range_search_ctx->open(state.get()).ok()); doris::VectorSearchUserParams user_params; ASSERT_TRUE(range_search_ctx->prepare_ann_range_search(user_params).ok()); - std::shared_ptr<VectorizedFnCall> fn_call = - std::dynamic_pointer_cast<VectorizedFnCall>(range_search_ctx->root()); - ASSERT_EQ(fn_call->_ann_range_search_params.user_params, user_params); - ASSERT_TRUE(fn_call->_ann_range_search_params.is_ann_range_search == true); - ASSERT_EQ(fn_call->_ann_range_search_params.is_le_or_lt, false); - ASSERT_EQ(fn_call->_ann_range_search_params.src_col_idx, 1); - ASSERT_EQ(fn_call->_ann_range_search_params.dst_col_idx, 3); - ASSERT_EQ(fn_call->_ann_range_search_params.radius, 10); + ASSERT_EQ(range_search_ctx->_ann_range_search_runtime.user_params, user_params); + ASSERT_TRUE(range_search_ctx->_ann_range_search_runtime.is_ann_range_search == true); + ASSERT_EQ(range_search_ctx->_ann_range_search_runtime.is_le_or_lt, false); + ASSERT_EQ(range_search_ctx->_ann_range_search_runtime.src_col_idx, 1); + ASSERT_EQ(range_search_ctx->_ann_range_search_runtime.dst_col_idx, 3); + ASSERT_EQ(range_search_ctx->_ann_range_search_runtime.radius, 10); std::vector<ColumnId> idx_to_cid; idx_to_cid.resize(4); @@ -910,20 +903,20 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch) { // 1. predicate is dist >= 10, so it is not a within range search // 2. return 10 results EXPECT_CALL(*mock_ann_index_iter, - range_search(testing::Truly([](const doris::segment_v2::RangeSearchParams& params) { + range_search(testing::Truly([](const doris::vectorized::RangeSearchParams& params) { return params.is_le_or_lt == false && params.radius == 10.0f; }), testing::_, testing::_)) - .WillOnce(testing::Invoke([](const doris::segment_v2::RangeSearchParams& params, + .WillOnce(testing::Invoke([](const doris::vectorized::RangeSearchParams& params, const doris::VectorSearchUserParams& custom_params, - doris::segment_v2::RangeSearchResult* result) { + doris::vectorized::RangeSearchResult* result) { result->roaring = std::make_shared<roaring::Roaring>(); result->row_ids = nullptr; result->distance = nullptr; return Status::OK(); })); - ASSERT_TRUE(range_search_ctx->root() + ASSERT_TRUE(range_search_ctx ->evaluate_ann_range_search(cid_to_index_iterators, idx_to_cid, column_iterators, row_bitmap) .ok()); @@ -964,15 +957,11 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch2) { ASSERT_TRUE(range_search_ctx->open(state.get()).ok()); doris::VectorSearchUserParams user_params; ASSERT_TRUE(range_search_ctx->prepare_ann_range_search(user_params).ok()); - - std::shared_ptr<VectorizedFnCall> fn_call = - std::dynamic_pointer_cast<VectorizedFnCall>(range_search_ctx->root()); - - ASSERT_TRUE(fn_call->_ann_range_search_params.is_ann_range_search == true); - ASSERT_EQ(fn_call->_ann_range_search_params.is_le_or_lt, true); - ASSERT_EQ(fn_call->_ann_range_search_params.src_col_idx, 1); - ASSERT_EQ(fn_call->_ann_range_search_params.dst_col_idx, 3); - ASSERT_EQ(fn_call->_ann_range_search_params.radius, 10); + ASSERT_TRUE(range_search_ctx->_ann_range_search_runtime.is_ann_range_search == true); + ASSERT_EQ(range_search_ctx->_ann_range_search_runtime.is_le_or_lt, true); + ASSERT_EQ(range_search_ctx->_ann_range_search_runtime.src_col_idx, 1); + ASSERT_EQ(range_search_ctx->_ann_range_search_runtime.dst_col_idx, 3); + ASSERT_EQ(range_search_ctx->_ann_range_search_runtime.radius, 10); std::vector<ColumnId> idx_to_cid; idx_to_cid.resize(4); @@ -1001,13 +990,13 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch2) { // 1. predicate is dist >= 10, so it is not a within range search // 2. return 10 results EXPECT_CALL(*mock_ann_index_iter, - range_search(testing::Truly([](const doris::segment_v2::RangeSearchParams& params) { + range_search(testing::Truly([](const doris::vectorized::RangeSearchParams& params) { return params.is_le_or_lt == true && params.radius == 10.0f; }), testing::_, testing::_)) - .WillOnce(testing::Invoke([](const doris::segment_v2::RangeSearchParams& params, + .WillOnce(testing::Invoke([](const doris::vectorized::RangeSearchParams& params, const doris::VectorSearchUserParams& custom_params, - doris::segment_v2::RangeSearchResult* result) { + doris::vectorized::RangeSearchResult* result) { size_t num_results = 10; result->roaring = std::make_shared<roaring::Roaring>(); result->row_ids = std::make_unique<std::vector<uint64_t>>(); @@ -1019,7 +1008,7 @@ TEST_F(VectorSearchTest, TestEvaluateAnnRangeSearch2) { return Status::OK(); })); - ASSERT_TRUE(range_search_ctx->root() + ASSERT_TRUE(range_search_ctx ->evaluate_ann_range_search(cid_to_index_iterators, idx_to_cid, column_iterators, row_bitmap) .ok()); diff --git a/be/test/olap/vector_search/ann_topn_descriptor_test.cpp b/be/test/olap/vector_search/ann_topn_descriptor_test.cpp index a562c7d7686..376f10d7a79 100644 --- a/be/test/olap/vector_search/ann_topn_descriptor_test.cpp +++ b/be/test/olap/vector_search/ann_topn_descriptor_test.cpp @@ -26,6 +26,7 @@ #include <iostream> #include <memory> +#include "olap/rowset/segment_v2/ann_index/ann_search_params.h" #include "vec/exprs/ann_topn_runtime.h" #include "vec/exprs/virtual_slot_ref.h" #include "vector_search_utils.h" @@ -153,7 +154,7 @@ TEST_F(VectorSearchTest, AnnTopNRuntimeEvaluateTopN) { EXPECT_CALL(*_ann_index_iterator, read_from_index(testing::_)) .Times(1) .WillOnce(testing::Invoke([](const segment_v2::IndexParam& value) { - auto* ann_param = std::get<segment_v2::AnnIndexParam*>(value); + auto* ann_param = std::get<vectorized::AnnIndexParam*>(value); ann_param->distance = std::make_unique<std::vector<float>>(); ann_param->row_ids = std::make_unique<std::vector<uint64_t>>(); for (size_t i = 0; i < 10; ++i) { diff --git a/be/test/olap/vector_search/faiss_vector_index_test.cpp b/be/test/olap/vector_search/faiss_vector_index_test.cpp index 7c70e174f0a..72bc3aefb74 100644 --- a/be/test/olap/vector_search/faiss_vector_index_test.cpp +++ b/be/test/olap/vector_search/faiss_vector_index_test.cpp @@ -28,6 +28,8 @@ #include <string> #include <vector> +#include "olap/rowset/segment_v2/ann_index/ann_search_params.h" +#include "util/metrics.h" #include "vector_index.h" #include "vector_search_utils.h" @@ -170,7 +172,11 @@ TEST_F(VectorSearchTest, CompareResultWithNativeFaiss1) { std::vector<float> native_distances(top_k); std::vector<faiss::idx_t> native_indices(top_k); native_index->search(1, query_vec, top_k, native_distances.data(), native_indices.data()); - + size_t cnt = std::count_if(native_indices.begin(), native_indices.end(), + [](faiss::idx_t idx) { return idx != -1; }); + for (size_t i = 0; i < cnt; ++i) { + native_distances[i] = std::sqrt(native_distances[i]); + } // Step 4: Compare results vector_search_utils::compare_search_results(doris_results, native_distances, native_indices); @@ -220,7 +226,11 @@ TEST_F(VectorSearchTest, CompareResultWithNativeFaiss2) { std::vector<float> native_distances(top_k, -1); std::vector<faiss::idx_t> native_indices(top_k, -1); native_index->search(1, query_vec, top_k, native_distances.data(), native_indices.data()); - + size_t cnt = std::count_if(native_indices.begin(), native_indices.end(), + [](faiss::idx_t idx) { return idx != -1; }); + for (size_t i = 0; i < cnt; ++i) { + native_distances[i] = std::sqrt(native_distances[i]); + } // Step 4: Compare results doris::vector_search_utils::compare_search_results(doris_results, native_distances, native_indices); @@ -289,73 +299,104 @@ TEST_F(VectorSearchTest, SearchAllVectors) { TEST_F(VectorSearchTest, CompRangeSearch) { size_t iterations = 25; + // 支持的metric类型集合 + std::vector<faiss::MetricType> metrics = { + faiss::METRIC_L2, faiss::METRIC_INNER_PRODUCT + // 如有更多metric可继续添加 + }; for (size_t i = 0; i < iterations; ++i) { - // Random parameters for each test iteration - std::random_device rd; - std::mt19937 gen(rd()); - size_t random_d = - std::uniform_int_distribution<>(1, 1024)(gen); // Random dimension from 32 to 256 - size_t random_m = - 4 << std::uniform_int_distribution<>(1, 4)(gen); // Random M (4, 8, 16, 32, 64) - size_t random_n = - std::uniform_int_distribution<>(500, 2000)(gen); // Random number of vectors - // Step 1: Create and build index - auto doris_index = std::make_unique<FaissVectorIndex>(); - - FaissBuildParameter params; - params.d = random_d; - params.m = random_m; - params.index_type = FaissBuildParameter::IndexType::HNSW; - doris_index->set_build_params(params); - - const int num_vectors = random_n; - std::vector<std::vector<float>> vectors; - for (int i = 0; i < num_vectors; i++) { - auto vec = vector_search_utils::generate_random_vector(params.d); - vectors.push_back(vec); - } - std::unique_ptr<faiss::Index> native_index = - std::make_unique<faiss::IndexHNSWFlat>(params.d, params.m); - doris::vector_search_utils::add_vectors_to_indexes_serial_mode(doris_index.get(), - native_index.get(), vectors); - - std::vector<float> query_vec = vectors.front(); - const float radius = doris::vector_search_utils::get_radius_from_matrix( - query_vec.data(), params.d, vectors, 0.4f); - - HNSWSearchParameters hnsw_params; - hnsw_params.ef_search = 16; // Set efSearch for better accuracy - hnsw_params.roaring = nullptr; // No selector for this test - hnsw_params.is_le_or_lt = true; - IndexSearchResult doris_result; - std::ignore = - doris_index->range_search(query_vec.data(), radius, hnsw_params, doris_result); + for (auto metric : metrics) { + // Random parameters for each test iteration + std::random_device rd; + std::mt19937 gen(rd()); + size_t random_d = std::uniform_int_distribution<>(1, 1024)(gen); + size_t random_m = 4 << std::uniform_int_distribution<>(1, 4)(gen); + size_t random_n = std::uniform_int_distribution<>(500, 2000)(gen); + + // Step 1: Create and build index + auto doris_index = std::make_unique<FaissVectorIndex>(); + FaissBuildParameter params; + params.d = random_d; + params.m = random_m; + params.index_type = FaissBuildParameter::IndexType::HNSW; + if (metric == faiss::METRIC_L2) { + params.metric_type = FaissBuildParameter::MetricType::L2; + } else if (metric == faiss::METRIC_INNER_PRODUCT) { + params.metric_type = FaissBuildParameter::MetricType::IP; + } else { + throw std::runtime_error(fmt::format("Unsupported metric type: {}", metric)); + } + doris_index->set_build_params(params); - faiss::SearchParametersHNSW search_params_native; - search_params_native.efSearch = hnsw_params.ef_search; - faiss::RangeSearchResult search_result_native(1, true); - native_index->range_search(1, query_vec.data(), radius * radius, &search_result_native, - &search_params_native); - std::vector<std::pair<int, float>> native_results; - size_t begin = search_result_native.lims[0]; - size_t end = search_result_native.lims[1]; - for (size_t i = begin; i < end; i++) { - native_results.push_back( - {search_result_native.labels[i], search_result_native.distances[i]}); - } + const int num_vectors = random_n; + std::vector<std::vector<float>> vectors; + for (int i = 0; i < num_vectors; i++) { + auto vec = vector_search_utils::generate_random_vector(params.d); + vectors.push_back(vec); + } + // 创建native index时指定metric + std::unique_ptr<faiss::Index> native_index = nullptr; + if (metric == faiss::METRIC_L2) { + native_index = std::make_unique<faiss::IndexHNSWFlat>(params.d, params.m, + faiss::METRIC_L2); + } else if (metric == faiss::METRIC_INNER_PRODUCT) { + native_index = std::make_unique<faiss::IndexHNSWFlat>(params.d, params.m, + faiss::METRIC_INNER_PRODUCT); + } else { + throw std::runtime_error(fmt::format("Unsupported metric type: {}", metric)); + } + doris::vector_search_utils::add_vectors_to_indexes_serial_mode( + doris_index.get(), native_index.get(), vectors); + + std::vector<float> query_vec = vectors.front(); + float radius = 0; + + radius = doris::vector_search_utils::get_radius_from_matrix( + query_vec.data(), params.d, vectors, 0.4f, metric); + + HNSWSearchParameters hnsw_params; + hnsw_params.ef_search = 16; + hnsw_params.roaring = nullptr; + hnsw_params.is_le_or_lt = true; + IndexSearchResult doris_result; + std::ignore = + doris_index->range_search(query_vec.data(), radius, hnsw_params, doris_result); + + faiss::SearchParametersHNSW search_params_native; + search_params_native.efSearch = hnsw_params.ef_search; + faiss::RangeSearchResult search_result_native(1, true); + // 对于L2,radius要平方;对于IP,直接用 + float faiss_radius = (metric == faiss::METRIC_L2) ? radius * radius : radius; + native_index->range_search(1, query_vec.data(), faiss_radius, &search_result_native, + &search_params_native); + + std::vector<std::pair<int, float>> native_results; + size_t begin = search_result_native.lims[0]; + size_t end = search_result_native.lims[1]; + for (size_t i = begin; i < end; i++) { + native_results.push_back( + {search_result_native.labels[i], search_result_native.distances[i]}); + } - // Make sure result is same - ASSERT_NEAR(doris_result.roaring->cardinality(), native_results.size(), 1) - << fmt::format("\nd: {}, m: {}, n: {}", random_d, random_m, random_n); - ASSERT_EQ(doris_result.distances != nullptr, true); - if (doris_result.roaring->cardinality() == native_results.size()) { - for (size_t i = 0; i < native_results.size(); i++) { - const size_t rowid = native_results[i].first; - const float dis = native_results[i].second; - ASSERT_EQ(doris_result.roaring->contains(rowid), true) - << "Row ID mismatch at rank " << i; - ASSERT_FLOAT_EQ(doris_result.distances[i], sqrt(dis)) - << "Distance mismatch at rank " << i; + // Make sure result is same + ASSERT_NEAR(doris_result.roaring->cardinality(), native_results.size(), 1) + << fmt::format("\nd: {}, m: {}, n: {}, metric: {}", random_d, random_m, + random_n, metric); + ASSERT_EQ(doris_result.distances != nullptr, true); + if (doris_result.roaring->cardinality() == native_results.size()) { + for (size_t i = 0; i < native_results.size(); i++) { + const size_t rowid = native_results[i].first; + const float dis = native_results[i].second; + ASSERT_EQ(doris_result.roaring->contains(rowid), true) + << "Row ID mismatch at rank " << i; + if (metric == faiss::METRIC_L2) { + ASSERT_FLOAT_EQ(doris_result.distances[i], sqrt(dis)) + << "Distance mismatch at rank " << i; + } else { + ASSERT_FLOAT_EQ(doris_result.distances[i], dis) + << "Distance mismatch at rank " << i; + } + } } } } @@ -634,7 +675,7 @@ TEST_F(VectorSearchTest, RangeSearchEmptyResult) { // L2 distance between [5,5,5,5,5,5,5,5,5,5] with any other vector is large than 5 and less than 250. // Find the min float radius = 5.0f; - doris::segment_v2::HNSWSearchParameters search_params; + doris::vectorized::HNSWSearchParameters search_params; search_params.ef_search = 1000; // Set efSearch for better accuracy auto doris_search_result = vector_search_utils::perform_doris_index_range_search( index1.get(), query_vec.data(), radius, search_params); @@ -645,7 +686,7 @@ TEST_F(VectorSearchTest, RangeSearchEmptyResult) { ASSERT_EQ(native_search_result.size(), 0); // Search all rows. - doris::segment_v2::HNSWSearchParameters search_params_all_rows; + doris::vectorized::HNSWSearchParameters search_params_all_rows; search_params_all_rows.ef_search = 1000; // Set efSearch for better accuracy search_params_all_rows.is_le_or_lt = true; std::unique_ptr<roaring::Roaring> sel_rows = std::make_unique<roaring::Roaring>(); diff --git a/be/test/olap/vector_search/vector_search_utils.cpp b/be/test/olap/vector_search/vector_search_utils.cpp index 49c60296a8e..33412a87d64 100644 --- a/be/test/olap/vector_search/vector_search_utils.cpp +++ b/be/test/olap/vector_search/vector_search_utils.cpp @@ -23,6 +23,7 @@ #include <memory> #include "faiss_vector_index.h" +#include "olap/rowset/segment_v2/ann_index/ann_search_params.h" #include "vector_index.h" namespace doris::vector_search_utils { @@ -143,7 +144,7 @@ void add_vectors_to_indexes_batch_mode(segment_v2::VectorIndex* doris_index, } // Helper function to print search results for comparison -void print_search_results(const segment_v2::IndexSearchResult& doris_results, +void print_search_results(const vectorized::IndexSearchResult& doris_results, const std::vector<float>& native_distances, const std::vector<faiss::idx_t>& native_indices, int query_idx) { std::cout << "Query vector index: " << query_idx << std::endl; @@ -163,7 +164,7 @@ void print_search_results(const segment_v2::IndexSearchResult& doris_results, } // Helper function to compare search results between Doris and native Faiss -void compare_search_results(const segment_v2::IndexSearchResult& doris_results, +void compare_search_results(const vectorized::IndexSearchResult& doris_results, const std::vector<float>& native_distances, const std::vector<faiss::idx_t>& native_indices, float abs_error) { EXPECT_EQ(doris_results.roaring->cardinality(), @@ -199,10 +200,10 @@ std::vector<std::pair<int, float>> perform_native_index_range_search(faiss::Inde return results; } -std::unique_ptr<doris::segment_v2::IndexSearchResult> perform_doris_index_range_search( +std::unique_ptr<doris::vectorized::IndexSearchResult> perform_doris_index_range_search( segment_v2::VectorIndex* index, const float* query_vector, float radius, - const segment_v2::IndexSearchParameters& params) { - auto result = std::make_unique<doris::segment_v2::IndexSearchResult>(); + const vectorized::IndexSearchParameters& params) { + auto result = std::make_unique<doris::vectorized::IndexSearchResult>(); std::ignore = index->range_search(query_vector, radius, params, *result); return result; } @@ -229,20 +230,36 @@ float get_radius_from_flatten(const float* vector, int dim, float get_radius_from_matrix(const float* vector, int dim, const std::vector<std::vector<float>>& matrix_vectors, - float percentile) { + float percentile, + faiss::MetricType metric_type /* = faiss::METRIC_L2 */) { size_t n = matrix_vectors.size(); std::vector<std::pair<size_t, float>> distances(n); for (size_t i = 0; i < n; i++) { double sum = 0; - for (int j = 0; j < dim; j++) { - accumulate(matrix_vectors[i][j], vector[j], sum); + if (metric_type == faiss::METRIC_L2) { + for (int j = 0; j < dim; j++) { + accumulate(matrix_vectors[i][j], vector[j], sum); + } + distances[i] = std::make_pair(i, finalize(sum)); + } else if (metric_type == faiss::METRIC_INNER_PRODUCT) { + for (int j = 0; j < dim; j++) { + sum += matrix_vectors[i][j] * vector[j]; + } + distances[i] = std::make_pair(i, static_cast<float>(sum)); + } else { + throw std::invalid_argument("Unsupported metric type in get_radius_from_matrix"); } - distances[i] = std::make_pair(i, finalize(sum)); } - std::sort(distances.begin(), distances.end(), - [](const auto& a, const auto& b) { return a.second < b.second; }); - // Use the median distance as the radius + if (metric_type == faiss::METRIC_L2) { + std::sort(distances.begin(), distances.end(), + [](const auto& a, const auto& b) { return a.second < b.second; }); + } else if (metric_type == faiss::METRIC_INNER_PRODUCT) { + std::sort(distances.begin(), distances.end(), + [](const auto& a, const auto& b) { return a.second > b.second; }); + } + // Use the percentile distance as the radius size_t percentile_index = static_cast<size_t>(n * percentile); + if (percentile_index >= n) percentile_index = n - 1; float radius = distances[percentile_index].second; return radius; diff --git a/be/test/olap/vector_search/vector_search_utils.h b/be/test/olap/vector_search/vector_search_utils.h index a7e3d5fdeaf..e28a79ed3a0 100644 --- a/be/test/olap/vector_search/vector_search_utils.h +++ b/be/test/olap/vector_search/vector_search_utils.h @@ -82,7 +82,7 @@ void add_vectors_to_indexes_batch_mode(segment_v2::VectorIndex* doris_index, faiss::Index* native_index, size_t num_vectors, const std::vector<float>& flatten_vectors); -void print_search_results(const segment_v2::IndexSearchResult& doris_results, +void print_search_results(const vectorized::IndexSearchResult& doris_results, const std::vector<float>& native_distances, const std::vector<faiss::idx_t>& native_indices, int query_idx); @@ -92,7 +92,7 @@ float get_radius_from_matrix(const float* vector, int dim, const std::vector<std::vector<float>>& matrix_vectors, float percentile); // Helper function to compare search results between Doris and native Faiss -void compare_search_results(const segment_v2::IndexSearchResult& doris_results, +void compare_search_results(const vectorized::IndexSearchResult& doris_results, const std::vector<float>& native_distances, const std::vector<faiss::idx_t>& native_indices, float abs_error = 1e-5); @@ -103,9 +103,9 @@ std::vector<std::pair<int, float>> perform_native_index_range_search(faiss::Inde const float* query_vector, float radius); -std::unique_ptr<doris::segment_v2::IndexSearchResult> perform_doris_index_range_search( +std::unique_ptr<doris::vectorized::IndexSearchResult> perform_doris_index_range_search( segment_v2::VectorIndex* index, const float* query_vector, float radius, - const segment_v2::IndexSearchParameters& params); + const vectorized::IndexSearchParameters& params); class MockIndexFileReader : public ::doris::segment_v2::IndexFileReader { public: @@ -150,9 +150,9 @@ public: MOCK_METHOD(Status, read_from_index, (const doris::segment_v2::IndexParam& param), (override)); MOCK_METHOD(Status, range_search, - (const segment_v2::RangeSearchParams& params, + (const vectorized::RangeSearchParams& params, const VectorSearchUserParams& custom_params, - segment_v2::RangeSearchResult* result), + vectorized::RangeSearchResult* result), (override)); private: diff --git a/be/test/olap/vector_search/virtual_column_iterator_test.cpp b/be/test/olap/vector_search/virtual_column_iterator_test.cpp index 0c3f4a73169..1631583105a 100644 --- a/be/test/olap/vector_search/virtual_column_iterator_test.cpp +++ b/be/test/olap/vector_search/virtual_column_iterator_test.cpp @@ -50,10 +50,10 @@ TEST_F(VectorSearchTest, ReadByRowIdsint32_tColumn) { VirtualColumnIterator iterator; // Create a materialized int32_t column with values [10, 20, 30, 40, 50] - auto int_column = vectorized::ColumnVector<int32_t>::create(); + auto int_column = vectorized::ColumnVector<TYPE_INT>::create(); std::unique_ptr<std::vector<uint64_t>> labels = std::make_unique<std::vector<uint64_t>>(); for (int i = 0; i < 5; i++) { - int_column->insert(10 * (i + 1)); + int_column->insert_value(10 * (i + 1)); labels->push_back(i); } // Set the materialized column @@ -61,7 +61,7 @@ TEST_F(VectorSearchTest, ReadByRowIdsint32_tColumn) { iterator.prepare_materialization(std::move(int_column), std::move(labels)); // Create destination column - vectorized::MutableColumnPtr dst = vectorized::ColumnVector<int32_t>::create(); + vectorized::MutableColumnPtr dst = vectorized::ColumnVector<TYPE_INT>::create(); // Select rowids 0, 2, 4 (values 10, 30, 50) rowid_t rowids[] = {0, 2, 4}; @@ -84,11 +84,11 @@ TEST_F(VectorSearchTest, ReadByRowIdsStringColumn) { // Create a materialized String column auto string_column = vectorized::ColumnString::create(); - string_column->insert("apple"); - string_column->insert("banana"); - string_column->insert("cherry"); - string_column->insert("date"); - string_column->insert("elderberry"); + string_column->insert_value("apple"); + string_column->insert_value("banana"); + string_column->insert_value("cherry"); + string_column->insert_value("date"); + string_column->insert_value("elderberry"); auto labels = std::make_unique<std::vector<uint64_t>>(); for (int i = 0; i < 5; i++) { labels->push_back(i); @@ -119,10 +119,10 @@ TEST_F(VectorSearchTest, ReadByRowIdsEmptyRowIds) { VirtualColumnIterator iterator; // Create a materialized int32_t column with values [10, 20, 30, 40, 50] - auto int_column = vectorized::ColumnVector<int32_t>::create(); + auto int_column = vectorized::ColumnVector<TYPE_INT>::create(); auto labels = std::make_unique<std::vector<uint64_t>>(); for (int i = 0; i < 5; i++) { - int_column->insert(10 * (i + 1)); + int_column->insert_value(10 * (i + 1)); labels->push_back(i); } @@ -130,7 +130,7 @@ TEST_F(VectorSearchTest, ReadByRowIdsEmptyRowIds) { iterator.prepare_materialization(std::move(int_column), std::move(labels)); // Create destination column - vectorized::MutableColumnPtr dst = vectorized::ColumnVector<int32_t>::create(); + vectorized::MutableColumnPtr dst = vectorized::ColumnVector<TYPE_INT>::create(); // Empty rowids array rowid_t rowids[1]; @@ -149,11 +149,11 @@ TEST_F(VectorSearchTest, TestLargeRowset) { VirtualColumnIterator iterator; // Create a large materialized int32_t column (1000 values) - auto int_column = vectorized::ColumnVector<int32_t>::create(); + auto int_column = vectorized::ColumnVector<TYPE_INT>::create(); auto labels = std::make_unique<std::vector<uint64_t>>(); for (int i = 0; i < 1000; i++) { - int_column->insert(i); + int_column->insert_value(i); labels->push_back(i); } @@ -161,7 +161,7 @@ TEST_F(VectorSearchTest, TestLargeRowset) { iterator.prepare_materialization(std::move(int_column), std::move(labels)); // Create destination column - vectorized::MutableColumnPtr dst = vectorized::ColumnVector<int32_t>::create(); + vectorized::MutableColumnPtr dst = vectorized::ColumnVector<TYPE_INT>::create(); // Select every 100th row (0, 100, 200, ... 900) const int step = 100; @@ -183,12 +183,12 @@ TEST_F(VectorSearchTest, TestLargeRowset) { TEST_F(VectorSearchTest, ReadByRowIdsNoContinueRowIds) { // Create a column with 1000 values (0-999) - auto column = ColumnVector<int32_t>::create(); + auto column = ColumnVector<TYPE_INT>::create(); auto labels = std::make_unique<std::vector<uint64_t>>(); // Generate non-consecutive row IDs by multiplying by 2 (0,2,4,...) for (size_t i = 0; i < 1000; i++) { - column->insert(i); + column->insert_value(i); labels->push_back(i * 2); // Non-consecutive row IDs } @@ -202,7 +202,7 @@ TEST_F(VectorSearchTest, ReadByRowIdsNoContinueRowIds) { } // Create destination column for results - vectorized::MutableColumnPtr dest_col = ColumnVector<int32_t>::create(); + vectorized::MutableColumnPtr dest_col = ColumnVector<TYPE_INT>::create(); // Test with various non-consecutive row IDs { @@ -281,17 +281,17 @@ TEST_F(VectorSearchTest, NextBatchTest1) { VirtualColumnIterator iterator; // 构造一个有100行的int32列,值为0~99 - auto int_column = vectorized::ColumnVector<int32_t>::create(); + auto int_column = vectorized::ColumnVector<TYPE_INT>::create(); auto labels = std::make_unique<std::vector<uint64_t>>(); for (int i = 0; i < 100; ++i) { - int_column->insert(i); + int_column->insert_value(i); labels->push_back(i); } iterator.prepare_materialization(std::move(int_column), std::move(labels)); // 1. seek到第10行,next_batch读取10行 { - vectorized::MutableColumnPtr dst = vectorized::ColumnVector<int32_t>::create(); + vectorized::MutableColumnPtr dst = vectorized::ColumnVector<TYPE_INT>::create(); Status st = iterator.seek_to_ordinal(10); ASSERT_TRUE(st.ok()); size_t rows_read = 10; @@ -307,7 +307,7 @@ TEST_F(VectorSearchTest, NextBatchTest1) { // 2. seek到第85行,next_batch读取10行(只剩5行可读) { - vectorized::MutableColumnPtr dst = vectorized::ColumnVector<int32_t>::create(); + vectorized::MutableColumnPtr dst = vectorized::ColumnVector<TYPE_INT>::create(); Status st = iterator.seek_to_ordinal(85); ASSERT_TRUE(st.ok()); size_t rows_read = 10; @@ -323,7 +323,7 @@ TEST_F(VectorSearchTest, NextBatchTest1) { // 3. seek到第0行,next_batch读取全部100行 { - vectorized::MutableColumnPtr dst = vectorized::ColumnVector<int32_t>::create(); + vectorized::MutableColumnPtr dst = vectorized::ColumnVector<TYPE_INT>::create(); Status st = iterator.seek_to_ordinal(0); ASSERT_TRUE(st.ok()); size_t rows_read = 100; @@ -339,7 +339,7 @@ TEST_F(VectorSearchTest, NextBatchTest1) { // 4. seek到越界位置(如100),应该报错 { - vectorized::MutableColumnPtr dst = vectorized::ColumnVector<int32_t>::create(); + vectorized::MutableColumnPtr dst = vectorized::ColumnVector<TYPE_INT>::create(); Status st = iterator.seek_to_ordinal(100); ASSERT_EQ(st.ok(), false); } @@ -349,12 +349,12 @@ TEST_F(VectorSearchTest, TestPrepare1) { VirtualColumnIterator iterator; // Create a materialized int32_t column with values [10, 20, 30, 40, 50] - auto int_column = vectorized::ColumnVector<int32_t>::create(); - int_column->insert(10); - int_column->insert(20); - int_column->insert(30); - int_column->insert(40); - int_column->insert(50); + auto int_column = vectorized::ColumnVector<TYPE_INT>::create(); + int_column->insert_value(10); + int_column->insert_value(20); + int_column->insert_value(30); + int_column->insert_value(40); + int_column->insert_value(50); auto labels = std::make_unique<std::vector<uint64_t>>(); labels->push_back(100); labels->push_back(11); @@ -375,7 +375,7 @@ TEST_F(VectorSearchTest, TestPrepare1) { auto materialization_col = iterator.get_materialized_column(); auto int_col_m = - assert_cast<const vectorized::ColumnVector<int32_t>*>(materialization_col.get()); + assert_cast<const vectorized::ColumnVector<TYPE_INT>*>(materialization_col.get()); ASSERT_EQ(int_col_m->get_data()[0], 20); ASSERT_EQ(int_col_m->get_data()[1], 40); ASSERT_EQ(int_col_m->get_data()[2], 30); @@ -387,12 +387,12 @@ TEST_F(VectorSearchTest, TestColumnNothing) { VirtualColumnIterator iterator; // Create a materialized int32_t column with values [10, 20, 30, 40, 50] - auto int_column = vectorized::ColumnVector<int32_t>::create(); - int_column->insert(10); - int_column->insert(20); - int_column->insert(30); - int_column->insert(40); - int_column->insert(50); + auto int_column = vectorized::ColumnVector<TYPE_INT>::create(); + int_column->insert_value(10); + int_column->insert_value(20); + int_column->insert_value(30); + int_column->insert_value(40); + int_column->insert_value(50); auto labels = std::make_unique<std::vector<uint64_t>>(); labels->push_back(100); labels->push_back(11); @@ -412,7 +412,7 @@ TEST_F(VectorSearchTest, TestColumnNothing) { ASSERT_TRUE(status.ok()); auto tmp_nothing = vectorized::check_and_get_column<vectorized::ColumnNothing>(*dst); ASSERT_TRUE(tmp_nothing == nullptr); - auto tmp_col_i32 = vectorized::check_and_get_column<vectorized::ColumnVector<int32_t>>( + auto tmp_col_i32 = vectorized::check_and_get_column<vectorized::ColumnVector<TYPE_INT>>( *iterator.get_materialized_column()); ASSERT_TRUE(tmp_col_i32 != nullptr); ASSERT_EQ(dst->size(), 3); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/DescriptorTable.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/DescriptorTable.java index fb6cc7df0a8..8c6d9e8ff4d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/DescriptorTable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/DescriptorTable.java @@ -71,6 +71,11 @@ public class DescriptorTable { return d; } + /** + * Create a new SlotDescriptor. + * Add it to input TupleDescriptor, store it to slot descriptors map. + * Return the newly created SlotDescriptor. + */ public SlotDescriptor addSlotDescriptor(TupleDescriptor d) { SlotDescriptor result = new SlotDescriptor(slotIdGenerator.getNextId(), d); d.addSlot(result); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java index 9b024de1f4f..56c3ade3371 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindRelation.java @@ -413,17 +413,17 @@ public class BindRelation extends OneAnalysisRuleFactory { } if (hmsTable.getDlaType() == DLAType.HUDI) { LogicalHudiScan hudiScan = new LogicalHudiScan(unboundRelation.getRelationId(), hmsTable, - qualifierWithoutTableName, unboundRelation.getTableSample(), - unboundRelation.getTableSnapshot(), ImmutableList.of(), Optional.empty()); + qualifierWithoutTableName, ImmutableList.of(), Optional.empty(), + unboundRelation.getTableSample(), unboundRelation.getTableSnapshot()); hudiScan = hudiScan.withScanParams( hmsTable, Optional.ofNullable(unboundRelation.getScanParams())); return hudiScan; } else { return new LogicalFileScan(unboundRelation.getRelationId(), (HMSExternalTable) table, qualifierWithoutTableName, + ImmutableList.of(), unboundRelation.getTableSample(), unboundRelation.getTableSnapshot(), - ImmutableList.of(), Optional.ofNullable(unboundRelation.getScanParams())); } case ICEBERG_EXTERNAL_TABLE: @@ -432,9 +432,9 @@ public class BindRelation extends OneAnalysisRuleFactory { case TRINO_CONNECTOR_EXTERNAL_TABLE: case LAKESOUl_EXTERNAL_TABLE: return new LogicalFileScan(unboundRelation.getRelationId(), (ExternalTable) table, - qualifierWithoutTableName, unboundRelation.getTableSample(), + qualifierWithoutTableName, ImmutableList.of(), + unboundRelation.getTableSample(), unboundRelation.getTableSnapshot(), - ImmutableList.of(), Optional.ofNullable(unboundRelation.getScanParams())); case SCHEMA: // schema table's name is case-insensitive, we need save its name in SQL text to get correct case. diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVectorTopNIntoOlapScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVectorTopNIntoOlapScan.java index adb7f0b5f2b..d8b7ba55daf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVectorTopNIntoOlapScan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownVectorTopNIntoOlapScan.java @@ -103,11 +103,11 @@ public class PushDownVectorTopNIntoOlapScan implements RewriteRuleFactory { return null; } SlotReference leftInput = (SlotReference) left; - if (!leftInput.getColumn().isPresent() || !leftInput.getTable().isPresent()) { + if (!leftInput.getOriginalColumn().isPresent() || !leftInput.getOriginalTable().isPresent()) { return null; } - TableIf table = leftInput.getTable().get(); - Column column = leftInput.getColumn().get(); + TableIf table = leftInput.getOriginalTable().get(); + Column column = leftInput.getOriginalColumn().get(); boolean hasAnnIndexOnColumn = false; for (Index index : table.getTableIndexes().getIndexes()) { if (index.getIndexType() == IndexType.ANN) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFileScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFileScan.java index d44b1e4f9d7..74566551223 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFileScan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalFileScan.java @@ -53,20 +53,23 @@ public class LogicalFileScan extends LogicalCatalogRelation { protected final Optional<TableScanParams> scanParams; public LogicalFileScan(RelationId id, ExternalTable table, List<String> qualifier, - Optional<TableSample> tableSample, Optional<TableSnapshot> tableSnapshot) { - this(id, table, qualifier, table.initSelectedPartitions(MvccUtil.getSnapshotFromContext(table)), - tableSample, tableSnapshot, ImmutableList.of(), Optional.empty(), Optional.empty()); + Collection<Slot> operativeSlots, + Optional<TableSample> tableSample, Optional<TableSnapshot> tableSnapshot, + Optional<TableScanParams> scanParams) { + this(id, table, qualifier, + table.initSelectedPartitions(MvccUtil.getSnapshotFromContext(table)), + operativeSlots, ImmutableList.of(), + tableSample, tableSnapshot, + scanParams, Optional.empty(), Optional.empty()); } /** * Constructor for LogicalFileScan. */ protected LogicalFileScan(RelationId id, ExternalTable table, List<String> qualifier, - SelectedPartitions selectedPartitions, Optional<TableSample> tableSample, - Optional<TableSnapshot> tableSnapshot, - Collection<Slot> operativeSlots, - Optional<TableScanParams> scanParams, - List<NamedExpression> virtualColumns, + SelectedPartitions selectedPartitions, Collection<Slot> operativeSlots, + List<NamedExpression> virtualColumns, Optional<TableSample> tableSample, + Optional<TableSnapshot> tableSnapshot, Optional<TableScanParams> scanParams, Optional<GroupExpression> groupExpression, Optional<LogicalProperties> logicalProperties) { super(id, PlanType.LOGICAL_FILE_SCAN, table, qualifier, operativeSlots, virtualColumns, groupExpression, logicalProperties); @@ -76,16 +79,6 @@ public class LogicalFileScan extends LogicalCatalogRelation { this.scanParams = scanParams; } - public LogicalFileScan(RelationId id, ExternalTable table, List<String> qualifier, - Optional<TableSample> tableSample, Optional<TableSnapshot> tableSnapshot, - Collection<Slot> operativeSlots, - Optional<TableScanParams> scanParams) { - this(id, table, qualifier, Optional.empty(), Optional.empty(), - table.initSelectedPartitions(MvccUtil.getSnapshotFromContext(table)), - tableSample, tableSnapshot, operativeSlots, scanParams, - Optional.empty(), Optional.empty(),Optional.empty()); - } - public SelectedPartitions getSelectedPartitions() { return selectedPartitions; } @@ -121,29 +114,29 @@ public class LogicalFileScan extends LogicalCatalogRelation { @Override public LogicalFileScan withGroupExpression(Optional<GroupExpression> groupExpression) { return new LogicalFileScan(relationId, (ExternalTable) table, qualifier, - selectedPartitions, tableSample, tableSnapshot, virtualColumns, groupExpression, - Optional.of(getLogicalProperties())); + selectedPartitions, operativeSlots, virtualColumns, tableSample, tableSnapshot, + scanParams, groupExpression, Optional.of(getLogicalProperties())); } @Override public Plan withGroupExprLogicalPropChildren(Optional<GroupExpression> groupExpression, Optional<LogicalProperties> logicalProperties, List<Plan> children) { return new LogicalFileScan(relationId, (ExternalTable) table, qualifier, - selectedPartitions, tableSample, tableSnapshot, virtualColumns, - groupExpression, logicalProperties); + selectedPartitions, operativeSlots, virtualColumns, tableSample, tableSnapshot, + scanParams, groupExpression, logicalProperties); } public LogicalFileScan withSelectedPartitions(SelectedPartitions selectedPartitions) { return new LogicalFileScan(relationId, (ExternalTable) table, qualifier, - selectedPartitions, tableSample, tableSnapshot, virtualColumns, - Optional.empty(), Optional.of(getLogicalProperties())); + selectedPartitions, operativeSlots, virtualColumns, tableSample, tableSnapshot, + scanParams, Optional.empty(), Optional.of(getLogicalProperties())); } @Override public LogicalFileScan withRelationId(RelationId relationId) { return new LogicalFileScan(relationId, (ExternalTable) table, qualifier, - selectedPartitions, tableSample, tableSnapshot, virtualColumns, - Optional.empty(), Optional.empty()); + selectedPartitions, operativeSlots, virtualColumns, tableSample, tableSnapshot, + scanParams, Optional.empty(), Optional.empty()); } @Override @@ -212,9 +205,8 @@ public class LogicalFileScan extends LogicalCatalogRelation { @Override public LogicalFileScan withOperativeSlots(Collection<Slot> operativeSlots) { return new LogicalFileScan(relationId, (ExternalTable) table, qualifier, - groupExpression, Optional.of(getLogicalProperties()), - selectedPartitions, tableSample, tableSnapshot, - operativeSlots, scanParams); + selectedPartitions, operativeSlots, virtualColumns,tableSample, tableSnapshot, + scanParams, groupExpression, Optional.of(getLogicalProperties())); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHudiScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHudiScan.java index 86a36290eed..7f72a897585 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHudiScan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalHudiScan.java @@ -78,21 +78,19 @@ public class LogicalHudiScan extends LogicalFileScan { List<NamedExpression> virtualColumns, Optional<GroupExpression> groupExpression, Optional<LogicalProperties> logicalProperties) { - super(id, table, qualifier, selectedPartitions, tableSample, tableSnapshot, operativeSlots, - virtualColumns, groupExpression, logicalProperties); + super(id, table, qualifier, selectedPartitions, operativeSlots, virtualColumns, + tableSample, tableSnapshot, scanParams, groupExpression, logicalProperties); Objects.requireNonNull(scanParams, "scanParams should not null"); Objects.requireNonNull(incrementalRelation, "incrementalRelation should not null"); this.incrementalRelation = incrementalRelation; } public LogicalHudiScan(RelationId id, ExternalTable table, List<String> qualifier, - Collection<Slot> operativeSlots, - Optional<TableScanParams> scanParams, + Collection<Slot> operativeSlots, Optional<TableScanParams> scanParams, Optional<TableSample> tableSample, Optional<TableSnapshot> tableSnapshot) { - this(id, table, qualifier, Optional.empty(), Optional.empty(), - ((HMSExternalTable) table).initHudiSelectedPartitions(tableSnapshot), tableSample, tableSnapshot, - Optional.empty(), Optional.empty(), - ImmutableList.of(), Optional.empty(), Optional.empty()); + this(id, table, qualifier, ((HMSExternalTable) table).initHudiSelectedPartitions(tableSnapshot), + tableSample, tableSnapshot, scanParams, Optional.empty(), operativeSlots, ImmutableList.of(), + Optional.empty(), Optional.empty()); } public Optional<TableScanParams> getScanParams() { @@ -142,29 +140,29 @@ public class LogicalHudiScan extends LogicalFileScan { @Override public LogicalHudiScan withGroupExpression(Optional<GroupExpression> groupExpression) { return new LogicalHudiScan(relationId, (ExternalTable) table, qualifier, - selectedPartitions, tableSample, tableSnapshot, scanParams, incrementalRelation, virtualColumns, - groupExpression, Optional.of(getLogicalProperties())); + selectedPartitions, tableSample, tableSnapshot, scanParams, incrementalRelation, + operativeSlots, virtualColumns, groupExpression, Optional.of(getLogicalProperties())); } @Override public Plan withGroupExprLogicalPropChildren(Optional<GroupExpression> groupExpression, Optional<LogicalProperties> logicalProperties, List<Plan> children) { return new LogicalHudiScan(relationId, (ExternalTable) table, qualifier, - selectedPartitions, tableSample, tableSnapshot, scanParams, incrementalRelation, virtualColumns, - groupExpression, logicalProperties); + selectedPartitions, tableSample, tableSnapshot, scanParams, incrementalRelation, + operativeSlots, virtualColumns, groupExpression, logicalProperties); } public LogicalHudiScan withSelectedPartitions(SelectedPartitions selectedPartitions) { return new LogicalHudiScan(relationId, (ExternalTable) table, qualifier, - selectedPartitions, tableSample, tableSnapshot, scanParams, incrementalRelation, virtualColumns, - Optional.empty(), Optional.of(getLogicalProperties())); + selectedPartitions, tableSample, tableSnapshot, scanParams, incrementalRelation, + operativeSlots, virtualColumns, groupExpression, Optional.of(getLogicalProperties())); } @Override public LogicalHudiScan withRelationId(RelationId relationId) { return new LogicalHudiScan(relationId, (ExternalTable) table, qualifier, - selectedPartitions, tableSample, tableSnapshot, scanParams, incrementalRelation, virtualColumns, - Optional.empty(), Optional.empty()); + selectedPartitions, tableSample, tableSnapshot, scanParams, incrementalRelation, + operativeSlots, virtualColumns, groupExpression, Optional.of(getLogicalProperties())); } @Override @@ -175,9 +173,8 @@ public class LogicalHudiScan extends LogicalFileScan { @Override public LogicalFileScan withOperativeSlots(Collection<Slot> operativeSlots) { return new LogicalHudiScan(relationId, (ExternalTable) table, qualifier, - groupExpression, Optional.of(getLogicalProperties()), - selectedPartitions, tableSample, tableSnapshot, scanParams, incrementalRelation, - operativeSlots); + selectedPartitions, tableSample, tableSnapshot, scanParams, incrementalRelation, + operativeSlots, virtualColumns, groupExpression, Optional.of(getLogicalProperties())); } /** @@ -227,9 +224,8 @@ public class LogicalHudiScan extends LogicalFileScan { "Failed to create incremental relation for table: " + table.getFullQualifiers(), e); } } - newScanParams = Optional.ofNullable(scanParams); - return new LogicalHudiScan(relationId, table, qualifier, - selectedPartitions, tableSample, tableSnapshot, newScanParams, newIncrementalRelation, virtualColumns, - Optional.empty(), Optional.empty()); + return new LogicalHudiScan(relationId, (ExternalTable) table, qualifier, + selectedPartitions, tableSample, tableSnapshot, scanParams, newIncrementalRelation, + operativeSlots, virtualColumns, groupExpression, Optional.of(getLogicalProperties())); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalLazyMaterializeOlapScan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalLazyMaterializeOlapScan.java index add4742099b..f7857069dc9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalLazyMaterializeOlapScan.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalLazyMaterializeOlapScan.java @@ -56,7 +56,11 @@ public class PhysicalLazyMaterializeOlapScan extends PhysicalOlapScan { physicalOlapScan.getPhysicalProperties(), physicalOlapScan.getStats(), physicalOlapScan.getTableSample(), - physicalOlapScan.getOperativeSlots()); + physicalOlapScan.getOperativeSlots(), + physicalOlapScan.getVirtualColumns(), + physicalOlapScan.getAnnOrderKeys(), + physicalOlapScan.getAnnLimit() + ); this.scan = physicalOlapScan; this.rowId = rowId; this.lazySlots = ImmutableList.copyOf(lazySlots); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org