gemini-code-assist[bot] commented on code in PR #18453:
URL: https://github.com/apache/tvm/pull/18453#discussion_r2529277596
##########
src/arith/const_fold.h:
##########
@@ -349,53 +349,53 @@ inline ffi::Optional<PrimExpr>
TryConstFold<tir::Max>(PrimExpr a, PrimExpr b) {
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(), pa->value > pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(), fa->value > fb->value);
});
return std::nullopt;
}
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(), pa->value >= pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(), fa->value >= fb->value);
});
return std::nullopt;
}
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(), pa->value < pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(), fa->value < fb->value);
});
return std::nullopt;
}
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(1), pa->value <= pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(1), fa->value <= fb->value);
});
return std::nullopt;
}
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::EQ>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(1), pa->value == pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(1), fa->value == fb->value);
});
return std::nullopt;
}
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::NE>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value != pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value != fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(1), pa->value != pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(1), fa->value != fb->value);
Review Comment:

For consistency, please use `DataType::Bool()` instead of
`DataType::Bool(1)`.
```suggestion
if (pa && pb) return IntImm(DataType::Bool(), pa->value != pb->value);
if (fa && fb) return IntImm(DataType::Bool(), fa->value != fb->value);
```
##########
src/arith/const_fold.h:
##########
@@ -349,53 +349,53 @@ inline ffi::Optional<PrimExpr>
TryConstFold<tir::Max>(PrimExpr a, PrimExpr b) {
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(), pa->value > pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(), fa->value > fb->value);
});
return std::nullopt;
}
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(), pa->value >= pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(), fa->value >= fb->value);
});
return std::nullopt;
}
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(), pa->value < pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(), fa->value < fb->value);
});
return std::nullopt;
}
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(1), pa->value <= pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(1), fa->value <= fb->value);
Review Comment:

For consistency with other comparison operators (`GT`, `GE`, `LT`) and
`Not`, please use `DataType::Bool()` instead of `DataType::Bool(1)`. While they
are functionally equivalent, using `DataType::Bool()` is more concise and
consistent across the file. This also applies to `EQ` and `NE` operators in
this file.
```suggestion
if (pa && pb) return IntImm(DataType::Bool(), pa->value <= pb->value);
if (fa && fb) return IntImm(DataType::Bool(), fa->value <= fb->value);
```
##########
src/arith/const_fold.h:
##########
@@ -349,53 +349,53 @@ inline ffi::Optional<PrimExpr>
TryConstFold<tir::Max>(PrimExpr a, PrimExpr b) {
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::GT>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value > pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value > fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(), pa->value > pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(), fa->value > fb->value);
});
return std::nullopt;
}
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::GE>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value >= pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value >= fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(), pa->value >= pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(), fa->value >= fb->value);
});
return std::nullopt;
}
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::LT>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value < pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value < fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(), pa->value < pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(), fa->value < fb->value);
});
return std::nullopt;
}
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::LE>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value <= pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value <= fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(1), pa->value <= pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(1), fa->value <= fb->value);
});
return std::nullopt;
}
template <>
inline ffi::Optional<PrimExpr> TryConstFold<tir::EQ>(PrimExpr a, PrimExpr b) {
TVM_ARITH_CONST_PROPAGATION({
- if (pa && pb) return IntImm(DataType::UInt(1), pa->value == pb->value);
- if (fa && fb) return IntImm(DataType::UInt(1), fa->value == fb->value);
+ if (pa && pb) return IntImm(DataType::Bool(1), pa->value == pb->value);
+ if (fa && fb) return IntImm(DataType::Bool(1), fa->value == fb->value);
Review Comment:

For consistency, please use `DataType::Bool()` instead of
`DataType::Bool(1)`.
```suggestion
if (pa && pb) return IntImm(DataType::Bool(), pa->value == pb->value);
if (fa && fb) return IntImm(DataType::Bool(), fa->value == fb->value);
```
##########
include/tvm/runtime/data_type.h:
##########
@@ -138,7 +139,9 @@ class DataType {
/*! \return whether type is a scalar type. */
bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
/*! \return whether type is a scalar type. */
- bool is_bool() const { return code() == DataType::kUInt && bits() == 1; }
+ bool is_bool() const { return code() == DataType::kBool; }
Review Comment:

The Doxygen comment for `is_bool()` appears to be a copy-paste from
`is_scalar()`. It should be updated to reflect that it checks for a boolean
type.
```c
/*! \return whether type is a bool type. */
bool is_bool() const { return code() == DataType::kBool; }
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]