Author: Cassandra Beckley Date: 2025-05-27T11:40:54-04:00 New Revision: 5a4571133af78e365e6e7b271688b9ceaa653e67
URL: https://github.com/llvm/llvm-project/commit/5a4571133af78e365e6e7b271688b9ceaa653e67 DIFF: https://github.com/llvm/llvm-project/commit/5a4571133af78e365e6e7b271688b9ceaa653e67.diff LOG: [HLSL] Implement `SpirvType` and `SpirvOpaqueType` (#134034) This implements the design proposed by [Representing SpirvType in Clang's Type System](https://github.com/llvm/wg-hlsl/pull/181). It creates `HLSLInlineSpirvType` as a new `Type` subclass, and `__hlsl_spirv_type` as a new builtin type template to create such a type. This new type is lowered to the `spirv.Type` target extension type, as described in [Target Extension Types for Inline SPIR-V and Decorated Types](https://github.com/llvm/wg-hlsl/blob/main/proposals/0017-inline-spirv-and-decorated-types.md). Added: clang/lib/Headers/hlsl/hlsl_spirv.h clang/test/AST/HLSL/Inputs/pch_spirv_type.hlsl clang/test/AST/HLSL/ast-dump-SpirvType.hlsl clang/test/AST/HLSL/pch_spirv_type.hlsl clang/test/CodeGenHLSL/inline-spirv/SpirvType.alignment.hlsl clang/test/CodeGenHLSL/inline-spirv/SpirvType.hlsl clang/test/SemaHLSL/inline-spirv/SpirvType.dx.error.hlsl clang/test/SemaHLSL/inline-spirv/SpirvType.incomplete.hlsl clang/test/SemaHLSL/inline-spirv/SpirvType.literal.error.hlsl Modified: clang/include/clang-c/Index.h clang/include/clang/AST/ASTContext.h clang/include/clang/AST/ASTNodeTraverser.h clang/include/clang/AST/PropertiesBase.td clang/include/clang/AST/RecursiveASTVisitor.h clang/include/clang/AST/Type.h clang/include/clang/AST/TypeLoc.h clang/include/clang/AST/TypeProperties.td clang/include/clang/Basic/BuiltinTemplates.td clang/include/clang/Basic/DiagnosticSemaKinds.td clang/include/clang/Basic/TypeNodes.td clang/include/clang/Serialization/ASTRecordReader.h clang/include/clang/Serialization/ASTRecordWriter.h clang/include/clang/Serialization/TypeBitCodes.def clang/lib/AST/ASTContext.cpp clang/lib/AST/ASTImporter.cpp clang/lib/AST/ASTStructuralEquivalence.cpp clang/lib/AST/ExprConstant.cpp clang/lib/AST/ItaniumMangle.cpp clang/lib/AST/MicrosoftMangle.cpp clang/lib/AST/Type.cpp clang/lib/AST/TypePrinter.cpp clang/lib/CodeGen/CGDebugInfo.cpp clang/lib/CodeGen/CGDebugInfo.h clang/lib/CodeGen/CodeGenFunction.cpp clang/lib/CodeGen/CodeGenTypes.cpp clang/lib/CodeGen/ItaniumCXXABI.cpp clang/lib/CodeGen/Targets/SPIR.cpp clang/lib/Headers/CMakeLists.txt clang/lib/Headers/hlsl.h clang/lib/Sema/SemaExpr.cpp clang/lib/Sema/SemaLookup.cpp clang/lib/Sema/SemaTemplate.cpp clang/lib/Sema/SemaTemplateDeduction.cpp clang/lib/Sema/SemaType.cpp clang/lib/Sema/TreeTransform.h clang/lib/Serialization/ASTReader.cpp clang/lib/Serialization/ASTWriter.cpp clang/test/AST/HLSL/vector-alias.hlsl clang/tools/libclang/CIndex.cpp clang/tools/libclang/CXType.cpp clang/utils/TableGen/ClangBuiltinTemplatesEmitter.cpp Removed: ################################################################################ diff --git a/clang/include/clang-c/Index.h b/clang/include/clang-c/Index.h index d30d15e53802a..e4cb4327fbaac 100644 --- a/clang/include/clang-c/Index.h +++ b/clang/include/clang-c/Index.h @@ -3034,7 +3034,8 @@ enum CXTypeKind { /* HLSL Types */ CXType_HLSLResource = 179, - CXType_HLSLAttributedResource = 180 + CXType_HLSLAttributedResource = 180, + CXType_HLSLInlineSpirv = 181 }; /** diff --git a/clang/include/clang/AST/ASTContext.h b/clang/include/clang/AST/ASTContext.h index 1fdc488a76507..2831256425702 100644 --- a/clang/include/clang/AST/ASTContext.h +++ b/clang/include/clang/AST/ASTContext.h @@ -260,6 +260,7 @@ class ASTContext : public RefCountedBase<ASTContext> { DependentBitIntTypes; mutable llvm::FoldingSet<BTFTagAttributedType> BTFTagAttributedTypes; llvm::FoldingSet<HLSLAttributedResourceType> HLSLAttributedResourceTypes; + llvm::FoldingSet<HLSLInlineSpirvType> HLSLInlineSpirvTypes; mutable llvm::FoldingSet<CountAttributedType> CountAttributedTypes; @@ -1808,6 +1809,10 @@ class ASTContext : public RefCountedBase<ASTContext> { QualType Wrapped, QualType Contained, const HLSLAttributedResourceType::Attributes &Attrs); + QualType getHLSLInlineSpirvType(uint32_t Opcode, uint32_t Size, + uint32_t Alignment, + ArrayRef<SpirvOperand> Operands); + QualType getSubstTemplateTypeParmType(QualType Replacement, Decl *AssociatedDecl, unsigned Index, UnsignedOrNone PackIndex, diff --git a/clang/include/clang/AST/ASTNodeTraverser.h b/clang/include/clang/AST/ASTNodeTraverser.h index 7bb435146f752..01bc12ce33eff 100644 --- a/clang/include/clang/AST/ASTNodeTraverser.h +++ b/clang/include/clang/AST/ASTNodeTraverser.h @@ -450,6 +450,24 @@ class ASTNodeTraverser if (!Contained.isNull()) Visit(Contained); } + void VisitHLSLInlineSpirvType(const HLSLInlineSpirvType *T) { + for (auto &Operand : T->getOperands()) { + using SpirvOperandKind = SpirvOperand::SpirvOperandKind; + + switch (Operand.getKind()) { + case SpirvOperandKind::ConstantId: + case SpirvOperandKind::Literal: + break; + + case SpirvOperandKind::TypeId: + Visit(Operand.getResultType()); + break; + + default: + llvm_unreachable("Invalid SpirvOperand kind!"); + } + } + } void VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType *) {} void VisitSubstTemplateTypeParmPackType(const SubstTemplateTypeParmPackType *T) { diff --git a/clang/include/clang/AST/PropertiesBase.td b/clang/include/clang/AST/PropertiesBase.td index 111a3e44f2fd5..8317b6a874fa3 100644 --- a/clang/include/clang/AST/PropertiesBase.td +++ b/clang/include/clang/AST/PropertiesBase.td @@ -148,6 +148,7 @@ def UnsignedOrNone : PropertyType; def UnaryTypeTransformKind : EnumPropertyType<"UnaryTransformType::UTTKind">; def VectorKind : EnumPropertyType<"VectorKind">; def TypeCoupledDeclRefInfo : PropertyType; +def HLSLSpirvOperand : PropertyType<"SpirvOperand"> { let PassByReference = 1; } def ExceptionSpecInfo : PropertyType<"FunctionProtoType::ExceptionSpecInfo"> { let BufferElementTypes = [ QualType ]; diff --git a/clang/include/clang/AST/RecursiveASTVisitor.h b/clang/include/clang/AST/RecursiveASTVisitor.h index 23a8c4f1f7380..a11157c006f92 100644 --- a/clang/include/clang/AST/RecursiveASTVisitor.h +++ b/clang/include/clang/AST/RecursiveASTVisitor.h @@ -1154,6 +1154,14 @@ DEF_TRAVERSE_TYPE(BTFTagAttributedType, DEF_TRAVERSE_TYPE(HLSLAttributedResourceType, { TRY_TO(TraverseType(T->getWrappedType())); }) +DEF_TRAVERSE_TYPE(HLSLInlineSpirvType, { + for (auto &Operand : T->getOperands()) { + if (Operand.isConstant() || Operand.isType()) { + TRY_TO(TraverseType(Operand.getResultType())); + } + } +}) + DEF_TRAVERSE_TYPE(ParenType, { TRY_TO(TraverseType(T->getInnerType())); }) DEF_TRAVERSE_TYPE(MacroQualifiedType, @@ -1457,6 +1465,9 @@ DEF_TRAVERSE_TYPELOC(BTFTagAttributedType, DEF_TRAVERSE_TYPELOC(HLSLAttributedResourceType, { TRY_TO(TraverseTypeLoc(TL.getWrappedLoc())); }) +DEF_TRAVERSE_TYPELOC(HLSLInlineSpirvType, + { TRY_TO(TraverseType(TL.getType())); }) + DEF_TRAVERSE_TYPELOC(ElaboratedType, { if (TL.getQualifierLoc()) { TRY_TO(TraverseNestedNameSpecifierLoc(TL.getQualifierLoc())); diff --git a/clang/include/clang/AST/Type.h b/clang/include/clang/AST/Type.h index 9f098edfc08ae..3d84f999567ca 100644 --- a/clang/include/clang/AST/Type.h +++ b/clang/include/clang/AST/Type.h @@ -2691,6 +2691,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase { bool isHLSLSpecificType() const; // Any HLSL specific type bool isHLSLBuiltinIntangibleType() const; // Any HLSL builtin intangible type bool isHLSLAttributedResourceType() const; + bool isHLSLInlineSpirvType() const; bool isHLSLResourceRecord() const; bool isHLSLIntangibleType() const; // Any HLSL intangible type (builtin, array, class) @@ -6364,6 +6365,143 @@ class HLSLAttributedResourceType : public Type, public llvm::FoldingSetNode { findHandleTypeOnResource(const Type *RT); }; +/// Instances of this class represent operands to a SPIR-V type instruction. +class SpirvOperand { +public: + enum SpirvOperandKind : unsigned char { + Invalid, ///< Uninitialized. + ConstantId, ///< Integral value to represent as a SPIR-V OpConstant + ///< instruction ID. + Literal, ///< Integral value to represent as an immediate literal. + TypeId, ///< Type to represent as a SPIR-V type ID. + + Max, + }; + +private: + SpirvOperandKind Kind = Invalid; + + QualType ResultType; + llvm::APInt Value; // Signedness of constants is represented by ResultType. + +public: + SpirvOperand() : Kind(Invalid), ResultType(), Value() {} + + SpirvOperand(SpirvOperandKind Kind, QualType ResultType, llvm::APInt Value) + : Kind(Kind), ResultType(ResultType), Value(Value) {} + + SpirvOperand(const SpirvOperand &Other) { *this = Other; } + ~SpirvOperand() {} + + SpirvOperand &operator=(const SpirvOperand &Other) { + this->Kind = Other.Kind; + this->ResultType = Other.ResultType; + this->Value = Other.Value; + return *this; + } + + bool operator==(const SpirvOperand &Other) const { + return Kind == Other.Kind && ResultType == Other.ResultType && + Value == Other.Value; + } + + bool operator!=(const SpirvOperand &Other) const { return !(*this == Other); } + + SpirvOperandKind getKind() const { return Kind; } + + bool isValid() const { return Kind != Invalid && Kind < Max; } + bool isConstant() const { return Kind == ConstantId; } + bool isLiteral() const { return Kind == Literal; } + bool isType() const { return Kind == TypeId; } + + llvm::APInt getValue() const { + assert((isConstant() || isLiteral()) && + "This is not an operand with a value!"); + return Value; + } + + QualType getResultType() const { + assert((isConstant() || isType()) && + "This is not an operand with a result type!"); + return ResultType; + } + + static SpirvOperand createConstant(QualType ResultType, llvm::APInt Val) { + return SpirvOperand(ConstantId, ResultType, Val); + } + + static SpirvOperand createLiteral(llvm::APInt Val) { + return SpirvOperand(Literal, QualType(), Val); + } + + static SpirvOperand createType(QualType T) { + return SpirvOperand(TypeId, T, llvm::APSInt()); + } + + void Profile(llvm::FoldingSetNodeID &ID) const { + ID.AddInteger(Kind); + ID.AddPointer(ResultType.getAsOpaquePtr()); + Value.Profile(ID); + } +}; + +/// Represents an arbitrary, user-specified SPIR-V type instruction. +class HLSLInlineSpirvType final + : public Type, + public llvm::FoldingSetNode, + private llvm::TrailingObjects<HLSLInlineSpirvType, SpirvOperand> { + friend class ASTContext; // ASTContext creates these + friend TrailingObjects; + +private: + uint32_t Opcode; + uint32_t Size; + uint32_t Alignment; + size_t NumOperands; + + HLSLInlineSpirvType(uint32_t Opcode, uint32_t Size, uint32_t Alignment, + ArrayRef<SpirvOperand> Operands) + : Type(HLSLInlineSpirv, QualType(), TypeDependence::None), Opcode(Opcode), + Size(Size), Alignment(Alignment), NumOperands(Operands.size()) { + for (size_t I = 0; I < NumOperands; I++) { + // Since Operands are stored as a trailing object, they have not been + // initialized yet. Call the constructor manually. + auto *Operand = + new (&getTrailingObjects<SpirvOperand>()[I]) SpirvOperand(); + *Operand = Operands[I]; + } + } + +public: + uint32_t getOpcode() const { return Opcode; } + uint32_t getSize() const { return Size; } + uint32_t getAlignment() const { return Alignment; } + ArrayRef<SpirvOperand> getOperands() const { + return {getTrailingObjects<SpirvOperand>(), NumOperands}; + } + + bool isSugared() const { return false; } + QualType desugar() const { return QualType(this, 0); } + + void Profile(llvm::FoldingSetNodeID &ID) { + Profile(ID, Opcode, Size, Alignment, getOperands()); + } + + static void Profile(llvm::FoldingSetNodeID &ID, uint32_t Opcode, + uint32_t Size, uint32_t Alignment, + ArrayRef<SpirvOperand> Operands) { + ID.AddInteger(Opcode); + ID.AddInteger(Size); + ID.AddInteger(Alignment); + for (auto &Operand : Operands) + Operand.Profile(ID); + } + + static bool classof(const Type *T) { + return T->getTypeClass() == HLSLInlineSpirv; + } +}; + class TemplateTypeParmType : public Type, public llvm::FoldingSetNode { friend class ASTContext; // ASTContext creates these @@ -8495,13 +8633,18 @@ inline bool Type::isHLSLBuiltinIntangibleType() const { } inline bool Type::isHLSLSpecificType() const { - return isHLSLBuiltinIntangibleType() || isHLSLAttributedResourceType(); + return isHLSLBuiltinIntangibleType() || isHLSLAttributedResourceType() || + isHLSLInlineSpirvType(); } inline bool Type::isHLSLAttributedResourceType() const { return isa<HLSLAttributedResourceType>(this); } +inline bool Type::isHLSLInlineSpirvType() const { + return isa<HLSLInlineSpirvType>(this); +} + inline bool Type::isTemplateTypeParmType() const { return isa<TemplateTypeParmType>(CanonicalType); } diff --git a/clang/include/clang/AST/TypeLoc.h b/clang/include/clang/AST/TypeLoc.h index 92661b8b13fe0..53c7ea8c65df2 100644 --- a/clang/include/clang/AST/TypeLoc.h +++ b/clang/include/clang/AST/TypeLoc.h @@ -973,6 +973,25 @@ class HLSLAttributedResourceTypeLoc } }; +struct HLSLInlineSpirvTypeLocInfo { + SourceLocation Loc; +}; // Nothing. + +class HLSLInlineSpirvTypeLoc + : public ConcreteTypeLoc<UnqualTypeLoc, HLSLInlineSpirvTypeLoc, + HLSLInlineSpirvType, HLSLInlineSpirvTypeLocInfo> { +public: + SourceLocation getSpirvTypeLoc() const { return getLocalData()->Loc; } + void setSpirvTypeLoc(SourceLocation loc) const { getLocalData()->Loc = loc; } + + SourceRange getLocalSourceRange() const { + return SourceRange(getSpirvTypeLoc(), getSpirvTypeLoc()); + } + void initializeLocal(ASTContext &Context, SourceLocation loc) { + setSpirvTypeLoc(loc); + } +}; + struct ObjCObjectTypeLocInfo { SourceLocation TypeArgsLAngleLoc; SourceLocation TypeArgsRAngleLoc; diff --git a/clang/include/clang/AST/TypeProperties.td b/clang/include/clang/AST/TypeProperties.td index f4b8ce0994ba8..c8dc083df7e10 100644 --- a/clang/include/clang/AST/TypeProperties.td +++ b/clang/include/clang/AST/TypeProperties.td @@ -701,6 +701,24 @@ let Class = HLSLAttributedResourceType in { }]>; } +let Class = HLSLInlineSpirvType in { + def : Property<"opcode", UInt32> { + let Read = [{ node->getOpcode() }]; + } + def : Property<"size", UInt32> { + let Read = [{ node->getSize() }]; + } + def : Property<"alignment", UInt32> { + let Read = [{ node->getAlignment() }]; + } + def : Property<"operands", Array<HLSLSpirvOperand>> { + let Read = [{ node->getOperands() }]; + } + def : Creator<[{ + return ctx.getHLSLInlineSpirvType(opcode, size, alignment, operands); + }]>; +} + let Class = DependentAddressSpaceType in { def : Property<"pointeeType", QualType> { let Read = [{ node->getPointeeType() }]; diff --git a/clang/include/clang/Basic/BuiltinTemplates.td b/clang/include/clang/Basic/BuiltinTemplates.td index d46ce063d2f7e..5b9672b395955 100644 --- a/clang/include/clang/Basic/BuiltinTemplates.td +++ b/clang/include/clang/Basic/BuiltinTemplates.td @@ -28,25 +28,37 @@ class BuiltinNTTP<string type_name> : TemplateArg<""> { } def SizeT : BuiltinNTTP<"size_t"> {} +def Uint32T: BuiltinNTTP<"uint32_t"> {} class BuiltinTemplate<list<TemplateArg> template_head> { list<TemplateArg> TemplateHead = template_head; } +class CPlusPlusBuiltinTemplate<list<TemplateArg> template_head> : BuiltinTemplate<template_head>; + +class HLSLBuiltinTemplate<list<TemplateArg> template_head> : BuiltinTemplate<template_head>; + // template <template <class T, T... Ints> IntSeq, class T, T N> -def __make_integer_seq : BuiltinTemplate< +def __make_integer_seq : CPlusPlusBuiltinTemplate< [Template<[Class<"T">, NTTP<"T", "Ints", /*is_variadic=*/1>], "IntSeq">, Class<"T">, NTTP<"T", "N">]>; // template <size_t, class... T> -def __type_pack_element : BuiltinTemplate< +def __type_pack_element : CPlusPlusBuiltinTemplate< [SizeT, Class<"T", /*is_variadic=*/1>]>; // template <template <class... Args> BaseTemplate, // template <class TypeMember> HasTypeMember, // class HasNoTypeMember // class... Ts> -def __builtin_common_type : BuiltinTemplate< +def __builtin_common_type : CPlusPlusBuiltinTemplate< [Template<[Class<"Args", /*is_variadic=*/1>], "BaseTemplate">, Template<[Class<"TypeMember">], "HasTypeMember">, Class<"HasNoTypeMember">, Class<"Ts", /*is_variadic=*/1>]>; + +// template <uint32_t Opcode, +// uint32_t Size, +// uint32_t Alignment, +// typename ...Operands> +def __hlsl_spirv_type : HLSLBuiltinTemplate< +[Uint32T, Uint32T, Uint32T, Class<"Operands", /*is_variadic=*/1>]>; diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index b63cc8a11b136..11bd7a8edfd72 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -12998,6 +12998,9 @@ def err_hlsl_expect_arg_const_int_one_or_neg_one: Error< def err_invalid_hlsl_resource_type: Error< "invalid __hlsl_resource_t type attributes">; +def err_hlsl_spirv_only: Error<"%0 is only available for the SPIR-V target">; +def err_hlsl_vk_literal_must_contain_constant: Error<"the argument to vk::Literal must be a vk::integral_constant">; + // Layout randomization diagnostics. def err_non_designated_init_used : Error< "a randomized struct can only be initialized with a designated initializer">; diff --git a/clang/include/clang/Basic/TypeNodes.td b/clang/include/clang/Basic/TypeNodes.td index 7e550ca2992f3..567b8a5ca5a4d 100644 --- a/clang/include/clang/Basic/TypeNodes.td +++ b/clang/include/clang/Basic/TypeNodes.td @@ -94,6 +94,7 @@ def ElaboratedType : TypeNode<Type>, NeverCanonical; def AttributedType : TypeNode<Type>, NeverCanonical; def BTFTagAttributedType : TypeNode<Type>, NeverCanonical; def HLSLAttributedResourceType : TypeNode<Type>; +def HLSLInlineSpirvType : TypeNode<Type>; def TemplateTypeParmType : TypeNode<Type>, AlwaysDependent, LeafType; def SubstTemplateTypeParmType : TypeNode<Type>, NeverCanonical; def SubstTemplateTypeParmPackType : TypeNode<Type>, AlwaysDependent; diff --git a/clang/include/clang/Serialization/ASTRecordReader.h b/clang/include/clang/Serialization/ASTRecordReader.h index 141804185083f..da3f504ff27df 100644 --- a/clang/include/clang/Serialization/ASTRecordReader.h +++ b/clang/include/clang/Serialization/ASTRecordReader.h @@ -214,6 +214,8 @@ class ASTRecordReader TypeCoupledDeclRefInfo readTypeCoupledDeclRefInfo(); + SpirvOperand readHLSLSpirvOperand(); + /// Read a declaration name, advancing Idx. // DeclarationName readDeclarationName(); (inherited) DeclarationNameLoc readDeclarationNameLoc(DeclarationName Name); diff --git a/clang/include/clang/Serialization/ASTRecordWriter.h b/clang/include/clang/Serialization/ASTRecordWriter.h index e1fb239a9ce49..07f7e8d919d8b 100644 --- a/clang/include/clang/Serialization/ASTRecordWriter.h +++ b/clang/include/clang/Serialization/ASTRecordWriter.h @@ -151,6 +151,20 @@ class ASTRecordWriter writeBool(Info.isDeref()); } + void writeHLSLSpirvOperand(SpirvOperand Op) { + QualType ResultType; + llvm::APInt Value; + + if (Op.isConstant() || Op.isType()) + ResultType = Op.getResultType(); + if (Op.isConstant() || Op.isLiteral()) + Value = Op.getValue(); + + Record->push_back(Op.getKind()); + writeQualType(ResultType); + writeAPInt(Value); + } + /// Emit a source range. void AddSourceRange(SourceRange Range, LocSeq *Seq = nullptr) { return Writer->AddSourceRange(Range, *Record, Seq); diff --git a/clang/include/clang/Serialization/TypeBitCodes.def b/clang/include/clang/Serialization/TypeBitCodes.def index 3c78b87805010..b8cde2e370960 100644 --- a/clang/include/clang/Serialization/TypeBitCodes.def +++ b/clang/include/clang/Serialization/TypeBitCodes.def @@ -68,5 +68,6 @@ TYPE_BIT_CODE(PackIndexing, PACK_INDEXING, 56) TYPE_BIT_CODE(CountAttributed, COUNT_ATTRIBUTED, 57) TYPE_BIT_CODE(ArrayParameter, ARRAY_PARAMETER, 58) TYPE_BIT_CODE(HLSLAttributedResource, HLSLRESOURCE_ATTRIBUTED, 59) +TYPE_BIT_CODE(HLSLInlineSpirv, HLSL_INLINE_SPIRV, 60) #undef TYPE_BIT_CODE diff --git a/clang/lib/AST/ASTContext.cpp b/clang/lib/AST/ASTContext.cpp index b5417fcf20ddd..2cd9023ea964a 100644 --- a/clang/lib/AST/ASTContext.cpp +++ b/clang/lib/AST/ASTContext.cpp @@ -2493,6 +2493,19 @@ TypeInfo ASTContext::getTypeInfoImpl(const Type *T) const { return getTypeInfo( cast<HLSLAttributedResourceType>(T)->getWrappedType().getTypePtr()); + case Type::HLSLInlineSpirv: { + const auto *ST = cast<HLSLInlineSpirvType>(T); + // Size is specified in bytes, convert to bits + Width = ST->getSize() * 8; + Align = ST->getAlignment(); + if (Width == 0 && Align == 0) { + // We are defaulting to laying out opaque SPIR-V types as 32-bit ints. + Width = 32; + Align = 32; + } + break; + } + case Type::Atomic: { // Start with the base type information. TypeInfo Info = getTypeInfo(cast<AtomicType>(T)->getValueType()); @@ -3507,6 +3520,7 @@ static void encodeTypeForFunctionPointerAuth(const ASTContext &Ctx, return; } case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: llvm_unreachable("should never get here"); break; case Type::DeducedTemplateSpecialization: @@ -4228,6 +4242,7 @@ QualType ASTContext::getVariableArrayDecayedType(QualType type) const { case Type::DependentBitInt: case Type::ArrayParameter: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: llvm_unreachable("type should never be variably-modified"); // These types can be variably-modified but should never need to @@ -5486,6 +5501,31 @@ QualType ASTContext::getHLSLAttributedResourceType( return QualType(Ty, 0); } + +QualType ASTContext::getHLSLInlineSpirvType(uint32_t Opcode, uint32_t Size, + uint32_t Alignment, + ArrayRef<SpirvOperand> Operands) { + llvm::FoldingSetNodeID ID; + HLSLInlineSpirvType::Profile(ID, Opcode, Size, Alignment, Operands); + + void *InsertPos = nullptr; + HLSLInlineSpirvType *Ty = + HLSLInlineSpirvTypes.FindNodeOrInsertPos(ID, InsertPos); + if (Ty) + return QualType(Ty, 0); + + void *Mem = Allocate( + HLSLInlineSpirvType::totalSizeToAlloc<SpirvOperand>(Operands.size()), + alignof(HLSLInlineSpirvType)); + + Ty = new (Mem) HLSLInlineSpirvType(Opcode, Size, Alignment, Operands); + + Types.push_back(Ty); + HLSLInlineSpirvTypes.InsertNode(Ty, InsertPos); + + return QualType(Ty, 0); +} + /// Retrieve a substitution-result type. QualType ASTContext::getSubstTemplateTypeParmType(QualType Replacement, Decl *AssociatedDecl, @@ -9457,6 +9497,7 @@ void ASTContext::getObjCEncodingForTypeImpl(QualType T, std::string &S, return; case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: llvm_unreachable("unexpected type"); case Type::ArrayParameter: @@ -11904,6 +11945,20 @@ QualType ASTContext::mergeTypes(QualType LHS, QualType RHS, bool OfBlockPointer, return LHS; return {}; } + case Type::HLSLInlineSpirv: + const HLSLInlineSpirvType *LHSTy = LHS->castAs<HLSLInlineSpirvType>(); + const HLSLInlineSpirvType *RHSTy = RHS->castAs<HLSLInlineSpirvType>(); + + if (LHSTy->getOpcode() == RHSTy->getOpcode() && + LHSTy->getSize() == RHSTy->getSize() && + LHSTy->getAlignment() == RHSTy->getAlignment()) { + for (size_t I = 0; I < LHSTy->getOperands().size(); I++) + if (LHSTy->getOperands()[I] != RHSTy->getOperands()[I]) + return {}; + + return LHS; + } + return {}; } llvm_unreachable("Invalid Type::Class!"); @@ -13922,6 +13977,7 @@ static QualType getCommonNonSugarTypeNode(ASTContext &Ctx, const Type *X, SUGAR_FREE_TYPE(SubstTemplateTypeParmPack) SUGAR_FREE_TYPE(UnresolvedUsing) SUGAR_FREE_TYPE(HLSLAttributedResource) + SUGAR_FREE_TYPE(HLSLInlineSpirv) #undef SUGAR_FREE_TYPE #define NON_UNIQUE_TYPE(Class) UNEXPECTED_TYPE(Class, "non-unique") NON_UNIQUE_TYPE(TypeOfExpr) @@ -14262,6 +14318,7 @@ static QualType getCommonSugarTypeNode(ASTContext &Ctx, const Type *X, CANONICAL_TYPE(FunctionProto) CANONICAL_TYPE(IncompleteArray) CANONICAL_TYPE(HLSLAttributedResource) + CANONICAL_TYPE(HLSLInlineSpirv) CANONICAL_TYPE(LValueReference) CANONICAL_TYPE(ObjCInterface) CANONICAL_TYPE(ObjCObject) diff --git a/clang/lib/AST/ASTImporter.cpp b/clang/lib/AST/ASTImporter.cpp index b481ad5df667e..d275f71a97352 100644 --- a/clang/lib/AST/ASTImporter.cpp +++ b/clang/lib/AST/ASTImporter.cpp @@ -1825,6 +1825,43 @@ ExpectedType clang::ASTNodeImporter::VisitHLSLAttributedResourceType( ToWrappedType, ToContainedType, ToAttrs); } +ExpectedType clang::ASTNodeImporter::VisitHLSLInlineSpirvType( + const clang::HLSLInlineSpirvType *T) { + Error Err = Error::success(); + + uint32_t ToOpcode = T->getOpcode(); + uint32_t ToSize = T->getSize(); + uint32_t ToAlignment = T->getAlignment(); + + llvm::SmallVector<SpirvOperand> ToOperands; + + for (auto &Operand : T->getOperands()) { + using SpirvOperandKind = SpirvOperand::SpirvOperandKind; + + switch (Operand.getKind()) { + case SpirvOperandKind::ConstantId: + ToOperands.push_back(SpirvOperand::createConstant( + importChecked(Err, Operand.getResultType()), Operand.getValue())); + break; + case SpirvOperandKind::Literal: + ToOperands.push_back(SpirvOperand::createLiteral(Operand.getValue())); + break; + case SpirvOperandKind::TypeId: + ToOperands.push_back(SpirvOperand::createType( + importChecked(Err, Operand.getResultType()))); + break; + default: + llvm_unreachable("Invalid SpirvOperand kind"); + } + + if (Err) + return std::move(Err); + } + + return Importer.getToContext().getHLSLInlineSpirvType( + ToOpcode, ToSize, ToAlignment, ToOperands); +} + ExpectedType clang::ASTNodeImporter::VisitConstantMatrixType( const clang::ConstantMatrixType *T) { ExpectedType ToElementTypeOrErr = import(T->getElementType()); diff --git a/clang/lib/AST/ASTStructuralEquivalence.cpp b/clang/lib/AST/ASTStructuralEquivalence.cpp index 499854a75abc6..47c8812ad4dd6 100644 --- a/clang/lib/AST/ASTStructuralEquivalence.cpp +++ b/clang/lib/AST/ASTStructuralEquivalence.cpp @@ -1157,6 +1157,23 @@ static bool IsStructurallyEquivalent(StructuralEquivalenceContext &Context, return false; break; + case Type::HLSLInlineSpirv: + if (cast<HLSLInlineSpirvType>(T1)->getOpcode() != + cast<HLSLInlineSpirvType>(T2)->getOpcode() || + cast<HLSLInlineSpirvType>(T1)->getSize() != + cast<HLSLInlineSpirvType>(T2)->getSize() || + cast<HLSLInlineSpirvType>(T1)->getAlignment() != + cast<HLSLInlineSpirvType>(T2)->getAlignment()) + return false; + for (size_t I = 0; I < cast<HLSLInlineSpirvType>(T1)->getOperands().size(); + I++) { + if (cast<HLSLInlineSpirvType>(T1)->getOperands()[I] != + cast<HLSLInlineSpirvType>(T2)->getOperands()[I]) { + return false; + } + } + break; + case Type::Paren: if (!IsStructurallyEquivalent(Context, cast<ParenType>(T1)->getInnerType(), cast<ParenType>(T2)->getInnerType())) diff --git a/clang/lib/AST/ExprConstant.cpp b/clang/lib/AST/ExprConstant.cpp index 39fc714402728..c7488ea0cd0f6 100644 --- a/clang/lib/AST/ExprConstant.cpp +++ b/clang/lib/AST/ExprConstant.cpp @@ -12452,6 +12452,7 @@ GCCTypeClass EvaluateBuiltinClassifyType(QualType T, case Type::ObjCObjectPointer: case Type::Pipe: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: // Classify all other types that don't fit into the regular // classification the same way. return GCCTypeClass::None; diff --git a/clang/lib/AST/ItaniumMangle.cpp b/clang/lib/AST/ItaniumMangle.cpp index 33a8728728574..17c07333f4107 100644 --- a/clang/lib/AST/ItaniumMangle.cpp +++ b/clang/lib/AST/ItaniumMangle.cpp @@ -2461,6 +2461,7 @@ bool CXXNameMangler::mangleUnresolvedTypeOrSimpleId(QualType Ty, case Type::Attributed: case Type::BTFTagAttributed: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: case Type::Auto: case Type::DeducedTemplateSpecialization: case Type::PackExpansion: @@ -4692,6 +4693,44 @@ void CXXNameMangler::mangleType(const HLSLAttributedResourceType *T) { mangleType(T->getWrappedType()); } +void CXXNameMangler::mangleType(const HLSLInlineSpirvType *T) { + SmallString<20> TypeNameStr; + llvm::raw_svector_ostream TypeNameOS(TypeNameStr); + + TypeNameOS << "spirv_type"; + + TypeNameOS << "_" << T->getOpcode(); + TypeNameOS << "_" << T->getSize(); + TypeNameOS << "_" << T->getAlignment(); + + mangleVendorType(TypeNameStr); + + for (auto &Operand : T->getOperands()) { + using SpirvOperandKind = SpirvOperand::SpirvOperandKind; + + switch (Operand.getKind()) { + case SpirvOperandKind::ConstantId: + mangleVendorQualifier("_Const"); + mangleIntegerLiteral(Operand.getResultType(), + llvm::APSInt(Operand.getValue())); + break; + case SpirvOperandKind::Literal: + mangleVendorQualifier("_Lit"); + mangleIntegerLiteral(Context.getASTContext().IntTy, + llvm::APSInt(Operand.getValue())); + break; + case SpirvOperandKind::TypeId: + mangleVendorQualifier("_Type"); + mangleType(Operand.getResultType()); + break; + default: + llvm_unreachable("Invalid SpirvOperand kind"); + break; + } + TypeNameOS << Operand.getKind(); + } +} + void CXXNameMangler::mangleIntegerLiteral(QualType T, const llvm::APSInt &Value) { // <expr-primary> ::= L <type> <value number> E # integer literal @@ -4705,7 +4744,6 @@ void CXXNameMangler::mangleIntegerLiteral(QualType T, mangleNumber(Value); } Out << 'E'; - } void CXXNameMangler::mangleMemberExprBase(const Expr *Base, bool IsArrow) { diff --git a/clang/lib/AST/MicrosoftMangle.cpp b/clang/lib/AST/MicrosoftMangle.cpp index add737b762ccc..290521a9bd531 100644 --- a/clang/lib/AST/MicrosoftMangle.cpp +++ b/clang/lib/AST/MicrosoftMangle.cpp @@ -3768,6 +3768,11 @@ void MicrosoftCXXNameMangler::mangleType(const HLSLAttributedResourceType *T, llvm_unreachable("HLSL uses Itanium name mangling"); } +void MicrosoftCXXNameMangler::mangleType(const HLSLInlineSpirvType *T, + Qualifiers, SourceRange Range) { + llvm_unreachable("HLSL uses Itanium name mangling"); +} + // <this-adjustment> ::= <no-adjustment> | <static-adjustment> | // <virtual-adjustment> // <no-adjustment> ::= A # private near diff --git a/clang/lib/AST/Type.cpp b/clang/lib/AST/Type.cpp index df084dd9149a4..35a5f8ec32ab0 100644 --- a/clang/lib/AST/Type.cpp +++ b/clang/lib/AST/Type.cpp @@ -4764,6 +4764,8 @@ static CachedProperties computeCachedProperties(const Type *T) { return Cache::get(cast<PipeType>(T)->getElementType()); case Type::HLSLAttributedResource: return Cache::get(cast<HLSLAttributedResourceType>(T)->getWrappedType()); + case Type::HLSLInlineSpirv: + return CachedProperties(Linkage::External, false); } llvm_unreachable("unhandled type class"); @@ -4862,6 +4864,17 @@ LinkageInfo LinkageComputer::computeTypeLinkageInfo(const Type *T) { return computeTypeLinkageInfo(cast<HLSLAttributedResourceType>(T) ->getContainedType() ->getCanonicalTypeInternal()); + case Type::HLSLInlineSpirv: + return LinkageInfo::external(); + { + const auto *ST = cast<HLSLInlineSpirvType>(T); + LinkageInfo LV = LinkageInfo::external(); + for (auto &Operand : ST->getOperands()) { + if (Operand.isConstant() || Operand.isType()) + LV.merge(computeTypeLinkageInfo(Operand.getResultType())); + } + return LV; + } } llvm_unreachable("unhandled type class"); @@ -5049,6 +5062,7 @@ bool Type::canHaveNullability(bool ResultIfUnknown) const { case Type::DependentBitInt: case Type::ArrayParameter: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: return false; } llvm_unreachable("bad type kind!"); diff --git a/clang/lib/AST/TypePrinter.cpp b/clang/lib/AST/TypePrinter.cpp index cba1a2d98d660..4793ef38c2c46 100644 --- a/clang/lib/AST/TypePrinter.cpp +++ b/clang/lib/AST/TypePrinter.cpp @@ -247,6 +247,7 @@ bool TypePrinter::canPrefixQualifiers(const Type *T, case Type::DependentBitInt: case Type::BTFTagAttributed: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: CanPrefixQualifiers = true; break; @@ -2139,6 +2140,53 @@ void TypePrinter::printHLSLAttributedResourceAfter( } } +void TypePrinter::printHLSLInlineSpirvBefore(const HLSLInlineSpirvType *T, + raw_ostream &OS) { + OS << "__hlsl_spirv_type<" << T->getOpcode(); + + OS << ", " << T->getSize(); + OS << ", " << T->getAlignment(); + + for (auto &Operand : T->getOperands()) { + using SpirvOperandKind = SpirvOperand::SpirvOperandKind; + + OS << ", "; + switch (Operand.getKind()) { + case SpirvOperandKind::ConstantId: { + QualType ConstantType = Operand.getResultType(); + OS << "vk::integral_constant<"; + printBefore(ConstantType, OS); + printAfter(ConstantType, OS); + OS << ", "; + OS << Operand.getValue(); + OS << ">"; + break; + } + case SpirvOperandKind::Literal: + OS << "vk::Literal<vk::integral_constant<uint, "; + OS << Operand.getValue(); + OS << ">>"; + break; + case SpirvOperandKind::TypeId: { + QualType Type = Operand.getResultType(); + printBefore(Type, OS); + printAfter(Type, OS); + break; + } + default: + llvm_unreachable("Invalid SpirvOperand kind!"); + break; + } + } + + OS << ">"; +} + +void TypePrinter::printHLSLInlineSpirvAfter(const HLSLInlineSpirvType *T, + raw_ostream &OS) { + // nothing to do +} + void TypePrinter::printObjCInterfaceBefore(const ObjCInterfaceType *T, raw_ostream &OS) { OS << T->getDecl()->getName(); diff --git a/clang/lib/CodeGen/CGDebugInfo.cpp b/clang/lib/CodeGen/CGDebugInfo.cpp index 21896c94c86b4..d5662b194a116 100644 --- a/clang/lib/CodeGen/CGDebugInfo.cpp +++ b/clang/lib/CodeGen/CGDebugInfo.cpp @@ -3638,6 +3638,12 @@ llvm::DIType *CGDebugInfo::CreateType(const HLSLAttributedResourceType *Ty, return getOrCreateType(Ty->getWrappedType(), U); } +llvm::DIType *CGDebugInfo::CreateType(const HLSLInlineSpirvType *Ty, + llvm::DIFile *U) { + // Debug information unneeded. + return nullptr; +} + llvm::DIType *CGDebugInfo::CreateEnumType(const EnumType *Ty) { const EnumDecl *ED = Ty->getDecl(); @@ -3991,6 +3997,8 @@ llvm::DIType *CGDebugInfo::CreateTypeNode(QualType Ty, llvm::DIFile *Unit) { return CreateType(cast<TemplateSpecializationType>(Ty), Unit); case Type::HLSLAttributedResource: return CreateType(cast<HLSLAttributedResourceType>(Ty), Unit); + case Type::HLSLInlineSpirv: + return CreateType(cast<HLSLInlineSpirvType>(Ty), Unit); case Type::CountAttributed: case Type::Auto: diff --git a/clang/lib/CodeGen/CGDebugInfo.h b/clang/lib/CodeGen/CGDebugInfo.h index 79d031acbf19e..ec27fb04f3d9c 100644 --- a/clang/lib/CodeGen/CGDebugInfo.h +++ b/clang/lib/CodeGen/CGDebugInfo.h @@ -210,6 +210,7 @@ class CGDebugInfo { llvm::DIType *CreateType(const FunctionType *Ty, llvm::DIFile *F); llvm::DIType *CreateType(const HLSLAttributedResourceType *Ty, llvm::DIFile *F); + llvm::DIType *CreateType(const HLSLInlineSpirvType *Ty, llvm::DIFile *F); /// Get structure or union type. llvm::DIType *CreateType(const RecordType *Tyg); diff --git a/clang/lib/CodeGen/CodeGenFunction.cpp b/clang/lib/CodeGen/CodeGenFunction.cpp index 0356952f4f291..e2357563f7d56 100644 --- a/clang/lib/CodeGen/CodeGenFunction.cpp +++ b/clang/lib/CodeGen/CodeGenFunction.cpp @@ -283,6 +283,7 @@ TypeEvaluationKind CodeGenFunction::getEvaluationKind(QualType type) { case Type::Pipe: case Type::BitInt: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: return TEK_Scalar; // Complexes. @@ -2473,6 +2474,7 @@ void CodeGenFunction::EmitVariablyModifiedType(QualType type) { case Type::ObjCInterface: case Type::ObjCObjectPointer: case Type::BitInt: + case Type::HLSLInlineSpirv: llvm_unreachable("type class is never variably-modified!"); case Type::Elaborated: diff --git a/clang/lib/CodeGen/CodeGenTypes.cpp b/clang/lib/CodeGen/CodeGenTypes.cpp index 843733ba6604d..36c5f2ba944c2 100644 --- a/clang/lib/CodeGen/CodeGenTypes.cpp +++ b/clang/lib/CodeGen/CodeGenTypes.cpp @@ -765,6 +765,7 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) { break; } case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: ResultType = CGM.getHLSLRuntime().convertHLSLSpecificType(Ty); break; } @@ -877,6 +878,10 @@ bool CodeGenTypes::isZeroInitializable(QualType T) { if (const MemberPointerType *MPT = T->getAs<MemberPointerType>()) return getCXXABI().isZeroInitializable(MPT); + // HLSL Inline SPIR-V types are non-zero-initializable. + if (T->getAs<HLSLInlineSpirvType>()) + return false; + // Everything else is okay. return true; } diff --git a/clang/lib/CodeGen/ItaniumCXXABI.cpp b/clang/lib/CodeGen/ItaniumCXXABI.cpp index 5018a6b39b000..e8114746ca339 100644 --- a/clang/lib/CodeGen/ItaniumCXXABI.cpp +++ b/clang/lib/CodeGen/ItaniumCXXABI.cpp @@ -3962,6 +3962,7 @@ void ItaniumRTTIBuilder::BuildVTablePointer(const Type *Ty, break; case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: llvm_unreachable("HLSL doesn't support virtual functions"); } @@ -4237,6 +4238,7 @@ llvm::Constant *ItaniumRTTIBuilder::BuildTypeInfo( break; case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: llvm_unreachable("HLSL doesn't support RTTI"); } diff --git a/clang/lib/CodeGen/Targets/SPIR.cpp b/clang/lib/CodeGen/Targets/SPIR.cpp index f35c124f50aa0..cb190b32abdb1 100644 --- a/clang/lib/CodeGen/Targets/SPIR.cpp +++ b/clang/lib/CodeGen/Targets/SPIR.cpp @@ -377,14 +377,99 @@ llvm::Type *CommonSPIRTargetCodeGenInfo::getOpenCLType(CodeGenModule &CGM, return nullptr; } +// Gets a spirv.IntegralConstant or spirv.Literal. If IntegralType is present, +// returns an IntegralConstant, otherwise returns a Literal. +static llvm::Type *getInlineSpirvConstant(CodeGenModule &CGM, + llvm::Type *IntegralType, + llvm::APInt Value) { + llvm::LLVMContext &Ctx = CGM.getLLVMContext(); + + // Convert the APInt value to an array of uint32_t words + llvm::SmallVector<uint32_t> Words; + + while (Value.ugt(0)) { + uint32_t Word = Value.trunc(32).getZExtValue(); + Value.lshrInPlace(32); + + Words.push_back(Word); + } + if (Words.size() == 0) + Words.push_back(0); + + if (IntegralType) + return llvm::TargetExtType::get(Ctx, "spirv.IntegralConstant", + {IntegralType}, Words); + return llvm::TargetExtType::get(Ctx, "spirv.Literal", {}, Words); +} + +static llvm::Type *getInlineSpirvType(CodeGenModule &CGM, + const HLSLInlineSpirvType *SpirvType) { + llvm::LLVMContext &Ctx = CGM.getLLVMContext(); + + llvm::SmallVector<llvm::Type *> Operands; + + for (auto &Operand : SpirvType->getOperands()) { + using SpirvOperandKind = SpirvOperand::SpirvOperandKind; + + llvm::Type *Result = nullptr; + switch (Operand.getKind()) { + case SpirvOperandKind::ConstantId: { + llvm::Type *IntegralType = + CGM.getTypes().ConvertType(Operand.getResultType()); + llvm::APInt Value = Operand.getValue(); + + Result = getInlineSpirvConstant(CGM, IntegralType, Value); + break; + } + case SpirvOperandKind::Literal: { + llvm::APInt Value = Operand.getValue(); + Result = getInlineSpirvConstant(CGM, nullptr, Value); + break; + } + case SpirvOperandKind::TypeId: { + QualType TypeOperand = Operand.getResultType(); + if (auto *RT = TypeOperand->getAs<RecordType>()) { + auto *RD = RT->getDecl(); + assert(RD->isCompleteDefinition() && + "Type completion should have been required in Sema"); + + const FieldDecl *HandleField = RD->findFirstNamedDataMember(); + if (HandleField) { + QualType ResourceType = HandleField->getType(); + if (ResourceType->getAs<HLSLAttributedResourceType>()) { + TypeOperand = ResourceType; + } + } + } + Result = CGM.getTypes().ConvertType(TypeOperand); + break; + } + default: + llvm_unreachable("HLSLInlineSpirvType had invalid operand!"); + break; + } + + assert(Result); + Operands.push_back(Result); + } + + return llvm::TargetExtType::get(Ctx, "spirv.Type", Operands, + {SpirvType->getOpcode(), SpirvType->getSize(), + SpirvType->getAlignment()}); +} + llvm::Type *CommonSPIRTargetCodeGenInfo::getHLSLType( CodeGenModule &CGM, const Type *Ty, const SmallVector<int32_t> *Packoffsets) const { + llvm::LLVMContext &Ctx = CGM.getLLVMContext(); + + if (auto *SpirvType = dyn_cast<HLSLInlineSpirvType>(Ty)) + return getInlineSpirvType(CGM, SpirvType); + auto *ResType = dyn_cast<HLSLAttributedResourceType>(Ty); if (!ResType) return nullptr; - llvm::LLVMContext &Ctx = CGM.getLLVMContext(); const HLSLAttributedResourceType::Attributes &ResAttrs = ResType->getAttrs(); switch (ResAttrs.ResourceClass) { case llvm::dxil::ResourceClass::UAV: diff --git a/clang/lib/Headers/CMakeLists.txt b/clang/lib/Headers/CMakeLists.txt index 449feb012481f..53219dc932587 100644 --- a/clang/lib/Headers/CMakeLists.txt +++ b/clang/lib/Headers/CMakeLists.txt @@ -91,6 +91,7 @@ set(hlsl_subdir_files hlsl/hlsl_intrinsic_helpers.h hlsl/hlsl_intrinsics.h hlsl/hlsl_detail.h + hlsl/hlsl_spirv.h ) set(hlsl_files ${hlsl_h} diff --git a/clang/lib/Headers/hlsl.h b/clang/lib/Headers/hlsl.h index b494b4d0f78bb..684d29d5ed55b 100644 --- a/clang/lib/Headers/hlsl.h +++ b/clang/lib/Headers/hlsl.h @@ -27,6 +27,10 @@ #endif #include "hlsl/hlsl_intrinsics.h" +#ifdef __spirv__ +#include "hlsl/hlsl_spirv.h" +#endif // __spirv__ + #if defined(__clang__) #pragma clang diagnostic pop #endif diff --git a/clang/lib/Headers/hlsl/hlsl_spirv.h b/clang/lib/Headers/hlsl/hlsl_spirv.h new file mode 100644 index 0000000000000..711da2fea46a4 --- /dev/null +++ b/clang/lib/Headers/hlsl/hlsl_spirv.h @@ -0,0 +1,28 @@ +//===----- hlsl_spirv.h - HLSL definitions for SPIR-V target --------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef _HLSL_HLSL_SPIRV_H_ +#define _HLSL_HLSL_SPIRV_H_ + +namespace hlsl { +namespace vk { +template <typename T, T v> struct integral_constant { + static constexpr T value = v; +}; + +template <typename T> struct Literal {}; + +template <uint Opcode, uint Size, uint Alignment, typename... Operands> +using SpirvType = __hlsl_spirv_type<Opcode, Size, Alignment, Operands...>; + +template <uint Opcode, typename... Operands> +using SpirvOpaqueType = __hlsl_spirv_type<Opcode, 0, 0, Operands...>; +} // namespace vk +} // namespace hlsl + +#endif // _HLSL_HLSL_SPIRV_H_ \ No newline at end of file diff --git a/clang/lib/Sema/SemaExpr.cpp b/clang/lib/Sema/SemaExpr.cpp index 5aa33939a7817..198b7da9c35d8 100644 --- a/clang/lib/Sema/SemaExpr.cpp +++ b/clang/lib/Sema/SemaExpr.cpp @@ -4562,6 +4562,7 @@ static void captureVariablyModifiedType(ASTContext &Context, QualType T, case Type::ObjCTypeParam: case Type::Pipe: case Type::BitInt: + case Type::HLSLInlineSpirv: llvm_unreachable("type class is never variably-modified!"); case Type::Elaborated: T = cast<ElaboratedType>(Ty)->getNamedType(); diff --git a/clang/lib/Sema/SemaLookup.cpp b/clang/lib/Sema/SemaLookup.cpp index 55428af73011c..eef134b158438 100644 --- a/clang/lib/Sema/SemaLookup.cpp +++ b/clang/lib/Sema/SemaLookup.cpp @@ -927,13 +927,25 @@ bool Sema::LookupBuiltin(LookupResult &R) { NameKind == Sema::LookupRedeclarationWithLinkage) { IdentifierInfo *II = R.getLookupName().getAsIdentifierInfo(); if (II) { - if (getLangOpts().CPlusPlus && NameKind == Sema::LookupOrdinaryName) { -#define BuiltinTemplate(BIName) \ + if (NameKind == Sema::LookupOrdinaryName) { + if (getLangOpts().CPlusPlus) { +#define BuiltinTemplate(BIName) +#define CPlusPlusBuiltinTemplate(BIName) \ if (II == getASTContext().get##BIName##Name()) { \ R.addDecl(getASTContext().get##BIName##Decl()); \ return true; \ } #include "clang/Basic/BuiltinTemplates.inc" + } + if (getLangOpts().HLSL) { +#define BuiltinTemplate(BIName) +#define HLSLBuiltinTemplate(BIName) \ + if (II == getASTContext().get##BIName##Name()) { \ + R.addDecl(getASTContext().get##BIName##Decl()); \ + return true; \ + } +#include "clang/Basic/BuiltinTemplates.inc" + } } // Check if this is an OpenCL Builtin, and if so, insert its overloads. @@ -3270,6 +3282,11 @@ addAssociatedClassesAndNamespaces(AssociatedLookup &Result, QualType Ty) { case Type::HLSLAttributedResource: T = cast<HLSLAttributedResourceType>(T)->getWrappedType().getTypePtr(); + break; + + // Inline SPIR-V types are treated as fundamental types. + case Type::HLSLInlineSpirv: + break; } if (Queue.empty()) diff --git a/clang/lib/Sema/SemaTemplate.cpp b/clang/lib/Sema/SemaTemplate.cpp index 41c3f81639c82..10e7823542f0b 100644 --- a/clang/lib/Sema/SemaTemplate.cpp +++ b/clang/lib/Sema/SemaTemplate.cpp @@ -3234,6 +3234,59 @@ static QualType builtinCommonTypeImpl(Sema &S, TemplateName BaseTemplate, } } +static bool isInVkNamespace(const RecordType *RT) { + DeclContext *DC = RT->getDecl()->getDeclContext(); + if (!DC) + return false; + + NamespaceDecl *ND = dyn_cast<NamespaceDecl>(DC); + if (!ND) + return false; + + return ND->getQualifiedNameAsString() == "hlsl::vk"; +} + +static SpirvOperand checkHLSLSpirvTypeOperand(Sema &SemaRef, + QualType OperandArg, + SourceLocation Loc) { + if (auto *RT = OperandArg->getAs<RecordType>()) { + bool Literal = false; + SourceLocation LiteralLoc; + if (isInVkNamespace(RT) && RT->getDecl()->getName() == "Literal") { + auto SpecDecl = dyn_cast<ClassTemplateSpecializationDecl>(RT->getDecl()); + assert(SpecDecl); + + const TemplateArgumentList &LiteralArgs = SpecDecl->getTemplateArgs(); + QualType ConstantType = LiteralArgs[0].getAsType(); + RT = ConstantType->getAs<RecordType>(); + Literal = true; + LiteralLoc = SpecDecl->getSourceRange().getBegin(); + } + + if (RT && isInVkNamespace(RT) && + RT->getDecl()->getName() == "integral_constant") { + auto SpecDecl = dyn_cast<ClassTemplateSpecializationDecl>(RT->getDecl()); + assert(SpecDecl); + + const TemplateArgumentList &ConstantArgs = SpecDecl->getTemplateArgs(); + + QualType ConstantType = ConstantArgs[0].getAsType(); + llvm::APInt Value = ConstantArgs[1].getAsIntegral(); + + if (Literal) + return SpirvOperand::createLiteral(Value); + return SpirvOperand::createConstant(ConstantType, Value); + } else if (Literal) { + SemaRef.Diag(LiteralLoc, diag::err_hlsl_vk_literal_must_contain_constant); + return SpirvOperand(); + } + } + if (SemaRef.RequireCompleteType(Loc, OperandArg, + diag::err_call_incomplete_argument)) + return SpirvOperand(); + return SpirvOperand::createType(OperandArg); +} + static QualType checkBuiltinTemplateIdType(Sema &SemaRef, BuiltinTemplateDecl *BTD, ArrayRef<TemplateArgument> Converted, @@ -3334,6 +3387,36 @@ checkBuiltinTemplateIdType(Sema &SemaRef, BuiltinTemplateDecl *BTD, QualType HasNoTypeMember = Converted[2].getAsType(); return HasNoTypeMember; } + + case BTK__hlsl_spirv_type: { + assert(Converted.size() == 4); + + if (!Context.getTargetInfo().getTriple().isSPIRV()) { + SemaRef.Diag(TemplateLoc, diag::err_hlsl_spirv_only) << BTD; + } + + if (llvm::any_of(Converted, [](auto &C) { return C.isDependent(); })) + return QualType(); + + uint64_t Opcode = Converted[0].getAsIntegral().getZExtValue(); + uint64_t Size = Converted[1].getAsIntegral().getZExtValue(); + uint64_t Alignment = Converted[2].getAsIntegral().getZExtValue(); + + ArrayRef<TemplateArgument> OperandArgs = Converted[3].getPackAsArray(); + + llvm::SmallVector<SpirvOperand> Operands; + + for (auto &OperandTA : OperandArgs) { + QualType OperandArg = OperandTA.getAsType(); + auto Operand = checkHLSLSpirvTypeOperand(SemaRef, OperandArg, + TemplateArgs[3].getLocation()); + if (!Operand.isValid()) + return QualType(); + Operands.push_back(Operand); + } + + return Context.getHLSLInlineSpirvType(Opcode, Size, Alignment, Operands); + } } llvm_unreachable("unexpected BuiltinTemplateDecl!"); } @@ -6251,6 +6334,15 @@ bool UnnamedLocalNoLinkageFinder::VisitHLSLAttributedResourceType( return Visit(T->getWrappedType()); } +bool UnnamedLocalNoLinkageFinder::VisitHLSLInlineSpirvType( + const HLSLInlineSpirvType *T) { + for (auto &Operand : T->getOperands()) + if (Operand.isConstant() && Operand.isLiteral()) + if (Visit(Operand.getResultType())) + return true; + return false; +} + bool Sema::CheckTemplateArgument(TypeSourceInfo *ArgInfo) { assert(ArgInfo && "invalid TypeSourceInfo"); QualType Arg = ArgInfo->getType(); diff --git a/clang/lib/Sema/SemaTemplateDeduction.cpp b/clang/lib/Sema/SemaTemplateDeduction.cpp index 75ae04b27d06a..51bfca9da4231 100644 --- a/clang/lib/Sema/SemaTemplateDeduction.cpp +++ b/clang/lib/Sema/SemaTemplateDeduction.cpp @@ -2474,6 +2474,7 @@ static TemplateDeductionResult DeduceTemplateArgumentsByTypeMatch( case Type::Pipe: case Type::ArrayParameter: case Type::HLSLAttributedResource: + case Type::HLSLInlineSpirv: // No template argument deduction for these types return TemplateDeductionResult::Success; @@ -6993,6 +6994,7 @@ MarkUsedTemplateParameters(ASTContext &Ctx, QualType T, case Type::UnresolvedUsing: case Type::Pipe: case Type::BitInt: + case Type::HLSLInlineSpirv: #define TYPE(Class, Base) #define ABSTRACT_TYPE(Class, Base) #define DEPENDENT_TYPE(Class, Base) diff --git a/clang/lib/Sema/SemaType.cpp b/clang/lib/Sema/SemaType.cpp index 49d10f5502084..338b81fe89748 100644 --- a/clang/lib/Sema/SemaType.cpp +++ b/clang/lib/Sema/SemaType.cpp @@ -5888,6 +5888,7 @@ namespace { Visit(TL.getWrappedLoc()); fillHLSLAttributedResourceTypeLoc(TL, State); } + void VisitHLSLInlineSpirvTypeLoc(HLSLInlineSpirvTypeLoc TL) {} void VisitMacroQualifiedTypeLoc(MacroQualifiedTypeLoc TL) { Visit(TL.getInnerLoc()); TL.setExpansionLoc( diff --git a/clang/lib/Sema/TreeTransform.h b/clang/lib/Sema/TreeTransform.h index 335e21d927b76..7629e84f5b848 100644 --- a/clang/lib/Sema/TreeTransform.h +++ b/clang/lib/Sema/TreeTransform.h @@ -7683,6 +7683,13 @@ QualType TreeTransform<Derived>::TransformHLSLAttributedResourceType( return Result; } +template <typename Derived> +QualType TreeTransform<Derived>::TransformHLSLInlineSpirvType( + TypeLocBuilder &TLB, HLSLInlineSpirvTypeLoc TL) { + // No transformations needed. + return TL.getType(); +} + template<typename Derived> QualType TreeTransform<Derived>::TransformParenType(TypeLocBuilder &TLB, diff --git a/clang/lib/Serialization/ASTReader.cpp b/clang/lib/Serialization/ASTReader.cpp index c113fd7cbb911..1ecfb92aac71b 100644 --- a/clang/lib/Serialization/ASTReader.cpp +++ b/clang/lib/Serialization/ASTReader.cpp @@ -7349,6 +7349,10 @@ void TypeLocReader::VisitHLSLAttributedResourceTypeLoc( // Nothing to do. } +void TypeLocReader::VisitHLSLInlineSpirvTypeLoc(HLSLInlineSpirvTypeLoc TL) { + // Nothing to do. +} + void TypeLocReader::VisitTemplateTypeParmTypeLoc(TemplateTypeParmTypeLoc TL) { TL.setNameLoc(readSourceLocation()); } @@ -9821,6 +9825,15 @@ TypeCoupledDeclRefInfo ASTRecordReader::readTypeCoupledDeclRefInfo() { return TypeCoupledDeclRefInfo(readDeclAs<ValueDecl>(), readBool()); } +SpirvOperand ASTRecordReader::readHLSLSpirvOperand() { + auto Kind = readInt(); + auto ResultType = readQualType(); + auto Value = readAPInt(); + SpirvOperand Op(SpirvOperand::SpirvOperandKind(Kind), ResultType, Value); + assert(Op.isValid()); + return Op; +} + void ASTRecordReader::readQualifierInfo(QualifierInfo &Info) { Info.QualifierLoc = readNestedNameSpecifierLoc(); unsigned NumTPLists = readInt(); diff --git a/clang/lib/Serialization/ASTWriter.cpp b/clang/lib/Serialization/ASTWriter.cpp index 1b3d3c22aa9f5..cc9916a75d4b4 100644 --- a/clang/lib/Serialization/ASTWriter.cpp +++ b/clang/lib/Serialization/ASTWriter.cpp @@ -604,6 +604,10 @@ void TypeLocWriter::VisitHLSLAttributedResourceTypeLoc( // Nothing to do. } +void TypeLocWriter::VisitHLSLInlineSpirvTypeLoc(HLSLInlineSpirvTypeLoc TL) { + // Nothing to do. +} + void TypeLocWriter::VisitTemplateTypeParmTypeLoc(TemplateTypeParmTypeLoc TL) { addSourceLocation(TL.getNameLoc()); } diff --git a/clang/test/AST/HLSL/Inputs/pch_spirv_type.hlsl b/clang/test/AST/HLSL/Inputs/pch_spirv_type.hlsl new file mode 100644 index 0000000000000..326c75dbc3bbe --- /dev/null +++ b/clang/test/AST/HLSL/Inputs/pch_spirv_type.hlsl @@ -0,0 +1,6 @@ + +float2 foo(float2 a, float2 b) { + return a + b; +} + +vk::SpirvOpaqueType</* OpTypeArray */ 28, RWBuffer<float>, vk::integral_constant<uint, 4>> buffers; diff --git a/clang/test/AST/HLSL/ast-dump-SpirvType.hlsl b/clang/test/AST/HLSL/ast-dump-SpirvType.hlsl new file mode 100644 index 0000000000000..a4a147622c6cc --- /dev/null +++ b/clang/test/AST/HLSL/ast-dump-SpirvType.hlsl @@ -0,0 +1,27 @@ +// RUN: %clang_cc1 -finclude-default-header -triple spirv-unknown-vulkan-compute -x hlsl -ast-dump -o - %s | FileCheck %s + +// CHECK: TypedefDecl 0x{{.+}} <{{.+}}:4:1, col:83> col:83 referenced AType 'vk::SpirvOpaqueType<123, RWBuffer<float>, vk::integral_constant<uint, 4>>':'__hlsl_spirv_type<123, 0, 0, RWBuffer<float>, vk::integral_constant<unsigned int, 4>>' +typedef vk::SpirvOpaqueType<123, RWBuffer<float>, vk::integral_constant<uint, 4>> AType; +// CHECK: TypedefDecl 0x{{.+}} <{{.+}}:6:1, col:133> col:133 referenced BType 'vk::SpirvType<12, 2, 4, vk::integral_constant<uint64_t, 4886718345L>, float, vk::Literal<vk::integral_constant<uint, 456>>>':'__hlsl_spirv_type<12, 2, 4, vk::integral_constant<unsigned long, 4886718345>, float, vk::Literal<vk::integral_constant<uint, 456>>>' +typedef vk::SpirvType<12, 2, 4, vk::integral_constant<uint64_t, 0x123456789>, float, vk::Literal<vk::integral_constant<uint, 456>>> BType; + +// CHECK: VarDecl 0x{{.+}} <{{.+}}:9:1, col:7> col:7 AValue 'hlsl_constant AType':'hlsl_constant __hlsl_spirv_type<123, 0, 0, RWBuffer<float>, vk::integral_constant<unsigned int, 4>>' +AType AValue; +// CHECK: VarDecl 0x{{.+}} <{{.+}}:11:1, col:7> col:7 BValue 'hlsl_constant BType':'hlsl_constant __hlsl_spirv_type<12, 2, 4, vk::integral_constant<unsigned long, 4886718345>, float, vk::Literal<vk::integral_constant<uint, 456>>>' +BType BValue; + +// CHECK: VarDecl 0x{{.+}} <{{.+}}:14:1, col:80> col:80 CValue 'hlsl_constant vk::SpirvOpaqueType<123, vk::Literal<vk::integral_constant<uint, 305419896>>>':'hlsl_constant __hlsl_spirv_type<123, 0, 0, vk::Literal<vk::integral_constant<uint, 305419896>>>' +vk::SpirvOpaqueType<123, vk::Literal<vk::integral_constant<uint, 0x12345678>>> CValue; + +// CHECK: TypeAliasDecl 0x{{.+}} <{{.+}}:18:1, col:72> col:7 Array 'vk::SpirvOpaqueType<28, T, vk::integral_constant<uint, L>>':'__hlsl_spirv_type<28U, 0, 0, T, vk::integral_constant<uint, L>>' +template <class T, uint L> +using Array = vk::SpirvOpaqueType<28, T, vk::integral_constant<uint, L>>; + +// CHECK: VarDecl 0x{{.+}} <{{.+}}:21:1, col:16> col:16 DValue 'hlsl_constant Array<uint, 5>':'hlsl_constant __hlsl_spirv_type<28, 0, 0, uint, vk::integral_constant<unsigned int, 5>>' +Array<uint, 5> DValue; + +[numthreads(1, 1, 1)] +void main() { +// CHECK: VarDecl 0x{{.+}} <col:5, col:11> col:11 EValue 'AType':'__hlsl_spirv_type<123, 0, 0, RWBuffer<float>, vk::integral_constant<unsigned int, 4>>' + AType EValue; +} diff --git a/clang/test/AST/HLSL/pch_spirv_type.hlsl b/clang/test/AST/HLSL/pch_spirv_type.hlsl new file mode 100644 index 0000000000000..045f89a1b8461 --- /dev/null +++ b/clang/test/AST/HLSL/pch_spirv_type.hlsl @@ -0,0 +1,17 @@ +// RUN: %clang_cc1 -triple spirv-unknown-vulkan-library -x hlsl \ +// RUN: -finclude-default-header -emit-pch -o %t %S/Inputs/pch_spirv_type.hlsl +// RUN: %clang_cc1 -triple spirv-unknown-vulkan-library -x hlsl \ +// RUN: -finclude-default-header -include-pch %t -ast-dump-all %s \ +// RUN: | FileCheck %s + +// Make sure PCH works by using function declared in PCH header and declare a SpirvType in current file. +// CHECK:FunctionDecl 0x[[FOO:[0-9a-f]+]] <{{.*}}:2:1, line:4:1> line:2:8 imported used foo 'float2 (float2, float2)' +// CHECK:VarDecl 0x{{[0-9a-f]+}} <{{.*}}:10:1, col:92> col:92 buffers2 'hlsl_constant vk::SpirvOpaqueType<28, RWBuffer<float>, vk::integral_constant<uint, 4>>':'hlsl_constant __hlsl_spirv_type<28, 0, 0, RWBuffer<float>, vk::integral_constant<unsigned int, 4>>' +vk::SpirvOpaqueType</* OpTypeArray */ 28, RWBuffer<float>, vk::integral_constant<uint, 4>> buffers2; + +float2 bar(float2 a, float2 b) { +// CHECK:CallExpr 0x{{[0-9a-f]+}} <col:10, col:18> 'float2':'vector<float, 2>' +// CHECK-NEXT:ImplicitCastExpr 0x{{[0-9a-f]+}} <col:10> 'float2 (*)(float2, float2)' <FunctionToPointerDecay> +// CHECK-NEXT:`-DeclRefExpr 0x{{[0-9a-f]+}} <col:10> 'float2 (float2, float2)' lvalue Function 0x[[FOO]] 'foo' 'float2 (float2, float2)' + return foo(a, b); +} diff --git a/clang/test/AST/HLSL/vector-alias.hlsl b/clang/test/AST/HLSL/vector-alias.hlsl index 58d80e9b4a4e4..e1f78e6abdca8 100644 --- a/clang/test/AST/HLSL/vector-alias.hlsl +++ b/clang/test/AST/HLSL/vector-alias.hlsl @@ -1,53 +1,52 @@ -// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s - -// CHECK: NamespaceDecl {{.*}} implicit hlsl -// CHECK-NEXT: TypeAliasTemplateDecl {{.*}} implicit vector -// CHECK-NEXT: TemplateTypeParmDecl {{.*}} class depth 0 index 0 element -// CHECK-NEXT: TemplateArgument type 'float' -// CHECK-NEXT: BuiltinType {{.*}} 'float' -// CHECK-NEXT: NonTypeTemplateParmDecl {{.*}} 'int' depth 0 index 1 element_count -// CHECK-NEXT: TemplateArgument expr -// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 4 -// CHECK-NEXT: TypeAliasDecl {{.*}} implicit vector 'vector<element, element_count>' -// CHECK-NEXT: DependentSizedExtVectorType {{.*}} 'vector<element, element_count>' dependent -// CHECK-NEXT: TemplateTypeParmType {{.*}} 'element' dependent depth 0 index 0 -// CHECK-NEXT: TemplateTypeParm {{.*}} 'element' -// CHECK-NEXT: DeclRefExpr {{.*}} 'int' lvalue -// NonTypeTemplateParm {{.*}} 'element_count' 'int' - -// Make sure we got a using directive at the end. -// CHECK: UsingDirectiveDecl {{.*}} Namespace {{.*}} 'hlsl' - -[numthreads(1,1,1)] -int entry() { - // Verify that the alias is generated inside the hlsl namespace. - hlsl::vector<float, 2> Vec2 = {1.0, 2.0}; - - // CHECK: DeclStmt - // CHECK-NEXT: VarDecl {{.*}} Vec2 'hlsl::vector<float, 2>':'vector<float, 2>' cinit - - // Verify that you don't need to specify the namespace. - vector<int, 2> Vec2a = {1, 2}; - - // CHECK: DeclStmt - // CHECK-NEXT: VarDecl {{.*}} Vec2a 'vector<int, 2>' cinit - - // Build a bigger vector. - vector<double, 4> Vec4 = {1.0, 2.0, 3.0, 4.0}; - - // CHECK: DeclStmt - // CHECK-NEXT: VarDecl {{.*}} used Vec4 'vector<double, 4>' cinit - - // Verify that swizzles still work. - vector<double, 3> Vec3 = Vec4.xyz; - - // CHECK: DeclStmt {{.*}} - // CHECK-NEXT: VarDecl {{.*}} Vec3 'vector<double, 3>' cinit - - // Verify that the implicit arguments generate the correct type. - vector<> ImpVec4 = {1.0, 2.0, 3.0, 4.0}; - - // CHECK: DeclStmt - // CHECK-NEXT: VarDecl {{.*}} ImpVec4 'vector<>':'vector<float, 4>' cinit - return 1; -} +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s +// CHECK: NamespaceDecl {{.*}} implicit hlsl +// CHECK: TypeAliasTemplateDecl {{.*}} implicit vector +// CHECK-NEXT: TemplateTypeParmDecl {{.*}} class depth 0 index 0 element +// CHECK-NEXT: TemplateArgument type 'float' +// CHECK-NEXT: BuiltinType {{.*}} 'float' +// CHECK-NEXT: NonTypeTemplateParmDecl {{.*}} 'int' depth 0 index 1 element_count +// CHECK-NEXT: TemplateArgument expr +// CHECK-NEXT: IntegerLiteral {{.*}} 'int' 4 +// CHECK-NEXT: TypeAliasDecl {{.*}} implicit vector 'vector<element, element_count>' +// CHECK-NEXT: DependentSizedExtVectorType {{.*}} 'vector<element, element_count>' dependent +// CHECK-NEXT: TemplateTypeParmType {{.*}} 'element' dependent depth 0 index 0 +// CHECK-NEXT: TemplateTypeParm {{.*}} 'element' +// CHECK-NEXT: DeclRefExpr {{.*}} 'int' lvalue +// NonTypeTemplateParm {{.*}} 'element_count' 'int' + +// Make sure we got a using directive at the end. +// CHECK: UsingDirectiveDecl {{.*}} Namespace {{.*}} 'hlsl' + +[numthreads(1,1,1)] +int entry() { + // Verify that the alias is generated inside the hlsl namespace. + hlsl::vector<float, 2> Vec2 = {1.0, 2.0}; + + // CHECK: DeclStmt + // CHECK-NEXT: VarDecl {{.*}} Vec2 'hlsl::vector<float, 2>':'vector<float, 2>' cinit + + // Verify that you don't need to specify the namespace. + vector<int, 2> Vec2a = {1, 2}; + + // CHECK: DeclStmt + // CHECK-NEXT: VarDecl {{.*}} Vec2a 'vector<int, 2>' cinit + + // Build a bigger vector. + vector<double, 4> Vec4 = {1.0, 2.0, 3.0, 4.0}; + + // CHECK: DeclStmt + // CHECK-NEXT: VarDecl {{.*}} used Vec4 'vector<double, 4>' cinit + + // Verify that swizzles still work. + vector<double, 3> Vec3 = Vec4.xyz; + + // CHECK: DeclStmt {{.*}} + // CHECK-NEXT: VarDecl {{.*}} Vec3 'vector<double, 3>' cinit + + // Verify that the implicit arguments generate the correct type. + vector<> ImpVec4 = {1.0, 2.0, 3.0, 4.0}; + + // CHECK: DeclStmt + // CHECK-NEXT: VarDecl {{.*}} ImpVec4 'vector<>':'vector<float, 4>' cinit + return 1; +} diff --git a/clang/test/CodeGenHLSL/inline-spirv/SpirvType.alignment.hlsl b/clang/test/CodeGenHLSL/inline-spirv/SpirvType.alignment.hlsl new file mode 100644 index 0000000000000..41cdd7d21bcbe --- /dev/null +++ b/clang/test/CodeGenHLSL/inline-spirv/SpirvType.alignment.hlsl @@ -0,0 +1,16 @@ +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \ +// RUN: -o - | FileCheck %s + +using Int = vk::SpirvType</* OpTypeInt */ 21, 4, 64, vk::Literal<vk::integral_constant<uint, 8>>, vk::Literal<vk::integral_constant<bool, false>>>; + +// CHECK: %struct.S = type <{ i32, target("spirv.Type", target("spirv.Literal", 8), target("spirv.Literal", 0), 21, 4, 64), [4 x i8] }> +struct S { + int a; + Int b; +}; + +[numthreads(1,1,1)] +void main() { + S value; +} diff --git a/clang/test/CodeGenHLSL/inline-spirv/SpirvType.hlsl b/clang/test/CodeGenHLSL/inline-spirv/SpirvType.hlsl new file mode 100644 index 0000000000000..5b58e436bbedb --- /dev/null +++ b/clang/test/CodeGenHLSL/inline-spirv/SpirvType.hlsl @@ -0,0 +1,68 @@ +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: spirv-unknown-vulkan-compute %s -emit-llvm -disable-llvm-passes \ +// RUN: -o - | FileCheck %s + +template<class T, uint64_t Size> +using Array = vk::SpirvOpaqueType</* OpTypeArray */ 28, T, vk::integral_constant<uint64_t, Size>>; + +template<uint64_t Size> +using ArrayBuffer = Array<RWBuffer<float>, Size>; + +typedef vk::SpirvType</* OpTypeInt */ 21, 4, 32, vk::Literal<vk::integral_constant<uint, 32>>, vk::Literal<vk::integral_constant<bool, false>>> Int; + +typedef Array<Int, 5> ArrayInt; + +// CHECK: %struct.S = type { target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0), target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32) } +struct S { + ArrayBuffer<4> b; + Int i; +}; + +// CHECK: define spir_func target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0) @_Z14getArrayBufferu17spirv_type_28_0_0U5_TypeN4hlsl8RWBufferIfEEU6_ConstLm4E(target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0) %v) #0 +ArrayBuffer<4> getArrayBuffer(ArrayBuffer<4> v) { + return v; +} + +// CHECK: define spir_func target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32) @_Z6getIntu18spirv_type_21_4_32U4_LitLi32EU4_LitLi0E(target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32) %v) #0 +Int getInt(Int v) { + return v; +} + +// TODO: uncomment and test once CBuffer handles are implemented for SPIR-V +// ArrayBuffer<4> g_buffers; +// Int g_word; + +[numthreads(1, 1, 1)] +void main() { + // CHECK: [[buffers:%.*]] = alloca target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0), align 4 + ArrayBuffer<4> buffers; + + // CHECK: [[longBuffers:%.*]] = alloca target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 591751049, 1), 28, 0, 0), align 4 + ArrayBuffer<0x123456789> longBuffers; + + // CHECK: [[word:%.*]] = alloca target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32), align 4 + Int word; + + // CHECK: [[words:%.*]] = alloca [4 x target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32)], align 4 + Int words[4]; + + // CHECK: [[words2:%.*]] = alloca target("spirv.Type", target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32), target("spirv.IntegralConstant", i64, 5), 28, 0, 0), align 4 + ArrayInt words2; + + // CHECK: [[value:%.*]] = alloca %struct.S, align 1 + S value; + + // CHECK: [[buffers2:%.*]] = alloca target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0), align 4 + // CHECK: [[word2:%.*]] = alloca target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32), align 4 + + + // CHECK: [[loaded:%[0-9]+]] = load target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0), ptr [[buffers]], align 4 + // CHECK: %call1 = call spir_func target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0) @_Z14getArrayBufferu17spirv_type_28_0_0U5_TypeN4hlsl8RWBufferIfEEU6_ConstLm4E(target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0) [[loaded]]) + // CHECK: store target("spirv.Type", target("spirv.Image", float, 5, 2, 0, 0, 2, 0), target("spirv.IntegralConstant", i64, 4), 28, 0, 0) %call1, ptr [[buffers2]], align 4 + ArrayBuffer<4> buffers2 = getArrayBuffer(buffers); + + // CHECK: [[loaded:%[0-9]+]] = load target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32), ptr [[word]], align 4 + // CHECK: %call2 = call spir_func target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32) @_Z6getIntu18spirv_type_21_4_32U4_LitLi32EU4_LitLi0E(target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32) [[loaded]]) + // CHECK: store target("spirv.Type", target("spirv.Literal", 32), target("spirv.Literal", 0), 21, 4, 32) %call2, ptr [[word2]], align 4 + Int word2 = getInt(word); +} diff --git a/clang/test/SemaHLSL/inline-spirv/SpirvType.dx.error.hlsl b/clang/test/SemaHLSL/inline-spirv/SpirvType.dx.error.hlsl new file mode 100644 index 0000000000000..a94566a0b13c4 --- /dev/null +++ b/clang/test/SemaHLSL/inline-spirv/SpirvType.dx.error.hlsl @@ -0,0 +1,10 @@ +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: dxil-pc-shadermodel6.0-compute %s \ +// RUN: -fsyntax-only -verify + +typedef vk::SpirvType<12, 2, 4, float> InvalidType; // expected-error {{use of undeclared identifier 'vk'}} + +[numthreads(1, 1, 1)] +void main() { + __hlsl_spirv_type<12, 2, 4, float> InvalidValue; // expected-error {{'__hlsl_spirv_type' is only available for the SPIR-V target}} +} diff --git a/clang/test/SemaHLSL/inline-spirv/SpirvType.incomplete.hlsl b/clang/test/SemaHLSL/inline-spirv/SpirvType.incomplete.hlsl new file mode 100644 index 0000000000000..5980023ac6086 --- /dev/null +++ b/clang/test/SemaHLSL/inline-spirv/SpirvType.incomplete.hlsl @@ -0,0 +1,14 @@ +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: spirv-unknown-vulkan-compute %s -fsyntax-only -verify + +struct S; // expected-note {{forward declaration of 'S'}} + +// expected-error@hlsl/hlsl_spirv.h:24 {{argument type 'S' is incomplete}} + +typedef vk::SpirvOpaqueType</* OpTypeArray */ 28, S, vk::integral_constant<uint, 4>> ArrayOfS; // #1 +// expected-note@#1 {{in instantiation of template type alias 'SpirvOpaqueType' requested here}} + +[numthreads(1, 1, 1)] +void main() { + ArrayOfS buffers; +} diff --git a/clang/test/SemaHLSL/inline-spirv/SpirvType.literal.error.hlsl b/clang/test/SemaHLSL/inline-spirv/SpirvType.literal.error.hlsl new file mode 100644 index 0000000000000..331f19a427384 --- /dev/null +++ b/clang/test/SemaHLSL/inline-spirv/SpirvType.literal.error.hlsl @@ -0,0 +1,11 @@ +// RUN: %clang_cc1 -finclude-default-header -x hlsl -triple \ +// RUN: spirv-unknown-vulkan-compute %s -fsyntax-only -verify + +// expected-error@hlsl/hlsl_spirv.h:18 {{the argument to vk::Literal must be a vk::integral_constant}} + +typedef vk::SpirvOpaqueType<28, vk::Literal<float>> T; // #1 +// expected-note@#1 {{in instantiation of template type alias 'SpirvOpaqueType' requested here}} + +[numthreads(1, 1, 1)] +void main() { +} diff --git a/clang/tools/libclang/CIndex.cpp b/clang/tools/libclang/CIndex.cpp index 1d9d1027521bc..a7ff7527a3ea0 100644 --- a/clang/tools/libclang/CIndex.cpp +++ b/clang/tools/libclang/CIndex.cpp @@ -1794,6 +1794,11 @@ bool CursorVisitor::VisitHLSLAttributedResourceTypeLoc( return Visit(TL.getWrappedLoc()); } +bool CursorVisitor::VisitHLSLInlineSpirvTypeLoc(HLSLInlineSpirvTypeLoc TL) { + // Nothing to do. + return false; +} + bool CursorVisitor::VisitFunctionTypeLoc(FunctionTypeLoc TL, bool SkipResultType) { if (!SkipResultType && Visit(TL.getReturnLoc())) diff --git a/clang/tools/libclang/CXType.cpp b/clang/tools/libclang/CXType.cpp index ffa942d10669c..586d7edf93343 100644 --- a/clang/tools/libclang/CXType.cpp +++ b/clang/tools/libclang/CXType.cpp @@ -653,6 +653,7 @@ CXString clang_getTypeKindSpelling(enum CXTypeKind K) { TKIND(Attributed); TKIND(BTFTagAttributed); TKIND(HLSLAttributedResource); + TKIND(HLSLInlineSpirv); TKIND(BFloat16); #define IMAGE_TYPE(ImgType, Id, SingletonId, Access, Suffix) TKIND(Id); #include "clang/Basic/OpenCLImageTypes.def" diff --git a/clang/utils/TableGen/ClangBuiltinTemplatesEmitter.cpp b/clang/utils/TableGen/ClangBuiltinTemplatesEmitter.cpp index e2eb65091bc5a..e4e6dbec7b763 100644 --- a/clang/utils/TableGen/ClangBuiltinTemplatesEmitter.cpp +++ b/clang/utils/TableGen/ClangBuiltinTemplatesEmitter.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "TableGenBackends.h" +#include "llvm/ADT/StringSet.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/TableGenBackend.h" @@ -21,11 +22,14 @@ using namespace llvm; static std::string TemplateNameList; static std::string CreateBuiltinTemplateParameterList; +static llvm::StringSet BuiltinClasses; + namespace { struct ParserState { size_t UniqueCounter = 0; size_t CurrentDepth = 0; bool EmittedSizeTInfo = false; + bool EmittedUint32TInfo = false; }; std::pair<std::string, std::string> @@ -61,7 +65,7 @@ ParseTemplateParameterList(ParserState &PS, auto Type = Arg->getValueAsString("TypeName"); if (!TemplateNameToParmName.contains(Type.str())) - PrintFatalError("Unkown Type Name"); + PrintFatalError("Unknown Type Name"); auto TSIName = "TSI" + std::to_string(PS.UniqueCounter++); Code << " auto *" << TSIName << " = C.getTrivialTypeSourceInfo(QualType(" @@ -73,19 +77,32 @@ ParseTemplateParameterList(ParserState &PS, << TSIName << "->getType(), " << Arg->getValueAsBit("IsVariadic") << ", " << TSIName << ");\n"; } else if (Arg->isSubClassOf("BuiltinNTTP")) { - if (Arg->getValueAsString("TypeName") != "size_t") - PrintFatalError("Unkown Type Name"); - if (!PS.EmittedSizeTInfo) { - Code << "TypeSourceInfo *SizeTInfo = " - "C.getTrivialTypeSourceInfo(C.getSizeType());\n"; - PS.EmittedSizeTInfo = true; + std::string SourceInfo; + if (Arg->getValueAsString("TypeName") == "size_t") { + SourceInfo = "SizeTInfo"; + if (!PS.EmittedSizeTInfo) { + Code << "TypeSourceInfo *SizeTInfo = " + "C.getTrivialTypeSourceInfo(C.getSizeType());\n"; + PS.EmittedSizeTInfo = true; + } + } else if (Arg->getValueAsString("TypeName") == "uint32_t") { + SourceInfo = "Uint32TInfo"; + if (!PS.EmittedUint32TInfo) { + Code << "TypeSourceInfo *Uint32TInfo = " + "C.getTrivialTypeSourceInfo(C.UnsignedIntTy);\n"; + PS.EmittedUint32TInfo = true; + } + } else { + PrintFatalError("Unknown Type Name"); } Code << " auto *" << ParmName << " = NonTypeTemplateParmDecl::Create(C, DC, SourceLocation(), " "SourceLocation(), " - << PS.CurrentDepth << ", " << Position++ - << ", /*Id=*/nullptr, SizeTInfo->getType(), " - "/*ParameterPack=*/false, SizeTInfo);\n"; + << PS.CurrentDepth << ", " << Position++ << ", /*Id=*/nullptr, " + << SourceInfo + << "->getType(), " + "/*ParameterPack=*/false, " + << SourceInfo << ");\n"; } else { PrintFatalError("Unknown Argument Type"); } @@ -132,7 +149,8 @@ EmitCreateBuiltinTemplateParameterList(std::vector<const Record *> TemplateArgs, CreateBuiltinTemplateParameterList += " }\n"; } -void EmitBuiltinTemplate(raw_ostream &OS, const Record *BuiltinTemplate) { +void EmitBuiltinTemplate(const Record *BuiltinTemplate) { + auto Class = BuiltinTemplate->getType()->getAsString(); auto Name = BuiltinTemplate->getName(); std::vector<const Record *> TemplateHead = @@ -140,21 +158,49 @@ void EmitBuiltinTemplate(raw_ostream &OS, const Record *BuiltinTemplate) { EmitCreateBuiltinTemplateParameterList(TemplateHead, Name); - TemplateNameList += "BuiltinTemplate("; + TemplateNameList += Class + "("; TemplateNameList += Name; TemplateNameList += ")\n"; + + BuiltinClasses.insert(Class); +} + +void EmitDefaultDefine(llvm::raw_ostream &OS, StringRef Name) { + OS << "#ifndef " << Name << "\n"; + OS << "#define " << Name << "(NAME)" << " " << "BuiltinTemplate" + << "(NAME)\n"; + OS << "#endif\n\n"; +} + +void EmitUndef(llvm::raw_ostream &OS, StringRef Name) { + OS << "#undef " << Name << "\n"; } } // namespace void clang::EmitClangBuiltinTemplates(const llvm::RecordKeeper &Records, llvm::raw_ostream &OS) { emitSourceFileHeader("Tables and code for Clang's builtin templates", OS); + for (const auto *Builtin : Records.getAllDerivedDefinitions("BuiltinTemplate")) - EmitBuiltinTemplate(OS, Builtin); + EmitBuiltinTemplate(Builtin); + + for (const auto &ClassEntry : BuiltinClasses) { + StringRef Class = ClassEntry.getKey(); + if (Class == "BuiltinTemplate") + continue; + EmitDefaultDefine(OS, Class); + } OS << "#if defined(CREATE_BUILTIN_TEMPLATE_PARAMETER_LIST)\n" << CreateBuiltinTemplateParameterList << "#undef CREATE_BUILTIN_TEMPLATE_PARAMETER_LIST\n#else\n" << TemplateNameList << "#undef BuiltinTemplate\n#endif\n"; + + for (const auto &ClassEntry : BuiltinClasses) { + StringRef Class = ClassEntry.getKey(); + if (Class == "BuiltinTemplate") + continue; + EmitUndef(OS, Class); + } } _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits