gemini-code-assist[bot] commented on code in PR #198:
URL: https://github.com/apache/tvm-ffi/pull/198#discussion_r2468656197


##########
python/tvm_ffi/cython/dtype.pxi:
##########
@@ -205,24 +205,28 @@ else:
 
 if ml_dtypes is not None:
     MLDTYPES_DTYPE_TO_DTYPE = {
-        numpy.dtype(ml_dtypes.int2): DLDataType(0, 2, 1),
         numpy.dtype(ml_dtypes.int4): DLDataType(0, 4, 1),
-        numpy.dtype(ml_dtypes.uint2): DLDataType(1, 2, 1),
         numpy.dtype(ml_dtypes.uint4): DLDataType(1, 4, 1),
         numpy.dtype(ml_dtypes.bfloat16): DLDataType(4, 16, 1),
-        numpy.dtype(ml_dtypes.float8_e3m4): DLDataType(7, 8, 1),
-        numpy.dtype(ml_dtypes.float8_e4m3): DLDataType(8, 8, 1),
         numpy.dtype(ml_dtypes.float8_e4m3b11fnuz): DLDataType(9, 8, 1),
         numpy.dtype(ml_dtypes.float8_e4m3fn): DLDataType(10, 8, 1),
         numpy.dtype(ml_dtypes.float8_e4m3fnuz): DLDataType(11, 8, 1),
         numpy.dtype(ml_dtypes.float8_e5m2): DLDataType(12, 8, 1),
         numpy.dtype(ml_dtypes.float8_e5m2fnuz): DLDataType(13, 8, 1),
-        numpy.dtype(ml_dtypes.float8_e8m0fnu): DLDataType(14, 8, 1),
-        numpy.dtype(ml_dtypes.float6_e2m3fn): DLDataType(15, 6, 1),
-        numpy.dtype(ml_dtypes.float6_e3m2fn): DLDataType(16, 6, 1),
-        numpy.dtype(ml_dtypes.float4_e2m1fn): DLDataType(17, 4, 1),
     }
 
+    if hasattr(ml_dtypes, "int2"):  # ml_dtypes >= 0.5.0
+        MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.int2] = DLDataType(0, 2, 1)
+        MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.uint2] = DLDataType(1, 2, 1)
+
+        MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float8_e3m4] = DLDataType(7, 8, 1)
+        MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float8_e4m3] = DLDataType(8, 8, 1)
+        MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float8_e8m0fnu] = DLDataType(14, 8, 
1)
+        MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float6_e2m3fn] = DLDataType(15, 6, 1)
+        MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float6_e3m2fn] = DLDataType(16, 6, 1)
+        MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float4_e2m1fn] = DLDataType(17, 4, 1)
+
+
 if numpy is not None:
     NUMPY_DTYPE_TO_DTYPE = {

Review Comment:
   ![critical](https://www.gstatic.com/codereviewagent/critical.svg)
   
   If `ml_dtypes` is not None, the code defines `MLDTYPES_DTYPE_TO_DTYPE` as a 
dictionary. Then, conditionally based on `hasattr(ml_dtypes, "int2")`, it adds 
more key-value pairs to the dictionary. However, if `ml_dtypes` is None, 
`MLDTYPES_DTYPE_TO_DTYPE` will not be defined, and when numpy is not None, the 
code will raise a `NameError` because `MLDTYPES_DTYPE_TO_DTYPE` is used in 
`NUMPY_DTYPE_TO_DTYPE = { ... , **MLDTYPES_DTYPE_TO_DTYPE, ... }`.
   
   To fix this, `MLDTYPES_DTYPE_TO_DTYPE` should be initialized as an empty 
dictionary even when `ml_dtypes` is None.



##########
python/tvm_ffi/cython/dtype.pxi:
##########
@@ -205,24 +205,28 @@ else:
 
 if ml_dtypes is not None:
     MLDTYPES_DTYPE_TO_DTYPE = {
-        numpy.dtype(ml_dtypes.int2): DLDataType(0, 2, 1),
         numpy.dtype(ml_dtypes.int4): DLDataType(0, 4, 1),
-        numpy.dtype(ml_dtypes.uint2): DLDataType(1, 2, 1),
         numpy.dtype(ml_dtypes.uint4): DLDataType(1, 4, 1),
         numpy.dtype(ml_dtypes.bfloat16): DLDataType(4, 16, 1),
-        numpy.dtype(ml_dtypes.float8_e3m4): DLDataType(7, 8, 1),
-        numpy.dtype(ml_dtypes.float8_e4m3): DLDataType(8, 8, 1),
         numpy.dtype(ml_dtypes.float8_e4m3b11fnuz): DLDataType(9, 8, 1),
         numpy.dtype(ml_dtypes.float8_e4m3fn): DLDataType(10, 8, 1),

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The initial dictionary definition for `MLDTYPES_DTYPE_TO_DTYPE` only 
includes a subset of the data types. This could lead to incomplete support for 
`ml_dtypes` if other data types are used before the hasattr check. It's better 
to initialize the dictionary with all possible types, and then conditionally 
update it based on the `ml_dtypes` version.
   
   Consider initializing all possible types and then conditionally updating the 
dictionary based on the `ml_dtypes` version.



##########
python/tvm_ffi/cython/dtype.pxi:
##########
@@ -205,24 +205,28 @@ else:
 
 if ml_dtypes is not None:
     MLDTYPES_DTYPE_TO_DTYPE = {
-        numpy.dtype(ml_dtypes.int2): DLDataType(0, 2, 1),
         numpy.dtype(ml_dtypes.int4): DLDataType(0, 4, 1),
-        numpy.dtype(ml_dtypes.uint2): DLDataType(1, 2, 1),
         numpy.dtype(ml_dtypes.uint4): DLDataType(1, 4, 1),
         numpy.dtype(ml_dtypes.bfloat16): DLDataType(4, 16, 1),
-        numpy.dtype(ml_dtypes.float8_e3m4): DLDataType(7, 8, 1),
-        numpy.dtype(ml_dtypes.float8_e4m3): DLDataType(8, 8, 1),
         numpy.dtype(ml_dtypes.float8_e4m3b11fnuz): DLDataType(9, 8, 1),
         numpy.dtype(ml_dtypes.float8_e4m3fn): DLDataType(10, 8, 1),
         numpy.dtype(ml_dtypes.float8_e4m3fnuz): DLDataType(11, 8, 1),
         numpy.dtype(ml_dtypes.float8_e5m2): DLDataType(12, 8, 1),
         numpy.dtype(ml_dtypes.float8_e5m2fnuz): DLDataType(13, 8, 1),
-        numpy.dtype(ml_dtypes.float8_e8m0fnu): DLDataType(14, 8, 1),
-        numpy.dtype(ml_dtypes.float6_e2m3fn): DLDataType(15, 6, 1),
-        numpy.dtype(ml_dtypes.float6_e3m2fn): DLDataType(16, 6, 1),
-        numpy.dtype(ml_dtypes.float4_e2m1fn): DLDataType(17, 4, 1),
     }
 
+    if hasattr(ml_dtypes, "int2"):  # ml_dtypes >= 0.5.0
+        MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.int2] = DLDataType(0, 2, 1)
+        MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.uint2] = DLDataType(1, 2, 1)
+
+        MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float8_e3m4] = DLDataType(7, 8, 1)
+        MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float8_e4m3] = DLDataType(8, 8, 1)
+        MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float8_e8m0fnu] = DLDataType(14, 8, 
1)
+        MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float6_e2m3fn] = DLDataType(15, 6, 1)
+        MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float6_e3m2fn] = DLDataType(16, 6, 1)
+        MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.float4_e2m1fn] = DLDataType(17, 4, 1)

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The current implementation uses `MLDTYPES_DTYPE_TO_DTYPE[ml_dtypes.int2] = 
DLDataType(0, 2, 1)` to add new types to the dictionary. However, if 
`ml_dtypes.int2` is not available (i.e., `ml_dtypes < 0.5`), this will raise a 
`NameError`. It's better to use 
`MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.int2)] = DLDataType(0, 2, 1)` to 
avoid the `NameError`.
   
   Consider using `numpy.dtype` to avoid the `NameError`.
   
   ```
           MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.int2)] = DLDataType(0, 
2, 1)
           MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.uint2)] = 
DLDataType(1, 2, 1)
   
           MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e3m4)] = 
DLDataType(7, 8, 1)
           MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e4m3)] = 
DLDataType(8, 8, 1)
           MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float8_e8m0fnu)] = 
DLDataType(14, 8, 1)
           MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float6_e2m3fn)] = 
DLDataType(15, 6, 1)
           MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float6_e3m2fn)] = 
DLDataType(16, 6, 1)
           MLDTYPES_DTYPE_TO_DTYPE[numpy.dtype(ml_dtypes.float4_e2m1fn)] = 
DLDataType(17, 4, 1)
   ```



##########
python/tvm_ffi/cython/dtype.pxi:
##########
@@ -205,24 +205,28 @@ else:
 
 if ml_dtypes is not None:
     MLDTYPES_DTYPE_TO_DTYPE = {
-        numpy.dtype(ml_dtypes.int2): DLDataType(0, 2, 1),
         numpy.dtype(ml_dtypes.int4): DLDataType(0, 4, 1),
-        numpy.dtype(ml_dtypes.uint2): DLDataType(1, 2, 1),
         numpy.dtype(ml_dtypes.uint4): DLDataType(1, 4, 1),
         numpy.dtype(ml_dtypes.bfloat16): DLDataType(4, 16, 1),
-        numpy.dtype(ml_dtypes.float8_e3m4): DLDataType(7, 8, 1),
-        numpy.dtype(ml_dtypes.float8_e4m3): DLDataType(8, 8, 1),
         numpy.dtype(ml_dtypes.float8_e4m3b11fnuz): DLDataType(9, 8, 1),
         numpy.dtype(ml_dtypes.float8_e4m3fn): DLDataType(10, 8, 1),
         numpy.dtype(ml_dtypes.float8_e4m3fnuz): DLDataType(11, 8, 1),
         numpy.dtype(ml_dtypes.float8_e5m2): DLDataType(12, 8, 1),
         numpy.dtype(ml_dtypes.float8_e5m2fnuz): DLDataType(13, 8, 1),
-        numpy.dtype(ml_dtypes.float8_e8m0fnu): DLDataType(14, 8, 1),
-        numpy.dtype(ml_dtypes.float6_e2m3fn): DLDataType(15, 6, 1),
-        numpy.dtype(ml_dtypes.float6_e3m2fn): DLDataType(16, 6, 1),
-        numpy.dtype(ml_dtypes.float4_e2m1fn): DLDataType(17, 4, 1),
     }
 
+    if hasattr(ml_dtypes, "int2"):  # ml_dtypes >= 0.5.0

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The comment `# ml_dtypes >= 0.5.0` is helpful, but it would be more robust 
to check the actual version number of `ml_dtypes` instead of just checking for 
the existence of the `int2` attribute. This would provide more accurate version 
detection and prevent potential issues if the `int2` attribute is present in 
versions earlier than 0.5.0.
   
   Consider checking the actual version number of `ml_dtypes`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to