https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/133286
>From 44cfa133cbaae27620c911d15d985a5b51f1f1aa Mon Sep 17 00:00:00 2001 From: Matthias Springer <msprin...@nvidia.com> Date: Thu, 27 Mar 2025 18:42:56 +0100 Subject: [PATCH 1/2] [mlir][LLVM] Delete `LLVMFixedVectorType` --- mlir/docs/Dialects/LLVM.md | 8 +- mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h | 1 - mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td | 46 +++------ mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 54 +++++------ mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 2 +- mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp | 18 ++-- mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp | 95 ++++++------------- mlir/lib/Target/LLVMIR/TypeToLLVM.cpp | 10 +- mlir/test/Dialect/LLVMIR/types-invalid.mlir | 19 ---- mlir/test/Dialect/LLVMIR/types.mlir | 2 + 10 files changed, 79 insertions(+), 176 deletions(-) diff --git a/mlir/docs/Dialects/LLVM.md b/mlir/docs/Dialects/LLVM.md index fadc81b567b4e..81c358244d96e 100644 --- a/mlir/docs/Dialects/LLVM.md +++ b/mlir/docs/Dialects/LLVM.md @@ -327,11 +327,9 @@ multiple of some fixed size in case of _scalable_ vectors, and the element type. Vectors cannot be nested and only 1D vectors are supported. Scalable vectors are still considered 1D. -LLVM dialect uses built-in vector types for _fixed_-size vectors of built-in -types, and provides additional types for fixed-sized vectors of LLVM dialect -types (`LLVMFixedVectorType`) and scalable vectors of any types -(`LLVMScalableVectorType`). These two additional types share the following -syntax: +The LLVM dialect uses built-in vector types for _fixed_-size vectors of built-in +types, and provides additional types for scalable vectors of any types +(`LLVMScalableVectorType`): ``` llvm-vec-type ::= `!llvm.vec<` (`?` `x`)? integer-literal `x` type `>` diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h index bca0feb45aab2..9d238fc746b8f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.h @@ -67,7 +67,6 @@ namespace LLVM { } DEFINE_TRIVIAL_LLVM_TYPE(LLVMVoidType, "llvm.void"); -DEFINE_TRIVIAL_LLVM_TYPE(LLVMPPCFP128Type, "llvm.ppc_fp128"); DEFINE_TRIVIAL_LLVM_TYPE(LLVMTokenType, "llvm.token"); DEFINE_TRIVIAL_LLVM_TYPE(LLVMLabelType, "llvm.label"); DEFINE_TRIVIAL_LLVM_TYPE(LLVMMetadataType, "llvm.metadata"); diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td index 3386003cb61fb..fe12ab99b9141 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td @@ -11,6 +11,7 @@ include "mlir/Dialect/LLVMIR/LLVMOpBase.td" include "mlir/IR/AttrTypeBase.td" +include "mlir/IR/BuiltinTypeInterfaces.td" include "mlir/IR/BuiltinTypes.td" include "mlir/Interfaces/DataLayoutInterfaces.td" include "mlir/Interfaces/MemorySlotInterfaces.td" @@ -288,38 +289,6 @@ def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [ ]; } -//===----------------------------------------------------------------------===// -// LLVMFixedVectorType -//===----------------------------------------------------------------------===// - -def LLVMFixedVectorType : LLVMType<"LLVMFixedVector", "vec"> { - let summary = "LLVM fixed vector type"; - let description = [{ - LLVM dialect vector type that supports all element types that are supported - in LLVM vectors but that are not supported by the builtin MLIR vector type. - E.g., LLVMFixedVectorType supports LLVM pointers as element type. - }]; - - let typeName = "llvm.fixed_vec"; - - let parameters = (ins "Type":$elementType, "unsigned":$numElements); - let assemblyFormat = [{ - `<` $numElements `x` custom<PrettyLLVMType>($elementType) `>` - }]; - - let genVerifyDecl = 1; - - let builders = [ - TypeBuilderWithInferredContext<(ins "Type":$elementType, - "unsigned":$numElements)> - ]; - - let extraClassDeclaration = [{ - /// Checks if the given type can be used in a vector type. - static bool isValidElementType(Type type); - }]; -} - //===----------------------------------------------------------------------===// // LLVMScalableVectorType //===----------------------------------------------------------------------===// @@ -400,4 +369,17 @@ def LLVMX86AMXType : LLVMType<"LLVMX86AMX", "x86_amx"> { }]; } +//===----------------------------------------------------------------------===// +// LLVMPPCFP128Type +//===----------------------------------------------------------------------===// + +def LLVMPPCFP128Type : LLVMType<"LLVMPPCFP128", "ppc_fp128", + [DeclareTypeInterfaceMethods<FloatTypeInterface, ["getFloatSemantics"]>]> { + let summary = "128 bit FP type with IBM double-double semantics"; + let description = [{ + A 128 bit floating-point type with IBM double-double semantics. + See S_PPCDoubleDouble in APFloat.h for details. + }]; +} + #endif // LLVMTYPES_TD diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 18a70cc64628f..29701ffc89b19 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -686,8 +686,6 @@ static Type extractVectorElementType(Type type) { return vectorType.getElementType(); if (auto scalableVectorType = llvm::dyn_cast<LLVMScalableVectorType>(type)) return scalableVectorType.getElementType(); - if (auto fixedVectorType = llvm::dyn_cast<LLVMFixedVectorType>(type)) - return fixedVectorType.getElementType(); return type; } @@ -724,20 +722,19 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices, if (rawConstantIndices.size() == 1 || !currType) continue; - currType = - TypeSwitch<Type, Type>(currType) - .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType, - LLVMArrayType>([](auto containerType) { - return containerType.getElementType(); - }) - .Case([&](LLVMStructType structType) -> Type { - int64_t memberIndex = rawConstantIndices.back(); - if (memberIndex >= 0 && static_cast<size_t>(memberIndex) < - structType.getBody().size()) - return structType.getBody()[memberIndex]; - return nullptr; - }) - .Default(Type(nullptr)); + currType = TypeSwitch<Type, Type>(currType) + .Case<VectorType, LLVMScalableVectorType, LLVMArrayType>( + [](auto containerType) { + return containerType.getElementType(); + }) + .Case([&](LLVMStructType structType) -> Type { + int64_t memberIndex = rawConstantIndices.back(); + if (memberIndex >= 0 && static_cast<size_t>(memberIndex) < + structType.getBody().size()) + return structType.getBody()[memberIndex]; + return nullptr; + }) + .Default(Type(nullptr)); } } @@ -838,11 +835,11 @@ verifyStructIndices(Type baseGEPType, unsigned indexPos, return verifyStructIndices(elementTypes[gepIndex], indexPos + 1, indices, emitOpError); }) - .Case<VectorType, LLVMScalableVectorType, LLVMFixedVectorType, - LLVMArrayType>([&](auto containerType) -> LogicalResult { - return verifyStructIndices(containerType.getElementType(), indexPos + 1, - indices, emitOpError); - }) + .Case<VectorType, LLVMScalableVectorType, LLVMArrayType>( + [&](auto containerType) -> LogicalResult { + return verifyStructIndices(containerType.getElementType(), + indexPos + 1, indices, emitOpError); + }) .Default([&](auto otherType) -> LogicalResult { return emitOpError() << "type " << otherType << " cannot be indexed (index #" @@ -3108,16 +3105,14 @@ OpFoldResult LLVM::ZeroOp::fold(FoldAdaptor) { //===----------------------------------------------------------------------===// /// Compute the total number of elements in the given type, also taking into -/// account nested types. Supported types are `VectorType`, `LLVMArrayType` and -/// `LLVMFixedVectorType`. Everything else is treated as a scalar. +/// account nested types. Supported types are `VectorType` and `LLVMArrayType`. +/// Everything else is treated as a scalar. static int64_t getNumElements(Type t) { if (auto vecType = dyn_cast<VectorType>(t)) return vecType.getNumElements() * getNumElements(vecType.getElementType()); if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t)) return arrayType.getNumElements() * getNumElements(arrayType.getElementType()); - if (auto vecType = dyn_cast<LLVMFixedVectorType>(t)) - return vecType.getNumElements() * getNumElements(vecType.getElementType()); assert(!isa<LLVM::LLVMScalableVectorType>(t) && "number of elements of a scalable vector type is unknown"); return 1; @@ -3135,8 +3130,6 @@ static bool hasScalableVectorType(Type t) { } if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t)) return hasScalableVectorType(arrayType.getElementType()); - if (auto vecType = dyn_cast<LLVMFixedVectorType>(t)) - return hasScalableVectorType(vecType.getElementType()); return false; } @@ -3216,8 +3209,7 @@ LogicalResult LLVM::ConstantOp::verify() { << "scalable vector type requires a splat attribute"; return success(); } - if (!isa<VectorType, LLVM::LLVMArrayType, LLVM::LLVMFixedVectorType>( - getType())) + if (!isa<VectorType, LLVM::LLVMArrayType>(getType())) return emitOpError() << "expected vector or array type"; // The number of elements of the attribute and the type must match. int64_t attrNumElements; @@ -3466,8 +3458,7 @@ LogicalResult LLVM::BitcastOp::verify() { if (!resultType) return success(); - auto isVector = - llvm::IsaPred<VectorType, LLVMScalableVectorType, LLVMFixedVectorType>; + auto isVector = llvm::IsaPred<VectorType, LLVMScalableVectorType>; // Due to bitcast requiring both operands to be of the same size, it is not // possible for only one of the two to be a pointer of vectors. @@ -3883,7 +3874,6 @@ void LLVMDialect::initialize() { // clang-format off addTypes<LLVMVoidType, - LLVMPPCFP128Type, LLVMTokenType, LLVMLabelType, LLVMMetadataType>(); diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index 51dcb071f9c18..c5a1502c8cbe8 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -137,7 +137,7 @@ static bool isSupportedTypeForConversion(Type type) { // LLVM vector types are only used for either pointers or target specific // types. These types cannot be casted in the general case, thus the memory // optimizations do not support them. - if (isa<LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType>(type)) + if (isa<LLVM::LLVMScalableVectorType>(type)) return false; if (auto vectorType = dyn_cast<VectorType>(type)) { diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp index d700dc52d42d2..edfc5adeb424e 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -40,8 +40,7 @@ static StringRef getTypeKeyword(Type type) { .Case<LLVMMetadataType>([&](Type) { return "metadata"; }) .Case<LLVMFunctionType>([&](Type) { return "func"; }) .Case<LLVMPointerType>([&](Type) { return "ptr"; }) - .Case<LLVMFixedVectorType, LLVMScalableVectorType>( - [&](Type) { return "vec"; }) + .Case<LLVMScalableVectorType>([&](Type) { return "vec"; }) .Case<LLVMArrayType>([&](Type) { return "array"; }) .Case<LLVMStructType>([&](Type) { return "struct"; }) .Case<LLVMTargetExtType>([&](Type) { return "target"; }) @@ -104,9 +103,9 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) { printer << getTypeKeyword(type); llvm::TypeSwitch<Type>(type) - .Case<LLVMPointerType, LLVMArrayType, LLVMFixedVectorType, - LLVMScalableVectorType, LLVMFunctionType, LLVMTargetExtType, - LLVMStructType>([&](auto type) { type.print(printer); }); + .Case<LLVMPointerType, LLVMArrayType, LLVMScalableVectorType, + LLVMFunctionType, LLVMTargetExtType, LLVMStructType>( + [&](auto type) { type.print(printer); }); } //===----------------------------------------------------------------------===// @@ -143,14 +142,11 @@ static Type parseVectorType(AsmParser &parser) { } bool isScalable = dims.size() == 2; - if (isScalable) - return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]); - if (elementType.isSignlessIntOrFloat()) { - parser.emitError(typePos) - << "cannot use !llvm.vec for built-in primitives, use 'vector' instead"; + if (!isScalable) { + parser.emitError(dimPos) << "expected scalable vector"; return Type(); } - return parser.getChecked<LLVMFixedVectorType>(loc, elementType, dims[0]); + return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]); } /// Attempts to set the body of an identified structure type. Reports a parsing diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index 403756765268e..b008659c7e958 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -658,7 +658,7 @@ LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries, } //===----------------------------------------------------------------------===// -// Vector types. +// LLVMScalableVectorType. //===----------------------------------------------------------------------===// /// Verifies that the type about to be constructed is well-formed. @@ -675,35 +675,6 @@ verifyVectorConstructionInvariants(function_ref<InFlightDiagnostic()> emitError, return success(); } -LLVMFixedVectorType LLVMFixedVectorType::get(Type elementType, - unsigned numElements) { - assert(elementType && "expected non-null subtype"); - return Base::get(elementType.getContext(), elementType, numElements); -} - -LLVMFixedVectorType -LLVMFixedVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError, - Type elementType, unsigned numElements) { - assert(elementType && "expected non-null subtype"); - return Base::getChecked(emitError, elementType.getContext(), elementType, - numElements); -} - -bool LLVMFixedVectorType::isValidElementType(Type type) { - return llvm::isa<LLVMPPCFP128Type>(type); -} - -LogicalResult -LLVMFixedVectorType::verify(function_ref<InFlightDiagnostic()> emitError, - Type elementType, unsigned numElements) { - return verifyVectorConstructionInvariants<LLVMFixedVectorType>( - emitError, elementType, numElements); -} - -//===----------------------------------------------------------------------===// -// LLVMScalableVectorType. -//===----------------------------------------------------------------------===// - LLVMScalableVectorType LLVMScalableVectorType::get(Type elementType, unsigned minNumElements) { assert(elementType && "expected non-null subtype"); @@ -762,6 +733,14 @@ bool LLVM::LLVMTargetExtType::supportsMemOps() const { return false; } +//===----------------------------------------------------------------------===// +// LLVMPPCFP128Type +//===----------------------------------------------------------------------===// + +const llvm::fltSemantics &LLVMPPCFP128Type::getFloatSemantics() const { + return APFloat::PPCDoubleDouble(); +} + //===----------------------------------------------------------------------===// // Utility functions. //===----------------------------------------------------------------------===// @@ -783,7 +762,6 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) { LLVMPointerType, LLVMStructType, LLVMTokenType, - LLVMFixedVectorType, LLVMScalableVectorType, LLVMTargetExtType, LLVMVoidType, @@ -832,7 +810,6 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) { }) // clang-format off .Case< - LLVMFixedVectorType, LLVMScalableVectorType, LLVMArrayType >([&](auto containerType) { @@ -880,7 +857,7 @@ bool mlir::LLVM::isCompatibleFloatingPointType(Type type) { } bool mlir::LLVM::isCompatibleVectorType(Type type) { - if (llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType>(type)) + if (llvm::isa<LLVMScalableVectorType>(type)) return true; if (auto vecType = llvm::dyn_cast<VectorType>(type)) { @@ -897,7 +874,7 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) { Type mlir::LLVM::getVectorElementType(Type type) { return llvm::TypeSwitch<Type, Type>(type) - .Case<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>( + .Case<LLVMScalableVectorType, VectorType>( [](auto ty) { return ty.getElementType(); }) .Default([](Type) -> Type { llvm_unreachable("incompatible with LLVM vector type"); @@ -911,9 +888,6 @@ llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) { return llvm::ElementCount::getScalable(ty.getNumElements()); return llvm::ElementCount::getFixed(ty.getNumElements()); }) - .Case([](LLVMFixedVectorType ty) { - return llvm::ElementCount::getFixed(ty.getNumElements()); - }) .Case([](LLVMScalableVectorType ty) { return llvm::ElementCount::getScalable(ty.getMinNumElements()); }) @@ -923,30 +897,28 @@ llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) { } bool mlir::LLVM::isScalableVectorType(Type vectorType) { - assert((llvm::isa<LLVMFixedVectorType, LLVMScalableVectorType, VectorType>( - vectorType)) && + assert((llvm::isa<LLVMScalableVectorType, VectorType>(vectorType)) && "expected LLVM-compatible vector type"); - return !llvm::isa<LLVMFixedVectorType>(vectorType) && - (llvm::isa<LLVMScalableVectorType>(vectorType) || - llvm::cast<VectorType>(vectorType).isScalable()); + return llvm::isa<LLVMScalableVectorType>(vectorType) || + llvm::cast<VectorType>(vectorType).isScalable(); } Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements, bool isScalable) { - bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType); - bool useBuiltIn = VectorType::isValidElementType(elementType); - (void)useBuiltIn; - assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type " - "to be either builtin or LLVM dialect type"); - if (useLLVM) { - if (isScalable) - return LLVMScalableVectorType::get(elementType, numElements); - return LLVMFixedVectorType::get(elementType, numElements); + if (!isScalable) { + // Non-scalable vectors always use the MLIR vector type. + assert(VectorType::isValidElementType(elementType) && + "incompatible element type"); + return VectorType::get(numElements, elementType, {false}); } - // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as - // scalable/non-scalable. - return VectorType::get(numElements, elementType, {isScalable}); + // This is a scalable vector. + if (VectorType::isValidElementType(elementType)) + return VectorType::get(numElements, elementType, {true}); + assert(LLVMScalableVectorType::isValidElementType(elementType) && + "neither the MLIR vector type nor LLVMScalableVectorType is " + "compatible with the specified element type"); + return LLVMScalableVectorType::get(elementType, numElements); } Type mlir::LLVM::getVectorType(Type elementType, @@ -959,13 +931,8 @@ Type mlir::LLVM::getVectorType(Type elementType, } Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) { - bool useLLVM = LLVMFixedVectorType::isValidElementType(elementType); - bool useBuiltIn = VectorType::isValidElementType(elementType); - (void)useBuiltIn; - assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible fixed-vector type " - "to be either builtin or LLVM dialect type"); - if (useLLVM) - return LLVMFixedVectorType::get(elementType, numElements); + assert(VectorType::isValidElementType(elementType) && + "incompatible element type"); return VectorType::get(numElements, elementType); } @@ -1000,12 +967,6 @@ llvm::TypeSize mlir::LLVM::getPrimitiveTypeSizeInBits(Type type) { }) .Case<LLVMPPCFP128Type>( [](Type) { return llvm::TypeSize::getFixed(128); }) - .Case<LLVMFixedVectorType>([](LLVMFixedVectorType t) { - llvm::TypeSize elementSize = - getPrimitiveTypeSizeInBits(t.getElementType()); - return llvm::TypeSize(elementSize.getFixedValue() * t.getNumElements(), - elementSize.isScalable()); - }) .Case<VectorType>([](VectorType t) { assert(isCompatibleVectorType(t) && "unexpected incompatible with LLVM vector type"); diff --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp index c7a533eddce84..285766357eae7 100644 --- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp +++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp @@ -72,8 +72,8 @@ class TypeToLLVMIRTranslatorImpl { }) .Case<LLVM::LLVMArrayType, IntegerType, LLVM::LLVMFunctionType, LLVM::LLVMPointerType, LLVM::LLVMStructType, - LLVM::LLVMFixedVectorType, LLVM::LLVMScalableVectorType, - VectorType, LLVM::LLVMTargetExtType>( + LLVM::LLVMScalableVectorType, VectorType, + LLVM::LLVMTargetExtType>( [this](auto type) { return this->translate(type); }) .Default([](Type t) -> llvm::Type * { llvm_unreachable("unknown LLVM dialect type"); @@ -143,12 +143,6 @@ class TypeToLLVMIRTranslatorImpl { type.getNumElements()); } - /// Translates the given fixed-vector type. - llvm::Type *translate(LLVM::LLVMFixedVectorType type) { - return llvm::FixedVectorType::get(translateType(type.getElementType()), - type.getNumElements()); - } - /// Translates the given scalable-vector type. llvm::Type *translate(LLVM::LLVMScalableVectorType type) { return llvm::ScalableVectorType::get(translateType(type.getElementType()), diff --git a/mlir/test/Dialect/LLVMIR/types-invalid.mlir b/mlir/test/Dialect/LLVMIR/types-invalid.mlir index 76fb6780d8668..2d4a63234a7fb 100644 --- a/mlir/test/Dialect/LLVMIR/types-invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/types-invalid.mlir @@ -139,20 +139,6 @@ func.func @unscalable_vector() { // ----- -func.func @zero_vector() { - // expected-error @+1 {{the number of vector elements must be positive}} - "some.op"() : () -> !llvm.vec<0 x ptr> -} - -// ----- - -func.func @nested_vector() { - // expected-error @+1 {{invalid vector element type}} - "some.op"() : () -> !llvm.vec<2 x vector<2xi32>> -} - -// ----- - func.func @scalable_void_vector() { // expected-error @+1 {{invalid vector element type}} "some.op"() : () -> !llvm.vec<?x4 x void> @@ -170,11 +156,6 @@ func.func private @unexpected_type() -> !llvm.f32 // ----- -// expected-error @below {{cannot use !llvm.vec for built-in primitives, use 'vector' instead}} -func.func private @llvm_vector_primitive() -> !llvm.vec<4 x f32> - -// ----- - func.func private @target_ext_invalid_order() { // expected-error @+1 {{failed to parse parameter list for target extension type}} "some.op"() : () -> !llvm.target<"target1", 5, i32, 1> diff --git a/mlir/test/Dialect/LLVMIR/types.mlir b/mlir/test/Dialect/LLVMIR/types.mlir index 184205bb0b1e7..bbdef72ece391 100644 --- a/mlir/test/Dialect/LLVMIR/types.mlir +++ b/mlir/test/Dialect/LLVMIR/types.mlir @@ -78,6 +78,8 @@ func.func @vec() { "some.op"() : () -> !llvm.vec<? x 8 x f16> // CHECK: vector<4x!llvm.ptr> "some.op"() : () -> vector<4x!llvm.ptr> + // CHECK: vector<4x!llvm.ppc_fp128> + "some.op"() : () -> vector<4x!llvm.ppc_fp128> return } >From e9b67d13c897974d95fa07014b67fb4a6b604fc9 Mon Sep 17 00:00:00 2001 From: Matthias Springer <msprin...@nvidia.com> Date: Thu, 27 Mar 2025 19:05:18 +0100 Subject: [PATCH 2/2] delete scalable vec type --- mlir/docs/Dialects/LLVM.md | 17 +--- mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td | 32 ------- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 17 +--- mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp | 6 -- mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp | 42 +-------- mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp | 92 ++----------------- mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp | 4 +- mlir/lib/Target/LLVMIR/TypeToLLVM.cpp | 9 +- mlir/test/Dialect/LLVMIR/mem2reg.mlir | 2 +- mlir/test/Dialect/LLVMIR/types-invalid.mlir | 28 ------ mlir/test/Dialect/LLVMIR/types.mlir | 8 +- mlir/test/Target/LLVMIR/Import/intrinsic.ll | 4 +- mlir/test/Target/LLVMIR/llvmir-types.mlir | 4 +- 13 files changed, 29 insertions(+), 236 deletions(-) diff --git a/mlir/docs/Dialects/LLVM.md b/mlir/docs/Dialects/LLVM.md index 81c358244d96e..d0509e036682f 100644 --- a/mlir/docs/Dialects/LLVM.md +++ b/mlir/docs/Dialects/LLVM.md @@ -327,18 +327,7 @@ multiple of some fixed size in case of _scalable_ vectors, and the element type. Vectors cannot be nested and only 1D vectors are supported. Scalable vectors are still considered 1D. -The LLVM dialect uses built-in vector types for _fixed_-size vectors of built-in -types, and provides additional types for scalable vectors of any types -(`LLVMScalableVectorType`): - -``` - llvm-vec-type ::= `!llvm.vec<` (`?` `x`)? integer-literal `x` type `>` -``` - -Note that the sets of element types supported by built-in and LLVM dialect -vector types are mutually exclusive, e.g., the built-in vector type does not -accept `!llvm.ptr` and the LLVM dialect fixed-width vector type does not -accept `i32`. +The LLVM dialect uses built-in vector type. The following functions are provided to operate on any kind of the vector types compatible with the LLVM dialect: @@ -358,8 +347,8 @@ compatible with the LLVM dialect: ```mlir vector<42 x i32> // Vector of 42 32-bit integers. -!llvm.vec<42 x ptr> // Vector of 42 pointers. -!llvm.vec<? x 4 x i32> // Scalable vector of 32-bit integers with +vector<42 x !llvm.ptr> // Vector of 42 pointers. +vector<[4] x i32> // Scalable vector of 32-bit integers with // size divisible by 4. !llvm.array<2 x vector<2 x i32>> // Array of 2 vectors of 2 32-bit integers. !llvm.array<2 x vec<2 x ptr>> // Array of 2 vectors of 2 pointers. diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td index fe12ab99b9141..df2ecf93ebcda 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td @@ -289,38 +289,6 @@ def LLVMPointerType : LLVMType<"LLVMPointer", "ptr", [ ]; } -//===----------------------------------------------------------------------===// -// LLVMScalableVectorType -//===----------------------------------------------------------------------===// - -def LLVMScalableVectorType : LLVMType<"LLVMScalableVector", "vec"> { - let summary = "LLVM scalable vector type"; - let description = [{ - LLVM dialect scalable vector type, represents a sequence of elements of - unknown length that is known to be divisible by some constant. These - elements can be processed as one in SIMD context. - }]; - - let typeName = "llvm.scalable_vec"; - - let parameters = (ins "Type":$elementType, "unsigned":$minNumElements); - let assemblyFormat = [{ - `<` `?` `x` $minNumElements `x` ` ` custom<PrettyLLVMType>($elementType) `>` - }]; - - let genVerifyDecl = 1; - - let builders = [ - TypeBuilderWithInferredContext<(ins "Type":$elementType, - "unsigned":$minNumElements)> - ]; - - let extraClassDeclaration = [{ - /// Checks if the given type can be used in a vector type. - static bool isValidElementType(Type type); - }]; -} - //===----------------------------------------------------------------------===// // LLVMTargetExtType //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 29701ffc89b19..a25029392e1e9 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -684,8 +684,6 @@ GEPIndicesAdaptor<ValueRange> GEPOp::getIndices() { static Type extractVectorElementType(Type type) { if (auto vectorType = llvm::dyn_cast<VectorType>(type)) return vectorType.getElementType(); - if (auto scalableVectorType = llvm::dyn_cast<LLVMScalableVectorType>(type)) - return scalableVectorType.getElementType(); return type; } @@ -723,10 +721,9 @@ static void destructureIndices(Type currType, ArrayRef<GEPArg> indices, continue; currType = TypeSwitch<Type, Type>(currType) - .Case<VectorType, LLVMScalableVectorType, LLVMArrayType>( - [](auto containerType) { - return containerType.getElementType(); - }) + .Case<VectorType, LLVMArrayType>([](auto containerType) { + return containerType.getElementType(); + }) .Case([&](LLVMStructType structType) -> Type { int64_t memberIndex = rawConstantIndices.back(); if (memberIndex >= 0 && static_cast<size_t>(memberIndex) < @@ -835,7 +832,7 @@ verifyStructIndices(Type baseGEPType, unsigned indexPos, return verifyStructIndices(elementTypes[gepIndex], indexPos + 1, indices, emitOpError); }) - .Case<VectorType, LLVMScalableVectorType, LLVMArrayType>( + .Case<VectorType, LLVMArrayType>( [&](auto containerType) -> LogicalResult { return verifyStructIndices(containerType.getElementType(), indexPos + 1, indices, emitOpError); @@ -3113,16 +3110,12 @@ static int64_t getNumElements(Type t) { if (auto arrayType = dyn_cast<LLVM::LLVMArrayType>(t)) return arrayType.getNumElements() * getNumElements(arrayType.getElementType()); - assert(!isa<LLVM::LLVMScalableVectorType>(t) && - "number of elements of a scalable vector type is unknown"); return 1; } /// Check if the given type is a scalable vector type or a vector/array type /// that contains a nested scalable vector type. static bool hasScalableVectorType(Type t) { - if (isa<LLVM::LLVMScalableVectorType>(t)) - return true; if (auto vecType = dyn_cast<VectorType>(t)) { if (vecType.isScalable()) return true; @@ -3458,7 +3451,7 @@ LogicalResult LLVM::BitcastOp::verify() { if (!resultType) return success(); - auto isVector = llvm::IsaPred<VectorType, LLVMScalableVectorType>; + auto isVector = llvm::IsaPred<VectorType>; // Due to bitcast requiring both operands to be of the same size, it is not // possible for only one of the two to be a pointer of vectors. diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp index c5a1502c8cbe8..8640ef28a9e56 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMMemorySlot.cpp @@ -134,12 +134,6 @@ static bool isSupportedTypeForConversion(Type type) { if (isa<LLVM::LLVMStructType, LLVM::LLVMArrayType>(type)) return false; - // LLVM vector types are only used for either pointers or target specific - // types. These types cannot be casted in the general case, thus the memory - // optimizations do not support them. - if (isa<LLVM::LLVMScalableVectorType>(type)) - return false; - if (auto vectorType = dyn_cast<VectorType>(type)) { // Vectors of pointers cannot be casted. if (isa<LLVM::LLVMPointerType>(vectorType.getElementType())) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp index edfc5adeb424e..319bb90d9b601 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypeSyntax.cpp @@ -40,7 +40,6 @@ static StringRef getTypeKeyword(Type type) { .Case<LLVMMetadataType>([&](Type) { return "metadata"; }) .Case<LLVMFunctionType>([&](Type) { return "func"; }) .Case<LLVMPointerType>([&](Type) { return "ptr"; }) - .Case<LLVMScalableVectorType>([&](Type) { return "vec"; }) .Case<LLVMArrayType>([&](Type) { return "array"; }) .Case<LLVMStructType>([&](Type) { return "struct"; }) .Case<LLVMTargetExtType>([&](Type) { return "target"; }) @@ -103,9 +102,8 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) { printer << getTypeKeyword(type); llvm::TypeSwitch<Type>(type) - .Case<LLVMPointerType, LLVMArrayType, LLVMScalableVectorType, - LLVMFunctionType, LLVMTargetExtType, LLVMStructType>( - [&](auto type) { type.print(printer); }); + .Case<LLVMPointerType, LLVMArrayType, LLVMFunctionType, LLVMTargetExtType, + LLVMStructType>([&](auto type) { type.print(printer); }); } //===----------------------------------------------------------------------===// @@ -114,41 +112,6 @@ void mlir::LLVM::detail::printType(Type type, AsmPrinter &printer) { static ParseResult dispatchParse(AsmParser &parser, Type &type); -/// Parses an LLVM dialect vector type. -/// llvm-type ::= `vec<` `? x`? integer `x` llvm-type `>` -/// Supports both fixed and scalable vectors. -static Type parseVectorType(AsmParser &parser) { - SmallVector<int64_t, 2> dims; - SMLoc dimPos, typePos; - Type elementType; - SMLoc loc = parser.getCurrentLocation(); - if (parser.parseLess() || parser.getCurrentLocation(&dimPos) || - parser.parseDimensionList(dims, /*allowDynamic=*/true) || - parser.getCurrentLocation(&typePos) || - dispatchParse(parser, elementType) || parser.parseGreater()) - return Type(); - - // We parsed a generic dimension list, but vectors only support two forms: - // - single non-dynamic entry in the list (fixed vector); - // - two elements, the first dynamic (indicated by ShapedType::kDynamic) - // and the second - // non-dynamic (scalable vector). - if (dims.empty() || dims.size() > 2 || - ((dims.size() == 2) ^ (ShapedType::isDynamic(dims[0]))) || - (dims.size() == 2 && ShapedType::isDynamic(dims[1]))) { - parser.emitError(dimPos) - << "expected '? x <integer> x <type>' or '<integer> x <type>'"; - return Type(); - } - - bool isScalable = dims.size() == 2; - if (!isScalable) { - parser.emitError(dimPos) << "expected scalable vector"; - return Type(); - } - return parser.getChecked<LLVMScalableVectorType>(loc, elementType, dims[1]); -} - /// Attempts to set the body of an identified structure type. Reports a parsing /// error at `subtypesLoc` in case of failure. static LLVMStructType trySetStructBody(LLVMStructType type, @@ -307,7 +270,6 @@ static Type dispatchParse(AsmParser &parser, bool allowAny = true) { .Case("metadata", [&] { return LLVMMetadataType::get(ctx); }) .Case("func", [&] { return LLVMFunctionType::parse(parser); }) .Case("ptr", [&] { return LLVMPointerType::parse(parser); }) - .Case("vec", [&] { return parseVectorType(parser); }) .Case("array", [&] { return LLVMArrayType::parse(parser); }) .Case("struct", [&] { return LLVMStructType::parse(parser); }) .Case("target", [&] { return LLVMTargetExtType::parse(parser); }) diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp index b008659c7e958..7e9da7aacddba 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMTypes.cpp @@ -150,8 +150,7 @@ generatedTypePrinter(Type def, AsmPrinter &printer); bool LLVMArrayType::isValidElementType(Type type) { return !llvm::isa<LLVMVoidType, LLVMLabelType, LLVMMetadataType, - LLVMFunctionType, LLVMTokenType, LLVMScalableVectorType>( - type); + LLVMFunctionType, LLVMTokenType>(type); } LLVMArrayType LLVMArrayType::get(Type elementType, uint64_t numElements) { @@ -657,53 +656,6 @@ LogicalResult LLVMStructType::verifyEntries(DataLayoutEntryListRef entries, return mlir::success(); } -//===----------------------------------------------------------------------===// -// LLVMScalableVectorType. -//===----------------------------------------------------------------------===// - -/// Verifies that the type about to be constructed is well-formed. -template <typename VecTy> -static LogicalResult -verifyVectorConstructionInvariants(function_ref<InFlightDiagnostic()> emitError, - Type elementType, unsigned numElements) { - if (numElements == 0) - return emitError() << "the number of vector elements must be positive"; - - if (!VecTy::isValidElementType(elementType)) - return emitError() << "invalid vector element type"; - - return success(); -} - -LLVMScalableVectorType LLVMScalableVectorType::get(Type elementType, - unsigned minNumElements) { - assert(elementType && "expected non-null subtype"); - return Base::get(elementType.getContext(), elementType, minNumElements); -} - -LLVMScalableVectorType -LLVMScalableVectorType::getChecked(function_ref<InFlightDiagnostic()> emitError, - Type elementType, unsigned minNumElements) { - assert(elementType && "expected non-null subtype"); - return Base::getChecked(emitError, elementType.getContext(), elementType, - minNumElements); -} - -bool LLVMScalableVectorType::isValidElementType(Type type) { - if (auto intType = llvm::dyn_cast<IntegerType>(type)) - return intType.isSignless(); - - return isCompatibleFloatingPointType(type) || - llvm::isa<LLVMPointerType>(type); -} - -LogicalResult -LLVMScalableVectorType::verify(function_ref<InFlightDiagnostic()> emitError, - Type elementType, unsigned numElements) { - return verifyVectorConstructionInvariants<LLVMScalableVectorType>( - emitError, elementType, numElements); -} - //===----------------------------------------------------------------------===// // LLVMTargetExtType. //===----------------------------------------------------------------------===// @@ -762,7 +714,6 @@ bool mlir::LLVM::isCompatibleOuterType(Type type) { LLVMPointerType, LLVMStructType, LLVMTokenType, - LLVMScalableVectorType, LLVMTargetExtType, LLVMVoidType, LLVMX86AMXType @@ -810,7 +761,6 @@ static bool isCompatibleImpl(Type type, DenseSet<Type> &compatibleTypes) { }) // clang-format off .Case< - LLVMScalableVectorType, LLVMArrayType >([&](auto containerType) { return isCompatible(containerType.getElementType()); @@ -857,9 +807,6 @@ bool mlir::LLVM::isCompatibleFloatingPointType(Type type) { } bool mlir::LLVM::isCompatibleVectorType(Type type) { - if (llvm::isa<LLVMScalableVectorType>(type)) - return true; - if (auto vecType = llvm::dyn_cast<VectorType>(type)) { if (vecType.getRank() != 1) return false; @@ -874,8 +821,7 @@ bool mlir::LLVM::isCompatibleVectorType(Type type) { Type mlir::LLVM::getVectorElementType(Type type) { return llvm::TypeSwitch<Type, Type>(type) - .Case<LLVMScalableVectorType, VectorType>( - [](auto ty) { return ty.getElementType(); }) + .Case<VectorType>([](auto ty) { return ty.getElementType(); }) .Default([](Type) -> Type { llvm_unreachable("incompatible with LLVM vector type"); }); @@ -888,37 +834,22 @@ llvm::ElementCount mlir::LLVM::getVectorNumElements(Type type) { return llvm::ElementCount::getScalable(ty.getNumElements()); return llvm::ElementCount::getFixed(ty.getNumElements()); }) - .Case([](LLVMScalableVectorType ty) { - return llvm::ElementCount::getScalable(ty.getMinNumElements()); - }) .Default([](Type) -> llvm::ElementCount { llvm_unreachable("incompatible with LLVM vector type"); }); } bool mlir::LLVM::isScalableVectorType(Type vectorType) { - assert((llvm::isa<LLVMScalableVectorType, VectorType>(vectorType)) && + assert(llvm::isa<VectorType>(vectorType) && "expected LLVM-compatible vector type"); - return llvm::isa<LLVMScalableVectorType>(vectorType) || - llvm::cast<VectorType>(vectorType).isScalable(); + return llvm::cast<VectorType>(vectorType).isScalable(); } Type mlir::LLVM::getVectorType(Type elementType, unsigned numElements, bool isScalable) { - if (!isScalable) { - // Non-scalable vectors always use the MLIR vector type. - assert(VectorType::isValidElementType(elementType) && - "incompatible element type"); - return VectorType::get(numElements, elementType, {false}); - } - - // This is a scalable vector. - if (VectorType::isValidElementType(elementType)) - return VectorType::get(numElements, elementType, {true}); - assert(LLVMScalableVectorType::isValidElementType(elementType) && - "neither the MLIR vector type nor LLVMScalableVectorType is " - "compatible with the specified element type"); - return LLVMScalableVectorType::get(elementType, numElements); + assert(VectorType::isValidElementType(elementType) && + "incompatible element type"); + return VectorType::get(numElements, elementType, {isScalable}); } Type mlir::LLVM::getVectorType(Type elementType, @@ -937,15 +868,6 @@ Type mlir::LLVM::getFixedVectorType(Type elementType, unsigned numElements) { } Type mlir::LLVM::getScalableVectorType(Type elementType, unsigned numElements) { - bool useLLVM = LLVMScalableVectorType::isValidElementType(elementType); - bool useBuiltIn = VectorType::isValidElementType(elementType); - (void)useBuiltIn; - assert((useLLVM ^ useBuiltIn) && "expected LLVM-compatible scalable-vector " - "type to be either builtin or LLVM dialect " - "type"); - if (useLLVM) - return LLVMScalableVectorType::get(elementType, numElements); - // LLVM vectors are always 1-D, hence only 1 bool is required to mark it as // scalable/non-scalable. return VectorType::get(numElements, elementType, /*scalableDims=*/true); diff --git a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp index ea990ca7aefbe..bc9765fff2953 100644 --- a/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp +++ b/mlir/lib/Target/LLVMIR/TypeFromLLVM.cpp @@ -130,8 +130,8 @@ class TypeFromLLVMIRTranslatorImpl { /// Translates the given scalable-vector type. Type translate(llvm::ScalableVectorType *type) { - return LLVM::LLVMScalableVectorType::get( - translateType(type->getElementType()), type->getMinNumElements()); + return LLVM::getScalableVectorType(translateType(type->getElementType()), + type->getMinNumElements()); } /// Translates the given target extension type. diff --git a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp index 285766357eae7..af78dcd16f792 100644 --- a/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp +++ b/mlir/lib/Target/LLVMIR/TypeToLLVM.cpp @@ -71,8 +71,7 @@ class TypeToLLVMIRTranslatorImpl { return llvm::Type::getX86_AMXTy(context); }) .Case<LLVM::LLVMArrayType, IntegerType, LLVM::LLVMFunctionType, - LLVM::LLVMPointerType, LLVM::LLVMStructType, - LLVM::LLVMScalableVectorType, VectorType, + LLVM::LLVMPointerType, LLVM::LLVMStructType, VectorType, LLVM::LLVMTargetExtType>( [this](auto type) { return this->translate(type); }) .Default([](Type t) -> llvm::Type * { @@ -143,12 +142,6 @@ class TypeToLLVMIRTranslatorImpl { type.getNumElements()); } - /// Translates the given scalable-vector type. - llvm::Type *translate(LLVM::LLVMScalableVectorType type) { - return llvm::ScalableVectorType::get(translateType(type.getElementType()), - type.getMinNumElements()); - } - /// Translates the given target extension type. llvm::Type *translate(LLVM::LLVMTargetExtType type) { SmallVector<llvm::Type *> typeParams; diff --git a/mlir/test/Dialect/LLVMIR/mem2reg.mlir b/mlir/test/Dialect/LLVMIR/mem2reg.mlir index 3c13eacde4856..56634cff87aa9 100644 --- a/mlir/test/Dialect/LLVMIR/mem2reg.mlir +++ b/mlir/test/Dialect/LLVMIR/mem2reg.mlir @@ -1033,7 +1033,7 @@ llvm.func @scalable_vector() -> i16 { llvm.func @scalable_llvm_vector() -> i16 { %0 = llvm.mlir.constant(1 : i32) : i32 // CHECK: llvm.alloca - %1 = llvm.alloca %0 x !llvm.vec<? x 4 x ppc_fp128> : (i32) -> !llvm.ptr + %1 = llvm.alloca %0 x vector<[4] x !llvm.ppc_fp128> : (i32) -> !llvm.ptr %2 = llvm.load %1 : !llvm.ptr -> i16 llvm.return %2 : i16 } diff --git a/mlir/test/Dialect/LLVMIR/types-invalid.mlir b/mlir/test/Dialect/LLVMIR/types-invalid.mlir index 2d4a63234a7fb..04710fa6f2396 100644 --- a/mlir/test/Dialect/LLVMIR/types-invalid.mlir +++ b/mlir/test/Dialect/LLVMIR/types-invalid.mlir @@ -118,34 +118,6 @@ func.func @identified_struct_with_void() { // ----- -func.func @dynamic_vector() { - // expected-error @+1 {{expected '? x <integer> x <type>' or '<integer> x <type>'}} - "some.op"() : () -> !llvm.vec<? x ptr> -} - -// ----- - -func.func @dynamic_scalable_vector() { - // expected-error @+1 {{expected '? x <integer> x <type>' or '<integer> x <type>'}} - "some.op"() : () -> !llvm.vec<?x? x ptr> -} - -// ----- - -func.func @unscalable_vector() { - // expected-error @+1 {{expected '? x <integer> x <type>' or '<integer> x <type>'}} - "some.op"() : () -> !llvm.vec<4x4 x ptr> -} - -// ----- - -func.func @scalable_void_vector() { - // expected-error @+1 {{invalid vector element type}} - "some.op"() : () -> !llvm.vec<?x4 x void> -} - -// ----- - // expected-error @+1 {{unexpected type, expected keyword}} func.func private @unexpected_type() -> !llvm.tensor<*xf32> diff --git a/mlir/test/Dialect/LLVMIR/types.mlir b/mlir/test/Dialect/LLVMIR/types.mlir index bbdef72ece391..b87c3dd6f2d7a 100644 --- a/mlir/test/Dialect/LLVMIR/types.mlir +++ b/mlir/test/Dialect/LLVMIR/types.mlir @@ -72,10 +72,10 @@ func.func @vec() { "some.op"() : () -> vector<4xi32> // CHECK: vector<4xf32> "some.op"() : () -> vector<4xf32> - // CHECK: !llvm.vec<? x 4 x i32> - "some.op"() : () -> !llvm.vec<? x 4 x i32> - // CHECK: !llvm.vec<? x 8 x f16> - "some.op"() : () -> !llvm.vec<? x 8 x f16> + // CHECK: vector<[4]xi32> + "some.op"() : () -> vector<[4] x i32> + // CHECK: vector<[8]xf16> + "some.op"() : () -> vector<[8] x f16> // CHECK: vector<4x!llvm.ptr> "some.op"() : () -> vector<4x!llvm.ptr> // CHECK: vector<4x!llvm.ppc_fp128> diff --git a/mlir/test/Target/LLVMIR/Import/intrinsic.ll b/mlir/test/Target/LLVMIR/Import/intrinsic.ll index ecc9fdc91d62e..9723333dede5f 100644 --- a/mlir/test/Target/LLVMIR/Import/intrinsic.ll +++ b/mlir/test/Target/LLVMIR/Import/intrinsic.ll @@ -873,7 +873,7 @@ define void @invariant_group(ptr %0) { ; CHECK-LABEL: llvm.func @vector_insert define void @vector_insert(<vscale x 4 x float> %0, <4 x float> %1) { - ; CHECK: llvm.intr.vector.insert %{{.*}}, %{{.*}}[4] : vector<4xf32> into !llvm.vec<? x 4 x f32> + ; CHECK: llvm.intr.vector.insert %{{.*}}, %{{.*}}[4] : vector<4xf32> into vector<[4]xf32> %3 = call <vscale x 4 x float> @llvm.vector.insert.nxv4f32.v4f32(<vscale x 4 x float> %0, <4 x float> %1, i64 4); ret void } @@ -889,7 +889,7 @@ define void @vector_extract(<vscale x 4 x float> %0) { define void @vector_deinterleave2(<4 x double> %0, <vscale x 8 x i32> %1) { ; CHECK: "llvm.intr.vector.deinterleave2"(%{{.*}}) : (vector<4xf64>) -> !llvm.struct<(vector<2xf64>, vector<2xf64>)> %3 = call { <2 x double>, <2 x double> } @llvm.vector.deinterleave2.v4f64(<4 x double> %0); - ; CHECK: "llvm.intr.vector.deinterleave2"(%{{.*}}) : (!llvm.vec<? x 8 x i32>) -> !llvm.struct<(vec<? x 4 x i32>, vec<? x 4 x i32>)> + ; CHECK: "llvm.intr.vector.deinterleave2"(%{{.*}}) : (vector<[8]xi32>) -> !llvm.struct<(vector<[4]xi32>, vector<[4]xi32>)> %4 = call { <vscale x 4 x i32>, <vscale x 4 x i32> } @llvm.vector.deinterleave2.nxv8i32(<vscale x 8 x i32> %1); ret void } diff --git a/mlir/test/Target/LLVMIR/llvmir-types.mlir b/mlir/test/Target/LLVMIR/llvmir-types.mlir index 33e1c7e6382ae..5278e1492bb72 100644 --- a/mlir/test/Target/LLVMIR/llvmir-types.mlir +++ b/mlir/test/Target/LLVMIR/llvmir-types.mlir @@ -81,9 +81,9 @@ llvm.func @return_v4_float() -> vector<4xf32> // CHECK: declare <vscale x 4 x float> @return_vs_4_float() llvm.func @return_vs_4_float() -> vector<[4]xf32> // CHECK: declare <vscale x 4 x i32> @return_vs_4_i32() -llvm.func @return_vs_4_i32() -> !llvm.vec<?x4 x i32> +llvm.func @return_vs_4_i32() -> vector<[4]xi32> // CHECK: declare <vscale x 8 x half> @return_vs_8_half() -llvm.func @return_vs_8_half() -> !llvm.vec<?x8 x f16> +llvm.func @return_vs_8_half() -> vector<[8]xf16> // CHECK: declare <4 x ptr> @return_v_4_pi8() llvm.func @return_v_4_pi8() -> vector<4x!llvm.ptr> _______________________________________________ llvm-branch-commits mailing list llvm-branch-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits