gemini-code-assist[bot] commented on code in PR #286:
URL: https://github.com/apache/tvm-ffi/pull/286#discussion_r2561254626


##########
include/tvm/ffi/extra/overload.h:
##########
@@ -0,0 +1,449 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/reflection/registry.h

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The file path in this comment is incorrect. It appears to be a copy-paste 
error and should be updated to reflect the correct file path.
   
   ```c
    * \file tvm/ffi/extra/overload.h
   ```



##########
include/tvm/ffi/extra/overload.h:
##########
@@ -0,0 +1,449 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file tvm/ffi/reflection/registry.h
+ * \brief Registry of reflection metadata.
+ */
+#ifndef TVM_FFI_EXTRA_OVERLOAD_H
+#define TVM_FFI_EXTRA_OVERLOAD_H
+
+#include <tvm/ffi/any.h>
+#include <tvm/ffi/c_api.h>
+#include <tvm/ffi/container/map.h>
+#include <tvm/ffi/container/variant.h>
+#include <tvm/ffi/function.h>
+#include <tvm/ffi/function_details.h>
+#include <tvm/ffi/optional.h>
+#include <tvm/ffi/reflection/registry.h>
+#include <tvm/ffi/string.h>
+#include <tvm/ffi/type_traits.h>
+
+#include <string>
+#include <type_traits>
+#include <unordered_map>
+#include <utility>
+
+namespace tvm {
+namespace ffi {
+
+namespace details {
+
+struct OverloadBase {
+ public:
+  // Try Call function pointer type
+  using FnPtr = bool (*)(OverloadBase*, const AnyView*, int32_t, Any*);
+
+  explicit OverloadBase(int32_t num_args, std::optional<std::string> name)
+      : num_args_(num_args),
+        name_(name ? std::move(*name) : ""),
+        name_ptr_(name ? &this->name_ : nullptr) {}
+
+  virtual void Register(std::unique_ptr<OverloadBase> overload) = 0;
+  virtual FnPtr GetTryCallPtr() = 0;
+
+  virtual ~OverloadBase() = default;
+  OverloadBase(const OverloadBase&) = delete;
+  OverloadBase& operator=(const OverloadBase&) = delete;
+
+ public:
+  // helper args
+  const int32_t num_args_;
+  const std::string name_;
+  const std::string* const name_ptr_ = nullptr;
+};
+
+template <typename T>
+struct CaptureTupleAux;
+
+template <typename... Args>
+struct CaptureTupleAux<std::tuple<Args...>> {
+  using type = std::tuple<std::optional<std::decay_t<Args>>...>;
+};
+
+template <typename Callable>
+struct TypedOverload : OverloadBase {
+ public:
+  static_assert(std::is_same_v<Callable, std::decay_t<Callable>>, "Callable 
must be value type");
+
+  using FuncInfo = details::FunctionInfo<Callable>;
+  using PackedArgs = typename FuncInfo::ArgType;
+  using Ret = typename FuncInfo::RetType;
+  using CaptureTuple = typename CaptureTupleAux<PackedArgs>::type;
+  using OverloadBase::name_;
+  using OverloadBase::name_ptr_;
+  using typename OverloadBase::FnPtr;
+
+  static constexpr auto kNumArgs = FuncInfo::num_args;
+  static constexpr auto kSeq = std::make_index_sequence<kNumArgs>{};
+
+  explicit TypedOverload(const Callable& f, std::optional<std::string> name = 
std::nullopt)
+      : OverloadBase(kNumArgs, std::move(name)), f_(f) {}
+  explicit TypedOverload(Callable&& f, std::optional<std::string> name = 
std::nullopt)
+      : OverloadBase(kNumArgs, std::move(name)), f_(std::move(f)) {}
+
+  bool TryCall(const AnyView* args, int32_t num_args, Any* rv) {
+    if (num_args != kNumArgs) return false;
+    CaptureTuple captures{};
+    if (!TrySetAux(kSeq, captures, args)) return false;
+    // now all captures are set
+    if constexpr (std::is_same_v<Ret, void>) {
+      CallAux(kSeq, captures);
+      return true;
+    } else {
+      *rv = CallAux(kSeq, captures);
+      return true;
+    }
+  }
+
+  void Register(std::unique_ptr<OverloadBase> overload) override {
+    TVM_FFI_ICHECK(false) << "This should never be called.";
+  }
+
+  FnPtr GetTryCallPtr() final {
+    // lambda without a capture can be converted to function pointer
+    return [](OverloadBase* base, const AnyView* args, int32_t num_args, Any* 
rv) -> bool {
+      return static_cast<TypedOverload<Callable>*>(base)->TryCall(args, 
num_args, rv);
+    };
+  }
+
+ private:
+  template <std::size_t... I>
+  Ret CallAux(std::index_sequence<I...>, CaptureTuple& tuple) {
+    /// NOTE: this works for T, const T, const T&, T&& argument types
+    return f_(static_cast<std::tuple_element_t<I, 
PackedArgs>>(std::move(*std::get<I>(tuple)))...);
+  }
+
+  template <std::size_t... I>
+  bool TrySetAux(std::index_sequence<I...>, CaptureTuple& tuple, const 
AnyView* args) {
+    return (TrySetOne<I>(tuple, args) && ...);
+  }
+
+  template <std::size_t I>
+  bool TrySetOne(CaptureTuple& tuple, const AnyView* args) {
+    using Type = std::decay_t<std::tuple_element_t<I, PackedArgs>>;
+    auto& capture = std::get<I>(tuple);
+    if constexpr (std::is_same_v<Type, AnyView>) {
+      capture = args[I];
+      return true;
+    } else if constexpr (std::is_same_v<Type, Any>) {
+      capture = Any(args[I]);
+      return true;
+    } else {
+      capture = args[I].template try_cast<Type>();
+      return capture.has_value();
+    }
+  }
+
+ protected:
+  Callable f_;
+};
+
+template <typename Callable>
+inline auto CreateNewOverload(Callable&& f, std::string name) {
+  using Type = TypedOverload<std::decay_t<Callable>>;
+  return std::make_unique<Type>(std::forward<Callable>(f), std::move(name));
+}
+
+template <typename Callable>
+struct OverloadedFunction : TypedOverload<Callable> {
+ public:
+  using TypedBase = TypedOverload<Callable>;
+  using OverloadBase::name_;
+  using OverloadBase::name_ptr_;
+  using TypedBase::GetTryCallPtr;
+  using TypedBase::kNumArgs;
+  using TypedBase::kSeq;
+  using TypedBase::TypedBase;  // constructors
+  using typename OverloadBase::FnPtr;
+  using typename TypedBase::Ret;
+
+  void Register(std::unique_ptr<OverloadBase> overload) final {
+    const auto fptr = overload->GetTryCallPtr();
+    overloads_.emplace_back(std::move(overload), fptr);
+  }
+
+  void operator()(const AnyView* args, int32_t num_args, Any* rv) {
+    // fast path: only add a little overhead when no overloads
+    if (overloads_.size() == 0) {
+      return unpack_call<Ret>(kSeq, name_ptr_, f_, args, num_args, rv);
+    }
+
+    // this can be inlined by compiler, don't worry
+    if (this->TryCall(args, num_args, rv)) return;
+
+    // virtual calls cannot be inlined, so we fast check the num_args first
+    // we also de-virtualize the fptr to reduce one more indirection
+    for (const auto& [overload, fptr] : overloads_) {
+      if (overload->num_args_ != num_args) continue;
+      if (fptr(overload.get(), args, num_args, rv)) return;
+    }
+
+    /// TODO: better error message
+    TVM_FFI_THROW(TypeError) << "No matching overload found when calling: `" 
<< name_ << "` with "
+                             << num_args << " arguments.";

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The current error message for a failed overload resolution is quite generic. 
To improve debuggability, consider enhancing the error message to include more 
context, such as the argument types that were provided and a list of the 
available overloads with their signatures. This would make it much easier for 
users to identify why a call failed to match any overload.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to