gemini-code-assist[bot] commented on code in PR #18510:
URL: https://github.com/apache/tvm/pull/18510#discussion_r2564020407


##########
src/tir/schedule/primitive/compute_inline.cc:
##########
@@ -1090,10 +1109,80 @@ bool 
ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
     // Ensure exactly one operand is from the reduction buffer
     if (a_is_target != b_is_target) {
       epilogue_addend_ = a_is_target ? add->b : add->a;
+      epilogue_type_ = EpilogueType::Bias;
       return true;
     }
   }
 
+  // Pattern 2: min(max(temp[i,j], lower), upper) or max(min(temp[i,j], 
upper), lower) (Clipping)
+  // Check for min(max(temp, lower), upper) pattern
+  if (const auto* min_node = value.as<MinNode>()) {
+    if (const auto* max_node = min_node->a.as<MaxNode>()) {
+      const auto* load = max_node->a.as<BufferLoadNode>();
+      if (load && load->buffer.same_as(inlined_buffer_)) {
+        // Pattern: min(max(temp, lower), upper)
+        clipping_lower_ = max_node->b;
+        clipping_upper_ = min_node->b;
+        epilogue_type_ = EpilogueType::Clipping;
+        return true;
+      }
+    }
+    // Check for min(temp, upper) where temp might be wrapped in max
+    if (const auto* max_node = min_node->b.as<MaxNode>()) {
+      const auto* load = max_node->a.as<BufferLoadNode>();
+      if (load && load->buffer.same_as(inlined_buffer_)) {
+        // Pattern: min(max(temp, lower), upper) - but check if min is outer
+        // This case is already handled above, but we check for max(min(temp, 
upper), lower)
+        clipping_lower_ = max_node->b;
+        clipping_upper_ = min_node->a;
+        epilogue_type_ = EpilogueType::Clipping;
+        return true;
+      }
+    }
+  }
+  // Check for max(min(temp[i,j], upper), lower) pattern
+  if (const auto* max_node = value.as<MaxNode>()) {
+    if (const auto* min_node = max_node->a.as<MinNode>()) {
+      const auto* load = min_node->a.as<BufferLoadNode>();
+      if (load && load->buffer.same_as(inlined_buffer_)) {
+        // Pattern: max(min(temp, upper), lower)
+        clipping_lower_ = max_node->b;
+        clipping_upper_ = min_node->b;
+        epilogue_type_ = EpilogueType::Clipping;
+        return true;
+      }
+    }
+  }
+
+  // Pattern 3: max(temp[i,j] + C[i,j], 0) or max(C[i,j] + temp[i,j], 0) 
(BiasReLU)
+  if (const auto* max_node = value.as<MaxNode>()) {
+    // Check if second operand is zero (ReLU: max(x, 0))
+    // Support both integer and float zero constants
+    bool is_zero_const = false;
+    if (tir::is_zero(max_node->b)) {
+      is_zero_const = true;
+    } else if (const auto* float_imm = max_node->b.as<FloatImmNode>()) {
+      is_zero_const = (float_imm->value == 0.0);
+    }
+    if (is_zero_const) {
+      // Check if first operand is AddNode
+      if (const auto* add = max_node->a.as<AddNode>()) {
+        const auto* load_a = add->a.as<BufferLoadNode>();
+        const auto* load_b = add->b.as<BufferLoadNode>();
+
+        bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_);
+        bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_);
+
+        // Ensure exactly one operand is from the reduction buffer
+        if (a_is_target != b_is_target) {
+          epilogue_addend_ = a_is_target ? add->b : add->a;
+          epilogue_type_ = EpilogueType::BiasReLU;
+          return true;
+        }
+      }
+    }
+  }

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The pattern matching for `BiasReLU` does not handle the commutativity of the 
`max` operator. It only checks for `max(x, 0)` where `0` is the second operand 
(`max_node->b`). It will miss the pattern `max(0, x)`. You should check both 
operands of the `max` node to see if one is zero and the other is an `AddNode`.



##########
tests/python/tir-schedule/test_tir_schedule_fuse_reduction_epilogue_clipping.py:
##########
@@ -0,0 +1,187 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: 
disable=missing-module-docstring,missing-function-docstring,missing-class-docstring
+"""Test cases for fuse_reduction_epilogue with clipping pattern (min(max(temp, 
lower), upper))."""
+
+import numpy as np
+import pytest
+import tvm
+import tvm.testing
+from tvm import tir
+from tvm.script import tir as T
+from tvm.tir.schedule.testing import (
+    verify_trace_roundtrip,
+    assert_structural_equal_ignore_global_symbol,
+)
+
+
[email protected]_func
+def matmul_clipping_before(
+    A: T.Buffer((16, 16), "float32"),
+    B: T.Buffer((16, 16), "float32"),
+    D: T.Buffer((16, 16), "float32"),
+) -> None:
+    """Original function with separate reduction and clipping epilogue 
blocks."""
+    temp = T.alloc_buffer((16, 16), dtype="float32")
+    for i, j, k in T.grid(16, 16, 16):
+        with T.block("matmul"):
+            vi, vj, vk = T.axis.remap("SSR", [i, j, k])
+            with T.init():
+                temp[vi, vj] = T.float32(0)
+            temp[vi, vj] = temp[vi, vj] + A[vi, vk] * B[vk, vj]
+    for i, j in T.grid(16, 16):
+        with T.block("clip"):
+            vi, vj = T.axis.remap("SS", [i, j])
+            D[vi, vj] = T.min(T.max(temp[vi, vj], T.float32(0)), T.float32(10))
+
+
+def test_matmul_clipping():
+    """Test that clipping pattern is correctly fused into reduction block."""
+    sch = tir.Schedule(matmul_clipping_before, debug_mask="all")
+    sch.fuse_reduction_epilogue("matmul", "clip")
+    mod = sch.mod["main"]
+    # The expected IR should have clipping in init, but due to parsing issues,
+    # we verify the structure programmatically instead
+    # Expected: init = T.min(T.max(T.float32(0.0), T.float32(0.0)), 
T.float32(10.0))
+    # For now, just verify fusion succeeded and the body has clipping
+    verify_trace_roundtrip(sch=sch, mod=matmul_clipping_before)

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   The test for clipping fusion only covers one specific pattern 
`T.min(T.max(temp, lower), upper)`. To ensure the pattern matching logic is 
robust, it would be beneficial to add more test cases for equivalent 
commutative variations of the clipping pattern. This would help catch issues 
with pattern matching not handling commutativity correctly.
   
   For example, you could add tests for patterns like:
   - `T.min(upper, T.max(temp, lower))`
   - `T.min(T.max(lower, temp), upper)`
   - `T.max(T.min(temp, upper), lower)`
   - `T.max(lower, T.min(temp, upper))`
   
   A good way to do this would be to use `pytest.mark.parametrize`.



##########
src/tir/schedule/primitive/compute_inline.cc:
##########
@@ -1090,10 +1109,80 @@ bool 
ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
     // Ensure exactly one operand is from the reduction buffer
     if (a_is_target != b_is_target) {
       epilogue_addend_ = a_is_target ? add->b : add->a;
+      epilogue_type_ = EpilogueType::Bias;
       return true;
     }
   }
 
+  // Pattern 2: min(max(temp[i,j], lower), upper) or max(min(temp[i,j], 
upper), lower) (Clipping)
+  // Check for min(max(temp, lower), upper) pattern
+  if (const auto* min_node = value.as<MinNode>()) {
+    if (const auto* max_node = min_node->a.as<MaxNode>()) {
+      const auto* load = max_node->a.as<BufferLoadNode>();
+      if (load && load->buffer.same_as(inlined_buffer_)) {
+        // Pattern: min(max(temp, lower), upper)
+        clipping_lower_ = max_node->b;
+        clipping_upper_ = min_node->b;
+        epilogue_type_ = EpilogueType::Clipping;
+        return true;
+      }
+    }
+    // Check for min(temp, upper) where temp might be wrapped in max
+    if (const auto* max_node = min_node->b.as<MaxNode>()) {
+      const auto* load = max_node->a.as<BufferLoadNode>();
+      if (load && load->buffer.same_as(inlined_buffer_)) {
+        // Pattern: min(max(temp, lower), upper) - but check if min is outer
+        // This case is already handled above, but we check for max(min(temp, 
upper), lower)
+        clipping_lower_ = max_node->b;
+        clipping_upper_ = min_node->a;
+        epilogue_type_ = EpilogueType::Clipping;
+        return true;
+      }
+    }
+  }
+  // Check for max(min(temp[i,j], upper), lower) pattern
+  if (const auto* max_node = value.as<MaxNode>()) {
+    if (const auto* min_node = max_node->a.as<MinNode>()) {
+      const auto* load = min_node->a.as<BufferLoadNode>();
+      if (load && load->buffer.same_as(inlined_buffer_)) {
+        // Pattern: max(min(temp, upper), lower)
+        clipping_lower_ = max_node->b;
+        clipping_upper_ = min_node->b;
+        epilogue_type_ = EpilogueType::Clipping;
+        return true;
+      }
+    }
+  }

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The clipping pattern matching logic does not fully handle the commutative 
property of `min` and `max` operators.
   
   Specifically:
   1. For `min(max(temp, ...), ...)` and `max(min(temp, ...), ...)` patterns, 
the code only checks if `temp` is the first operand of the inner `max`/`min` 
node. It will miss patterns like `min(max(lower, temp), upper)`.
   2. For the `max(min(...), ...)` pattern, the code only checks if the `min` 
node is the first operand of `max`. It will miss `max(lower, min(temp, upper))`.
   
   This will cause fusion to fail for valid and equivalent patterns. To make 
the pattern matching more robust, it should be updated to check both operands 
of `min` and `max` nodes at each level.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to