https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/135940
>From d137ec06b1b846232a77b78472c522183b872152 Mon Sep 17 00:00:00 2001 From: Matthias Springer <msprin...@nvidia.com> Date: Wed, 16 Apr 2025 10:02:41 +0200 Subject: [PATCH] [mlir][memref][NFC] Simplify `constifyIndexValues` --- mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp | 151 ++++++++--------------- 1 file changed, 49 insertions(+), 102 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index 63f5251398716..e773236b30c68 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -88,101 +88,30 @@ SmallVector<OpFoldResult> memref::getMixedSizes(OpBuilder &builder, // Utility functions for propagating static information //===----------------------------------------------------------------------===// -/// Helper function that infers the constant values from a list of \p values, -/// a \p memRefTy, and another helper function \p getAttributes. -/// The inferred constant values replace the related `OpFoldResult` in -/// \p values. +/// Helper function that sets values[i] to constValues[i] if the latter is a +/// static value, as indicated by ShapedType::kDynamic. /// -/// \note This function shouldn't be used directly, instead, use the -/// `getConstifiedMixedXXX` methods from the related operations. -/// -/// \p getAttributes retuns a list of potentially constant values, as determined -/// by \p isDynamic, from the given \p memRefTy. The returned list must have as -/// many elements as \p values or be empty. -/// -/// E.g., consider the following example: -/// ``` -/// memref.reinterpret_cast %base to <...> strides: [2, %dyn_stride] : -/// memref<f32> to memref<?x?xf32, strided<[?, 1], offset: ?>> -/// ``` -/// `ReinterpretCastOp::getMixedStrides()` will return `[2, %dyn_stride]`. -/// Now using this helper function with: -/// - `values == [2, %dyn_stride]`, -/// - `memRefTy == memref<?x?xf32, strided<[?, 1], offset: ?>>` -/// - `getAttributes == getConstantStrides` (i.e., a wrapper around -/// `getStridesAndOffset`), and -/// - `isDynamic == ShapedType::isDynamic` -/// Will yield: `values == [2, 1]` -static void constifyIndexValues( - SmallVectorImpl<OpFoldResult> &values, MemRefType memRefTy, - MLIRContext *ctxt, - llvm::function_ref<SmallVector<int64_t>(MemRefType)> getAttributes, - llvm::function_ref<bool(int64_t)> isDynamic) { - SmallVector<int64_t> constValues = getAttributes(memRefTy); - Builder builder(ctxt); - for (const auto &it : llvm::enumerate(constValues)) { - int64_t constValue = it.value(); - if (!isDynamic(constValue)) - values[it.index()] = builder.getIndexAttr(constValue); - } - for (OpFoldResult &ofr : values) { - if (auto attr = dyn_cast<Attribute>(ofr)) { - // FIXME: We shouldn't need to do that, but right now, the static indices - // are created with the wrong type: `i64` instead of `index`. - // As a result, if we were to keep the attribute as is, we may fail to see - // that two attributes are equal because one would have the i64 type and - // the other the index type. - // The alternative would be to create constant indices with getI64Attr in - // this and the previous loop, but it doesn't logically make sense (we are - // dealing with indices here) and would only strenghten the inconsistency - // around how static indices are created (some places use getI64Attr, - // others use getIndexAttr). - // The workaround here is to stick to the IndexAttr type for all the - // values, hence we recreate the attribute even when it is already static - // to make sure the type is consistent. - ofr = builder.getIndexAttr(llvm::cast<IntegerAttr>(attr).getInt()); +/// If constValues[i] is dynamic, tries to extract a constant value from +/// value[i] to allow for additional folding opportunities. Also convertes all +/// existing attributes to index attributes. (They may be i64 attributes.) +static void constifyIndexValues(SmallVectorImpl<OpFoldResult> &values, + ArrayRef<int64_t> constValues) { + assert(constValues.size() == values.size() && + "incorrect number of const values"); + for (int64_t i = 0, e = constValues.size(); i < e; ++i) { + Builder builder(values[i].getContext()); + if (!ShapedType::isDynamic(constValues[i])) { + // Constant value is known, use it directly. + values[i] = builder.getIndexAttr(constValues[i]); continue; } - std::optional<int64_t> maybeConstant = - getConstantIntValue(cast<Value>(ofr)); - if (maybeConstant) - ofr = builder.getIndexAttr(*maybeConstant); + if (std::optional<int64_t> cst = getConstantIntValue(values[i])) { + // Try to extract a constant or convert an existing to index. + values[i] = builder.getIndexAttr(*cst); + } } } -/// Wrapper around `getShape` that conforms to the function signature -/// expected for `getAttributes` in `constifyIndexValues`. -static SmallVector<int64_t> getConstantSizes(MemRefType memRefTy) { - ArrayRef<int64_t> sizes = memRefTy.getShape(); - return SmallVector<int64_t>(sizes); -} - -/// Wrapper around `getStridesAndOffset` that returns only the offset and -/// conforms to the function signature expected for `getAttributes` in -/// `constifyIndexValues`. -static SmallVector<int64_t> getConstantOffset(MemRefType memrefType) { - SmallVector<int64_t> strides; - int64_t offset; - LogicalResult hasStaticInformation = - memrefType.getStridesAndOffset(strides, offset); - if (failed(hasStaticInformation)) - return SmallVector<int64_t>(); - return SmallVector<int64_t>(1, offset); -} - -/// Wrapper around `getStridesAndOffset` that returns only the strides and -/// conforms to the function signature expected for `getAttributes` in -/// `constifyIndexValues`. -static SmallVector<int64_t> getConstantStrides(MemRefType memrefType) { - SmallVector<int64_t> strides; - int64_t offset; - LogicalResult hasStaticInformation = - memrefType.getStridesAndOffset(strides, offset); - if (failed(hasStaticInformation)) - return SmallVector<int64_t>(); - return strides; -} - //===----------------------------------------------------------------------===// // AllocOp / AllocaOp //===----------------------------------------------------------------------===// @@ -1445,24 +1374,34 @@ ExtractStridedMetadataOp::fold(FoldAdaptor adaptor, SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedSizes() { SmallVector<OpFoldResult> values = getAsOpFoldResult(getSizes()); - constifyIndexValues(values, getSource().getType(), getContext(), - getConstantSizes, ShapedType::isDynamic); + constifyIndexValues(values, getSource().getType().getShape()); return values; } SmallVector<OpFoldResult> ExtractStridedMetadataOp::getConstifiedMixedStrides() { SmallVector<OpFoldResult> values = getAsOpFoldResult(getStrides()); - constifyIndexValues(values, getSource().getType(), getContext(), - getConstantStrides, ShapedType::isDynamic); + SmallVector<int64_t> staticValues; + int64_t unused; + LogicalResult status = + getSource().getType().getStridesAndOffset(staticValues, unused); + (void)status; + assert(succeeded(status) && "could not get strides from type"); + constifyIndexValues(values, staticValues); return values; } OpFoldResult ExtractStridedMetadataOp::getConstifiedMixedOffset() { OpFoldResult offsetOfr = getAsOpFoldResult(getOffset()); SmallVector<OpFoldResult> values(1, offsetOfr); - constifyIndexValues(values, getSource().getType(), getContext(), - getConstantOffset, ShapedType::isDynamic); + SmallVector<int64_t> staticValues, unused; + int64_t offset; + LogicalResult status = + getSource().getType().getStridesAndOffset(unused, offset); + (void)status; + assert(succeeded(status) && "could not get offset from type"); + staticValues.push_back(offset); + constifyIndexValues(values, staticValues); return values[0]; } @@ -1975,15 +1914,18 @@ OpFoldResult ReinterpretCastOp::fold(FoldAdaptor /*operands*/) { SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedSizes() { SmallVector<OpFoldResult> values = getMixedSizes(); - constifyIndexValues(values, getType(), getContext(), getConstantSizes, - ShapedType::isDynamic); + constifyIndexValues(values, getType().getShape()); return values; } SmallVector<OpFoldResult> ReinterpretCastOp::getConstifiedMixedStrides() { SmallVector<OpFoldResult> values = getMixedStrides(); - constifyIndexValues(values, getType(), getContext(), getConstantStrides, - ShapedType::isDynamic); + SmallVector<int64_t> staticValues; + int64_t unused; + LogicalResult status = getType().getStridesAndOffset(staticValues, unused); + (void)status; + assert(succeeded(status) && "could not get strides from type"); + constifyIndexValues(values, staticValues); return values; } @@ -1991,8 +1933,13 @@ OpFoldResult ReinterpretCastOp::getConstifiedMixedOffset() { SmallVector<OpFoldResult> values = getMixedOffsets(); assert(values.size() == 1 && "reinterpret_cast must have one and only one offset"); - constifyIndexValues(values, getType(), getContext(), getConstantOffset, - ShapedType::isDynamic); + SmallVector<int64_t> staticValues, unused; + int64_t offset; + LogicalResult status = getType().getStridesAndOffset(unused, offset); + (void)status; + assert(succeeded(status) && "could not get offset from type"); + staticValues.push_back(offset); + constifyIndexValues(values, staticValues); return values[0]; } @@ -2062,7 +2009,7 @@ struct ReinterpretCastOpExtractStridedMetadataFolder // Second, check the sizes. if (!llvm::equal(extractStridedMetadata.getConstifiedMixedSizes(), op.getConstifiedMixedSizes())) - return false; + return false; // Finally, check the offset. assert(op.getMixedOffsets().size() == 1 && _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits