llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir-sparse Author: Peiming Liu (PeimingLiu) <details> <summary>Changes</summary> Stacked PRs: * #<!-- -->105567 * __->__#<!-- -->105566 * #<!-- -->105565 --- --- --- ### [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. --- Full diff: https://github.com/llvm/llvm-project/pull/105566.diff 1 Files Affected: - (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (+36-82) ``````````diff diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp index d6c0da4a9e457..f7fcabb0220b5 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp @@ -244,88 +244,41 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> { std::unique_ptr<SparseIterator> it = iterSpace.extractIterator(rewriter, loc); - if (it->iteratableByFor()) { - auto [lo, hi] = it->genForCond(rewriter, loc); - Value step = constantIndex(rewriter, loc, 1); - SmallVector<Value> ivs; - for (ValueRange inits : adaptor.getInitArgs()) - llvm::append_range(ivs, inits); - scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, lo, hi, step, ivs); - - Block *loopBody = op.getBody(); - OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes()); - if (failed(typeConverter->convertSignatureArgs( - loopBody->getArgumentTypes(), bodyTypeMapping))) - return failure(); - rewriter.applySignatureConversion(loopBody, bodyTypeMapping); - - rewriter.eraseBlock(forOp.getBody()); - Region &dstRegion = forOp.getRegion(); - rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); - - auto yieldOp = - llvm::cast<sparse_tensor::YieldOp>(forOp.getBody()->getTerminator()); - - rewriter.setInsertionPointToEnd(forOp.getBody()); - // replace sparse_tensor.yield with scf.yield. - rewriter.create<scf::YieldOp>(loc, yieldOp.getResults()); - rewriter.eraseOp(yieldOp); - - const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); - rewriter.replaceOp(op, forOp.getResults(), resultMapping); - } else { - SmallVector<Value> ivs; - // TODO: put iterator at the end of argument list to be consistent with - // coiterate operation. - llvm::append_range(ivs, it->getCursor()); - for (ValueRange inits : adaptor.getInitArgs()) - llvm::append_range(ivs, inits); - - assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; })); - - TypeRange types = ValueRange(ivs).getTypes(); - auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs); - SmallVector<Location> l(types.size(), op.getIterator().getLoc()); - - // Generates loop conditions. - Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l); - rewriter.setInsertionPointToStart(before); - ValueRange bArgs = before->getArguments(); - auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs); - assert(remArgs.size() == adaptor.getInitArgs().size()); - rewriter.create<scf::ConditionOp>(loc, whileCond, before->getArguments()); - - // Generates loop body. - Block *loopBody = op.getBody(); - OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes()); - if (failed(typeConverter->convertSignatureArgs( - loopBody->getArgumentTypes(), bodyTypeMapping))) - return failure(); - rewriter.applySignatureConversion(loopBody, bodyTypeMapping); - - Region &dstRegion = whileOp.getAfter(); - // TODO: handle uses of coordinate! - rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); - ValueRange aArgs = whileOp.getAfterArguments(); - auto yieldOp = llvm::cast<sparse_tensor::YieldOp>( - whileOp.getAfterBody()->getTerminator()); - - rewriter.setInsertionPointToEnd(whileOp.getAfterBody()); + SmallVector<Value> ivs; + for (ValueRange inits : adaptor.getInitArgs()) + llvm::append_range(ivs, inits); + + // Type conversion on iterate op block. + OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes()); + if (failed(typeConverter->convertSignatureArgs( + op.getBody()->getArgumentTypes(), blockTypeMapping))) + return rewriter.notifyMatchFailure( + op, "failed to convert iterate region argurment types"); + rewriter.applySignatureConversion(op.getBody(), blockTypeMapping); + + Block *block = op.getBody(); + ValueRange ret = genLoopWithIterator( + rewriter, loc, it.get(), ivs, /*iterFirst=*/true, + [block](PatternRewriter &rewriter, Location loc, Region &loopBody, + SparseIterator *it, ValueRange reduc) -> SmallVector<Value> { + SmallVector<Value> blockArgs(it->getCursor()); + // TODO: Also appends coordinates if used. + // blockArgs.push_back(it->deref(rewriter, loc)); + llvm::append_range(blockArgs, reduc); + + Block *dstBlock = &loopBody.getBlocks().front(); + rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(), + blockArgs); + auto yield = llvm::cast<sparse_tensor::YieldOp>(dstBlock->back()); + // We can not use ValueRange as the operation holding the values will + // be destoryed. + SmallVector<Value> result(yield.getResults()); + rewriter.eraseOp(yield); + return result; + }); - aArgs = it->linkNewScope(aArgs); - ValueRange nx = it->forward(rewriter, loc); - SmallVector<Value> yields; - llvm::append_range(yields, nx); - llvm::append_range(yields, yieldOp.getResults()); - - // replace sparse_tensor.yield with scf.yield. - rewriter.eraseOp(yieldOp); - rewriter.create<scf::YieldOp>(loc, yields); - const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); - rewriter.replaceOp( - op, whileOp.getResults().drop_front(it->getCursor().size()), - resultMapping); - } + const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); + rewriter.replaceOp(op, ret, resultMapping); return success(); } }; @@ -366,9 +319,10 @@ class SparseCoIterateOpConverter Block *block = ®ion.getBlocks().front(); OneToNTypeMapping blockTypeMapping(block->getArgumentTypes()); if (failed(typeConverter->convertSignatureArgs(block->getArgumentTypes(), - blockTypeMapping))) + blockTypeMapping))) { return rewriter.notifyMatchFailure( op, "failed to convert coiterate region argurment types"); + } rewriter.applySignatureConversion(block, blockTypeMapping); } `````````` </details> https://github.com/llvm/llvm-project/pull/105566 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits