llvmbot wrote:

<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-gpu

Author: Nicolas Vasilache (nicolasvasilache)

<details>
<summary>Changes</summary>

…lOp and use it to implement warp specialization.

This revision adds DeviceMaskingAttrInterface and extends 
DeviceMappingArrayAttr to accept a union of DeviceMappingAttrInterface and 
DeviceMaskingAttrInterface.

The first implementation is if the form of a GPUMappingMaskAttr, which can be 
additionally passed to the scf.forall.mapping attribute to specify a mask on 
compute resources that should be active.

Support is added to GPUTransformOps to take advantage of this information and 
lower to block/warpgroup/warp/thread specialization when mapped to linear ids.

---

Patch is 35.49 KiB, truncated to 20.00 KiB below, full version: 
https://github.com/llvm/llvm-project/pull/146943.diff


12 Files Affected:

- (modified) mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td (+18) 
- (modified) mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h (+10-5) 
- (modified) mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td (+44-1) 
- (modified) mlir/include/mlir/Dialect/SCF/IR/SCFOps.td (+12) 
- (modified) mlir/lib/Dialect/GPU/CMakeLists.txt (+1) 
- (modified) mlir/lib/Dialect/GPU/IR/GPUDialect.cpp (+45) 
- (modified) mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp (+39-19) 
- (modified) mlir/lib/Dialect/GPU/TransformOps/Utils.cpp (+73-27) 
- (modified) mlir/lib/Dialect/SCF/IR/SCF.cpp (+37-6) 
- (modified) mlir/test/Dialect/GPU/transform-gpu-failing.mlir (+61) 
- (modified) mlir/test/Dialect/GPU/transform-gpu.mlir (+81) 
- (modified) mlir/test/Dialect/SCF/invalid.mlir (+18) 


``````````diff
diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td 
b/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td
index 63f228ca3157f..e8540027e7b77 100644
--- a/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td
+++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td
@@ -252,6 +252,24 @@ def GPULaneMappingAttr
   }];
 }
 
+def GPUMappingMaskAttr : GPU_Attr<"GPUMappingMask", "mask", [
+  DeclareAttrInterfaceMethods<DeviceMaskingAttrInterface> ] >  {
+  let parameters = (ins "uint64_t":$mask);
+  let assemblyFormat = "`<` params `>`";
+  let description = [{
+    Attribute describing how to filter the processing units that a
+    region is mapped to.
+
+    In the first implementation the masking is a bitfield that specifies for
+    each processing unit whether it is active or not.
+
+    In the future, we may want to implement this as a symbol to refer to
+    dynamically defined values.
+
+    Extending op semantics with an operand is deemed too intrusive at this 
time.
+  }];
+}
+
 def GPUMemorySpaceMappingAttr : GPU_Attr<"GPUMemorySpaceMapping", 
"memory_space", [
   DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] >  {
   let parameters = (ins
diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h 
b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
index de512ded59fec..0a11b8f8d3fa0 100644
--- a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
+++ b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h
@@ -78,7 +78,8 @@ struct GpuIdBuilder {
 /// If `useLinearMapping` is true, the `idBuilder` method returns nD values
 /// used for indexing rewrites as well as 1D sizes for predicate generation.
 struct GpuBlockIdBuilder : public GpuIdBuilder {
-  GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping = false);
+  GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping = false,
+                    DeviceMaskingAttrInterface mask = nullptr);
 };
 
 /// Builder for warpgroup ids used to map scf.forall to reindexed warpgroups.
@@ -88,7 +89,8 @@ struct GpuBlockIdBuilder : public GpuIdBuilder {
 /// used for indexing rewrites as well as 1D sizes for predicate generation.
 struct GpuWarpgroupIdBuilder : public GpuIdBuilder {
   GpuWarpgroupIdBuilder(MLIRContext *ctx, int64_t warpSize,
-                        bool useLinearMapping = false);
+                        bool useLinearMapping = false,
+                        DeviceMaskingAttrInterface mask = nullptr);
   int64_t warpSize = 32;
   /// In the future this may be configured by the transformation.
   static constexpr int64_t kNumWarpsPerGroup = 4;
@@ -101,7 +103,8 @@ struct GpuWarpgroupIdBuilder : public GpuIdBuilder {
 /// used for indexing rewrites as well as 1D sizes for predicate generation.
 struct GpuWarpIdBuilder : public GpuIdBuilder {
   GpuWarpIdBuilder(MLIRContext *ctx, int64_t warpSize,
-                   bool useLinearMapping = false);
+                   bool useLinearMapping = false,
+                   DeviceMaskingAttrInterface mask = nullptr);
   int64_t warpSize = 32;
 };
 
@@ -111,7 +114,8 @@ struct GpuWarpIdBuilder : public GpuIdBuilder {
 /// If `useLinearMapping` is true, the `idBuilder` method returns nD values
 /// used for indexing rewrites as well as 1D sizes for predicate generation.
 struct GpuThreadIdBuilder : public GpuIdBuilder {
-  GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping = false);
+  GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping = false,
+                     DeviceMaskingAttrInterface mask = nullptr);
 };
 
 /// Builder for lane id.
@@ -119,7 +123,8 @@ struct GpuThreadIdBuilder : public GpuIdBuilder {
 /// as 1D sizes for predicate generation.
 /// This `useLinearMapping` case is the only supported case.
 struct GpuLaneIdBuilder : public GpuIdBuilder {
-  GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize, bool unused);
+  GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize, bool unused,
+                   DeviceMaskingAttrInterface mask = nullptr);
   int64_t warpSize = 32;
 };
 
diff --git a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td 
b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
index 96db2a40cf58e..353aaf05bee0c 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td
@@ -60,8 +60,51 @@ def DeviceMappingAttrInterface : 
AttrInterface<"DeviceMappingAttrInterface"> {
   ];
 }
 
+def DeviceMaskingAttrInterface : AttrInterface<"DeviceMaskingAttrInterface"> {
+  let cppNamespace = "::mlir";
+  let description = [{
+    Attribute interface describing how to filter the processing units that a
+    region is mapped to.
+
+    A popcount can be applied to determine the logical linear index that a
+    physical processing unit is responsible for.
+  }];
+
+ let methods = [
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the logical active id for a given physical id.
+        Expects a physicalLinearMappingId of I64Type.
+      }],
+      /*retTy=*/"Value",
+      /*methodName=*/"getLogicalLinearMappingId",
+      /*args=*/(ins "OpBuilder&":$builder, "Value":$physicalLinearMappingId)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the dynamic condition determining whether a given physical id is
+        active under the mask.
+        Expects a physicalLinearMappingId of I64Type.
+      }],
+      /*retTy=*/"Value",
+      /*methodName=*/"getIsActiveIdPredicate",
+      /*args=*/(ins "OpBuilder&":$builder, "Value":$physicalLinearMappingId)
+    >,
+    InterfaceMethod<
+      /*desc=*/[{
+        Return the maximal number of pysical ids supported.
+        This is to account for temporary implementation limitations (e.g. i64)
+        and fail gracefully with actionnable error messages.
+      }],
+      /*retTy=*/"int64_t",
+      /*methodName=*/"getMaxNumPhysicalIds",
+      /*args=*/(ins)
+    >,
+  ];
+}
+
 def DeviceMappingArrayAttr :
-  TypedArrayAttrBase<DeviceMappingAttrInterface,
+  TypedArrayAttrBase<AnyAttrOf<[DeviceMappingAttrInterface, 
DeviceMaskingAttrInterface]>,
   "Device Mapping array attribute"> { }
 
 #endif // MLIR_DEVICEMAPPINGINTERFACE
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td 
b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 8b14cef7437d4..2d15544e871b3 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -611,6 +611,18 @@ def ForallOp : SCF_Op<"forall", [
     /// Returns operations within scf.forall.in_parallel whose destination
     /// operand is the block argument `bbArg`.
     SmallVector<Operation*> getCombiningOps(BlockArgument bbArg);
+
+    /// Returns the subset of DeviceMappingArrayAttrs of type
+    /// DeviceMappingAttrInterface.
+    SmallVector<DeviceMappingAttrInterface> getDeviceMappingAttrs();
+
+    /// Returns the at most one DeviceMaskingAttrInterface in the mapping.
+    /// If more than one DeviceMaskingAttrInterface is specified, returns
+    /// failure. If no mapping is present, returns nullptr.
+    FailureOr<DeviceMaskingAttrInterface> getDeviceMaskingAttr();
+
+    /// Returns true if the mapping specified for this forall op is linear.
+    bool usesLinearMapping();
   }];
 }
 
diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt 
b/mlir/lib/Dialect/GPU/CMakeLists.txt
index c8c53374d676b..4862d1f722785 100644
--- a/mlir/lib/Dialect/GPU/CMakeLists.txt
+++ b/mlir/lib/Dialect/GPU/CMakeLists.txt
@@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRGPUDialect
   MLIRFunctionInterfaces
   MLIRInferIntRangeInterface
   MLIRIR
+  MLIRMathDialect
   MLIRMemRefDialect
   MLIRSideEffectInterfaces
   MLIRSupport
diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp 
b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
index 56631f1aac084..9d74c23c24cc8 100644
--- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
+++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp
@@ -14,6 +14,7 @@
 
 #include "mlir/Dialect/Arith/IR/Arith.h"
 #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h"
+#include "mlir/Dialect/Math/IR/Math.h"
 #include "mlir/Dialect/MemRef/IR/MemRef.h"
 #include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
@@ -120,6 +121,50 @@ int64_t GPULaneMappingAttr::getRelativeIndex() const {
              : getMappingId();
 }
 
+int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds() const { return 64; }
+
+///                 8       4       0
+/// Example mask  : 0 0 0 1 1 0 1 0 0
+///
+/// Active physical (resp. logical) is  2 (0), 4 (1) and 5 (2).
+/// Logical id for e.g. 5 (2) constructs filter (1 << 5 - 1).
+///
+/// Example mask  : 0 0 0 1 1 0 1 0 0
+/// Example filter: 0 0 0 0 1 1 1 1 1
+/// Intersection  : 0 0 0 0 1 0 1 0 0
+/// PopCnt        : 2
+Value GPUMappingMaskAttr::getLogicalLinearMappingId(
+    OpBuilder &b, Value physicalLinearMappingId) const {
+  Location loc = physicalLinearMappingId.getLoc();
+  Value mask = b.create<arith::ConstantOp>(loc, 
b.getI64IntegerAttr(getMask()));
+  Value one = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
+  Value filter = b.create<arith::ShLIOp>(loc, one, physicalLinearMappingId);
+  filter = b.create<arith::SubIOp>(loc, filter, one);
+  Value filteredId = b.create<arith::AndIOp>(loc, mask, filter);
+  return b.create<math::CtPopOp>(loc, filteredId);
+}
+
+///                 8       4       0
+/// Example mask  : 0 0 0 1 1 0 1 0 0
+///
+/// Active physical (resp. logical) is  2 (0), 4 (1) and 5 (2).
+/// Logical id for e.g. 5 (2) constructs filter (1 << 5).
+///
+/// Example mask  : 0 0 0 1 1 0 1 0 0
+/// Example filter: 0 0 0 1 0 0 0 0 0
+/// Intersection  : 0 0 0 1 0 0 0 0 0
+/// Cmp           : 1
+Value GPUMappingMaskAttr::getIsActiveIdPredicate(
+    OpBuilder &b, Value physicalLinearMappingId) const {
+  Location loc = physicalLinearMappingId.getLoc();
+  Value mask = b.create<arith::ConstantOp>(loc, 
b.getI64IntegerAttr(getMask()));
+  Value one = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1));
+  Value filter = b.create<arith::ShLIOp>(loc, one, physicalLinearMappingId);
+  Value filtered = b.create<arith::AndIOp>(loc, mask, filter);
+  Value zero = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(0));
+  return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, filtered, 
zero);
+}
+
 int64_t GPUMemorySpaceMappingAttr::getMappingId() const {
   return static_cast<int64_t>(getAddressSpace());
 }
diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp 
b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
index 63f87d9b5877e..a8eaa20928b7f 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp
@@ -351,16 +351,25 @@ 
checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp,
     seen.insert(map);
   }
 
-  auto isLinear = [](Attribute a) {
-    return cast<DeviceMappingAttrInterface>(a).isLinearMapping();
+  auto isLinear = [](DeviceMappingAttrInterface attr) {
+    return attr.isLinearMapping();
   };
-  if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) &&
-      !llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) {
+  if (llvm::any_of(forallOp.getDeviceMappingAttrs(), isLinear) &&
+      !llvm::all_of(forallOp.getDeviceMappingAttrs(), isLinear)) {
     return definiteFailureHelper(
         transformOp, forallOp,
         "cannot mix linear and non-linear mapping modes");
   }
 
+  FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
+      forallOp.getDeviceMaskingAttr();
+  if (succeeded(maybeMaskingAttr) && *maybeMaskingAttr &&
+      !forallOp.usesLinearMapping()) {
+    return definiteFailureHelper(
+        transformOp, forallOp,
+        "device masking is only available in linear mapping mode");
+  }
+
   return DiagnosedSilenceableFailure::success();
 }
 
@@ -381,9 +390,7 @@ verifyGpuMapping(std::optional<TransformOpInterface> 
transformOp,
   if (forallOp.getNumResults() > 0)
     return definiteFailureHelper(transformOp, forallOp,
                                  "only bufferized scf.forall can be mapped");
-  bool useLinearMapping = cast<DeviceMappingAttrInterface>(
-                              forallOp.getMapping()->getValue().front())
-                              .isLinearMapping();
+  bool useLinearMapping = forallOp.usesLinearMapping();
   // TODO: This would be more natural with support for Optional<EnumParameter>
   // in GPUDeviceMappingAttr.
   int64_t maxNumMappingsSupported =
@@ -682,12 +689,17 @@ DiagnosedSilenceableFailure 
transform::MapForallToBlocks::applyToOne(
 
   // The BlockIdBuilder adapts to whatever is thrown at it.
   bool useLinearMapping = false;
-  if (topLevelForallOp.getMapping()) {
-    auto mappingAttr = cast<DeviceMappingAttrInterface>(
-        topLevelForallOp.getMapping()->getValue().front());
-    useLinearMapping = mappingAttr.isLinearMapping();
-  }
-  GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping);
+  if (topLevelForallOp.getMapping())
+    useLinearMapping = topLevelForallOp.usesLinearMapping();
+
+  FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
+      topLevelForallOp.getDeviceMaskingAttr();
+  assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr");
+  assert((!*maybeMaskingAttr || useLinearMapping) &&
+         "masking requires linear mapping");
+
+  GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping,
+                                      *maybeMaskingAttr);
 
   diag = mlir::transform::gpu::mapForallToBlocksImpl(
       rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder);
