This is an automated email from the ASF dual-hosted git repository. lihaopeng pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 002c76e06f [vectorized](udaf) support udaf function work with window function (#19962) 002c76e06f is described below commit 002c76e06f47ab7fbc7654cccdf58b3e24d91856 Author: zhangstar333 <87313068+zhangstar...@users.noreply.github.com> AuthorDate: Thu May 25 14:38:47 2023 +0800 [vectorized](udaf) support udaf function work with window function (#19962) --- .../aggregate_function_java_udaf.h | 35 ++++++++++++++--- .../ecosystem/udf/java-user-defined-function.md | 7 ++++ .../ecosystem/udf/java-user-defined-function.md | 7 ++++ .../java/org/apache/doris/udf/BaseExecutor.java | 1 + .../java/org/apache/doris/udf/UdafExecutor.java | 45 +++++++++++++++++++--- .../data/javaudf_p0/test_javaudaf_mysum_int.out | 33 ++++++++++++++++ .../main/java/org/apache/doris/udf/MySumInt.java | 4 ++ .../javaudf_p0/test_javaudaf_mysum_int.groovy | 4 +- 8 files changed, 125 insertions(+), 11 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h index d913c8b32d..478e7d79c2 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h +++ b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h @@ -47,6 +47,7 @@ const char* UDAF_EXECUTOR_ADD_SIGNATURE = "(ZJJ)V"; const char* UDAF_EXECUTOR_SERIALIZE_SIGNATURE = "(J)[B"; const char* UDAF_EXECUTOR_MERGE_SIGNATURE = "(J[B)V"; const char* UDAF_EXECUTOR_RESULT_SIGNATURE = "(JJ)Z"; +const char* UDAF_EXECUTOR_RESET_SIGNATURE = "(J)V"; // Calling Java method about those signature means: "(argument-types)return-type" // https://www.iitk.ac.in/esc101/05Aug/tutorial/native1.1/implementing/method.html @@ -219,6 +220,13 @@ public: return JniUtil::GetJniExceptionMsg(env); } + Status reset(int64_t place) { + JNIEnv* env = nullptr; + RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf reset function"); + env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_reset_id, place); + return JniUtil::GetJniExceptionMsg(env); + } + void read(BufferReadable& buf) { read_binary(serialize_data, buf); } Status destroy() { @@ -375,6 +383,7 @@ private: RETURN_IF_ERROR(register_id("<init>", UDAF_EXECUTOR_CTOR_SIGNATURE, executor_ctor_id)); RETURN_IF_ERROR(register_id("add", UDAF_EXECUTOR_ADD_SIGNATURE, executor_add_id)); + RETURN_IF_ERROR(register_id("reset", UDAF_EXECUTOR_RESET_SIGNATURE, executor_reset_id)); RETURN_IF_ERROR(register_id("close", UDAF_EXECUTOR_CLOSE_SIGNATURE, executor_close_id)); RETURN_IF_ERROR(register_id("merge", UDAF_EXECUTOR_MERGE_SIGNATURE, executor_merge_id)); RETURN_IF_ERROR( @@ -397,6 +406,7 @@ private: jmethodID executor_merge_id; jmethodID executor_serialize_id; jmethodID executor_result_id; + jmethodID executor_reset_id; jmethodID executor_close_id; jmethodID executor_destroy_id; @@ -502,11 +512,26 @@ public: } } - // TODO: reset function should be implement also in struct data - void reset(AggregateDataPtr /*place*/) const override { - LOG(WARNING) << " shouldn't going reset function, there maybe some error about function " - << _fn.name.function_name; - throw doris::Exception(ErrorCode::INTERNAL_ERROR, "shouldn't going reset function"); + void add_range_single_place(int64_t partition_start, int64_t partition_end, int64_t frame_start, + int64_t frame_end, AggregateDataPtr place, const IColumn** columns, + Arena* arena) const override { + frame_start = std::max<int64_t>(frame_start, partition_start); + frame_end = std::min<int64_t>(frame_end, partition_end); + int64_t places_address[1]; + places_address[0] = reinterpret_cast<int64_t>(place); + Status st = + this->data(_exec_place) + .add(places_address, true, columns, frame_start, frame_end, argument_types); + if (UNLIKELY(st != Status::OK())) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); + } + } + + void reset(AggregateDataPtr place) const override { + Status st = this->data(_exec_place).reset(reinterpret_cast<int64_t>(place)); + if (UNLIKELY(st != Status::OK())) { + throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); + } } void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, diff --git a/docs/en/docs/ecosystem/udf/java-user-defined-function.md b/docs/en/docs/ecosystem/udf/java-user-defined-function.md index 4a0f2e0d5e..c6b57fed3c 100644 --- a/docs/en/docs/ecosystem/udf/java-user-defined-function.md +++ b/docs/en/docs/ecosystem/udf/java-user-defined-function.md @@ -130,6 +130,13 @@ public class SimpleDemo { /* here could do some destroy work if needed */ } + /*Not Required*/ + public void reset(State state) { + /*if you want this udaf function can work with window function.*/ + /*Must impl this, it will be reset to init state after calculate every window frame*/ + state.sum = 0; + } + /*required*/ //first argument is State, then other types your input public void add(State state, Integer val) throws Exception { diff --git a/docs/zh-CN/docs/ecosystem/udf/java-user-defined-function.md b/docs/zh-CN/docs/ecosystem/udf/java-user-defined-function.md index 85080f4bf5..a30fcca614 100644 --- a/docs/zh-CN/docs/ecosystem/udf/java-user-defined-function.md +++ b/docs/zh-CN/docs/ecosystem/udf/java-user-defined-function.md @@ -130,6 +130,13 @@ public class SimpleDemo { /* here could do some destroy work if needed */ } + /*Not Required*/ + public void reset(State state) { + /*if you want this udaf function can work with window function.*/ + /*Must impl this, it will be reset to init state after calculate every window frame*/ + state.sum = 0; + } + /*required*/ //first argument is State, then other types your input public void add(State state, Integer val) throws Exception { diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java b/fe/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java index dde6a0b084..f1d2e3a9a7 100644 --- a/fe/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java +++ b/fe/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java @@ -46,6 +46,7 @@ public abstract class BaseExecutor { public static final String UDAF_CREATE_FUNCTION = "create"; public static final String UDAF_DESTROY_FUNCTION = "destroy"; public static final String UDAF_ADD_FUNCTION = "add"; + public static final String UDAF_RESET_FUNCTION = "reset"; public static final String UDAF_SERIALIZE_FUNCTION = "serialize"; public static final String UDAF_DESERIALIZE_FUNCTION = "deserialize"; public static final String UDAF_MERGE_FUNCTION = "merge"; diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java b/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java index 1cba22a120..76e9a09c3d 100644 --- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java +++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java @@ -71,10 +71,21 @@ public class UdafExecutor extends BaseExecutor { try { long idx = rowStart; do { - Long curPlace = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, inputPlacesPtr) + 8L * idx); + Long curPlace = null; + if (isSinglePlace) { + curPlace = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, inputPlacesPtr)); + } else { + curPlace = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, inputPlacesPtr) + 8L * idx); + } Object[] inputArgs = new Object[argTypes.length + 1]; - stateObjMap.putIfAbsent(curPlace, createAggState()); - inputArgs[0] = stateObjMap.get(curPlace); + Object state = stateObjMap.get(curPlace); + if (state != null) { + inputArgs[0] = state; + } else { + Object newState = createAggState(); + stateObjMap.put(curPlace, newState); + inputArgs[0] = newState; + } do { Object[] inputObjects = allocateInputObjects(idx, 1); for (int i = 0; i < argTypes.length; ++i) { @@ -134,6 +145,23 @@ public class UdafExecutor extends BaseExecutor { } } + /* + * invoke reset function and reset the state to init. + */ + public void reset(long place) throws UdfRuntimeException { + try { + Object[] args = new Object[1]; + args[0] = stateObjMap.get((Long) place); + if (args[0] == null) { + return; + } + allMethods.get(UDAF_RESET_FUNCTION).invoke(udf, args); + } catch (Exception e) { + LOG.warn("invoke reset function meet some error: " + e.getCause().toString()); + throw new UdfRuntimeException("UDAF failed to reset: ", e); + } + } + /** * invoke merge function and it's have done deserialze. * here call deserialize first, and call merge. @@ -147,8 +175,14 @@ public class UdafExecutor extends BaseExecutor { allMethods.get(UDAF_DESERIALIZE_FUNCTION).invoke(udf, args); args[1] = args[0]; Long curPlace = place; - stateObjMap.putIfAbsent(curPlace, createAggState()); - args[0] = stateObjMap.get(curPlace); + Object state = stateObjMap.get(curPlace); + if (state != null) { + args[0] = state; + } else { + Object newState = createAggState(); + stateObjMap.put(curPlace, newState); + args[0] = newState; + } allMethods.get(UDAF_MERGE_FUNCTION).invoke(udf, args); } catch (Exception e) { LOG.warn("invoke merge function meet some error: " + e.getCause().toString()); @@ -226,6 +260,7 @@ public class UdafExecutor extends BaseExecutor { case UDAF_CREATE_FUNCTION: case UDAF_MERGE_FUNCTION: case UDAF_SERIALIZE_FUNCTION: + case UDAF_RESET_FUNCTION: case UDAF_DESERIALIZE_FUNCTION: { allMethods.put(methods[idx].getName(), methods[idx]); break; diff --git a/regression-test/data/javaudf_p0/test_javaudaf_mysum_int.out b/regression-test/data/javaudf_p0/test_javaudaf_mysum_int.out index 47c14ac114..d5d4fc44aa 100644 --- a/regression-test/data/javaudf_p0/test_javaudaf_mysum_int.out +++ b/regression-test/data/javaudf_p0/test_javaudaf_mysum_int.out @@ -31,3 +31,36 @@ 2 6 6 9 9 9 +-- !select5 -- +1 +2 +0 +1 +2 +0 +1 +2 +9 + +-- !select6 -- +1 +2 +0 +1 +2 +0 +1 +2 +9 + +-- !select7 -- +1 +2 +0 +1 +2 +0 +1 +2 +9 + diff --git a/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumInt.java b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumInt.java index ca23de7e9a..3037740e79 100644 --- a/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumInt.java +++ b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumInt.java @@ -31,6 +31,10 @@ public class MySumInt { public void destroy(State state) { } + public void reset(State state) { + state.counter = 0; + } + public void add(State state, Integer val) { if (val == null) return; state.counter += val; diff --git a/regression-test/suites/javaudf_p0/test_javaudaf_mysum_int.groovy b/regression-test/suites/javaudf_p0/test_javaudaf_mysum_int.groovy index ad3f2cc532..a07233c43e 100644 --- a/regression-test/suites/javaudf_p0/test_javaudaf_mysum_int.groovy +++ b/regression-test/suites/javaudf_p0/test_javaudaf_mysum_int.groovy @@ -72,7 +72,9 @@ suite("test_javaudaf_mysum_int") { qt_select4 """ select user_id, udaf_my_sum_int(user_id), sum(user_id) from ${tableName} group by user_id order by user_id; """ - + qt_select5 """ select udaf_my_sum_int(user_id) over(partition by char_col) from test_javaudaf_mysum_int order by char_col; """ + qt_select6 """ select udaf_my_sum_int(user_id) over(partition by char_col order by string_col) from test_javaudaf_mysum_int order by char_col; """ + qt_select7 """ select udaf_my_sum_int(user_id) over(partition by char_col order by string_col rows between 1 preceding and 1 following ) from test_javaudaf_mysum_int order by char_col; """ } finally { try_sql("DROP FUNCTION IF EXISTS udaf_my_sum_int(int);") try_sql("DROP TABLE IF EXISTS ${tableName}") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org