This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-doris.git
The following commit(s) were added to refs/heads/master by this push: new a75e4a1 Window funnel (#8485) a75e4a1 is described below commit a75e4a1469f2a9e0c1e1ab63c787b0d60c70163b Author: dataroaring <98214048+dataroar...@users.noreply.github.com> AuthorDate: Sat Apr 2 22:08:50 2022 +0800 Window funnel (#8485) Add new feature window funnel --- be/src/exprs/aggregate_functions.cpp | 207 ++++++++++ be/src/exprs/aggregate_functions.h | 10 + be/src/vec/CMakeLists.txt | 1 + .../aggregate_function_simple_factory.cpp | 5 +- .../aggregate_function_window_funnel.cpp | 37 ++ .../aggregate_function_window_funnel.h | 214 ++++++++++ be/test/exprs/CMakeLists.txt | 2 +- be/test/exprs/window_funnel_test.cpp | 431 +++++++++++++++++++++ be/test/vec/aggregate_functions/CMakeLists.txt | 1 + .../aggregate_functions/vec_window_funnel_test.cpp | 425 ++++++++++++++++++++ .../java/org/apache/doris/analysis/Analyzer.java | 1 - .../apache/doris/analysis/FunctionCallExpr.java | 72 +++- .../apache/doris/catalog/AggregateFunction.java | 2 +- .../java/org/apache/doris/catalog/FunctionSet.java | 31 ++ .../org/apache/doris/analysis/AggregateTest.java | 93 ++++- 15 files changed, 1508 insertions(+), 24 deletions(-) diff --git a/be/src/exprs/aggregate_functions.cpp b/be/src/exprs/aggregate_functions.cpp index 42661e8..09d9b1b 100644 --- a/be/src/exprs/aggregate_functions.cpp +++ b/be/src/exprs/aggregate_functions.cpp @@ -2254,6 +2254,212 @@ void AggregateFunctions::offset_fn_update(FunctionContext* ctx, const T& src, co *dst = src; } +// Refer to AggregateFunctionWindowFunnel.h in https://github.com/ClickHouse/ClickHouse.git +struct WindowFunnelState { + std::vector<std::pair<DateTimeValue, int>> events; + int max_event_level; + bool sorted; + int64_t window; + + WindowFunnelState() { + sorted = true; + max_event_level = 0; + window = 0; + } + + void add(DateTimeValue& timestamp, int event_idx, int event_num) { + max_event_level = event_num; + if (sorted && events.size() > 0) { + if (events.back().first == timestamp) { + sorted = events.back().second <= event_idx; + } else { + sorted = events.back().first < timestamp; + } + } + events.emplace_back(timestamp, event_idx); + } + + void sort() { + if (sorted) { + return; + } + std::stable_sort(events.begin(), events.end()); + } + + int get_event_level() { + std::vector<std::optional<DateTimeValue>> events_timestamp(max_event_level); + for (int64_t i = 0; i < events.size(); i++) { + int& event_idx = events[i].second; + DateTimeValue& timestamp = events[i].first; + if (event_idx == 0) { + events_timestamp[0] = timestamp; + continue; + } + if (events_timestamp[event_idx - 1].has_value()) { + DateTimeValue& first_timestamp = events_timestamp[event_idx - 1].value(); + DateTimeValue last_timestamp = first_timestamp; + TimeInterval interval(SECOND, window, false); + last_timestamp.date_add_interval(interval, SECOND); + + if (timestamp <= last_timestamp) { + events_timestamp[event_idx] = first_timestamp; + if (event_idx + 1 == max_event_level) { + // Usually, max event level is small. + return max_event_level; + } + } + } + } + + for (int64_t i = events_timestamp.size() - 1; i >= 0; i--) { + if (events_timestamp[i].has_value()) { + return i + 1; + } + } + + return 0; + } + + void merge(WindowFunnelState *other) { + if (other->events.empty()) { + return; + } + + int64_t orig_size = events.size(); + events.insert(std::end(events), std::begin(other->events), std::end(other->events)); + const auto begin = std::begin(events); + const auto middle = std::next(events.begin(), orig_size); + const auto end = std::end(events); + if (!other->sorted) { + std::stable_sort(middle, end); + } + + if (!sorted) { + std::stable_sort(begin, middle); + } + std::inplace_merge(begin, middle, end); + max_event_level = max_event_level > 0 ? max_event_level : other->max_event_level; + window = window > 0 ? window : other->window; + + sorted = true; + } + + int64_t serialized_size() { + return sizeof(int) + sizeof(int64_t) + sizeof(uint64_t) + + events.size() * (sizeof(int64_t) + sizeof(int)); + } + + void serialize(uint8_t *buf) { + memcpy(buf, &max_event_level, sizeof(int)); + buf += sizeof(int); + memcpy(buf, &window, sizeof(int64_t)); + buf += sizeof(int64_t); + + uint64_t event_num = events.size(); + memcpy(buf, &event_num, sizeof(uint64_t)); + buf += sizeof(uint64_t); + for (int64_t i = 0; i < events.size(); i++) { + int64_t timestamp = events[i].first; + int event_idx = events[i].second; + memcpy(buf, ×tamp, sizeof(int64_t)); + buf += sizeof(int64_t); + memcpy(buf, &event_idx, sizeof(int)); + buf += sizeof(int); + } + } + + void deserialize(uint8_t *buf) { + uint64_t size; + + memcpy(&max_event_level, buf, sizeof(int)); + buf += sizeof(int); + memcpy(&window, buf, sizeof(int64_t)); + buf += sizeof(int64_t); + memcpy(&size, buf, sizeof(uint64_t)); + buf += sizeof(uint64_t); + for (int64_t i = 0; i < size; i++) { + int64_t timestamp; + int event_idx; + + memcpy(×tamp, buf, sizeof(int64_t)); + buf += sizeof(int64_t); + memcpy(&event_idx, buf, sizeof(int)); + buf += sizeof(int); + DateTimeValue time_value; + time_value.from_date_int64(timestamp); + add(time_value, event_idx, max_event_level); + } + } +}; + +void AggregateFunctions::window_funnel_init(FunctionContext* ctx, StringVal* dst) { + dst->is_null = false; + dst->len = sizeof(WindowFunnelState); + WindowFunnelState* state = new WindowFunnelState(); + dst->ptr = (uint8_t*)state; + // constant args at index 0 and 1 + DCHECK(ctx->is_arg_constant(0)); + BigIntVal* window = reinterpret_cast<BigIntVal*>(ctx->get_constant_arg(0)); + state->window = window->val; + // TODO handle mode in the future +} + +void AggregateFunctions::window_funnel_update(FunctionContext* ctx, const BigIntVal& window, + const StringVal& mode, const DateTimeVal& timestamp, + int num_cond, const BooleanVal* conds, StringVal* dst) { + DCHECK(dst->ptr != nullptr); + DCHECK_EQ(sizeof(WindowFunnelState), dst->len); + + if (timestamp.is_null) { + return; + } + + WindowFunnelState* state = reinterpret_cast<WindowFunnelState*>(dst->ptr); + for (int i = 0; i < num_cond; i++) { + if (conds[i].is_null) { + continue; + } + if (conds[i].val) { + DateTimeValue time_value = DateTimeValue::from_datetime_val(timestamp); + state->add(time_value, i, num_cond); + } + } +} + +StringVal AggregateFunctions::window_funnel_serialize(FunctionContext* ctx, + const StringVal& src) { + WindowFunnelState* state = reinterpret_cast<WindowFunnelState*>(src.ptr); + int64_t serialized_size = state->serialized_size(); + StringVal result(ctx, sizeof(double) + serialized_size); + state->serialize(result.ptr); + + delete state; + return result; +} + +void AggregateFunctions::window_funnel_merge(FunctionContext* ctx, const StringVal& src, + StringVal* dst) { + DCHECK(dst->ptr != nullptr); + DCHECK_EQ(sizeof(WindowFunnelState), dst->len); + WindowFunnelState* dst_state = reinterpret_cast<WindowFunnelState*>(dst->ptr); + + WindowFunnelState* src_state = new WindowFunnelState; + + src_state->deserialize(src.ptr); + dst_state->merge(src_state); + delete src_state; +} + +IntVal AggregateFunctions::window_funnel_finalize(FunctionContext* ctx, const StringVal& src) { + DCHECK(!src.is_null); + + WindowFunnelState* state = reinterpret_cast<WindowFunnelState*>(src.ptr); + state->sort(); + int val = state->get_event_level(); + delete state; + return doris_udf::IntVal(val); +} + // Stamp out the templates for the types we need. template void AggregateFunctions::init_zero_null<BigIntVal>(FunctionContext*, BigIntVal* dst); template void AggregateFunctions::init_zero_null<LargeIntVal>(FunctionContext*, LargeIntVal* dst); @@ -2729,4 +2935,5 @@ template void AggregateFunctions::percentile_approx_update<doris_udf::DoubleVal> template void AggregateFunctions::percentile_approx_update<doris_udf::DoubleVal>( FunctionContext* ctx, const doris_udf::DoubleVal&, const doris_udf::DoubleVal&, const doris_udf::DoubleVal&, doris_udf::StringVal*); + } // namespace doris diff --git a/be/src/exprs/aggregate_functions.h b/be/src/exprs/aggregate_functions.h index b3b19ab..b010692 100644 --- a/be/src/exprs/aggregate_functions.h +++ b/be/src/exprs/aggregate_functions.h @@ -347,6 +347,16 @@ public: static void offset_fn_update(doris_udf::FunctionContext*, const T& src, const doris_udf::BigIntVal&, const T&, T* dst); + // windowFunnel + static void window_funnel_init(FunctionContext* ctx, StringVal* dst); + static void window_funnel_update(FunctionContext* ctx, const BigIntVal& window, + const StringVal& mode, const DateTimeVal& timestamp, + int num_cond, const BooleanVal* conds, StringVal* dst); + static void window_funnel_merge(FunctionContext* ctx, const StringVal& src, + StringVal* dst); + static StringVal window_funnel_serialize(FunctionContext* ctx, const StringVal& src); + static IntVal window_funnel_finalize(FunctionContext* ctx, const StringVal& src); + // todo(kks): keep following HLL methods only for backward compatibility, we should remove these methods // when doris 0.12 release static void hll_init(doris_udf::FunctionContext*, doris_udf::StringVal* slot); diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt index e952d98..9dc5089 100644 --- a/be/src/vec/CMakeLists.txt +++ b/be/src/vec/CMakeLists.txt @@ -20,6 +20,7 @@ set(LIBRARY_OUTPUT_PATH "${BUILD_DIR}/src/vec") set(EXECUTABLE_OUTPUT_PATH "${BUILD_DIR}/src/vec") set(VEC_FILES + aggregate_functions/aggregate_function_window_funnel.cpp aggregate_functions/aggregate_function_avg.cpp aggregate_functions/aggregate_function_count.cpp aggregate_functions/aggregate_function_distinct.cpp diff --git a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp index 3be7d18..fcf333c 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp +++ b/be/src/vec/aggregate_functions/aggregate_function_simple_factory.cpp @@ -43,6 +43,7 @@ void register_aggregate_function_topn(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_approx_count_distinct(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_group_concat(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_percentile(AggregateFunctionSimpleFactory& factory); +void register_aggregate_function_window_funnel(AggregateFunctionSimpleFactory& factory); void register_aggregate_function_percentile_approx(AggregateFunctionSimpleFactory& factory); AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { static std::once_flag oc; @@ -62,6 +63,8 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { register_aggregate_function_approx_count_distinct(instance); register_aggregate_function_group_concat(instance); register_aggregate_function_percentile(instance); + register_aggregate_function_percentile_approx(instance); + register_aggregate_function_window_funnel(instance); // if you only register function with no nullable, and wants to add nullable automatically, you should place function above this line register_aggregate_function_combinator_null(instance); @@ -75,4 +78,4 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() { return instance; } -} // namespace doris::vectorized \ No newline at end of file +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_window_funnel.cpp b/be/src/vec/aggregate_functions/aggregate_function_window_funnel.cpp new file mode 100644 index 0000000..a889299 --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_window_funnel.cpp @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "vec/aggregate_functions/aggregate_function_window_funnel.h" + +#include "common/logging.h" +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/factory_helpers.h" +#include "vec/aggregate_functions/helpers.h" + +namespace doris::vectorized { + +AggregateFunctionPtr create_aggregate_function_window_funnel(const std::string& name, + const DataTypes& argument_types, + const Array& parameters, + const bool result_is_nullable) { + return std::make_shared<AggregateFunctionWindowFunnel>(argument_types); +} + +void register_aggregate_function_window_funnel(AggregateFunctionSimpleFactory& factory) { + factory.register_function("window_funnel", create_aggregate_function_window_funnel, false); +} +} // namespace doris::vectorized diff --git a/be/src/vec/aggregate_functions/aggregate_function_window_funnel.h b/be/src/vec/aggregate_functions/aggregate_function_window_funnel.h new file mode 100644 index 0000000..f4364ee --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_window_funnel.h @@ -0,0 +1,214 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/AggregateFunctionWindowFunnel.h +// and modified by Doris + +#pragma once + +#include "common/logging.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/columns_number.h" +#include "vec/data_types/data_type_decimal.h" +#include "vec/io/var_int.h" + +namespace doris::vectorized { + +struct WindowFunnelState { + std::vector<std::pair<VecDateTimeValue, int>> events; + int max_event_level; + bool sorted; + int64_t window; + + WindowFunnelState() { + sorted = true; + max_event_level = 0; + window = 0; + } + + void reset() { + sorted = true; + max_event_level = 0; + window = 0; + events.shrink_to_fit(); + } + + void add(const VecDateTimeValue& timestamp, int event_idx, int event_num, int64_t win) { + window = win; + max_event_level = event_num; + if (sorted && events.size() > 0) { + if (events.back().first == timestamp) { + sorted = events.back().second <= event_idx; + } else { + sorted = events.back().first < timestamp; + } + } + events.emplace_back(timestamp, event_idx); + } + + void sort() { + if (sorted) { + return; + } + std::stable_sort(events.begin(), events.end()); + } + + int get() const { + std::vector<std::optional<VecDateTimeValue>> events_timestamp(max_event_level); + for (int64_t i = 0; i < events.size(); i++) { + const int& event_idx = events[i].second; + const VecDateTimeValue& timestamp = events[i].first; + if (event_idx == 0) { + events_timestamp[0] = timestamp; + continue; + } + if (events_timestamp[event_idx - 1].has_value()) { + const VecDateTimeValue& first_timestamp = events_timestamp[event_idx - 1].value(); + VecDateTimeValue last_timestamp = first_timestamp; + TimeInterval interval(SECOND, window, false); + last_timestamp.date_add_interval(interval, SECOND); + + if (timestamp <= last_timestamp) { + events_timestamp[event_idx] = first_timestamp; + if (event_idx + 1 == max_event_level) { + // Usually, max event level is small. + return max_event_level; + } + } + } + } + + for (int64_t i = events_timestamp.size() - 1; i >= 0; i--) { + if (events_timestamp[i].has_value()) { + return i + 1; + } + } + + return 0; + } + + void merge(const WindowFunnelState& other) { + if (other.events.empty()) { + return; + } + + int64_t orig_size = events.size(); + events.insert(std::end(events), std::begin(other.events), std::end(other.events)); + const auto begin = std::begin(events); + const auto middle = std::next(events.begin(), orig_size); + const auto end = std::end(events); + if (!other.sorted) { + std::stable_sort(middle, end); + } + + if (!sorted) { + std::stable_sort(begin, middle); + } + std::inplace_merge(begin, middle, end); + max_event_level = max_event_level > 0 ? max_event_level : other.max_event_level; + window = window > 0 ? window : other.window; + + sorted = true; + } + + void write(BufferWritable &out) const { + write_var_int(max_event_level, out); + write_var_int(window, out); + write_var_int(events.size(), out); + + for (int64_t i = 0; i < events.size(); i++) { + int64_t timestamp = events[i].first; + int event_idx = events[i].second; + write_var_int(timestamp, out); + write_var_int(event_idx, out); + } + } + + void read(BufferReadable& in) { + int64_t event_level; + read_var_int(event_level, in); + max_event_level = (int)event_level; + read_var_int(window, in); + int64_t size = 0; + read_var_int(size, in); + for (int64_t i = 0; i < size; i++) { + int64_t timestamp; + int64_t event_idx; + + read_var_int(timestamp, in); + read_var_int(event_idx, in); + VecDateTimeValue time_value(timestamp); + add(time_value, (int)event_idx, max_event_level, window); + } + } +}; + +class AggregateFunctionWindowFunnel + : public IAggregateFunctionDataHelper<WindowFunnelState, + AggregateFunctionWindowFunnel> { +public: + AggregateFunctionWindowFunnel(const DataTypes& argument_types_) + : IAggregateFunctionDataHelper<WindowFunnelState, + AggregateFunctionWindowFunnel>(argument_types_, {}) { + } + + String get_name() const override { return "window_funnel"; } + + DataTypePtr get_return_type() const override { + return std::make_shared<DataTypeInt32>(); + } + + void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); } + + void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + Arena*) const override { + const auto& window = static_cast<const ColumnVector<Int64>&>(*columns[0]).get_data()[row_num]; + // TODO: handle mode in the future. + // be/src/olap/row_block2.cpp copy_data_to_column + const auto& timestamp = static_cast<const ColumnVector<VecDateTimeValue>&>(*columns[2]).get_data()[row_num]; + const int NON_EVENT_NUM = 3; + for (int i = NON_EVENT_NUM; i < get_argument_types().size(); i++) { + const auto& is_set = static_cast<const ColumnVector<UInt8>&>(*columns[i]).get_data()[row_num]; + if (is_set) { + this->data(place).add(timestamp, i - NON_EVENT_NUM, + get_argument_types().size() - NON_EVENT_NUM, window); + } + } + } + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + this->data(place).merge(this->data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + this->data(place).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + this->data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + this->data(const_cast<AggregateDataPtr>(place)).sort(); + assert_cast<ColumnInt32&>(to).get_data().push_back(data(place).get()); + } +}; + +} // namespace doris::vectorized diff --git a/be/test/exprs/CMakeLists.txt b/be/test/exprs/CMakeLists.txt index 4205be2..814ed58 100644 --- a/be/test/exprs/CMakeLists.txt +++ b/be/test/exprs/CMakeLists.txt @@ -40,4 +40,4 @@ ADD_BE_TEST(runtime_filter_test) ADD_BE_TEST(bloom_filter_predicate_test) ADD_BE_TEST(array_functions_test) ADD_BE_TEST(quantile_function_test) - +ADD_BE_TEST(window_funnel_test) diff --git a/be/test/exprs/window_funnel_test.cpp b/be/test/exprs/window_funnel_test.cpp new file mode 100644 index 0000000..f8f190a --- /dev/null +++ b/be/test/exprs/window_funnel_test.cpp @@ -0,0 +1,431 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> + +#include "common/logging.h" +#include "exprs/aggregate_functions.h" +#include "runtime/datetime_value.h" +#include "testutil/function_utils.h" + +namespace doris { + +class WindowFunnelTest : public testing::Test { +public: + WindowFunnelTest() {} +}; + +TEST_F(WindowFunnelTest, testMax4SortedNoMerge) { + FunctionUtils* futil = new FunctionUtils(); + doris_udf::FunctionContext* context = futil->get_fn_ctx(); + + const int NUM_CONDS = 4; + for (int i = -1; i < NUM_CONDS + 4; i++) { + StringVal stringVal1; + BigIntVal window(i); + StringVal mode("default"); + std::vector<doris_udf::AnyVal*> constant_args; + constant_args.emplace_back(&window); + constant_args.emplace_back(&mode); + context->impl()->set_constant_args(std::move(constant_args)); + + AggregateFunctions::window_funnel_init(context, &stringVal1); + + DateTimeVal timestamp; + DateTimeValue time_value; + time_value.set_time(2020, 2, 28, 0, 0, 1, 0); + time_value.to_datetime_val(×tamp); + BooleanVal conds[NUM_CONDS] = {true, false, false, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds, &stringVal1); + + time_value.set_time(2020, 2, 28, 0, 0, 2, 0); + time_value.to_datetime_val(×tamp); + BooleanVal conds1[NUM_CONDS] = {false, true, false, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds1, &stringVal1); + + time_value.set_time(2020, 2, 28, 0, 0, 3, 0); + time_value.to_datetime_val(×tamp); + BooleanVal conds2[NUM_CONDS] = {false, false, true, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds2, &stringVal1); + + time_value.set_time(2020, 2, 28, 0, 0, 4, 0); + time_value.to_datetime_val(×tamp); + BooleanVal conds3[NUM_CONDS] = {false, false, false, true}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds3, &stringVal1); + + IntVal v = AggregateFunctions::window_funnel_finalize(context, stringVal1); + LOG(INFO) << "event num: " << NUM_CONDS << " window: " << window.val; + ASSERT_EQ(v.val, i < 0 ? 1 : (i < NUM_CONDS ? i + 1 : NUM_CONDS)); + } + delete futil; +} + +TEST_F(WindowFunnelTest, testMax4SortedMerge) { + FunctionUtils* futil = new FunctionUtils(); + doris_udf::FunctionContext* context = futil->get_fn_ctx(); + + const int NUM_CONDS = 4; + for (int i = -1; i < NUM_CONDS + 4; i++) { + StringVal stringVal1; + BigIntVal window(i); + StringVal mode("default"); + std::vector<doris_udf::AnyVal*> constant_args; + constant_args.emplace_back(&window); + constant_args.emplace_back(&mode); + context->impl()->set_constant_args(std::move(constant_args)); + + AggregateFunctions::window_funnel_init(context, &stringVal1); + + DateTimeVal timestamp; + DateTimeValue time_value; + time_value.set_time(2020, 2, 28, 0, 0, 1, 0); + time_value.to_datetime_val(×tamp); + + BooleanVal conds[NUM_CONDS] = {true, false, false, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds, &stringVal1); + + time_value.set_time(2020, 2, 28, 0, 0, 2, 0); + time_value.to_datetime_val(×tamp); + BooleanVal conds1[NUM_CONDS] = {false, true, false, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds1, &stringVal1); + + time_value.set_time(2020, 2, 28, 0, 0, 3, 0); + time_value.to_datetime_val(×tamp); + BooleanVal conds2[NUM_CONDS] = {false, false, true, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds2, &stringVal1); + + time_value.set_time(2020, 2, 28, 0, 0, 4, 0); + time_value.to_datetime_val(×tamp); + BooleanVal conds3[NUM_CONDS] = {false, false, false, true}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds3, &stringVal1); + + StringVal s = AggregateFunctions::window_funnel_serialize(context, stringVal1); + + StringVal stringVal2; + AggregateFunctions::window_funnel_init(context, &stringVal2); + AggregateFunctions::window_funnel_merge(context, s, &stringVal2); + IntVal v = AggregateFunctions::window_funnel_finalize(context, stringVal2); + LOG(INFO) << "event num: " << NUM_CONDS << " window: " << window.val; + ASSERT_EQ(v.val, i < 0 ? 1 : (i < NUM_CONDS ? i + 1 : NUM_CONDS)); + } + delete futil; +} + +TEST_F(WindowFunnelTest, testMax4ReverseSortedNoMerge) { + FunctionUtils* futil = new FunctionUtils(); + doris_udf::FunctionContext* context = futil->get_fn_ctx(); + + const int NUM_CONDS = 4; + for (int i = -1; i < NUM_CONDS + 4; i++) { + StringVal stringVal1; + BigIntVal window(i); + StringVal mode("default"); + std::vector<doris_udf::AnyVal*> constant_args; + constant_args.emplace_back(&window); + constant_args.emplace_back(&mode); + context->impl()->set_constant_args(std::move(constant_args)); + + AggregateFunctions::window_funnel_init(context, &stringVal1); + + DateTimeVal timestamp; + DateTimeValue time_value; + time_value.set_time(2020, 2, 28, 0, 0, 3, 0); + time_value.to_datetime_val(×tamp); + + BooleanVal conds[NUM_CONDS] = {true, false, false, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds, &stringVal1); + + time_value.set_time(2020, 2, 28, 0, 0, 2, 0); + time_value.to_datetime_val(×tamp); + BooleanVal conds1[NUM_CONDS] = {false, true, false, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds1, &stringVal1); + + time_value.set_time(2020, 2, 28, 0, 0, 1, 0); + time_value.to_datetime_val(×tamp); + BooleanVal conds2[NUM_CONDS] = {false, false, true, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds2, &stringVal1); + + time_value.set_time(2020, 2, 28, 0, 0, 0, 0); + time_value.to_datetime_val(×tamp); + BooleanVal conds3[NUM_CONDS] = {false, false, false, true}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds3, &stringVal1); + + IntVal v = AggregateFunctions::window_funnel_finalize(context, stringVal1); + LOG(INFO) << "event num: " << NUM_CONDS << " window: " << window.val; + ASSERT_EQ(v.val, 1); + } + delete futil; +} + +TEST_F(WindowFunnelTest, testMax4ReverseSortedMerge) { + FunctionUtils* futil = new FunctionUtils(); + doris_udf::FunctionContext* context = futil->get_fn_ctx(); + + const int NUM_CONDS = 4; + for (int i = -1; i < NUM_CONDS + 4; i++) { + StringVal stringVal1; + BigIntVal window(i); + StringVal mode("default"); + std::vector<doris_udf::AnyVal*> constant_args; + constant_args.emplace_back(&window); + constant_args.emplace_back(&mode); + context->impl()->set_constant_args(std::move(constant_args)); + + AggregateFunctions::window_funnel_init(context, &stringVal1); + + DateTimeVal timestamp; + DateTimeValue time_value; + time_value.set_time(2020, 2, 28, 0, 0, 3, 0); + time_value.to_datetime_val(×tamp); + + BooleanVal conds[NUM_CONDS] = {true, false, false, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds, &stringVal1); + + time_value.set_time(2020, 2, 28, 0, 0, 2, 0); + time_value.to_datetime_val(×tamp); + BooleanVal conds1[NUM_CONDS] = {false, true, false, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds1, &stringVal1); + + time_value.set_time(2020, 2, 28, 0, 0, 1, 0); + time_value.to_datetime_val(×tamp); + BooleanVal conds2[NUM_CONDS] = {false, false, true, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds2, &stringVal1); + + time_value.set_time(2020, 2, 28, 0, 0, 0, 0); + time_value.to_datetime_val(×tamp); + BooleanVal conds3[NUM_CONDS] = {false, false, false, true}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds3, &stringVal1); + + StringVal s = AggregateFunctions::window_funnel_serialize(context, stringVal1); + + StringVal stringVal2; + AggregateFunctions::window_funnel_init(context, &stringVal2); + AggregateFunctions::window_funnel_merge(context, s, &stringVal2); + IntVal v = AggregateFunctions::window_funnel_finalize(context, stringVal2); + LOG(INFO) << "event num: " << NUM_CONDS << " window: " << window.val; + ASSERT_EQ(v.val, 1); + } + delete futil; +} + +TEST_F(WindowFunnelTest, testMax4DuplicateSortedNoMerge) { + FunctionUtils* futil = new FunctionUtils(); + doris_udf::FunctionContext* context = futil->get_fn_ctx(); + + const int NUM_CONDS = 4; + for (int i = -1; i < NUM_CONDS + 4; i++) { + StringVal stringVal1; + BigIntVal window(i); + StringVal mode("default"); + std::vector<doris_udf::AnyVal*> constant_args; + constant_args.emplace_back(&window); + constant_args.emplace_back(&mode); + context->impl()->set_constant_args(std::move(constant_args)); + + AggregateFunctions::window_funnel_init(context, &stringVal1); + + DateTimeVal timestamp; + DateTimeValue time_value; + time_value.set_time(2020, 2, 28, 0, 0, 0, 0); + time_value.to_datetime_val(×tamp); + + BooleanVal conds[NUM_CONDS] = {true, false, false, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds, &stringVal1); + + time_value.set_time(2020, 2, 28, 0, 0, 1, 0); + time_value.to_datetime_val(×tamp); + BooleanVal conds1[NUM_CONDS] = {false, true, false, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds1, &stringVal1); + + time_value.set_time(2020, 2, 28, 0, 0, 2, 0); + time_value.to_datetime_val(×tamp); + BooleanVal conds2[NUM_CONDS] = {true, false, false, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds2, &stringVal1); + + time_value.set_time(2020, 2, 28, 0, 0, 3, 0); + time_value.to_datetime_val(×tamp); + BooleanVal conds3[NUM_CONDS] = {false, false, false, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds3, &stringVal1); + + IntVal v = AggregateFunctions::window_funnel_finalize(context, stringVal1); + LOG(INFO) << "event num: " << NUM_CONDS << " window: " << window.val; + ASSERT_EQ(v.val, i < 0 ? 1 : (i < 2 ? i + 1 : 2)); + } + delete futil; +} + +TEST_F(WindowFunnelTest, testMax4DuplicateSortedMerge) { + FunctionUtils* futil = new FunctionUtils(); + doris_udf::FunctionContext* context = futil->get_fn_ctx(); + + const int NUM_CONDS = 4; + for (int i = -1; i < NUM_CONDS + 4; i++) { + StringVal stringVal1; + BigIntVal window(i); + StringVal mode("default"); + std::vector<doris_udf::AnyVal*> constant_args; + constant_args.emplace_back(&window); + constant_args.emplace_back(&mode); + context->impl()->set_constant_args(std::move(constant_args)); + + AggregateFunctions::window_funnel_init(context, &stringVal1); + + DateTimeVal timestamp; + DateTimeValue time_value; + time_value.set_time(2020, 2, 28, 0, 0, 0, 0); + time_value.to_datetime_val(×tamp); + + BooleanVal conds[NUM_CONDS] = {true, false, false, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds, &stringVal1); + + time_value.set_time(2020, 2, 28, 0, 0, 1, 0); + time_value.to_datetime_val(×tamp); + BooleanVal conds1[NUM_CONDS] = {false, true, false, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds1, &stringVal1); + + time_value.set_time(2020, 2, 28, 0, 0, 2, 0); + time_value.to_datetime_val(×tamp); + BooleanVal conds2[NUM_CONDS] = {true, false, false, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds2, &stringVal1); + + time_value.set_time(2020, 2, 28, 0, 0, 3, 0); + time_value.to_datetime_val(×tamp); + BooleanVal conds3[NUM_CONDS] = {false, false, false, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, NUM_CONDS, + conds3, &stringVal1); + + StringVal s = AggregateFunctions::window_funnel_serialize(context, stringVal1); + + StringVal stringVal2; + AggregateFunctions::window_funnel_init(context, &stringVal2); + AggregateFunctions::window_funnel_merge(context, s, &stringVal2); + IntVal v = AggregateFunctions::window_funnel_finalize(context, stringVal2); + LOG(INFO) << "event num: " << NUM_CONDS << " window: " << window.val; + ASSERT_EQ(v.val, i < 0 ? 1 : (i < 2 ? i + 1 : 2)); + } + delete futil; +} + +TEST_F(WindowFunnelTest, testNoMatchedEvent) { + FunctionUtils* futil = new FunctionUtils(); + doris_udf::FunctionContext* context = futil->get_fn_ctx(); + + StringVal stringVal1; + BigIntVal window(0); + StringVal mode("default"); + std::vector<doris_udf::AnyVal*> constant_args; + constant_args.emplace_back(&window); + constant_args.emplace_back(&mode); + context->impl()->set_constant_args(std::move(constant_args)); + + AggregateFunctions::window_funnel_init(context, &stringVal1); + + DateTimeVal timestamp; + DateTimeValue time_value; + time_value.set_time(2020, 2, 28, 0, 0, 0, 0); + time_value.to_datetime_val(×tamp); + + BooleanVal conds[4] = {false, false, false, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, 4, + conds, &stringVal1); + + IntVal v = AggregateFunctions::window_funnel_finalize(context, stringVal1); + ASSERT_EQ(v.val, 0); + delete futil; +} + +TEST_F(WindowFunnelTest, testNoEvent) { + FunctionUtils* futil = new FunctionUtils(); + doris_udf::FunctionContext* context = futil->get_fn_ctx(); + + StringVal stringVal1; + BigIntVal window(0); + StringVal mode("default"); + std::vector<doris_udf::AnyVal*> constant_args; + constant_args.emplace_back(&window); + constant_args.emplace_back(&mode); + context->impl()->set_constant_args(std::move(constant_args)); + + AggregateFunctions::window_funnel_init(context, &stringVal1); + + IntVal v = AggregateFunctions::window_funnel_finalize(context, stringVal1); + ASSERT_EQ(v.val, 0); + + StringVal stringVal2; + AggregateFunctions::window_funnel_init(context, &stringVal2); + + v = AggregateFunctions::window_funnel_finalize(context, stringVal2); + ASSERT_EQ(v.val, 0); + + delete futil; +} + +TEST_F(WindowFunnelTest, testInputNull) { + FunctionUtils* futil = new FunctionUtils(); + doris_udf::FunctionContext* context = futil->get_fn_ctx(); + + BigIntVal window(0); + StringVal mode("default"); + std::vector<doris_udf::AnyVal*> constant_args; + constant_args.emplace_back(&window); + constant_args.emplace_back(&mode); + context->impl()->set_constant_args(std::move(constant_args)); + + StringVal stringVal1; + AggregateFunctions::window_funnel_init(context, &stringVal1); + + DateTimeVal timestamp = DateTimeVal::null(); + BooleanVal conds[4] = {false, false, false, false}; + AggregateFunctions::window_funnel_update(context, window, mode, timestamp, 4, + conds, &stringVal1); + + + IntVal v = AggregateFunctions::window_funnel_finalize(context, stringVal1); + ASSERT_EQ(v.val, 0); + + delete futil; +} + +} // namespace doris + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/be/test/vec/aggregate_functions/CMakeLists.txt b/be/test/vec/aggregate_functions/CMakeLists.txt index 5c7c278..13a1315 100644 --- a/be/test/vec/aggregate_functions/CMakeLists.txt +++ b/be/test/vec/aggregate_functions/CMakeLists.txt @@ -20,3 +20,4 @@ set(EXECUTABLE_OUTPUT_PATH "${BUILD_DIR}/test/vec/aggregate_functions") ADD_BE_TEST(agg_test) ADD_BE_TEST(agg_min_max_test) +ADD_BE_TEST(vec_window_funnel_test) diff --git a/be/test/vec/aggregate_functions/vec_window_funnel_test.cpp b/be/test/vec/aggregate_functions/vec_window_funnel_test.cpp new file mode 100644 index 0000000..8bcd255 --- /dev/null +++ b/be/test/vec/aggregate_functions/vec_window_funnel_test.cpp @@ -0,0 +1,425 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include <gtest/gtest.h> + +#include "common/logging.h" +#include "gtest/gtest.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/aggregate_functions/aggregate_function_simple_factory.h" +#include "vec/aggregate_functions/aggregate_function_topn.h" +#include "vec/columns/column_vector.h" +#include "vec/data_types/data_type.h" +#include "vec/data_types/data_type_number.h" +#include "vec/data_types/data_type_string.h" + +namespace doris::vectorized { + +void register_aggregate_function_window_funnel(AggregateFunctionSimpleFactory& factory); + +class WindowFunnelTest : public testing::Test { +public: + AggregateFunctionPtr agg_function; + + WindowFunnelTest() {} + + void SetUp() { + AggregateFunctionSimpleFactory factory = AggregateFunctionSimpleFactory::instance(); + DataTypes data_types = { + std::make_shared<DataTypeInt64>(), + std::make_shared<DataTypeString>(), + std::make_shared<DataTypeDateTime>(), + std::make_shared<DataTypeUInt8>(), + std::make_shared<DataTypeUInt8>(), + std::make_shared<DataTypeUInt8>(), + std::make_shared<DataTypeUInt8>(), + }; + Array array; + agg_function = factory.get("window_funnel", data_types, array, false); + ASSERT_NE(agg_function, nullptr); + } + + void TearDown() { + } +}; + +TEST_F(WindowFunnelTest, testEmpty) { + std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]); + AggregateDataPtr place = memory.get(); + agg_function->create(place); + + ColumnString buf; + VectorBufferWriter buf_writer(buf); + agg_function->serialize(place, buf_writer); + buf_writer.commit(); + LOG(INFO) << "buf size : " << buf.size(); + VectorBufferReader buf_reader(buf.get_data_at(0)); + agg_function->deserialize(place, buf_reader, nullptr); + + std::unique_ptr<char[]> memory2(new char[agg_function->size_of_data()]); + AggregateDataPtr place2 = memory2.get(); + agg_function->create(place2); + + agg_function->merge(place, place2, nullptr); + ColumnVector<Int32> column_result; + agg_function->insert_result_into(place, column_result); + ASSERT_EQ(column_result.get_data()[0], 0); + + ColumnVector<Int32> column_result2; + agg_function->insert_result_into(place2, column_result2); + ASSERT_EQ(column_result2.get_data()[0], 0); + + agg_function->destroy(place); + agg_function->destroy(place2); +} + +TEST_F(WindowFunnelTest, testSerialize) { + const int NUM_CONDS = 4; + auto column_mode = ColumnString::create(); + for (int i = 0; i < NUM_CONDS; i++) { + column_mode->insert("mode"); + } + + auto column_timestamp = ColumnVector<Int64>::create(); + for (int i = 0; i < NUM_CONDS; i++) { + VecDateTimeValue time_value; + time_value.set_time(2022, 2, 28, 0, 0, i); + column_timestamp->insert_data((char *)&time_value, 0); + } + auto column_event1 = ColumnVector<UInt8>::create(); + column_event1->insert(1); + column_event1->insert(0); + column_event1->insert(0); + column_event1->insert(0); + + auto column_event2 = ColumnVector<UInt8>::create(); + column_event2->insert(0); + column_event2->insert(1); + column_event2->insert(0); + column_event2->insert(0); + + auto column_event3 = ColumnVector<UInt8>::create(); + column_event3->insert(0); + column_event3->insert(0); + column_event3->insert(1); + column_event3->insert(0); + + auto column_event4 = ColumnVector<UInt8>::create(); + column_event4->insert(0); + column_event4->insert(0); + column_event4->insert(0); + column_event4->insert(1); + + auto column_window = ColumnVector<Int64>::create(); + for (int i = 0; i < NUM_CONDS; i++) { + column_window->insert(2); + } + + std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]); + AggregateDataPtr place = memory.get(); + agg_function->create(place); + const IColumn* column[7] = { column_window.get(), column_mode.get(), + column_timestamp.get(), + column_event1.get(), column_event2.get(), + column_event3.get(), column_event4.get() }; + for (int i = 0; i < NUM_CONDS; i++) { + agg_function->add(place, column, i, nullptr); + } + + ColumnString buf; + VectorBufferWriter buf_writer(buf); + agg_function->serialize(place, buf_writer); + buf_writer.commit(); + + std::unique_ptr<char[]> memory2(new char[agg_function->size_of_data()]); + AggregateDataPtr place2 = memory2.get(); + agg_function->create(place2); + + VectorBufferReader buf_reader(buf.get_data_at(0)); + agg_function->deserialize(place2, buf_reader, nullptr); + + ColumnVector<Int32> column_result; + agg_function->insert_result_into(place, column_result); + ASSERT_EQ(column_result.get_data()[0], 3); + agg_function->destroy(place); + + ColumnVector<Int32> column_result2; + agg_function->insert_result_into(place2, column_result2); + ASSERT_EQ(column_result2.get_data()[0], 3); + agg_function->destroy(place2); +} + + +TEST_F(WindowFunnelTest, testMax4SortedNoMerge) { + const int NUM_CONDS = 4; + auto column_mode = ColumnString::create(); + for (int i = 0; i < NUM_CONDS; i++) { + column_mode->insert("mode"); + } + auto column_timestamp = ColumnVector<Int64>::create(); + for (int i = 0; i < NUM_CONDS; i++) { + VecDateTimeValue time_value; + time_value.set_time(2022, 2, 28, 0, 0, i); + column_timestamp->insert_data((char *)&time_value, 0); + } + auto column_event1 = ColumnVector<UInt8>::create(); + column_event1->insert(1); + column_event1->insert(0); + column_event1->insert(0); + column_event1->insert(0); + + auto column_event2 = ColumnVector<UInt8>::create(); + column_event2->insert(0); + column_event2->insert(1); + column_event2->insert(0); + column_event2->insert(0); + + auto column_event3 = ColumnVector<UInt8>::create(); + column_event3->insert(0); + column_event3->insert(0); + column_event3->insert(1); + column_event3->insert(0); + + auto column_event4 = ColumnVector<UInt8>::create(); + column_event4->insert(0); + column_event4->insert(0); + column_event4->insert(0); + column_event4->insert(1); + + for(int win = -1; win < NUM_CONDS + 1; win++) { + auto column_window = ColumnVector<Int64>::create(); + for (int i = 0; i < NUM_CONDS; i++) { + column_window->insert(win); + } + + std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]); + AggregateDataPtr place = memory.get(); + agg_function->create(place); + const IColumn* column[7] = { column_window.get(), column_mode.get(), + column_timestamp.get(), + column_event1.get(), column_event2.get(), + column_event3.get(), column_event4.get() }; + for (int i = 0; i < NUM_CONDS; i++) { + agg_function->add(place, column, i, nullptr); + } + + ColumnVector<Int32> column_result; + agg_function->insert_result_into(place, column_result); + ASSERT_EQ(column_result.get_data()[0], win < 0 ? 1 : (win < NUM_CONDS ? win + 1 : NUM_CONDS)); + agg_function->destroy(place); + } +} + +TEST_F(WindowFunnelTest, testMax4SortedMerge) { + const int NUM_CONDS = 4; + auto column_mode = ColumnString::create(); + for (int i = 0; i < NUM_CONDS; i++) { + column_mode->insert("mode"); + } + auto column_timestamp = ColumnVector<Int64>::create(); + for (int i = 0; i < NUM_CONDS; i++) { + VecDateTimeValue time_value; + time_value.set_time(2022, 2, 28, 0, 0, i); + column_timestamp->insert_data((char *)&time_value, 0); + } + auto column_event1 = ColumnVector<UInt8>::create(); + column_event1->insert(1); + column_event1->insert(0); + column_event1->insert(0); + column_event1->insert(0); + + auto column_event2 = ColumnVector<UInt8>::create(); + column_event2->insert(0); + column_event2->insert(1); + column_event2->insert(0); + column_event2->insert(0); + + auto column_event3 = ColumnVector<UInt8>::create(); + column_event3->insert(0); + column_event3->insert(0); + column_event3->insert(1); + column_event3->insert(0); + + auto column_event4 = ColumnVector<UInt8>::create(); + column_event4->insert(0); + column_event4->insert(0); + column_event4->insert(0); + column_event4->insert(1); + + for(int win = -1; win < NUM_CONDS + 1; win++) { + auto column_window = ColumnVector<Int64>::create(); + for (int i = 0; i < NUM_CONDS; i++) { + column_window->insert(win); + } + + std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]); + AggregateDataPtr place = memory.get(); + agg_function->create(place); + const IColumn* column[7] = { column_window.get(), column_mode.get(), + column_timestamp.get(), + column_event1.get(), column_event2.get(), + column_event3.get(), column_event4.get() }; + for (int i = 0; i < NUM_CONDS; i++) { + agg_function->add(place, column, i, nullptr); + } + + std::unique_ptr<char[]> memory2(new char[agg_function->size_of_data()]); + AggregateDataPtr place2 = memory2.get(); + agg_function->create(place2); + + agg_function->merge(place2, place, nullptr); + ColumnVector<Int32> column_result; + agg_function->insert_result_into(place2, column_result); + ASSERT_EQ(column_result.get_data()[0], win < 0 ? 1 : (win < NUM_CONDS ? win + 1 : NUM_CONDS)); + agg_function->destroy(place); + agg_function->destroy(place2); + } +} + +TEST_F(WindowFunnelTest, testMax4ReverseSortedNoMerge) { + const int NUM_CONDS = 4; + auto column_mode = ColumnString::create(); + for (int i = 0; i < NUM_CONDS; i++) { + column_mode->insert("mode"); + } + auto column_timestamp = ColumnVector<Int64>::create(); + for (int i = 0; i < NUM_CONDS; i++) { + VecDateTimeValue time_value; + time_value.set_time(2022, 2, 28, 0, 0, NUM_CONDS - i); + column_timestamp->insert_data((char *)&time_value, 0); + } + auto column_event1 = ColumnVector<UInt8>::create(); + column_event1->insert(0); + column_event1->insert(0); + column_event1->insert(0); + column_event1->insert(1); + + auto column_event2 = ColumnVector<UInt8>::create(); + column_event2->insert(0); + column_event2->insert(0); + column_event2->insert(1); + column_event2->insert(0); + + auto column_event3 = ColumnVector<UInt8>::create(); + column_event3->insert(0); + column_event3->insert(1); + column_event3->insert(0); + column_event3->insert(0); + + auto column_event4 = ColumnVector<UInt8>::create(); + column_event4->insert(1); + column_event4->insert(0); + column_event4->insert(0); + column_event4->insert(0); + + for(int win = -1; win < NUM_CONDS + 1; win++) { + auto column_window = ColumnVector<Int64>::create(); + for (int i = 0; i < NUM_CONDS; i++) { + column_window->insert(win); + } + + std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]); + AggregateDataPtr place = memory.get(); + agg_function->create(place); + const IColumn* column[7] = { column_window.get(), column_mode.get(), + column_timestamp.get(), + column_event1.get(), column_event2.get(), + column_event3.get(), column_event4.get() }; + for (int i = 0; i < NUM_CONDS; i++) { + agg_function->add(place, column, i, nullptr); + } + + LOG(INFO) << "win " << win; + ColumnVector<Int32> column_result; + agg_function->insert_result_into(place, column_result); + ASSERT_EQ(column_result.get_data()[0], win < 0 ? 1 : (win < NUM_CONDS ? win + 1 : NUM_CONDS)); + agg_function->destroy(place); + } +} + +TEST_F(WindowFunnelTest, testMax4ReverseSortedMerge) { + const int NUM_CONDS = 4; + auto column_mode = ColumnString::create(); + for (int i = 0; i < NUM_CONDS; i++) { + column_mode->insert("mode"); + } + auto column_timestamp = ColumnVector<Int64>::create(); + for (int i = 0; i < NUM_CONDS; i++) { + VecDateTimeValue time_value; + time_value.set_time(2022, 2, 28, 0, 0, NUM_CONDS - i); + column_timestamp->insert_data((char *)&time_value, 0); + } + auto column_event1 = ColumnVector<UInt8>::create(); + column_event1->insert(0); + column_event1->insert(0); + column_event1->insert(0); + column_event1->insert(1); + + auto column_event2 = ColumnVector<UInt8>::create(); + column_event2->insert(0); + column_event2->insert(0); + column_event2->insert(1); + column_event2->insert(0); + + auto column_event3 = ColumnVector<UInt8>::create(); + column_event3->insert(0); + column_event3->insert(1); + column_event3->insert(0); + column_event3->insert(0); + + auto column_event4 = ColumnVector<UInt8>::create(); + column_event4->insert(1); + column_event4->insert(0); + column_event4->insert(0); + column_event4->insert(0); + + for(int win = -1; win < NUM_CONDS + 1; win++) { + auto column_window = ColumnVector<Int64>::create(); + for (int i = 0; i < NUM_CONDS; i++) { + column_window->insert(win); + } + + std::unique_ptr<char[]> memory(new char[agg_function->size_of_data()]); + AggregateDataPtr place = memory.get(); + agg_function->create(place); + const IColumn* column[7] = { column_window.get(), column_mode.get(), + column_timestamp.get(), + column_event1.get(), column_event2.get(), + column_event3.get(), column_event4.get() }; + for (int i = 0; i < NUM_CONDS; i++) { + agg_function->add(place, column, i, nullptr); + } + + std::unique_ptr<char[]> memory2(new char[agg_function->size_of_data()]); + AggregateDataPtr place2 = memory2.get(); + agg_function->create(place2); + + agg_function->merge(place2, place, NULL); + ColumnVector<Int32> column_result; + agg_function->insert_result_into(place2, column_result); + ASSERT_EQ(column_result.get_data()[0], win < 0 ? 1 : (win < NUM_CONDS ? win + 1 : NUM_CONDS)); + agg_function->destroy(place); + agg_function->destroy(place2); + } +} + +} // namespace doris::vectorized + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/Analyzer.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/Analyzer.java index 9bcb5a1..8b5ebab 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/Analyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/Analyzer.java @@ -731,7 +731,6 @@ public class Analyzer { result.setMultiRef(true); return result; } - result = addSlotDescriptor(tupleDescriptor); Column col = new Column(colName, type); result.setColumn(col); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index c0e7922..6115350 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -125,7 +125,7 @@ public class FunctionCallExpr extends Expr { } public FunctionCallExpr(String fnName, FunctionParams params) { - this(new FunctionName(fnName), params); + this(new FunctionName(fnName), params, false); } public FunctionCallExpr(FunctionName fnName, FunctionParams params) { @@ -140,8 +140,8 @@ public class FunctionCallExpr extends Expr { this.isMergeAggFn = isMergeAggFn; if (params.exprs() != null) { children.addAll(params.exprs()); - originChildSize = children.size(); } + originChildSize = children.size(); } // Constructs the same agg function with new params. @@ -236,29 +236,21 @@ public class FunctionCallExpr extends Expr { && fnParams.isStar() == o.fnParams.isStar(); } - @Override - public String toSqlImpl() { - Expr expr; - if (originStmtFnExpr != null) { - expr = originStmtFnExpr; - } else { - expr = this; - } + private String paramsToSql(FunctionParams params) { StringBuilder sb = new StringBuilder(); - sb.append(((FunctionCallExpr) expr).fnName).append("("); - if (((FunctionCallExpr) expr).fnParams.isStar()) { + sb.append("("); + + if (params.isStar()) { sb.append("*"); } - if (((FunctionCallExpr) expr).fnParams.isDistinct()) { + if (params.isDistinct()) { sb.append("DISTINCT "); } - boolean isJsonFunction = false; - int len = children.size(); + int len = params.exprs().size(); List<String> result = Lists.newArrayList(); if (fnName.getFunction().equalsIgnoreCase("json_array") || fnName.getFunction().equalsIgnoreCase("json_object")) { len = len - 1; - isJsonFunction = true; } if (fnName.getFunction().equalsIgnoreCase("aes_decrypt") || fnName.getFunction().equalsIgnoreCase("aes_encrypt") || @@ -273,11 +265,27 @@ public class FunctionCallExpr extends Expr { fnName.getFunction().equalsIgnoreCase("sm4_encrypt"))) { result.add("\'***\'"); } else { - result.add(children.get(i).toSql()); + result.add(params.exprs().get(i).toSql()); } } sb.append(Joiner.on(", ").join(result)).append(")"); - if (fnName.getFunction().equalsIgnoreCase("json_quote") || isJsonFunction) { + return sb.toString(); + } + + @Override + public String toSqlImpl() { + Expr expr; + if (originStmtFnExpr != null) { + expr = originStmtFnExpr; + } else { + expr = this; + } + StringBuilder sb = new StringBuilder(); + sb.append(((FunctionCallExpr) expr).fnName); + sb.append(paramsToSql(fnParams)); + if (fnName.getFunction().equalsIgnoreCase("json_quote") || + fnName.getFunction().equalsIgnoreCase("json_array") || + fnName.getFunction().equalsIgnoreCase("json_object")) { return forJSON(sb.toString()); } return sb.toString(); @@ -784,6 +792,34 @@ public class FunctionCallExpr extends Expr { fn = getBuiltinFunction(analyzer, fnName.getFunction(), new Type[]{compatibleType}, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); + } else if (fnName.getFunction().equalsIgnoreCase(FunctionSet.WINDOW_FUNNEL)) { + if (fnParams.exprs() == null || fnParams.exprs().size() < 4) { + throw new AnalysisException("The " + fnName + " function must have at least four params"); + } + + if (!children.get(0).type.isIntegerType()) { + throw new AnalysisException("The window params of " + fnName + " function must be integer"); + } + if (!children.get(1).type.isStringType()) { + throw new AnalysisException("The mode params of " + fnName + " function must be integer"); + } + if (!children.get(2).type.isDateType()) { + throw new AnalysisException("The 3rd param of " + fnName + " function must be DATE or DATETIME"); + } + + Type[] childTypes = new Type[children.size()]; + for (int i = 0; i < 3; i++) { + childTypes[i] = children.get(i).type; + } + for (int i = 3; i < children.size(); i++) { + if (children.get(i).type != Type.BOOLEAN) { + throw new AnalysisException("The 4th and subsequent params of " + fnName + " function must be boolean"); + } + childTypes[i] = children.get(i).type; + } + + fn = getBuiltinFunction(analyzer, fnName.getFunction(), childTypes, + Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); } else { // now first find table function in table function sets if (isTableFnCall) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java index a695a9c..bc59188 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java @@ -49,7 +49,7 @@ public class AggregateFunction extends Function { private static final Logger LOG = LogManager.getLogger(AggregateFunction.class); public static ImmutableSet<String> NOT_NULLABLE_AGGREGATE_FUNCTION_NAME_SET = - ImmutableSet.of("row_number", "rank", "dense_rank", "multi_distinct_count", "multi_distinct_sum", "hll_union_agg", "hll_union", "bitmap_union", "bitmap_intersect", FunctionSet.COUNT, "approx_count_distinct", "ndv", FunctionSet.BITMAP_UNION_INT, FunctionSet.BITMAP_UNION_COUNT, "ndv_no_finalize"); + ImmutableSet.of("row_number", "rank", "dense_rank", "multi_distinct_count", "multi_distinct_sum", "hll_union_agg", "hll_union", "bitmap_union", "bitmap_intersect", FunctionSet.COUNT, "approx_count_distinct", "ndv", FunctionSet.BITMAP_UNION_INT, FunctionSet.BITMAP_UNION_COUNT, "ndv_no_finalize", FunctionSet.WINDOW_FUNNEL); public static ImmutableSet<String> ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET = ImmutableSet.of("stddev_samp", "variance_samp", "var_samp", "percentile_approx"); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index fb05ed1..0bbc83f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -1218,6 +1218,7 @@ public class FunctionSet<min_initIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionCo public static final String COUNT = "count"; + public static final String WINDOW_FUNNEL = "window_funnel"; // Populate all the aggregate builtins in the catalog. // null symbols indicate the function does not need that step of the evaluation. // An empty symbol indicates a TODO for the BE to implement the function. @@ -1251,6 +1252,36 @@ public class FunctionSet<min_initIN9doris_udf12DecimalV2ValEEEvPNS2_15FunctionCo prefix + "17count_star_removeEPN9doris_udf15FunctionContextEPNS1_9BigIntValE", null, false, true, true, true)); + // windowFunnel + addBuiltin(AggregateFunction.createBuiltin(FunctionSet.WINDOW_FUNNEL, + Lists.newArrayList(Type.BIGINT, Type.STRING, Type.DATETIME, Type.BOOLEAN), + Type.INT, + Type.VARCHAR, + true, + prefix + "18window_funnel_initEPN9doris_udf15FunctionContextEPNS1_9StringValE", + prefix + "20window_funnel_updateEPN9doris_udf15FunctionContextERKNS1_9BigIntValERKNS1_9StringValERKNS1_11DateTimeValEiPKNS1_10BooleanValEPS7_", + prefix + "19window_funnel_mergeEPN9doris_udf15FunctionContextERKNS1_9StringValEPS4_", + prefix + "23window_funnel_serializeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + "", + "", + prefix + "22window_funnel_finalizeEPN9doris_udf15FunctionContextERKNS1_9StringValE", + true, false, true)); + + // Vectorization does not need symbol any more, we should clean it in the future. + addBuiltin(AggregateFunction.createBuiltin(FunctionSet.WINDOW_FUNNEL, + Lists.newArrayList(Type.BIGINT, Type.STRING, Type.DATETIME, Type.BOOLEAN), + Type.INT, + Type.VARCHAR, + true, + "", + "", + "", + "", + "", + "", + "", + true, false, true, true)); + for (Type t : Type.getSupportedTypes()) { if (t.isNull()) { continue; // NULL is handled through type promotion. diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/AggregateTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/AggregateTest.java index 3f0028c..35225dc 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/analysis/AggregateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/AggregateTest.java @@ -46,7 +46,7 @@ public class AggregateTest { dorisAssert = new DorisAssert(); dorisAssert.withDatabase(DB_NAME).useDatabase(DB_NAME); String createTableSQL = "create table " + DB_NAME + "." + TABLE_NAME + " (empid int, name varchar, " + - "deptno int, salary int, commission int) " + "deptno int, salary int, commission int, time DATETIME) " + "distributed by hash(empid) buckets 3 properties('replication_num' = '1');"; dorisAssert.withTable(createTableSQL); } @@ -93,8 +93,97 @@ public class AggregateTest { } while (false); } + @Test + public void testWindowFunnelAnalysisException() throws Exception { + ConnectContext ctx = UtFrameUtils.createDefaultCtx(); + + // normal. + { + String query = "select empid, window_funnel(1, 'default', time, empid = 1, empid = 2) from " + + DB_NAME + "." + TABLE_NAME + " group by empid"; + try { + UtFrameUtils.parseAndAnalyzeStmt(query, ctx); + } catch (Exception e) { + Assert.fail("must be AnalysisException."); + } + } + + // less argument. + do { + String query = "select empid, window_funnel(1, 'default', time) from " + + DB_NAME + "." + TABLE_NAME + " group by empid"; + try { + UtFrameUtils.parseAndAnalyzeStmt(query, ctx); + } catch (AnalysisException e) { + Assert.assertTrue(e.getMessage().contains("function must have at least four params")); + break; + } catch (Exception e) { + Assert.fail("must be AnalysisException."); + } + Assert.fail("must be AnalysisException."); + } while(false); + + // argument with wrong type. + do { + String query = "select empid, window_funnel('xx', 'default', time, empid = 1) from " + + DB_NAME + "." + TABLE_NAME + " group by empid"; + try { + UtFrameUtils.parseAndAnalyzeStmt(query, ctx); + } catch (AnalysisException e) { + Assert.assertTrue(e.getMessage().contains("The window param of window_funnel function must be integer")); + break; + } catch (Exception e) { + Assert.fail("must be AnalysisException."); + } + Assert.fail("must be AnalysisException."); + } while(false); + + do { + String query = "select empid, window_funnel(1, 1, time, empid = 1) from " + + DB_NAME + "." + TABLE_NAME + " group by empid"; + try { + UtFrameUtils.parseAndAnalyzeStmt(query, ctx); + } catch (AnalysisException e) { + Assert.assertTrue(e.getMessage().contains("The mode param of window_funnel function must be string")); + break; + } catch (Exception e) { + Assert.fail("must be AnalysisException."); + } + Assert.fail("must be AnalysisException."); + } while(false); + + + do { + String query = "select empid, window_funnel(1, '1', empid, '1') from " + + DB_NAME + "." + TABLE_NAME + " group by empid"; + try { + UtFrameUtils.parseAndAnalyzeStmt(query, ctx); + } catch (AnalysisException e) { + Assert.assertTrue(e.getMessage().contains("The 3rd param of window_funnel function must be DATE or DATETIME")); + break; + } catch (Exception e) { + Assert.fail("must be AnalysisException."); + } + Assert.fail("must be AnalysisException."); + } while(false); + + do { + String query = "select empid, window_funnel(1, '1', time, '1') from " + + DB_NAME + "." + TABLE_NAME + " group by empid"; + try { + UtFrameUtils.parseAndAnalyzeStmt(query, ctx); + } catch (AnalysisException e) { + Assert.assertTrue(e.getMessage().contains("The 4th and subsequent params of window_funnel function must be boolean")); + break; + } catch (Exception e) { + Assert.fail("must be AnalysisException."); + } + Assert.fail("must be AnalysisException."); + } while(false); + } + @AfterClass public static void afterClass() throws Exception { UtFrameUtils.cleanDorisFeDir(baseDir); } -} \ No newline at end of file +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org