gemini-code-assist[bot] commented on code in PR #18513:
URL: https://github.com/apache/tvm/pull/18513#discussion_r2564648992
##########
src/tir/schedule/primitive/compute_inline.cc:
##########
@@ -1160,20 +1286,54 @@ Block
ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode* reducti
var_map[epilogue_data_vars[i]] = reduction_data_vars[i];
}
- // 2. Change init to epilogue value: D[vi, vj] = C[vi, vj]
- BufferStore new_init_store(epilogue_output_buffer_,
Substitute(epilogue_addend_, var_map),
- Substitute(epilogue_output_indices_, var_map));
+ // 2. Change init to epilogue value based on epilogue type
+ BufferStore new_init_store;
+ if (epilogue_type_ == EpilogueType::BiasReLU) {
+ // For ReLU, init should be max(C[vi, vj], 0) to match per-iteration ReLU
semantics
+ PrimExpr init_value = Substitute(epilogue_addend_, var_map);
+ PrimExpr zero = tir::make_zero(init_value.dtype());
+ new_init_store = BufferStore(epilogue_output_buffer_, Max(init_value,
zero),
+ Substitute(epilogue_output_indices_,
var_map));
+ } else if (epilogue_type_ == EpilogueType::Clipping) {
+ // For Clipping, init should be min(max(init_value, lower), upper)
+ // Since init is typically 0, this becomes min(max(0, lower), upper)
+ PrimExpr init_value = tir::make_zero(epilogue_output_buffer_->dtype);
+ PrimExpr clipped_init = Min(Max(init_value, Substitute(clipping_lower_,
var_map)),
+ Substitute(clipping_upper_, var_map));
+ new_init_store = BufferStore(epilogue_output_buffer_, clipped_init,
+ Substitute(epilogue_output_indices_,
var_map));
+ } else {
+ // Bias: D[vi, vj] = C[vi, vj]
+ new_init_store = BufferStore(epilogue_output_buffer_,
Substitute(epilogue_addend_, var_map),
+ Substitute(epilogue_output_indices_,
var_map));
+ }
new_block->init = new_init_store;
// 3. Replace output buffer from temp to D in body
class BufferReplacer : public StmtExprMutator {
public:
- BufferReplacer(Buffer old_buf, Buffer new_buf) : old_buffer_(old_buf),
new_buffer_(new_buf) {}
+ BufferReplacer(Buffer old_buf, Buffer new_buf, EpilogueType epilogue_type,
DataType dtype,
+ PrimExpr clipping_lower = PrimExpr(), PrimExpr
clipping_upper = PrimExpr())
+ : old_buffer_(old_buf),
+ new_buffer_(new_buf),
+ epilogue_type_(epilogue_type),
+ dtype_(dtype),
+ clipping_lower_(clipping_lower),
+ clipping_upper_(clipping_upper) {}
Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store =
Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
if (store->buffer.same_as(old_buffer_)) {
- return BufferStore(new_buffer_, store->value, store->indices);
+ PrimExpr new_value = store->value;
+ // For ReLU, apply max per iteration to match per-iteration ReLU
semantics
+ if (epilogue_type_ == EpilogueType::BiasReLU) {
+ PrimExpr zero = tir::make_zero(dtype_);
+ new_value = Max(new_value, zero);
+ } else if (epilogue_type_ == EpilogueType::Clipping) {
+ // For Clipping, apply min(max(value, lower), upper) per iteration
+ new_value = Min(Max(new_value, clipping_lower_), clipping_upper_);
+ }
Review Comment:

This fusion for non-linear epilogues like `Clipping` and `BiasReLU` changes
the semantics of the computation. The original operation is a post-reduction
clipping/ReLU (e.g., `clip(sum(A))`), while the fused version performs
per-iteration clipping/ReLU (e.g., `sum'(clip(acc + a_i))`). These two are not
numerically equivalent in general.
For example, with clipping to `[-5, 5]` and inputs `[6, -2]`:
- Post-reduction clipping: `clip(6 + (-2)) = clip(4) = 4`.
- Per-iteration clipping: `acc = 0; acc = clip(0 + 6) = 5; acc = clip(5 +
(-2)) = 3`. The result is `3`.
This semantic change is a significant contract change for
`fuse_reduction_epilogue`. While it might be an intended optimization for some
models, it can lead to unexpected correctness issues if users are not aware of
it.
This behavior should be explicitly and clearly documented in the
`fuse_reduction_epilogue` docstring to prevent misuse. The current docstring is
outdated and only mentions simple addition. It should be updated to include all
supported patterns and a prominent warning about the semantic change for
non-linear epilogues.
##########
src/tir/schedule/primitive/compute_inline.cc:
##########
@@ -1090,10 +1109,117 @@ bool
ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
// Ensure exactly one operand is from the reduction buffer
if (a_is_target != b_is_target) {
epilogue_addend_ = a_is_target ? add->b : add->a;
+ epilogue_type_ = EpilogueType::Bias;
return true;
}
}
+ // Pattern 2: min(max(temp[i,j], lower), upper) or max(min(temp[i,j],
upper), lower) (Clipping)
+ // Handle all commutative variants of min/max at each level.
+
+ // Helper to check if an expression is a load from the reduction buffer, and
+ // return the other operand as `other` if so.
+ auto match_buffer_in_commutative_op = [this](const PrimExpr& a, const
PrimExpr& b,
+ PrimExpr* other) -> bool {
+ if (const auto* load_a = a.as<BufferLoadNode>()) {
+ if (load_a->buffer.same_as(inlined_buffer_)) {
+ *other = b;
+ return true;
+ }
+ }
+ if (const auto* load_b = b.as<BufferLoadNode>()) {
+ if (load_b->buffer.same_as(inlined_buffer_)) {
+ *other = a;
+ return true;
+ }
+ }
+ return false;
+ };
+
+ // Check for min(max(temp, lower), upper) and commutative variants
+ if (const auto* min_node = value.as<MinNode>()) {
+ const MaxNode* max_node = nullptr;
+ PrimExpr upper;
+ // Try both (a, b) as possible positions of the inner max
+ if ((max_node = min_node->a.as<MaxNode>())) {
+ upper = min_node->b;
+ } else if ((max_node = min_node->b.as<MaxNode>())) {
+ upper = min_node->a;
+ }
+ if (max_node != nullptr) {
+ PrimExpr lower;
+ if (match_buffer_in_commutative_op(max_node->a, max_node->b, &lower)) {
+ clipping_lower_ = lower;
+ clipping_upper_ = upper;
+ epilogue_type_ = EpilogueType::Clipping;
+ return true;
+ }
+ }
+ }
+
+ // Check for max(min(temp[i,j], upper), lower) and commutative variants
+ if (const auto* max_node = value.as<MaxNode>()) {
+ const MinNode* min_node = nullptr;
+ PrimExpr lower;
+ // Try both (a, b) as possible positions of the inner min
+ if ((min_node = max_node->a.as<MinNode>())) {
+ lower = max_node->b;
+ } else if ((min_node = max_node->b.as<MinNode>())) {
+ lower = max_node->a;
+ }
+ if (min_node != nullptr) {
+ PrimExpr upper;
+ if (match_buffer_in_commutative_op(min_node->a, min_node->b, &upper)) {
+ clipping_lower_ = lower;
+ clipping_upper_ = upper;
+ epilogue_type_ = EpilogueType::Clipping;
+ return true;
+ }
+ }
+ }
+
+ // Pattern 3: max(temp[i,j] + C[i,j], 0) or max(C[i,j] + temp[i,j], 0)
(BiasReLU)
+ if (const auto* max_node = value.as<MaxNode>()) {
+ // Check if either operand is zero (ReLU: max(x, 0) or max(0, x))
+ // Support both integer and float zero constants.
+ const PrimExpr* add_candidate = nullptr;
+ bool is_zero_const = false;
+ auto is_zero_expr = [](const PrimExpr& expr) -> bool {
+ if (tir::is_zero(expr)) {
+ return true;
+ }
+ if (const auto* float_imm = expr.as<FloatImmNode>()) {
+ return float_imm->value == 0.0;
+ }
+ return false;
Review Comment:

The implementation of `is_zero_expr` is redundant. The call to
`tir::is_zero(expr)` already checks for both integer and floating-point zeros.
You can simplify this lambda.
```c
return tir::is_zero(expr);
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]