[llvm-branch-commits] [mlir] [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. (PR #105566)
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/105566 [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. >From 1a32495b27dfd003408dd5b4f12f3db7f0b73b5a Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Thu, 15 Aug 2024 18:10:25 + Subject: [PATCH] [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. --- .../Transforms/SparseIterationToScf.cpp | 118 ++ 1 file changed, 36 insertions(+), 82 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp index d6c0da4a9e4573..f7fcabb0220b50 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp @@ -244,88 +244,41 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern { std::unique_ptr it = iterSpace.extractIterator(rewriter, loc); -if (it->iteratableByFor()) { - auto [lo, hi] = it->genForCond(rewriter, loc); - Value step = constantIndex(rewriter, loc, 1); - SmallVector ivs; - for (ValueRange inits : adaptor.getInitArgs()) -llvm::append_range(ivs, inits); - scf::ForOp forOp = rewriter.create(loc, lo, hi, step, ivs); - - Block *loopBody = op.getBody(); - OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes()); - if (failed(typeConverter->convertSignatureArgs( - loopBody->getArgumentTypes(), bodyTypeMapping))) -return failure(); - rewriter.applySignatureConversion(loopBody, bodyTypeMapping); - - rewriter.eraseBlock(forOp.getBody()); - Region &dstRegion = forOp.getRegion(); - rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); - - auto yieldOp = - llvm::cast(forOp.getBody()->getTerminator()); - - rewriter.setInsertionPointToEnd(forOp.getBody()); - // replace sparse_tensor.yield with scf.yield. - rewriter.create(loc, yieldOp.getResults()); - rewriter.eraseOp(yieldOp); - - const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); - rewriter.replaceOp(op, forOp.getResults(), resultMapping); -} else { - SmallVector ivs; - // TODO: put iterator at the end of argument list to be consistent with - // coiterate operation. - llvm::append_range(ivs, it->getCursor()); - for (ValueRange inits : adaptor.getInitArgs()) -llvm::append_range(ivs, inits); - - assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; })); - - TypeRange types = ValueRange(ivs).getTypes(); - auto whileOp = rewriter.create(loc, types, ivs); - SmallVector l(types.size(), op.getIterator().getLoc()); - - // Generates loop conditions. - Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l); - rewriter.setInsertionPointToStart(before); - ValueRange bArgs = before->getArguments(); - auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs); - assert(remArgs.size() == adaptor.getInitArgs().size()); - rewriter.create(loc, whileCond, before->getArguments()); - - // Generates loop body. - Block *loopBody = op.getBody(); - OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes()); - if (failed(typeConverter->convertSignatureArgs( - loopBody->getArgumentTypes(), bodyTypeMapping))) -return failure(); - rewriter.applySignatureConversion(loopBody, bodyTypeMapping); - - Region &dstRegion = whileOp.getAfter(); - // TODO: handle uses of coordinate! - rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); - ValueRange aArgs = whileOp.getAfterArguments(); - auto yieldOp = llvm::cast( - whileOp.getAfterBody()->getTerminator()); - - rewriter.setInsertionPointToEnd(whileOp.getAfterBody()); +SmallVector ivs; +for (ValueRange inits : adaptor.getInitArgs()) + llvm::append_range(ivs, inits); + +// Type conversion on iterate op block. +OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes()); +if (failed(typeConverter->convertSignatureArgs( +op.getBody()->getArgumentTypes(), blockTypeMapping))) + return rewriter.notifyMatchFailure( + op, "failed to convert iterate region argurment types"); +rewriter.applySignatureConversion(op.getBody(), blockTypeMapping); + +Block *block = op.getBody(); +ValueRange ret = genLoopWithIterator( +rewriter, loc, it.get(), ivs, /*iterFirst=*/true, +[block](PatternRewriter &rewriter, Location loc, Region &loopBody, +SparseIterator *it, ValueRange reduc) -> SmallVector { + SmallVector blockArgs(it->getCursor()); + // TODO: Also appends coordinates if used. + // blockArgs.push_back(it->deref(rewriter, loc)); + llvm::a
[llvm-branch-commits] [mlir] [mlir][sparse] unify block arguments order between iterate/coiterate operations. (PR #105567)
https://github.com/PeimingLiu created https://github.com/llvm/llvm-project/pull/105567 [mlir][sparse] unify block arguments order between iterate/coiterate operations. >From 6fd099fb7039f8fda37d50f1c44cd7afd62cafb7 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Thu, 15 Aug 2024 21:10:37 + Subject: [PATCH] [mlir][sparse] unify block arguments order between iterate/coiterate operations. --- .../SparseTensor/IR/SparseTensorOps.td| 7 ++-- .../SparseTensor/IR/SparseTensorDialect.cpp | 31 .../Transforms/SparseIterationToScf.cpp | 36 ++- 3 files changed, 31 insertions(+), 43 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index 20512f972e67cd..96a61419a541f7 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -1644,7 +1644,7 @@ def IterateOp : SparseTensor_Op<"iterate", return getIterSpace().getType().getSpaceDim(); } BlockArgument getIterator() { - return getRegion().getArguments().front(); + return getRegion().getArguments().back(); } std::optional getLvlCrd(Level lvl) { if (getCrdUsedLvls()[lvl]) { @@ -1654,9 +1654,8 @@ def IterateOp : SparseTensor_Op<"iterate", return std::nullopt; } Block::BlockArgListType getCrds() { - // The first block argument is iterator, the remaining arguments are - // referenced coordinates. - return getRegion().getArguments().slice(1, getCrdUsedLvls().count()); + // User-provided iteration arguments -> coords -> iterator. + return getRegion().getArguments().slice(getNumRegionIterArgs(), getCrdUsedLvls().count()); } unsigned getNumRegionIterArgs() { return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count(); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 16856b958d4f13..b21bc1a93036c4 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -2228,9 +2228,10 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, parser.getNameLoc(), "mismatch in number of sparse iterators and sparse spaces"); - if (failed(parseUsedCoordList(parser, state, blockArgs))) + SmallVector coords; + if (failed(parseUsedCoordList(parser, state, coords))) return failure(); - size_t numCrds = blockArgs.size(); + size_t numCrds = coords.size(); // Parse "iter_args(%arg = %init, ...)" bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args")); @@ -2238,6 +2239,8 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, if (parser.parseAssignmentList(blockArgs, initArgs)) return failure(); + blockArgs.append(coords); + SmallVector iterSpaceTps; // parse ": sparse_tensor.iter_space -> ret" if (parser.parseColon() || parser.parseTypeList(iterSpaceTps)) @@ -2267,7 +2270,7 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, if (hasIterArgs) { // Strip off leading args that used for coordinates. -MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds); +MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds); if (args.size() != initArgs.size() || args.size() != state.types.size()) { return parser.emitError( parser.getNameLoc(), @@ -2448,18 +2451,18 @@ void IterateOp::build(OpBuilder &builder, OperationState &odsState, odsState.addTypes(initArgs.getTypes()); Block *bodyBlock = builder.createBlock(bodyRegion); - // First argument, sparse iterator - bodyBlock->addArgument( - llvm::cast(iterSpace.getType()).getIteratorType(), - odsState.location); + // Starts with a list of user-provided loop arguments. + for (Value v : initArgs) +bodyBlock->addArgument(v.getType(), v.getLoc()); - // Followed by a list of used coordinates. + // Follows by a list of used coordinates. for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++) bodyBlock->addArgument(builder.getIndexType(), odsState.location); - // Followed by a list of user-provided loop arguments. - for (Value v : initArgs) -bodyBlock->addArgument(v.getType(), v.getLoc()); + // Ends with sparse iterator + bodyBlock->addArgument( + llvm::cast(iterSpace.getType()).getIteratorType(), + odsState.location); } ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { @@ -2473,9 +2476,9 @@ ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { return parser.emitError(parser.getNameLoc(), "expected only one iterator/iteration space"); - iters.append(iterArgs); + iterArgs.append(iters); Region *body = result.addRegion();
[llvm-branch-commits] [mlir] [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. (PR #105566)
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/105566 >From 937bcd814688e7c6f88ef27b7586254006e0d050 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Thu, 15 Aug 2024 18:10:25 + Subject: [PATCH] [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. stack-info: PR: https://github.com/llvm/llvm-project/pull/105566, branch: users/PeimingLiu/stack/2 --- .../Transforms/SparseIterationToScf.cpp | 118 ++ 1 file changed, 36 insertions(+), 82 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp index d6c0da4a9e4573..f7fcabb0220b50 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp @@ -244,88 +244,41 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern { std::unique_ptr it = iterSpace.extractIterator(rewriter, loc); -if (it->iteratableByFor()) { - auto [lo, hi] = it->genForCond(rewriter, loc); - Value step = constantIndex(rewriter, loc, 1); - SmallVector ivs; - for (ValueRange inits : adaptor.getInitArgs()) -llvm::append_range(ivs, inits); - scf::ForOp forOp = rewriter.create(loc, lo, hi, step, ivs); - - Block *loopBody = op.getBody(); - OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes()); - if (failed(typeConverter->convertSignatureArgs( - loopBody->getArgumentTypes(), bodyTypeMapping))) -return failure(); - rewriter.applySignatureConversion(loopBody, bodyTypeMapping); - - rewriter.eraseBlock(forOp.getBody()); - Region &dstRegion = forOp.getRegion(); - rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); - - auto yieldOp = - llvm::cast(forOp.getBody()->getTerminator()); - - rewriter.setInsertionPointToEnd(forOp.getBody()); - // replace sparse_tensor.yield with scf.yield. - rewriter.create(loc, yieldOp.getResults()); - rewriter.eraseOp(yieldOp); - - const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); - rewriter.replaceOp(op, forOp.getResults(), resultMapping); -} else { - SmallVector ivs; - // TODO: put iterator at the end of argument list to be consistent with - // coiterate operation. - llvm::append_range(ivs, it->getCursor()); - for (ValueRange inits : adaptor.getInitArgs()) -llvm::append_range(ivs, inits); - - assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; })); - - TypeRange types = ValueRange(ivs).getTypes(); - auto whileOp = rewriter.create(loc, types, ivs); - SmallVector l(types.size(), op.getIterator().getLoc()); - - // Generates loop conditions. - Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l); - rewriter.setInsertionPointToStart(before); - ValueRange bArgs = before->getArguments(); - auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs); - assert(remArgs.size() == adaptor.getInitArgs().size()); - rewriter.create(loc, whileCond, before->getArguments()); - - // Generates loop body. - Block *loopBody = op.getBody(); - OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes()); - if (failed(typeConverter->convertSignatureArgs( - loopBody->getArgumentTypes(), bodyTypeMapping))) -return failure(); - rewriter.applySignatureConversion(loopBody, bodyTypeMapping); - - Region &dstRegion = whileOp.getAfter(); - // TODO: handle uses of coordinate! - rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); - ValueRange aArgs = whileOp.getAfterArguments(); - auto yieldOp = llvm::cast( - whileOp.getAfterBody()->getTerminator()); - - rewriter.setInsertionPointToEnd(whileOp.getAfterBody()); +SmallVector ivs; +for (ValueRange inits : adaptor.getInitArgs()) + llvm::append_range(ivs, inits); + +// Type conversion on iterate op block. +OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes()); +if (failed(typeConverter->convertSignatureArgs( +op.getBody()->getArgumentTypes(), blockTypeMapping))) + return rewriter.notifyMatchFailure( + op, "failed to convert iterate region argurment types"); +rewriter.applySignatureConversion(op.getBody(), blockTypeMapping); + +Block *block = op.getBody(); +ValueRange ret = genLoopWithIterator( +rewriter, loc, it.get(), ivs, /*iterFirst=*/true, +[block](PatternRewriter &rewriter, Location loc, Region &loopBody, +SparseIterator *it, ValueRange reduc) -> SmallVector { + SmallVector blockArgs(it->getCursor()); + // TODO: Also appends coordinates if used. + // blockArgs.push_back(it->deref(rewriter, loc)); +
[llvm-branch-commits] [mlir] [mlir][sparse] unify block arguments order between iterate/coiterate operations. (PR #105567)
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/105567 >From 3f83d7a1eadc1101fb96707ecd348925e5aaed70 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Thu, 15 Aug 2024 21:10:37 + Subject: [PATCH] [mlir][sparse] unify block arguments order between iterate/coiterate operations. stack-info: PR: https://github.com/llvm/llvm-project/pull/105567, branch: users/PeimingLiu/stack/3 --- .../SparseTensor/IR/SparseTensorOps.td| 7 ++-- .../SparseTensor/IR/SparseTensorDialect.cpp | 31 .../Transforms/SparseIterationToScf.cpp | 36 ++- 3 files changed, 31 insertions(+), 43 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index 20512f972e67cd..96a61419a541f7 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -1644,7 +1644,7 @@ def IterateOp : SparseTensor_Op<"iterate", return getIterSpace().getType().getSpaceDim(); } BlockArgument getIterator() { - return getRegion().getArguments().front(); + return getRegion().getArguments().back(); } std::optional getLvlCrd(Level lvl) { if (getCrdUsedLvls()[lvl]) { @@ -1654,9 +1654,8 @@ def IterateOp : SparseTensor_Op<"iterate", return std::nullopt; } Block::BlockArgListType getCrds() { - // The first block argument is iterator, the remaining arguments are - // referenced coordinates. - return getRegion().getArguments().slice(1, getCrdUsedLvls().count()); + // User-provided iteration arguments -> coords -> iterator. + return getRegion().getArguments().slice(getNumRegionIterArgs(), getCrdUsedLvls().count()); } unsigned getNumRegionIterArgs() { return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count(); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 16856b958d4f13..b21bc1a93036c4 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -2228,9 +2228,10 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, parser.getNameLoc(), "mismatch in number of sparse iterators and sparse spaces"); - if (failed(parseUsedCoordList(parser, state, blockArgs))) + SmallVector coords; + if (failed(parseUsedCoordList(parser, state, coords))) return failure(); - size_t numCrds = blockArgs.size(); + size_t numCrds = coords.size(); // Parse "iter_args(%arg = %init, ...)" bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args")); @@ -2238,6 +2239,8 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, if (parser.parseAssignmentList(blockArgs, initArgs)) return failure(); + blockArgs.append(coords); + SmallVector iterSpaceTps; // parse ": sparse_tensor.iter_space -> ret" if (parser.parseColon() || parser.parseTypeList(iterSpaceTps)) @@ -2267,7 +2270,7 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, if (hasIterArgs) { // Strip off leading args that used for coordinates. -MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds); +MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds); if (args.size() != initArgs.size() || args.size() != state.types.size()) { return parser.emitError( parser.getNameLoc(), @@ -2448,18 +2451,18 @@ void IterateOp::build(OpBuilder &builder, OperationState &odsState, odsState.addTypes(initArgs.getTypes()); Block *bodyBlock = builder.createBlock(bodyRegion); - // First argument, sparse iterator - bodyBlock->addArgument( - llvm::cast(iterSpace.getType()).getIteratorType(), - odsState.location); + // Starts with a list of user-provided loop arguments. + for (Value v : initArgs) +bodyBlock->addArgument(v.getType(), v.getLoc()); - // Followed by a list of used coordinates. + // Follows by a list of used coordinates. for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++) bodyBlock->addArgument(builder.getIndexType(), odsState.location); - // Followed by a list of user-provided loop arguments. - for (Value v : initArgs) -bodyBlock->addArgument(v.getType(), v.getLoc()); + // Ends with sparse iterator + bodyBlock->addArgument( + llvm::cast(iterSpace.getType()).getIteratorType(), + odsState.location); } ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { @@ -2473,9 +2476,9 @@ ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { return parser.emitError(parser.getNameLoc(), "expected only one iterator/iteration space"); - iters.append(iterArgs); + iterArgs.append(iters); Region *body = r
[llvm-branch-commits] [mlir] [mlir][sparse] unify block arguments order between iterate/coiterate operations. (PR #105567)
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/105567 >From 3f83d7a1eadc1101fb96707ecd348925e5aaed70 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Thu, 15 Aug 2024 21:10:37 + Subject: [PATCH] [mlir][sparse] unify block arguments order between iterate/coiterate operations. stack-info: PR: https://github.com/llvm/llvm-project/pull/105567, branch: users/PeimingLiu/stack/3 --- .../SparseTensor/IR/SparseTensorOps.td| 7 ++-- .../SparseTensor/IR/SparseTensorDialect.cpp | 31 .../Transforms/SparseIterationToScf.cpp | 36 ++- 3 files changed, 31 insertions(+), 43 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index 20512f972e67cd..96a61419a541f7 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -1644,7 +1644,7 @@ def IterateOp : SparseTensor_Op<"iterate", return getIterSpace().getType().getSpaceDim(); } BlockArgument getIterator() { - return getRegion().getArguments().front(); + return getRegion().getArguments().back(); } std::optional getLvlCrd(Level lvl) { if (getCrdUsedLvls()[lvl]) { @@ -1654,9 +1654,8 @@ def IterateOp : SparseTensor_Op<"iterate", return std::nullopt; } Block::BlockArgListType getCrds() { - // The first block argument is iterator, the remaining arguments are - // referenced coordinates. - return getRegion().getArguments().slice(1, getCrdUsedLvls().count()); + // User-provided iteration arguments -> coords -> iterator. + return getRegion().getArguments().slice(getNumRegionIterArgs(), getCrdUsedLvls().count()); } unsigned getNumRegionIterArgs() { return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count(); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 16856b958d4f13..b21bc1a93036c4 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -2228,9 +2228,10 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, parser.getNameLoc(), "mismatch in number of sparse iterators and sparse spaces"); - if (failed(parseUsedCoordList(parser, state, blockArgs))) + SmallVector coords; + if (failed(parseUsedCoordList(parser, state, coords))) return failure(); - size_t numCrds = blockArgs.size(); + size_t numCrds = coords.size(); // Parse "iter_args(%arg = %init, ...)" bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args")); @@ -2238,6 +2239,8 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, if (parser.parseAssignmentList(blockArgs, initArgs)) return failure(); + blockArgs.append(coords); + SmallVector iterSpaceTps; // parse ": sparse_tensor.iter_space -> ret" if (parser.parseColon() || parser.parseTypeList(iterSpaceTps)) @@ -2267,7 +2270,7 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, if (hasIterArgs) { // Strip off leading args that used for coordinates. -MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds); +MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds); if (args.size() != initArgs.size() || args.size() != state.types.size()) { return parser.emitError( parser.getNameLoc(), @@ -2448,18 +2451,18 @@ void IterateOp::build(OpBuilder &builder, OperationState &odsState, odsState.addTypes(initArgs.getTypes()); Block *bodyBlock = builder.createBlock(bodyRegion); - // First argument, sparse iterator - bodyBlock->addArgument( - llvm::cast(iterSpace.getType()).getIteratorType(), - odsState.location); + // Starts with a list of user-provided loop arguments. + for (Value v : initArgs) +bodyBlock->addArgument(v.getType(), v.getLoc()); - // Followed by a list of used coordinates. + // Follows by a list of used coordinates. for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++) bodyBlock->addArgument(builder.getIndexType(), odsState.location); - // Followed by a list of user-provided loop arguments. - for (Value v : initArgs) -bodyBlock->addArgument(v.getType(), v.getLoc()); + // Ends with sparse iterator + bodyBlock->addArgument( + llvm::cast(iterSpace.getType()).getIteratorType(), + odsState.location); } ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { @@ -2473,9 +2476,9 @@ ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { return parser.emitError(parser.getNameLoc(), "expected only one iterator/iteration space"); - iters.append(iterArgs); + iterArgs.append(iters); Region *body = r
[llvm-branch-commits] [mlir] [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. (PR #105566)
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/105566 >From 937bcd814688e7c6f88ef27b7586254006e0d050 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Thu, 15 Aug 2024 18:10:25 + Subject: [PATCH] [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. stack-info: PR: https://github.com/llvm/llvm-project/pull/105566, branch: users/PeimingLiu/stack/2 --- .../Transforms/SparseIterationToScf.cpp | 118 ++ 1 file changed, 36 insertions(+), 82 deletions(-) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp index d6c0da4a9e4573..f7fcabb0220b50 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp @@ -244,88 +244,41 @@ class SparseIterateOpConverter : public OneToNOpConversionPattern { std::unique_ptr it = iterSpace.extractIterator(rewriter, loc); -if (it->iteratableByFor()) { - auto [lo, hi] = it->genForCond(rewriter, loc); - Value step = constantIndex(rewriter, loc, 1); - SmallVector ivs; - for (ValueRange inits : adaptor.getInitArgs()) -llvm::append_range(ivs, inits); - scf::ForOp forOp = rewriter.create(loc, lo, hi, step, ivs); - - Block *loopBody = op.getBody(); - OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes()); - if (failed(typeConverter->convertSignatureArgs( - loopBody->getArgumentTypes(), bodyTypeMapping))) -return failure(); - rewriter.applySignatureConversion(loopBody, bodyTypeMapping); - - rewriter.eraseBlock(forOp.getBody()); - Region &dstRegion = forOp.getRegion(); - rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); - - auto yieldOp = - llvm::cast(forOp.getBody()->getTerminator()); - - rewriter.setInsertionPointToEnd(forOp.getBody()); - // replace sparse_tensor.yield with scf.yield. - rewriter.create(loc, yieldOp.getResults()); - rewriter.eraseOp(yieldOp); - - const OneToNTypeMapping &resultMapping = adaptor.getResultMapping(); - rewriter.replaceOp(op, forOp.getResults(), resultMapping); -} else { - SmallVector ivs; - // TODO: put iterator at the end of argument list to be consistent with - // coiterate operation. - llvm::append_range(ivs, it->getCursor()); - for (ValueRange inits : adaptor.getInitArgs()) -llvm::append_range(ivs, inits); - - assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; })); - - TypeRange types = ValueRange(ivs).getTypes(); - auto whileOp = rewriter.create(loc, types, ivs); - SmallVector l(types.size(), op.getIterator().getLoc()); - - // Generates loop conditions. - Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l); - rewriter.setInsertionPointToStart(before); - ValueRange bArgs = before->getArguments(); - auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs); - assert(remArgs.size() == adaptor.getInitArgs().size()); - rewriter.create(loc, whileCond, before->getArguments()); - - // Generates loop body. - Block *loopBody = op.getBody(); - OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes()); - if (failed(typeConverter->convertSignatureArgs( - loopBody->getArgumentTypes(), bodyTypeMapping))) -return failure(); - rewriter.applySignatureConversion(loopBody, bodyTypeMapping); - - Region &dstRegion = whileOp.getAfter(); - // TODO: handle uses of coordinate! - rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end()); - ValueRange aArgs = whileOp.getAfterArguments(); - auto yieldOp = llvm::cast( - whileOp.getAfterBody()->getTerminator()); - - rewriter.setInsertionPointToEnd(whileOp.getAfterBody()); +SmallVector ivs; +for (ValueRange inits : adaptor.getInitArgs()) + llvm::append_range(ivs, inits); + +// Type conversion on iterate op block. +OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes()); +if (failed(typeConverter->convertSignatureArgs( +op.getBody()->getArgumentTypes(), blockTypeMapping))) + return rewriter.notifyMatchFailure( + op, "failed to convert iterate region argurment types"); +rewriter.applySignatureConversion(op.getBody(), blockTypeMapping); + +Block *block = op.getBody(); +ValueRange ret = genLoopWithIterator( +rewriter, loc, it.get(), ivs, /*iterFirst=*/true, +[block](PatternRewriter &rewriter, Location loc, Region &loopBody, +SparseIterator *it, ValueRange reduc) -> SmallVector { + SmallVector blockArgs(it->getCursor()); + // TODO: Also appends coordinates if used. + // blockArgs.push_back(it->deref(rewriter, loc)); +
[llvm-branch-commits] [mlir] [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. (PR #105566)
https://github.com/PeimingLiu edited https://github.com/llvm/llvm-project/pull/105566 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][sparse] unify block arguments order between iterate/coiterate operations. (PR #105567)
https://github.com/PeimingLiu edited https://github.com/llvm/llvm-project/pull/105567 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][sparse] unify block arguments order between iterate/coiterate operations. (PR #105567)
https://github.com/PeimingLiu updated https://github.com/llvm/llvm-project/pull/105567 >From 58bae5cff0b813347512a67a89e3abf6637ad0a9 Mon Sep 17 00:00:00 2001 From: Peiming Liu Date: Thu, 15 Aug 2024 21:10:37 + Subject: [PATCH] [mlir][sparse] unify block arguments order between iterate/coiterate operations. stack-info: PR: https://github.com/llvm/llvm-project/pull/105567, branch: users/PeimingLiu/stack/3 --- .../SparseTensor/IR/SparseTensorOps.td| 7 ++-- .../SparseTensor/IR/SparseTensorDialect.cpp | 31 .../Transforms/SparseIterationToScf.cpp | 36 ++- 3 files changed, 31 insertions(+), 43 deletions(-) diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index 20512f972e67cd..96a61419a541f7 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -1644,7 +1644,7 @@ def IterateOp : SparseTensor_Op<"iterate", return getIterSpace().getType().getSpaceDim(); } BlockArgument getIterator() { - return getRegion().getArguments().front(); + return getRegion().getArguments().back(); } std::optional getLvlCrd(Level lvl) { if (getCrdUsedLvls()[lvl]) { @@ -1654,9 +1654,8 @@ def IterateOp : SparseTensor_Op<"iterate", return std::nullopt; } Block::BlockArgListType getCrds() { - // The first block argument is iterator, the remaining arguments are - // referenced coordinates. - return getRegion().getArguments().slice(1, getCrdUsedLvls().count()); + // User-provided iteration arguments -> coords -> iterator. + return getRegion().getArguments().slice(getNumRegionIterArgs(), getCrdUsedLvls().count()); } unsigned getNumRegionIterArgs() { return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count(); diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 16856b958d4f13..b21bc1a93036c4 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -2228,9 +2228,10 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, parser.getNameLoc(), "mismatch in number of sparse iterators and sparse spaces"); - if (failed(parseUsedCoordList(parser, state, blockArgs))) + SmallVector coords; + if (failed(parseUsedCoordList(parser, state, coords))) return failure(); - size_t numCrds = blockArgs.size(); + size_t numCrds = coords.size(); // Parse "iter_args(%arg = %init, ...)" bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args")); @@ -2238,6 +2239,8 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, if (parser.parseAssignmentList(blockArgs, initArgs)) return failure(); + blockArgs.append(coords); + SmallVector iterSpaceTps; // parse ": sparse_tensor.iter_space -> ret" if (parser.parseColon() || parser.parseTypeList(iterSpaceTps)) @@ -2267,7 +2270,7 @@ parseSparseIterateLoop(OpAsmParser &parser, OperationState &state, if (hasIterArgs) { // Strip off leading args that used for coordinates. -MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds); +MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds); if (args.size() != initArgs.size() || args.size() != state.types.size()) { return parser.emitError( parser.getNameLoc(), @@ -2448,18 +2451,18 @@ void IterateOp::build(OpBuilder &builder, OperationState &odsState, odsState.addTypes(initArgs.getTypes()); Block *bodyBlock = builder.createBlock(bodyRegion); - // First argument, sparse iterator - bodyBlock->addArgument( - llvm::cast(iterSpace.getType()).getIteratorType(), - odsState.location); + // Starts with a list of user-provided loop arguments. + for (Value v : initArgs) +bodyBlock->addArgument(v.getType(), v.getLoc()); - // Followed by a list of used coordinates. + // Follows by a list of used coordinates. for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++) bodyBlock->addArgument(builder.getIndexType(), odsState.location); - // Followed by a list of user-provided loop arguments. - for (Value v : initArgs) -bodyBlock->addArgument(v.getType(), v.getLoc()); + // Ends with sparse iterator + bodyBlock->addArgument( + llvm::cast(iterSpace.getType()).getIteratorType(), + odsState.location); } ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { @@ -2473,9 +2476,9 @@ ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) { return parser.emitError(parser.getNameLoc(), "expected only one iterator/iteration space"); - iters.append(iterArgs); + iterArgs.append(iters); Region *body = r
[llvm-branch-commits] [mlir] [mlir][SparseTensor] Fix type conversion rule (PR #140350)
PeimingLiu wrote: Thx! https://github.com/llvm/llvm-project/pull/140350 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits
[llvm-branch-commits] [mlir] [mlir][SparseTensor] Fix type conversion rule (PR #140350)
https://github.com/PeimingLiu approved this pull request. https://github.com/llvm/llvm-project/pull/140350 ___ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits