This is an automated email from the ASF dual-hosted git repository.

lihaopeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 002c76e06f [vectorized](udaf) support udaf function work with window 
function (#19962)
002c76e06f is described below

commit 002c76e06f47ab7fbc7654cccdf58b3e24d91856
Author: zhangstar333 <87313068+zhangstar...@users.noreply.github.com>
AuthorDate: Thu May 25 14:38:47 2023 +0800

    [vectorized](udaf) support udaf function work with window function (#19962)
---
 .../aggregate_function_java_udaf.h                 | 35 ++++++++++++++---
 .../ecosystem/udf/java-user-defined-function.md    |  7 ++++
 .../ecosystem/udf/java-user-defined-function.md    |  7 ++++
 .../java/org/apache/doris/udf/BaseExecutor.java    |  1 +
 .../java/org/apache/doris/udf/UdafExecutor.java    | 45 +++++++++++++++++++---
 .../data/javaudf_p0/test_javaudaf_mysum_int.out    | 33 ++++++++++++++++
 .../main/java/org/apache/doris/udf/MySumInt.java   |  4 ++
 .../javaudf_p0/test_javaudaf_mysum_int.groovy      |  4 +-
 8 files changed, 125 insertions(+), 11 deletions(-)

diff --git a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h 
b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
index d913c8b32d..478e7d79c2 100644
--- a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
+++ b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h
@@ -47,6 +47,7 @@ const char* UDAF_EXECUTOR_ADD_SIGNATURE = "(ZJJ)V";
 const char* UDAF_EXECUTOR_SERIALIZE_SIGNATURE = "(J)[B";
 const char* UDAF_EXECUTOR_MERGE_SIGNATURE = "(J[B)V";
 const char* UDAF_EXECUTOR_RESULT_SIGNATURE = "(JJ)Z";
+const char* UDAF_EXECUTOR_RESET_SIGNATURE = "(J)V";
 // Calling Java method about those signature means: 
"(argument-types)return-type"
 // 
https://www.iitk.ac.in/esc101/05Aug/tutorial/native1.1/implementing/method.html
 
@@ -219,6 +220,13 @@ public:
         return JniUtil::GetJniExceptionMsg(env);
     }
 
+    Status reset(int64_t place) {
+        JNIEnv* env = nullptr;
+        RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf 
reset function");
+        env->CallNonvirtualVoidMethod(executor_obj, executor_cl, 
executor_reset_id, place);
+        return JniUtil::GetJniExceptionMsg(env);
+    }
+
     void read(BufferReadable& buf) { read_binary(serialize_data, buf); }
 
     Status destroy() {
@@ -375,6 +383,7 @@ private:
 
         RETURN_IF_ERROR(register_id("<init>", UDAF_EXECUTOR_CTOR_SIGNATURE, 
executor_ctor_id));
         RETURN_IF_ERROR(register_id("add", UDAF_EXECUTOR_ADD_SIGNATURE, 
executor_add_id));
+        RETURN_IF_ERROR(register_id("reset", UDAF_EXECUTOR_RESET_SIGNATURE, 
executor_reset_id));
         RETURN_IF_ERROR(register_id("close", UDAF_EXECUTOR_CLOSE_SIGNATURE, 
executor_close_id));
         RETURN_IF_ERROR(register_id("merge", UDAF_EXECUTOR_MERGE_SIGNATURE, 
executor_merge_id));
         RETURN_IF_ERROR(
@@ -397,6 +406,7 @@ private:
     jmethodID executor_merge_id;
     jmethodID executor_serialize_id;
     jmethodID executor_result_id;
+    jmethodID executor_reset_id;
     jmethodID executor_close_id;
     jmethodID executor_destroy_id;
 
@@ -502,11 +512,26 @@ public:
         }
     }
 
-    // TODO: reset function should be implement also in struct data
-    void reset(AggregateDataPtr /*place*/) const override {
-        LOG(WARNING) << " shouldn't going reset function, there maybe some 
error about function "
-                     << _fn.name.function_name;
-        throw doris::Exception(ErrorCode::INTERNAL_ERROR, "shouldn't going 
reset function");
+    void add_range_single_place(int64_t partition_start, int64_t 
partition_end, int64_t frame_start,
+                                int64_t frame_end, AggregateDataPtr place, 
const IColumn** columns,
+                                Arena* arena) const override {
+        frame_start = std::max<int64_t>(frame_start, partition_start);
+        frame_end = std::min<int64_t>(frame_end, partition_end);
+        int64_t places_address[1];
+        places_address[0] = reinterpret_cast<int64_t>(place);
+        Status st =
+                this->data(_exec_place)
+                        .add(places_address, true, columns, frame_start, 
frame_end, argument_types);
+        if (UNLIKELY(st != Status::OK())) {
+            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
+        }
+    }
+
+    void reset(AggregateDataPtr place) const override {
+        Status st = 
this->data(_exec_place).reset(reinterpret_cast<int64_t>(place));
+        if (UNLIKELY(st != Status::OK())) {
+            throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string());
+        }
     }
 
     void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
diff --git a/docs/en/docs/ecosystem/udf/java-user-defined-function.md 
b/docs/en/docs/ecosystem/udf/java-user-defined-function.md
index 4a0f2e0d5e..c6b57fed3c 100644
--- a/docs/en/docs/ecosystem/udf/java-user-defined-function.md
+++ b/docs/en/docs/ecosystem/udf/java-user-defined-function.md
@@ -130,6 +130,13 @@ public class SimpleDemo  {
         /* here could do some destroy work if needed */
     }
 
+    /*Not Required*/
+    public void reset(State state) {
+        /*if you want this udaf function can work with window function.*/
+        /*Must impl this, it will be reset to init state after calculate every 
window frame*/
+        state.sum = 0;
+    }
+
     /*required*/
     //first argument is State, then other types your input
     public void add(State state, Integer val) throws Exception {
diff --git a/docs/zh-CN/docs/ecosystem/udf/java-user-defined-function.md 
b/docs/zh-CN/docs/ecosystem/udf/java-user-defined-function.md
index 85080f4bf5..a30fcca614 100644
--- a/docs/zh-CN/docs/ecosystem/udf/java-user-defined-function.md
+++ b/docs/zh-CN/docs/ecosystem/udf/java-user-defined-function.md
@@ -130,6 +130,13 @@ public class SimpleDemo  {
         /* here could do some destroy work if needed */
     }
 
+    /*Not Required*/
+    public void reset(State state) {
+        /*if you want this udaf function can work with window function.*/
+        /*Must impl this, it will be reset to init state after calculate every 
window frame*/
+        state.sum = 0;
+    }
+
     /*required*/
     //first argument is State, then other types your input
     public void add(State state, Integer val) throws Exception {
diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java 
b/fe/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java
index dde6a0b084..f1d2e3a9a7 100644
--- a/fe/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java
+++ b/fe/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java
@@ -46,6 +46,7 @@ public abstract class BaseExecutor {
     public static final String UDAF_CREATE_FUNCTION = "create";
     public static final String UDAF_DESTROY_FUNCTION = "destroy";
     public static final String UDAF_ADD_FUNCTION = "add";
+    public static final String UDAF_RESET_FUNCTION = "reset";
     public static final String UDAF_SERIALIZE_FUNCTION = "serialize";
     public static final String UDAF_DESERIALIZE_FUNCTION = "deserialize";
     public static final String UDAF_MERGE_FUNCTION = "merge";
diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java 
b/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
index 1cba22a120..76e9a09c3d 100644
--- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
+++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java
@@ -71,10 +71,21 @@ public class UdafExecutor extends BaseExecutor {
         try {
             long idx = rowStart;
             do {
-                Long curPlace = UdfUtils.UNSAFE.getLong(null, 
UdfUtils.UNSAFE.getLong(null, inputPlacesPtr) + 8L * idx);
+                Long curPlace = null;
+                if (isSinglePlace) {
+                    curPlace = UdfUtils.UNSAFE.getLong(null, 
UdfUtils.UNSAFE.getLong(null, inputPlacesPtr));
+                } else {
+                    curPlace = UdfUtils.UNSAFE.getLong(null, 
UdfUtils.UNSAFE.getLong(null, inputPlacesPtr) + 8L * idx);
+                }
                 Object[] inputArgs = new Object[argTypes.length + 1];
-                stateObjMap.putIfAbsent(curPlace, createAggState());
-                inputArgs[0] = stateObjMap.get(curPlace);
+                Object state = stateObjMap.get(curPlace);
+                if (state != null) {
+                    inputArgs[0] = state;
+                } else {
+                    Object newState = createAggState();
+                    stateObjMap.put(curPlace, newState);
+                    inputArgs[0] = newState;
+                }
                 do {
                     Object[] inputObjects = allocateInputObjects(idx, 1);
                     for (int i = 0; i < argTypes.length; ++i) {
@@ -134,6 +145,23 @@ public class UdafExecutor extends BaseExecutor {
         }
     }
 
+    /*
+     * invoke reset function and reset the state to init.
+     */
+    public void reset(long place) throws UdfRuntimeException {
+        try {
+            Object[] args = new Object[1];
+            args[0] = stateObjMap.get((Long) place);
+            if (args[0] == null) {
+                return;
+            }
+            allMethods.get(UDAF_RESET_FUNCTION).invoke(udf, args);
+        } catch (Exception e) {
+            LOG.warn("invoke reset function meet some error: " + 
e.getCause().toString());
+            throw new UdfRuntimeException("UDAF failed to reset: ", e);
+        }
+    }
+
     /**
      * invoke merge function and it's have done deserialze.
      * here call deserialize first, and call merge.
@@ -147,8 +175,14 @@ public class UdafExecutor extends BaseExecutor {
             allMethods.get(UDAF_DESERIALIZE_FUNCTION).invoke(udf, args);
             args[1] = args[0];
             Long curPlace = place;
-            stateObjMap.putIfAbsent(curPlace, createAggState());
-            args[0] = stateObjMap.get(curPlace);
+            Object state = stateObjMap.get(curPlace);
+            if (state != null) {
+                args[0] = state;
+            } else {
+                Object newState = createAggState();
+                stateObjMap.put(curPlace, newState);
+                args[0] = newState;
+            }
             allMethods.get(UDAF_MERGE_FUNCTION).invoke(udf, args);
         } catch (Exception e) {
             LOG.warn("invoke merge function meet some error: " + 
e.getCause().toString());
@@ -226,6 +260,7 @@ public class UdafExecutor extends BaseExecutor {
                     case UDAF_CREATE_FUNCTION:
                     case UDAF_MERGE_FUNCTION:
                     case UDAF_SERIALIZE_FUNCTION:
+                    case UDAF_RESET_FUNCTION:
                     case UDAF_DESERIALIZE_FUNCTION: {
                         allMethods.put(methods[idx].getName(), methods[idx]);
                         break;
diff --git a/regression-test/data/javaudf_p0/test_javaudaf_mysum_int.out 
b/regression-test/data/javaudf_p0/test_javaudaf_mysum_int.out
index 47c14ac114..d5d4fc44aa 100644
--- a/regression-test/data/javaudf_p0/test_javaudaf_mysum_int.out
+++ b/regression-test/data/javaudf_p0/test_javaudaf_mysum_int.out
@@ -31,3 +31,36 @@
 2      6       6
 9      9       9
 
+-- !select5 --
+1
+2
+0
+1
+2
+0
+1
+2
+9
+
+-- !select6 --
+1
+2
+0
+1
+2
+0
+1
+2
+9
+
+-- !select7 --
+1
+2
+0
+1
+2
+0
+1
+2
+9
+
diff --git 
a/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumInt.java 
b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumInt.java
index ca23de7e9a..3037740e79 100644
--- 
a/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumInt.java
+++ 
b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumInt.java
@@ -31,6 +31,10 @@ public class MySumInt {
     public void destroy(State state) {
     }
 
+    public void reset(State state) {
+        state.counter = 0;
+    }
+
     public void add(State state, Integer val) {
         if (val == null) return;
         state.counter += val;
diff --git a/regression-test/suites/javaudf_p0/test_javaudaf_mysum_int.groovy 
b/regression-test/suites/javaudf_p0/test_javaudaf_mysum_int.groovy
index ad3f2cc532..a07233c43e 100644
--- a/regression-test/suites/javaudf_p0/test_javaudaf_mysum_int.groovy
+++ b/regression-test/suites/javaudf_p0/test_javaudaf_mysum_int.groovy
@@ -72,7 +72,9 @@ suite("test_javaudaf_mysum_int") {
 
         qt_select4 """ select user_id, udaf_my_sum_int(user_id), sum(user_id) 
from ${tableName} group by user_id order by user_id; """
         
-
+        qt_select5 """ select udaf_my_sum_int(user_id) over(partition by 
char_col) from test_javaudaf_mysum_int order by char_col; """
+        qt_select6 """ select udaf_my_sum_int(user_id) over(partition by 
char_col order by string_col) from test_javaudaf_mysum_int order by char_col; 
"""
+        qt_select7 """ select udaf_my_sum_int(user_id) over(partition by 
char_col order by string_col rows between 1 preceding and 1 following ) from 
test_javaudaf_mysum_int order by char_col; """
     } finally {
         try_sql("DROP FUNCTION IF EXISTS udaf_my_sum_int(int);")
         try_sql("DROP TABLE IF EXISTS ${tableName}")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to