gemini-code-assist[bot] commented on code in PR #18421:
URL: https://github.com/apache/tvm/pull/18421#discussion_r2532632309


##########
src/target/source/codegen_webgpu.cc:
##########
@@ -668,12 +668,19 @@ void CodeGenWebGPU::VisitStmt_(const AllocateNode* op) {
 
 void CodeGenWebGPU::VisitStmt_(const ForNode* op) {
   std::string extent = PrintExpr(op->extent);
+  std::string step = op->step.has_value()? PrintExpr(*op->step) : "";
   std::string vid = AllocVarID(op->loop_var.get());
   ICHECK(is_zero(op->min));
   PrintIndent();
   stream << "for (var " << vid << " : ";
   PrintType(op->loop_var.dtype(), stream);
-  stream << " = 0; " << vid << " < " << extent << "; " << vid << "++) {\n";
+  stream << " = 0; " << vid << " < " << extent << "; " << vid;
+  if (step.empty()) {
+    stream << "++";
+  } else {
+    stream << " += " << step; 
+  }
+  stream << ") {\n";
   int for_scope = BeginScope();
   PrintStmt(op->body);
   this->EndScope(for_scope);

Review Comment:
   It looks like you've made good progress on supporting non-zero `min` and 
non-trivial `step` in the WebGPU codegen! However, I noticed that the 
`ICHECK(is_zero(op->min));` on line 675 is still present. This check will 
prevent the codegen from handling loops where `min` is not zero, effectively 
negating the new functionality.
   
   Please remove this `ICHECK` to fully enable support for non-zero `min` 
values.
   
   ```suggestion
   void CodeGenWebGPU::VisitStmt_(const ForNode* op) {
     std::string begin_str = PrintExpr(op->min);
     PrimExpr end = is_zero(op->min) ? op->extent : 
arith::Analyzer().Simplify(op->min + op->extent);
     std::string end_str = PrintExpr(end);
     std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : "";
     std::string vid = AllocVarID(op->loop_var.get());
     PrintIndent();
     stream << "for (var " << vid << " : ";
     PrintType(op->loop_var.dtype(), stream);
     stream << " = " << begin_str << "; " << vid << " < " << end_str << "; " << 
vid;
     if (step_str.empty()) {
       stream << "++";
     } else {
       stream << " += " << step_str;
     }
     stream << ") {\n";
     int for_scope = BeginScope();
   ```
   
   ```suggestion
   void CodeGenWebGPU::VisitStmt_(const ForNode* op) {
     std::string begin_str = PrintExpr(op->min);
     PrimExpr end = is_zero(op->min) ? op->extent : 
arith::Analyzer().Simplify(op->min + op->extent);
     std::string end_str = PrintExpr(end);
     std::string step_str = op->step.has_value() ? PrintExpr(*op->step) : "";
     std::string vid = AllocVarID(op->loop_var.get());
     PrintIndent();
     stream << "for (var " << vid << " : ";
     PrintType(op->loop_var.dtype(), stream);
     stream << " = " << begin_str << "; " << vid << " < " << end_str << "; " << 
vid;
     if (step_str.empty()) {
       stream << "++";
     } else {
       stream << " += " << step_str;
     }
     stream << ") {\n";
     int for_scope = BeginScope();
   
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to