This is an automated email from the ASF dual-hosted git repository. morningman pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 86502b014d [feature-wip](unique-key-merge-on-write)port IntervalTree from kudu (#10511) 86502b014d is described below commit 86502b014d8412dd84b8a5a390114ceda9feba85 Author: zhannngchen <48427519+zhannngc...@users.noreply.github.com> AuthorDate: Tue Jul 5 17:43:01 2022 +0800 [feature-wip](unique-key-merge-on-write)port IntervalTree from kudu (#10511) See the DISP-18:https://cwiki.apache.org/confluence/display/DORIS/DSIP-018%3A+Support+Merge-On-Write+implementation+for+UNIQUE+KEY+data+model This patch is for step 3.1 in scheduling. --- be/src/util/interval_tree-inl.h | 440 ++++++++++++++++++++++++++++++++++++ be/src/util/interval_tree.h | 159 +++++++++++++ be/test/CMakeLists.txt | 1 + be/test/util/interval_tree_test.cpp | 392 ++++++++++++++++++++++++++++++++ 4 files changed, 992 insertions(+) diff --git a/be/src/util/interval_tree-inl.h b/be/src/util/interval_tree-inl.h new file mode 100644 index 0000000000..d322d260d5 --- /dev/null +++ b/be/src/util/interval_tree-inl.h @@ -0,0 +1,440 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +// This file is copied from +// https://github.com/apache/kudu/blob/master/src/kudu/util/interval_tree-inl.h +// and modified by Doris +// + +#pragma once + +#include <algorithm> +#include <vector> + +#include "util/interval_tree.h" + +namespace doris { + +template <class Traits> +IntervalTree<Traits>::IntervalTree(const IntervalVector& intervals) : root_(NULL) { + if (!intervals.empty()) { + root_ = CreateNode(intervals); + } +} + +template <class Traits> +IntervalTree<Traits>::~IntervalTree() { + delete root_; +} + +template <class Traits> +template <class QueryPointType> +void IntervalTree<Traits>::FindContainingPoint(const QueryPointType& query, + IntervalVector* results) const { + if (root_) { + root_->FindContainingPoint(query, results); + } +} + +template <class Traits> +template <class Callback, class QueryContainer> +void IntervalTree<Traits>::ForEachIntervalContainingPoints(const QueryContainer& queries, + const Callback& cb) const { + if (root_) { + root_->ForEachIntervalContainingPoints(queries.begin(), queries.end(), cb); + } +} + +template <class Traits> +template <class QueryPointType> +void IntervalTree<Traits>::FindIntersectingInterval(const QueryPointType& lower_bound, + const QueryPointType& upper_bound, + IntervalVector* results) const { + if (root_) { + root_->FindIntersectingInterval(lower_bound, upper_bound, results); + } +} + +template <class Traits> +static bool LessThan(const typename Traits::point_type& a, const typename Traits::point_type& b) { + return Traits::compare(a, b) < 0; +} + +// Select a split point which attempts to evenly divide 'in' into three groups: +// (a) those that are fully left of the split point +// (b) those that overlap the split point. +// (c) those that are fully right of the split point +// These three groups are stored in the output parameters '*left', '*overlapping', +// and '*right', respectively. The selected split point is stored in *split_point. +// +// For example, the input interval set: +// +// |------1-------| |-----2-----| +// |--3--| |---4--| |----5----| +// | +// Resulting split: | Partition point +// | +// +// *left: intervals 1 and 3 +// *overlapping: interval 4 +// *right: intervals 2 and 5 +template <class Traits> +void IntervalTree<Traits>::Partition(const IntervalVector& in, point_type* split_point, + IntervalVector* left, IntervalVector* overlapping, + IntervalVector* right) { + CHECK(!in.empty()); + + // Pick a split point which is the median of all of the interval boundaries. + std::vector<point_type> endpoints; + endpoints.reserve(in.size() * 2); + for (const interval_type& interval : in) { + endpoints.push_back(Traits::get_left(interval)); + endpoints.push_back(Traits::get_right(interval)); + } + std::sort(endpoints.begin(), endpoints.end(), LessThan<Traits>); + *split_point = endpoints[endpoints.size() / 2]; + + // Partition into the groups based on the determined split point. + for (const interval_type& interval : in) { + if (Traits::compare(Traits::get_right(interval), *split_point) < 0) { + // | split point + // |------------| | + // interval + left->push_back(interval); + } else if (Traits::compare(Traits::get_left(interval), *split_point) > 0) { + // | split point + // | |------------| + // interval + right->push_back(interval); + } else { + // | split point + // | + // |------------| + // interval + overlapping->push_back(interval); + } + } +} + +template <class Traits> +typename IntervalTree<Traits>::node_type* IntervalTree<Traits>::CreateNode( + const IntervalVector& intervals) { + IntervalVector left, right, overlap; + point_type split_point; + + // First partition the input intervals and select a split point + Partition(intervals, &split_point, &left, &overlap, &right); + + // Recursively subdivide the intervals which are fully left or fully + // right of the split point into subtree nodes. + node_type* left_node = !left.empty() ? CreateNode(left) : NULL; + node_type* right_node = !right.empty() ? CreateNode(right) : NULL; + + return new node_type(split_point, left_node, overlap, right_node); +} + +namespace interval_tree_internal { + +// Node in the interval tree. +template <typename Traits> +class ITNode { +private: + // Import types. + typedef std::vector<typename Traits::interval_type> IntervalVector; + typedef typename Traits::interval_type interval_type; + typedef typename Traits::point_type point_type; + +public: + ITNode(point_type split_point, ITNode<Traits>* left, const IntervalVector& overlap, + ITNode<Traits>* right); + ~ITNode(); + + // See IntervalTree::FindContainingPoint(...) + template <class QueryPointType> + void FindContainingPoint(const QueryPointType& query, IntervalVector* results) const; + + // See IntervalTree::ForEachIntervalContainingPoints(). + // We use iterators here since as recursion progresses down the tree, we + // process sub-sequences of the original set of query points. + template <class Callback, class ItType> + void ForEachIntervalContainingPoints(ItType begin_queries, ItType end_queries, + const Callback& cb) const; + + // See IntervalTree::FindIntersectingInterval(...) + template <class QueryPointType> + void FindIntersectingInterval(const QueryPointType& lower_bound, + const QueryPointType& upper_bound, IntervalVector* results) const; + +private: + // Comparators for sorting lists of intervals. + static bool SortByAscLeft(const interval_type& a, const interval_type& b); + static bool SortByDescRight(const interval_type& a, const interval_type& b); + + // Partition point of this node. + point_type split_point_; + + // Those nodes that overlap with split_point_, in ascending order by their left side. + IntervalVector overlapping_by_asc_left_; + + // Those nodes that overlap with split_point_, in descending order by their right side. + IntervalVector overlapping_by_desc_right_; + + // Tree node for intervals fully left of split_point_, or NULL. + ITNode* left_; + + // Tree node for intervals fully right of split_point_, or NULL. + ITNode* right_; + + DISALLOW_COPY_AND_ASSIGN(ITNode); +}; + +template <class Traits> +bool ITNode<Traits>::SortByAscLeft(const interval_type& a, const interval_type& b) { + return Traits::compare(Traits::get_left(a), Traits::get_left(b)) < 0; +} + +template <class Traits> +bool ITNode<Traits>::SortByDescRight(const interval_type& a, const interval_type& b) { + return Traits::compare(Traits::get_right(a), Traits::get_right(b)) > 0; +} + +template <class Traits> +ITNode<Traits>::ITNode(typename Traits::point_type split_point, ITNode<Traits>* left, + const IntervalVector& overlap, ITNode<Traits>* right) + : split_point_(std::move(split_point)), left_(left), right_(right) { + // Store two copies of the set of intervals which overlap the split point: + // 1) Sorted by ascending left boundary + overlapping_by_asc_left_.assign(overlap.begin(), overlap.end()); + std::sort(overlapping_by_asc_left_.begin(), overlapping_by_asc_left_.end(), SortByAscLeft); + // 2) Sorted by descending right boundary + overlapping_by_desc_right_.assign(overlap.begin(), overlap.end()); + std::sort(overlapping_by_desc_right_.begin(), overlapping_by_desc_right_.end(), + SortByDescRight); +} + +template <class Traits> +ITNode<Traits>::~ITNode() { + if (left_) delete left_; + if (right_) delete right_; +} + +template <class Traits> +template <class Callback, class ItType> +void ITNode<Traits>::ForEachIntervalContainingPoints(ItType begin_queries, ItType end_queries, + const Callback& cb) const { + if (begin_queries == end_queries) return; + + typedef decltype(*begin_queries) QueryPointType; + const auto& partitioner = [&](const QueryPointType& query_point) { + return Traits::compare(query_point, split_point_) < 0; + }; + + // Partition the query points into those less than the split_point_ and those greater + // than or equal to the split_point_. Because the input queries are already sorted, we + // can use 'std::partition_point' instead of 'std::partition'. + // + // The resulting 'partition_point' is the first query point in the second group. + // + // Complexity: O(log(number of query points)) + DCHECK(std::is_partitioned(begin_queries, end_queries, partitioner)); + auto partition_point = std::partition_point(begin_queries, end_queries, partitioner); + + // Recurse left: any query points left of the split point may intersect + // with non-overlapping intervals fully-left of our split point. + if (left_ != NULL) { + left_->ForEachIntervalContainingPoints(begin_queries, partition_point, cb); + } + + // Handle the query points < split_point / + // / + // split_point_ / + // | / + // [------] \ / + // [-------] | overlapping_by_asc_left_ / + // [--------] / / + // Q Q Q / + // ^ ^ \___ not handled (right of split_point_) / + // | | / + // \___\___ these points will be handled here / + // + + // Lower bound of query points still relevant. + auto rem_queries = begin_queries; + for (const interval_type& interval : overlapping_by_asc_left_) { + const auto& interval_left = Traits::get_left(interval); + // Find those query points which are right of the left side of the interval. + // 'first_match' here is the first query point >= interval_left. + // Complexity: O(log(num_queries)) + // + // TODO(todd): The non-batched implementation is O(log(num_intervals) * num_queries) + // whereas this loop ends up O(num_intervals * log(num_queries)). So, for + // small numbers of queries this is not the fastest way to structure these loops. + auto first_match = std::partition_point( + rem_queries, partition_point, [&](const QueryPointType& query_point) { + return Traits::compare(query_point, interval_left) < 0; + }); + for (auto it = first_match; it != partition_point; ++it) { + cb(*it, interval); + } + // Since the intervals are sorted in ascending-left order, we can start + // the search for the next interval at the first match in this interval. + // (any query point which was left of the current interval will also be left + // of all future intervals). + rem_queries = std::move(first_match); + } + + // Handle the query points >= split_point / + // / + // split_point_ / + // | / + // [--------] \ / + // [-------] | overlapping_by_desc_right_ / + // [------] / / + // Q Q Q / + // | \______\___ these points will be handled here / + // | / + // \___ not handled (left of split_point_) / + + // Upper bound of query points still relevant. + rem_queries = end_queries; + for (const interval_type& interval : overlapping_by_desc_right_) { + const auto& interval_right = Traits::get_right(interval); + // Find the first query point which is > the right side of the interval. + auto first_non_match = std::partition_point( + partition_point, rem_queries, [&](const QueryPointType& query_point) { + return Traits::compare(query_point, interval_right) <= 0; + }); + for (auto it = partition_point; it != first_non_match; ++it) { + cb(*it, interval); + } + // Same logic as above: if a query point was fully right of 'interval', + // then it will be fully right of all following intervals because they are + // sorted by descending-right. + rem_queries = std::move(first_non_match); + } + + if (right_ != NULL) { + while (partition_point != end_queries && + Traits::compare(*partition_point, split_point_) == 0) { + ++partition_point; + } + right_->ForEachIntervalContainingPoints(partition_point, end_queries, cb); + } +} + +template <class Traits> +template <class QueryPointType> +void ITNode<Traits>::FindContainingPoint(const QueryPointType& query, + IntervalVector* results) const { + int cmp = Traits::compare(query, split_point_); + if (cmp < 0) { + // None of the intervals in right_ may intersect this. + if (left_ != NULL) { + left_->FindContainingPoint(query, results); + } + + // Any intervals which start before the query point and overlap the split point + // must therefore contain the query point. + auto p = std::partition_point( + overlapping_by_asc_left_.cbegin(), overlapping_by_asc_left_.cend(), + [&](const interval_type& interval) { + return Traits::compare(Traits::get_left(interval), query) <= 0; + }); + results->insert(results->end(), overlapping_by_asc_left_.cbegin(), p); + } else if (cmp > 0) { + // None of the intervals in left_ may intersect this. + if (right_ != NULL) { + right_->FindContainingPoint(query, results); + } + + // Any intervals which end after the query point and overlap the split point + // must therefore contain the query point. + auto p = std::partition_point( + overlapping_by_desc_right_.cbegin(), overlapping_by_desc_right_.cend(), + [&](const interval_type& interval) { + return Traits::compare(Traits::get_right(interval), query) >= 0; + }); + results->insert(results->end(), overlapping_by_desc_right_.cbegin(), p); + } else { + DCHECK_EQ(cmp, 0); + // The query is exactly our split point -- in this case we've already got + // the computed list of overlapping intervals. + results->insert(results->end(), overlapping_by_asc_left_.begin(), + overlapping_by_asc_left_.end()); + } +} + +template <class Traits> +template <class QueryPointType> +void ITNode<Traits>::FindIntersectingInterval(const QueryPointType& lower_bound, + const QueryPointType& upper_bound, + IntervalVector* results) const { + if (Traits::compare(upper_bound, split_point_, POSITIVE_INFINITY) <= 0) { + // The interval is fully left of the split point and with split point. + // So, it may not overlap with any in 'right_' + if (left_ != NULL) { + left_->FindIntersectingInterval(lower_bound, upper_bound, results); + } + + // Any interval whose left edge is < the query interval's right edge + // intersect the query interval. 'std::partition_point' returns the first + // such interval which does not meet that criterion, so we insert all + // up to that point. + auto first_greater = std::partition_point( + overlapping_by_asc_left_.cbegin(), overlapping_by_asc_left_.cend(), + [&](const interval_type& interval) { + return Traits::compare(Traits::get_left(interval), upper_bound, + POSITIVE_INFINITY) < 0; + }); + results->insert(results->end(), overlapping_by_asc_left_.cbegin(), first_greater); + } else if (Traits::compare(lower_bound, split_point_, NEGATIVE_INFINITY) > 0) { + // The interval is fully right of the split point. So, it may not overlap + // with any in 'left_'. + if (right_ != NULL) { + right_->FindIntersectingInterval(lower_bound, upper_bound, results); + } + + // Any interval whose right edge is >= the query interval's left edge + // intersect the query interval. 'std::partition_point' returns the first + // such interval which does not meet that criterion, so we insert all + // up to that point. + auto first_lesser = std::partition_point( + overlapping_by_desc_right_.cbegin(), overlapping_by_desc_right_.cend(), + [&](const interval_type& interval) { + return Traits::compare(Traits::get_right(interval), lower_bound, + NEGATIVE_INFINITY) >= 0; + }); + results->insert(results->end(), overlapping_by_desc_right_.cbegin(), first_lesser); + } else { + // The query interval contains the split point. Therefore all other intervals + // which also contain the split point are intersecting. + results->insert(results->end(), overlapping_by_asc_left_.begin(), + overlapping_by_asc_left_.end()); + + // The query interval may _also_ intersect some in either child. + if (left_ != NULL) { + left_->FindIntersectingInterval(lower_bound, upper_bound, results); + } + if (right_ != NULL) { + right_->FindIntersectingInterval(lower_bound, upper_bound, results); + } + } +} + +} // namespace interval_tree_internal + +} // namespace doris diff --git a/be/src/util/interval_tree.h b/be/src/util/interval_tree.h new file mode 100644 index 0000000000..dd978e8354 --- /dev/null +++ b/be/src/util/interval_tree.h @@ -0,0 +1,159 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +// Implements an Interval Tree. See http://en.wikipedia.org/wiki/Interval_tree +// or CLRS for a full description of the data structure. +// +// This file is copied from +// https://github.com/apache/kudu/blob/master/src/kudu/util/interval_tree.h +// and modified by Doris +// +// Callers of this class should also include interval_tree-inl.h for function +// definitions. + +#pragma once + +#include <glog/logging.h> + +#include <vector> + +#include "gutil/macros.h" + +namespace doris { + +namespace interval_tree_internal { +template <class Traits> +class ITNode; +} + +// End point type when boost::none. +enum EndpointIfNone { POSITIVE_INFINITY, NEGATIVE_INFINITY }; + +// Implements an Interval Tree. +// +// An Interval Tree is a data structure which stores a set of intervals and supports +// efficient searches to determine which intervals in that set overlap a query +// point or interval. These operations are O(lg n + k) where 'n' is the number of +// intervals in the tree and 'k' is the number of results returned for a given query. +// +// This particular implementation is a static tree -- intervals may not be added or +// removed once the tree is instantiated. +// +// This class also assumes that all intervals are "closed" intervals -- the intervals +// are inclusive of their start and end points. +// +// The Traits class should have the following members: +// Traits::point_type +// a typedef for what a "point" in the range is +// +// Traits::interval_type +// a typedef for an interval +// +// static point_type get_left(const interval_type &) +// static point_type get_right(const interval_type &) +// accessors which fetch the left and right bound of the interval, respectively. +// +// static int compare(const point_type &a, const point_type &b) +// return < 0 if a < b, 0 if a == b, > 0 if a > b +// +// See interval_tree-test.cc for an example Traits class for 'int' ranges. +template <class Traits> +class IntervalTree { +private: + // Import types from the traits class to make code more readable. + typedef typename Traits::interval_type interval_type; + typedef typename Traits::point_type point_type; + + // And some convenience types. + typedef std::vector<interval_type> IntervalVector; + typedef interval_tree_internal::ITNode<Traits> node_type; + +public: + // Construct an Interval Tree containing the given set of intervals. + explicit IntervalTree(const IntervalVector& intervals); + + ~IntervalTree(); + + // Find all intervals in the tree which contain the query point. + // The resulting intervals are added to the 'results' vector. + // The vector is not cleared first. + // + // NOTE: 'QueryPointType' is usually point_type, but can be any other + // type for which there exists the appropriate Traits::Compare(...) method. + template <class QueryPointType> + void FindContainingPoint(const QueryPointType& query, IntervalVector* results) const; + + // For each of the query points in the STL container 'queries', find all + // intervals in the tree which may contain those points. Calls 'cb(point, interval)' + // for each such interval. + // + // The points in the query container must be comparable to 'point_type' + // using Traits::Compare(). + // + // The implementation sequences the calls to 'cb' with the following guarantees: + // 1) all of the results corresponding to a given interval will be yielded in at + // most two "groups" of calls (i.e. sub-sequences of calls with the same interval). + // 2) within each "group" of calls, the query points will be in ascending order. + // + // For example, the callback sequence may be: + // + // cb(q1, interval_1) - + // cb(q2, interval_1) | first group of interval_1 + // cb(q6, interval_1) | + // cb(q7, interval_1) - + // + // cb(q2, interval_2) - + // cb(q3, interval_2) | first group of interval_2 + // cb(q4, interval_2) - + // + // cb(q3, interval_1) - + // cb(q4, interval_1) | second group of interval_1 + // cb(q5, interval_1) - + // + // cb(q2, interval_3) - + // cb(q3, interval_3) | first group of interval_3 + // cb(q4, interval_3) - + // + // cb(q5, interval_2) - + // cb(q6, interval_2) | second group of interval_2 + // cb(q7, interval_2) - + // + // REQUIRES: The input points must be pre-sorted or else this will return invalid + // results. + template <class Callback, class QueryContainer> + void ForEachIntervalContainingPoints(const QueryContainer& queries, const Callback& cb) const; + + // Find all intervals in the tree which intersect the given interval. + // The resulting intervals are added to the 'results' vector. + // The vector is not cleared first. + template <class QueryPointType> + void FindIntersectingInterval(const QueryPointType& lower_bound, + const QueryPointType& upper_bound, IntervalVector* results) const; + +private: + static void Partition(const IntervalVector& in, point_type* split_point, IntervalVector* left, + IntervalVector* overlapping, IntervalVector* right); + + // Create a node containing the given intervals, recursively splitting down the tree. + static node_type* CreateNode(const IntervalVector& intervals); + + node_type* root_; + + DISALLOW_COPY_AND_ASSIGN(IntervalTree); +}; + +} // namespace doris diff --git a/be/test/CMakeLists.txt b/be/test/CMakeLists.txt index 6b8648b9ba..6ee7bccaed 100644 --- a/be/test/CMakeLists.txt +++ b/be/test/CMakeLists.txt @@ -319,6 +319,7 @@ set(UTIL_TEST_FILES util/array_parser_test.cpp util/quantile_state_test.cpp util/hdfs_storage_backend_test.cpp + util/interval_tree_test.cpp ) set(VEC_TEST_FILES vec/aggregate_functions/agg_test.cpp diff --git a/be/test/util/interval_tree_test.cpp b/be/test/util/interval_tree_test.cpp new file mode 100644 index 0000000000..4fb8ed4197 --- /dev/null +++ b/be/test/util/interval_tree_test.cpp @@ -0,0 +1,392 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// +// This file is copied from +// https://github.com/apache/kudu/blob/master/src/kudu/util/interval_tree-test.cc +// and modified by Doris + +#include "util/interval_tree.h" + +#include <glog/logging.h> +#include <gtest/gtest.h> + +#include <algorithm> +#include <map> +#include <memory> +#include <ostream> +#include <string> +#include <tuple> // IWYU pragma: keep +#include <utility> +#include <vector> + +#include "gutil/stringprintf.h" +#include "gutil/strings/substitute.h" +#include "testutil/test_util.h" +#include "util/interval_tree-inl.h" + +using std::pair; +using std::string; +using std::vector; +using strings::Substitute; + +namespace doris { + +// Test harness. +class TestIntervalTree : public testing::Test {}; + +// Simple interval class for integer intervals. +struct IntInterval { + IntInterval(int left, int right, int id = -1) : left(left), right(right), id(id) {} + + // std::nullopt means infinity. + // [left, right] is closed interval. + // [lower, upper) is half-open interval, so the upper is exclusive. + bool Intersects(const std::optional<int>& lower, const std::optional<int>& upper) const { + if (lower == std::nullopt && upper == std::nullopt) { + // [left, right] + // | | + // [-OO, +OO) + } else if (lower == std::nullopt) { + // [left, right] + // | + // [-OO, upper) + if (*upper <= this->left) return false; + } else if (upper == std::nullopt) { + // [left, right] / + // \ / + // [lower, +OO) / + if (*lower > this->right) return false; + } else { + // [left, right] / + // \ / + // [lower, upper) / + if (*lower > this->right) return false; + // [left, right] / + // | / + // [lower, upper) / + if (*upper <= this->left) return false; + } + return true; + } + + string ToString() const { return strings::Substitute("[$0, $1]($2) ", left, right, id); } + + int left, right, id; +}; + +// A wrapper around an int which can be compared with IntTraits::compare() +// but also can keep a counter of how many times it has been compared. Used +// for TestBigO below. +struct CountingQueryPoint { + explicit CountingQueryPoint(int v) : val(v), count(new int(0)) {} + + int val; + std::shared_ptr<int> count; +}; + +// Traits definition for intervals made up of ints on either end. +struct IntTraits { + typedef int point_type; + typedef IntInterval interval_type; + static point_type get_left(const IntInterval& x) { return x.left; } + static point_type get_right(const IntInterval& x) { return x.right; } + static int compare(int a, int b) { + if (a < b) return -1; + if (a > b) return 1; + return 0; + } + + static int compare(const CountingQueryPoint& q, int b) { + (*q.count)++; + return compare(q.val, b); + } + static int compare(int a, const CountingQueryPoint& b) { return -compare(b, a); } + + static int compare(const std::optional<int>& a, const int b, const EndpointIfNone& type) { + if (a == std::nullopt) { + return ((POSITIVE_INFINITY == type) ? 1 : -1); + } + + return compare(*a, b); + } + + static int compare(const int a, const std::optional<int>& b, const EndpointIfNone& type) { + return -compare(b, a, type); + } +}; + +// Compare intervals in an arbitrary but consistent way - this is only +// used for verifying that the two algorithms come up with the same results. +// It's not necessary to define this to use an interval tree. +static bool CompareIntervals(const IntInterval& a, const IntInterval& b) { + return std::make_tuple(a.left, a.right, a.id) < std::make_tuple(b.left, b.right, b.id); +} + +// Stringify a list of int intervals, for easy test error reporting. +static string Stringify(const vector<IntInterval>& intervals) { + string ret; + bool first = true; + for (const IntInterval& interval : intervals) { + if (!first) { + ret.append(","); + } + ret.append(interval.ToString()); + } + return ret; +} + +// Find any intervals in 'intervals' which contain 'query_point' by brute force. +static void FindContainingBruteForce(const vector<IntInterval>& intervals, int query_point, + vector<IntInterval>* results) { + for (const IntInterval& i : intervals) { + if (query_point >= i.left && query_point <= i.right) { + results->push_back(i); + } + } +} + +// Find any intervals in 'intervals' which intersect 'query_interval' by brute force. +static void FindIntersectingBruteForce(const vector<IntInterval>& intervals, + const std::optional<int>& lower, + const std::optional<int>& upper, + vector<IntInterval>* results) { + for (const IntInterval& i : intervals) { + if (i.Intersects(lower, upper)) { + results->push_back(i); + } + } +} + +// Verify that IntervalTree::FindContainingPoint yields the same results as the naive +// brute-force O(n) algorithm. +static void VerifyFindContainingPoint(const vector<IntInterval>& all_intervals, + const IntervalTree<IntTraits>& tree, int query_point) { + vector<IntInterval> results; + tree.FindContainingPoint(query_point, &results); + std::sort(results.begin(), results.end(), CompareIntervals); + + vector<IntInterval> brute_force; + FindContainingBruteForce(all_intervals, query_point, &brute_force); + std::sort(brute_force.begin(), brute_force.end(), CompareIntervals); + + SCOPED_TRACE(Stringify(all_intervals) + StringPrintf(" {q=%d}", query_point)); + EXPECT_EQ(Stringify(brute_force), Stringify(results)); +} + +// Verify that IntervalTree::FindIntersectingInterval yields the same results as the naive +// brute-force O(n) algorithm. +static void VerifyFindIntersectingInterval(const vector<IntInterval>& all_intervals, + const IntervalTree<IntTraits>& tree, + const IntInterval& query_interval) { + const auto& Process = [&](const std::optional<int>& lower, const std::optional<int>& upper) { + vector<IntInterval> results; + tree.FindIntersectingInterval(lower, upper, &results); + std::sort(results.begin(), results.end(), CompareIntervals); + + vector<IntInterval> brute_force; + FindIntersectingBruteForce(all_intervals, lower, upper, &brute_force); + std::sort(brute_force.begin(), brute_force.end(), CompareIntervals); + EXPECT_EQ(Stringify(brute_force), Stringify(results)); + }; + + { + // [lower, upper) + std::optional<int> lower = query_interval.left; + std::optional<int> upper = query_interval.right; + SCOPED_TRACE(Stringify(all_intervals) + StringPrintf(" {q=[%d, %d)}", *lower, *upper)); + Process(lower, upper); + } + + { + // [-OO, upper) + std::optional<int> lower = std::nullopt; + std::optional<int> upper = query_interval.right; + SCOPED_TRACE(Stringify(all_intervals) + StringPrintf(" {q=[-OO, %d)}", *upper)); + Process(lower, upper); + } + + { + // [lower, +OO) + std::optional<int> lower = query_interval.left; + std::optional<int> upper = std::nullopt; + SCOPED_TRACE(Stringify(all_intervals) + StringPrintf(" {q=[%d, +OO)}", *lower)); + Process(lower, upper); + } + + { + // [-OO, +OO) + std::optional<int> lower = query_interval.left; + std::optional<int> upper = std::nullopt; + SCOPED_TRACE(Stringify(all_intervals) + StringPrintf(" {q=[-OO, +OO)}")); + Process(lower, upper); + } +} + +static vector<IntInterval> CreateRandomIntervals(int n = 100) { + vector<IntInterval> intervals; + for (int i = 0; i < n; i++) { + int l = rand_rng_int(0, 100); // NOLINT(runtime/threadsafe_fn) + int r = l + rand_rng_int(0, 20); // NOLINT(runtime/threadsafe_fn) + intervals.emplace_back(l, r, i); + } + return intervals; +} + +TEST_F(TestIntervalTree, TestBasic) { + vector<IntInterval> intervals; + intervals.emplace_back(1, 2, 1); + intervals.emplace_back(3, 4, 2); + intervals.emplace_back(1, 4, 3); + IntervalTree<IntTraits> t(intervals); + + for (int i = 0; i <= 5; i++) { + VerifyFindContainingPoint(intervals, t, i); + + for (int j = i; j <= 5; j++) { + VerifyFindIntersectingInterval(intervals, t, IntInterval(i, j, 0)); + } + } +} + +TEST_F(TestIntervalTree, TestRandomized) { + // Generate 100 random intervals spanning 0-200 and build an interval tree from them. + vector<IntInterval> intervals = CreateRandomIntervals(); + IntervalTree<IntTraits> t(intervals); + + // Test that we get the correct result on every possible query. + for (int i = -1; i < 201; i++) { + VerifyFindContainingPoint(intervals, t, i); + } + + // Test that we get the correct result for random intervals + for (int i = 0; i < 100; i++) { + int l = rand_rng_int(0, 100); // NOLINT(runtime/threadsafe_fn) + int r = rand_rng_int(l, l + 100); // NOLINT(runtime/threadsafe_fn) + VerifyFindIntersectingInterval(intervals, t, IntInterval(l, r)); + } +} + +TEST_F(TestIntervalTree, TestEmpty) { + vector<IntInterval> empty; + IntervalTree<IntTraits> t(empty); + + VerifyFindContainingPoint(empty, t, 1); + VerifyFindIntersectingInterval(empty, t, IntInterval(1, 2, 0)); +} + +TEST_F(TestIntervalTree, TestBigO) { +#ifndef NDEBUG + LOG(WARNING) << "big-O results are not valid if DCHECK is enabled"; + return; +#endif + LOG(INFO) << "num_int\tnum_q\tresults\tsimple\tbatch"; + for (int num_intervals = 1; num_intervals < 2000; num_intervals *= 2) { + vector<IntInterval> intervals = CreateRandomIntervals(num_intervals); + IntervalTree<IntTraits> t(intervals); + for (int num_queries = 1; num_queries < 2000; num_queries *= 2) { + vector<CountingQueryPoint> queries; + for (int i = 0; i < num_queries; i++) { + queries.emplace_back(rand_rng_int(0, 100)); + } + std::sort(queries.begin(), queries.end(), + [](const CountingQueryPoint& a, const CountingQueryPoint& b) { + return a.val < b.val; + }); + + // Test using batch algorithm. + int num_results_batch = 0; + t.ForEachIntervalContainingPoints( + queries, [&](CountingQueryPoint query_point, const IntInterval& interval) { + num_results_batch++; + }); + int num_comparisons_batch = 0; + for (const auto& q : queries) { + num_comparisons_batch += *q.count; + *q.count = 0; + } + + // Test using one-by-one queries. + int num_results_simple = 0; + for (auto& q : queries) { + vector<IntInterval> tmp_intervals; + t.FindContainingPoint(q, &tmp_intervals); + num_results_simple += tmp_intervals.size(); + } + int num_comparisons_simple = 0; + for (const auto& q : queries) { + num_comparisons_simple += *q.count; + } + ASSERT_EQ(num_results_simple, num_results_batch); + + LOG(INFO) << num_intervals << "\t" << num_queries << "\t" << num_results_simple << "\t" + << num_comparisons_simple << "\t" << num_comparisons_batch; + } + } +} + +TEST_F(TestIntervalTree, TestMultiQuery) { + const int kNumQueries = 1; + vector<IntInterval> intervals = CreateRandomIntervals(10); + IntervalTree<IntTraits> t(intervals); + + // Generate random queries. + vector<int> queries; + for (int i = 0; i < kNumQueries; i++) { + queries.push_back(rand_rng_int(0, 100)); + } + std::sort(queries.begin(), queries.end()); + + vector<pair<string, int>> results_simple; + for (int q : queries) { + vector<IntInterval> tmp_intervals; + t.FindContainingPoint(q, &tmp_intervals); + for (const auto& interval : tmp_intervals) { + results_simple.emplace_back(interval.ToString(), q); + } + } + + vector<pair<string, int>> results_batch; + t.ForEachIntervalContainingPoints(queries, [&](int query_point, const IntInterval& interval) { + results_batch.emplace_back(interval.ToString(), query_point); + }); + + // Check the property that, when the batch query points are in sorted order, + // the results are grouped by interval, and within each interval, sorted by + // query point. Each interval may have at most two groups. + std::optional<pair<string, int>> prev = std::nullopt; + std::map<string, int> intervals_seen; + for (int i = 0; i < results_batch.size(); i++) { + const auto& cur = results_batch[i]; + // If it's another query point hitting the same interval, + // make sure the query points are returned in order. + if (prev && prev->first == cur.first) { + EXPECT_GE(cur.second, prev->second) << prev->first; + } else { + // It's the start of a new interval's data. Make sure that we don't + // see the same interval twice. + EXPECT_LE(++intervals_seen[cur.first], 2) + << "Saw more than two groups for interval " << cur.first; + } + prev = cur; + } + + std::sort(results_simple.begin(), results_simple.end()); + std::sort(results_batch.begin(), results_batch.end()); + ASSERT_EQ(results_simple, results_batch); +} + +} // namespace doris --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org