gemini-code-assist[bot] commented on code in PR #18421: URL: https://github.com/apache/tvm/pull/18421#discussion_r2533514663
########## src/tir/transforms/canonicalize_loop.cc: ########## @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/transforms/canonicalize_loop.cc + * \brief Canonicalize all loops to start from zero and step one. + */ +#include <tvm/arith/analyzer.h> +#include <tvm/ffi/reflection/registry.h> +#include <tvm/runtime/device_api.h> +#include <tvm/tir/function.h> +#include <tvm/tir/op.h> +#include <tvm/tir/stmt_functor.h> +#include <tvm/tir/transform.h> + +#include <utility> + +namespace tvm { +namespace tir { + +class LoopCanonicalizer : public StmtExprMutator { + public: + LoopCanonicalizer() = default; + + private: + Stmt VisitStmt_(const ForNode* op) final { + if (is_zero(op->min) && op->HasTrivialStep()) { + return StmtExprMutator::VisitStmt_(op); + } + arith::Analyzer analyzer; + const auto* loop_var = op->loop_var.get(); + PrimExpr step = op->step.value_or(make_const(loop_var->dtype, 1)); + + // report warning for negative step, since it would be a forever loop + if (!analyzer.CanProveGreaterEqual(step, 1)) { + // TODO(tvm): prove dynamic shaped step + LOG(FATAL) << "Loop step for " << op->loop_var << " may not be positive: " << step; + } + + new_iter_info_[loop_var] = std::make_pair(step, op->min); + auto n = CopyOnWrite(op); + n->body = VisitStmt(op->body); + n->min = make_zero(loop_var->dtype); + n->extent = analyzer.Simplify(ceildiv(op->extent - op->min + 1, step)); Review Comment:  The calculation for the new loop extent seems incorrect. The `ForNode` documentation states that a loop runs from `min` to `min + extent`. With a step, the number of iterations should be `ceildiv(extent, step)`. The current formula `ceildiv(op->extent - op->min + 1, step)` appears to be incorrect. For example, for a loop `for i in range(2, 12, 3)`, we have `min=2` and `extent=10`. The loop should have 4 iterations (2, 5, 8, 11). The correct formula `ceildiv(10, 3)` gives 4. However, the formula in the code gives `ceildiv(10 - 2 + 1, 3) = ceildiv(9, 3) = 3`, which is incorrect. The correct formula for the number of iterations should be `ceildiv(op->extent, step)`. ```suggestion n->extent = analyzer.Simplify(ceildiv(op->extent, step)); ``` ########## tests/python/codegen/test_target_codegen_cuda.py: ########## @@ -877,5 +877,37 @@ def main(A: T.Buffer((16, 16), "float32"), B: T.Buffer((16, 16), "float32")): assert "return;" in cuda_code [email protected]_gpu [email protected]_cuda +def test_cuda_loop_step(): + @T.prim_func + def cuda_loop_step( + A: T.Buffer((1024,), "float32"), + B: T.Buffer((1024,), "float32"), + C: T.Buffer((1024,), "float32"), + ): + # Each thread computes a strided subset of the i loop: start = tx*3, step = 96 (3 * 32 threads) + for bx in T.thread_binding(1, "blockIdx.x"): + for tx in T.thread_binding(96, "threadIdx.x"): + for i in T.serial(tx, 4096, step=96): Review Comment:  The loop's stop condition `4096` will lead to out-of-bounds memory access, as the buffers `A`, `B`, and `C` are all defined with a size of 1024. The test expects all elements to be computed, which means the loop should cover indices from 0 to 1023. To achieve this with the given threading scheme, the stop value should be 1024. ```suggestion for i in T.serial(tx, 1024, step=96): ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
