gemini-code-assist[bot] commented on code in PR #18421:
URL: https://github.com/apache/tvm/pull/18421#discussion_r2532632309
##########
src/target/source/codegen_webgpu.cc:
##########
@@ -668,12 +668,19 @@ void CodeGenWebGPU::VisitStmt_(const AllocateNode* op) {
void CodeGenWebGPU::VisitStmt_(const ForNode* op) {
std::string extent = PrintExpr(op->extent);
+ std::string step = op->step.has_value()? PrintExpr(*op->step) : "";
std::string vid = AllocVarID(op->loop_var.get());
ICHECK(is_zero(op->min));
PrintIndent();
stream << "for (var " << vid << " : ";
PrintType(op->loop_var.dtype(), stream);
- stream << " = 0; " << vid << " < " << extent << "; " << vid << "++) {\n";
+ stream << " = 0; " << vid << " < " << extent << "; " << vid;
+ if (step.empty()) {
+ stream << "++";
+ } else {
+ stream << " += " << step;
+ }
+ stream << ") {\n";
int for_scope = BeginScope();
PrintStmt(op->body);
this->EndScope(for_scope);
Review Comment:
It looks like you've made good progress on supporting non-zero `min` and
non-trivial `step` in the WebGPU codegen! However, I noticed that the
`ICHECK(is_zero(op->min));` on line 675 is still present. This check will
prevent the codegen from handling loops where `min` is not zero, effectively
negating the new functionality.
Please remove this `ICHECK` to fully enable support for non-zero `min`
values.
```suggestion
void CodeGenWebGPU::VisitStmt_(const ForNode* op) {
std::string begin_str = PrintExpr(op->min);
PrimExpr end = is_zero(op->min) ? op->extent :
arith::Analyzer().Simplify(op->min + op->extent);
std::string end_str = PrintExpr(end);
std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : "";
std::string vid = AllocVarID(op->loop_var.get());
PrintIndent();
stream << "for (var " << vid << " : ";
PrintType(op->loop_var.dtype(), stream);
stream << " = " << begin_str << "; " << vid << " < " << end_str << "; " <<
vid;
if (step_str.empty()) {
stream << "++";
} else {
stream << " += " << step_str;
}
stream << ") {\n";
int for_scope = BeginScope();
```
```suggestion
void CodeGenWebGPU::VisitStmt_(const ForNode* op) {
std::string begin_str = PrintExpr(op->min);
PrimExpr end = is_zero(op->min) ? op->extent :
arith::Analyzer().Simplify(op->min + op->extent);
std::string end_str = PrintExpr(end);
std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : "";
std::string vid = AllocVarID(op->loop_var.get());
PrintIndent();
stream << "for (var " << vid << " : ";
PrintType(op->loop_var.dtype(), stream);
stream << " = " << begin_str << "; " << vid << " < " << end_str << "; " <<
vid;
if (step_str.empty()) {
stream << "++";
} else {
stream << " += " << step_str;
}
stream << ") {\n";
int for_scope = BeginScope();
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]