This is an automated email from the ASF dual-hosted git repository. panxiaolei pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new c07e2ada43 [imporve](udaf) refactor java-udaf executor by using for loop (#21713) c07e2ada43 is described below commit c07e2ada43e870c17e4308477eae3979fff9a3e1 Author: zhangstar333 <87313068+zhangstar...@users.noreply.github.com> AuthorDate: Fri Jul 14 11:37:19 2023 +0800 [imporve](udaf) refactor java-udaf executor by using for loop (#21713) refactor java-udaf executor by using for loop --- .../aggregate_function_java_udaf.h | 119 ++++++---- fe/be-java-extensions/java-udf/pom.xml | 6 + .../java/org/apache/doris/udf/BaseExecutor.java | 181 ++++++++++++++ .../java/org/apache/doris/udf/UdafExecutor.java | 89 ++++++- .../main/java/org/apache/doris/udf/UdfConvert.java | 262 +++++++++++---------- .../java/org/apache/doris/udf/UdfExecutor.java | 165 +------------ 6 files changed, 485 insertions(+), 337 deletions(-) diff --git a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h index fa0c4efd9d..6fe4742064 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h +++ b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h @@ -128,63 +128,80 @@ public: return Status::OK(); } - Status add(const int64_t places_address[], bool is_single_place, const IColumn** columns, - size_t row_num_start, size_t row_num_end, const DataTypes& argument_types) { + Status add(int64_t places_address, bool is_single_place, const IColumn** columns, + int row_num_start, int row_num_end, const DataTypes& argument_types, + int place_offset) { JNIEnv* env = nullptr; RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf add function"); + jclass obj_class = env->FindClass("[Ljava/lang/Object;"); + jobjectArray arg_objects = env->NewObjectArray(argument_size, obj_class, nullptr); + int64_t nullmap_address = 0; + for (int arg_idx = 0; arg_idx < argument_size; ++arg_idx) { + bool arg_column_nullable = false; auto data_col = columns[arg_idx]; if (auto* nullable = check_and_get_column<const ColumnNullable>(*columns[arg_idx])) { + arg_column_nullable = true; + auto null_col = nullable->get_null_map_column_ptr(); data_col = nullable->get_nested_column_ptr(); - auto null_col = check_and_get_column<ColumnVector<UInt8>>( - nullable->get_null_map_column_ptr()); - input_nulls_buffer_ptr.get()[arg_idx] = - reinterpret_cast<int64_t>(null_col->get_data().data()); - } else { - input_nulls_buffer_ptr.get()[arg_idx] = -1; + nullmap_address = reinterpret_cast<int64_t>( + check_and_get_column<ColumnVector<UInt8>>(null_col)->get_data().data()); } - if (data_col->is_column_string()) { - const ColumnString* str_col = check_and_get_column<ColumnString>(data_col); - input_values_buffer_ptr.get()[arg_idx] = - reinterpret_cast<int64_t>(str_col->get_chars().data()); - input_offsets_ptrs.get()[arg_idx] = - reinterpret_cast<int64_t>(str_col->get_offsets().data()); - } else if (data_col->is_numeric() || data_col->is_column_decimal()) { - input_values_buffer_ptr.get()[arg_idx] = - reinterpret_cast<int64_t>(data_col->get_raw_data().data); + // convert argument column data into java type + jobjectArray arr_obj = nullptr; + if (data_col->is_numeric() || data_col->is_column_decimal()) { + arr_obj = (jobjectArray)env->CallObjectMethod( + executor_obj, executor_convert_basic_argument_id, arg_idx, + arg_column_nullable, row_num_start, row_num_end, nullmap_address, + reinterpret_cast<int64_t>(data_col->get_raw_data().data), 0); + } else if (data_col->is_column_string()) { + const ColumnString* str_col = assert_cast<const ColumnString*>(data_col); + arr_obj = (jobjectArray)env->CallObjectMethod( + executor_obj, executor_convert_basic_argument_id, arg_idx, + arg_column_nullable, row_num_start, row_num_end, nullmap_address, + reinterpret_cast<int64_t>(str_col->get_chars().data()), + reinterpret_cast<int64_t>(str_col->get_offsets().data())); } else if (data_col->is_column_array()) { const ColumnArray* array_col = assert_cast<const ColumnArray*>(data_col); - input_offsets_ptrs.get()[arg_idx] = reinterpret_cast<int64_t>( - array_col->get_offsets_column().get_raw_data().data); const ColumnNullable& array_nested_nullable = assert_cast<const ColumnNullable&>(array_col->get_data()); auto data_column_null_map = array_nested_nullable.get_null_map_column_ptr(); auto data_column = array_nested_nullable.get_nested_column_ptr(); - input_array_nulls_buffer_ptr.get()[arg_idx] = reinterpret_cast<int64_t>( + auto offset_address = reinterpret_cast<int64_t>( + array_col->get_offsets_column().get_raw_data().data); + auto nested_nullmap_address = reinterpret_cast<int64_t>( check_and_get_column<ColumnVector<UInt8>>(data_column_null_map) ->get_data() .data()); - - //need pass FE, nullamp and offset, chars + int64_t nested_data_address = 0, nested_offset_address = 0; + // array type need pass address: [nullmap_address], offset_address, nested_nullmap_address, nested_data_address/nested_char_address,nested_offset_address if (data_column->is_column_string()) { const ColumnString* col = assert_cast<const ColumnString*>(data_column.get()); - input_values_buffer_ptr.get()[arg_idx] = - reinterpret_cast<int64_t>(col->get_chars().data()); - input_array_string_offsets_ptrs.get()[arg_idx] = - reinterpret_cast<int64_t>(col->get_offsets().data()); + nested_data_address = reinterpret_cast<int64_t>(col->get_chars().data()); + nested_offset_address = reinterpret_cast<int64_t>(col->get_offsets().data()); } else { - input_values_buffer_ptr.get()[arg_idx] = + nested_data_address = reinterpret_cast<int64_t>(data_column->get_raw_data().data); } + arr_obj = (jobjectArray)env->CallObjectMethod( + executor_obj, executor_convert_array_argument_id, arg_idx, + arg_column_nullable, row_num_start, row_num_end, nullmap_address, + offset_address, nested_nullmap_address, nested_data_address, + nested_offset_address); } else { return Status::InvalidArgument( strings::Substitute("Java UDAF doesn't support type is $0 now !", argument_types[arg_idx]->get_name())); } + env->SetObjectArrayElement(arg_objects, arg_idx, arr_obj); + env->DeleteLocalRef(arr_obj); } - *input_place_ptrs = reinterpret_cast<int64_t>(places_address); - env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_add_id, is_single_place, - row_num_start, row_num_end); + RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); + // invoke add batch + env->CallObjectMethod(executor_obj, executor_add_batch_id, is_single_place, row_num_start, + row_num_end, places_address, place_offset, arg_objects); + env->DeleteLocalRef(arg_objects); + env->DeleteLocalRef(obj_class); return JniUtil::GetJniExceptionMsg(env); } @@ -392,6 +409,12 @@ private: register_id("getValue", UDAF_EXECUTOR_RESULT_SIGNATURE, executor_result_id)); RETURN_IF_ERROR( register_id("destroy", UDAF_EXECUTOR_DESTROY_SIGNATURE, executor_destroy_id)); + RETURN_IF_ERROR(register_id("convertBasicArguments", "(IZIIJJJ)[Ljava/lang/Object;", + executor_convert_basic_argument_id)); + RETURN_IF_ERROR(register_id("convertArrayArguments", "(IZIIJJJJJ)[Ljava/lang/Object;", + executor_convert_array_argument_id)); + RETURN_IF_ERROR( + register_id("addBatch", "(ZIIJI[Ljava/lang/Object;)V", executor_add_batch_id)); return Status::OK(); } @@ -403,12 +426,15 @@ private: jmethodID executor_ctor_id; jmethodID executor_add_id; + jmethodID executor_add_batch_id; jmethodID executor_merge_id; jmethodID executor_serialize_id; jmethodID executor_result_id; jmethodID executor_reset_id; jmethodID executor_close_id; jmethodID executor_destroy_id; + jmethodID executor_convert_basic_argument_id; + jmethodID executor_convert_array_argument_id; std::unique_ptr<int64_t[]> input_values_buffer_ptr; std::unique_ptr<int64_t[]> input_nulls_buffer_ptr; @@ -481,11 +507,10 @@ public: void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, Arena*) const override { - int64_t places_address[1]; - places_address[0] = reinterpret_cast<int64_t>(place); - Status st = - this->data(_exec_place) - .add(places_address, true, columns, row_num, row_num + 1, argument_types); + int64_t places_address = reinterpret_cast<int64_t>(place); + Status st = this->data(_exec_place) + .add(places_address, true, columns, row_num, row_num + 1, + argument_types, 0); if (UNLIKELY(st != Status::OK())) { throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); } @@ -493,25 +518,20 @@ public: void add_batch(size_t batch_size, AggregateDataPtr* places, size_t place_offset, const IColumn** columns, Arena* /*arena*/, bool /*agg_many*/) const override { - int64_t places_address[batch_size]; - for (size_t i = 0; i < batch_size; ++i) { - places_address[i] = reinterpret_cast<int64_t>(places[i] + place_offset); - } + int64_t places_address = reinterpret_cast<int64_t>(places); Status st = this->data(_exec_place) - .add(places_address, false, columns, 0, batch_size, argument_types); + .add(places_address, false, columns, 0, batch_size, argument_types, + place_offset); if (UNLIKELY(st != Status::OK())) { throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); } } - // TODO: Here we calling method by jni, And if we get a thrown from FE, - // But can't let user known the error, only return directly and output error to log file. void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns, Arena* /*arena*/) const override { - int64_t places_address[1]; - places_address[0] = reinterpret_cast<int64_t>(place); + int64_t places_address = reinterpret_cast<int64_t>(place); Status st = this->data(_exec_place) - .add(places_address, true, columns, 0, batch_size, argument_types); + .add(places_address, true, columns, 0, batch_size, argument_types, 0); if (UNLIKELY(st != Status::OK())) { throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); } @@ -522,11 +542,10 @@ public: Arena* arena) const override { frame_start = std::max<int64_t>(frame_start, partition_start); frame_end = std::min<int64_t>(frame_end, partition_end); - int64_t places_address[1]; - places_address[0] = reinterpret_cast<int64_t>(place); - Status st = - this->data(_exec_place) - .add(places_address, true, columns, frame_start, frame_end, argument_types); + int64_t places_address = reinterpret_cast<int64_t>(place); + Status st = this->data(_exec_place) + .add(places_address, true, columns, frame_start, frame_end, + argument_types, 0); if (UNLIKELY(st != Status::OK())) { throw doris::Exception(ErrorCode::INTERNAL_ERROR, st.to_string()); } diff --git a/fe/be-java-extensions/java-udf/pom.xml b/fe/be-java-extensions/java-udf/pom.xml index bf05aeafc2..67921aa2cf 100644 --- a/fe/be-java-extensions/java-udf/pom.xml +++ b/fe/be-java-extensions/java-udf/pom.xml @@ -41,6 +41,12 @@ under the License. <artifactId>java-common</artifactId> <version>${project.version}</version> </dependency> + <!-- https://mvnrepository.com/artifact/com.esotericsoftware/reflectasm --> + <dependency> + <groupId>com.esotericsoftware</groupId> + <artifactId>reflectasm</artifactId> + <version>1.11.9</version> + </dependency> </dependencies> <build> diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java index 3dbe10ca27..ef405197d6 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java @@ -25,12 +25,14 @@ import org.apache.doris.common.jni.utils.UdfUtils; import org.apache.doris.common.jni.utils.UdfUtils.JavaUdfDataType; import org.apache.doris.thrift.TJavaUdfExecutorCtorParams; +import com.google.common.base.Preconditions; import org.apache.log4j.Logger; import org.apache.thrift.TDeserializer; import org.apache.thrift.TException; import org.apache.thrift.protocol.TBinaryProtocol; import java.io.IOException; +import java.lang.reflect.Array; import java.math.BigDecimal; import java.math.BigInteger; import java.math.RoundingMode; @@ -1021,4 +1023,183 @@ public abstract class BaseExecutor { protected void updateOutputOffset(long offset) { } + + public Object[] convertBasicArg(boolean isUdf, int argIdx, boolean isNullable, int rowStart, int rowEnd, + long nullMapAddr, long columnAddr, long strOffsetAddr) { + switch (argTypes[argIdx]) { + case BOOLEAN: + return UdfConvert.convertBooleanArg(isNullable, rowStart, rowEnd, nullMapAddr, columnAddr); + case TINYINT: + return UdfConvert.convertTinyIntArg(isNullable, rowStart, rowEnd, nullMapAddr, columnAddr); + case SMALLINT: + return UdfConvert.convertSmallIntArg(isNullable, rowStart, rowEnd, nullMapAddr, columnAddr); + case INT: + return UdfConvert.convertIntArg(isNullable, rowStart, rowEnd, nullMapAddr, columnAddr); + case BIGINT: + return UdfConvert.convertBigIntArg(isNullable, rowStart, rowEnd, nullMapAddr, columnAddr); + case LARGEINT: + return UdfConvert.convertLargeIntArg(isNullable, rowStart, rowEnd, nullMapAddr, columnAddr); + case FLOAT: + return UdfConvert.convertFloatArg(isNullable, rowStart, rowEnd, nullMapAddr, columnAddr); + case DOUBLE: + return UdfConvert.convertDoubleArg(isNullable, rowStart, rowEnd, nullMapAddr, columnAddr); + case CHAR: + case VARCHAR: + case STRING: + return UdfConvert + .convertStringArg(isNullable, rowStart, rowEnd, nullMapAddr, columnAddr, strOffsetAddr); + case DATE: // udaf maybe argClass[i + argClassOffset] need add +1 + return UdfConvert + .convertDateArg(isUdf ? argClass[argIdx] : argClass[argIdx + 1], isNullable, rowStart, rowEnd, + nullMapAddr, columnAddr); + case DATETIME: + return UdfConvert + .convertDateTimeArg(isUdf ? argClass[argIdx] : argClass[argIdx + 1], isNullable, rowStart, + rowEnd, nullMapAddr, columnAddr); + case DATEV2: + return UdfConvert + .convertDateV2Arg(isUdf ? argClass[argIdx] : argClass[argIdx + 1], isNullable, rowStart, rowEnd, + nullMapAddr, columnAddr); + case DATETIMEV2: + return UdfConvert + .convertDateTimeV2Arg(isUdf ? argClass[argIdx] : argClass[argIdx + 1], isNullable, rowStart, + rowEnd, nullMapAddr, columnAddr); + case DECIMALV2: + case DECIMAL128: + return UdfConvert + .convertDecimalArg(argTypes[argIdx].getScale(), 16L, isNullable, rowStart, rowEnd, nullMapAddr, + columnAddr); + case DECIMAL32: + return UdfConvert + .convertDecimalArg(argTypes[argIdx].getScale(), 4L, isNullable, rowStart, rowEnd, nullMapAddr, + columnAddr); + case DECIMAL64: + return UdfConvert + .convertDecimalArg(argTypes[argIdx].getScale(), 8L, isNullable, rowStart, rowEnd, nullMapAddr, + columnAddr); + default: { + LOG.info("Not support type: " + argTypes[argIdx].toString()); + Preconditions.checkState(false, "Not support type: " + argTypes[argIdx].toString()); + break; + } + } + return null; + } + + public Object[] convertArrayArg(int argIdx, boolean isNullable, int rowStart, int rowEnd, long nullMapAddr, + long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) { + Object[] argument = (Object[]) Array.newInstance(ArrayList.class, rowEnd - rowStart); + for (int row = rowStart; row < rowEnd; ++row) { + long offsetStart = UdfUtils.UNSAFE.getLong(null, offsetsAddr + 8L * (row - 1)); + long offsetEnd = UdfUtils.UNSAFE.getLong(null, offsetsAddr + 8L * (row)); + int currentRowNum = (int) (offsetEnd - offsetStart); + switch (argTypes[argIdx].getItemType().getPrimitiveType()) { + case BOOLEAN: { + argument[row - rowStart] = UdfConvert + .convertArrayBooleanArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr, + nestedNullMapAddr, dataAddr); + break; + } + case TINYINT: { + argument[row - rowStart] = UdfConvert + .convertArrayTinyIntArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr, + nestedNullMapAddr, dataAddr); + break; + } + case SMALLINT: { + argument[row - rowStart] = UdfConvert + .convertArraySmallIntArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr, + nestedNullMapAddr, dataAddr); + break; + } + case INT: { + argument[row - rowStart] = UdfConvert + .convertArrayIntArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr, + nestedNullMapAddr, dataAddr); + break; + } + case BIGINT: { + argument[row - rowStart] = UdfConvert + .convertArrayBigIntArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr, + nestedNullMapAddr, dataAddr); + break; + } + case LARGEINT: { + argument[row - rowStart] = UdfConvert + .convertArrayLargeIntArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr, + nestedNullMapAddr, dataAddr); + break; + } + case FLOAT: { + argument[row - rowStart] = UdfConvert + .convertArrayFloatArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr, + nestedNullMapAddr, dataAddr); + break; + } + case DOUBLE: { + argument[row - rowStart] = UdfConvert + .convertArrayDoubleArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr, + nestedNullMapAddr, dataAddr); + break; + } + case CHAR: + case VARCHAR: + case STRING: { + argument[row - rowStart] = UdfConvert + .convertArrayStringArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr, + nestedNullMapAddr, dataAddr, strOffsetAddr); + break; + } + case DATE: { + argument[row - rowStart] = UdfConvert + .convertArrayDateArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr, + nestedNullMapAddr, dataAddr); + break; + } + case DATETIME: { + argument[row - rowStart] = UdfConvert + .convertArrayDateTimeArg(row, currentRowNum, offsetStart, isNullable, nullMapAddr, + nestedNullMapAddr, dataAddr); + break; + } + case DATEV2: { + argument[row - rowStart] = UdfConvert + .convertArrayDateV2Arg(row, currentRowNum, offsetStart, isNullable, nullMapAddr, + nestedNullMapAddr, dataAddr); + break; + } + case DATETIMEV2: { + argument[row - rowStart] = UdfConvert + .convertArrayDateTimeV2Arg(row, currentRowNum, offsetStart, isNullable, + nullMapAddr, nestedNullMapAddr, dataAddr); + break; + } + case DECIMALV2: + case DECIMAL128: { + argument[row - rowStart] = UdfConvert + .convertArrayDecimalArg(argTypes[argIdx].getScale(), 16L, row, currentRowNum, + offsetStart, isNullable, nullMapAddr, nestedNullMapAddr, dataAddr); + break; + } + case DECIMAL32: { + argument[row - rowStart] = UdfConvert + .convertArrayDecimalArg(argTypes[argIdx].getScale(), 4L, row, currentRowNum, + offsetStart, isNullable, nullMapAddr, nestedNullMapAddr, dataAddr); + break; + } + case DECIMAL64: { + argument[row - rowStart] = UdfConvert + .convertArrayDecimalArg(argTypes[argIdx].getScale(), 8L, row, currentRowNum, + offsetStart, isNullable, nullMapAddr, nestedNullMapAddr, dataAddr); + break; + } + default: { + LOG.info("Not support: " + argTypes[argIdx]); + Preconditions.checkState(false, "Not support type " + argTypes[argIdx].toString()); + break; + } + } + } + return argument; + } } diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java index e2d8ab1b75..a0736b5a72 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java @@ -24,6 +24,7 @@ import org.apache.doris.common.jni.utils.UdfUtils; import org.apache.doris.common.jni.utils.UdfUtils.JavaUdfDataType; import org.apache.doris.thrift.TJavaUdfExecutorCtorParams; +import com.esotericsoftware.reflectasm.MethodAccess; import com.google.common.base.Joiner; import com.google.common.collect.Lists; import org.apache.log4j.Logger; @@ -36,6 +37,7 @@ import java.lang.reflect.Constructor; import java.lang.reflect.Method; import java.net.MalformedURLException; import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; /** @@ -49,6 +51,8 @@ public class UdafExecutor extends BaseExecutor { private HashMap<String, Method> allMethods; private HashMap<Long, Object> stateObjMap; private Class retClass; + private int addIndex; + private MethodAccess methodAccess; /** * Constructor to create an object. @@ -66,6 +70,84 @@ public class UdafExecutor extends BaseExecutor { super.close(); } + public Object[] convertBasicArguments(int argIdx, boolean isNullable, int rowStart, int rowEnd, long nullMapAddr, + long columnAddr, long strOffsetAddr) { + return convertBasicArg(false, argIdx, isNullable, rowStart, rowEnd, nullMapAddr, columnAddr, strOffsetAddr); + } + + public Object[] convertArrayArguments(int argIdx, boolean isNullable, int rowStart, int rowEnd, long nullMapAddr, + long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) { + return convertArrayArg(argIdx, isNullable, rowStart, rowEnd, nullMapAddr, offsetsAddr, nestedNullMapAddr, + dataAddr, strOffsetAddr); + } + + public void addBatch(boolean isSinglePlace, int rowStart, int rowEnd, long placeAddr, int offset, Object[] column) + throws UdfRuntimeException { + if (isSinglePlace) { + addBatchSingle(rowStart, rowEnd, placeAddr, column); + } else { + addBatchPlaces(rowStart, rowEnd, placeAddr, offset, column); + } + } + + public void addBatchSingle(int rowStart, int rowEnd, long placeAddr, Object[] column) throws UdfRuntimeException { + try { + Long curPlace = placeAddr; + Object[] inputArgs = new Object[argTypes.length + 1]; + Object state = stateObjMap.get(curPlace); + if (state != null) { + inputArgs[0] = state; + } else { + Object newState = createAggState(); + stateObjMap.put(curPlace, newState); + inputArgs[0] = newState; + } + + Object[][] inputs = (Object[][]) column; + for (int i = 0; i < (rowEnd - rowStart); ++i) { + for (int j = 0; j < column.length; ++j) { + inputArgs[j + 1] = inputs[j][i]; + } + methodAccess.invoke(udf, addIndex, inputArgs); + } + } catch (Exception e) { + LOG.warn("invoke add function meet some error: " + e.getCause().toString()); + throw new UdfRuntimeException("UDAF failed to addBatchSingle: ", e); + } + } + + public void addBatchPlaces(int rowStart, int rowEnd, long placeAddr, int offset, Object[] column) + throws UdfRuntimeException { + try { + Object[][] inputs = (Object[][]) column; + ArrayList<Object> placeState = new ArrayList<>(rowEnd - rowStart); + for (int row = rowStart; row < rowEnd; ++row) { + Long curPlace = UdfUtils.UNSAFE.getLong(null, placeAddr + (8L * row)) + offset; + Object state = stateObjMap.get(curPlace); + if (state != null) { + placeState.add(state); + } else { + Object newState = createAggState(); + stateObjMap.put(curPlace, newState); + placeState.add(newState); + } + } + //spilt into two for loop + + Object[] inputArgs = new Object[argTypes.length + 1]; + for (int row = 0; row < (rowEnd - rowStart); ++row) { + inputArgs[0] = placeState.get(row); + for (int j = 0; j < column.length; ++j) { + inputArgs[j + 1] = inputs[j][row]; + } + methodAccess.invoke(udf, addIndex, inputArgs); + } + } catch (Exception e) { + LOG.warn("invoke add function meet some error: " + Arrays.toString(e.getStackTrace())); + throw new UdfRuntimeException("UDAF failed to addBatchPlaces: ", e); + } + } + /** * invoke add function, add row in loop [rowStart, rowEnd). */ @@ -224,10 +306,10 @@ public class UdafExecutor extends BaseExecutor { protected long getCurrentOutputOffset(long row, boolean isArrayType) { if (isArrayType) { return Integer.toUnsignedLong( - UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * (row - 1))); + UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * (row - 1))); } else { return Integer.toUnsignedLong( - UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * (row - 1))); + UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * (row - 1))); } } @@ -251,6 +333,7 @@ public class UdafExecutor extends BaseExecutor { loader = ClassLoader.getSystemClassLoader(); } Class<?> c = Class.forName(className, true, loader); + methodAccess = MethodAccess.get(c); Constructor<?> ctor = c.getConstructor(); udf = ctor.newInstance(); Method[] methods = c.getDeclaredMethods(); @@ -281,7 +364,7 @@ public class UdafExecutor extends BaseExecutor { } case UDAF_ADD_FUNCTION: { allMethods.put(methods[idx].getName(), methods[idx]); - + addIndex = methodAccess.getIndex(UDAF_ADD_FUNCTION); argClass = methods[idx].getParameterTypes(); if (argClass.length != parameterTypes.length + 1) { LOG.debug("add function parameterTypes length not equal " + argClass.length + " " diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfConvert.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfConvert.java index 4519b23a54..fb2ead5a3f 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfConvert.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfConvert.java @@ -37,263 +37,269 @@ import java.util.Arrays; public class UdfConvert { private static final Logger LOG = Logger.getLogger(UdfConvert.class); - public static Object[] convertBooleanArg(boolean isNullable, int numRows, long nullMapAddr, long columnAddr) { - Boolean[] argument = new Boolean[numRows]; + public static Object[] convertBooleanArg(boolean isNullable, int rowsStart, int rowsEnd, long nullMapAddr, + long columnAddr) { + Boolean[] argument = new Boolean[rowsEnd - rowsStart]; if (isNullable) { - for (int i = 0; i < numRows; ++i) { + for (int i = rowsStart; i < rowsEnd; ++i) { if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) { - argument[i] = UdfUtils.UNSAFE.getBoolean(null, columnAddr + i); + argument[i - rowsStart] = UdfUtils.UNSAFE.getBoolean(null, columnAddr + i); } // else is the current row is null } } else { - for (int i = 0; i < numRows; ++i) { - argument[i] = UdfUtils.UNSAFE.getBoolean(null, columnAddr + i); + for (int i = rowsStart; i < rowsEnd; ++i) { + argument[i - rowsStart] = UdfUtils.UNSAFE.getBoolean(null, columnAddr + i); } } return argument; } - public static Object[] convertTinyIntArg(boolean isNullable, int numRows, long nullMapAddr, long columnAddr) { - Byte[] argument = new Byte[numRows]; + public static Object[] convertTinyIntArg(boolean isNullable, int rowsStart, int rowsEnd, long nullMapAddr, + long columnAddr) { + Byte[] argument = new Byte[rowsEnd - rowsStart]; if (isNullable) { - for (int i = 0; i < numRows; ++i) { + for (int i = rowsStart; i < rowsEnd; ++i) { if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) { - argument[i] = UdfUtils.UNSAFE.getByte(null, columnAddr + i); + argument[i - rowsStart] = UdfUtils.UNSAFE.getByte(null, columnAddr + i); } // else is the current row is null } } else { - for (int i = 0; i < numRows; ++i) { - argument[i] = UdfUtils.UNSAFE.getByte(null, columnAddr + i); + for (int i = rowsStart; i < rowsEnd; ++i) { + argument[i - rowsStart] = UdfUtils.UNSAFE.getByte(null, columnAddr + i); } } return argument; } - public static Object[] convertSmallIntArg(boolean isNullable, int numRows, long nullMapAddr, long columnAddr) { - Short[] argument = new Short[numRows]; + public static Object[] convertSmallIntArg(boolean isNullable, int rowsStart, int rowsEnd, long nullMapAddr, + long columnAddr) { + Short[] argument = new Short[rowsEnd - rowsStart]; if (isNullable) { - for (int i = 0; i < numRows; ++i) { + for (int i = rowsStart; i < rowsEnd; ++i) { if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) { - argument[i] = UdfUtils.UNSAFE.getShort(null, columnAddr + (i * 2L)); + argument[i - rowsStart] = UdfUtils.UNSAFE.getShort(null, columnAddr + (i * 2L)); } // else is the current row is null } } else { - for (int i = 0; i < numRows; ++i) { - argument[i] = UdfUtils.UNSAFE.getShort(null, columnAddr + (i * 2L)); + for (int i = rowsStart; i < rowsEnd; ++i) { + argument[i - rowsStart] = UdfUtils.UNSAFE.getShort(null, columnAddr + (i * 2L)); } } return argument; } - public static Object[] convertIntArg(boolean isNullable, int numRows, long nullMapAddr, long columnAddr) { - Integer[] argument = new Integer[numRows]; + public static Object[] convertIntArg(boolean isNullable, int rowsStart, int rowsEnd, long nullMapAddr, + long columnAddr) { + Integer[] argument = new Integer[rowsEnd - rowsStart]; if (isNullable) { - for (int i = 0; i < numRows; ++i) { + for (int i = rowsStart; i < rowsEnd; ++i) { if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) { - argument[i] = UdfUtils.UNSAFE.getInt(null, columnAddr + (i * 4L)); + argument[i - rowsStart] = UdfUtils.UNSAFE.getInt(null, columnAddr + (i * 4L)); } // else is the current row is null } } else { - for (int i = 0; i < numRows; ++i) { - argument[i] = UdfUtils.UNSAFE.getInt(null, columnAddr + (i * 4L)); + for (int i = rowsStart; i < rowsEnd; ++i) { + argument[i - rowsStart] = UdfUtils.UNSAFE.getInt(null, columnAddr + (i * 4L)); } } return argument; } - public static Object[] convertBigIntArg(boolean isNullable, int numRows, long nullMapAddr, long columnAddr) { - Long[] argument = new Long[numRows]; + public static Object[] convertBigIntArg(boolean isNullable, int rowsStart, int rowsEnd, long nullMapAddr, + long columnAddr) { + Long[] argument = new Long[rowsEnd - rowsStart]; if (isNullable) { - for (int i = 0; i < numRows; ++i) { + for (int i = rowsStart; i < rowsEnd; ++i) { if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) { - argument[i] = UdfUtils.UNSAFE.getLong(null, columnAddr + (i * 8L)); + argument[i - rowsStart] = UdfUtils.UNSAFE.getLong(null, columnAddr + (i * 8L)); } // else is the current row is null } } else { - for (int i = 0; i < numRows; ++i) { - argument[i] = UdfUtils.UNSAFE.getLong(null, columnAddr + (i * 8L)); + for (int i = rowsStart; i < rowsEnd; ++i) { + argument[i - rowsStart] = UdfUtils.UNSAFE.getLong(null, columnAddr + (i * 8L)); } } return argument; } - public static Object[] convertFloatArg(boolean isNullable, int numRows, long nullMapAddr, long columnAddr) { - Float[] argument = new Float[numRows]; + public static Object[] convertFloatArg(boolean isNullable, int rowsStart, int rowsEnd, long nullMapAddr, + long columnAddr) { + Float[] argument = new Float[rowsEnd - rowsStart]; if (isNullable) { - for (int i = 0; i < numRows; ++i) { + for (int i = rowsStart; i < rowsEnd; ++i) { if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) { - argument[i] = UdfUtils.UNSAFE.getFloat(null, columnAddr + (i * 4L)); + argument[i - rowsStart] = UdfUtils.UNSAFE.getFloat(null, columnAddr + (i * 4L)); } // else is the current row is null } } else { - for (int i = 0; i < numRows; ++i) { - argument[i] = UdfUtils.UNSAFE.getFloat(null, columnAddr + (i * 4L)); + for (int i = rowsStart; i < rowsEnd; ++i) { + argument[i - rowsStart] = UdfUtils.UNSAFE.getFloat(null, columnAddr + (i * 4L)); } } return argument; } - public static Object[] convertDoubleArg(boolean isNullable, int numRows, long nullMapAddr, long columnAddr) { - Double[] argument = new Double[numRows]; + public static Object[] convertDoubleArg(boolean isNullable, int rowsStart, int rowsEnd, long nullMapAddr, + long columnAddr) { + Double[] argument = new Double[rowsEnd - rowsStart]; if (isNullable) { - for (int i = 0; i < numRows; ++i) { + for (int i = rowsStart; i < rowsEnd; ++i) { if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) { - argument[i] = UdfUtils.UNSAFE.getDouble(null, columnAddr + (i * 8L)); + argument[i - rowsStart] = UdfUtils.UNSAFE.getDouble(null, columnAddr + (i * 8L)); } // else is the current row is null } } else { - for (int i = 0; i < numRows; ++i) { - argument[i] = UdfUtils.UNSAFE.getDouble(null, columnAddr + (i * 8L)); + for (int i = rowsStart; i < rowsEnd; ++i) { + argument[i - rowsStart] = UdfUtils.UNSAFE.getDouble(null, columnAddr + (i * 8L)); } } return argument; } - public static Object[] convertDateArg(Class argTypeClass, boolean isNullable, int numRows, long nullMapAddr, - long columnAddr) { - Object[] argument = (Object[]) Array.newInstance(argTypeClass, numRows); + public static Object[] convertDateArg(Class argTypeClass, boolean isNullable, int rowsStart, int rowsEnd, + long nullMapAddr, long columnAddr) { + Object[] argument = (Object[]) Array.newInstance(argTypeClass, rowsEnd - rowsStart); if (isNullable) { - for (int i = 0; i < numRows; ++i) { + for (int i = rowsStart; i < rowsEnd; ++i) { if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) { long value = UdfUtils.UNSAFE.getLong(null, columnAddr + (i * 8L)); - argument[i] = UdfUtils.convertDateToJavaDate(value, argTypeClass); + argument[i - rowsStart] = UdfUtils.convertDateToJavaDate(value, argTypeClass); } // else is the current row is null } } else { - for (int i = 0; i < numRows; ++i) { + for (int i = rowsStart; i < rowsEnd; ++i) { long value = UdfUtils.UNSAFE.getLong(null, columnAddr + (i * 8L)); - argument[i] = UdfUtils.convertDateToJavaDate(value, argTypeClass); + argument[i - rowsStart] = UdfUtils.convertDateToJavaDate(value, argTypeClass); } } return argument; } - public static Object[] convertDateTimeArg(Class argTypeClass, boolean isNullable, int numRows, long nullMapAddr, - long columnAddr) { - Object[] argument = (Object[]) Array.newInstance(argTypeClass, numRows); + public static Object[] convertDateTimeArg(Class argTypeClass, boolean isNullable, int rowsStart, int rowsEnd, + long nullMapAddr, long columnAddr) { + Object[] argument = (Object[]) Array.newInstance(argTypeClass, rowsEnd - rowsStart); if (isNullable) { - for (int i = 0; i < numRows; ++i) { + for (int i = rowsStart; i < rowsEnd; ++i) { if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) { long value = UdfUtils.UNSAFE.getLong(null, columnAddr + (i * 8L)); - argument[i] = UdfUtils + argument[i - rowsStart] = UdfUtils .convertDateTimeToJavaDateTime(value, argTypeClass); } // else is the current row is null } } else { - for (int i = 0; i < numRows; ++i) { + for (int i = rowsStart; i < rowsEnd; ++i) { long value = UdfUtils.UNSAFE.getLong(null, columnAddr + (i * 8L)); - argument[i] = UdfUtils.convertDateTimeToJavaDateTime(value, argTypeClass); + argument[i - rowsStart] = UdfUtils.convertDateTimeToJavaDateTime(value, argTypeClass); } } return argument; } - public static Object[] convertDateV2Arg(Class argTypeClass, boolean isNullable, int numRows, long nullMapAddr, - long columnAddr) { - Object[] argument = (Object[]) Array.newInstance(argTypeClass, numRows); + public static Object[] convertDateV2Arg(Class argTypeClass, boolean isNullable, int rowsStart, int rowsEnd, + long nullMapAddr, long columnAddr) { + Object[] argument = (Object[]) Array.newInstance(argTypeClass, rowsEnd - rowsStart); if (isNullable) { - for (int i = 0; i < numRows; ++i) { + for (int i = rowsStart; i < rowsEnd; ++i) { if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) { int value = UdfUtils.UNSAFE.getInt(null, columnAddr + (i * 4L)); - argument[i] = UdfUtils.convertDateV2ToJavaDate(value, argTypeClass); + argument[i - rowsStart] = UdfUtils.convertDateV2ToJavaDate(value, argTypeClass); } // else is the current row is null } } else { - for (int i = 0; i < numRows; ++i) { + for (int i = rowsStart; i < rowsEnd; ++i) { int value = UdfUtils.UNSAFE.getInt(null, columnAddr + (i * 4L)); - argument[i] = UdfUtils.convertDateV2ToJavaDate(value, argTypeClass); + argument[i - rowsStart] = UdfUtils.convertDateV2ToJavaDate(value, argTypeClass); } } return argument; } - public static Object[] convertDateTimeV2Arg(Class argTypeClass, boolean isNullable, int numRows, long nullMapAddr, - long columnAddr) { - Object[] argument = (Object[]) Array.newInstance(argTypeClass, numRows); + public static Object[] convertDateTimeV2Arg(Class argTypeClass, boolean isNullable, int rowsStart, int rowsEnd, + long nullMapAddr, long columnAddr) { + Object[] argument = (Object[]) Array.newInstance(argTypeClass, rowsEnd - rowsStart); if (isNullable) { - for (int i = 0; i < numRows; ++i) { + for (int i = rowsStart; i < rowsEnd; ++i) { if (UdfUtils.UNSAFE.getByte(null, nullMapAddr + i) == 0) { long value = UdfUtils.UNSAFE.getLong(columnAddr + (i * 8L)); - argument[i] = UdfUtils + argument[i - rowsStart] = UdfUtils .convertDateTimeV2ToJavaDateTime(value, argTypeClass); } // else is the current row is null } } else { - for (int i = 0; i < numRows; ++i) { + for (int i = rowsStart; i < rowsEnd; ++i) { long value = UdfUtils.UNSAFE.getLong(null, columnAddr + (i * 8L)); - argument[i] = UdfUtils + argument[i - rowsStart] = UdfUtils .convertDateTimeV2ToJavaDateTime(value, argTypeClass); } } return argument; } - public static Object[] convertLargeIntArg(boolean isNullable, int numRows, long nullMapAddr, long columnAddr) { - BigInteger[] argument = new BigInteger[numRows]; + public static Object[] convertLargeIntArg(boolean isNullable, int rowsStart, int rowsEnd, long nullMapAddr, + long columnAddr) { + BigInteger[] argument = new BigInteger[rowsEnd - rowsStart]; byte[] bytes = new byte[16]; if (isNullable) { - for (int i = 0; i < numRows; ++i) { + for (int i = rowsStart; i < rowsEnd; ++i) { if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) { UdfUtils.copyMemory(null, columnAddr + (i * 16L), bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16); - argument[i] = new BigInteger(UdfUtils.convertByteOrder(bytes)); + argument[i - rowsStart] = new BigInteger(UdfUtils.convertByteOrder(bytes)); } // else is the current row is null } } else { - for (int i = 0; i < numRows; ++i) { + for (int i = rowsStart; i < rowsEnd; ++i) { UdfUtils.copyMemory(null, columnAddr + (i * 16L), bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16); - argument[i] = new BigInteger(UdfUtils.convertByteOrder(bytes)); + argument[i - rowsStart] = new BigInteger(UdfUtils.convertByteOrder(bytes)); } } return argument; } - public static Object[] convertDecimalArg(int scale, long typeLen, boolean isNullable, int numRows, long nullMapAddr, - long columnAddr) { - BigDecimal[] argument = new BigDecimal[numRows]; + public static Object[] convertDecimalArg(int scale, long typeLen, boolean isNullable, int rowsStart, int rowsEnd, + long nullMapAddr, long columnAddr) { + BigDecimal[] argument = new BigDecimal[rowsEnd - rowsStart]; byte[] bytes = new byte[(int) typeLen]; if (isNullable) { - for (int i = 0; i < numRows; ++i) { + for (int i = rowsStart; i < rowsEnd; ++i) { if (UdfUtils.UNSAFE.getByte(nullMapAddr + i) == 0) { UdfUtils.copyMemory(null, columnAddr + (i * typeLen), bytes, UdfUtils.BYTE_ARRAY_OFFSET, typeLen); BigInteger bigInteger = new BigInteger(UdfUtils.convertByteOrder(bytes)); - argument[i] = new BigDecimal(bigInteger, scale); //show to pass scale info + argument[i - rowsStart] = new BigDecimal(bigInteger, scale); //show to pass scale info } // else is the current row is null } } else { - for (int i = 0; i < numRows; ++i) { + for (int i = rowsStart; i < rowsEnd; ++i) { UdfUtils.copyMemory(null, columnAddr + (i * typeLen), bytes, UdfUtils.BYTE_ARRAY_OFFSET, typeLen); BigInteger bigInteger = new BigInteger(UdfUtils.convertByteOrder(bytes)); - argument[i] = new BigDecimal(bigInteger, scale); + argument[i - rowsStart] = new BigDecimal(bigInteger, scale); } } return argument; } - public static Object[] convertStringArg(boolean isNullable, int numRows, long nullMapAddr, + public static Object[] convertStringArg(boolean isNullable, int rowsStart, int rowsEnd, long nullMapAddr, long charsAddr, long offsetsAddr) { - String[] argument = new String[numRows]; + String[] argument = new String[rowsEnd - rowsStart]; Preconditions.checkState(UdfUtils.UNSAFE.getInt(null, offsetsAddr + 4L * (0 - 1)) == 0, "offsetsAddr[-1] should be 0;"); - + final int totalLen = UdfUtils.UNSAFE.getInt(null, offsetsAddr + (rowsEnd - 1) * 4L); + byte[] bytes = new byte[totalLen]; + UdfUtils.copyMemory(null, charsAddr, bytes, UdfUtils.BYTE_ARRAY_OFFSET, totalLen); if (isNullable) { - for (int row = 0; row < numRows; ++row) { + for (int row = rowsStart; row < rowsEnd; ++row) { if (UdfUtils.UNSAFE.getByte(nullMapAddr + row) == 0) { - int offset = UdfUtils.UNSAFE.getInt(null, offsetsAddr + row * 4L); - int numBytes = offset - UdfUtils.UNSAFE.getInt(null, offsetsAddr + 4L * (row - 1)); - long base = charsAddr + offset - numBytes; - byte[] bytes = new byte[numBytes]; - UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes); - argument[row] = new String(bytes, StandardCharsets.UTF_8); + int prevOffset = UdfUtils.UNSAFE.getInt(null, offsetsAddr + 4L * (row - 1)); + int currOffset = UdfUtils.UNSAFE.getInt(null, offsetsAddr + row * 4L); + argument[row - rowsStart] = new String(bytes, prevOffset, currOffset - prevOffset, + StandardCharsets.UTF_8); } // else is the current row is null } } else { - for (int row = 0; row < numRows; ++row) { - int offset = UdfUtils.UNSAFE.getInt(null, offsetsAddr + row * 4L); - int numBytes = offset - UdfUtils.UNSAFE.getInt(null, offsetsAddr + 4L * (row - 1)); - long base = charsAddr + offset - numBytes; - byte[] bytes = new byte[numBytes]; - UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes); - argument[row] = new String(bytes, StandardCharsets.UTF_8); + for (int row = rowsStart; row < rowsEnd; ++row) { + int prevOffset = UdfUtils.UNSAFE.getInt(null, offsetsAddr + 4L * (row - 1)); + int currOffset = UdfUtils.UNSAFE.getInt(null, offsetsAddr + 4L * row); + argument[row - rowsStart] = new String(bytes, prevOffset, currOffset - prevOffset, + StandardCharsets.UTF_8); } } return argument; @@ -1314,7 +1320,7 @@ public class UdfConvert { } //////////////////////////////////////////convertArray/////////////////////////////////////////////////////////// - public static void convertArrayBooleanArg(Object[] argument, int row, int currentRowNum, long offsetStart, + public static ArrayList<Boolean> convertArrayBooleanArg(int row, int currentRowNum, long offsetStart, boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) { ArrayList<Boolean> data = null; if (isNullable) { @@ -1340,10 +1346,10 @@ public class UdfConvert { } } // for loop } // end for all current row - argument[row] = data; + return data; } - public static void convertArrayTinyIntArg(Object[] argument, int row, int currentRowNum, long offsetStart, + public static ArrayList<Byte> convertArrayTinyIntArg(int row, int currentRowNum, long offsetStart, boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) { ArrayList<Byte> data = null; if (isNullable) { @@ -1369,10 +1375,10 @@ public class UdfConvert { } } // for loop } // end for all current row - argument[row] = data; + return data; } - public static void convertArraySmallIntArg(Object[] argument, int row, int currentRowNum, long offsetStart, + public static ArrayList<Short> convertArraySmallIntArg(int row, int currentRowNum, long offsetStart, boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) { ArrayList<Short> data = null; if (isNullable) { @@ -1398,10 +1404,10 @@ public class UdfConvert { } } // for loop } // end for all current row - argument[row] = data; + return data; } - public static void convertArrayIntArg(Object[] argument, int row, int currentRowNum, long offsetStart, + public static ArrayList<Integer> convertArrayIntArg(int row, int currentRowNum, long offsetStart, boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) { ArrayList<Integer> data = null; if (isNullable) { @@ -1427,10 +1433,10 @@ public class UdfConvert { } } // for loop } // end for all current row - argument[row] = data; + return data; } - public static void convertArrayBigIntArg(Object[] argument, int row, int currentRowNum, long offsetStart, + public static ArrayList<Long> convertArrayBigIntArg(int row, int currentRowNum, long offsetStart, boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) { ArrayList<Long> data = null; if (isNullable) { @@ -1456,10 +1462,10 @@ public class UdfConvert { } } // for loop } // end for all current row - argument[row] = data; + return data; } - public static void convertArrayFloatArg(Object[] argument, int row, int currentRowNum, long offsetStart, + public static ArrayList<Float> convertArrayFloatArg(int row, int currentRowNum, long offsetStart, boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) { ArrayList<Float> data = null; if (isNullable) { @@ -1485,10 +1491,10 @@ public class UdfConvert { } } // for loop } // end for all current row - argument[row] = data; + return data; } - public static void convertArrayDoubleArg(Object[] argument, int row, int currentRowNum, long offsetStart, + public static ArrayList<Double> convertArrayDoubleArg(int row, int currentRowNum, long offsetStart, boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) { ArrayList<Double> data = null; if (isNullable) { @@ -1514,10 +1520,10 @@ public class UdfConvert { } } // for loop } // end for all current row - argument[row] = data; + return data; } - public static void convertArrayDateArg(Object[] argument, int row, int currentRowNum, long offsetStart, + public static ArrayList<LocalDate> convertArrayDateArg(int row, int currentRowNum, long offsetStart, boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) { ArrayList<LocalDate> data = null; if (isNullable) { @@ -1549,10 +1555,10 @@ public class UdfConvert { } } // for loop } // end for all current row - argument[row] = data; + return data; } - public static void convertArrayDateTimeArg(Object[] argument, int row, int currentRowNum, long offsetStart, + public static ArrayList<LocalDateTime> convertArrayDateTimeArg(int row, int currentRowNum, long offsetStart, boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) { ArrayList<LocalDateTime> data = null; if (isNullable) { @@ -1582,10 +1588,10 @@ public class UdfConvert { } } // for loop } // end for all current row - argument[row] = data; + return data; } - public static void convertArrayDateV2Arg(Object[] argument, int row, int currentRowNum, long offsetStart, + public static ArrayList<LocalDate> convertArrayDateV2Arg(int row, int currentRowNum, long offsetStart, boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) { ArrayList<LocalDate> data = null; if (isNullable) { @@ -1613,10 +1619,10 @@ public class UdfConvert { } } // for loop } // end for all current row - argument[row] = data; + return data; } - public static void convertArrayDateTimeV2Arg(Object[] argument, int row, int currentRowNum, long offsetStart, + public static ArrayList<LocalDateTime> convertArrayDateTimeV2Arg(int row, int currentRowNum, long offsetStart, boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) { ArrayList<LocalDateTime> data = null; if (isNullable) { @@ -1646,10 +1652,10 @@ public class UdfConvert { } } // for loop } // end for all current row - argument[row] = data; + return data; } - public static void convertArrayLargeIntArg(Object[] argument, int row, int currentRowNum, long offsetStart, + public static ArrayList<BigInteger> convertArrayLargeIntArg(int row, int currentRowNum, long offsetStart, boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) { ArrayList<BigInteger> data = null; byte[] bytes = new byte[16]; @@ -1678,10 +1684,10 @@ public class UdfConvert { } } // for loop } // end for all current row - argument[row] = data; + return data; } - public static void convertArrayDecimalArg(int scale, long typeLen, Object[] argument, int row, int currentRowNum, + public static ArrayList<BigDecimal> convertArrayDecimalArg(int scale, long typeLen, int row, int currentRowNum, long offsetStart, boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr) { ArrayList<BigDecimal> data = null; @@ -1713,10 +1719,10 @@ public class UdfConvert { } } // for loop } // end for all current row - argument[row] = data; + return data; } - public static void convertArrayStringArg(Object[] argument, int row, int currentRowNum, long offsetStart, + public static ArrayList<String> convertArrayStringArg(int row, int currentRowNum, long offsetStart, boolean isNullable, long nullMapAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) { ArrayList<String> data = null; if (isNullable) { @@ -1755,6 +1761,6 @@ public class UdfConvert { } } } - argument[row] = data; + return data; } } diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java index f1993ec488..50528d007b 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java @@ -24,6 +24,7 @@ import org.apache.doris.common.jni.utils.UdfUtils; import org.apache.doris.common.jni.utils.UdfUtils.JavaUdfDataType; import org.apache.doris.thrift.TJavaUdfExecutorCtorParams; +import com.esotericsoftware.reflectasm.MethodAccess; import com.google.common.base.Joiner; import com.google.common.base.Preconditions; import com.google.common.collect.Lists; @@ -50,6 +51,8 @@ public class UdfExecutor extends BaseExecutor { private long rowIdx; private long batchSizePtr; + private int evaluateIndex; + private MethodAccess methodAccess; /** * Create a UdfExecutor, using parameters from a serialized thrift object. Used by @@ -113,166 +116,14 @@ public class UdfExecutor extends BaseExecutor { public Object[] convertBasicArguments(int argIdx, boolean isNullable, int numRows, long nullMapAddr, long columnAddr, long strOffsetAddr) { - switch (argTypes[argIdx]) { - case BOOLEAN: - return UdfConvert.convertBooleanArg(isNullable, numRows, nullMapAddr, columnAddr); - case TINYINT: - return UdfConvert.convertTinyIntArg(isNullable, numRows, nullMapAddr, columnAddr); - case SMALLINT: - return UdfConvert.convertSmallIntArg(isNullable, numRows, nullMapAddr, columnAddr); - case INT: - return UdfConvert.convertIntArg(isNullable, numRows, nullMapAddr, columnAddr); - case BIGINT: - return UdfConvert.convertBigIntArg(isNullable, numRows, nullMapAddr, columnAddr); - case LARGEINT: - return UdfConvert.convertLargeIntArg(isNullable, numRows, nullMapAddr, columnAddr); - case FLOAT: - return UdfConvert.convertFloatArg(isNullable, numRows, nullMapAddr, columnAddr); - case DOUBLE: - return UdfConvert.convertDoubleArg(isNullable, numRows, nullMapAddr, columnAddr); - case CHAR: - case VARCHAR: - case STRING: - return UdfConvert.convertStringArg(isNullable, numRows, nullMapAddr, columnAddr, strOffsetAddr); - case DATE: // udaf maybe argClass[i + argClassOffset] need add +1 - return UdfConvert.convertDateArg(argClass[argIdx], isNullable, numRows, nullMapAddr, columnAddr); - case DATETIME: - return UdfConvert.convertDateTimeArg(argClass[argIdx], isNullable, numRows, nullMapAddr, columnAddr); - case DATEV2: - return UdfConvert.convertDateV2Arg(argClass[argIdx], isNullable, numRows, nullMapAddr, columnAddr); - case DATETIMEV2: - return UdfConvert.convertDateTimeV2Arg(argClass[argIdx], isNullable, numRows, nullMapAddr, columnAddr); - case DECIMALV2: - case DECIMAL128: - return UdfConvert.convertDecimalArg(argTypes[argIdx].getScale(), 16L, isNullable, numRows, nullMapAddr, - columnAddr); - case DECIMAL32: - return UdfConvert.convertDecimalArg(argTypes[argIdx].getScale(), 4L, isNullable, numRows, nullMapAddr, - columnAddr); - case DECIMAL64: - return UdfConvert.convertDecimalArg(argTypes[argIdx].getScale(), 8L, isNullable, numRows, nullMapAddr, - columnAddr); - default: { - LOG.info("Not support type: " + argTypes[argIdx].toString()); - Preconditions.checkState(false, "Not support type: " + argTypes[argIdx].toString()); - break; - } - } - return null; + return convertBasicArg(true, argIdx, isNullable, 0, numRows, nullMapAddr, columnAddr, strOffsetAddr); } public Object[] convertArrayArguments(int argIdx, boolean isNullable, int numRows, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) { - Object[] argument = (Object[]) Array.newInstance(ArrayList.class, numRows); - for (int row = 0; row < numRows; ++row) { - long offsetStart = UdfUtils.UNSAFE.getLong(null, offsetsAddr + 8L * (row - 1)); - long offsetEnd = UdfUtils.UNSAFE.getLong(null, offsetsAddr + 8L * (row)); - int currentRowNum = (int) (offsetEnd - offsetStart); - switch (argTypes[argIdx].getItemType().getPrimitiveType()) { - case BOOLEAN: { - UdfConvert - .convertArrayBooleanArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr, - nestedNullMapAddr, dataAddr); - break; - } - case TINYINT: { - UdfConvert - .convertArrayTinyIntArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr, - nestedNullMapAddr, dataAddr); - break; - } - case SMALLINT: { - UdfConvert - .convertArraySmallIntArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr, - nestedNullMapAddr, dataAddr); - break; - } - case INT: { - UdfConvert.convertArrayIntArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr, - nestedNullMapAddr, dataAddr); - break; - } - case BIGINT: { - UdfConvert.convertArrayBigIntArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr, - nestedNullMapAddr, dataAddr); - break; - } - case LARGEINT: { - UdfConvert - .convertArrayLargeIntArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr, - nestedNullMapAddr, dataAddr); - break; - } - case FLOAT: { - UdfConvert.convertArrayFloatArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr, - nestedNullMapAddr, dataAddr); - break; - } - case DOUBLE: { - UdfConvert.convertArrayDoubleArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr, - nestedNullMapAddr, dataAddr); - break; - } - case CHAR: - case VARCHAR: - case STRING: { - UdfConvert.convertArrayStringArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr, - nestedNullMapAddr, dataAddr, strOffsetAddr); - break; - } - case DATE: { - UdfConvert.convertArrayDateArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr, - nestedNullMapAddr, dataAddr); - break; - } - case DATETIME: { - UdfConvert - .convertArrayDateTimeArg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr, - nestedNullMapAddr, dataAddr); - break; - } - case DATEV2: { - UdfConvert.convertArrayDateV2Arg(argument, row, currentRowNum, offsetStart, isNullable, nullMapAddr, - nestedNullMapAddr, dataAddr); - break; - } - case DATETIMEV2: { - UdfConvert.convertArrayDateTimeV2Arg(argument, row, currentRowNum, offsetStart, isNullable, - nullMapAddr, - nestedNullMapAddr, dataAddr); - break; - } - case DECIMALV2: - case DECIMAL128: { - UdfConvert.convertArrayDecimalArg(argTypes[argIdx].getScale(), 16L, argument, row, currentRowNum, - offsetStart, isNullable, - nullMapAddr, - nestedNullMapAddr, dataAddr); - break; - } - case DECIMAL32: { - UdfConvert.convertArrayDecimalArg(argTypes[argIdx].getScale(), 4L, argument, row, currentRowNum, - offsetStart, isNullable, - nullMapAddr, - nestedNullMapAddr, dataAddr); - break; - } - case DECIMAL64: { - UdfConvert.convertArrayDecimalArg(argTypes[argIdx].getScale(), 8L, argument, row, currentRowNum, - offsetStart, isNullable, - nullMapAddr, - nestedNullMapAddr, dataAddr); - break; - } - default: { - LOG.info("Not support: " + argTypes[argIdx]); - Preconditions.checkState(false, "Not support type " + argTypes[argIdx].toString()); - break; - } - } - } - return argument; + return convertArrayArg(argIdx, isNullable, 0, numRows, nullMapAddr, offsetsAddr, nestedNullMapAddr, dataAddr, + strOffsetAddr); } /** @@ -287,7 +138,7 @@ public class UdfExecutor extends BaseExecutor { for (int j = 0; j < column.length; ++j) { parameters[j] = inputs[j][i]; } - result[i] = method.invoke(udf, parameters); + result[i] = methodAccess.invoke(udf, evaluateIndex, parameters); } return result; } catch (Exception e) { @@ -581,6 +432,7 @@ public class UdfExecutor extends BaseExecutor { loader = ClassLoader.getSystemClassLoader(); } Class<?> c = Class.forName(className, true, loader); + methodAccess = MethodAccess.get(c); Constructor<?> ctor = c.getConstructor(); udf = ctor.newInstance(); Method[] methods = c.getMethods(); @@ -597,6 +449,7 @@ public class UdfExecutor extends BaseExecutor { continue; } method = m; + evaluateIndex = methodAccess.getIndex(UDF_FUNCTION_NAME); Pair<Boolean, JavaUdfDataType> returnType; if (argClass.length == 0 && parameterTypes.length == 0) { // Special case where the UDF doesn't take any input args --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org