llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang Author: Farzon Lotfi (farzonl) <details> <summary>Changes</summary> fixes #<!-- -->159434 In HLSL matrices are `matrix_type` in all respects except that they support a constructor style syntax for initializing matrices. This change adds a translation of vector constructor arguments into initializer lists. This supports the following HLSL syntax: (1) HLSL matrices support constructor syntax (2) HLSL matrices are expanded to constituate components in constructor --- Patch is 47.17 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/160960.diff 6 Files Affected: - (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+4-4) - (modified) clang/include/clang/Sema/Initialization.h (+8-4) - (modified) clang/lib/Sema/CheckExprLifetime.cpp (+1) - (modified) clang/lib/Sema/SemaInit.cpp (+186-13) - (added) clang/test/AST/HLSL/matrix-constructors.hlsl (+338) - (added) clang/test/SemaHLSL/BuiltIns/matrix-constructors-errors.hlsl (+24) ``````````diff diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index bd0e53d3086b0..19e4c548d0208 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -6543,9 +6543,9 @@ def warn_extern_init : Warning<"'extern' variable has an initializer">, def err_variable_object_no_init : Error< "variable-sized object may not be initialized">; def err_excess_initializers : Error< - "excess elements in %select{array|vector|scalar|union|struct}0 initializer">; + "excess elements in %select{array|vector|matrix|scalar|union|struct}0 initializer">; def ext_excess_initializers : ExtWarn< - "excess elements in %select{array|vector|scalar|union|struct}0 initializer">, + "excess elements in %select{array|vector|matrix|scalar|union|struct}0 initializer">, InGroup<ExcessInitializers>; def err_excess_initializers_for_sizeless_type : Error< "excess elements in initializer for indivisible sizeless type %0">; @@ -11086,8 +11086,8 @@ def err_first_argument_to_cwsc_pdtor_call : Error< def err_second_argument_to_cwsc_not_pointer : Error< "second argument to __builtin_call_with_static_chain must be of pointer type">; -def err_vector_incorrect_num_elements : Error< - "%select{too many|too few}0 elements in vector %select{initialization|operand}3 (expected %1 elements, have %2)">; +def err_tensor_incorrect_num_elements : Error< + "%select{too many|too few}0 elements in %select{vector|matrix}1 %select{initialization|operand}4 (expected %2 elements, have %3)">; def err_altivec_empty_initializer : Error<"expected initializer">; def err_vector_incorrect_bit_count : Error< diff --git a/clang/include/clang/Sema/Initialization.h b/clang/include/clang/Sema/Initialization.h index d7675ea153afb..865b6428f3081 100644 --- a/clang/include/clang/Sema/Initialization.h +++ b/clang/include/clang/Sema/Initialization.h @@ -91,6 +91,10 @@ class alignas(8) InitializedEntity { /// or vector. EK_VectorElement, + /// The entity being initialized is an element of a matrix. + /// or matrix. + EK_MatrixElement, + /// The entity being initialized is a field of block descriptor for /// the copied-in c++ object. EK_BlockElement, @@ -205,8 +209,8 @@ class alignas(8) InitializedEntity { /// virtual base. llvm::PointerIntPair<const CXXBaseSpecifier *, 1> Base; - /// When Kind == EK_ArrayElement, EK_VectorElement, or - /// EK_ComplexElement, the index of the array or vector element being + /// When Kind == EK_ArrayElement, EK_VectorElement, or EK_MatrixElement, + /// or EK_ComplexElement, the index of the array or vector element being /// initialized. unsigned Index; @@ -536,7 +540,7 @@ class alignas(8) InitializedEntity { /// element's index. unsigned getElementIndex() const { assert(getKind() == EK_ArrayElement || getKind() == EK_VectorElement || - getKind() == EK_ComplexElement); + getKind() == EK_MatrixElement || getKind() == EK_ComplexElement); return Index; } @@ -544,7 +548,7 @@ class alignas(8) InitializedEntity { /// element, sets the element index. void setElementIndex(unsigned Index) { assert(getKind() == EK_ArrayElement || getKind() == EK_VectorElement || - getKind() == EK_ComplexElement); + getKind() == EK_MatrixElement || getKind() == EK_ComplexElement); this->Index = Index; } diff --git a/clang/lib/Sema/CheckExprLifetime.cpp b/clang/lib/Sema/CheckExprLifetime.cpp index e02e00231e58e..d647fbf007838 100644 --- a/clang/lib/Sema/CheckExprLifetime.cpp +++ b/clang/lib/Sema/CheckExprLifetime.cpp @@ -154,6 +154,7 @@ getEntityLifetime(const InitializedEntity *Entity, case InitializedEntity::EK_LambdaToBlockConversionBlockElement: case InitializedEntity::EK_LambdaCapture: case InitializedEntity::EK_VectorElement: + case clang::InitializedEntity::EK_MatrixElement: case InitializedEntity::EK_ComplexElement: return {nullptr, LK_FullExpression}; diff --git a/clang/lib/Sema/SemaInit.cpp b/clang/lib/Sema/SemaInit.cpp index c97129336736b..0f06d715600b3 100644 --- a/clang/lib/Sema/SemaInit.cpp +++ b/clang/lib/Sema/SemaInit.cpp @@ -17,6 +17,7 @@ #include "clang/AST/ExprCXX.h" #include "clang/AST/ExprObjC.h" #include "clang/AST/IgnoreExpr.h" +#include "clang/AST/TypeBase.h" #include "clang/AST/TypeLoc.h" #include "clang/Basic/SourceManager.h" #include "clang/Basic/Specifiers.h" @@ -403,6 +404,9 @@ class InitListChecker { unsigned &Index, InitListExpr *StructuredList, unsigned &StructuredIndex); + void CheckMatrixType(const InitializedEntity &Entity, InitListExpr *IList, + QualType DeclType, unsigned &Index, + InitListExpr *StructuredList, unsigned &StructuredIndex); void CheckVectorType(const InitializedEntity &Entity, InitListExpr *IList, QualType DeclType, unsigned &Index, InitListExpr *StructuredList, @@ -1003,7 +1007,8 @@ InitListChecker::FillInEmptyInitializations(const InitializedEntity &Entity, return; if (ElementEntity.getKind() == InitializedEntity::EK_ArrayElement || - ElementEntity.getKind() == InitializedEntity::EK_VectorElement) + ElementEntity.getKind() == InitializedEntity::EK_VectorElement || + ElementEntity.getKind() == InitializedEntity::EK_MatrixElement) ElementEntity.setElementIndex(Init); if (Init >= NumInits && (ILE->hasArrayFiller() || SkipEmptyInitChecks)) @@ -1273,6 +1278,7 @@ static void warnBracedScalarInit(Sema &S, const InitializedEntity &Entity, switch (Entity.getKind()) { case InitializedEntity::EK_VectorElement: + case InitializedEntity::EK_MatrixElement: case InitializedEntity::EK_ComplexElement: case InitializedEntity::EK_ArrayElement: case InitializedEntity::EK_Parameter: @@ -1372,11 +1378,12 @@ void InitListChecker::CheckExplicitInitList(const InitializedEntity &Entity, SemaRef.Diag(IList->getInit(Index)->getBeginLoc(), DK) << T << IList->getInit(Index)->getSourceRange(); } else { - int initKind = T->isArrayType() ? 0 : - T->isVectorType() ? 1 : - T->isScalarType() ? 2 : - T->isUnionType() ? 3 : - 4; + int initKind = T->isArrayType() ? 0 + : T->isVectorType() ? 1 + : T->isMatrixType() ? 2 + : T->isScalarType() ? 3 + : T->isUnionType() ? 4 + : 5; unsigned DK = ExtraInitsIsError ? diag::err_excess_initializers : diag::ext_excess_initializers; @@ -1430,6 +1437,9 @@ void InitListChecker::CheckListElementTypes(const InitializedEntity &Entity, } else if (DeclType->isVectorType()) { CheckVectorType(Entity, IList, DeclType, Index, StructuredList, StructuredIndex); + } else if (DeclType->isMatrixType()) { + CheckMatrixType(Entity, IList, DeclType, Index, StructuredList, + StructuredIndex); } else if (const RecordDecl *RD = DeclType->getAsRecordDecl()) { auto Bases = CXXRecordDecl::base_class_const_range(CXXRecordDecl::base_class_const_iterator(), @@ -1877,6 +1887,93 @@ void InitListChecker::CheckReferenceType(const InitializedEntity &Entity, AggrDeductionCandidateParamTypes->push_back(DeclType); } +void InitListChecker::CheckMatrixType(const InitializedEntity &Entity, + InitListExpr *IList, QualType DeclType, + unsigned &Index, + InitListExpr *StructuredList, + unsigned &StructuredIndex) { + if (!SemaRef.getLangOpts().HLSL) + return; + + const ConstantMatrixType *MT = DeclType->castAs<ConstantMatrixType>(); + QualType ElemTy = MT->getElementType(); + const unsigned Rows = MT->getNumRows(); + const unsigned Cols = MT->getNumColumns(); + const unsigned MaxElts = Rows * Cols; + + unsigned NumEltsInit = 0; + InitializedEntity ElemEnt = + InitializedEntity::InitializeElement(SemaRef.Context, 0, Entity); + + // A Matrix initalizer should be able to take scalars, vectors, and matrices. + auto HandleInit = [&](InitListExpr *List, unsigned &Idx) { + Expr *Init = List->getInit(Idx); + QualType ITy = Init->getType(); + + if (ITy->isVectorType()) { + const VectorType *IVT = ITy->castAs<VectorType>(); + unsigned N = IVT->getNumElements(); + QualType VTy = + ITy->isExtVectorType() + ? SemaRef.Context.getExtVectorType(ElemTy, N) + : SemaRef.Context.getVectorType(ElemTy, N, IVT->getVectorKind()); + ElemEnt.setElementIndex(Idx); + CheckSubElementType(ElemEnt, List, VTy, Idx, StructuredList, + StructuredIndex); + NumEltsInit += N; + return; + } + + if (ITy->isMatrixType()) { + const ConstantMatrixType *IMT = ITy->castAs<ConstantMatrixType>(); + unsigned N = IMT->getNumRows() * IMT->getNumColumns(); + QualType MTy = SemaRef.Context.getConstantMatrixType( + ElemTy, IMT->getNumRows(), IMT->getNumColumns()); + ElemEnt.setElementIndex(Idx); + CheckSubElementType(ElemEnt, List, MTy, Idx, StructuredList, + StructuredIndex); + NumEltsInit += N; + return; + } + + // Scalar element + ElemEnt.setElementIndex(Idx); + CheckSubElementType(ElemEnt, List, ElemTy, Idx, StructuredList, + StructuredIndex); + ++NumEltsInit; + }; + + // Column-major: each top-level sublist is treated as a column. + while (NumEltsInit < MaxElts && Index < IList->getNumInits()) { + Expr *Init = IList->getInit(Index); + + if (auto *SubList = dyn_cast<InitListExpr>(Init)) { + unsigned SubIdx = 0; + unsigned Row = 0; + while (Row < Rows && SubIdx < SubList->getNumInits() && + NumEltsInit < MaxElts) { + HandleInit(SubList, SubIdx); + ++Row; + } + ++Index; // advance past this column sublist + continue; + } + + // Not a sublist: just consume directly. + HandleInit(IList, Index); + } + + // HLSL requires exactly Rows*Cols initializers after flattening. + if (NumEltsInit != MaxElts) { + if (!VerifyOnly) + SemaRef.Diag(IList->getBeginLoc(), + diag::err_tensor_incorrect_num_elements) + << (NumEltsInit < MaxElts) << /*matrix*/ 1 << MaxElts << NumEltsInit + << /*initialization*/ 0; + hadError = true; + } +} + void InitListChecker::CheckVectorType(const InitializedEntity &Entity, InitListExpr *IList, QualType DeclType, unsigned &Index, @@ -2026,9 +2123,9 @@ void InitListChecker::CheckVectorType(const InitializedEntity &Entity, if (numEltsInit != maxElements) { if (!VerifyOnly) SemaRef.Diag(IList->getBeginLoc(), - diag::err_vector_incorrect_num_elements) - << (numEltsInit < maxElements) << maxElements << numEltsInit - << /*initialization*/ 0; + diag::err_tensor_incorrect_num_elements) + << (numEltsInit < maxElements) << /*vector*/ 0 << maxElements + << numEltsInit << /*initialization*/ 0; hadError = true; } } @@ -3639,6 +3736,9 @@ InitializedEntity::InitializedEntity(ASTContext &Context, unsigned Index, } else if (const VectorType *VT = Parent.getType()->getAs<VectorType>()) { Kind = EK_VectorElement; Type = VT->getElementType(); + } else if (const MatrixType *MT = Parent.getType()->getAs<MatrixType>()) { + Kind = EK_MatrixElement; + Type = MT->getElementType(); } else { const ComplexType *CT = Parent.getType()->getAs<ComplexType>(); assert(CT && "Unexpected type"); @@ -3687,6 +3787,7 @@ DeclarationName InitializedEntity::getName() const { case EK_Delegating: case EK_ArrayElement: case EK_VectorElement: + case EK_MatrixElement: case EK_ComplexElement: case EK_BlockElement: case EK_LambdaToBlockConversionBlockElement: @@ -3720,6 +3821,7 @@ ValueDecl *InitializedEntity::getDecl() const { case EK_Delegating: case EK_ArrayElement: case EK_VectorElement: + case EK_MatrixElement: case EK_ComplexElement: case EK_BlockElement: case EK_LambdaToBlockConversionBlockElement: @@ -3753,6 +3855,7 @@ bool InitializedEntity::allowsNRVO() const { case EK_Delegating: case EK_ArrayElement: case EK_VectorElement: + case EK_MatrixElement: case EK_ComplexElement: case EK_BlockElement: case EK_LambdaToBlockConversionBlockElement: @@ -3792,6 +3895,9 @@ unsigned InitializedEntity::dumpImpl(raw_ostream &OS) const { case EK_Delegating: OS << "Delegating"; break; case EK_ArrayElement: OS << "ArrayElement " << Index; break; case EK_VectorElement: OS << "VectorElement " << Index; break; + case EK_MatrixElement: + OS << "MatrixElement " << Index; + break; case EK_ComplexElement: OS << "ComplexElement " << Index; break; case EK_BlockElement: OS << "Block"; break; case EK_LambdaToBlockConversionBlockElement: @@ -6847,6 +6953,67 @@ void InitializationSequence::InitializeFrom(Sema &S, return; } + if (S.getLangOpts().HLSL && DestType->isMatrixType() && + (SourceType.isNull() || + !Context.hasSameUnqualifiedType(SourceType, DestType))) { + + llvm::SmallVector<Expr *> InitArgs; + + for (Expr *Arg : Args) { + QualType AT = Arg->getType(); + + // Expand matrix arguments element-by-element (col-major). + if (AT->isMatrixType()) { + const auto *MTy = AT->castAs<ConstantMatrixType>(); + unsigned Rows = MTy->getNumRows(); + unsigned Cols = MTy->getNumColumns(); + QualType ElemTy = MTy->getElementType(); + + for (unsigned c = 0; c < Cols; ++c) { + for (unsigned r = 0; r < Rows; ++r) { + // row index literal + Expr *RowIdx = IntegerLiteral::Create( + Context, llvm::APInt(Context.getIntWidth(Context.IntTy), r), + Context.IntTy, Arg->getBeginLoc()); + // column index literal + Expr *ColIdx = IntegerLiteral::Create( + Context, llvm::APInt(Context.getIntWidth(Context.IntTy), c), + Context.IntTy, Arg->getBeginLoc()); + + InitArgs.emplace_back(new (Context) MatrixSubscriptExpr( + Arg, RowIdx, ColIdx, ElemTy, Arg->getEndLoc())); + } + } + + // Keep your vector expansion, in case vectors appear in the argument + // list. + } else if (AT->isExtVectorType()) { + const auto *VTy = AT->castAs<ExtVectorType>(); + unsigned Elm = VTy->getNumElements(); + for (unsigned Idx = 0; Idx < Elm; ++Idx) { + InitArgs.emplace_back(new (Context) ArraySubscriptExpr( + Arg, + IntegerLiteral::Create( + Context, llvm::APInt(Context.getIntWidth(Context.IntTy), Idx), + Context.IntTy, Arg->getBeginLoc()), + VTy->getElementType(), Arg->getValueKind(), Arg->getObjectKind(), + Arg->getEndLoc())); + } + + } else { + // Scalar or other: forward as-is + InitArgs.emplace_back(Arg); + } + } + + InitListExpr *ILE = new (Context) InitListExpr( + S.getASTContext(), SourceLocation(), InitArgs, SourceLocation()); + + Args[0] = ILE; + AddListInitializationStep(DestType); + return; + } + // The remaining cases all need a source type. if (Args.size() > 1) { SetFailed(FK_TooManyInitsForScalar); @@ -6999,6 +7166,7 @@ static AssignmentAction getAssignmentAction(const InitializedEntity &Entity, case InitializedEntity::EK_Binding: case InitializedEntity::EK_ArrayElement: case InitializedEntity::EK_VectorElement: + case InitializedEntity::EK_MatrixElement: case InitializedEntity::EK_ComplexElement: case InitializedEntity::EK_BlockElement: case InitializedEntity::EK_LambdaToBlockConversionBlockElement: @@ -7024,6 +7192,7 @@ static bool shouldBindAsTemporary(const InitializedEntity &Entity) { case InitializedEntity::EK_Base: case InitializedEntity::EK_Delegating: case InitializedEntity::EK_VectorElement: + case InitializedEntity::EK_MatrixElement: case InitializedEntity::EK_ComplexElement: case InitializedEntity::EK_Exception: case InitializedEntity::EK_BlockElement: @@ -7054,6 +7223,7 @@ static bool shouldDestroyEntity(const InitializedEntity &Entity) { case InitializedEntity::EK_Base: case InitializedEntity::EK_Delegating: case InitializedEntity::EK_VectorElement: + case InitializedEntity::EK_MatrixElement: case InitializedEntity::EK_ComplexElement: case InitializedEntity::EK_BlockElement: case InitializedEntity::EK_LambdaToBlockConversionBlockElement: @@ -7107,6 +7277,7 @@ static SourceLocation getInitializationLoc(const InitializedEntity &Entity, case InitializedEntity::EK_Base: case InitializedEntity::EK_Delegating: case InitializedEntity::EK_VectorElement: + case InitializedEntity::EK_MatrixElement: case InitializedEntity::EK_ComplexElement: case InitializedEntity::EK_BlockElement: case InitializedEntity::EK_LambdaToBlockConversionBlockElement: @@ -7858,9 +8029,11 @@ ExprResult InitializationSequence::Perform(Sema &S, // HLSL allows vector initialization to function like list initialization, but // use the syntax of a C++-like constructor. - bool IsHLSLVectorInit = S.getLangOpts().HLSL && DestType->isExtVectorType() && - isa<InitListExpr>(Args[0]); - (void)IsHLSLVectorInit; + bool IsHLSLVectorOrMatrixInit = + S.getLangOpts().HLSL && + (DestType->isExtVectorType() || DestType->isMatrixType()) && + isa<InitListExpr>(Args[0]); + (void)IsHLSLVectorOrMatrixInit; // For initialization steps that start with a single initializer, // grab the only argument out the Args and place it into the "current" @@ -7899,7 +8072,7 @@ ExprResult InitializationSequence::Perform(Sema &S, case SK_StdInitializerList: case SK_OCLSamplerInit: case SK_OCLZeroOpaqueType: { - assert(Args.size() == 1 || IsHLSLVectorInit); + assert(Args.size() == 1 || IsHLSLVectorOrMatrixInit); CurInit = Args[0]; if (!CurInit.get()) return ExprError(); break; diff --git a/clang/test/AST/HLSL/matrix-constructors.hlsl b/clang/test/AST/HLSL/matrix-constructors.hlsl new file mode 100644 index 0000000000000..faee0162a314b --- /dev/null +++ b/clang/test/AST/HLSL/matrix-constructors.hlsl @@ -0,0 +1,338 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -x hlsl -ast-dump -o - %s | FileCheck %s + +typedef float float2x1 __attribute__((matrix_type(2,1))); +typedef float float2x3 __attribute__((matrix_type(2,3))); +typedef float float2x2 __attribute__((matrix_type(2,2))); +typedef float float2 __attribute__((ext_vector_type(2))); +typedef float float4 __attribute__((ext_vector_type(4))); + +[numthreads(1,1,1)] +void ok() { + + // CHECK: VarDecl 0x{{[0-9a-fA-F]+}} <col:3, col:36> col:12 A 'float2x3':'matrix<float, 2, 3>' cinit + // CHECK-NEXT: CXXFunctionalCastExpr 0x{{[0-9a-fA-F]+}} <col:16, col:36> 'float2x3':'matrix<float, 2, 3>' functional cast to float2x3 <NoOp> + // CHECK-NEXT: InitListExpr 0x{{[0-9a-fA-F]+}} <col:25, col:35> 'float2x3':'matrix<float, 2, 3>' + // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:25> 'float' <IntegralToFloating> + // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:25> 'int' 1 + // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:27> 'float' <IntegralToFloating> + // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:27> 'int' 2 + // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:29> 'float' <IntegralToFloating> + // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:29> 'int' 3 + // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:31> 'float' <IntegralToFloating> + // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:31> 'int' 4 + // CHECK-NEXT: ImplicitCastExpr 0x{{[0-9a-fA-F]+}} <col:33> 'float' <IntegralToFloating> + // CHECK-NEXT: IntegerLiteral 0x{{[0-9a-fA-F]+}} <col:33> ... [truncated] `````````` </details> https://github.com/llvm/llvm-project/pull/160960 _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
