gemini-code-assist[bot] commented on code in PR #18515:
URL: https://github.com/apache/tvm/pull/18515#discussion_r2566949412


##########
src/tir/schedule/primitive/compute_inline.cc:
##########
@@ -1060,26 +1085,35 @@ bool 
ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize& epilogue
     return false;
   }
 
-  // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j]
+  // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j] or
+  //    D[i,j] = min(max(temp[i,j], lower), upper)
   if (!AnalyzeEpiloguePattern(inlined_store_->value)) {
-    // Failure: epilogue is not a simple addition pattern
+    // Failure: epilogue is not a supported pattern (Bias, BiasReLU, or 
Clipping)
     return false;
   }
 
-  // 5. Check if producer is a reduction block
+  // 5. For Clipping pattern, verify temp appears exactly once
+  if (epilogue_type_ == EpilogueType::Clipping) {
+    if (loads.size() != 1) {
+      // Failure: temp must appear exactly once in clipping pattern
+      return false;
+    }
+  }

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The check `loads.size() != 1` is currently only performed for the `Clipping` 
epilogue type. However, the `Bias` and `BiasReLU` patterns (`temp + C` and 
`max(temp + C, 0)`) also imply a single use of the reduction result buffer 
(`temp`). If `temp` appears multiple times (e.g., in a pattern like `temp + C + 
temp`), the fusion logic might be incorrect or produce unexpected results. It 
would be safer to enforce this single-use constraint for all supported epilogue 
patterns to ensure correctness and predictability.
   
   ```c
     // 5. Verify temp appears exactly once in the epilogue pattern
     if (loads.size() != 1) {
       // Failure: The reduction result (temp) must be used exactly once in the
       // epilogue expression for fusion.
       return false;
     }
   ```



##########
src/tir/schedule/primitive/compute_inline.cc:
##########
@@ -1090,10 +1124,118 @@ bool 
ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
     // Ensure exactly one operand is from the reduction buffer
     if (a_is_target != b_is_target) {
       epilogue_addend_ = a_is_target ? add->b : add->a;
+      epilogue_type_ = EpilogueType::Bias;
       return true;
     }
   }
 
+  // Pattern 2: min(max(temp[i,j], lower), upper) or max(min(temp[i,j], 
upper), lower) (Clipping)
+  // Handle all commutative variants of min/max at each level.
+
+  // Helper to check if an expression is a load from the reduction buffer, and
+  // return the other operand as `other` if so.
+  auto match_buffer_in_commutative_op = [this](const PrimExpr& a, const 
PrimExpr& b,
+                                               PrimExpr* other) -> bool {
+    if (const auto* load_a = a.as<BufferLoadNode>()) {
+      if (load_a->buffer.same_as(inlined_buffer_)) {
+        *other = b;
+        return true;
+      }
+    }
+    if (const auto* load_b = b.as<BufferLoadNode>()) {
+      if (load_b->buffer.same_as(inlined_buffer_)) {
+        *other = a;
+        return true;
+      }
+    }
+    return false;
+  };
+
+  // Check for min(max(temp, lower), upper) and commutative variants
+  if (const auto* min_node = value.as<MinNode>()) {
+    const MaxNode* max_node = nullptr;
+    PrimExpr upper;
+    // Try both (a, b) as possible positions of the inner max
+    if ((max_node = min_node->a.as<MaxNode>())) {
+      upper = min_node->b;
+    } else if ((max_node = min_node->b.as<MaxNode>())) {
+      upper = min_node->a;
+    }
+    if (max_node != nullptr) {
+      PrimExpr lower;
+      if (match_buffer_in_commutative_op(max_node->a, max_node->b, &lower)) {
+        clipping_lower_ = lower;
+        clipping_upper_ = upper;
+        epilogue_type_ = EpilogueType::Clipping;
+        return true;
+      }
+    }
+  }
+
+  // Check for max(min(temp[i,j], upper), lower) and commutative variants
+  if (const auto* max_node = value.as<MaxNode>()) {
+    const MinNode* min_node = nullptr;
+    PrimExpr lower;
+    // Try both (a, b) as possible positions of the inner min
+    if ((min_node = max_node->a.as<MinNode>())) {
+      lower = max_node->b;
+    } else if ((min_node = max_node->b.as<MinNode>())) {
+      lower = max_node->a;
+    }
+    if (min_node != nullptr) {
+      PrimExpr upper;
+      if (match_buffer_in_commutative_op(min_node->a, min_node->b, &upper)) {
+        clipping_lower_ = lower;
+        clipping_upper_ = upper;
+        epilogue_type_ = EpilogueType::Clipping;
+        return true;
+      }
+    }
+  }
+
+  // Pattern 3: max(temp[i,j] + C[i,j], 0) or max(C[i,j] + temp[i,j], 0) 
(BiasReLU)
+  // Also handle max(0, temp[i,j] + C[i,j]) or max(0, C[i,j] + temp[i,j])
+  if (const auto* max_node = value.as<MaxNode>()) {
+    // Check if either operand is zero (ReLU: max(x, 0) or max(0, x))
+    // Support both integer and float zero constants.
+    const PrimExpr* add_candidate = nullptr;
+    bool is_zero_const = false;
+    auto is_zero_expr = [](const PrimExpr& expr) -> bool {
+      if (tir::is_zero(expr)) {
+        return true;
+      }
+      if (const auto* float_imm = expr.as<FloatImmNode>()) {
+        return float_imm->value == 0.0;
+      }
+      return false;
+    };
+
+    if (is_zero_expr(max_node->a)) {
+      is_zero_const = true;
+      add_candidate = &max_node->b;
+    } else if (is_zero_expr(max_node->b)) {
+      is_zero_const = true;
+      add_candidate = &max_node->a;
+    }
+
+    if (is_zero_const && add_candidate != nullptr) {
+      if (const auto* add = add_candidate->as<AddNode>()) {
+        const auto* load_a = add->a.as<BufferLoadNode>();
+        const auto* load_b = add->b.as<BufferLoadNode>();
+
+        bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_);
+        bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_);
+
+        // Ensure exactly one operand is from the reduction buffer
+        if (a_is_target != b_is_target) {
+          epilogue_addend_ = a_is_target ? add->b : add->a;
+          epilogue_type_ = EpilogueType::BiasReLU;
+          return true;
+        }
+      }
+    }

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The current implementation for the `BiasReLU` pattern matches `max(temp + C, 
0)` but doesn't seem to handle the case where the bias `C` is zero, i.e., 
`max(temp, 0)`. In this case, the `add_candidate` would be a `BufferLoadNode` 
of `temp`, and the subsequent check `add_candidate->as<AddNode>()` would fail. 
Supporting bias-free ReLU is a common use case and would make this primitive 
more general. You could extend the logic to handle cases where `add_candidate` 
is a direct `BufferLoadNode` from the reduction buffer, treating it as a 
`BiasReLU` with a zero bias.
   
   ```c
       if (is_zero_const && add_candidate != nullptr) {
         if (const auto* add = add_candidate->as<AddNode>()) {
           const auto* load_a = add->a.as<BufferLoadNode>();
           const auto* load_b = add->b.as<BufferLoadNode>();
   
           bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_);
           bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_);
   
           // Ensure exactly one operand is from the reduction buffer
           if (a_is_target != b_is_target) {
             epilogue_addend_ = a_is_target ? add->b : add->a;
             epilogue_type_ = EpilogueType::BiasReLU;
             return true;
           }
         } else if (const auto* load = add_candidate->as<BufferLoadNode>()) {
           if (load->buffer.same_as(inlined_buffer_)) {
             epilogue_addend_ = tir::make_zero(load->dtype);
             epilogue_type_ = EpilogueType::BiasReLU;
             return true;
           }
         }
       }
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to