https://github.com/jurahul updated https://github.com/llvm/llvm-project/pull/144930
>From f5216d4c55c4dffa8785ff2fa051492ed98f405a Mon Sep 17 00:00:00 2001 From: Rahul Joshi <rjo...@nvidia.com> Date: Thu, 19 Jun 2025 10:25:12 -0700 Subject: [PATCH] [LLVM][Clang] Add and enable strict mode for `getTrailingObjects` Under strict mode, the templated `getTrailingObjects` can be called only when there is > 1 trailing types. The strict mode can be disabled on a per-call basis when its not possible to know statically if there will be a single or multiple trailing types (like in OpenMPClause.h). --- clang/include/clang/AST/OpenMPClause.h | 46 ++++++++++++------- clang/lib/AST/Expr.cpp | 3 +- llvm/include/llvm/Support/TrailingObjects.h | 21 +++++---- .../unittests/Support/TrailingObjectsTest.cpp | 5 +- 4 files changed, 47 insertions(+), 28 deletions(-) diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h index 2fa8fa529741e..b62ebd614e4c7 100644 --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -295,7 +295,8 @@ template <class T> class OMPVarListClause : public OMPClause { /// Fetches list of variables associated with this clause. MutableArrayRef<Expr *> getVarRefs() { - return static_cast<T *>(this)->template getTrailingObjects<Expr *>(NumVars); + return static_cast<T *>(this) + ->template getTrailingObjects<Expr *, /*Strict=*/false>(NumVars); } /// Sets the list of variables for this clause. @@ -334,8 +335,8 @@ template <class T> class OMPVarListClause : public OMPClause { /// Fetches list of all variables in the clause. ArrayRef<const Expr *> getVarRefs() const { - return static_cast<const T *>(this)->template getTrailingObjects<Expr *>( - NumVars); + return static_cast<const T *>(this) + ->template getTrailingObjects<Expr *, /*Strict=*/false>(NumVars); } }; @@ -380,7 +381,8 @@ template <class T> class OMPDirectiveListClause : public OMPClause { MutableArrayRef<OpenMPDirectiveKind> getDirectiveKinds() { return static_cast<T *>(this) - ->template getTrailingObjects<OpenMPDirectiveKind>(NumKinds); + ->template getTrailingObjects<OpenMPDirectiveKind, /*Strict=*/false>( + NumKinds); } void setDirectiveKinds(ArrayRef<OpenMPDirectiveKind> DK) { @@ -5901,15 +5903,17 @@ class OMPMappableExprListClause : public OMPVarListClause<T>, /// Get the unique declarations that are in the trailing objects of the /// class. MutableArrayRef<ValueDecl *> getUniqueDeclsRef() { - return static_cast<T *>(this)->template getTrailingObjects<ValueDecl *>( - NumUniqueDeclarations); + return static_cast<T *>(this) + ->template getTrailingObjects<ValueDecl *, /*Strict=*/false>( + NumUniqueDeclarations); } /// Get the unique declarations that are in the trailing objects of the /// class. ArrayRef<ValueDecl *> getUniqueDeclsRef() const { return static_cast<const T *>(this) - ->template getTrailingObjects<ValueDecl *>(NumUniqueDeclarations); + ->template getTrailingObjects<ValueDecl *, /*Strict=*/false>( + NumUniqueDeclarations); } /// Set the unique declarations that are in the trailing objects of the @@ -5923,15 +5927,17 @@ class OMPMappableExprListClause : public OMPVarListClause<T>, /// Get the number of lists per declaration that are in the trailing /// objects of the class. MutableArrayRef<unsigned> getDeclNumListsRef() { - return static_cast<T *>(this)->template getTrailingObjects<unsigned>( - NumUniqueDeclarations); + return static_cast<T *>(this) + ->template getTrailingObjects<unsigned, /*Strict=*/false>( + NumUniqueDeclarations); } /// Get the number of lists per declaration that are in the trailing /// objects of the class. ArrayRef<unsigned> getDeclNumListsRef() const { - return static_cast<const T *>(this)->template getTrailingObjects<unsigned>( - NumUniqueDeclarations); + return static_cast<const T *>(this) + ->template getTrailingObjects<unsigned, /*Strict=*/false>( + NumUniqueDeclarations); } /// Set the number of lists per declaration that are in the trailing @@ -5946,7 +5952,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>, /// objects of the class. They are appended after the number of lists. MutableArrayRef<unsigned> getComponentListSizesRef() { return MutableArrayRef<unsigned>( - static_cast<T *>(this)->template getTrailingObjects<unsigned>() + + static_cast<T *>(this) + ->template getTrailingObjects<unsigned, /*Strict=*/false>() + NumUniqueDeclarations, NumComponentLists); } @@ -5955,7 +5962,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>, /// objects of the class. They are appended after the number of lists. ArrayRef<unsigned> getComponentListSizesRef() const { return ArrayRef<unsigned>( - static_cast<const T *>(this)->template getTrailingObjects<unsigned>() + + static_cast<const T *>(this) + ->template getTrailingObjects<unsigned, /*Strict=*/false>() + NumUniqueDeclarations, NumComponentLists); } @@ -5971,13 +5979,15 @@ class OMPMappableExprListClause : public OMPVarListClause<T>, /// Get the components that are in the trailing objects of the class. MutableArrayRef<MappableComponent> getComponentsRef() { return static_cast<T *>(this) - ->template getTrailingObjects<MappableComponent>(NumComponents); + ->template getTrailingObjects<MappableComponent, /*Strict=*/false>( + NumComponents); } /// Get the components that are in the trailing objects of the class. ArrayRef<MappableComponent> getComponentsRef() const { return static_cast<const T *>(this) - ->template getTrailingObjects<MappableComponent>(NumComponents); + ->template getTrailingObjects<MappableComponent, /*Strict=*/false>( + NumComponents); } /// Set the components that are in the trailing objects of the class. @@ -6084,7 +6094,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>, assert(SupportsMapper && "Must be a clause that is possible to have user-defined mappers"); return llvm::MutableArrayRef<Expr *>( - static_cast<T *>(this)->template getTrailingObjects<Expr *>() + + static_cast<T *>(this) + ->template getTrailingObjects<Expr *, /*Strict=*/false>() + OMPVarListClause<T>::varlist_size(), OMPVarListClause<T>::varlist_size()); } @@ -6095,7 +6106,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>, assert(SupportsMapper && "Must be a clause that is possible to have user-defined mappers"); return llvm::ArrayRef<Expr *>( - static_cast<const T *>(this)->template getTrailingObjects<Expr *>() + + static_cast<const T *>(this) + ->template getTrailingObjects<Expr *, /*Strict=*/false>() + OMPVarListClause<T>::varlist_size(), OMPVarListClause<T>::varlist_size()); } diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp index c3722c65abf6e..b93a31ca4ed36 100644 --- a/clang/lib/AST/Expr.cpp +++ b/clang/lib/AST/Expr.cpp @@ -2024,7 +2024,8 @@ CXXBaseSpecifier **CastExpr::path_buffer() { #define ABSTRACT_STMT(x) #define CASTEXPR(Type, Base) \ case Stmt::Type##Class: \ - return static_cast<Type *>(this)->getTrailingObjects<CXXBaseSpecifier *>(); + return static_cast<Type *>(this) \ + ->getTrailingObjects<CXXBaseSpecifier *, /*Strict=*/false>(); #define STMT(Type, Base) #include "clang/AST/StmtNodes.inc" default: diff --git a/llvm/include/llvm/Support/TrailingObjects.h b/llvm/include/llvm/Support/TrailingObjects.h index f25f2311a81a4..3d701de93b4f1 100644 --- a/llvm/include/llvm/Support/TrailingObjects.h +++ b/llvm/include/llvm/Support/TrailingObjects.h @@ -282,7 +282,9 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl< /// Returns a pointer to the trailing object array of the given type /// (which must be one of those specified in the class template). The /// array may have zero or more elements in it. - template <typename T> const T *getTrailingObjects() const { + template <typename T, bool Strict = true> + const T *getTrailingObjects() const { + static_assert(!Strict || sizeof...(TrailingTys) > 1); verifyTrailingObjectsAssertions(); // Forwards to an impl function with overloads, since member // function templates can't be specialized. @@ -294,7 +296,8 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl< /// Returns a pointer to the trailing object array of the given type /// (which must be one of those specified in the class template). The /// array may have zero or more elements in it. - template <typename T> T *getTrailingObjects() { + template <typename T, bool Strict = true> T *getTrailingObjects() { + static_assert(!Strict || sizeof...(TrailingTys) > 1); verifyTrailingObjectsAssertions(); // Forwards to an impl function with overloads, since member // function templates can't be specialized. @@ -310,23 +313,25 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl< static_assert(sizeof...(TrailingTys) == 1, "Can use non-templated getTrailingObjects() only when there " "is a single trailing type"); - return getTrailingObjects<FirstTrailingType>(); + return getTrailingObjects<FirstTrailingType, /*Strict=*/false>(); } FirstTrailingType *getTrailingObjects() { static_assert(sizeof...(TrailingTys) == 1, "Can use non-templated getTrailingObjects() only when there " "is a single trailing type"); - return getTrailingObjects<FirstTrailingType>(); + return getTrailingObjects<FirstTrailingType, /*Strict=*/false>(); } // Functions that return the trailing objects as ArrayRefs. - template <typename T> MutableArrayRef<T> getTrailingObjects(size_t N) { - return MutableArrayRef(getTrailingObjects<T>(), N); + template <typename T, bool Strict = true> + MutableArrayRef<T> getTrailingObjects(size_t N) { + return MutableArrayRef(getTrailingObjects<T, Strict>(), N); } - template <typename T> ArrayRef<T> getTrailingObjects(size_t N) const { - return ArrayRef(getTrailingObjects<T>(), N); + template <typename T, bool Strict = true> + ArrayRef<T> getTrailingObjects(size_t N) const { + return ArrayRef(getTrailingObjects<T, Strict>(), N); } MutableArrayRef<FirstTrailingType> getTrailingObjects(size_t N) { diff --git a/llvm/unittests/Support/TrailingObjectsTest.cpp b/llvm/unittests/Support/TrailingObjectsTest.cpp index 2590f375b6598..83afb22f837aa 100644 --- a/llvm/unittests/Support/TrailingObjectsTest.cpp +++ b/llvm/unittests/Support/TrailingObjectsTest.cpp @@ -123,11 +123,12 @@ TEST(TrailingObjects, OneArg) { EXPECT_EQ(Class1::totalSizeToAlloc<short>(3), sizeof(Class1) + sizeof(short) * 3); - EXPECT_EQ(C->getTrailingObjects<short>(), reinterpret_cast<short *>(C + 1)); + EXPECT_EQ(C->getTrailingObjects(), reinterpret_cast<short *>(C + 1)); EXPECT_EQ(C->get(0), 1); EXPECT_EQ(C->get(2), 3); - EXPECT_EQ(C->getTrailingObjects(), C->getTrailingObjects<short>()); + EXPECT_EQ(C->getTrailingObjects(), + (C->getTrailingObjects<short, /*Strict=*/false>())); delete C; } _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits