https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/114155
>From 5c02edc9f35d4c35b2c25bc3dba4d10531e2a4ab Mon Sep 17 00:00:00 2001 From: Matthias Springer <msprin...@nvidia.com> Date: Wed, 30 Oct 2024 00:58:32 +0100 Subject: [PATCH] [mlir][bufferization] Remove remaining dialect conversion-based infra parts This commit removes the last remaining components of the dialect conversion-based bufferization passes. Note for LLVM integration: If you depend on these components, migrate to One-Shot Bufferize or copy them to your codebase. Depends on #114154. --- .../Bufferization/Transforms/Bufferize.h | 23 ------ .../mlir/Dialect/Func/Transforms/Passes.h | 4 - .../Bufferization/Transforms/BufferUtils.cpp | 6 +- .../Bufferization/Transforms/Bufferize.cpp | 73 ------------------- 4 files changed, 4 insertions(+), 102 deletions(-) diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h index ebed2c354bfca5..2f495d304b4a56 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h @@ -38,24 +38,6 @@ struct BufferizationStatistics { int64_t numTensorOutOfPlace = 0; }; -/// A helper type converter class that automatically populates the relevant -/// materializations and type conversions for bufferization. -class BufferizeTypeConverter : public TypeConverter { -public: - BufferizeTypeConverter(); -}; - -/// Marks ops used by bufferization for type conversion materializations as -/// "legal" in the given ConversionTarget. -/// -/// This function should be called by all bufferization passes using -/// BufferizeTypeConverter so that materializations work properly. One exception -/// is bufferization passes doing "full" conversions, where it can be desirable -/// for even the materializations to remain illegal so that they are eliminated, -/// such as via the patterns in -/// populateEliminateBufferizeMaterializationsPatterns. -void populateBufferizeMaterializationLegality(ConversionTarget &target); - /// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. /// /// Note: This function does not resolve read-after-write conflicts. Use this @@ -81,11 +63,6 @@ LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options, LogicalResult bufferizeBlockSignature(Block *block, RewriterBase &rewriter, const BufferizationOptions &options); -/// Return `BufferizationOptions` such that the `bufferizeOp` behaves like the -/// old (deprecated) partial, dialect conversion-based bufferization passes. A -/// copy will be inserted before every buffer write. -BufferizationOptions getPartialBufferizationOptions(); - } // namespace bufferization } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Func/Transforms/Passes.h b/mlir/include/mlir/Dialect/Func/Transforms/Passes.h index 02fc9e1d934390..0248f068320c54 100644 --- a/mlir/include/mlir/Dialect/Func/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/Func/Transforms/Passes.h @@ -18,10 +18,6 @@ #include "mlir/Pass/Pass.h" namespace mlir { -namespace bufferization { -class BufferizeTypeConverter; -} // namespace bufferization - class RewritePatternSet; namespace func { diff --git a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp index 8fffdbf664c3f4..b11803da19ef98 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/BufferUtils.cpp @@ -11,6 +11,8 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Bufferization/Transforms/BufferUtils.h" + +#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" @@ -138,8 +140,8 @@ bufferization::getGlobalFor(arith::ConstantOp constantOp, uint64_t alignment, alignment > 0 ? IntegerAttr::get(globalBuilder.getI64Type(), alignment) : IntegerAttr(); - BufferizeTypeConverter typeConverter; - auto memrefType = cast<MemRefType>(typeConverter.convertType(type)); + auto memrefType = + cast<MemRefType>(getMemRefTypeWithStaticIdentityLayout(type)); if (memorySpace) memrefType = MemRefType::Builder(memrefType).setMemorySpace(memorySpace); auto global = globalBuilder.create<memref::GlobalOp>( diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 62ce2583f4fa1d..6f0cdfa20f7be5 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -37,65 +37,6 @@ namespace bufferization { using namespace mlir; using namespace mlir::bufferization; -//===----------------------------------------------------------------------===// -// BufferizeTypeConverter -//===----------------------------------------------------------------------===// - -static Value materializeToTensor(OpBuilder &builder, TensorType type, - ValueRange inputs, Location loc) { - assert(inputs.size() == 1); - assert(isa<BaseMemRefType>(inputs[0].getType())); - return builder.create<bufferization::ToTensorOp>(loc, type, inputs[0]); -} - -/// Registers conversions into BufferizeTypeConverter -BufferizeTypeConverter::BufferizeTypeConverter() { - // Keep all types unchanged. - addConversion([](Type type) { return type; }); - // Convert RankedTensorType to MemRefType. - addConversion([](RankedTensorType type) -> Type { - return MemRefType::get(type.getShape(), type.getElementType()); - }); - // Convert UnrankedTensorType to UnrankedMemRefType. - addConversion([](UnrankedTensorType type) -> Type { - return UnrankedMemRefType::get(type.getElementType(), 0); - }); - addArgumentMaterialization(materializeToTensor); - addSourceMaterialization(materializeToTensor); - addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type, - ValueRange inputs, Location loc) -> Value { - assert(inputs.size() == 1 && "expected exactly one input"); - - if (auto inputType = dyn_cast<MemRefType>(inputs[0].getType())) { - // MemRef to MemRef cast. - assert(inputType != type && "expected different types"); - // Unranked to ranked and ranked to unranked casts must be explicit. - auto rankedDestType = dyn_cast<MemRefType>(type); - if (!rankedDestType) - return nullptr; - BufferizationOptions options; - options.bufferAlignment = 0; - FailureOr<Value> replacement = - castOrReallocMemRefValue(builder, inputs[0], rankedDestType, options); - if (failed(replacement)) - return nullptr; - return *replacement; - } - - if (isa<TensorType>(inputs[0].getType())) { - // Tensor to MemRef cast. - return builder.create<bufferization::ToMemrefOp>(loc, type, inputs[0]); - } - - llvm_unreachable("only tensor/memref input types supported"); - }); -} - -void mlir::bufferization::populateBufferizeMaterializationLegality( - ConversionTarget &target) { - target.addLegalOp<bufferization::ToTensorOp, bufferization::ToMemrefOp>(); -} - namespace { static LayoutMapOption parseLayoutMapOption(const std::string &s) { @@ -545,17 +486,3 @@ bufferization::bufferizeBlockSignature(Block *block, RewriterBase &rewriter, return success(); } - -BufferizationOptions bufferization::getPartialBufferizationOptions() { - BufferizationOptions options; - options.allowUnknownOps = true; - options.copyBeforeWrite = true; - options.enforceAliasingInvariants = false; - options.unknownTypeConverterFn = [](Value value, Attribute memorySpace, - const BufferizationOptions &options) { - return getMemRefTypeWithStaticIdentityLayout( - cast<TensorType>(value.getType()), memorySpace); - }; - options.opFilter.allowDialect<BufferizationDialect>(); - return options; -} _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits