https://github.com/jurahul updated https://github.com/llvm/llvm-project/pull/144930
>From ee982b8b2d14b1199f051db53aea4f26899d4d77 Mon Sep 17 00:00:00 2001 From: Rahul Joshi <rjo...@nvidia.com> Date: Thu, 19 Jun 2025 10:25:12 -0700 Subject: [PATCH] [LLVM][Clang] Add and enable strict mode for `getTrailingObjects` Under strict mode, the templated `getTrailingObjects` can be called only when there is > 1 trailing types. The strict mode can be disabled on a per-call basis when its not possible to know statically if there will be a single or multiple trailing types (like in OpenMPClause.h). --- clang/include/clang/AST/OpenMPClause.h | 37 ++++++++----- clang/lib/AST/Expr.cpp | 3 +- llvm/include/llvm/Support/TrailingObjects.h | 55 ++++++++++++++++--- .../unittests/Support/TrailingObjectsTest.cpp | 7 ++- 4 files changed, 74 insertions(+), 28 deletions(-) diff --git a/clang/include/clang/AST/OpenMPClause.h b/clang/include/clang/AST/OpenMPClause.h index c6f99fb21a0f0..5b2206af75bee 100644 --- a/clang/include/clang/AST/OpenMPClause.h +++ b/clang/include/clang/AST/OpenMPClause.h @@ -295,7 +295,8 @@ template <class T> class OMPVarListClause : public OMPClause { /// Fetches list of variables associated with this clause. MutableArrayRef<Expr *> getVarRefs() { - return static_cast<T *>(this)->template getTrailingObjects<Expr *>(NumVars); + return static_cast<T *>(this)->template getTrailingObjectsNonStrict<Expr *>( + NumVars); } /// Sets the list of variables for this clause. @@ -334,8 +335,8 @@ template <class T> class OMPVarListClause : public OMPClause { /// Fetches list of all variables in the clause. ArrayRef<const Expr *> getVarRefs() const { - return static_cast<const T *>(this)->template getTrailingObjects<Expr *>( - NumVars); + return static_cast<const T *>(this) + ->template getTrailingObjectsNonStrict<Expr *>(NumVars); } }; @@ -380,7 +381,7 @@ template <class T> class OMPDirectiveListClause : public OMPClause { MutableArrayRef<OpenMPDirectiveKind> getDirectiveKinds() { return static_cast<T *>(this) - ->template getTrailingObjects<OpenMPDirectiveKind>(NumKinds); + ->template getTrailingObjectsNonStrict<OpenMPDirectiveKind>(NumKinds); } void setDirectiveKinds(ArrayRef<OpenMPDirectiveKind> DK) { @@ -5921,15 +5922,17 @@ class OMPMappableExprListClause : public OMPVarListClause<T>, /// Get the unique declarations that are in the trailing objects of the /// class. MutableArrayRef<ValueDecl *> getUniqueDeclsRef() { - return static_cast<T *>(this)->template getTrailingObjects<ValueDecl *>( - NumUniqueDeclarations); + return static_cast<T *>(this) + ->template getTrailingObjectsNonStrict<ValueDecl *>( + NumUniqueDeclarations); } /// Get the unique declarations that are in the trailing objects of the /// class. ArrayRef<ValueDecl *> getUniqueDeclsRef() const { return static_cast<const T *>(this) - ->template getTrailingObjects<ValueDecl *>(NumUniqueDeclarations); + ->template getTrailingObjectsNonStrict<ValueDecl *>( + NumUniqueDeclarations); } /// Set the unique declarations that are in the trailing objects of the @@ -5943,15 +5946,15 @@ class OMPMappableExprListClause : public OMPVarListClause<T>, /// Get the number of lists per declaration that are in the trailing /// objects of the class. MutableArrayRef<unsigned> getDeclNumListsRef() { - return static_cast<T *>(this)->template getTrailingObjects<unsigned>( - NumUniqueDeclarations); + return static_cast<T *>(this) + ->template getTrailingObjectsNonStrict<unsigned>(NumUniqueDeclarations); } /// Get the number of lists per declaration that are in the trailing /// objects of the class. ArrayRef<unsigned> getDeclNumListsRef() const { - return static_cast<const T *>(this)->template getTrailingObjects<unsigned>( - NumUniqueDeclarations); + return static_cast<const T *>(this) + ->template getTrailingObjectsNonStrict<unsigned>(NumUniqueDeclarations); } /// Set the number of lists per declaration that are in the trailing @@ -5966,7 +5969,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>, /// objects of the class. They are appended after the number of lists. MutableArrayRef<unsigned> getComponentListSizesRef() { return MutableArrayRef<unsigned>( - static_cast<T *>(this)->template getTrailingObjects<unsigned>() + + static_cast<T *>(this) + ->template getTrailingObjectsNonStrict<unsigned>() + NumUniqueDeclarations, NumComponentLists); } @@ -5975,7 +5979,8 @@ class OMPMappableExprListClause : public OMPVarListClause<T>, /// objects of the class. They are appended after the number of lists. ArrayRef<unsigned> getComponentListSizesRef() const { return ArrayRef<unsigned>( - static_cast<const T *>(this)->template getTrailingObjects<unsigned>() + + static_cast<const T *>(this) + ->template getTrailingObjectsNonStrict<unsigned>() + NumUniqueDeclarations, NumComponentLists); } @@ -5991,13 +5996,15 @@ class OMPMappableExprListClause : public OMPVarListClause<T>, /// Get the components that are in the trailing objects of the class. MutableArrayRef<MappableComponent> getComponentsRef() { return static_cast<T *>(this) - ->template getTrailingObjects<MappableComponent>(NumComponents); + ->template getTrailingObjectsNonStrict<MappableComponent>( + NumComponents); } /// Get the components that are in the trailing objects of the class. ArrayRef<MappableComponent> getComponentsRef() const { return static_cast<const T *>(this) - ->template getTrailingObjects<MappableComponent>(NumComponents); + ->template getTrailingObjectsNonStrict<MappableComponent>( + NumComponents); } /// Set the components that are in the trailing objects of the class. diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp index 149b274f36b63..642867c0942b5 100644 --- a/clang/lib/AST/Expr.cpp +++ b/clang/lib/AST/Expr.cpp @@ -2020,7 +2020,8 @@ CXXBaseSpecifier **CastExpr::path_buffer() { #define ABSTRACT_STMT(x) #define CASTEXPR(Type, Base) \ case Stmt::Type##Class: \ - return static_cast<Type *>(this)->getTrailingObjects<CXXBaseSpecifier *>(); + return static_cast<Type *>(this) \ + ->getTrailingObjectsNonStrict<CXXBaseSpecifier *>(); #define STMT(Type, Base) #include "clang/AST/StmtNodes.inc" default: diff --git a/llvm/include/llvm/Support/TrailingObjects.h b/llvm/include/llvm/Support/TrailingObjects.h index f25f2311a81a4..d7211a930ae49 100644 --- a/llvm/include/llvm/Support/TrailingObjects.h +++ b/llvm/include/llvm/Support/TrailingObjects.h @@ -228,12 +228,18 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl< using ParentType::getTrailingObjectsImpl; - // This function contains only a static_assert BaseTy is final. The - // static_assert must be in a function, and not at class-level - // because BaseTy isn't complete at class instantiation time, but - // will be by the time this function is instantiated. - static void verifyTrailingObjectsAssertions() { + template <bool Strict> static void verifyTrailingObjectsAssertions() { + // The static_assert for BaseTy must be in a function, and not at + // class-level because BaseTy isn't complete at class instantiation time, + // but will be by the time this function is instantiated. static_assert(std::is_final<BaseTy>(), "BaseTy must be final."); + + // Verify that templated getTrailingObjects() is used only with multiple + // trailing types. Use getTrailingObjectsNonStrict() which does not check + // this. + static_assert(!Strict || sizeof...(TrailingTys) > 1, + "Use templated getTrailingObjects() only when there are " + "multiple trailing types"); } // These two methods are the base of the recursion for this method. @@ -283,7 +289,7 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl< /// (which must be one of those specified in the class template). The /// array may have zero or more elements in it. template <typename T> const T *getTrailingObjects() const { - verifyTrailingObjectsAssertions(); + verifyTrailingObjectsAssertions<true>(); // Forwards to an impl function with overloads, since member // function templates can't be specialized. return this->getTrailingObjectsImpl( @@ -295,7 +301,7 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl< /// (which must be one of those specified in the class template). The /// array may have zero or more elements in it. template <typename T> T *getTrailingObjects() { - verifyTrailingObjectsAssertions(); + verifyTrailingObjectsAssertions<true>(); // Forwards to an impl function with overloads, since member // function templates can't be specialized. return this->getTrailingObjectsImpl( @@ -310,14 +316,20 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl< static_assert(sizeof...(TrailingTys) == 1, "Can use non-templated getTrailingObjects() only when there " "is a single trailing type"); - return getTrailingObjects<FirstTrailingType>(); + verifyTrailingObjectsAssertions<false>(); + return this->getTrailingObjectsImpl( + static_cast<const BaseTy *>(this), + TrailingObjectsBase::OverloadToken<FirstTrailingType>()); } FirstTrailingType *getTrailingObjects() { static_assert(sizeof...(TrailingTys) == 1, "Can use non-templated getTrailingObjects() only when there " "is a single trailing type"); - return getTrailingObjects<FirstTrailingType>(); + verifyTrailingObjectsAssertions<false>(); + return this->getTrailingObjectsImpl( + static_cast<BaseTy *>(this), + TrailingObjectsBase::OverloadToken<FirstTrailingType>()); } // Functions that return the trailing objects as ArrayRefs. @@ -337,6 +349,31 @@ class TrailingObjects : private trailing_objects_internal::TrailingObjectsImpl< return ArrayRef(getTrailingObjects(), N); } + // Non-strict forms of templated `getTrailingObjects` that work with single + // trailing type. + template <typename T> const T *getTrailingObjectsNonStrict() const { + verifyTrailingObjectsAssertions<false>(); + return this->getTrailingObjectsImpl( + static_cast<const BaseTy *>(this), + TrailingObjectsBase::OverloadToken<T>()); + } + + template <typename T> T *getTrailingObjectsNonStrict() { + verifyTrailingObjectsAssertions<false>(); + return this->getTrailingObjectsImpl( + static_cast<BaseTy *>(this), TrailingObjectsBase::OverloadToken<T>()); + } + + template <typename T> + MutableArrayRef<T> getTrailingObjectsNonStrict(size_t N) { + return MutableArrayRef(getTrailingObjectsNonStrict<T>(), N); + } + + template <typename T> + ArrayRef<T> getTrailingObjectsNonStrict(size_t N) const { + return ArrayRef(getTrailingObjectsNonStrict<T>(), N); + } + /// Returns the size of the trailing data, if an object were /// allocated with the given counts (The counts are in the same order /// as the template arguments). This does not include the size of the diff --git a/llvm/unittests/Support/TrailingObjectsTest.cpp b/llvm/unittests/Support/TrailingObjectsTest.cpp index 2590f375b6598..9184a4dd0cc23 100644 --- a/llvm/unittests/Support/TrailingObjectsTest.cpp +++ b/llvm/unittests/Support/TrailingObjectsTest.cpp @@ -45,9 +45,10 @@ class Class1 final : private TrailingObjects<Class1, short> { template <typename... Ty> using FixedSizeStorage = TrailingObjects::FixedSizeStorage<Ty...>; - using TrailingObjects::totalSizeToAlloc; using TrailingObjects::additionalSizeToAlloc; using TrailingObjects::getTrailingObjects; + using TrailingObjects::getTrailingObjectsNonStrict; + using TrailingObjects::totalSizeToAlloc; }; // Here, there are two singular optional object types appended. Note @@ -123,11 +124,11 @@ TEST(TrailingObjects, OneArg) { EXPECT_EQ(Class1::totalSizeToAlloc<short>(3), sizeof(Class1) + sizeof(short) * 3); - EXPECT_EQ(C->getTrailingObjects<short>(), reinterpret_cast<short *>(C + 1)); + EXPECT_EQ(C->getTrailingObjects(), reinterpret_cast<short *>(C + 1)); EXPECT_EQ(C->get(0), 1); EXPECT_EQ(C->get(2), 3); - EXPECT_EQ(C->getTrailingObjects(), C->getTrailingObjects<short>()); + EXPECT_EQ(C->getTrailingObjects(), C->getTrailingObjectsNonStrict<short>()); delete C; } _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits