[llvm-branch-commits] [mlir] [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. (PR #105566)

2024-08-21 Thread Peiming Liu via llvm-branch-commits

https://github.com/PeimingLiu created 
https://github.com/llvm/llvm-project/pull/105566

[mlir][sparse] refactoring sparse_tensor.iterate lowering pattern 
implementation.

>From 1a32495b27dfd003408dd5b4f12f3db7f0b73b5a Mon Sep 17 00:00:00 2001
From: Peiming Liu 
Date: Thu, 15 Aug 2024 18:10:25 +
Subject: [PATCH] [mlir][sparse] refactoring sparse_tensor.iterate lowering
 pattern implementation.

---
 .../Transforms/SparseIterationToScf.cpp   | 118 ++
 1 file changed, 36 insertions(+), 82 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp 
b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index d6c0da4a9e4573..f7fcabb0220b50 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -244,88 +244,41 @@ class SparseIterateOpConverter : public 
OneToNOpConversionPattern {
 std::unique_ptr it =
 iterSpace.extractIterator(rewriter, loc);
 
-if (it->iteratableByFor()) {
-  auto [lo, hi] = it->genForCond(rewriter, loc);
-  Value step = constantIndex(rewriter, loc, 1);
-  SmallVector ivs;
-  for (ValueRange inits : adaptor.getInitArgs())
-llvm::append_range(ivs, inits);
-  scf::ForOp forOp = rewriter.create(loc, lo, hi, step, ivs);
-
-  Block *loopBody = op.getBody();
-  OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
-  if (failed(typeConverter->convertSignatureArgs(
-  loopBody->getArgumentTypes(), bodyTypeMapping)))
-return failure();
-  rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
-  rewriter.eraseBlock(forOp.getBody());
-  Region &dstRegion = forOp.getRegion();
-  rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
-
-  auto yieldOp =
-  llvm::cast(forOp.getBody()->getTerminator());
-
-  rewriter.setInsertionPointToEnd(forOp.getBody());
-  // replace sparse_tensor.yield with scf.yield.
-  rewriter.create(loc, yieldOp.getResults());
-  rewriter.eraseOp(yieldOp);
-
-  const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
-  rewriter.replaceOp(op, forOp.getResults(), resultMapping);
-} else {
-  SmallVector ivs;
-  // TODO: put iterator at the end of argument list to be consistent with
-  // coiterate operation.
-  llvm::append_range(ivs, it->getCursor());
-  for (ValueRange inits : adaptor.getInitArgs())
-llvm::append_range(ivs, inits);
-
-  assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
-
-  TypeRange types = ValueRange(ivs).getTypes();
-  auto whileOp = rewriter.create(loc, types, ivs);
-  SmallVector l(types.size(), op.getIterator().getLoc());
-
-  // Generates loop conditions.
-  Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
-  rewriter.setInsertionPointToStart(before);
-  ValueRange bArgs = before->getArguments();
-  auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
-  assert(remArgs.size() == adaptor.getInitArgs().size());
-  rewriter.create(loc, whileCond, 
before->getArguments());
-
-  // Generates loop body.
-  Block *loopBody = op.getBody();
-  OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
-  if (failed(typeConverter->convertSignatureArgs(
-  loopBody->getArgumentTypes(), bodyTypeMapping)))
-return failure();
-  rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
-  Region &dstRegion = whileOp.getAfter();
-  // TODO: handle uses of coordinate!
-  rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
-  ValueRange aArgs = whileOp.getAfterArguments();
-  auto yieldOp = llvm::cast(
-  whileOp.getAfterBody()->getTerminator());
-
-  rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
+SmallVector ivs;
+for (ValueRange inits : adaptor.getInitArgs())
+  llvm::append_range(ivs, inits);
+
+// Type conversion on iterate op block.
+OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
+if (failed(typeConverter->convertSignatureArgs(
+op.getBody()->getArgumentTypes(), blockTypeMapping)))
+  return rewriter.notifyMatchFailure(
+  op, "failed to convert iterate region argurment types");
+rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);
+
+Block *block = op.getBody();
+ValueRange ret = genLoopWithIterator(
+rewriter, loc, it.get(), ivs, /*iterFirst=*/true,
+[block](PatternRewriter &rewriter, Location loc, Region &loopBody,
+SparseIterator *it, ValueRange reduc) -> SmallVector {
+  SmallVector blockArgs(it->getCursor());
+  // TODO: Also appends coordinates if used.
+  // blockArgs.push_back(it->deref(rewriter, loc));
+  llvm::a

[llvm-branch-commits] [mlir] [mlir][sparse] unify block arguments order between iterate/coiterate operations. (PR #105567)

2024-08-21 Thread Peiming Liu via llvm-branch-commits

https://github.com/PeimingLiu created 
https://github.com/llvm/llvm-project/pull/105567

[mlir][sparse] unify block arguments order between iterate/coiterate operations.

>From 6fd099fb7039f8fda37d50f1c44cd7afd62cafb7 Mon Sep 17 00:00:00 2001
From: Peiming Liu 
Date: Thu, 15 Aug 2024 21:10:37 +
Subject: [PATCH] [mlir][sparse] unify block arguments order between
 iterate/coiterate operations.

---
 .../SparseTensor/IR/SparseTensorOps.td|  7 ++--
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 31 
 .../Transforms/SparseIterationToScf.cpp   | 36 ++-
 3 files changed, 31 insertions(+), 43 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td 
b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 20512f972e67cd..96a61419a541f7 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1644,7 +1644,7 @@ def IterateOp : SparseTensor_Op<"iterate",
   return getIterSpace().getType().getSpaceDim();
 }
 BlockArgument getIterator() {
-  return getRegion().getArguments().front();
+  return getRegion().getArguments().back();
 }
 std::optional getLvlCrd(Level lvl) {
   if (getCrdUsedLvls()[lvl]) {
@@ -1654,9 +1654,8 @@ def IterateOp : SparseTensor_Op<"iterate",
   return std::nullopt;
 }
 Block::BlockArgListType getCrds() {
-  // The first block argument is iterator, the remaining arguments are
-  // referenced coordinates.
-  return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+  // User-provided iteration arguments -> coords -> iterator.
+  return getRegion().getArguments().slice(getNumRegionIterArgs(), 
getCrdUsedLvls().count());
 }
 unsigned getNumRegionIterArgs() {
   return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp 
b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 16856b958d4f13..b21bc1a93036c4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2228,9 +2228,10 @@ parseSparseIterateLoop(OpAsmParser &parser, 
OperationState &state,
 parser.getNameLoc(),
 "mismatch in number of sparse iterators and sparse spaces");
 
-  if (failed(parseUsedCoordList(parser, state, blockArgs)))
+  SmallVector coords;
+  if (failed(parseUsedCoordList(parser, state, coords)))
 return failure();
-  size_t numCrds = blockArgs.size();
+  size_t numCrds = coords.size();
 
   // Parse "iter_args(%arg = %init, ...)"
   bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
@@ -2238,6 +2239,8 @@ parseSparseIterateLoop(OpAsmParser &parser, 
OperationState &state,
 if (parser.parseAssignmentList(blockArgs, initArgs))
   return failure();
 
+  blockArgs.append(coords);
+
   SmallVector iterSpaceTps;
   // parse ": sparse_tensor.iter_space -> ret"
   if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
@@ -2267,7 +2270,7 @@ parseSparseIterateLoop(OpAsmParser &parser, 
OperationState &state,
 
   if (hasIterArgs) {
 // Strip off leading args that used for coordinates.
-MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
+MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
   return parser.emitError(
   parser.getNameLoc(),
@@ -2448,18 +2451,18 @@ void IterateOp::build(OpBuilder &builder, 
OperationState &odsState,
   odsState.addTypes(initArgs.getTypes());
   Block *bodyBlock = builder.createBlock(bodyRegion);
 
-  // First argument, sparse iterator
-  bodyBlock->addArgument(
-  llvm::cast(iterSpace.getType()).getIteratorType(),
-  odsState.location);
+  // Starts with a list of user-provided loop arguments.
+  for (Value v : initArgs)
+bodyBlock->addArgument(v.getType(), v.getLoc());
 
-  // Followed by a list of used coordinates.
+  // Follows by a list of used coordinates.
   for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
 bodyBlock->addArgument(builder.getIndexType(), odsState.location);
 
-  // Followed by a list of user-provided loop arguments.
-  for (Value v : initArgs)
-bodyBlock->addArgument(v.getType(), v.getLoc());
+  // Ends with sparse iterator
+  bodyBlock->addArgument(
+  llvm::cast(iterSpace.getType()).getIteratorType(),
+  odsState.location);
 }
 
 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -2473,9 +2476,9 @@ ParseResult IterateOp::parse(OpAsmParser &parser, 
OperationState &result) {
 return parser.emitError(parser.getNameLoc(),
 "expected only one iterator/iteration space");
 
-  iters.append(iterArgs);
+  iterArgs.append(iters);
   Region *body = result.addRegion();

[llvm-branch-commits] [mlir] [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. (PR #105566)

2024-08-21 Thread Peiming Liu via llvm-branch-commits

https://github.com/PeimingLiu updated 
https://github.com/llvm/llvm-project/pull/105566

>From 937bcd814688e7c6f88ef27b7586254006e0d050 Mon Sep 17 00:00:00 2001
From: Peiming Liu 
Date: Thu, 15 Aug 2024 18:10:25 +
Subject: [PATCH] [mlir][sparse] refactoring sparse_tensor.iterate lowering
 pattern implementation.

stack-info: PR: https://github.com/llvm/llvm-project/pull/105566, branch: 
users/PeimingLiu/stack/2
---
 .../Transforms/SparseIterationToScf.cpp   | 118 ++
 1 file changed, 36 insertions(+), 82 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp 
b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index d6c0da4a9e4573..f7fcabb0220b50 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -244,88 +244,41 @@ class SparseIterateOpConverter : public 
OneToNOpConversionPattern {
 std::unique_ptr it =
 iterSpace.extractIterator(rewriter, loc);
 
-if (it->iteratableByFor()) {
-  auto [lo, hi] = it->genForCond(rewriter, loc);
-  Value step = constantIndex(rewriter, loc, 1);
-  SmallVector ivs;
-  for (ValueRange inits : adaptor.getInitArgs())
-llvm::append_range(ivs, inits);
-  scf::ForOp forOp = rewriter.create(loc, lo, hi, step, ivs);
-
-  Block *loopBody = op.getBody();
-  OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
-  if (failed(typeConverter->convertSignatureArgs(
-  loopBody->getArgumentTypes(), bodyTypeMapping)))
-return failure();
-  rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
-  rewriter.eraseBlock(forOp.getBody());
-  Region &dstRegion = forOp.getRegion();
-  rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
-
-  auto yieldOp =
-  llvm::cast(forOp.getBody()->getTerminator());
-
-  rewriter.setInsertionPointToEnd(forOp.getBody());
-  // replace sparse_tensor.yield with scf.yield.
-  rewriter.create(loc, yieldOp.getResults());
-  rewriter.eraseOp(yieldOp);
-
-  const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
-  rewriter.replaceOp(op, forOp.getResults(), resultMapping);
-} else {
-  SmallVector ivs;
-  // TODO: put iterator at the end of argument list to be consistent with
-  // coiterate operation.
-  llvm::append_range(ivs, it->getCursor());
-  for (ValueRange inits : adaptor.getInitArgs())
-llvm::append_range(ivs, inits);
-
-  assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
-
-  TypeRange types = ValueRange(ivs).getTypes();
-  auto whileOp = rewriter.create(loc, types, ivs);
-  SmallVector l(types.size(), op.getIterator().getLoc());
-
-  // Generates loop conditions.
-  Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
-  rewriter.setInsertionPointToStart(before);
-  ValueRange bArgs = before->getArguments();
-  auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
-  assert(remArgs.size() == adaptor.getInitArgs().size());
-  rewriter.create(loc, whileCond, 
before->getArguments());
-
-  // Generates loop body.
-  Block *loopBody = op.getBody();
-  OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
-  if (failed(typeConverter->convertSignatureArgs(
-  loopBody->getArgumentTypes(), bodyTypeMapping)))
-return failure();
-  rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
-  Region &dstRegion = whileOp.getAfter();
-  // TODO: handle uses of coordinate!
-  rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
-  ValueRange aArgs = whileOp.getAfterArguments();
-  auto yieldOp = llvm::cast(
-  whileOp.getAfterBody()->getTerminator());
-
-  rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
+SmallVector ivs;
+for (ValueRange inits : adaptor.getInitArgs())
+  llvm::append_range(ivs, inits);
+
+// Type conversion on iterate op block.
+OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
+if (failed(typeConverter->convertSignatureArgs(
+op.getBody()->getArgumentTypes(), blockTypeMapping)))
+  return rewriter.notifyMatchFailure(
+  op, "failed to convert iterate region argurment types");
+rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);
+
+Block *block = op.getBody();
+ValueRange ret = genLoopWithIterator(
+rewriter, loc, it.get(), ivs, /*iterFirst=*/true,
+[block](PatternRewriter &rewriter, Location loc, Region &loopBody,
+SparseIterator *it, ValueRange reduc) -> SmallVector {
+  SmallVector blockArgs(it->getCursor());
+  // TODO: Also appends coordinates if used.
+  // blockArgs.push_back(it->deref(rewriter, loc));
+ 

[llvm-branch-commits] [mlir] [mlir][sparse] unify block arguments order between iterate/coiterate operations. (PR #105567)

2024-08-21 Thread Peiming Liu via llvm-branch-commits

https://github.com/PeimingLiu updated 
https://github.com/llvm/llvm-project/pull/105567

>From 3f83d7a1eadc1101fb96707ecd348925e5aaed70 Mon Sep 17 00:00:00 2001
From: Peiming Liu 
Date: Thu, 15 Aug 2024 21:10:37 +
Subject: [PATCH] [mlir][sparse] unify block arguments order between
 iterate/coiterate operations.

stack-info: PR: https://github.com/llvm/llvm-project/pull/105567, branch: 
users/PeimingLiu/stack/3
---
 .../SparseTensor/IR/SparseTensorOps.td|  7 ++--
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 31 
 .../Transforms/SparseIterationToScf.cpp   | 36 ++-
 3 files changed, 31 insertions(+), 43 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td 
b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 20512f972e67cd..96a61419a541f7 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1644,7 +1644,7 @@ def IterateOp : SparseTensor_Op<"iterate",
   return getIterSpace().getType().getSpaceDim();
 }
 BlockArgument getIterator() {
-  return getRegion().getArguments().front();
+  return getRegion().getArguments().back();
 }
 std::optional getLvlCrd(Level lvl) {
   if (getCrdUsedLvls()[lvl]) {
@@ -1654,9 +1654,8 @@ def IterateOp : SparseTensor_Op<"iterate",
   return std::nullopt;
 }
 Block::BlockArgListType getCrds() {
-  // The first block argument is iterator, the remaining arguments are
-  // referenced coordinates.
-  return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+  // User-provided iteration arguments -> coords -> iterator.
+  return getRegion().getArguments().slice(getNumRegionIterArgs(), 
getCrdUsedLvls().count());
 }
 unsigned getNumRegionIterArgs() {
   return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp 
b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 16856b958d4f13..b21bc1a93036c4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2228,9 +2228,10 @@ parseSparseIterateLoop(OpAsmParser &parser, 
OperationState &state,
 parser.getNameLoc(),
 "mismatch in number of sparse iterators and sparse spaces");
 
-  if (failed(parseUsedCoordList(parser, state, blockArgs)))
+  SmallVector coords;
+  if (failed(parseUsedCoordList(parser, state, coords)))
 return failure();
-  size_t numCrds = blockArgs.size();
+  size_t numCrds = coords.size();
 
   // Parse "iter_args(%arg = %init, ...)"
   bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
@@ -2238,6 +2239,8 @@ parseSparseIterateLoop(OpAsmParser &parser, 
OperationState &state,
 if (parser.parseAssignmentList(blockArgs, initArgs))
   return failure();
 
+  blockArgs.append(coords);
+
   SmallVector iterSpaceTps;
   // parse ": sparse_tensor.iter_space -> ret"
   if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
@@ -2267,7 +2270,7 @@ parseSparseIterateLoop(OpAsmParser &parser, 
OperationState &state,
 
   if (hasIterArgs) {
 // Strip off leading args that used for coordinates.
-MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
+MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
   return parser.emitError(
   parser.getNameLoc(),
@@ -2448,18 +2451,18 @@ void IterateOp::build(OpBuilder &builder, 
OperationState &odsState,
   odsState.addTypes(initArgs.getTypes());
   Block *bodyBlock = builder.createBlock(bodyRegion);
 
-  // First argument, sparse iterator
-  bodyBlock->addArgument(
-  llvm::cast(iterSpace.getType()).getIteratorType(),
-  odsState.location);
+  // Starts with a list of user-provided loop arguments.
+  for (Value v : initArgs)
+bodyBlock->addArgument(v.getType(), v.getLoc());
 
-  // Followed by a list of used coordinates.
+  // Follows by a list of used coordinates.
   for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
 bodyBlock->addArgument(builder.getIndexType(), odsState.location);
 
-  // Followed by a list of user-provided loop arguments.
-  for (Value v : initArgs)
-bodyBlock->addArgument(v.getType(), v.getLoc());
+  // Ends with sparse iterator
+  bodyBlock->addArgument(
+  llvm::cast(iterSpace.getType()).getIteratorType(),
+  odsState.location);
 }
 
 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -2473,9 +2476,9 @@ ParseResult IterateOp::parse(OpAsmParser &parser, 
OperationState &result) {
 return parser.emitError(parser.getNameLoc(),
 "expected only one iterator/iteration space");
 
-  iters.append(iterArgs);
+  iterArgs.append(iters);
   Region *body = r

[llvm-branch-commits] [mlir] [mlir][sparse] unify block arguments order between iterate/coiterate operations. (PR #105567)

2024-08-21 Thread Peiming Liu via llvm-branch-commits

https://github.com/PeimingLiu updated 
https://github.com/llvm/llvm-project/pull/105567

>From 3f83d7a1eadc1101fb96707ecd348925e5aaed70 Mon Sep 17 00:00:00 2001
From: Peiming Liu 
Date: Thu, 15 Aug 2024 21:10:37 +
Subject: [PATCH] [mlir][sparse] unify block arguments order between
 iterate/coiterate operations.

stack-info: PR: https://github.com/llvm/llvm-project/pull/105567, branch: 
users/PeimingLiu/stack/3
---
 .../SparseTensor/IR/SparseTensorOps.td|  7 ++--
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 31 
 .../Transforms/SparseIterationToScf.cpp   | 36 ++-
 3 files changed, 31 insertions(+), 43 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td 
b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 20512f972e67cd..96a61419a541f7 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1644,7 +1644,7 @@ def IterateOp : SparseTensor_Op<"iterate",
   return getIterSpace().getType().getSpaceDim();
 }
 BlockArgument getIterator() {
-  return getRegion().getArguments().front();
+  return getRegion().getArguments().back();
 }
 std::optional getLvlCrd(Level lvl) {
   if (getCrdUsedLvls()[lvl]) {
@@ -1654,9 +1654,8 @@ def IterateOp : SparseTensor_Op<"iterate",
   return std::nullopt;
 }
 Block::BlockArgListType getCrds() {
-  // The first block argument is iterator, the remaining arguments are
-  // referenced coordinates.
-  return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+  // User-provided iteration arguments -> coords -> iterator.
+  return getRegion().getArguments().slice(getNumRegionIterArgs(), 
getCrdUsedLvls().count());
 }
 unsigned getNumRegionIterArgs() {
   return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp 
b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 16856b958d4f13..b21bc1a93036c4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2228,9 +2228,10 @@ parseSparseIterateLoop(OpAsmParser &parser, 
OperationState &state,
 parser.getNameLoc(),
 "mismatch in number of sparse iterators and sparse spaces");
 
-  if (failed(parseUsedCoordList(parser, state, blockArgs)))
+  SmallVector coords;
+  if (failed(parseUsedCoordList(parser, state, coords)))
 return failure();
-  size_t numCrds = blockArgs.size();
+  size_t numCrds = coords.size();
 
   // Parse "iter_args(%arg = %init, ...)"
   bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
@@ -2238,6 +2239,8 @@ parseSparseIterateLoop(OpAsmParser &parser, 
OperationState &state,
 if (parser.parseAssignmentList(blockArgs, initArgs))
   return failure();
 
+  blockArgs.append(coords);
+
   SmallVector iterSpaceTps;
   // parse ": sparse_tensor.iter_space -> ret"
   if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
@@ -2267,7 +2270,7 @@ parseSparseIterateLoop(OpAsmParser &parser, 
OperationState &state,
 
   if (hasIterArgs) {
 // Strip off leading args that used for coordinates.
-MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
+MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
   return parser.emitError(
   parser.getNameLoc(),
@@ -2448,18 +2451,18 @@ void IterateOp::build(OpBuilder &builder, 
OperationState &odsState,
   odsState.addTypes(initArgs.getTypes());
   Block *bodyBlock = builder.createBlock(bodyRegion);
 
-  // First argument, sparse iterator
-  bodyBlock->addArgument(
-  llvm::cast(iterSpace.getType()).getIteratorType(),
-  odsState.location);
+  // Starts with a list of user-provided loop arguments.
+  for (Value v : initArgs)
+bodyBlock->addArgument(v.getType(), v.getLoc());
 
-  // Followed by a list of used coordinates.
+  // Follows by a list of used coordinates.
   for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
 bodyBlock->addArgument(builder.getIndexType(), odsState.location);
 
-  // Followed by a list of user-provided loop arguments.
-  for (Value v : initArgs)
-bodyBlock->addArgument(v.getType(), v.getLoc());
+  // Ends with sparse iterator
+  bodyBlock->addArgument(
+  llvm::cast(iterSpace.getType()).getIteratorType(),
+  odsState.location);
 }
 
 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -2473,9 +2476,9 @@ ParseResult IterateOp::parse(OpAsmParser &parser, 
OperationState &result) {
 return parser.emitError(parser.getNameLoc(),
 "expected only one iterator/iteration space");
 
-  iters.append(iterArgs);
+  iterArgs.append(iters);
   Region *body = r

[llvm-branch-commits] [mlir] [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. (PR #105566)

2024-08-21 Thread Peiming Liu via llvm-branch-commits

https://github.com/PeimingLiu updated 
https://github.com/llvm/llvm-project/pull/105566

>From 937bcd814688e7c6f88ef27b7586254006e0d050 Mon Sep 17 00:00:00 2001
From: Peiming Liu 
Date: Thu, 15 Aug 2024 18:10:25 +
Subject: [PATCH] [mlir][sparse] refactoring sparse_tensor.iterate lowering
 pattern implementation.

stack-info: PR: https://github.com/llvm/llvm-project/pull/105566, branch: 
users/PeimingLiu/stack/2
---
 .../Transforms/SparseIterationToScf.cpp   | 118 ++
 1 file changed, 36 insertions(+), 82 deletions(-)

diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp 
b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
index d6c0da4a9e4573..f7fcabb0220b50 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseIterationToScf.cpp
@@ -244,88 +244,41 @@ class SparseIterateOpConverter : public 
OneToNOpConversionPattern {
 std::unique_ptr it =
 iterSpace.extractIterator(rewriter, loc);
 
-if (it->iteratableByFor()) {
-  auto [lo, hi] = it->genForCond(rewriter, loc);
-  Value step = constantIndex(rewriter, loc, 1);
-  SmallVector ivs;
-  for (ValueRange inits : adaptor.getInitArgs())
-llvm::append_range(ivs, inits);
-  scf::ForOp forOp = rewriter.create(loc, lo, hi, step, ivs);
-
-  Block *loopBody = op.getBody();
-  OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
-  if (failed(typeConverter->convertSignatureArgs(
-  loopBody->getArgumentTypes(), bodyTypeMapping)))
-return failure();
-  rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
-  rewriter.eraseBlock(forOp.getBody());
-  Region &dstRegion = forOp.getRegion();
-  rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
-
-  auto yieldOp =
-  llvm::cast(forOp.getBody()->getTerminator());
-
-  rewriter.setInsertionPointToEnd(forOp.getBody());
-  // replace sparse_tensor.yield with scf.yield.
-  rewriter.create(loc, yieldOp.getResults());
-  rewriter.eraseOp(yieldOp);
-
-  const OneToNTypeMapping &resultMapping = adaptor.getResultMapping();
-  rewriter.replaceOp(op, forOp.getResults(), resultMapping);
-} else {
-  SmallVector ivs;
-  // TODO: put iterator at the end of argument list to be consistent with
-  // coiterate operation.
-  llvm::append_range(ivs, it->getCursor());
-  for (ValueRange inits : adaptor.getInitArgs())
-llvm::append_range(ivs, inits);
-
-  assert(llvm::all_of(ivs, [](Value v) { return v != nullptr; }));
-
-  TypeRange types = ValueRange(ivs).getTypes();
-  auto whileOp = rewriter.create(loc, types, ivs);
-  SmallVector l(types.size(), op.getIterator().getLoc());
-
-  // Generates loop conditions.
-  Block *before = rewriter.createBlock(&whileOp.getBefore(), {}, types, l);
-  rewriter.setInsertionPointToStart(before);
-  ValueRange bArgs = before->getArguments();
-  auto [whileCond, remArgs] = it->genWhileCond(rewriter, loc, bArgs);
-  assert(remArgs.size() == adaptor.getInitArgs().size());
-  rewriter.create(loc, whileCond, 
before->getArguments());
-
-  // Generates loop body.
-  Block *loopBody = op.getBody();
-  OneToNTypeMapping bodyTypeMapping(loopBody->getArgumentTypes());
-  if (failed(typeConverter->convertSignatureArgs(
-  loopBody->getArgumentTypes(), bodyTypeMapping)))
-return failure();
-  rewriter.applySignatureConversion(loopBody, bodyTypeMapping);
-
-  Region &dstRegion = whileOp.getAfter();
-  // TODO: handle uses of coordinate!
-  rewriter.inlineRegionBefore(op.getRegion(), dstRegion, dstRegion.end());
-  ValueRange aArgs = whileOp.getAfterArguments();
-  auto yieldOp = llvm::cast(
-  whileOp.getAfterBody()->getTerminator());
-
-  rewriter.setInsertionPointToEnd(whileOp.getAfterBody());
+SmallVector ivs;
+for (ValueRange inits : adaptor.getInitArgs())
+  llvm::append_range(ivs, inits);
+
+// Type conversion on iterate op block.
+OneToNTypeMapping blockTypeMapping(op.getBody()->getArgumentTypes());
+if (failed(typeConverter->convertSignatureArgs(
+op.getBody()->getArgumentTypes(), blockTypeMapping)))
+  return rewriter.notifyMatchFailure(
+  op, "failed to convert iterate region argurment types");
+rewriter.applySignatureConversion(op.getBody(), blockTypeMapping);
+
+Block *block = op.getBody();
+ValueRange ret = genLoopWithIterator(
+rewriter, loc, it.get(), ivs, /*iterFirst=*/true,
+[block](PatternRewriter &rewriter, Location loc, Region &loopBody,
+SparseIterator *it, ValueRange reduc) -> SmallVector {
+  SmallVector blockArgs(it->getCursor());
+  // TODO: Also appends coordinates if used.
+  // blockArgs.push_back(it->deref(rewriter, loc));
+ 

[llvm-branch-commits] [mlir] [mlir][sparse] refactoring sparse_tensor.iterate lowering pattern implementation. (PR #105566)

2024-08-22 Thread Peiming Liu via llvm-branch-commits

https://github.com/PeimingLiu edited 
https://github.com/llvm/llvm-project/pull/105566
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [mlir][sparse] unify block arguments order between iterate/coiterate operations. (PR #105567)

2024-08-22 Thread Peiming Liu via llvm-branch-commits

https://github.com/PeimingLiu edited 
https://github.com/llvm/llvm-project/pull/105567
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [mlir][sparse] unify block arguments order between iterate/coiterate operations. (PR #105567)

2024-08-23 Thread Peiming Liu via llvm-branch-commits

https://github.com/PeimingLiu updated 
https://github.com/llvm/llvm-project/pull/105567

>From 58bae5cff0b813347512a67a89e3abf6637ad0a9 Mon Sep 17 00:00:00 2001
From: Peiming Liu 
Date: Thu, 15 Aug 2024 21:10:37 +
Subject: [PATCH] [mlir][sparse] unify block arguments order between
 iterate/coiterate operations.

stack-info: PR: https://github.com/llvm/llvm-project/pull/105567, branch: 
users/PeimingLiu/stack/3
---
 .../SparseTensor/IR/SparseTensorOps.td|  7 ++--
 .../SparseTensor/IR/SparseTensorDialect.cpp   | 31 
 .../Transforms/SparseIterationToScf.cpp   | 36 ++-
 3 files changed, 31 insertions(+), 43 deletions(-)

diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td 
b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
index 20512f972e67cd..96a61419a541f7 100644
--- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
+++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td
@@ -1644,7 +1644,7 @@ def IterateOp : SparseTensor_Op<"iterate",
   return getIterSpace().getType().getSpaceDim();
 }
 BlockArgument getIterator() {
-  return getRegion().getArguments().front();
+  return getRegion().getArguments().back();
 }
 std::optional getLvlCrd(Level lvl) {
   if (getCrdUsedLvls()[lvl]) {
@@ -1654,9 +1654,8 @@ def IterateOp : SparseTensor_Op<"iterate",
   return std::nullopt;
 }
 Block::BlockArgListType getCrds() {
-  // The first block argument is iterator, the remaining arguments are
-  // referenced coordinates.
-  return getRegion().getArguments().slice(1, getCrdUsedLvls().count());
+  // User-provided iteration arguments -> coords -> iterator.
+  return getRegion().getArguments().slice(getNumRegionIterArgs(), 
getCrdUsedLvls().count());
 }
 unsigned getNumRegionIterArgs() {
   return getRegion().getArguments().size() - 1 - getCrdUsedLvls().count();
diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp 
b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
index 16856b958d4f13..b21bc1a93036c4 100644
--- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
+++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp
@@ -2228,9 +2228,10 @@ parseSparseIterateLoop(OpAsmParser &parser, 
OperationState &state,
 parser.getNameLoc(),
 "mismatch in number of sparse iterators and sparse spaces");
 
-  if (failed(parseUsedCoordList(parser, state, blockArgs)))
+  SmallVector coords;
+  if (failed(parseUsedCoordList(parser, state, coords)))
 return failure();
-  size_t numCrds = blockArgs.size();
+  size_t numCrds = coords.size();
 
   // Parse "iter_args(%arg = %init, ...)"
   bool hasIterArgs = succeeded(parser.parseOptionalKeyword("iter_args"));
@@ -2238,6 +2239,8 @@ parseSparseIterateLoop(OpAsmParser &parser, 
OperationState &state,
 if (parser.parseAssignmentList(blockArgs, initArgs))
   return failure();
 
+  blockArgs.append(coords);
+
   SmallVector iterSpaceTps;
   // parse ": sparse_tensor.iter_space -> ret"
   if (parser.parseColon() || parser.parseTypeList(iterSpaceTps))
@@ -2267,7 +2270,7 @@ parseSparseIterateLoop(OpAsmParser &parser, 
OperationState &state,
 
   if (hasIterArgs) {
 // Strip off leading args that used for coordinates.
-MutableArrayRef args = MutableArrayRef(blockArgs).drop_front(numCrds);
+MutableArrayRef args = MutableArrayRef(blockArgs).drop_back(numCrds);
 if (args.size() != initArgs.size() || args.size() != state.types.size()) {
   return parser.emitError(
   parser.getNameLoc(),
@@ -2448,18 +2451,18 @@ void IterateOp::build(OpBuilder &builder, 
OperationState &odsState,
   odsState.addTypes(initArgs.getTypes());
   Block *bodyBlock = builder.createBlock(bodyRegion);
 
-  // First argument, sparse iterator
-  bodyBlock->addArgument(
-  llvm::cast(iterSpace.getType()).getIteratorType(),
-  odsState.location);
+  // Starts with a list of user-provided loop arguments.
+  for (Value v : initArgs)
+bodyBlock->addArgument(v.getType(), v.getLoc());
 
-  // Followed by a list of used coordinates.
+  // Follows by a list of used coordinates.
   for (unsigned i = 0, e = crdUsedLvls.count(); i < e; i++)
 bodyBlock->addArgument(builder.getIndexType(), odsState.location);
 
-  // Followed by a list of user-provided loop arguments.
-  for (Value v : initArgs)
-bodyBlock->addArgument(v.getType(), v.getLoc());
+  // Ends with sparse iterator
+  bodyBlock->addArgument(
+  llvm::cast(iterSpace.getType()).getIteratorType(),
+  odsState.location);
 }
 
 ParseResult IterateOp::parse(OpAsmParser &parser, OperationState &result) {
@@ -2473,9 +2476,9 @@ ParseResult IterateOp::parse(OpAsmParser &parser, 
OperationState &result) {
 return parser.emitError(parser.getNameLoc(),
 "expected only one iterator/iteration space");
 
-  iters.append(iterArgs);
+  iterArgs.append(iters);
   Region *body = r

[llvm-branch-commits] [mlir] [mlir][SparseTensor] Fix type conversion rule (PR #140350)

2025-05-17 Thread Peiming Liu via llvm-branch-commits

PeimingLiu wrote:

Thx!

https://github.com/llvm/llvm-project/pull/140350
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits


[llvm-branch-commits] [mlir] [mlir][SparseTensor] Fix type conversion rule (PR #140350)

2025-05-17 Thread Peiming Liu via llvm-branch-commits

https://github.com/PeimingLiu approved this pull request.


https://github.com/llvm/llvm-project/pull/140350
___
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits