https://github.com/nicolasvasilache created https://github.com/llvm/llvm-project/pull/146943
…lOp and use it to implement warp specialization. This revision adds DeviceMaskingAttrInterface and extends DeviceMappingArrayAttr to accept a union of DeviceMappingAttrInterface and DeviceMaskingAttrInterface. The first implementation is if the form of a GPUMappingMaskAttr, which can be additionally passed to the scf.forall.mapping attribute to specify a mask on compute resources that should be active. Support is added to GPUTransformOps to take advantage of this information and lower to block/warpgroup/warp/thread specialization when mapped to linear ids. >From 02e425b30966f4781fe07d8cf595a1e2b0d41aa3 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache <nico.vasila...@amd.com> Date: Thu, 3 Jul 2025 21:26:53 +0200 Subject: [PATCH] [mlir][SCF][GPU] Add DeviceMaskingAttrInterface support to scf::ForallOp and use it to implement warp specialization. This revision adds DeviceMaskingAttrInterface and extends DeviceMappingArrayAttr to accept a union of DeviceMappingAttrInterface and DeviceMaskingAttrInterface. The first implementation is if the form of a GPUMappingMaskAttr, which can be additionally passed to the scf.forall.mapping attribute to specify a mask on compute resources that should be active. Support is added to GPUTransformOps to take advantage of this information and lower to block/warpgroup/warp/thread specialization when mapped to linear ids. Co-authored-by: Oleksandr "Alex" Zinenko <g...@ozinenko.com> --- .../Dialect/GPU/IR/GPUDeviceMappingAttr.td | 18 ++++ .../mlir/Dialect/GPU/TransformOps/Utils.h | 15 ++- .../Dialect/SCF/IR/DeviceMappingInterface.td | 45 +++++++- mlir/include/mlir/Dialect/SCF/IR/SCFOps.td | 12 +++ mlir/lib/Dialect/GPU/CMakeLists.txt | 1 + mlir/lib/Dialect/GPU/IR/GPUDialect.cpp | 45 ++++++++ .../GPU/TransformOps/GPUTransformOps.cpp | 58 ++++++---- mlir/lib/Dialect/GPU/TransformOps/Utils.cpp | 100 +++++++++++++----- mlir/lib/Dialect/SCF/IR/SCF.cpp | 43 ++++++-- .../Dialect/GPU/transform-gpu-failing.mlir | 61 +++++++++++ mlir/test/Dialect/GPU/transform-gpu.mlir | 81 ++++++++++++++ mlir/test/Dialect/SCF/invalid.mlir | 18 ++++ 12 files changed, 439 insertions(+), 58 deletions(-) diff --git a/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td b/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td index 63f228ca3157f..e8540027e7b77 100644 --- a/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td +++ b/mlir/include/mlir/Dialect/GPU/IR/GPUDeviceMappingAttr.td @@ -252,6 +252,24 @@ def GPULaneMappingAttr }]; } +def GPUMappingMaskAttr : GPU_Attr<"GPUMappingMask", "mask", [ + DeclareAttrInterfaceMethods<DeviceMaskingAttrInterface> ] > { + let parameters = (ins "uint64_t":$mask); + let assemblyFormat = "`<` params `>`"; + let description = [{ + Attribute describing how to filter the processing units that a + region is mapped to. + + In the first implementation the masking is a bitfield that specifies for + each processing unit whether it is active or not. + + In the future, we may want to implement this as a symbol to refer to + dynamically defined values. + + Extending op semantics with an operand is deemed too intrusive at this time. + }]; +} + def GPUMemorySpaceMappingAttr : GPU_Attr<"GPUMemorySpaceMapping", "memory_space", [ DeclareAttrInterfaceMethods<DeviceMappingAttrInterface> ] > { let parameters = (ins diff --git a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h index de512ded59fec..0a11b8f8d3fa0 100644 --- a/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h +++ b/mlir/include/mlir/Dialect/GPU/TransformOps/Utils.h @@ -78,7 +78,8 @@ struct GpuIdBuilder { /// If `useLinearMapping` is true, the `idBuilder` method returns nD values /// used for indexing rewrites as well as 1D sizes for predicate generation. struct GpuBlockIdBuilder : public GpuIdBuilder { - GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping = false); + GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping = false, + DeviceMaskingAttrInterface mask = nullptr); }; /// Builder for warpgroup ids used to map scf.forall to reindexed warpgroups. @@ -88,7 +89,8 @@ struct GpuBlockIdBuilder : public GpuIdBuilder { /// used for indexing rewrites as well as 1D sizes for predicate generation. struct GpuWarpgroupIdBuilder : public GpuIdBuilder { GpuWarpgroupIdBuilder(MLIRContext *ctx, int64_t warpSize, - bool useLinearMapping = false); + bool useLinearMapping = false, + DeviceMaskingAttrInterface mask = nullptr); int64_t warpSize = 32; /// In the future this may be configured by the transformation. static constexpr int64_t kNumWarpsPerGroup = 4; @@ -101,7 +103,8 @@ struct GpuWarpgroupIdBuilder : public GpuIdBuilder { /// used for indexing rewrites as well as 1D sizes for predicate generation. struct GpuWarpIdBuilder : public GpuIdBuilder { GpuWarpIdBuilder(MLIRContext *ctx, int64_t warpSize, - bool useLinearMapping = false); + bool useLinearMapping = false, + DeviceMaskingAttrInterface mask = nullptr); int64_t warpSize = 32; }; @@ -111,7 +114,8 @@ struct GpuWarpIdBuilder : public GpuIdBuilder { /// If `useLinearMapping` is true, the `idBuilder` method returns nD values /// used for indexing rewrites as well as 1D sizes for predicate generation. struct GpuThreadIdBuilder : public GpuIdBuilder { - GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping = false); + GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping = false, + DeviceMaskingAttrInterface mask = nullptr); }; /// Builder for lane id. @@ -119,7 +123,8 @@ struct GpuThreadIdBuilder : public GpuIdBuilder { /// as 1D sizes for predicate generation. /// This `useLinearMapping` case is the only supported case. struct GpuLaneIdBuilder : public GpuIdBuilder { - GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize, bool unused); + GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize, bool unused, + DeviceMaskingAttrInterface mask = nullptr); int64_t warpSize = 32; }; diff --git a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td index 96db2a40cf58e..353aaf05bee0c 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td +++ b/mlir/include/mlir/Dialect/SCF/IR/DeviceMappingInterface.td @@ -60,8 +60,51 @@ def DeviceMappingAttrInterface : AttrInterface<"DeviceMappingAttrInterface"> { ]; } +def DeviceMaskingAttrInterface : AttrInterface<"DeviceMaskingAttrInterface"> { + let cppNamespace = "::mlir"; + let description = [{ + Attribute interface describing how to filter the processing units that a + region is mapped to. + + A popcount can be applied to determine the logical linear index that a + physical processing unit is responsible for. + }]; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return the logical active id for a given physical id. + Expects a physicalLinearMappingId of I64Type. + }], + /*retTy=*/"Value", + /*methodName=*/"getLogicalLinearMappingId", + /*args=*/(ins "OpBuilder&":$builder, "Value":$physicalLinearMappingId) + >, + InterfaceMethod< + /*desc=*/[{ + Return the dynamic condition determining whether a given physical id is + active under the mask. + Expects a physicalLinearMappingId of I64Type. + }], + /*retTy=*/"Value", + /*methodName=*/"getIsActiveIdPredicate", + /*args=*/(ins "OpBuilder&":$builder, "Value":$physicalLinearMappingId) + >, + InterfaceMethod< + /*desc=*/[{ + Return the maximal number of pysical ids supported. + This is to account for temporary implementation limitations (e.g. i64) + and fail gracefully with actionnable error messages. + }], + /*retTy=*/"int64_t", + /*methodName=*/"getMaxNumPhysicalIds", + /*args=*/(ins) + >, + ]; +} + def DeviceMappingArrayAttr : - TypedArrayAttrBase<DeviceMappingAttrInterface, + TypedArrayAttrBase<AnyAttrOf<[DeviceMappingAttrInterface, DeviceMaskingAttrInterface]>, "Device Mapping array attribute"> { } #endif // MLIR_DEVICEMAPPINGINTERFACE diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td index 8b14cef7437d4..2d15544e871b3 100644 --- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td +++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td @@ -611,6 +611,18 @@ def ForallOp : SCF_Op<"forall", [ /// Returns operations within scf.forall.in_parallel whose destination /// operand is the block argument `bbArg`. SmallVector<Operation*> getCombiningOps(BlockArgument bbArg); + + /// Returns the subset of DeviceMappingArrayAttrs of type + /// DeviceMappingAttrInterface. + SmallVector<DeviceMappingAttrInterface> getDeviceMappingAttrs(); + + /// Returns the at most one DeviceMaskingAttrInterface in the mapping. + /// If more than one DeviceMaskingAttrInterface is specified, returns + /// failure. If no mapping is present, returns nullptr. + FailureOr<DeviceMaskingAttrInterface> getDeviceMaskingAttr(); + + /// Returns true if the mapping specified for this forall op is linear. + bool usesLinearMapping(); }]; } diff --git a/mlir/lib/Dialect/GPU/CMakeLists.txt b/mlir/lib/Dialect/GPU/CMakeLists.txt index c8c53374d676b..4862d1f722785 100644 --- a/mlir/lib/Dialect/GPU/CMakeLists.txt +++ b/mlir/lib/Dialect/GPU/CMakeLists.txt @@ -20,6 +20,7 @@ add_mlir_dialect_library(MLIRGPUDialect MLIRFunctionInterfaces MLIRInferIntRangeInterface MLIRIR + MLIRMathDialect MLIRMemRefDialect MLIRSideEffectInterfaces MLIRSupport diff --git a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp index 56631f1aac084..9d74c23c24cc8 100644 --- a/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp +++ b/mlir/lib/Dialect/GPU/IR/GPUDialect.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Bufferization/IR/BufferDeallocationOpInterface.h" +#include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/Builders.h" @@ -120,6 +121,50 @@ int64_t GPULaneMappingAttr::getRelativeIndex() const { : getMappingId(); } +int64_t GPUMappingMaskAttr::getMaxNumPhysicalIds() const { return 64; } + +/// 8 4 0 +/// Example mask : 0 0 0 1 1 0 1 0 0 +/// +/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2). +/// Logical id for e.g. 5 (2) constructs filter (1 << 5 - 1). +/// +/// Example mask : 0 0 0 1 1 0 1 0 0 +/// Example filter: 0 0 0 0 1 1 1 1 1 +/// Intersection : 0 0 0 0 1 0 1 0 0 +/// PopCnt : 2 +Value GPUMappingMaskAttr::getLogicalLinearMappingId( + OpBuilder &b, Value physicalLinearMappingId) const { + Location loc = physicalLinearMappingId.getLoc(); + Value mask = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(getMask())); + Value one = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1)); + Value filter = b.create<arith::ShLIOp>(loc, one, physicalLinearMappingId); + filter = b.create<arith::SubIOp>(loc, filter, one); + Value filteredId = b.create<arith::AndIOp>(loc, mask, filter); + return b.create<math::CtPopOp>(loc, filteredId); +} + +/// 8 4 0 +/// Example mask : 0 0 0 1 1 0 1 0 0 +/// +/// Active physical (resp. logical) is 2 (0), 4 (1) and 5 (2). +/// Logical id for e.g. 5 (2) constructs filter (1 << 5). +/// +/// Example mask : 0 0 0 1 1 0 1 0 0 +/// Example filter: 0 0 0 1 0 0 0 0 0 +/// Intersection : 0 0 0 1 0 0 0 0 0 +/// Cmp : 1 +Value GPUMappingMaskAttr::getIsActiveIdPredicate( + OpBuilder &b, Value physicalLinearMappingId) const { + Location loc = physicalLinearMappingId.getLoc(); + Value mask = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(getMask())); + Value one = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(1)); + Value filter = b.create<arith::ShLIOp>(loc, one, physicalLinearMappingId); + Value filtered = b.create<arith::AndIOp>(loc, mask, filter); + Value zero = b.create<arith::ConstantOp>(loc, b.getI64IntegerAttr(0)); + return b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::ne, filtered, zero); +} + int64_t GPUMemorySpaceMappingAttr::getMappingId() const { return static_cast<int64_t>(getAddressSpace()); } diff --git a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp index 63f87d9b5877e..a8eaa20928b7f 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp @@ -351,16 +351,25 @@ checkMappingAttributeTypes(std::optional<TransformOpInterface> transformOp, seen.insert(map); } - auto isLinear = [](Attribute a) { - return cast<DeviceMappingAttrInterface>(a).isLinearMapping(); + auto isLinear = [](DeviceMappingAttrInterface attr) { + return attr.isLinearMapping(); }; - if (llvm::any_of(forallOp.getMapping()->getValue(), isLinear) && - !llvm::all_of(forallOp.getMapping()->getValue(), isLinear)) { + if (llvm::any_of(forallOp.getDeviceMappingAttrs(), isLinear) && + !llvm::all_of(forallOp.getDeviceMappingAttrs(), isLinear)) { return definiteFailureHelper( transformOp, forallOp, "cannot mix linear and non-linear mapping modes"); } + FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr = + forallOp.getDeviceMaskingAttr(); + if (succeeded(maybeMaskingAttr) && *maybeMaskingAttr && + !forallOp.usesLinearMapping()) { + return definiteFailureHelper( + transformOp, forallOp, + "device masking is only available in linear mapping mode"); + } + return DiagnosedSilenceableFailure::success(); } @@ -381,9 +390,7 @@ verifyGpuMapping(std::optional<TransformOpInterface> transformOp, if (forallOp.getNumResults() > 0) return definiteFailureHelper(transformOp, forallOp, "only bufferized scf.forall can be mapped"); - bool useLinearMapping = cast<DeviceMappingAttrInterface>( - forallOp.getMapping()->getValue().front()) - .isLinearMapping(); + bool useLinearMapping = forallOp.usesLinearMapping(); // TODO: This would be more natural with support for Optional<EnumParameter> // in GPUDeviceMappingAttr. int64_t maxNumMappingsSupported = @@ -682,12 +689,17 @@ DiagnosedSilenceableFailure transform::MapForallToBlocks::applyToOne( // The BlockIdBuilder adapts to whatever is thrown at it. bool useLinearMapping = false; - if (topLevelForallOp.getMapping()) { - auto mappingAttr = cast<DeviceMappingAttrInterface>( - topLevelForallOp.getMapping()->getValue().front()); - useLinearMapping = mappingAttr.isLinearMapping(); - } - GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping); + if (topLevelForallOp.getMapping()) + useLinearMapping = topLevelForallOp.usesLinearMapping(); + + FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr = + topLevelForallOp.getDeviceMaskingAttr(); + assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr"); + assert((!*maybeMaskingAttr || useLinearMapping) && + "masking requires linear mapping"); + + GpuBlockIdBuilder gpuBlockIdBuilder(getContext(), useLinearMapping, + *maybeMaskingAttr); diag = mlir::transform::gpu::mapForallToBlocksImpl( rewriter, transformOp, topLevelForallOp, gridDims, gpuBlockIdBuilder); @@ -744,8 +756,7 @@ static DiagnosedSilenceableFailure getThreadIdBuilder(std::optional<TransformOpInterface> transformOp, scf::ForallOp forallOp, ArrayRef<int64_t> blockSizes, int64_t warpSize, GpuIdBuilder &gpuIdBuilder) { - auto mappingAttr = cast<DeviceMappingAttrInterface>( - forallOp.getMapping()->getValue().front()); + auto mappingAttr = forallOp.getDeviceMappingAttrs().front(); bool useLinearMapping = mappingAttr.isLinearMapping(); // Sanity checks that may result in runtime verification errors. @@ -768,21 +779,30 @@ getThreadIdBuilder(std::optional<TransformOpInterface> transformOp, if (!diag.succeeded()) return diag; + FailureOr<DeviceMaskingAttrInterface> maybeMaskingAttr = + forallOp.getDeviceMaskingAttr(); + assert(succeeded(maybeMaskingAttr) && "unexpected failed maybeMaskingAttr"); + assert((!*maybeMaskingAttr || useLinearMapping) && + "masking requires linear mapping"); + // Start mapping. MLIRContext *ctx = forallOp.getContext(); gpuIdBuilder = TypeSwitch<DeviceMappingAttrInterface, GpuIdBuilder>(mappingAttr) .Case([&](GPUWarpgroupMappingAttr) { - return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping); + return GpuWarpgroupIdBuilder(ctx, warpSize, useLinearMapping, + *maybeMaskingAttr); }) .Case([&](GPUWarpMappingAttr) { - return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping); + return GpuWarpIdBuilder(ctx, warpSize, useLinearMapping, + *maybeMaskingAttr); }) .Case([&](GPUThreadMappingAttr) { - return GpuThreadIdBuilder(ctx, useLinearMapping); + return GpuThreadIdBuilder(ctx, useLinearMapping, *maybeMaskingAttr); }) .Case([&](GPULaneMappingAttr) { - return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping); + return GpuLaneIdBuilder(ctx, warpSize, useLinearMapping, + *maybeMaskingAttr); }) .Default([&](DeviceMappingAttrInterface) -> GpuIdBuilder { llvm_unreachable("unknown mapping attribute"); diff --git a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp index 795d643c05912..d1969dbc82997 100644 --- a/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp +++ b/mlir/lib/Dialect/GPU/TransformOps/Utils.cpp @@ -44,7 +44,7 @@ using namespace mlir::transform::gpu; #define DEBUG_TYPE "gpu-transforms" #define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") -#define LDBG(X) LLVM_DEBUG(DBGS() << (X) << "\n") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") #define DBGS_ALIAS() (llvm::dbgs() << '[' << DEBUG_TYPE_ALIAS << "] ") /// Build predicates to filter execution by only the activeIds. Along each @@ -120,10 +120,23 @@ static Value buildLinearId(RewriterBase &rewriter, Location loc, /// it in the basis of `forallMappingSizes`. The linear id builder returns an /// n-D vector of ids for indexing and 1-D size + id for predicate generation. template <typename ThreadOrBlockIdOp> -static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) { - auto res = [multiplicity](RewriterBase &rewriter, Location loc, - ArrayRef<int64_t> forallMappingSizes, - ArrayRef<int64_t> originalBasis) { +static GpuIdBuilderFnType +commonLinearIdBuilderFn(int64_t multiplicity = 1, + DeviceMaskingAttrInterface mask = nullptr) { + auto res = [multiplicity, mask](RewriterBase &rewriter, Location loc, + ArrayRef<int64_t> forallMappingSizes, + ArrayRef<int64_t> originalBasis) { + // 0. Early-exit mask case. + if (mask) { + if (computeProduct(originalBasis) > + mask.getMaxNumPhysicalIds() * multiplicity) { + return IdBuilderResult{ + /*errorMsg=*/std::string( + "mask representation too short to capture all physical ids: ") + + std::to_string(mask.getMaxNumPhysicalIds())}; + } + } + // 1. Compute linearId. SmallVector<OpFoldResult> originalBasisOfr = getAsIndexOpFoldResult(rewriter.getContext(), originalBasis); @@ -132,9 +145,25 @@ static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) { // 2. Compute scaledLinearId. AffineExpr d0 = getAffineDimExpr(0, rewriter.getContext()); - OpFoldResult scaledLinearId = affine::makeComposedFoldedAffineApply( + OpFoldResult scaledLinearIdOfr = affine::makeComposedFoldedAffineApply( rewriter, loc, d0.floorDiv(multiplicity), {physicalLinearId}); + // 2.b. Adjust with mask if needed. + Value scaledLinearIdI64; + Value scaledLinearId = + getValueOrCreateConstantIndexOp(rewriter, loc, scaledLinearIdOfr); + if (mask) { + scaledLinearId = + getValueOrCreateConstantIndexOp(rewriter, loc, scaledLinearIdOfr); + scaledLinearIdI64 = rewriter.create<arith::IndexCastUIOp>( + loc, rewriter.getI64Type(), scaledLinearId); + Value logicalLinearIdI64 = + mask.getLogicalLinearMappingId(rewriter, scaledLinearIdI64); + scaledLinearId = rewriter.create<arith::IndexCastUIOp>( + loc, rewriter.getIndexType(), logicalLinearIdI64); + LDBG("------adjusting linearId with mask: " << scaledLinearId); + } + // 3. Compute remapped indices. SmallVector<Value> ids; // Sizes in [0 .. n] -> [n .. 0] order to properly compute strides in @@ -148,15 +177,23 @@ static GpuIdBuilderFnType commonLinearIdBuilderFn(int64_t multiplicity = 1) { affine::makeComposedAffineApply(rewriter, loc, e, {scaledLinearId})); } - // 4. Handle predicates using physicalLinearId. std::string errorMsg; SmallVector<Value> predicateOps; - FailureOr<SmallVector<Value>> maybePredicateOps = - buildPredicates(rewriter, loc, physicalLinearId, - computeProduct(forallMappingSizes) * multiplicity, - computeProduct(originalBasis), errorMsg); - if (succeeded(maybePredicateOps)) - predicateOps = *maybePredicateOps; + // 4. If mask present, it takes precedence to determine predication. + if (mask) { + Value isActiveIdPredicate = + mask.getIsActiveIdPredicate(rewriter, scaledLinearIdI64); + LDBG("------adjusting predicate with mask: " << isActiveIdPredicate); + predicateOps.push_back(isActiveIdPredicate); + } else { + // 4.b. Otherwise, handle predicates using physicalLinearId. + FailureOr<SmallVector<Value>> maybePredicateOps = + buildPredicates(rewriter, loc, physicalLinearId, + computeProduct(forallMappingSizes) * multiplicity, + computeProduct(originalBasis), errorMsg); + if (succeeded(maybePredicateOps)) + predicateOps = *maybePredicateOps; + } return IdBuilderResult{/*errorMsg=*/errorMsg, /*mappingIdOps=*/ids, @@ -271,58 +308,67 @@ GpuIdBuilder::GpuIdBuilder(MLIRContext *ctx, bool useLinearMapping, } } -GpuBlockIdBuilder::GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping) +GpuBlockIdBuilder::GpuBlockIdBuilder(MLIRContext *ctx, bool useLinearMapping, + DeviceMaskingAttrInterface mask) : GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) { return GPUBlockMappingAttr::get(ctx, id); }) { + assert((!mask || useLinearMapping) && "mask requires linear mapping"); idBuilder = useLinearMapping - ? commonLinearIdBuilderFn<BlockIdOp>(/*multiplicity=*/1) + ? commonLinearIdBuilderFn<BlockIdOp>(/*multiplicity=*/1, mask) : common3DIdBuilderFn<BlockIdOp>(/*multiplicity=*/1); } GpuWarpgroupIdBuilder::GpuWarpgroupIdBuilder(MLIRContext *ctx, int64_t warpSize, - bool useLinearMapping) + bool useLinearMapping, + DeviceMaskingAttrInterface mask) : GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) { return GPUWarpgroupMappingAttr::get(ctx, id); }), warpSize(warpSize) { + assert((!mask || useLinearMapping) && "mask requires linear mapping"); idBuilder = useLinearMapping ? commonLinearIdBuilderFn<ThreadIdOp>( - /*multiplicity=*/kNumWarpsPerGroup * warpSize) + /*multiplicity=*/kNumWarpsPerGroup * warpSize, mask) : common3DIdBuilderFn<ThreadIdOp>( /*multiplicity=*/kNumWarpsPerGroup * warpSize); } GpuWarpIdBuilder::GpuWarpIdBuilder(MLIRContext *ctx, int64_t warpSize, - bool useLinearMapping) + bool useLinearMapping, + DeviceMaskingAttrInterface mask) : GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) { return GPUWarpMappingAttr::get(ctx, id); }), warpSize(warpSize) { - idBuilder = - useLinearMapping - ? commonLinearIdBuilderFn<ThreadIdOp>(/*multiplicity=*/warpSize) - : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/warpSize); + assert((!mask || useLinearMapping) && "mask requires linear mapping"); + idBuilder = useLinearMapping + ? commonLinearIdBuilderFn<ThreadIdOp>( + /*multiplicity=*/warpSize, mask) + : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/warpSize); } -GpuThreadIdBuilder::GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping) +GpuThreadIdBuilder::GpuThreadIdBuilder(MLIRContext *ctx, bool useLinearMapping, + DeviceMaskingAttrInterface mask) : GpuIdBuilder(ctx, useLinearMapping, [](MLIRContext *ctx, MappingId id) { return GPUThreadMappingAttr::get(ctx, id); }) { - idBuilder = useLinearMapping - ? commonLinearIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1) - : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1); + idBuilder = + useLinearMapping + ? commonLinearIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1, mask) + : common3DIdBuilderFn<ThreadIdOp>(/*multiplicity=*/1); } GpuLaneIdBuilder::GpuLaneIdBuilder(MLIRContext *ctx, int64_t warpSize, - bool unused) + bool unused, DeviceMaskingAttrInterface mask) : GpuIdBuilder(ctx, /*useLinearMapping=*/true, [](MLIRContext *ctx, MappingId id) { return GPULaneMappingAttr::get(ctx, id); }), warpSize(warpSize) { + assert(!mask && "mask NYI for lanes, unclear it should be at all"); idBuilder = laneIdBuilderFn(/*periodicity=*/warpSize); } diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index 79012dbd32f80..5a3bd984530db 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -1175,13 +1175,11 @@ LogicalResult ForallOp::verify() { return emitOpError("type mismatch between ") << i << "-th output and corresponding block argument"; if (getMapping().has_value() && !getMapping()->empty()) { - if (static_cast<int64_t>(getMapping()->size()) != numLoops) + if (getDeviceMappingAttrs().size() != numLoops) return emitOpError() << "mapping attribute size must match op rank"; - for (auto map : getMapping()->getValue()) { - if (!isa<DeviceMappingAttrInterface>(map)) - return emitOpError() - << getMappingAttrName() << " is not device mapping attribute"; - } + if (failed(getDeviceMaskingAttr())) + return emitOpError() << getMappingAttrName() + << " supports at most one device masking attribute"; } // Verify mixed static/dynamic control variables. @@ -1435,6 +1433,39 @@ SmallVector<Operation *> ForallOp::getCombiningOps(BlockArgument bbArg) { return storeOps; } +SmallVector<DeviceMappingAttrInterface> ForallOp::getDeviceMappingAttrs() { + SmallVector<DeviceMappingAttrInterface> res; + if (!getMapping()) + return res; + for (auto attr : getMapping()->getValue()) { + auto m = dyn_cast<DeviceMappingAttrInterface>(attr); + if (m) + res.push_back(m); + } + return res; +} + +FailureOr<DeviceMaskingAttrInterface> ForallOp::getDeviceMaskingAttr() { + DeviceMaskingAttrInterface res; + if (!getMapping()) + return res; + for (auto attr : getMapping()->getValue()) { + auto m = dyn_cast<DeviceMaskingAttrInterface>(attr); + if (m && res) + return failure(); + if (m) + res = m; + } + return res; +} + +bool ForallOp::usesLinearMapping() { + SmallVector<DeviceMappingAttrInterface> ifaces = getDeviceMappingAttrs(); + if (ifaces.empty()) + return false; + return ifaces.front().isLinearMapping(); +} + std::optional<SmallVector<Value>> ForallOp::getLoopInductionVars() { return SmallVector<Value>{getBody()->getArguments().take_front(getRank())}; } diff --git a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir index 8d7a1aa2a55fc..bc052a0230a8e 100644 --- a/mlir/test/Dialect/GPU/transform-gpu-failing.mlir +++ b/mlir/test/Dialect/GPU/transform-gpu-failing.mlir @@ -405,6 +405,67 @@ module attributes {transform.with_named_sequence} { // ----- +func.func @masking_mapping_attribute_requires_linear_mapping( + %x: memref<32xf32>, %y: memref<32xf32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<32xf32> { + %one = arith.constant 1 : index + %c9 = arith.constant 9 : index + %c7 = arith.constant 7 : index + %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one) + threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one) + { + scf.forall (%i) in (%c7) { + %4 = memref.load %x[%i] : memref<32xf32> + %5 = memref.load %y[%i] : memref<32xf32> + %6 = math.fma %alpha, %4, %5 : f32 + memref.store %6, %y[%i] : memref<32xf32> + } { mapping = [#gpu.warp<x>, #gpu.mask<0x33>] } + gpu.terminator + } + + return %y : memref<32xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-error @below {{device masking is only available in linear mapping mode}} + transform.gpu.map_nested_forall_to_threads %funcop block_dims = [1, 1, 1] : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + +func.func @masking_mapping_attribute_requires_linear_mapping( + %x: memref<32xf32>, %y: memref<32xf32>, %t: memref<32 x f32>, %alpha : f32, %stream : !gpu.async.token) -> memref<32xf32> { + %one = arith.constant 1 : index + %c99 = arith.constant 99 : index + %name = gpu.launch async[%stream] blocks(%arg3, %arg4, %arg5) in (%arg9 = %one, %arg10 = %one, %arg11 = %one) + threads(%arg6, %arg7, %arg8) in (%arg12 = %one, %arg13 = %one, %arg14 = %one) + { + scf.forall (%i) in (%c99) { + %4 = memref.load %x[%i] : memref<32xf32> + %5 = memref.load %y[%i] : memref<32xf32> + %6 = math.fma %alpha, %4, %5 : f32 + memref.store %6, %y[%i] : memref<32xf32> + } { mapping = [#gpu.thread<linear_dim_0>, #gpu.mask<0xff>] } + gpu.terminator + } + + return %y : memref<32xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %funcop = transform.structured.match ops{["gpu.launch"]} in %arg0 : (!transform.any_op) -> !transform.any_op + // expected-error @below {{mask representation too short to capture all physical ids: 64}} + transform.gpu.map_nested_forall_to_threads %funcop block_dims = [128, 1, 1] : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + +// ----- + func.func public @not_a_block_mapping_attribute(%arg0: memref<32x32xf32>, %arg1: memref<32x32xf32>, %arg2: memref<32x32xf32>) { scf.forall (%arg3, %arg4) in (1, 1) { linalg.matmul ins(%arg0, %arg1 : memref<32x32xf32>, memref<32x32xf32>) outs(%arg2 : memref<32x32xf32>) diff --git a/mlir/test/Dialect/GPU/transform-gpu.mlir b/mlir/test/Dialect/GPU/transform-gpu.mlir index fe5d451408355..a9cd45a192c0f 100644 --- a/mlir/test/Dialect/GPU/transform-gpu.mlir +++ b/mlir/test/Dialect/GPU/transform-gpu.mlir @@ -754,3 +754,84 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +#map = affine_map<(d0) -> (d0 * 128)> +#map1 = affine_map<(d0) -> (d0 * 32)> + +// CHECK-DAG: #[[$MAPB:.*]] = affine_map<(d0) -> (d0 * 128)> +// CHECK-DAG: #[[$MAP_LIN_W:.*]] = affine_map<(d0, d1) -> ((d0 + d1 * 73) floordiv 32)> +// CHECK-DAG: #[[$MAP_W0:.*]] = affine_map<(d0) -> (d0 * 32 - (d0 floordiv 2) * 64)> +// CHECK-DAG: #[[$MAP_W1:.*]] = affine_map<(d0) -> ((d0 floordiv 2) * 32)> + +// CHECK-LABEL: func.func @simple_fill( +func.func @simple_fill(%arg0: memref<128xf32>) -> memref<128xf32> { + %c0 = arith.constant 0 : index + %cst = arith.constant dense<0.000000e+00> : vector<32xf32> +// CHECK-DAG: %[[C0_i64:.*]] = arith.constant 0 : i64 +// CHECK-DAG: %[[C1_i64:.*]] = arith.constant 1 : i64 +/// 0x2f1 is 753 +// CHECK-DAG: %[[C753_i64:.*]] = arith.constant 753 : i64 + +// CHECK: gpu.launch + scf.forall (%arg1) in (1) { +// CHECK: %[[BIDX:.*]] = gpu.block_id x +// CHECK: %[[BLX:.*]] = affine.apply #[[$MAPB]](%[[BIDX]]) + %0 = affine.apply #map(%arg1) + %subview = memref.subview %arg0[%0] [128] [1] : memref<128xf32> to memref<128xf32, strided<[1], offset: ?>> + + // %arg2 and %arg3 map to lanes [0, 6) and are turned into epxressions + // involving threadIdx.x/y by the map_nested_forall_to_threads + // transformation. This results in a if (linear_thread_id < 6) conditional. + scf.forall (%arg2, %arg3) in (2, 3) { + // CHECK: %[[TIDX:.*]] = gpu.thread_id x + // CHECK: %[[TIDY:.*]] = gpu.thread_id y + + // CHECK: %[[LIN_W:.*]] = affine.apply #[[$MAP_LIN_W]](%[[TIDX]], %[[TIDY]]) + // + // Compute the active warps below using the mask + popcnt + // CHECK: %[[LIN_W_i64:.*]] = arith.index_castui %[[LIN_W]] : index to i64 + // CHECK: %[[TWO_POW_W:.*]] = arith.shli %[[C1_i64]], %[[LIN_W_i64]] : i64 + // CHECK: %[[FILTER_TILL_W:.*]] = arith.subi %[[TWO_POW_W]], %[[C1_i64]] : i64 + // CHECK: %[[ACTIVE_TILL_W:.*]] = arith.andi %[[FILTER_TILL_W]], %[[C753_i64]] : i64 + // CHECK: %[[LOGICAL_ID_W_i64:.*]] = math.ctpop %[[ACTIVE_TILL_W]] : i64 + // CHECK: %[[LOGICAL_ID_W:.*]] = arith.index_castui %[[LOGICAL_ID_W_i64]] : i64 to index + // + // Dynamically compute whether this warp is active below using the mask + popcnt + // CHECK: %[[IS_ACTIVE_W_MASK:.*]] = arith.andi %[[TWO_POW_W]], %[[C753_i64]] : i64 + // CHECK: %[[IS_ACTIVE_W:.*]] = arith.cmpi ne, %[[IS_ACTIVE_W_MASK]], %[[C0_i64]] : i64 + // CHECK: scf.if %[[IS_ACTIVE_W]] { + + // CHECK: %[[W0:.*]] = affine.apply #[[$MAP_W0]](%[[LOGICAL_ID_W]]) + // CHECK: %[[W1:.*]] = affine.apply #[[$MAP_W1]](%[[LOGICAL_ID_W]]) + // CHECK: memref.subview %{{.*}}[%[[W0]]] [%[[W1]]] + %1 = affine.apply #map1(%arg2) + %2 = affine.apply #map1(%arg3) + %subview_0 = memref.subview %subview[%1] [%2] [1] : memref<128xf32, strided<[1], offset: ?>> to memref<?xf32, strided<[1], offset: ?>> + vector.transfer_write %cst, %subview_0[%c0] {in_bounds = [true]} : vector<32xf32>, memref<?xf32, strided<[1], offset: ?>> + + // This could be obtained e.g. if a previous transformation mapped this loop + // to lanes. This can aslo be written by hand as valid IR. + // This additionally uses the hex mask: 0x 10 1111 0001 + } {mapping = [#gpu.warp<linear_dim_0>, #gpu.warp<linear_dim_1>, #gpu.mask<0x2f1>]} + + memref.copy %subview, %subview : memref<128xf32, strided<[1], offset: ?>> to memref<128xf32, strided<[1], offset: ?>> + } {mapping = [#gpu.block<x>]} + return %arg0 : memref<128xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module_op: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module_op + : (!transform.any_op) -> !transform.any_op + %gpu_launch = transform.gpu.map_forall_to_blocks %func generate_gpu_launch + : (!transform.any_op) -> !transform.any_op + + // This transformation maps scf.forall ivs to a particular mapping of thread + // ids (laneid, threadid, warpid or warpgroupid). + transform.gpu.map_nested_forall_to_threads %gpu_launch block_dims = [73, 5, 1] + : (!transform.any_op) -> !transform.any_op + transform.yield + } +} diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir index b944852ceba3f..bb7958083e55c 100644 --- a/mlir/test/Dialect/SCF/invalid.mlir +++ b/mlir/test/Dialect/SCF/invalid.mlir @@ -684,6 +684,24 @@ func.func @forall_wrong_terminator_op() -> () { // ----- +func.func @at_most_one_masking_attribute(%in: tensor<100xf32>, %out: tensor<100xf32>) { + %c1 = arith.constant 1 : index + %num_threads = arith.constant 100 : index + + // expected-error @below {{"mapping" supports at most one device masking attribute}} + %result = scf.forall (%thread_idx) in (%num_threads) shared_outs(%o = %out) -> (tensor<100xf32>) { + %1 = tensor.extract_slice %in[%thread_idx][1][1] : tensor<100xf32> to tensor<1xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %1 into %o[%thread_idx][1][1] : + tensor<1xf32> into tensor<100xf32> + } + } { mapping = [#gpu.thread<x>, #gpu.mask<0x1>, #gpu.mask<0x2>] } + + return +} + +// ----- + func.func @switch_wrong_case_count(%arg0: index) { // expected-error @below {{'scf.index_switch' op has 0 case regions but 1 case values}} "scf.index_switch"(%arg0) ({ _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits