tqchen commented on code in PR #286: URL: https://github.com/apache/tvm-ffi/pull/286#discussion_r2590195571
########## include/tvm/ffi/extra/overload.h: ########## @@ -0,0 +1,501 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file tvm/ffi/extra/overload.h + * \brief Registry of reflection metadata, supporting function overloading. + */ +#ifndef TVM_FFI_EXTRA_OVERLOAD_H +#define TVM_FFI_EXTRA_OVERLOAD_H + +#include <tvm/ffi/any.h> +#include <tvm/ffi/c_api.h> +#include <tvm/ffi/container/map.h> +#include <tvm/ffi/container/variant.h> +#include <tvm/ffi/function.h> +#include <tvm/ffi/function_details.h> +#include <tvm/ffi/optional.h> +#include <tvm/ffi/reflection/registry.h> +#include <tvm/ffi/string.h> +#include <tvm/ffi/type_traits.h> + +#include <cstddef> +#include <cstdint> +#include <sstream> +#include <string> +#include <type_traits> +#include <unordered_map> +#include <utility> + +namespace tvm { +namespace ffi { + +namespace details { + +struct OverloadBase { + public: + // Try Call function pointer type, return the fail index + using FnPtr = bool (*)(OverloadBase*, const AnyView*, int32_t, Any*); + + explicit OverloadBase(int32_t num_args, std::optional<std::string> name) + : num_args_(num_args), + name_(name ? std::move(*name) : ""), + name_ptr_(name ? &this->name_ : nullptr) {} + + virtual void Register(std::unique_ptr<OverloadBase> overload) = 0; + virtual FnPtr GetTryCallPtr() = 0; + virtual void GetMismatchMessage(std::ostringstream& os, const AnyView* args, + int32_t num_args) = 0; + + virtual ~OverloadBase() = default; + OverloadBase(const OverloadBase&) = delete; + OverloadBase& operator=(const OverloadBase&) = delete; + + public: + static constexpr int32_t kAllMatched = -1; + + // a fast cache for last matched arg index + // on 64-bit platform, this is packed in the same 8 byte with num_args_ + int32_t last_mismatch_index_{kAllMatched}; + + // some constant helper args + const int32_t num_args_; + const std::string name_; + const std::string* const name_ptr_; +}; + +template <typename T> +struct CaptureTupleAux; + +template <typename... Args> +struct CaptureTupleAux<std::tuple<Args...>> { + using type = std::tuple<std::optional<std::decay_t<Args>>...>; +}; + +template <typename Callable> +struct TypedOverload : OverloadBase { + public: + static_assert(std::is_same_v<Callable, std::decay_t<Callable>>, "Callable must be value type"); + + using FuncInfo = details::FunctionInfo<Callable>; + using PackedArgs = typename FuncInfo::ArgType; + using Ret = typename FuncInfo::RetType; + using CaptureTuple = typename CaptureTupleAux<PackedArgs>::type; + using OverloadBase::name_; + using OverloadBase::name_ptr_; + using typename OverloadBase::FnPtr; + + static constexpr auto kNumArgs = FuncInfo::num_args; + static constexpr auto kSeq = std::make_index_sequence<kNumArgs>{}; + + explicit TypedOverload(const Callable& f, std::optional<std::string> name = std::nullopt) + : OverloadBase(kNumArgs, std::move(name)), f_(f) {} + explicit TypedOverload(Callable&& f, std::optional<std::string> name = std::nullopt) + : OverloadBase(kNumArgs, std::move(name)), f_(std::move(f)) {} + + bool TryCall(const AnyView* args, int32_t num_args, Any* rv) { + if (num_args != kNumArgs) return false; + CaptureTuple captures{}; + if (!TrySetAux(kSeq, captures, args)) return false; + // now all captures are set + if constexpr (std::is_same_v<Ret, void>) { + CallAux(kSeq, captures); + return true; + } else { + *rv = CallAux(kSeq, captures); + return true; + } + } + + void Register(std::unique_ptr<OverloadBase> overload) override { + TVM_FFI_ICHECK(false) << "This should never be called."; + } + + FnPtr GetTryCallPtr() final { + // lambda without a capture can be converted to function pointer + return [](OverloadBase* base, const AnyView* args, int32_t num_args, Any* rv) -> bool { + return static_cast<TypedOverload<Callable>*>(base)->TryCall(args, num_args, rv); + }; + } + + void GetMismatchMessage(std::ostringstream& os, const AnyView* args, int32_t num_args) final { + FGetFuncSignature f_sig = FuncInfo::Sig; + if (num_args != kNumArgs) { + os << "Mismatched number of arguments when calling: `" << name_ << " " + << (f_sig == nullptr ? "" : (*f_sig)()) << "`. Expected " << kNumArgs << " arguments"; + } else { + GetMismatchMessageAux<0>(os, args, num_args); + } + } + + private: + template <std::size_t I> + void GetMismatchMessageAux(std::ostringstream& os, const AnyView* args, int32_t num_args) { + if constexpr (I < kNumArgs) { + if (this->last_mismatch_index_ == static_cast<int32_t>(I)) { + TVMFFIAny any_data = args[I].CopyToTVMFFIAny(); + FGetFuncSignature f_sig = FuncInfo::Sig; + using Type = std::decay_t<std::tuple_element_t<I, PackedArgs>>; + os << "Mismatched type on argument #" << I << " when calling: `" << name_ << " " + << (f_sig == nullptr ? "" : (*f_sig)()) << "`. Expected `" << Type2Str<Type>::v() + << "` but got `" << TypeTraits<Type>::GetMismatchTypeInfo(&any_data) << '`'; + } else { + GetMismatchMessageAux<I + 1>(os, args, num_args); + } + } + // end of recursion + } + + template <std::size_t... I> + Ret CallAux(std::index_sequence<I...>, CaptureTuple& tuple) { + /// NOTE: this works for T, const T, const T&, T&& argument types + return f_(static_cast<std::tuple_element_t<I, PackedArgs>>(std::move(*std::get<I>(tuple)))...); + } + + template <std::size_t... I> + bool TrySetAux(std::index_sequence<I...>, CaptureTuple& tuple, const AnyView* args) { + return (TrySetOne<I>(tuple, args) && ...); + } + + template <std::size_t I> + bool TrySetOne(CaptureTuple& tuple, const AnyView* args) { + using Type = std::decay_t<std::tuple_element_t<I, PackedArgs>>; + auto& capture = std::get<I>(tuple); + if constexpr (std::is_same_v<Type, AnyView>) { + capture = args[I]; + return true; + } else if constexpr (std::is_same_v<Type, Any>) { + capture = Any(args[I]); + return true; + } else { + capture = args[I].template try_cast<Type>(); + if (capture.has_value()) return true; + // slow path: record the last mismatch index + this->last_mismatch_index_ = static_cast<int32_t>(I); + return false; + } + } + + protected: + Callable f_; +}; + +template <typename Callable> +inline auto CreateNewOverload(Callable&& f, std::string name) { + using Type = TypedOverload<std::decay_t<Callable>>; + return std::make_unique<Type>(std::forward<Callable>(f), std::move(name)); +} + +template <typename Callable> +struct OverloadedFunction : TypedOverload<Callable> { + public: + using TypedBase = TypedOverload<Callable>; + using OverloadBase::name_; + using OverloadBase::name_ptr_; + using TypedBase::GetTryCallPtr; + using TypedBase::kNumArgs; + using TypedBase::kSeq; + using TypedBase::TypedBase; // constructors + using typename OverloadBase::FnPtr; + using typename TypedBase::Ret; + + void Register(std::unique_ptr<OverloadBase> overload) final { + const auto fptr = overload->GetTryCallPtr(); + overloads_.emplace_back(std::move(overload), fptr); + } + + void operator()(const AnyView* args, int32_t num_args, Any* rv) { + // fast path: only add a little overhead when no overloads + if (overloads_.size() == 0) { + return unpack_call<Ret>(kSeq, name_ptr_, f_, args, num_args, rv); + } + + // this can be inlined by compiler, don't worry + if (this->TryCall(args, num_args, rv)) return; + + // virtual calls cannot be inlined, so we fast check the num_args first + // we also de-virtualize the fptr to reduce one more indirection + for (const auto& [overload, fptr] : overloads_) { + if (overload->num_args_ != num_args) continue; + if (fptr(overload.get(), args, num_args, rv)) return; + } + + this->handle_overload_failure(args, num_args); + } + + private: + void handle_overload_failure(const AnyView* args, int32_t num_args) { Review Comment: style, the codebase use CamelSyle per google C style -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
