Author: Qinkun Bao
Date: 2025-06-24T10:09:20-04:00
New Revision: b8672c3278bf3ee83e8c44053d03558632ba46e0

URL: 
https://github.com/llvm/llvm-project/commit/b8672c3278bf3ee83e8c44053d03558632ba46e0
DIFF: 
https://github.com/llvm/llvm-project/commit/b8672c3278bf3ee83e8c44053d03558632ba46e0.diff

LOG: Revert "[mlir][mesh] adding option for traversal order in sharding 
propagatio…"

This reverts commit 43e1a5a411d972fe06a1afb86ffd5ba21fd2a376.

Added: 
    

Modified: 
    mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
    mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
    mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
    mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
    mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp

Removed: 
    mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
    mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir
    mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir


################################################################################
diff  --git a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h 
b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
index c4d512b60bc51..1dc178586e918 100644
--- a/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
+++ b/mlir/include/mlir/Dialect/Mesh/IR/MeshOps.h
@@ -206,6 +206,9 @@ Type shardType(Type type, MeshOp mesh, MeshSharding 
sharding);
 // Use newShardOp if it is not null. Otherwise create a new one.
 // May insert resharding if required.
 // Potentially updates newShardOp.
+void maybeInsertTargetShardingAnnotation(MeshSharding sharding,
+                                         OpOperand &operand, OpBuilder 
&builder,
+                                         ShardOp &newShardOp);
 void maybeInsertTargetShardingAnnotation(MeshSharding sharding, OpResult 
result,
                                          OpBuilder &builder);
 void maybeInsertSourceShardingAnnotation(MeshSharding sharding,

diff  --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h 
b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
index a2424d43a8ba9..83399d10beaae 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.h
@@ -19,18 +19,6 @@ class FuncOp;
 
 namespace mesh {
 
-/// This enum controls the traversal order for the sharding propagation.
-enum class TraversalOrder {
-  /// Forward traversal.
-  Forward,
-  /// Backward traversal.
-  Backward,
-  /// Forward then backward traversal.
-  ForwardBackward,
-  /// Backward then forward traversal.
-  BackwardForward
-};
-
 
//===----------------------------------------------------------------------===//
 // Passes
 
//===----------------------------------------------------------------------===//

diff  --git a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td 
b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
index 11ec7e78cd5e6..06ebf151e7d64 100644
--- a/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
+++ b/mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td
@@ -24,21 +24,6 @@ def ShardingPropagation : 
InterfacePass<"sharding-propagation", "mlir::FunctionO
     operation, and the operations themselves are added with sharding option
     attributes.
   }];
-  let options = [
-    Option<"traversal", "traversal",
-           "mlir::mesh::TraversalOrder", 
/*default=*/"mlir::mesh::TraversalOrder::BackwardForward",
-           "Traversal order to use for sharding propagation:",
-            [{::llvm::cl::values(
-              clEnumValN(mlir::mesh::TraversalOrder::Forward, "forward",
-              "Forward only traversal."),
-              clEnumValN(mlir::mesh::TraversalOrder::Backward, "backward",
-              "backward only traversal."),
-              clEnumValN(mlir::mesh::TraversalOrder::ForwardBackward, 
"forward-backward",
-              "forward-backward traversal."),
-              clEnumValN(mlir::mesh::TraversalOrder::BackwardForward, 
"backward-forward",
-              "backward-forward traversal.")
-            )}]>,
-  ];
   let dependentDialects = [
     "mesh::MeshDialect"
   ];

diff  --git a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp 
b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
index b8cc91da722f0..0a01aaf776e7d 100644
--- a/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
+++ b/mlir/lib/Dialect/Mesh/IR/MeshOps.cpp
@@ -298,12 +298,13 @@ Type mesh::shardType(Type type, MeshOp mesh, MeshSharding 
sharding) {
   return type;
 }
 
-static void maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
-                                                    Value &operandValue,
-                                                    Operation *operandOp,
-                                                    OpBuilder &builder,
-                                                    ShardOp &newShardOp) {
+void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
+                                                     OpOperand &operand,
+                                                     OpBuilder &builder,
+                                                     ShardOp &newShardOp) {
   OpBuilder::InsertionGuard insertionGuard(builder);
+  Value operandValue = operand.get();
+  Operation *operandOp = operand.getOwner();
   builder.setInsertionPointAfterValue(operandValue);
   ShardOp shardOp = dyn_cast<ShardOp>(operandOp);
   if (shardOp && sharding == shardOp.getSharding() &&
@@ -322,8 +323,9 @@ static void 
maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
         builder.create<ShardOp>(operandValue.getLoc(), operandValue, 
shardingOp,
                                 /*annotate_for_users*/ false);
   }
-  operandValue.replaceUsesWithIf(
-      newShardOp, [operandOp, operandValue](OpOperand &use) {
+  IRRewriter rewriter(builder);
+  rewriter.replaceUsesWithIf(
+      operandValue, newShardOp, [operandOp, operandValue](OpOperand &use) {
         return use.getOwner() == operandOp && use.get() == operandValue;
       });
 
@@ -334,20 +336,15 @@ static void 
maybeInsertTargetShardingAnnotationImpl(MeshSharding sharding,
   auto newShardOp2 = builder.create<ShardOp>(operandValue.getLoc(), newShardOp,
                                              newShardOp.getSharding(),
                                              /*annotate_for_users*/ true);
-  newShardOp.getResult().replaceAllUsesExcept(newShardOp2, newShardOp2);
+  rewriter.replaceAllUsesExcept(newShardOp, newShardOp2, newShardOp2);
 }
 
 void mlir::mesh::maybeInsertTargetShardingAnnotation(MeshSharding sharding,
                                                      OpResult result,
                                                      OpBuilder &builder) {
   ShardOp newShardOp;
-  SmallVector<std::pair<Value, Operation *>> uses;
-  for (auto &use : result.getUses()) {
-    uses.emplace_back(use.get(), use.getOwner());
-  }
-  for (auto &[operandValue, operandOp] : uses) {
-    maybeInsertTargetShardingAnnotationImpl(sharding, operandValue, operandOp,
-                                            builder, newShardOp);
+  for (auto &use : llvm::make_early_inc_range(result.getUses())) {
+    maybeInsertTargetShardingAnnotation(sharding, use, builder, newShardOp);
   }
 }
 

diff  --git a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp 
b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
index 6751fafaf1776..4452dd65fce9d 100644
--- a/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
+++ b/mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp
@@ -362,9 +362,6 @@ static LogicalResult visitOp(Operation *op, OpBuilder 
&builder) {
 
//===----------------------------------------------------------------------===//
 struct ShardingPropagation
     : public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
-
-  using ShardingPropagationBase<ShardingPropagation>::ShardingPropagationBase;
-
   void runOnOperation() override {
     FunctionOpInterface funcOp = getOperation();
     MLIRContext *ctx = funcOp.getContext();
@@ -385,31 +382,18 @@ struct ShardingPropagation
             shardingOp.printLoopTypesAndIndexingMaps(llvm::dbgs());
         });
 
-    auto traverse = [&](auto &&range, OpBuilder &builder,
-                        const char *order) -> bool {
-      for (Operation &op : range) {
-        if (failed(visitOp(&op, builder))) {
-          signalPassFailure();
-          return true;
-        }
-      }
-      LLVM_DEBUG(DBGS() << "After " << order << " order propagation:\n"
-                        << funcOp << "\n");
-      LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
-      return false;
-    };
-
-    // 1. Propagate in reversed order.
-    if (traversal == TraversalOrder::Backward ||
-        traversal == TraversalOrder::BackwardForward)
-      traverse(llvm::reverse(block), builder, "backward");
-
-    // 2. Propagate in original order.
-    if (traversal != TraversalOrder::Backward)
-      traverse(block, builder, "forward");
-
-    // 3. Propagate in backward order if needed.
-    if (traversal == TraversalOrder::ForwardBackward)
-      traverse(llvm::reverse(block), builder, "backward");
+    // 1. propagate in reversed order
+    for (Operation &op : llvm::make_early_inc_range(llvm::reverse(block)))
+      if (failed(visitOp(&op, builder)))
+        return signalPassFailure();
+
+    LLVM_DEBUG(DBGS() << "After reversed order propagation:\n"
+                      << funcOp << "\n");
+    LLVM_DEBUG(assert(succeeded(mlir::verify(funcOp))));
+
+    // 2. propagate in original order
+    for (Operation &op : llvm::make_early_inc_range(block))
+      if (failed(visitOp(&op, builder)))
+        return signalPassFailure();
   }
 };

diff  --git a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir 
b/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
deleted file mode 100644
index 4223d01d65111..0000000000000
--- a/mlir/test/Dialect/Mesh/backward-sharding-propagation.mlir
+++ /dev/null
@@ -1,26 +0,0 @@
-// RUN: mlir-opt 
--pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=backward}))"
 %s | FileCheck %s
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-module {
-  mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
-  func.func @test_forward() -> tensor<6x6xi32> {
-    %c1_i32 = arith.constant 1 : i32
-    // CHECK: tensor.empty()
-    %0 = tensor.empty() : tensor<6x6xi32>
-    %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
-    // CHECK-COUNT-2: mesh.shard
-    %sharding_annotated = mesh.shard %0 to %sharding : tensor<6x6xi32>
-    %1 = linalg.fill ins(%c1_i32 : i32) outs(%sharding_annotated : 
tensor<6x6xi32>) -> tensor<6x6xi32>
-    // CHECK: tensor.empty()
-    // CHECK-NOT: mesh.shard @
-    %2 = tensor.empty() : tensor<6x6xi32>
-    %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = 
["parallel", "parallel"]} ins(%1, %1
-        : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
-    ^bb0(%in: i32, %in_2: i32, %out: i32):
-      %9 = arith.addi %in, %in_2 : i32
-      linalg.yield %9 : i32
-    } -> tensor<6x6xi32>
-    // CHECK: return
-    return %3 : tensor<6x6xi32>
-  }
-}

diff  --git a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir 
b/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir
deleted file mode 100644
index dd2eee2f7def8..0000000000000
--- a/mlir/test/Dialect/Mesh/forward-backward-sharding-propagation.mlir
+++ /dev/null
@@ -1,27 +0,0 @@
-// RUN: mlir-opt 
--pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward-backward}))"
 %s | FileCheck %s
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-module {
-  mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
-  func.func @test_forward() -> tensor<6x6xi32> {
-    %c1_i32 = arith.constant 1 : i32
-    // CHECK: tensor.empty()
-    %0 = tensor.empty() : tensor<6x6xi32>
-    // CHECK-COUNT-3: mesh.sharding @mesh split_axes = {{\[\[0}}]]
-    %sharding_row = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
-    %annotated_row = mesh.shard %0 to %sharding_row : tensor<6x6xi32>
-    %1 = linalg.fill ins(%c1_i32 : i32) outs(%annotated_row : tensor<6x6xi32>) 
-> tensor<6x6xi32>
-    %2 = tensor.empty() : tensor<6x6xi32>
-    // CHECK-COUNT-4: mesh.sharding @mesh split_axes = {{\[\[1}}]]
-    %3 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = 
["parallel", "parallel"]} ins(%2, %1
-        : tensor<6x6xi32>, tensor<6x6xi32>) outs(%2 : tensor<6x6xi32>) {
-    ^bb0(%in: i32, %in_2: i32, %out: i32):
-      %9 = arith.addi %in, %in_2 : i32
-      linalg.yield %9 : i32
-    } -> tensor<6x6xi32>
-    %sharding_col = mesh.sharding @mesh split_axes = [[1]] : !mesh.sharding
-    %annotated_col = mesh.shard %3 to %sharding_col : tensor<6x6xi32>
-    // CHECK: return
-    return %annotated_col : tensor<6x6xi32>
-  }
-}

diff  --git a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir 
b/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
deleted file mode 100644
index 98e9931b8de94..0000000000000
--- a/mlir/test/Dialect/Mesh/forward-sharding-propagation.mlir
+++ /dev/null
@@ -1,49 +0,0 @@
-// RUN: mlir-opt 
--pass-pipeline="builtin.module(func.func(sharding-propagation{traversal=forward}))"
 %s | FileCheck %s
-
-#map = affine_map<(d0, d1) -> (d0, d1)>
-module attributes {dlti.map = #dlti.map<"MPI:Implementation" = "mpich", 
"MPI:comm_world_rank" = 0 : i32>} {
-  mesh.mesh @mesh(shape = 1) {sym_visibility = "private"}
-  func.func @test_forward() -> (tensor<6x6xi32>, tensor<6x6xi32>, tensor<i32>) 
attributes {llvm.emit_c_interface} {
-    %c1_i32 = arith.constant 1 : i32
-    // CHECK: [[v3:%.*]] = tensor.empty() : tensor<6x6xi32>
-    %0 = tensor.empty() : tensor<6x6xi32>
-    // CHECK: [[v1:%.*]] = linalg.fill ins
-    // CHECK: [[vsharding_0:%.*]] = mesh.sharding @mesh split_axes = 
{{\[\[}}0]] : !mesh.sharding
-    // CHECK: [[vsharding_annotated_1:%.*]] = mesh.shard [[v1]] to 
[[vsharding_0]] : tensor<6x6xi32>
-    %1 = linalg.fill ins(%c1_i32 : i32) outs(%0 : tensor<6x6xi32>) -> 
tensor<6x6xi32>
-    %sharding = mesh.sharding @mesh split_axes = [[0]] : !mesh.sharding
-    %sharding_annotated = mesh.shard %1 to %sharding : tensor<6x6xi32>
-    // CHECK: [[v2:%.*]] = tensor.empty() : tensor<6x6xi32>
-    // CHECK: [[vsharding_2:%.*]] = mesh.sharding @mesh split_axes = 
{{\[\[}}0]] : !mesh.sharding
-    // CHECK: [[vsharding_annotated_3:%.*]] = mesh.shard 
[[vsharding_annotated_1]] to [[vsharding_2]] annotate_for_users : 
tensor<6x6xi32>
-    %3 = tensor.empty() : tensor<6x6xi32>
-    // CHECK: [[vsharding_4:%.*]] = mesh.sharding @mesh split_axes = 
{{\[\[}}0]] : !mesh.sharding
-    // CHECK: [[vsharding_annotated_5:%.*]] = mesh.shard [[v2]] to 
[[vsharding_4]] annotate_for_users : tensor<6x6xi32>
-    // CHECK: [[v3:%.*]] = linalg.generic {indexing_maps = [#map, #map, #map], 
iterator_types = ["parallel", "parallel"]}
-    // CHECK-SAME: ins([[vsharding_annotated_3]], [[vsharding_annotated_3]] : 
tensor<6x6xi32>, tensor<6x6xi32>) outs([[vsharding_annotated_5]] : 
tensor<6x6xi32>) {
-    // CHECK: [[vsharding_6:%.*]] = mesh.sharding @mesh split_axes = 
{{\[\[}}0]] : !mesh.sharding
-    // CHECK: [[vsharding_annotated_7:%.*]] = mesh.shard [[v3]] to 
[[vsharding_6]] : tensor<6x6xi32>
-    %4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = 
["parallel", "parallel"]} ins(%sharding_annotated, %sharding_annotated
-        : tensor<6x6xi32>, tensor<6x6xi32>) outs(%3 : tensor<6x6xi32>) {
-    ^bb0(%in: i32, %in_2: i32, %out: i32):
-      %9 = arith.addi %in, %in_2 : i32
-      linalg.yield %9 : i32
-    } -> tensor<6x6xi32>
-    %c0_i32 = arith.constant 0 : i32
-    %6 = tensor.empty() : tensor<i32>
-    %7 = linalg.fill ins(%c0_i32 : i32) outs(%6 : tensor<i32>) -> tensor<i32>
-    // CHECK: [[vreduced:%.*]] = linalg.reduce ins
-    // CHECK: [[vsharding_12:%.*]] = mesh.sharding @mesh split_axes = [] 
partial =  sum [0] : !mesh.sharding
-    // CHECK: [[vsharding_annotated_13:%.*]] = mesh.shard [[vreduced]] to 
[[vsharding_12]] : tensor<i32>
-    %reduced = linalg.reduce ins(%4 : tensor<6x6xi32>) outs(%7 : tensor<i32>) 
dimensions = [0, 1] 
-      (%in: i32, %init: i32) {
-        %9 = arith.addi %in, %init : i32
-        linalg.yield %9 : i32
-      }
-    // CHECK: [[vsharding_14:%.*]] = mesh.sharding @mesh split_axes = 
{{\[\[}}]] : !mesh.sharding
-    %sharding_0 = mesh.sharding @mesh split_axes = [[]] : !mesh.sharding
-    // CHECK: [[vsharding_annotated_15:%.*]] = mesh.shard 
[[vsharding_annotated_13]] to [[vsharding_14]] annotate_for_users : tensor<i32>
-    %sharding_annotated_1 = mesh.shard %reduced to %sharding_0 
annotate_for_users : tensor<i32>
-    return %sharding_annotated, %4, %sharding_annotated_1 : tensor<6x6xi32>, 
tensor<6x6xi32>, tensor<i32>
-  }
-}


        
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to