kimm240 commented on code in PR #18418:
URL: https://github.com/apache/tvm/pull/18418#discussion_r2540572389
##########
src/tir/schedule/primitive/compute_inline.cc:
##########
@@ -984,6 +984,472 @@ void ReverseComputeInline(ScheduleState self, const
StmtSRef& consumer_block_sre
ReverseComputeInlineImpl(self, consumer_block_sref);
}
+/*!
+ * \brief Helper to fuse epilogue block into reduction block
+ * Analyzes epilogue pattern and transforms reduction init/update
+ */
+class ReductionEpilogueFuser : public BaseInliner {
+ public:
+ explicit ReductionEpilogueFuser(const Buffer& reduction_buffer, const
BlockNode* reduction_block,
+ const BlockRealize& epilogue_block_realize,
+ const StmtSRef& scope_root_sref, const
IRModule& mod)
+ : BaseInliner(reduction_buffer, epilogue_block_realize->block,
scope_root_sref),
+ reduction_block_(reduction_block),
+ epilogue_block_(epilogue_block_realize->block.get()),
+ mod_(mod) {}
+
+ bool BodyPatternAllowFusion(const BlockRealize& epilogue_block_realize);
+
+ // Step 2: Create single fused reduction block
+ Block CreateFusedReductionBlock(const BlockNode* reduction_block,
+ const BlockRealizeNode* reduction_realize);
+
+ private:
+ bool AnalyzeEpiloguePattern(const PrimExpr& value);
+ bool IsReductionBlock(const BlockNode* block);
+ void ExtractEpilogueInfo();
+ // Helper function to extract BufferLoad nodes from BufferStore
+ static std::vector<const BufferLoadNode*> ExtractBufferLoad(const Buffer&
buffer,
+ const
BufferStoreNode* from) {
+ struct Extractor : public ExprVisitor {
+ void VisitExpr_(const BufferLoadNode* load) final {
+ if (load->buffer.get() == buffer) {
+ result.push_back(load);
+ }
+ ExprVisitor::VisitExpr_(load);
+ }
+ const BufferNode* buffer;
+ std::vector<const BufferLoadNode*> result;
+ } extractor;
+ extractor.buffer = buffer.get();
+ for (const PrimExpr& expr : from->indices) {
+ extractor(expr);
+ }
+ extractor(from->value);
+ return std::move(extractor.result);
+ }
+
+ const BlockNode* reduction_block_;
+ const BlockNode* epilogue_block_;
+ const IRModule& mod_;
+ PrimExpr epilogue_addend_{nullptr}; // C[vi, vj] in D =
temp + C
+ Buffer epilogue_output_buffer_{nullptr}; // Output buffer D
+ ffi::Array<PrimExpr> epilogue_output_indices_{nullptr}; // Indices of D[vi,
vj]
+ BufferRegion epilogue_output_region_{nullptr}; // Write region of D
+ Buffer epilogue_addend_buffer_{nullptr}; // Addend buffer C
+ BufferRegion epilogue_addend_region_{nullptr}; // Read region of C
+};
+
+bool ReductionEpilogueFuser::BodyPatternAllowFusion(const BlockRealize&
epilogue_block_realize) {
+ // 1. Validate predicate
+ if (!is_one(epilogue_block_realize->predicate)) {
+ // Failure: Predicate in epilogue block is not supported
+ return false;
+ }
+
+ // 2. Check if epilogue body is BufferStore
+ if (inlined_store_ == nullptr) {
+ // Failure: epilogue block body is not BufferStore
+ return false;
+ }
+
+ // 3. Check if epilogue reads from reduction buffer
+ std::vector<const BufferLoadNode*> loads =
ExtractBufferLoad(inlined_buffer_, inlined_store_);
+ if (loads.size() == 0) {
+ // Failure: no BufferLoad from the reduction buffer
+ return false;
+ }
+
+ // 4. Analyze epilogue pattern: D[i,j] = temp[i,j] + C[i,j]
+ if (!AnalyzeEpiloguePattern(inlined_store_->value)) {
+ // Failure: epilogue is not a simple addition pattern
+ return false;
+ }
+
+ // 5. Check if producer is a reduction block
+ if (!IsReductionBlock(reduction_block_)) {
+ // Failure: producer is not a reduction block
+ return false;
+ }
+
+ // 6. Extract epilogue information (output buffer, indices, regions, etc.)
+ ExtractEpilogueInfo();
+
+ return true;
+}
+
+bool ReductionEpilogueFuser::AnalyzeEpiloguePattern(const PrimExpr& value) {
+ // Pattern: temp[i,j] + C[i,j] or C[i,j] + temp[i,j]
+ if (const auto* add = value.as<AddNode>()) {
+ const auto* load_a = add->a.as<BufferLoadNode>();
+ const auto* load_b = add->b.as<BufferLoadNode>();
+
+ bool a_is_target = load_a && load_a->buffer.same_as(inlined_buffer_);
+ bool b_is_target = load_b && load_b->buffer.same_as(inlined_buffer_);
+
+ // Ensure exactly one operand is from the reduction buffer
+ if (a_is_target != b_is_target) {
+ epilogue_addend_ = a_is_target ? add->b : add->a;
+ return true;
+ }
+ }
+
+ return false;
+}
+
+bool ReductionEpilogueFuser::IsReductionBlock(const BlockNode* block) {
+ // Check if block has reduction iter vars
+ for (const IterVar& iter : block->iter_vars) {
+ if (iter->iter_type == kCommReduce) {
+ return true;
+ }
+ }
+ return false;
+}
+
+void ReductionEpilogueFuser::ExtractEpilogueInfo() {
+ // Extract epilogue output buffer and indices
+ epilogue_output_buffer_ = inlined_store_->buffer;
+ epilogue_output_indices_ = inlined_store_->indices;
+
+ // Extract epilogue output region from epilogue block writes
+ for (const BufferRegion& write : epilogue_block_->writes) {
+ if (write->buffer.same_as(epilogue_output_buffer_)) {
+ epilogue_output_region_ = write;
+ break;
+ }
+ }
+
+ // Extract epilogue addend buffer and region from epilogue_addend_
+ if (const auto* load = epilogue_addend_.as<BufferLoadNode>()) {
+ epilogue_addend_buffer_ = load->buffer;
+ // Find the read region from epilogue block reads
+ for (const BufferRegion& read : epilogue_block_->reads) {
+ if (read->buffer.same_as(epilogue_addend_buffer_)) {
+ epilogue_addend_region_ = read;
+ break;
+ }
+ }
+ }
+}
+
+Block ReductionEpilogueFuser::CreateFusedReductionBlock(const BlockNode*
reduction_block,
+ const
BlockRealizeNode* reduction_realize) {
+ ObjectPtr<BlockNode> new_block =
ffi::make_object<BlockNode>(*reduction_block);
+
+ // 1. Keep all iter vars (data parallel + reduction)
+ new_block->iter_vars = reduction_block->iter_vars;
Review Comment:
@wrongtest-intellif I've addressed your feedback and pushed the changes.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]