This is an automated email from the ASF dual-hosted git repository. lihaopeng pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-doris.git
The following commit(s) were added to refs/heads/master by this push: new 4c24586865 [Vectorized][UDF] support java-udaf (#9930) 4c24586865 is described below commit 4c24586865907936cfebe001c65978503fb0b21d Author: zhangstar333 <87313068+zhangstar...@users.noreply.github.com> AuthorDate: Wed Jun 15 10:53:44 2022 +0800 [Vectorized][UDF] support java-udaf (#9930) --- be/src/util/jni-util.h | 6 + be/src/vec/CMakeLists.txt | 1 + .../aggregate_function_java_udaf.h | 346 ++++++++++++++ be/src/vec/exec/vaggregation_node.cpp | 9 +- be/src/vec/exprs/vectorized_agg_fn.cpp | 14 +- be/src/vec/functions/function_java_udf.cpp | 5 +- be/src/vec/functions/function_java_udf.h | 6 - .../ecosystem/udf/java-user-defined-function.md | 102 +++- .../ecosystem/udf/java-user-defined-function.md | 99 +++- .../apache/doris/analysis/CreateFunctionStmt.java | 164 ++++++- .../apache/doris/catalog/AggregateFunction.java | 30 ++ .../java/org/apache/doris/udf/UdafExecutor.java | 528 +++++++++++++++++++++ .../java/org/apache/doris/udf/UdfExecutor.java | 161 ++----- .../main/java/org/apache/doris/udf/UdfUtils.java | 120 +++++ gensrc/thrift/Types.thrift | 2 + 15 files changed, 1415 insertions(+), 178 deletions(-) diff --git a/be/src/util/jni-util.h b/be/src/util/jni-util.h index 9a8c1cb859..ccbd26871b 100644 --- a/be/src/util/jni-util.h +++ b/be/src/util/jni-util.h @@ -60,6 +60,12 @@ public: static jclass jni_util_class() { return jni_util_cl_; } static jmethodID throwable_to_stack_trace_id() { return throwable_to_stack_trace_id_; } + static const int32_t INITIAL_RESERVED_BUFFER_SIZE = 1024; + // TODO: we need a heuristic strategy to increase buffer size for variable-size output. + static inline int32_t IncreaseReservedBufferSize(int n) { + return INITIAL_RESERVED_BUFFER_SIZE << n; + } + private: static Status GetJNIEnvSlowPath(JNIEnv** env); diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt index 7862172c2c..5c4d5c7b36 100644 --- a/be/src/vec/CMakeLists.txt +++ b/be/src/vec/CMakeLists.txt @@ -39,6 +39,7 @@ set(VEC_FILES aggregate_functions/aggregate_function_group_concat.cpp aggregate_functions/aggregate_function_percentile_approx.cpp aggregate_functions/aggregate_function_simple_factory.cpp + aggregate_functions/aggregate_function_java_udaf.h columns/collator.cpp columns/column.cpp columns/column_array.cpp diff --git a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h new file mode 100644 index 0000000000..8594cd30bf --- /dev/null +++ b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h @@ -0,0 +1,346 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#ifdef LIBJVM + +#include <jni.h> +#include <unistd.h> + +#include <cstdint> +#include <memory> + +#include "common/status.h" +#include "gen_cpp/Exprs_types.h" +#include "runtime/user_function_cache.h" +#include "util/jni-util.h" +#include "vec/aggregate_functions/aggregate_function.h" +#include "vec/columns/column_string.h" +#include "vec/common/exception.h" +#include "vec/common/string_ref.h" +#include "vec/core/block.h" +#include "vec/core/column_numbers.h" +#include "vec/core/field.h" +#include "vec/core/types.h" +#include "vec/data_types/data_type_string.h" +#include "vec/io/io_helper.h" + +namespace doris::vectorized { + +const char* UDAF_EXECUTOR_CLASS = "org/apache/doris/udf/UdafExecutor"; +const char* UDAF_EXECUTOR_CTOR_SIGNATURE = "([B)V"; +const char* UDAF_EXECUTOR_CLOSE_SIGNATURE = "()V"; +const char* UDAF_EXECUTOR_ADD_SIGNATURE = "(JJ)V"; +const char* UDAF_EXECUTOR_SERIALIZE_SIGNATURE = "()[B"; +const char* UDAF_EXECUTOR_MERGE_SIGNATURE = "([B)V"; +const char* UDAF_EXECUTOR_RESULT_SIGNATURE = "(J)Z"; +// Calling Java method about those signture means: "(argument-types)return-type" +// https://www.iitk.ac.in/esc101/05Aug/tutorial/native1.1/implementing/method.html + +struct AggregateJavaUdafData { +public: + AggregateJavaUdafData() = default; + AggregateJavaUdafData(int64_t num_args) { + argument_size = num_args; + input_values_buffer_ptr.reset(new int64_t[num_args]); + input_nulls_buffer_ptr.reset(new int64_t[num_args]); + input_offsets_ptrs.reset(new int64_t[num_args]); + output_value_buffer.reset(new int64_t); + output_null_value.reset(new int64_t); + output_offsets_ptr.reset(new int64_t); + output_intermediate_state_ptr.reset(new int64_t); + } + + ~AggregateJavaUdafData() { + JNIEnv* env; + Status status; + RETURN_IF_STATUS_ERROR(status, JniUtil::GetJNIEnv(&env)); + env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_close_id); + RETURN_IF_STATUS_ERROR(status, JniUtil::GetJniExceptionMsg(env)); + env->DeleteGlobalRef(executor_obj); + } + + Status init_udaf(const TFunction& fn) { + JNIEnv* env = nullptr; + RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf init_udaf function"); + RETURN_IF_ERROR(JniUtil::GetGlobalClassRef(env, UDAF_EXECUTOR_CLASS, &executor_cl)); + RETURN_NOT_OK_STATUS_WITH_WARN(register_func_id(env), + "Java-Udaf register_func_id function"); + + // Add a scoped cleanup jni reference object. This cleans up local refs made below. + JniLocalFrame jni_frame; + { + std::string local_location; + auto function_cache = UserFunctionCache::instance(); + RETURN_IF_ERROR(function_cache->get_jarpath(fn.id, fn.hdfs_location, fn.checksum, + &local_location)); + TJavaUdfExecutorCtorParams ctor_params; + ctor_params.__set_fn(fn); + ctor_params.__set_location(local_location); + ctor_params.__set_input_offsets_ptrs((int64_t)input_offsets_ptrs.get()); + ctor_params.__set_input_buffer_ptrs((int64_t)input_values_buffer_ptr.get()); + ctor_params.__set_input_nulls_ptrs((int64_t)input_nulls_buffer_ptr.get()); + ctor_params.__set_output_buffer_ptr((int64_t)output_value_buffer.get()); + + ctor_params.__set_output_null_ptr((int64_t)output_null_value.get()); + ctor_params.__set_output_offsets_ptr((int64_t)output_offsets_ptr.get()); + ctor_params.__set_output_intermediate_state_ptr( + (int64_t)output_intermediate_state_ptr.get()); + + jbyteArray ctor_params_bytes; + + // Pushed frame will be popped when jni_frame goes out-of-scope. + RETURN_IF_ERROR(jni_frame.push(env)); + RETURN_IF_ERROR(SerializeThriftMsg(env, &ctor_params, &ctor_params_bytes)); + executor_obj = env->NewObject(executor_cl, executor_ctor_id, ctor_params_bytes); + } + RETURN_ERROR_IF_EXC(env); + RETURN_IF_ERROR(JniUtil::LocalToGlobalRef(env, executor_obj, &executor_obj)); + return Status::OK(); + } + + Status add(const IColumn** columns, size_t row_num_start, size_t row_num_end, + const DataTypes& argument_types) { + JNIEnv* env = nullptr; + RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf add function"); + for (int arg_idx = 0; arg_idx < argument_size; ++arg_idx) { + auto data_col = columns[arg_idx]; + if (auto* nullable = check_and_get_column<const ColumnNullable>(*columns[arg_idx])) { + data_col = nullable->get_nested_column_ptr(); + auto null_col = check_and_get_column<ColumnVector<UInt8>>( + nullable->get_null_map_column_ptr()); + input_nulls_buffer_ptr.get()[arg_idx] = + reinterpret_cast<int64_t>(null_col->get_data().data()); + } else { + input_nulls_buffer_ptr.get()[arg_idx] = -1; + } + if (data_col->is_column_string()) { + const ColumnString* str_col = check_and_get_column<ColumnString>(data_col); + input_values_buffer_ptr.get()[arg_idx] = + reinterpret_cast<int64_t>(str_col->get_chars().data()); + input_offsets_ptrs.get()[arg_idx] = + reinterpret_cast<int64_t>(str_col->get_offsets().data()); + } else if (data_col->is_numeric() || data_col->is_column_decimal()) { + input_values_buffer_ptr.get()[arg_idx] = + reinterpret_cast<int64_t>(data_col->get_raw_data().data); + } else { + return Status::InvalidArgument( + strings::Substitute("Java UDAF doesn't support type is $0 now !", + argument_types[arg_idx]->get_name())); + } + } + env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_add_id, row_num_start, + row_num_end); + return JniUtil::GetJniExceptionMsg(env); + } + + Status merge(const AggregateJavaUdafData& rhs) { + JNIEnv* env = nullptr; + RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf merge function"); + serialize_data = rhs.serialize_data; + long len = serialize_data.length(); + jbyteArray arr = env->NewByteArray(len); + env->SetByteArrayRegion(arr, 0, len, reinterpret_cast<jbyte*>(serialize_data.data())); + env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_merge_id, arr); + return JniUtil::GetJniExceptionMsg(env); + } + + Status write(BufferWritable& buf) { + JNIEnv* env = nullptr; + RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf write function"); + // TODO: Here get a byte[] from FE serialize, and then allocate the same length bytes to + // save it in BE, Because i'm not sure there is a way to use the byte[] not allocate again. + jbyteArray arr = (jbyteArray)(env->CallNonvirtualObjectMethod(executor_obj, executor_cl, + executor_serialize_id)); + int len = env->GetArrayLength(arr); + serialize_data.resize(len); + env->GetByteArrayRegion(arr, 0, len, reinterpret_cast<jbyte*>(serialize_data.data())); + write_binary(serialize_data, buf); + return JniUtil::GetJniExceptionMsg(env); + } + + void read(BufferReadable& buf) { read_binary(serialize_data, buf); } + + Status get(IColumn& to, const DataTypePtr& result_type) const { + to.insert_default(); + JNIEnv* env = nullptr; + RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf get value function"); + if (result_type->is_nullable()) { + auto& nullable = assert_cast<ColumnNullable&>(to); + *output_null_value = + reinterpret_cast<int64_t>(nullable.get_null_map_column().get_raw_data().data); + auto& data_col = nullable.get_nested_column(); + +#ifndef EVALUATE_JAVA_UDAF +#define EVALUATE_JAVA_UDAF \ + if (data_col.is_column_string()) { \ + const ColumnString* str_col = check_and_get_column<ColumnString>(data_col); \ + ColumnString::Chars& chars = const_cast<ColumnString::Chars&>(str_col->get_chars()); \ + ColumnString::Offsets& offsets = \ + const_cast<ColumnString::Offsets&>(str_col->get_offsets()); \ + int increase_buffer_size = 0; \ + *output_value_buffer = reinterpret_cast<int64_t>(chars.data()); \ + *output_offsets_ptr = reinterpret_cast<int64_t>(offsets.data()); \ + *output_intermediate_state_ptr = chars.size(); \ + jboolean res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, \ + executor_result_id, to.size() - 1); \ + while (res != JNI_TRUE) { \ + int32_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \ + increase_buffer_size++; \ + chars.reserve(chars.size() + buffer_size); \ + chars.resize(chars.size() + buffer_size); \ + *output_intermediate_state_ptr = chars.size(); \ + res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id, \ + to.size() - 1); \ + } \ + } else if (data_col.is_numeric() || data_col.is_column_decimal()) { \ + *output_value_buffer = reinterpret_cast<int64_t>(data_col.get_raw_data().data); \ + env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id, \ + to.size() - 1); \ + } else { \ + return Status::InvalidArgument(strings::Substitute( \ + "Java UDAF doesn't support return type is $0 now !", result_type->get_name())); \ + } +#endif + EVALUATE_JAVA_UDAF; + } else { + *output_null_value = -1; + *output_value_buffer = reinterpret_cast<int64_t>(to.get_raw_data().data); + auto& data_col = to; + EVALUATE_JAVA_UDAF; + env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id, + to.size() - 1); + } + return JniUtil::GetJniExceptionMsg(env); + } + +private: + Status register_func_id(JNIEnv* env) { + auto register_id = [&](const char* func_name, const char* func_sign, jmethodID& func_id) { + func_id = env->GetMethodID(executor_cl, func_name, func_sign); + Status s = JniUtil::GetJniExceptionMsg(env); + if (!s.ok()) { + return Status::InternalError( + strings::Substitute("Java-Udaf register_func_id meet error and error is $0", + s.get_error_msg())); + } + return s; + }; + + RETURN_IF_ERROR(register_id("<init>", UDAF_EXECUTOR_CTOR_SIGNATURE, executor_ctor_id)); + RETURN_IF_ERROR(register_id("add", UDAF_EXECUTOR_ADD_SIGNATURE, executor_add_id)); + RETURN_IF_ERROR(register_id("close", UDAF_EXECUTOR_CLOSE_SIGNATURE, executor_close_id)); + RETURN_IF_ERROR(register_id("merge", UDAF_EXECUTOR_MERGE_SIGNATURE, executor_merge_id)); + RETURN_IF_ERROR( + register_id("serialize", UDAF_EXECUTOR_SERIALIZE_SIGNATURE, executor_serialize_id)); + RETURN_IF_ERROR( + register_id("getValue", UDAF_EXECUTOR_RESULT_SIGNATURE, executor_result_id)); + return Status::OK(); + } + +private: + jclass executor_cl; + jobject executor_obj; + jmethodID executor_ctor_id; + + jmethodID executor_add_id; + jmethodID executor_merge_id; + jmethodID executor_serialize_id; + jmethodID executor_result_id; + jmethodID executor_close_id; + + std::unique_ptr<int64_t[]> input_values_buffer_ptr; + std::unique_ptr<int64_t[]> input_nulls_buffer_ptr; + std::unique_ptr<int64_t[]> input_offsets_ptrs; + std::unique_ptr<int64_t> output_value_buffer; + std::unique_ptr<int64_t> output_null_value; + std::unique_ptr<int64_t> output_offsets_ptr; + std::unique_ptr<int64_t> output_intermediate_state_ptr; + + int argument_size = 0; + std::string serialize_data; +}; + +class AggregateJavaUdaf final + : public IAggregateFunctionDataHelper<AggregateJavaUdafData, AggregateJavaUdaf> { +public: + AggregateJavaUdaf(const TFunction& fn, const DataTypes& argument_types, const Array& parameters, + const DataTypePtr& return_type) + : IAggregateFunctionDataHelper(argument_types, parameters), + _fn(fn), + _return_type(return_type) {} + ~AggregateJavaUdaf() = default; + + static AggregateFunctionPtr create(const TFunction& fn, const DataTypes& argument_types, + const Array& parameters, const DataTypePtr& return_type) { + return std::make_shared<AggregateJavaUdaf>(fn, argument_types, parameters, return_type); + } + + void create(AggregateDataPtr __restrict place) const override { + new (place) Data(argument_types.size()); + Status status = Status::OK(); + RETURN_IF_STATUS_ERROR(status, data(place).init_udaf(_fn)); + } + + String get_name() const override { return _fn.name.function_name; } + + DataTypePtr get_return_type() const override { return _return_type; } + + // TODO: here calling add operator maybe only hava done one row, this performance may be poorly + // so it's possible to maintain a hashtable in FE, the key is place address, value is the object + // then we can calling add_bacth function and calculate the whole batch at once, + // and avoid calling jni multiple times. + void add(AggregateDataPtr __restrict place, const IColumn** columns, size_t row_num, + Arena*) const override { + this->data(place).add(columns, row_num, row_num + 1, argument_types); + } + + // TODO: Here we calling method by jni, And if we get a thrown from FE, + // But can't let user known the error, only return directly and output error to log file. + void add_batch_single_place(size_t batch_size, AggregateDataPtr place, const IColumn** columns, + Arena* arena) const override { + this->data(place).add(columns, 0, batch_size, argument_types); + } + + void reset(AggregateDataPtr place) const override {} + + void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs, + Arena*) const override { + this->data(place).merge(this->data(rhs)); + } + + void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override { + this->data(const_cast<AggregateDataPtr&>(place)).write(buf); + } + + void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf, + Arena*) const override { + this->data(place).read(buf); + } + + void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override { + this->data(place).get(to, _return_type); + } + +private: + TFunction _fn; + DataTypePtr _return_type; +}; + +} // namespace doris::vectorized +#endif \ No newline at end of file diff --git a/be/src/vec/exec/vaggregation_node.cpp b/be/src/vec/exec/vaggregation_node.cpp index c9017321b9..48071da7b8 100644 --- a/be/src/vec/exec/vaggregation_node.cpp +++ b/be/src/vec/exec/vaggregation_node.cpp @@ -272,8 +272,6 @@ Status AggregationNode::prepare(RuntimeState* state) { _agg_data.without_key = reinterpret_cast<AggregateDataPtr>( _mem_pool->allocate(_total_size_of_aggregate_states)); - _create_agg_status(_agg_data.without_key); - if (_is_merge) { _executor.execute = std::bind<Status>(&AggregationNode::_merge_without_key, this, std::placeholders::_1); @@ -345,7 +343,12 @@ Status AggregationNode::open(RuntimeState* state) { // Streaming preaggregations do all processing in GetNext(). if (_is_streaming_preagg) return Status::OK(); - + // move _create_agg_status to open not in during prepare, + // because during prepare and open thread is not the same one, + // this could cause unable to get JVM + if (_probe_expr_ctxs.empty()) { + _create_agg_status(_agg_data.without_key); + } bool eos = false; Block block; while (!eos) { diff --git a/be/src/vec/exprs/vectorized_agg_fn.cpp b/be/src/vec/exprs/vectorized_agg_fn.cpp index 0ee0d7cf16..bf6deae4ad 100644 --- a/be/src/vec/exprs/vectorized_agg_fn.cpp +++ b/be/src/vec/exprs/vectorized_agg_fn.cpp @@ -20,13 +20,13 @@ #include "fmt/format.h" #include "fmt/ranges.h" #include "runtime/descriptors.h" +#include "vec/aggregate_functions/aggregate_function_java_udaf.h" #include "vec/aggregate_functions/aggregate_function_simple_factory.h" #include "vec/columns/column_nullable.h" #include "vec/core/materialize_block.h" #include "vec/data_types/data_type_factory.hpp" #include "vec/data_types/data_type_nullable.h" #include "vec/exprs/vexpr.h" - namespace doris::vectorized { AggFnEvaluator::AggFnEvaluator(const TExprNode& desc) @@ -87,8 +87,16 @@ Status AggFnEvaluator::prepare(RuntimeState* state, const RowDescriptor& desc, M child_expr_name.emplace_back(_input_exprs_ctxs[i]->root()->expr_name()); } - _function = AggregateFunctionSimpleFactory::instance().get( - _fn.name.function_name, argument_types, params, _data_type->is_nullable()); + if (_fn.binary_type == TFunctionBinaryType::JAVA_UDF) { +#ifdef LIBJVM + _function = AggregateJavaUdaf::create(_fn, argument_types, params, _data_type); +#else + return Status::InternalError("Java UDAF is disabled since no libjvm is found!"); +#endif + } else { + _function = AggregateFunctionSimpleFactory::instance().get( + _fn.name.function_name, argument_types, params, _data_type->is_nullable()); + } if (_function == nullptr) { return Status::InternalError( fmt::format("Agg Function {} is not implemented", _fn.name.function_name)); diff --git a/be/src/vec/functions/function_java_udf.cpp b/be/src/vec/functions/function_java_udf.cpp index f22321e69d..ce7d4579d1 100644 --- a/be/src/vec/functions/function_java_udf.cpp +++ b/be/src/vec/functions/function_java_udf.cpp @@ -160,7 +160,7 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, ColumnString::Offsets& offsets = \ const_cast<ColumnString::Offsets&>(str_col->get_offsets()); \ int increase_buffer_size = 0; \ - int32_t buffer_size = JavaFunctionCall::IncreaseReservedBufferSize(increase_buffer_size); \ + int32_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \ chars.reserve(buffer_size); \ chars.resize(buffer_size); \ offsets.reserve(num_rows); \ @@ -173,8 +173,7 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, nullptr); \ while (jni_ctx->output_intermediate_state_ptr->row_idx < num_rows) { \ increase_buffer_size++; \ - int32_t buffer_size = \ - JavaFunctionCall::IncreaseReservedBufferSize(increase_buffer_size); \ + int32_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); \ chars.resize(buffer_size); \ *(jni_ctx->output_value_buffer) = reinterpret_cast<int64_t>(chars.data()); \ jni_ctx->output_intermediate_state_ptr->buffer_size = buffer_size; \ diff --git a/be/src/vec/functions/function_java_udf.h b/be/src/vec/functions/function_java_udf.h index 4c90d2dd4d..db400a2482 100644 --- a/be/src/vec/functions/function_java_udf.h +++ b/be/src/vec/functions/function_java_udf.h @@ -123,12 +123,6 @@ private: static void SetInputNullsBufferElement(JniContext* jni_ctx, int index, uint8_t value); static uint8_t* GetInputValuesBufferAtOffset(JniContext* jni_ctx, int offset); }; - - static const int32_t INITIAL_RESERVED_BUFFER_SIZE = 1024; - // TODO: we need a heuristic strategy to increase buffer size for variable-size output. - static inline int32_t IncreaseReservedBufferSize(int n) { - return INITIAL_RESERVED_BUFFER_SIZE << n; - } }; } // namespace vectorized diff --git a/docs/en/docs/ecosystem/udf/java-user-defined-function.md b/docs/en/docs/ecosystem/udf/java-user-defined-function.md index 9ccc87aa73..aff3c7874f 100644 --- a/docs/en/docs/ecosystem/udf/java-user-defined-function.md +++ b/docs/en/docs/ecosystem/udf/java-user-defined-function.md @@ -36,19 +36,12 @@ Java UDF provides users with a Java interface written in UDF to facilitate the e * Performance: Compared with native UDF, Java UDF will bring additional JNI overhead, but through batch execution, we have minimized the JNI overhead as much as possible. * Vectorized engine: Java UDF is only supported on vectorized engine now. -## Write UDF functions - -This section mainly introduces how to develop a Java UDF. Samples for the Java version are provided under `samples/doris-demo/java-udf-demo/` for your reference, Check it out [here](https://github.com/apache/incubator-doris/tree/master/samples/doris-demo/java-udf-demo) - -To use Java UDF, the main entry of UDF must be the `evaluate` function. This is consistent with other engines such as Hive. In the example of `AddOne`, we have completed the operation of adding an integer as the UDF. - -It is worth mentioning that this example is not only the Java UDF supported by Doris, but also the UDF supported by Hive, that's to say, for users, Hive UDF can be directly migrated to Doris. +### Type correspondence -#### Type correspondence - -|UDF Type|Argument Type| +|Type|UDF Argument Type| |----|---------| -|TinyInt|TinyIntVal| +|Bool|Boolean| +|TinyInt|Byte| |SmallInt|Short| |Int|Integer| |BigInt|Long| @@ -61,9 +54,15 @@ It is worth mentioning that this example is not only the Java UDF supported by D |Varchar|String| |Decimal|BigDecimal| -## Create UDF +## Write UDF functions + +This section mainly introduces how to develop a Java UDF. Samples for the Java version are provided under `samples/doris-demo/java-udf-demo/` for your reference, Check it out [here](https://github.com/apache/incubator-doris/tree/master/samples/doris-demo/java-udf-demo) + +To use Java UDF, the main entry of UDF must be the `evaluate` function. This is consistent with other engines such as Hive. In the example of `AddOne`, we have completed the operation of adding an integer as the UDF. -Currently, UDAF and UDTF are not supported. +It is worth mentioning that this example is not only the Java UDF supported by Doris, but also the UDF supported by Hive, that's to say, for users, Hive UDF can be directly migrated to Doris. + +## Create UDF ```sql CREATE FUNCTION @@ -87,6 +86,83 @@ CREATE FUNCTION java_udf_add_one(int) RETURNS int PROPERTIES ( ); ``` +## Create UDAF +<br/> +When using Java code to write UDAF, there are some functions that must be implemented (mark required) and an inner class State, which will be explained with a specific example below. +The following SimpleDemo will implement a simple function similar to sum, the input parameter is INT, and the output parameter is INT + +```JAVA +package org.apache.doris.udf; + +public class SimpleDemo { + //Need an inner class to store data + /*required*/ + public static class State { + /*some variables if you need */ + public int sum = 0; + } + + /*required*/ + public State create() { + /* here could do some init work if needed */ + return new State(); + } + + /*required*/ + public void destroy(State state) { + /* here could do some destroy work if needed */ + } + + /*required*/ + //first argument is State, then other types your input + public void add(State state, Integer val) { + /* here doing update work when input data*/ + if (val != null) { + state.sum += val; + } + } + + /*required*/ + public void serialize(State state, DataOutputStream out) { + /* serialize some data into buffer */ + out.writeInt(state.sum); + } + + /*required*/ + public void deserialize(State state, DataInputStream in) { + /* deserialize get data from buffer before you put */ + int val = in.readInt(); + state.sum = val; + } + + /*required*/ + public void merge(State state, State rhs) { + /* merge data from state */ + state.sum += rhs.sum; + } + + /*required*/ + //return Type you defined + public Integer getValue(State state) { + /* return finally result */ + return state.sum; + } +} + +``` + +```sql +CREATE AGGREGATE FUNCTION simple_sum(INT) RETURNS INT PROPERTIES ( + "file"="file:///pathTo/java-udaf.jar", + "symbol"="org.apache.doris.udf.SimpleDemo", + "type"="JAVA_UDF" +); +``` + +Currently, UDTF are not supported. + +<br/> + ## Use UDF Users must have the `SELECT` permission of the corresponding database to use UDF/UDAF. diff --git a/docs/zh-CN/docs/ecosystem/udf/java-user-defined-function.md b/docs/zh-CN/docs/ecosystem/udf/java-user-defined-function.md index 8306e84217..8cb870cc0b 100644 --- a/docs/zh-CN/docs/ecosystem/udf/java-user-defined-function.md +++ b/docs/zh-CN/docs/ecosystem/udf/java-user-defined-function.md @@ -36,18 +36,12 @@ Java UDF 为用户提供UDF编写的Java接口,以方便用户使用Java语言 * 性能:相比于 Native UDF,Java UDF会带来额外的JNI开销,不过通过批式执行的方式,我们已经尽可能的将JNI开销降到最低。 * 向量化引擎:Java UDF当前只支持向量化引擎。 -## 编写 UDF 函数 - -本小节主要介绍如何开发一个 Java UDF。在 `samples/doris-demo/java-udf-demo/` 下提供了示例,可供参考,查看点击[这里](https://github.com/apache/incubator-doris/tree/master/samples/doris-demo/java-udf-demo) - -使用Java代码编写UDF,UDF的主入口必须为 `evaluate` 函数。这一点与Hive等其他引擎保持一致。在本示例中,我们编写了 `AddOne` UDF来完成对整型输入进行加一的操作。 -值得一提的是,本例不只是Doris支持的Java UDF,同时还是Hive支持的UDF,也就是说,对于用户来讲,Hive UDF是可以直接迁移至Doris的。 - -#### 类型对应关系 +### 类型对应关系 -|UDF Type|Argument Type| +|Type|UDF Argument Type| |----|---------| -|TinyInt|TinyIntVal| +|Bool|Boolean| +|TinyInt|Byte| |SmallInt|Short| |Int|Integer| |BigInt|Long| @@ -60,10 +54,14 @@ Java UDF 为用户提供UDF编写的Java接口,以方便用户使用Java语言 |Varchar|String| |Decimal|BigDecimal| +## 编写 UDF 函数 -## 创建 UDF +本小节主要介绍如何开发一个 Java UDF。在 `samples/doris-demo/java-udf-demo/` 下提供了示例,可供参考,查看点击[这里](https://github.com/apache/incubator-doris/tree/master/samples/doris-demo/java-udf-demo) -目前暂不支持 UDAF 和 UDTF +使用Java代码编写UDF,UDF的主入口必须为 `evaluate` 函数。这一点与Hive等其他引擎保持一致。在本示例中,我们编写了 `AddOne` UDF来完成对整型输入进行加一的操作。 +值得一提的是,本例不只是Doris支持的Java UDF,同时还是Hive支持的UDF,也就是说,对于用户来讲,Hive UDF是可以直接迁移至Doris的。 + +## 创建 UDF ```sql CREATE FUNCTION @@ -87,6 +85,83 @@ CREATE FUNCTION java_udf_add_one(int) RETURNS int PROPERTIES ( ); ``` +## 编写 UDAF 函数 +<br/> + +在使用Java代码编写UDAF时,有一些必须实现的函数(标记required)和一个内部类State,下面将以一个具体的实例来说明 +下面的SimpleDemo将实现一个类似的sum的简单函数,输入参数INT,输出参数是INT +```JAVA +package org.apache.doris.udf; + +public class SimpleDemo { + //Need an inner class to store data + /*required*/ + public static class State { + /*some variables if you need */ + public int sum = 0; + } + + /*required*/ + public State create() { + /* here could do some init work if needed */ + return new State(); + } + + /*required*/ + public void destroy(State state) { + /* here could do some destroy work if needed */ + } + + /*required*/ + //first argument is State, then other types your input + public void add(State state, Integer val) { + /* here doing update work when input data*/ + if (val != null) { + state.sum += val; + } + } + + /*required*/ + public void serialize(State state, DataOutputStream out) { + /* serialize some data into buffer */ + out.writeInt(state.sum); + } + + /*required*/ + public void deserialize(State state, DataInputStream in) { + /* deserialize get data from buffer before you put */ + int val = in.readInt(); + state.sum = val; + } + + /*required*/ + public void merge(State state, State rhs) { + /* merge data from state */ + state.sum += rhs.sum; + } + + /*required*/ + //return Type you defined + public Integer getValue(State state) { + /* return finally result */ + return state.sum; + } +} + +``` + +```sql +CREATE AGGREGATE FUNCTION simple_sum(int) RETURNS int PROPERTIES ( + "file"="file:///pathTo/java-udaf.jar", + "symbol"="org.apache.doris.udf.SimpleDemo", + "type"="JAVA_UDF" +); +``` + +目前还暂不支持UDTF + +<br/> + ## 使用 UDF 用户使用 UDF 必须拥有对应数据库的 `SELECT` 权限。 diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java index 89f5603fd7..5cbbc5102a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java @@ -64,6 +64,7 @@ import java.net.URL; import java.net.URLClassLoader; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; +import java.util.HashMap; import java.time.LocalDate; import java.time.LocalDateTime; import java.util.List; @@ -72,7 +73,6 @@ import java.util.Set; // create a user define function public class CreateFunctionStmt extends DdlStmt { - private final static Logger LOG = LogManager.getLogger(CreateFunctionStmt.class); @Deprecated public static final String OBJECT_FILE_KEY = "object_file"; public static final String FILE_KEY = "file"; @@ -89,6 +89,14 @@ public class CreateFunctionStmt extends DdlStmt { public static final String REMOVE_KEY = "remove_fn"; public static final String BINARY_TYPE = "type"; public static final String EVAL_METHOD_KEY = "evaluate"; + public static final String CREATE_METHOD_NAME = "create"; + public static final String DESTROY_METHOD_NAME = "destroy"; + public static final String ADD_METHOD_NAME = "add"; + public static final String SERIALIZE_METHOD_NAME = "serialize"; + public static final String MERGE_METHOD_NAME = "merge"; + public static final String GETVALUE_METHOD_NAME = "getValue"; + public static final String STATE_CLASS_NAME = "State"; + private static final Logger LOG = LogManager.getLogger(CreateFunctionStmt.class); private final FunctionName functionName; private final boolean isAggregate; @@ -243,21 +251,22 @@ public class CreateFunctionStmt extends DdlStmt { builder.name(functionName).argsType(argsDef.getArgTypes()).retType(returnType.getType()). hasVarArgs(argsDef.isVariadic()).intermediateType(intermediateType.getType()).location(URI.create(userFile)); String initFnSymbol = properties.get(INIT_KEY); - if (initFnSymbol == null) { + if (initFnSymbol == null && !(binaryType == TFunctionBinaryType.JAVA_UDF)) { throw new AnalysisException("No 'init_fn' in properties"); } String updateFnSymbol = properties.get(UPDATE_KEY); - if (updateFnSymbol == null) { + if (updateFnSymbol == null && !(binaryType == TFunctionBinaryType.JAVA_UDF)) { throw new AnalysisException("No 'update_fn' in properties"); } String mergeFnSymbol = properties.get(MERGE_KEY); - if (mergeFnSymbol == null) { + if (mergeFnSymbol == null && !(binaryType == TFunctionBinaryType.JAVA_UDF)) { throw new AnalysisException("No 'merge_fn' in properties"); } String serializeFnSymbol = properties.get(SERIALIZE_KEY); String finalizeFnSymbol = properties.get(FINALIZE_KEY); String getValueFnSymbol = properties.get(GET_VALUE_KEY); String removeFnSymbol = properties.get(REMOVE_KEY); + String symbol = properties.get(SYMBOL_KEY); if (binaryType == TFunctionBinaryType.RPC && !userFile.contains("://")) { checkRPCUdf(initFnSymbol); checkRPCUdf(updateFnSymbol); @@ -274,12 +283,19 @@ public class CreateFunctionStmt extends DdlStmt { if (removeFnSymbol != null) { checkRPCUdf(removeFnSymbol); } + } else if (binaryType == TFunctionBinaryType.JAVA_UDF) { + if (Strings.isNullOrEmpty(symbol)) { + throw new AnalysisException("No 'symbol' in properties of java-udaf"); + } + analyzeJavaUdaf(symbol); } - function = builder.initFnSymbol(initFnSymbol) - .updateFnSymbol(updateFnSymbol).mergeFnSymbol(mergeFnSymbol) + function = builder.initFnSymbol(initFnSymbol).updateFnSymbol(updateFnSymbol).mergeFnSymbol(mergeFnSymbol) .serializeFnSymbol(serializeFnSymbol).finalizeFnSymbol(finalizeFnSymbol) - .getValueFnSymbol(getValueFnSymbol).removeFnSymbol(removeFnSymbol) - .build(); + .getValueFnSymbol(getValueFnSymbol).removeFnSymbol(removeFnSymbol).symbolName(symbol).build(); + + URI location = URI.create(userFile); + function.setLocation(location); + function.setBinaryType(binaryType); function.setChecksum(checksum); } @@ -308,6 +324,138 @@ public class CreateFunctionStmt extends DdlStmt { function.setChecksum(checksum); } + private void analyzeJavaUdaf(String clazz) throws AnalysisException { + HashMap<String, Method> allMethods = new HashMap<>(); + + try { + URL[] urls = {new URL("jar:" + userFile + "!/")}; + URLClassLoader cl = URLClassLoader.newInstance(urls); + Class udfClass = cl.loadClass(clazz); + String udfClassName = udfClass.getCanonicalName(); + String stateClassName = udfClassName + "$" + STATE_CLASS_NAME; + Class stateClass = cl.loadClass(stateClassName); + + for (Method m : udfClass.getMethods()) { + if (!m.getDeclaringClass().equals(udfClass)) { + continue; + } + String name = m.getName(); + if (allMethods.containsKey(name)) { + throw new AnalysisException( + String.format("UDF class '%s' has multiple methods with name '%s' ", udfClassName, name)); + } + allMethods.put(name, m); + } + + if (allMethods.get(CREATE_METHOD_NAME) == null) { + throw new AnalysisException( + String.format("No method '%s' in class '%s'!", CREATE_METHOD_NAME, udfClassName)); + } else { + checkMethodNonStaticAndPublic(CREATE_METHOD_NAME, allMethods.get(CREATE_METHOD_NAME), udfClassName); + checkArgumentCount(allMethods.get(CREATE_METHOD_NAME), 0, udfClassName); + checkReturnJavaType(udfClassName, allMethods.get(CREATE_METHOD_NAME), stateClass); + } + + if (allMethods.get(DESTROY_METHOD_NAME) == null) { + throw new AnalysisException( + String.format("No method '%s' in class '%s'!", DESTROY_METHOD_NAME, udfClassName)); + } else { + checkMethodNonStaticAndPublic(DESTROY_METHOD_NAME, allMethods.get(DESTROY_METHOD_NAME), udfClassName); + checkArgumentCount(allMethods.get(DESTROY_METHOD_NAME), 1, udfClassName); + checkReturnJavaType(udfClassName, allMethods.get(DESTROY_METHOD_NAME), void.class); + } + + if (allMethods.get(ADD_METHOD_NAME) == null) { + throw new AnalysisException( + String.format("No method '%s' in class '%s'!", ADD_METHOD_NAME, udfClassName)); + } else { + checkMethodNonStaticAndPublic(ADD_METHOD_NAME, allMethods.get(ADD_METHOD_NAME), udfClassName); + checkArgumentCount(allMethods.get(ADD_METHOD_NAME), argsDef.getArgTypes().length + 1, udfClassName); + checkReturnJavaType(udfClassName, allMethods.get(ADD_METHOD_NAME), void.class); + for (int i = 0; i < argsDef.getArgTypes().length; i++) { + Parameter p = allMethods.get(ADD_METHOD_NAME).getParameters()[i + 1]; + checkUdfType(udfClass, allMethods.get(ADD_METHOD_NAME), argsDef.getArgTypes()[i], p.getType(), + p.getName()); + } + } + + if (allMethods.get(SERIALIZE_METHOD_NAME) == null) { + throw new AnalysisException( + String.format("No method '%s' in class '%s'!", SERIALIZE_METHOD_NAME, udfClassName)); + } else { + checkMethodNonStaticAndPublic(SERIALIZE_METHOD_NAME, allMethods.get(SERIALIZE_METHOD_NAME), + udfClassName); + checkArgumentCount(allMethods.get(SERIALIZE_METHOD_NAME), 2, udfClassName); + checkReturnJavaType(udfClassName, allMethods.get(SERIALIZE_METHOD_NAME), void.class); + } + + if (allMethods.get(MERGE_METHOD_NAME) == null) { + throw new AnalysisException( + String.format("No method '%s' in class '%s'!", MERGE_METHOD_NAME, udfClassName)); + } else { + checkMethodNonStaticAndPublic(MERGE_METHOD_NAME, allMethods.get(MERGE_METHOD_NAME), udfClassName); + checkArgumentCount(allMethods.get(MERGE_METHOD_NAME), 2, udfClassName); + checkReturnJavaType(udfClassName, allMethods.get(MERGE_METHOD_NAME), void.class); + } + + if (allMethods.get(GETVALUE_METHOD_NAME) == null) { + throw new AnalysisException( + String.format("No method '%s' in class '%s'!", GETVALUE_METHOD_NAME, udfClassName)); + } else { + checkMethodNonStaticAndPublic(GETVALUE_METHOD_NAME, allMethods.get(GETVALUE_METHOD_NAME), udfClassName); + checkArgumentCount(allMethods.get(GETVALUE_METHOD_NAME), 1, udfClassName); + checkReturnUdfType(udfClass, allMethods.get(GETVALUE_METHOD_NAME), returnType.getType()); + } + + if (!Modifier.isPublic(stateClass.getModifiers()) || !Modifier.isStatic(stateClass.getModifiers())) { + throw new AnalysisException( + String.format("UDAF '%s' should have one public & static 'State' class to Construction data ", + udfClassName)); + } + } catch (MalformedURLException e) { + throw new AnalysisException("Failed to load file: " + userFile); + } catch (ClassNotFoundException e) { + throw new AnalysisException("Class [" + clazz + "] or inner class [State] not found in file :" + userFile); + } + } + + private void checkMethodNonStaticAndPublic(String methoName, Method method, String udfClassName) + throws AnalysisException { + if (Modifier.isStatic(method.getModifiers())) { + throw new AnalysisException( + String.format("Method '%s' in class '%s' should be non-static", methoName, udfClassName)); + } + if (!Modifier.isPublic(method.getModifiers())) { + throw new AnalysisException( + String.format("Method '%s' in class '%s' should be public", methoName, udfClassName)); + } + } + + private void checkArgumentCount(Method method, int argumentCount, String udfClassName) throws AnalysisException { + if (method.getParameters().length != argumentCount) { + throw new AnalysisException( + String.format("The number of parameters for method '%s' in class '%s' should be %d", + method.getName(), udfClassName, argumentCount)); + } + } + + private void checkReturnJavaType(String udfClassName, Method method, Class expType) throws AnalysisException { + checkJavaType(udfClassName, method, expType, method.getReturnType(), "return"); + } + + private void checkJavaType(String udfClassName, Method method, Class expType, Class ptype, String pname) + throws AnalysisException { + if (!expType.equals(ptype)) { + throw new AnalysisException( + String.format("UDF class '%s' method '%s' parameter %s[%s] expect type %s", udfClassName, + method.getName(), pname, ptype.getCanonicalName(), expType.getCanonicalName())); + } + } + + private void checkReturnUdfType(Class clazz, Method method, Type expType) throws AnalysisException { + checkUdfType(clazz, method, expType, method.getReturnType(), "return"); + } + private void analyzeJavaUdf(String clazz) throws AnalysisException { try { URL[] urls = {new URL("jar:" + userFile + "!/")}; diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java index 8e1c0631ac..98d0488614 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/AggregateFunction.java @@ -90,6 +90,9 @@ public class AggregateFunction extends Function { // empty input in BE). private boolean returnsNonNullOnEmpty; + //use for java-udaf to point the class of user define + private String symbolName; + // only used for serialization protected AggregateFunction() { } @@ -317,6 +320,7 @@ public class AggregateFunction extends Function { String mergeFnSymbol; String removeFnSymbol; String getValueFnSymbol; + String symbolName; private AggregateFunctionBuilder(TFunctionBinaryType binaryType) { this.binaryType = binaryType; @@ -395,12 +399,18 @@ public class AggregateFunction extends Function { return this; } + public AggregateFunctionBuilder symbolName(String symbol) { + this.symbolName = symbol; + return this; + } + public AggregateFunction build() { AggregateFunction fn = new AggregateFunction(name, argTypes, retType, hasVarArgs, intermediateType, location, initFnSymbol, updateFnSymbol, mergeFnSymbol, serializeFnSymbol, finalizeFnSymbol, getValueFnSymbol, removeFnSymbol); fn.setBinaryType(binaryType); + fn.symbolName = symbolName; return fn; } } @@ -433,6 +443,10 @@ public class AggregateFunction extends Function { return finalizeFnSymbol; } + public String getSymbolName() { + return symbolName; + } + public boolean ignoresDistinct() { return ignoresDistinct; } @@ -485,6 +499,10 @@ public class AggregateFunction extends Function { finalizeFnSymbol = fn; } + public void setSymbolName(String fn) { + symbolName = fn; + } + public void setIntermediateType(Type t) { intermediateType = t; } @@ -511,6 +529,9 @@ public class AggregateFunction extends Function { if (getFinalizeFnSymbol() != null) { sb.append(",\n \"FINALIZE_FN\"=\"" + getFinalizeFnSymbol() + "\""); } + if (getSymbolName() != null) { + sb.append(",\n \"SYMBOL\"=\"" + getSymbolName() + "\""); + } sb.append(",\n \"OBJECT_FILE\"=") .append("\"" + (getLocation() == null ? "" : getLocation().toString()) + "\""); @@ -539,6 +560,9 @@ public class AggregateFunction extends Function { if (finalizeFnSymbol != null) { aggFn.setFinalizeFnSymbol(finalizeFnSymbol); } + if (symbolName != null) { + aggFn.setSymbol(symbolName); + } if (intermediateType != null) { aggFn.setIntermediateType(intermediateType.toThrift()); } else { @@ -568,6 +592,7 @@ public class AggregateFunction extends Function { IOUtils.writeOptionString(output, getValueFnSymbol); IOUtils.writeOptionString(output, removeFnSymbol); IOUtils.writeOptionString(output, finalizeFnSymbol); + IOUtils.writeOptionString(output, symbolName); output.writeBoolean(ignoresDistinct); output.writeBoolean(isAnalyticFn); @@ -588,6 +613,8 @@ public class AggregateFunction extends Function { getValueFnSymbol = IOUtils.readOptionStringOrNull(input); removeFnSymbol = IOUtils.readOptionStringOrNull(input); finalizeFnSymbol = IOUtils.readOptionStringOrNull(input); + symbolName = IOUtils.readOptionStringOrNull(input); + ignoresDistinct = input.readBoolean(); isAnalyticFn = input.readBoolean(); isAggregateFn = input.readBoolean(); @@ -612,6 +639,9 @@ public class AggregateFunction extends Function { if (removeFnSymbol != null) { properties.put(CreateFunctionStmt.REMOVE_KEY, removeFnSymbol); } + if (symbolName != null) { + properties.put(CreateFunctionStmt.SYMBOL_KEY, symbolName); + } return new Gson().toJson(properties); } } diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java b/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java new file mode 100644 index 0000000000..ba59cee308 --- /dev/null +++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java @@ -0,0 +1,528 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.udf; + +import org.apache.doris.catalog.Type; +import org.apache.doris.common.Pair; +import org.apache.doris.thrift.TJavaUdfExecutorCtorParams; +import org.apache.doris.udf.UdfExecutor.JavaUdfDataType; + +import com.google.common.base.Joiner; +import com.google.common.collect.Lists; +import org.apache.log4j.Logger; +import org.apache.thrift.TDeserializer; +import org.apache.thrift.TException; +import org.apache.thrift.protocol.TBinaryProtocol; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.lang.reflect.Constructor; +import java.lang.reflect.Method; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.net.MalformedURLException; +import java.net.URLClassLoader; +import java.nio.charset.StandardCharsets; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; + +/** + * udaf executor. + */ +public class UdafExecutor { + public static final String UDAF_CREATE_FUNCTION = "create"; + public static final String UDAF_DESTORY_FUNCTION = "destroy"; + public static final String UDAF_ADD_FUNCTION = "add"; + public static final String UDAF_SERIALIZE_FUNCTION = "serialize"; + public static final String UDAF_DESERIALIZE_FUNCTION = "deserialize"; + public static final String UDAF_MERGE_FUNCTION = "merge"; + public static final String UDAF_RESULT_FUNCTION = "getValue"; + private static final Logger LOG = Logger.getLogger(UdfExecutor.class); + private static final TBinaryProtocol.Factory PROTOCOL_FACTORY = new TBinaryProtocol.Factory(); + private final long inputBufferPtrs; + private final long inputNullsPtrs; + private final long inputOffsetsPtrs; + private final long outputBufferPtr; + private final long outputNullPtr; + private final long outputOffsetsPtr; + private final long outputIntermediateStatePtr; + private Object udaf; + private HashMap<String, Method> allMethods; + private URLClassLoader classLoader; + private JavaUdfDataType[] argTypes; + private JavaUdfDataType retType; + private Object stateObj; + + /** + * Constructor to create an object. + */ + public UdafExecutor(byte[] thriftParams) throws Exception { + TJavaUdfExecutorCtorParams request = new TJavaUdfExecutorCtorParams(); + TDeserializer deserializer = new TDeserializer(PROTOCOL_FACTORY); + try { + deserializer.deserialize(request, thriftParams); + } catch (TException e) { + throw new InternalException(e.getMessage()); + } + Type[] parameterTypes = new Type[request.fn.arg_types.size()]; + for (int i = 0; i < request.fn.arg_types.size(); ++i) { + parameterTypes[i] = Type.fromThrift(request.fn.arg_types.get(i)); + } + inputBufferPtrs = request.input_buffer_ptrs; + inputNullsPtrs = request.input_nulls_ptrs; + inputOffsetsPtrs = request.input_offsets_ptrs; + + outputBufferPtr = request.output_buffer_ptr; + outputNullPtr = request.output_null_ptr; + outputOffsetsPtr = request.output_offsets_ptr; + outputIntermediateStatePtr = request.output_intermediate_state_ptr; + allMethods = new HashMap<>(); + String className = request.fn.aggregate_fn.symbol; + String jarFile = request.location; + Type funcRetType = UdfUtils.fromThrift(request.fn.ret_type, 0).first; + init(jarFile, className, funcRetType, parameterTypes); + stateObj = create(); + } + + /** + * close and invoke destroy function. + */ + public void close() { + if (classLoader != null) { + try { + destroy(); + classLoader.close(); + } catch (Exception e) { + // Log and ignore. + LOG.debug("Error closing the URLClassloader.", e); + } + } + // We are now un-usable (because the class loader has been + // closed), so null out allMethods and classLoader. + allMethods = null; + classLoader = null; + } + + @Override + protected void finalize() throws Throwable { + close(); + super.finalize(); + } + + /** + * invoke add function, add row in loop [rowStart, rowEnd). + */ + public void add(long rowStart, long rowEnd) throws UdfRuntimeException { + try { + Object[] inputArgs = new Object[argTypes.length + 1]; + inputArgs[0] = stateObj; + for (long row = rowStart; row < rowEnd; ++row) { + Object[] inputObjects = allocateInputObjects(row); + for (int i = 0; i < argTypes.length; ++i) { + if (UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) == -1 + || UdfUtils.UNSAFE.getByte(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) + row) + == 0) { + inputArgs[i + 1] = inputObjects[i]; + } else { + inputArgs[i + 1] = null; + } + } + allMethods.get(UDAF_ADD_FUNCTION).invoke(udaf, inputArgs); + } + } catch (Exception e) { + throw new UdfRuntimeException("UDAF failed to add: ", e); + } + } + + /** + * invoke user create function to get obj. + */ + public Object create() throws UdfRuntimeException { + try { + return allMethods.get(UDAF_CREATE_FUNCTION).invoke(udaf, null); + } catch (Exception e) { + throw new UdfRuntimeException("UDAF failed to create: ", e); + } + } + + /** + * invoke destroy before colse. + */ + public void destroy() throws UdfRuntimeException { + try { + allMethods.get(UDAF_DESTORY_FUNCTION).invoke(udaf, stateObj); + } catch (Exception e) { + throw new UdfRuntimeException("UDAF failed to destroy: ", e); + } + } + + /** + * invoke serialize function and return byte[] to backends. + */ + public byte[] serialize() throws UdfRuntimeException { + try { + Object[] args = new Object[2]; + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + args[0] = stateObj; + args[1] = new DataOutputStream(baos); + allMethods.get(UDAF_SERIALIZE_FUNCTION).invoke(udaf, args); + return baos.toByteArray(); + } catch (Exception e) { + throw new UdfRuntimeException("UDAF failed to serialize: ", e); + } + } + + /** + * invoke merge function and it's have done deserialze. + * here call deserialize first, and call merge. + */ + public void merge(byte[] data) throws UdfRuntimeException { + try { + Object[] args = new Object[2]; + ByteArrayInputStream bins = new ByteArrayInputStream(data); + args[0] = create(); + args[1] = new DataInputStream(bins); + allMethods.get(UDAF_DESERIALIZE_FUNCTION).invoke(udaf, args); + args[1] = args[0]; + args[0] = stateObj; + allMethods.get(UDAF_MERGE_FUNCTION).invoke(udaf, args); + } catch (Exception e) { + throw new UdfRuntimeException("UDAF failed to merge: ", e); + } + } + + /** + * invoke getValue to return finally result. + */ + public boolean getValue(long row) throws UdfRuntimeException { + try { + return storeUdfResult(allMethods.get(UDAF_RESULT_FUNCTION).invoke(udaf, stateObj), row); + } catch (Exception e) { + throw new UdfRuntimeException("UDAF failed to result", e); + } + } + + private boolean storeUdfResult(Object obj, long row) throws UdfRuntimeException { + if (obj == null) { + //if result is null, because we have insert default before, so return true directly when row == 0 + //others because we hava resize the buffer, so maybe be insert value is not correct + if (row != 0) { + long offset = Integer.toUnsignedLong( + UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * row)); + UdfUtils.UNSAFE.putChar(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + offset - 1, + UdfUtils.END_OF_STRING); + } + return true; + } + if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) != -1) { + UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputNullPtr) + row, (byte) 0); + } + switch (retType) { + case BOOLEAN: { + boolean val = (boolean) obj; + UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + val ? (byte) 1 : 0); + return true; + } + case TINYINT: { + UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + (byte) obj); + return true; + } + case SMALLINT: { + UdfUtils.UNSAFE.putShort(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + (short) obj); + return true; + } + case INT: { + UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + (int) obj); + return true; + } + case BIGINT: { + UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + (long) obj); + return true; + } + case FLOAT: { + UdfUtils.UNSAFE.putFloat(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + (float) obj); + return true; + } + case DOUBLE: { + UdfUtils.UNSAFE.putDouble(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + (double) obj); + return true; + } + case DATE: { + LocalDate date = (LocalDate) obj; + long time = UdfUtils.convertDateTimeToLong(date.getYear(), date.getMonthValue(), date.getDayOfMonth(), + 0, 0, 0, true); + UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); + return true; + } + case DATETIME: { + LocalDateTime date = (LocalDateTime) obj; + long time = UdfUtils.convertDateTimeToLong(date.getYear(), date.getMonthValue(), date.getDayOfMonth(), + date.getHour(), date.getMinute(), date.getSecond(), false); + UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); + return true; + } + case LARGEINT: { + BigInteger data = (BigInteger) obj; + byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray()); + + //here value is 16 bytes, so if result data greater than the maximum of 16 bytes + //it will return a wrong num to backend; + byte[] value = new byte[16]; + //check data is negative + if (data.signum() == -1) { + Arrays.fill(value, (byte) -1); + } + for (int index = 0; index < Math.min(bytes.length, value.length); ++index) { + value[index] = bytes[index]; + } + + UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, + UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length); + return true; + } + case DECIMALV2: { + BigInteger data = ((BigDecimal) obj).unscaledValue(); + byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray()); + //TODO: here is maybe overflow also, and may find a better way to handle + byte[] value = new byte[16]; + if (data.signum() == -1) { + Arrays.fill(value, (byte) -1); + } + + for (int index = 0; index < Math.min(bytes.length, value.length); ++index) { + value[index] = bytes[index]; + } + + UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, + UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length); + return true; + } + case CHAR: + case VARCHAR: + case STRING: + long bufferSize = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr); + byte[] bytes = ((String) obj).getBytes(StandardCharsets.UTF_8); + + long offset = Integer.toUnsignedLong( + UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * row)); + if (offset + bytes.length > bufferSize) { + return false; + } + offset += bytes.length; + UdfUtils.UNSAFE.putChar(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + offset - 1, + UdfUtils.END_OF_STRING); + UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * row, + Integer.parseUnsignedInt(String.valueOf(offset))); + UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null, + UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + offset - bytes.length - 1, bytes.length); + return true; + default: + throw new UdfRuntimeException("Unsupported return type: " + retType); + } + } + + private Object[] allocateInputObjects(long row) throws UdfRuntimeException { + Object[] inputObjects = new Object[argTypes.length]; + + for (int i = 0; i < argTypes.length; ++i) { + switch (argTypes[i]) { + case BOOLEAN: + inputObjects[i] = UdfUtils.UNSAFE.getBoolean(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row); + break; + case TINYINT: + inputObjects[i] = UdfUtils.UNSAFE.getByte(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row); + break; + case SMALLINT: + inputObjects[i] = UdfUtils.UNSAFE.getShort(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 2L * row); + break; + case INT: + inputObjects[i] = UdfUtils.UNSAFE.getInt(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 4L * row); + break; + case BIGINT: + inputObjects[i] = UdfUtils.UNSAFE.getLong(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row); + break; + case FLOAT: + inputObjects[i] = UdfUtils.UNSAFE.getFloat(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 4L * row); + break; + case DOUBLE: + inputObjects[i] = UdfUtils.UNSAFE.getDouble(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row); + break; + case DATE: { + long data = UdfUtils.UNSAFE.getLong(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row); + inputObjects[i] = UdfUtils.convertToDate(data); + break; + } + case DATETIME: { + long data = UdfUtils.UNSAFE.getLong(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row); + inputObjects[i] = UdfUtils.convertToDateTime(data); + break; + } + case LARGEINT: { + long base = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + + 16L * row; + byte[] bytes = new byte[16]; + UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16); + + inputObjects[i] = new BigInteger(UdfUtils.convertByteOrder(bytes)); + break; + } + case DECIMALV2: { + long base = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + + 16L * row; + byte[] bytes = new byte[16]; + UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16); + + BigInteger value = new BigInteger(UdfUtils.convertByteOrder(bytes)); + inputObjects[i] = new BigDecimal(value, 9); + break; + } + case CHAR: + case VARCHAR: + case STRING: + long offset = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + + 4L * row)); + long numBytes = row == 0 ? offset - 1 : offset - Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, + UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row + - 1))) - 1; + long base = row == 0 ? UdfUtils.UNSAFE.getLong(null, + UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + : UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + offset + - numBytes - 1; + byte[] bytes = new byte[(int) numBytes]; + UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes); + inputObjects[i] = new String(bytes, StandardCharsets.UTF_8); + break; + default: + throw new UdfRuntimeException("Unsupported argument type: " + argTypes[i]); + } + } + return inputObjects; + } + + private void init(String jarPath, String udfPath, Type funcRetType, Type... parameterTypes) + throws UdfRuntimeException { + ArrayList<String> signatures = Lists.newArrayList(); + try { + ClassLoader loader; + if (jarPath != null) { + ClassLoader parent = getClass().getClassLoader(); + classLoader = UdfUtils.getClassLoader(jarPath, parent); + loader = classLoader; + } else { + loader = ClassLoader.getSystemClassLoader(); + } + Class<?> c = Class.forName(udfPath, true, loader); + Constructor<?> ctor = c.getConstructor(); + udaf = ctor.newInstance(); + Method[] methods = c.getDeclaredMethods(); + int idx = 0; + for (idx = 0; idx < methods.length; ++idx) { + signatures.add(methods[idx].toGenericString()); + switch (methods[idx].getName()) { + case UDAF_DESTORY_FUNCTION: + case UDAF_CREATE_FUNCTION: + case UDAF_MERGE_FUNCTION: + case UDAF_SERIALIZE_FUNCTION: + case UDAF_DESERIALIZE_FUNCTION: { + allMethods.put(methods[idx].getName(), methods[idx]); + break; + } + case UDAF_RESULT_FUNCTION: { + allMethods.put(methods[idx].getName(), methods[idx]); + Pair<Boolean, JavaUdfDataType> returnType = UdfUtils.setReturnType(funcRetType, + methods[idx].getReturnType()); + if (!returnType.first) { + LOG.debug("result function set return parameterTypes has error"); + } else { + retType = returnType.second; + } + break; + } + case UDAF_ADD_FUNCTION: { + allMethods.put(methods[idx].getName(), methods[idx]); + + Class<?>[] methodTypes = methods[idx].getParameterTypes(); + if (methodTypes.length != parameterTypes.length + 1) { + LOG.debug("add function parameterTypes length not equal " + methodTypes.length + " " + + parameterTypes.length + " " + methods[idx].getName()); + } + if (!(parameterTypes.length == 0)) { + Pair<Boolean, JavaUdfDataType[]> inputType = UdfUtils.setArgTypes(parameterTypes, + methodTypes, true); + if (!inputType.first) { + LOG.debug("add function set arg parameterTypes has error"); + } else { + argTypes = inputType.second; + } + } else { + // Special case where the UDF doesn't take any input args + argTypes = new JavaUdfDataType[0]; + } + break; + } + default: + break; + } + } + if (idx == methods.length) { + return; + } + StringBuilder sb = new StringBuilder(); + sb.append("Unable to find evaluate function with the correct signature: ").append(udfPath + ".evaluate(") + .append(Joiner.on(", ").join(parameterTypes)).append(")\n").append("UDF contains: \n ") + .append(Joiner.on("\n ").join(signatures)); + throw new UdfRuntimeException(sb.toString()); + + } catch (MalformedURLException e) { + throw new UdfRuntimeException("Unable to load jar.", e); + } catch (SecurityException e) { + throw new UdfRuntimeException("Unable to load function.", e); + } catch (ClassNotFoundException e) { + throw new UdfRuntimeException("Unable to find class.", e); + } catch (NoSuchMethodException e) { + throw new UdfRuntimeException("Unable to find constructor with no arguments.", e); + } catch (IllegalArgumentException e) { + throw new UdfRuntimeException("Unable to call UDAF constructor with no arguments.", e); + } catch (Exception e) { + throw new UdfRuntimeException("Unable to call create UDAF instance.", e); + } + } +} diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java index 1d2560102a..6ab54339ef 100644 --- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java +++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java @@ -18,25 +18,23 @@ package org.apache.doris.udf; import org.apache.doris.catalog.Type; +import org.apache.doris.common.Pair; import org.apache.doris.thrift.TJavaUdfExecutorCtorParams; import org.apache.doris.thrift.TPrimitiveType; import com.google.common.base.Joiner; -import com.google.common.base.Preconditions; import com.google.common.collect.Lists; import org.apache.log4j.Logger; import org.apache.thrift.TDeserializer; import org.apache.thrift.TException; import org.apache.thrift.protocol.TBinaryProtocol; -import java.io.File; import java.io.IOException; import java.lang.reflect.Constructor; import java.lang.reflect.Method; import java.math.BigDecimal; import java.math.BigInteger; import java.net.MalformedURLException; -import java.net.URL; import java.net.URLClassLoader; import java.nio.charset.StandardCharsets; import java.time.LocalDate; @@ -362,25 +360,21 @@ public class UdfExecutor { case DATE: { LocalDate date = (LocalDate) obj; long time = - convertDateTimeToLong(date.getYear(), date.getMonthValue(), date.getDayOfMonth(), 0, 0, 0, - true); - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), - time); + UdfUtils.convertDateTimeToLong(date.getYear(), date.getMonthValue(), date.getDayOfMonth(), 0, 0, + 0, true); + UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); return true; } case DATETIME: { LocalDateTime date = (LocalDateTime) obj; - long time = - convertDateTimeToLong(date.getYear(), date.getMonthValue(), date.getDayOfMonth(), - date.getHour(), - date.getMinute(), date.getSecond(), false); - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), - time); + long time = UdfUtils.convertDateTimeToLong(date.getYear(), date.getMonthValue(), date.getDayOfMonth(), + date.getHour(), date.getMinute(), date.getSecond(), false); + UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); return true; } case LARGEINT: { BigInteger data = (BigInteger) obj; - byte[] bytes = convertByteOrder(data.toByteArray()); + byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray()); //here value is 16 bytes, so if result data greater than the maximum of 16 bytes //it will return a wrong num to backend; @@ -399,7 +393,7 @@ public class UdfExecutor { } case DECIMALV2: { BigInteger data = ((BigDecimal) obj).unscaledValue(); - byte[] bytes = convertByteOrder(data.toByteArray()); + byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray()); //TODO: here is maybe overflow also, and may find a better way to handle byte[] value = new byte[16]; if (data.signum() == -1) { @@ -474,13 +468,13 @@ public class UdfExecutor { case DATE: { long data = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row); - inputObjects[i] = convertToDate(data); + inputObjects[i] = UdfUtils.convertToDate(data); break; } case DATETIME: { long data = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + 8L * row); - inputObjects[i] = convertToDateTime(data); + inputObjects[i] = UdfUtils.convertToDateTime(data); break; } case LARGEINT: { @@ -489,7 +483,7 @@ public class UdfExecutor { byte[] bytes = new byte[16]; UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16); - inputObjects[i] = new BigInteger(convertByteOrder(bytes)); + inputObjects[i] = new BigInteger(UdfUtils.convertByteOrder(bytes)); break; } case DECIMALV2: { @@ -498,7 +492,7 @@ public class UdfExecutor { byte[] bytes = new byte[16]; UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16); - BigInteger value = new BigInteger(convertByteOrder(bytes)); + BigInteger value = new BigInteger(UdfUtils.convertByteOrder(bytes)); inputObjects[i] = new BigDecimal(value, 9); break; } @@ -525,58 +519,16 @@ public class UdfExecutor { } } - private URLClassLoader getClassLoader(String jarPath) throws MalformedURLException { - URL url = new File(jarPath).toURI().toURL(); - return URLClassLoader.newInstance(new URL[] {url}, getClass().getClassLoader()); - } - - /** - * Sets the return type of a Java UDF. Returns true if the return type is compatible - * with the return type from the function definition. Throws an UdfRuntimeException - * if the return type is not supported. - */ - private boolean setReturnType(Type retType, Class<?> udfReturnType) - throws InternalException { - if (!JavaUdfDataType.isSupported(retType)) { - throw new InternalException("Unsupported return type: " + retType.toSql()); - } - JavaUdfDataType javaType = JavaUdfDataType.getType(udfReturnType); - // Check if the evaluate method return type is compatible with the return type from - // the function definition. This happens when both of them map to the same primitive - // type. - if (retType.getPrimitiveType().toThrift() != javaType.getPrimitiveType()) { - return false; - } - this.retType = javaType; - return true; - } - - /** - * Sets the argument types of a Java UDF. Returns true if the argument types specified - * in the UDF are compatible with the argument types of the evaluate() function loaded - * from the associated JAR file. - */ - private boolean setArgTypes(Type[] parameterTypes, Class<?>[] udfArgTypes) { - Preconditions.checkNotNull(argTypes); - for (int i = 0; i < udfArgTypes.length; ++i) { - argTypes[i] = JavaUdfDataType.getType(udfArgTypes[i]); - if (argTypes[i].getPrimitiveType() - != parameterTypes[i].getPrimitiveType().toThrift()) { - return false; - } - } - return true; - } - - private void init(String jarPath, String udfPath, - Type retType, Type... parameterTypes) throws UdfRuntimeException { + private void init(String jarPath, String udfPath, Type funcRetType, Type... parameterTypes) + throws UdfRuntimeException { ArrayList<String> signatures = Lists.newArrayList(); try { LOG.debug("Loading UDF '" + udfPath + "' from " + jarPath); ClassLoader loader; if (jarPath != null) { // Save for cleanup. - classLoader = getClassLoader(jarPath); + ClassLoader parent = getClass().getClassLoader(); + classLoader = UdfUtils.getClassLoader(jarPath, parent); loader = classLoader; } else { loader = ClassLoader.getSystemClassLoader(); @@ -584,7 +536,6 @@ public class UdfExecutor { Class<?> c = Class.forName(udfPath, true, loader); Constructor<?> ctor = c.getConstructor(); udf = ctor.newInstance(); - argTypes = new JavaUdfDataType[parameterTypes.length]; Method[] methods = c.getMethods(); for (Method m : methods) { // By convention, the udf must contain the function "evaluate" @@ -599,19 +550,30 @@ public class UdfExecutor { continue; } method = m; + Pair<Boolean, JavaUdfDataType> returnType; if (methodTypes.length == 0 && parameterTypes.length == 0) { // Special case where the UDF doesn't take any input args - if (!setReturnType(retType, m.getReturnType())) { + returnType = UdfUtils.setReturnType(funcRetType, m.getReturnType()); + if (!returnType.first) { continue; + } else { + retType = returnType.second; } + argTypes = new JavaUdfDataType[0]; LOG.debug("Loaded UDF '" + udfPath + "' from " + jarPath); return; } - if (!setReturnType(retType, m.getReturnType())) { + returnType = UdfUtils.setReturnType(funcRetType, m.getReturnType()); + if (!returnType.first) { continue; + } else { + retType = returnType.second; } - if (!setArgTypes(parameterTypes, methodTypes)) { + Pair<Boolean, JavaUdfDataType[]> inputType = UdfUtils.setArgTypes(parameterTypes, methodTypes, false); + if (!inputType.first) { continue; + } else { + argTypes = inputType.second; } LOG.debug("Loaded UDF '" + udfPath + "' from " + jarPath); return; @@ -641,65 +603,4 @@ public class UdfExecutor { throw new UdfRuntimeException("Unable to call create UDF instance.", e); } } - - // input is a 64bit num from backend, and then get year, month, day, hour, minus, second by the order of bits - // return a new LocalDateTime data to evaluate method; - private LocalDateTime convertToDateTime(long date) { - int year = (int) (date >> 48); - int yearMonth = (int) (date >> 40); - int yearMonthDay = (int) (date >> 32); - - int month = (yearMonth & 0XFF); - int day = (yearMonthDay & 0XFF); - - int hourMinuteSecond = (int) (date % (1 << 31)); - int minuteTypeNeg = (hourMinuteSecond % (1 << 16)); - - int hour = (hourMinuteSecond >> 24); - int minute = ((hourMinuteSecond >> 16) & 0XFF); - int second = (minuteTypeNeg >> 4); - //here don't need those bits are type = ((minus_type_neg >> 1) & 0x7); - - LocalDateTime value = LocalDateTime.of(year, month, day, hour, minute, second); - return value; - } - - private LocalDate convertToDate(long date) { - int year = (int) (date >> 48); - int yearMonth = (int) (date >> 40); - int yearMonthDay = (int) (date >> 32); - - int month = (yearMonth & 0XFF); - int day = (yearMonthDay & 0XFF); - LocalDate value = LocalDate.of(year, month, day); - return value; - } - - //input is the second, minute, hours, day , month and year respectively - //and then combining all num to a 64bit value return to backend; - private long convertDateTimeToLong(int year, int month, int day, int hour, int minute, int second, boolean isDate) { - long time = 0; - time = time + year; - time = (time << 8) + month; - time = (time << 8) + day; - time = (time << 8) + hour; - time = (time << 8) + minute; - time = (time << 12) + second; - int type = isDate ? 2 : 3; - time = (time << 3) + type; - //this bit is int neg = 0; - time = (time << 1); - return time; - } - - // Change the order of the bytes, Because JVM is Big-Endian , x86 is Little-Endian - private byte[] convertByteOrder(byte[] bytes) { - int length = bytes.length; - for (int i = 0; i < length / 2; ++i) { - byte temp = bytes[i]; - bytes[i] = bytes[length - 1 - i]; - bytes[length - 1 - i] = temp; - } - return bytes; - } } diff --git a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java index da2d0c9324..616416ef3a 100644 --- a/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java +++ b/fe/java-udf/src/main/java/org/apache/doris/udf/UdfUtils.java @@ -25,13 +25,20 @@ import org.apache.doris.thrift.TPrimitiveType; import org.apache.doris.thrift.TScalarType; import org.apache.doris.thrift.TTypeDesc; import org.apache.doris.thrift.TTypeNode; +import org.apache.doris.udf.UdfExecutor.JavaUdfDataType; import com.google.common.base.Preconditions; import sun.misc.Unsafe; +import java.io.File; import java.lang.reflect.Field; +import java.net.MalformedURLException; +import java.net.URL; +import java.net.URLClassLoader; import java.security.AccessController; import java.security.PrivilegedAction; +import java.time.LocalDate; +import java.time.LocalDateTime; public class UdfUtils { public static final Unsafe UNSAFE; @@ -112,4 +119,117 @@ public class UdfUtils { } } + + public static URLClassLoader getClassLoader(String jarPath, ClassLoader parent) throws MalformedURLException { + URL url = new File(jarPath).toURI().toURL(); + return URLClassLoader.newInstance(new URL[] {url}, parent); + } + + /** + * Sets the return type of a Java UDF. Returns true if the return type is compatible + * with the return type from the function definition. Throws an UdfRuntimeException + * if the return type is not supported. + */ + public static Pair<Boolean, JavaUdfDataType> setReturnType(Type retType, Class<?> udfReturnType) + throws InternalException { + if (!JavaUdfDataType.isSupported(retType)) { + throw new InternalException("Unsupported return type: " + retType.toSql()); + } + JavaUdfDataType javaType = JavaUdfDataType.getType(udfReturnType); + // Check if the evaluate method return type is compatible with the return type from + // the function definition. This happens when both of them map to the same primitive + // type. + if (retType.getPrimitiveType().toThrift() != javaType.getPrimitiveType()) { + return new Pair<Boolean, JavaUdfDataType>(false, javaType); + } + return new Pair<Boolean, JavaUdfDataType>(true, javaType); + } + + /** + * Sets the argument types of a Java UDF or UDAF. Returns true if the argument types specified + * in the UDF are compatible with the argument types of the evaluate() function loaded + * from the associated JAR file. + */ + public static Pair<Boolean, JavaUdfDataType[]> setArgTypes(Type[] parameterTypes, Class<?>[] udfArgTypes, + boolean isUdaf) { + JavaUdfDataType[] inputArgTypes = new JavaUdfDataType[parameterTypes.length]; + int firstPos = isUdaf ? 1 : 0; + for (int i = 0; i < parameterTypes.length; ++i) { + inputArgTypes[i] = JavaUdfDataType.getType(udfArgTypes[i + firstPos]); + if (inputArgTypes[i].getPrimitiveType() != parameterTypes[i].getPrimitiveType().toThrift()) { + return new Pair<Boolean, JavaUdfDataType[]>(false, inputArgTypes); + } + } + return new Pair<Boolean, JavaUdfDataType[]>(true, inputArgTypes); + } + + /** + * input is a 64bit num from backend, and then get year, month, day, hour, minus, second by the order of bits. + */ + public static LocalDateTime convertToDateTime(long date) { + int year = (int) (date >> 48); + int yearMonth = (int) (date >> 40); + int yearMonthDay = (int) (date >> 32); + + int month = (yearMonth & 0XFF); + int day = (yearMonthDay & 0XFF); + + int hourMinuteSecond = (int) (date % (1 << 31)); + int minuteTypeNeg = (hourMinuteSecond % (1 << 16)); + + int hour = (hourMinuteSecond >> 24); + int minute = ((hourMinuteSecond >> 16) & 0XFF); + int second = (minuteTypeNeg >> 4); + //here don't need those bits are type = ((minus_type_neg >> 1) & 0x7); + + LocalDateTime value = LocalDateTime.of(year, month, day, hour, minute, second); + return value; + } + + /** + * a 64bit num convertToDate. + */ + public static LocalDate convertToDate(long date) { + int year = (int) (date >> 48); + int yearMonth = (int) (date >> 40); + int yearMonthDay = (int) (date >> 32); + + int month = (yearMonth & 0XFF); + int day = (yearMonthDay & 0XFF); + LocalDate value = LocalDate.of(year, month, day); + return value; + } + + /** + * input is the second, minute, hours, day , month and year respectively. + * and then combining all num to a 64bit value return to backend; + */ + public static long convertDateTimeToLong(int year, int month, int day, int hour, int minute, int second, + boolean isDate) { + long time = 0; + time = time + year; + time = (time << 8) + month; + time = (time << 8) + day; + time = (time << 8) + hour; + time = (time << 8) + minute; + time = (time << 12) + second; + int type = isDate ? 2 : 3; + time = (time << 3) + type; + //this bit is int neg = 0; + time = (time << 1); + return time; + } + + /** + * Change the order of the bytes, Because JVM is Big-Endian , x86 is Little-Endian. + */ + public static byte[] convertByteOrder(byte[] bytes) { + int length = bytes.length; + for (int i = 0; i < length / 2; ++i) { + byte temp = bytes[i]; + bytes[i] = bytes[length - 1 - i]; + bytes[length - 1 - i] = temp; + } + return bytes; + } } diff --git a/gensrc/thrift/Types.thrift b/gensrc/thrift/Types.thrift index 13af910597..1d6bc8c683 100644 --- a/gensrc/thrift/Types.thrift +++ b/gensrc/thrift/Types.thrift @@ -309,6 +309,8 @@ struct TAggregateFunction { 8: optional string get_value_fn_symbol 9: optional string remove_fn_symbol 10: optional bool is_analytic_only_fn = false + // used for java-udaf to point user defined class + 11: optional string symbol } // Represents a function in the Catalog. --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org