kimm240 opened a new pull request, #18636:
URL: https://github.com/apache/tvm/pull/18636
## Major Changes for Generalization
### 1. Pattern Matching Removal
**Removed Items:**
- `EpilogueType` enum (Bias, BiasReLU, Clipping)
- `AnalyzeEpiloguePattern()` function (~130 lines)
- Pattern-specific branching logic
**Current Approach:**
- Directly process the entire epilogue expression without pattern matching
### 2. Store Entire Epilogue Expression
**Changes:**
- Store the entire epilogue expression in `epilogue_expression_`
- Use the expression directly without pattern analysis
```cpp
// Store the epilogue expression and reduction buffer load
epilogue_expression_ = inlined_store_->value;
reduction_buffer_load_ = loads[0];
```
### 3. Generalized Init Transformation
**Approach:**
- Replace reduction buffer load with identity element (0)
- Apply to the entire expression to generate init value
```cpp
InitSubstituter init_subst(inlined_buffer_, identity_elem);
PrimExpr init_epilogue = init_subst(epilogue_expression_);
// Simplify: 0 + C[vi, vj] -> C[vi, vj]
```
**Examples:**
- `temp + C` → `0 + C` → `C` (simplify)
- `max(temp + C, 0)` → `max(0 + C, 0)` → `max(C, 0)`
- `min(max(temp, lower), upper)` → `min(max(0, lower), upper)`
### 4. Generalized Update Transformation
**Core Logic: `GeneralizedEpilogueApplier`**
- Replace reduction buffer load with reduction update
- If parent is Add and the other operand is not a reduction buffer → treat
as bias addend and remove
- Otherwise → apply expression as-is
```cpp
class GeneralizedEpilogueApplier : public ExprMutator {
// Replace reduction buffer load with reduction update
// Automatically detect and remove bias addend in Add nodes
// Automatically support other activation functions
};
```
## Results and Verification
### Existing Tests Pass
All existing tests pass, maintaining backward compatibility:
- `test_fuse_reduction_epilogue_basic`
- `test_fuse_reduction_epilogue_fp32`
- `test_fuse_reduction_epilogue_numerical_correctness`
- `test_fuse_reduction_epilogue_multiple_epilogue`
- `test_matmul_bias_relu`
- `test_matmul_bias_relu_correctness_unified`
- `test_matmul_clipping`
- `test_matmul_clipping_correctness_unified`
- Other commutative variants tests
Total: All 15 tests pass
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]