gemini-code-assist[bot] commented on code in PR #253:
URL: https://github.com/apache/tvm-ffi/pull/253#discussion_r2511800167


##########
tests/scripts/benchmark_dlpack.py:
##########
@@ -337,17 +337,30 @@ def load_torch_get_current_cuda_stream() -> 
Callable[[int], int]:
     """Create a faster get_current_cuda_stream for torch through cpp 
extension."""
     from torch.utils import cpp_extension  # noqa: PLC0415
 
-    source = """
-    #include <c10/cuda/CUDAStream.h>
-
-    int64_t get_current_cuda_stream(int device_id) {
-        at::cuda::CUDAStream stream = 
at::cuda::getCurrentCUDAStream(device_id);
-        // fast invariant, default stream is always 0
-        if (stream.id() == 0) return 0;
-        // convert to cudaStream_t
-        return reinterpret_cast<int64_t>(static_cast<cudaStream_t>(stream));
-    }
-    """
+    if torch.version.cuda is not None:
+        source = """
+        #include <c10/cuda/CUDAStream.h>
+
+        int64_t get_current_cuda_stream(int device_id) {
+            at::cuda::CUDAStream stream = 
at::cuda::getCurrentCUDAStream(device_id);
+            // fast invariant, default stream is always 0
+            if (stream.id() == 0) return 0;
+            // convert to cudaStream_t
+            return 
reinterpret_cast<int64_t>(static_cast<cudaStream_t>(stream));
+        }
+        """
+    elif torch.version.hip is not None:
+        source = """
+        #include <c10/hip/HIPStream.h>
+
+        int64_t get_current_cuda_stream(int device_id) {
+            at::hip::HIPStream stream = 
at::hip::getCurrentHIPStream(device_id);
+            // fast invariant, default stream is always 0
+            if (stream.id() == 0) return 0;
+            // convert to cudaStream_t

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   This comment appears to be a copy-paste error from the CUDA block. It should 
refer to `hipStream_t`.
   
   ```suggestion
               // convert to hipStream_t
   ```



##########
python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py:
##########
@@ -759,7 +774,10 @@ def main() -> None:  # noqa: PLR0912, PLR0915
 
         if args.build_with_cuda:
             cflags.append("-DBUILD_WITH_CUDA")
-        include_paths.extend(get_torch_include_paths(args.build_with_cuda))
+        if args.build_with_rocm:
+            cflags.extend(torch.utils.cpp_extension.COMMON_HIP_FLAGS)
+            cflags.append("-DBUILD_WITH_ROCM")

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   Since `build_with_cuda` and `build_with_rocm` are mutually exclusive (as 
checked on lines 727-728), using an `elif` here would make the logic clearer 
and explicitly show that only one of these blocks can be executed.
   
   ```suggestion
           if args.build_with_cuda:
               cflags.append("-DBUILD_WITH_CUDA")
           elif args.build_with_rocm:
               cflags.extend(torch.utils.cpp_extension.COMMON_HIP_FLAGS)
               cflags.append("-DBUILD_WITH_ROCM")
   ```



##########
tests/scripts/benchmark_dlpack.py:
##########
@@ -337,17 +337,30 @@ def load_torch_get_current_cuda_stream() -> 
Callable[[int], int]:
     """Create a faster get_current_cuda_stream for torch through cpp 
extension."""
     from torch.utils import cpp_extension  # noqa: PLC0415
 
-    source = """
-    #include <c10/cuda/CUDAStream.h>
-
-    int64_t get_current_cuda_stream(int device_id) {
-        at::cuda::CUDAStream stream = 
at::cuda::getCurrentCUDAStream(device_id);
-        // fast invariant, default stream is always 0
-        if (stream.id() == 0) return 0;
-        // convert to cudaStream_t
-        return reinterpret_cast<int64_t>(static_cast<cudaStream_t>(stream));
-    }
-    """
+    if torch.version.cuda is not None:
+        source = """
+        #include <c10/cuda/CUDAStream.h>
+
+        int64_t get_current_cuda_stream(int device_id) {
+            at::cuda::CUDAStream stream = 
at::cuda::getCurrentCUDAStream(device_id);
+            // fast invariant, default stream is always 0
+            if (stream.id() == 0) return 0;
+            // convert to cudaStream_t
+            return 
reinterpret_cast<int64_t>(static_cast<cudaStream_t>(stream));
+        }
+        """
+    elif torch.version.hip is not None:
+        source = """
+        #include <c10/hip/HIPStream.h>
+
+        int64_t get_current_cuda_stream(int device_id) {
+            at::hip::HIPStream stream = 
at::hip::getCurrentHIPStream(device_id);
+            // fast invariant, default stream is always 0
+            if (stream.id() == 0) return 0;
+            // convert to cudaStream_t
+            return reinterpret_cast<int64_t>(static_cast<hipStream_t>(stream));
+        }
+        """

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   There's significant code duplication between the CUDA and ROCm source 
strings. This can be avoided by using C++ preprocessor macros to abstract the 
backend-specific types and function calls, making the code more maintainable. 
You could define the `source` string once as a template, and then use 
`extra_cflags` to pass the appropriate `-D` flag (`-DTVM_FFI_USE_CUDA` or 
`-DTVM_FFI_USE_ROCM`) to the compiler.



##########
tests/python/test_tensor.py:
##########
@@ -112,3 +112,23 @@ def __tvm_ffi_object__(self) -> tvm_ffi.Tensor:
     fecho = tvm_ffi.get_global_func("testing.echo")
     z = fecho(y)
     assert z.__chandle__() == x.__chandle__()
+
+
[email protected](
+    torch is None or torch.version.hip is None, reason="ROCm is not enabled in 
PyTorch"
+)
+def test_tensor_from_pytorch_rocm() -> None:
+    assert torch is not None
+
+    @tvm_ffi.register_global_func("testing.check_device", override=True)
+    def _check_device(x: tvm_ffi.Tensor) -> bool:

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The type hint for the return value of `_check_device` is `bool`, but the 
function returns `x.device.type`, which is a string representing the device 
type (e.g., 'rocm'). The type hint should be corrected to `str`.
   
   ```suggestion
       def _check_device(x: tvm_ffi.Tensor) -> str:
   ```



##########
addons/torch_c_dlpack_ext/build_aot_wheels.sh:
##########
@@ -83,6 +83,7 @@ function build_libs() {
         mkdir "$tvm_ffi"/lib -p
         python -m tvm_ffi.utils._build_optional_torch_c_dlpack --output-dir 
"$tvm_ffi"/lib
         python -m tvm_ffi.utils._build_optional_torch_c_dlpack --output-dir 
"$tvm_ffi"/lib --build-with-cuda
+        python -m tvm_ffi.utils._build_optional_torch_c_dlpack --output-dir 
"$tvm_ffi"/lib --build-with-rocm

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   These three commands are very similar. To improve maintainability and reduce 
duplication, you could use a loop to iterate over the different build flags.
   
   ```suggestion
           for flag in "" "--build-with-cuda" "--build-with-rocm"; do
               python -m tvm_ffi.utils._build_optional_torch_c_dlpack 
--output-dir "$tvm_ffi"/lib $flag
           done
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to