gemini-code-assist[bot] commented on code in PR #131:
URL: https://github.com/apache/tvm-ffi/pull/131#discussion_r2433412955
##########
include/tvm/ffi/container/tensor.h:
##########
@@ -359,66 +359,39 @@ class Tensor : public ObjectRef {
std::forward<ExtraArgs>(extra_args)...));
}
/*!
- * \brief Create a Tensor from a DLPackManagedTensorAllocator
+ * \brief Create a Tensor from the TVMFFIEnvTensorAlloc API
*
- * This function can be used together with TVMFFIEnvSetTensorAllocator
- * in the extra/c_env_api.h to create Tensor from the thread-local
- * environment allocator.
+ * This function can be used together with
TVMFFIEnvSetDLPackManagedTensorAllocator
+ * in the extra/c_env_api.h to create a Tensor from the thread-local
environment allocator.
+ * We explicitly pass TVMFFIEnvTensorAlloc to maintain explicit dependency
on extra/c_env_api.h
*
* \code
*
- * ffi::Tensor tensor = ffi::Tensor::FromDLPackAlloc(
- * TVMFFIEnvGetTensorAllocator(), shape, dtype, device
+ * ffi::Tensor tensor = ffi::Tensor::FromEnvAlloc(
+ * TVMFFIEnvTensorAlloc, shape, dtype, device
* );
+ *
* \endcode
*
- * \param allocator The DLPack allocator.
+ * \param env_alloc TVMFFIEnvTensorAlloc function pointer.
* \param shape The shape of the Tensor.
* \param dtype The data type of the Tensor.
* \param device The device of the Tensor.
* \return The created Tensor.
+ *
+ * \sa TVMFFIEnvTensorAlloc
*/
- static Tensor FromDLPackAlloc(DLPackManagedTensorAllocator allocator,
ffi::ShapeView shape,
- DLDataType dtype, DLDevice device) {
- if (allocator == nullptr) {
- TVM_FFI_THROW(RuntimeError)
- << "FromDLPackAlloc: allocator is nullptr, "
- << "likely because TVMFFIEnvSetTensorAllocator has not been called.";
- }
+ static Tensor FromEnvAlloc(int (*env_alloc)(DLTensor*, TVMFFIObjectHandle*),
ffi::ShapeView shape,
+ DLDataType dtype, DLDevice device) {
+ TVMFFIObjectHandle out;
DLTensor prototype;
prototype.device = device;
prototype.dtype = dtype;
prototype.shape = const_cast<int64_t*>(shape.data());
prototype.ndim = static_cast<int>(shape.size());
Review Comment:

It's good practice to initialize all members of the `DLTensor` prototype to
avoid using uninitialized values. While the current implementation of
`TVMFFIEnvTensorAlloc` might not use fields like `strides`, `byte_offset`, or
`data`, it's safer to zero-initialize the struct. You can achieve this
concisely by using value-initialization (`{}`).
```c
DLTensor prototype{};
prototype.device = device;
prototype.dtype = dtype;
prototype.shape = const_cast<int64_t*>(shape.data());
prototype.ndim = static_cast<int>(shape.size());
```
##########
src/ffi/extra/env_context.cc:
##########
@@ -106,16 +107,43 @@ TVMFFIStreamHandle TVMFFIEnvGetStream(int32_t
device_type, int32_t device_id) {
TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIEnvGetStream);
}
-int TVMFFIEnvSetTensorAllocator(DLPackManagedTensorAllocator allocator, int
write_to_global_context,
- DLPackManagedTensorAllocator*
opt_out_original_allocator) {
+int TVMFFIEnvSetDLPackManagedTensorAllocator(
+ DLPackManagedTensorAllocator allocator, int write_to_global_context,
+ DLPackManagedTensorAllocator* opt_out_original_allocator) {
TVM_FFI_SAFE_CALL_BEGIN();
tvm::ffi::EnvContext::ThreadLocal()->SetDLPackManagedTensorAllocator(
allocator, write_to_global_context, opt_out_original_allocator);
TVM_FFI_SAFE_CALL_END();
}
-DLPackManagedTensorAllocator TVMFFIEnvGetTensorAllocator() {
+DLPackManagedTensorAllocator TVMFFIEnvGetDLPackManagedTensorAllocator() {
TVM_FFI_LOG_EXCEPTION_CALL_BEGIN();
return
tvm::ffi::EnvContext::ThreadLocal()->GetDLPackManagedTensorAllocator();
- TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIEnvGetTensorAllocator);
+ TVM_FFI_LOG_EXCEPTION_CALL_END(TVMFFIEnvGetDLPackManagedTensorAllocator);
+}
+
+void TVMFFIEnvTensorAllocSetError(void* error_ctx, const char* kind, const
char* message) {
+ TVMFFIErrorSetRaisedFromCStr(kind, message);
+}
+
+int TVMFFIEnvTensorAlloc(DLTensor* prototype, TVMFFIObjectHandle* out) {
+ TVM_FFI_SAFE_CALL_BEGIN();
+ DLPackManagedTensorAllocator dlpack_alloc =
+ tvm::ffi::EnvContext::ThreadLocal()->GetDLPackManagedTensorAllocator();
+ DLManagedTensorVersioned* dlpack_tensor = nullptr;
+ int ret = (*dlpack_alloc)(const_cast<DLTensor*>(prototype), &dlpack_tensor,
nullptr,
+ TVMFFIEnvTensorAllocSetError);
+ if (ret != 0) return ret;
Review Comment:

This block has a few issues that should be addressed:
1. **Potential null pointer dereference (Critical):** `dlpack_alloc` can be
`nullptr` if no allocator is set in the environment, which would cause a crash
when it's dereferenced.
2. **Potential null pointer dereference (Critical):** Even if
`dlpack_alloc` returns `0` (success), it might still leave `dlpack_tensor` as
`nullptr`. The code would then crash when accessing `dlpack_tensor->dl_tensor`.
A null check for `dlpack_tensor` after the call is necessary.
3. **Redundant `const_cast` (Medium):** The `prototype` variable is of type
`DLTensor*`, which already matches the expected argument type of
`dlpack_alloc`. The `const_cast` is redundant and can be removed for clarity.
The suggested change addresses all these points for robustness and
readability.
```c
DLPackManagedTensorAllocator dlpack_alloc =
tvm::ffi::EnvContext::ThreadLocal()->GetDLPackManagedTensorAllocator();
if (dlpack_alloc == nullptr) {
TVM_FFI_THROW(RuntimeError)
<< "TVMFFIEnvTensorAlloc: allocator is nullptr, likely because "
<< "TVMFFIEnvSetDLPackManagedTensorAllocator has not been called.";
}
DLManagedTensorVersioned* dlpack_tensor = nullptr;
int ret = (*dlpack_alloc)(prototype, &dlpack_tensor, nullptr,
TVMFFIEnvTensorAllocSetError);
if (ret != 0) return ret;
if (dlpack_tensor == nullptr) {
TVM_FFI_THROW(RuntimeError)
<< "TVMFFIEnvTensorAlloc: allocator returned a null tensor on
success.";
}
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]