@@ -744,8 +756,7 @@ static DiagnosedSilenceableFailure
 getThreadIdBuilder(std::optional<TransformOpInterface> transformOp,
                    scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes,
                    int64_t warpSize, GpuIdBuilder &gpuIdBuilder) {
-  auto mappingAttr = cast<DeviceMappingAttrInterface>(
-      forallOp.getMapping()->getValue().front());
+  auto mappingAttr = forallOp.getDeviceMappingAttrs().front();
   bool useLinearMapping = mappingAttr.isLinearMapping();
 
   // Sanity checks that may result in runtime verification errors.
@@ -768,21 +779,30 @@ getThreadIdBuilder(std::optional<TransformOpInterface> 
transformOp,
   if (!diag.succeeded())
     return diag;
 
+  FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr =
+      forallOp.getDeviceMaskingAttr();
+  assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr");
+  assert((!*maybeMaskingAttr || useLinearMapping) &&
+         "masking requires linear mapping");
+
   // Start mapping.
   MLIRContext *ctx = forallOp.getContext();
   gpuIdBuilder =
       TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr)
           .Case([&](GPUWarpgroupMappingAttr) {
-            return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping);
+            return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping,
+                                         *maybeMaskingAttr);
           })
           .Case([&](GPUWarpMappingAttr) {
-            return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping);
+            return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping,
+                                    *maybeMaskingAttr);
           })
           .Case([&](GPUThreadMappingAttr) {
-            return GpuThreadIdBuilder(ctx, useLinearMapping);
+            return GpuThreadIdBuilder(ctx, useLinearMapping, 
*maybeMaskingAttr);
           })
           .Case([&](GPULaneMappingAttr) {
-            return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping);
+            return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping,
+                                    *maybeMaskingAttr);
           })
           .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder {
             llvm_unreachable("unknown mapping attribute");
diff --git a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp 
b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
index 795d643c05912..d1969dbc82997 100644
--- a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
+++ b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp
@@ -44,7 +44,7 @@ using namespace mlir::transform::gpu;
 #define DEBUG_TYPE "gpu-transforms"
 
 #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
-#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
 #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ")
 
 /// Build predicates to filter execution by only the activeIds. Along each
@@ -120,10 +120,23 @@ static Value buildLinearId(RewriterBase &rewriter, 
Location loc,
 /// it in the basis of `forallMappingSizes`. The linear id builder returns an
 /// n-D vector of ids for indexing and 1-D size + id for predicate generation.
 template <typename ThreadOrBlockIdOp>
-static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) {
-  auto res = [multiplicity](RewriterBase &rewriter, Location loc,
-                            ArrayRef<int64_t> forallMappingSizes,
-                            ArrayRef<int64_t> originalBasis) {
+static GpuIdBuilderFnType
+commonLinearIdBuilderFn(int64_t multiplicity = 1,
+                        DeviceMaskingAttrInterface mask = nullptr) {
+  auto res = [multiplicity, mask](RewriterBase &rewriter, Location loc,
+                                  ArrayRef<int64_t> forallMappingSizes,
+                                  ArrayRef<int64_t> originalBasis) {
+    // 0. Early-exit mask case.
+    if (mask) {
+      if (computeProduct(originalBasis) >
+          mask.getMaxNumPhysicalIds() * multiplicity) {
+        return IdBuilderResult{
+            /*errorMsg=*/std::string(
+                "mask representation too short to capture all physical ids: ") 
+
+            std::to_string(mask.getMaxNumPhysicalIds())};
+      }
+    }
+
     // 1. Compute linearId.
     SmallVector<OpFoldResult> originalBasisOfr =
         getAsIndexOpFoldResult(rewriter.getContext(), originalBasis);
@@ -132,9 +145,25 @@ static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t 
multiplicity = 1) {
 
     // 2. Compute scaledLinearId.
     AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext());
-    OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply(
+    OpFoldResult scaledLinearIdOfr = affine::makeComposedFoldedAffineApply(
         rewriter, loc, d0.floorDiv(multiplicity), {physicalLinearId});
 
+    // 2.b. Adjust with mask if needed.
+    Value scaledLinearIdI64;
+    Value scaledLinearId =
+        getValueOrCreateConstantIndexOp(rewriter, loc, scaledLinearIdOfr);
+    if (mask) {
+      scaledLinearId =
+          getValueOrCreateConstantIndexOp(rewriter, loc, scaledLinearIdOfr);
+      scaledLinearIdI64 = rewriter.create<arith::IndexCastUIOp>(
+          loc, rewriter.getI64Type(), scaledLinearId);
+      Value logicalLinearIdI64 =
+          mask.getLogicalLinearMappingId(rewriter, scaledLinearIdI64);
+      scaledLinearId = rewriter.create<arith::IndexCastUIOp>(
+          loc, rewriter.getIndexType(), logicalLinearIdI64);
+      LDBG("------adjusting linearId with mask: " << scaledLinearId);
+    }
+
     // 3. Compute remapped indices.
     SmallVector<Value> ids;
     // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in
@@ -148,15 +177,23 @@ static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t 
multiplicity = 1) {
           affine::makeComposedAffineApply(rewriter, loc, e, {scaledLinearId}));
     }
 
-    // 4. Handle predicates using physicalLinearId.
     std::string errorMsg;
     SmallVector<Value> predicateOps;
-    FailureOr<SmallVector<Value>> maybePredicateOps =
-        buildPredicates(rewriter, loc, physicalLinearId,
-                        computeProduct(forallMappingSizes) * multiplicity,
-                        computeProduct(originalBasis), errorMsg);
-    if (succeeded(maybePredicateOps))
-      predicateOps = *maybePredicateOps;
+    // 4. If mask present, it takes precedence to determine predication.
+    if (mask) {
+      Value isActiveIdPredicate =
+          mask.getIsActiveIdPredicate(rewriter, scaledLinearIdI64);
+      LDBG("------adjusting predicate with mask: " << isActiveIdPredicate);
+      predicateOps.push_back(isActiveIdPredicate);
+    } else {
+      // 4.b. Otherwise, handle predicates using physicalLinearId.
+      FailureOr<SmallVector<Value>> maybePredicateOps =
+          buildPredicates(rewriter, loc, physicalLinearId,
+                          computeProduct(forallMappingSizes) * multiplicity,
+  ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/146943
_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to