llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-mlir-sparse Author: Peiming Liu (PeimingLiu) <details> <summary>Changes</summary> Stacked PRs: * __->__#<!-- -->105567 * #<!-- -->105566 * #<!-- -->105565 --- --- --- ### [mlir][sparse] unify block arguments order between iterate/coiterate operations. --- Full diff: https://github.com/llvm/llvm-project/pull/105567.diff 3 Files Affected: - (modified) mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td (+3-4) - (modified) mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp (+17-14) - (modified) mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp (+11-25) ``````````diff diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index 20512f972e67cd..96a61419a541f7 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -1644,7 +1644,7 @@ def IterateOp : SparseTensor_Op<"iterate", return getIterSpace().getType().getSpaceDim(); } BlockArgument getIterator() { - return getRegion().getArguments().front(); + return getRegion().getArguments().back(); } std::optional<BlockArgument> getLvlCrd(Level lvl) { if (getCrdUsedLvls()[lvl]) { @@ -1654,9 +1654,8 @@ def IterateOp : SparseTensor_Op<"iterate", return std::nullopt; } Block::BlockArgListType getCrds() { - // The first block argument is iterator, the remaining arguments are - // referenced coordinates. - return getRegion().getArguments().slice(1, getCrdUsedLvls().count()); + // User-provided iteration arguments -> coords -> iterator. + return getRegion().getArguments().slice(getNumRegionIterArgs(), getCrdUsedLvls().count()); } unsigned getNumRegionIterArgs() { return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count(); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 16856b958d4f13..b21bc1a93036c4 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -2228,9 +2228,10 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, parser.getNameLoc(), "mismatch in number of sparse iterators and sparse spaces"); - if (failed(parseUsedCoordList(parser, state, blockArgs))) + SmallVector<OpAsmParser::Argument> coords; + if (failed(parseUsedCoordList(parser, state, coords))) return failure(); - size_t numCrds = blockArgs.size(); + size_t numCrds = coords.size(); // Parse "iter_args(%arg = %init, ...)" bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args")); @@ -2238,6 +2239,8 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, if (parser.parseAssignmentList(blockArgs, initArgs)) return failure(); + blockArgs.append(coords); + SmallVector<Type> iterSpaceTps; // parse ": sparse_tensor.iter_space -> ret" if (parser.parseColon() || parser.parseTypeList(iterSpaceTps)) @@ -2267,7 +2270,7 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, if (hasIterArgs) { // Strip off leading args that used for coordinates. - MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds); + MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds); if (args.size() != initArgs.size() || args.size() != state.types.size()) { return parser.emitError( parser.getNameLoc(), @@ -2448,18 +2451,18 @@ void IterateOp::build(OpBuilder &builder, OperationState &odsState, odsState.addTypes(initArgs.getTypes()); Block *bodyBlock = builder.createBlock(bodyRegion); - // First argument, sparse iterator - bodyBlock->addArgument( - llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(), - odsState.location); + // Starts with a list of user-provided loop arguments. + for (Value v : initArgs) + bodyBlock->addArgument(v.getType(), v.getLoc()); - // Followed by a list of used coordinates. + // Follows by a list of used coordinates. for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++) bodyBlock->addArgument(builder.getIndexType(), odsState.location); - // Followed by a list of user-provided loop arguments. - for (Value v : initArgs) - bodyBlock->addArgument(v.getType(), v.getLoc()); + // Ends with sparse iterator + bodyBlock->addArgument( + llvm::cast<IterSpaceType>(iterSpace.getType()).getIteratorType(), + odsState.location); } ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { @@ -2473,9 +2476,9 @@ ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { return parser.emitError(parser.getNameLoc(), "expected only one iterator/iteration space"); - iters.append(iterArgs); + iterArgs.append(iters); Region *body = result.addRegion(); - if (parser.parseRegion(*body, iters)) + if (parser.parseRegion(*body, iterArgs)) return failure(); IterateOp::ensureTerminator(*body, parser.getBuilder(), result.location); @@ -2580,7 +2583,7 @@ MutableArrayRef<OpOperand> IterateOp::getInitsMutable() { } Block::BlockArgListType IterateOp::getRegionIterArgs() { - return getRegion().getArguments().take_back(getNumRegionIterArgs()); + return getRegion().getArguments().take_front(getNumRegionIterArgs()); } std::optional<MutableArrayRef<OpOperand>> IterateOp::getYieldedValuesMutable() { diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp index f7fcabb0220b50..71a229bea990c0 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp @@ -111,7 +111,7 @@ genCoIterateBranchNest(PatternRewriter &rewriter, Location loc, CoIterateOp op, static ValueRange genLoopWithIterator( PatternRewriter &rewriter, Location loc, SparseIterator *it, - ValueRange reduc, bool iterFirst, + ValueRange reduc, function_ref<SmallVector<Value>(PatternRewriter &rewriter, Location loc, Region &loopBody, SparseIterator *it, ValueRange reduc)> @@ -138,15 +138,9 @@ static ValueRange genLoopWithIterator( } return forOp.getResults(); } - SmallVector<Value> ivs; - // TODO: always put iterator SSA values at the end of argument list to be - // consistent with coiterate operation. - if (!iterFirst) - llvm::append_range(ivs, it->getCursor()); - // Appends the user-provided values. - llvm::append_range(ivs, reduc); - if (iterFirst) - llvm::append_range(ivs, it->getCursor()); + + SmallVector<Value> ivs(reduc); + llvm::append_range(ivs, it->getCursor()); TypeRange types = ValueRange(ivs).getTypes(); auto whileOp = rewriter.create<scf::WhileOp>(loc, types, ivs); @@ -164,12 +158,8 @@ static ValueRange genLoopWithIterator( Region &dstRegion = whileOp.getAfter(); Block *after = rewriter.createBlock(&dstRegion, {}, types, l); ValueRange aArgs = whileOp.getAfterArguments(); - if (iterFirst) { - aArgs = it->linkNewScope(aArgs); - } else { - aArgs = aArgs.take_front(reduc.size()); - it->linkNewScope(aArgs.drop_front(reduc.size())); - } + it->linkNewScope(aArgs.drop_front(reduc.size())); + aArgs = aArgs.take_front(reduc.size()); rewriter.setInsertionPointToStart(after); SmallVector<Value> ret = bodyBuilder(rewriter, loc, dstRegion, it, aArgs); @@ -177,12 +167,8 @@ static ValueRange genLoopWithIterator( // Forward loops SmallVector<Value> yields; - ValueRange nx = it->forward(rewriter, loc); - if (iterFirst) - llvm::append_range(yields, nx); llvm::append_range(yields, ret); - if (!iterFirst) - llvm::append_range(yields, nx); + llvm::append_range(yields, it->forward(rewriter, loc)); rewriter.create<scf::YieldOp>(loc, yields); } return whileOp.getResults().drop_front(it->getCursor().size()); @@ -258,13 +244,13 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern<IterateOp> { Block *block = op.getBody(); ValueRange ret = genLoopWithIterator( - rewriter, loc, it.get(), ivs, /*iterFirst=*/true, + rewriter, loc, it.get(), ivs, [block](PatternRewriter &rewriter, Location loc, Region &loopBody, SparseIterator *it, ValueRange reduc) -> SmallVector<Value> { - SmallVector<Value> blockArgs(it->getCursor()); + SmallVector<Value> blockArgs(reduc); // TODO: Also appends coordinates if used. // blockArgs.push_back(it->deref(rewriter, loc)); - llvm::append_range(blockArgs, reduc); + llvm::append_range(blockArgs, it->getCursor()); Block *dstBlock = &loopBody.getBlocks().front(); rewriter.inlineBlockBefore(block, dstBlock, dstBlock->end(), @@ -404,7 +390,7 @@ class SparseCoIterateOpConverter Block *block = &r.getBlocks().front(); ValueRange curResult = genLoopWithIterator( - rewriter, loc, validIters.front(), userReduc, /*iterFirst=*/false, + rewriter, loc, validIters.front(), userReduc, /*bodyBuilder=*/ [block](PatternRewriter &rewriter, Location loc, Region &dstRegion, SparseIterator *it, `````````` </details> https://github.com/llvm/llvm-project/pull/105567 _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits