https://github.com/jurahul updated 
https://github.com/llvm/llvm-project/pull/144930

>From f5216d4c55c4dffa8785ff2fa051492ed98f405a Mon Sep 17 00:00:00 2001
From: Rahul Joshi <rjo...@nvidia.com>
Date: Thu, 19 Jun 2025 10:25:12 -0700
Subject: [PATCH] [LLVM][Clang] Add and enable strict mode for
 `getTrailingObjects`

Under strict mode, the templated `getTrailingObjects` can be called
only when there is > 1 trailing types. The strict mode can be disabled
on a per-call basis when its not possible to know statically if there
will be a single or multiple trailing types (like in OpenMPClause.h).
---
 clang/include/clang/AST/OpenMPClause.h        | 46 ++++++++++++-------
 clang/lib/AST/Expr.cpp                        |  3 +-
 llvm/include/llvm/Support/TrailingObjects.h   | 21 +++++----
 .../unittests/Support/TrailingObjectsTest.cpp |  5 +-
 4 files changed, 47 insertions(+), 28 deletions(-)

diff --git a/clang/include/clang/AST/OpenMPClause.h 
b/clang/include/clang/AST/OpenMPClause.h
index 2fa8fa529741e..b62ebd614e4c7 100644
--- a/clang/include/clang/AST/OpenMPClause.h
+++ b/clang/include/clang/AST/OpenMPClause.h
@@ -295,7 +295,8 @@ template <class T> class OMPVarListClause : public 
OMPClause {
 
   /// Fetches list of variables associated with this clause.
   MutableArrayRef<Expr *> getVarRefs() {
-    return static_cast<T *>(this)->template getTrailingObjects<Expr 
*>(NumVars);
+    return static_cast<T *>(this)
+        ->template getTrailingObjects<Expr *, /*Strict=*/false>(NumVars);
   }
 
   /// Sets the list of variables for this clause.
@@ -334,8 +335,8 @@ template <class T> class OMPVarListClause : public 
OMPClause {
 
   /// Fetches list of all variables in the clause.
   ArrayRef<const Expr *> getVarRefs() const {
-    return static_cast<const T *>(this)->template getTrailingObjects<Expr *>(
-        NumVars);
+    return static_cast<const T *>(this)
+        ->template getTrailingObjects<Expr *, /*Strict=*/false>(NumVars);
   }
 };
 
@@ -380,7 +381,8 @@ template <class T> class OMPDirectiveListClause : public 
OMPClause {
 
   MutableArrayRef<OpenMPDirectiveKind> getDirectiveKinds() {
     return static_cast<T *>(this)
-        ->template getTrailingObjects<OpenMPDirectiveKind>(NumKinds);
+        ->template getTrailingObjects<OpenMPDirectiveKind, /*Strict=*/false>(
+            NumKinds);
   }
 
   void setDirectiveKinds(ArrayRef<OpenMPDirectiveKind> DK) {
@@ -5901,15 +5903,17 @@ class OMPMappableExprListClause : public 
OMPVarListClause<T>,
   /// Get the unique declarations that are in the trailing objects of the
   /// class.
   MutableArrayRef<ValueDecl *> getUniqueDeclsRef() {
-    return static_cast<T *>(this)->template getTrailingObjects<ValueDecl *>(
-        NumUniqueDeclarations);
+    return static_cast<T *>(this)
+        ->template getTrailingObjects<ValueDecl *, /*Strict=*/false>(
+            NumUniqueDeclarations);
   }
 
   /// Get the unique declarations that are in the trailing objects of the
   /// class.
   ArrayRef<ValueDecl *> getUniqueDeclsRef() const {
     return static_cast<const T *>(this)
-        ->template getTrailingObjects<ValueDecl *>(NumUniqueDeclarations);
+        ->template getTrailingObjects<ValueDecl *, /*Strict=*/false>(
+            NumUniqueDeclarations);
   }
 
   /// Set the unique declarations that are in the trailing objects of the
@@ -5923,15 +5927,17 @@ class OMPMappableExprListClause : public 
OMPVarListClause<T>,
   /// Get the number of lists per declaration that are in the trailing
   /// objects of the class.
   MutableArrayRef<unsigned> getDeclNumListsRef() {
-    return static_cast<T *>(this)->template getTrailingObjects<unsigned>(
-        NumUniqueDeclarations);
+    return static_cast<T *>(this)
+        ->template getTrailingObjects<unsigned, /*Strict=*/false>(
+            NumUniqueDeclarations);
   }
 
   /// Get the number of lists per declaration that are in the trailing
   /// objects of the class.
   ArrayRef<unsigned> getDeclNumListsRef() const {
-    return static_cast<const T *>(this)->template getTrailingObjects<unsigned>(
-        NumUniqueDeclarations);
+    return static_cast<const T *>(this)
+        ->template getTrailingObjects<unsigned, /*Strict=*/false>(
+            NumUniqueDeclarations);
   }
 
   /// Set the number of lists per declaration that are in the trailing
@@ -5946,7 +5952,8 @@ class OMPMappableExprListClause : public 
OMPVarListClause<T>,
   /// objects of the class. They are appended after the number of lists.
   MutableArrayRef<unsigned> getComponentListSizesRef() {
     return MutableArrayRef<unsigned>(
-        static_cast<T *>(this)->template getTrailingObjects<unsigned>() +
+        static_cast<T *>(this)
+                ->template getTrailingObjects<unsigned, /*Strict=*/false>() +
             NumUniqueDeclarations,
         NumComponentLists);
   }
@@ -5955,7 +5962,8 @@ class OMPMappableExprListClause : public 
OMPVarListClause<T>,
   /// objects of the class. They are appended after the number of lists.
   ArrayRef<unsigned> getComponentListSizesRef() const {
     return ArrayRef<unsigned>(
-        static_cast<const T *>(this)->template getTrailingObjects<unsigned>() +
+        static_cast<const T *>(this)
+                ->template getTrailingObjects<unsigned, /*Strict=*/false>() +
             NumUniqueDeclarations,
         NumComponentLists);
   }
@@ -5971,13 +5979,15 @@ class OMPMappableExprListClause : public 
OMPVarListClause<T>,
   /// Get the components that are in the trailing objects of the class.
   MutableArrayRef<MappableComponent> getComponentsRef() {
     return static_cast<T *>(this)
-        ->template getTrailingObjects<MappableComponent>(NumComponents);
+        ->template getTrailingObjects<MappableComponent, /*Strict=*/false>(
+            NumComponents);
   }
 
   /// Get the components that are in the trailing objects of the class.
   ArrayRef<MappableComponent> getComponentsRef() const {
     return static_cast<const T *>(this)
-        ->template getTrailingObjects<MappableComponent>(NumComponents);
+        ->template getTrailingObjects<MappableComponent, /*Strict=*/false>(
+            NumComponents);
   }
 
   /// Set the components that are in the trailing objects of the class.
@@ -6084,7 +6094,8 @@ class OMPMappableExprListClause : public 
OMPVarListClause<T>,
     assert(SupportsMapper &&
            "Must be a clause that is possible to have user-defined mappers");
     return llvm::MutableArrayRef<Expr *>(
-        static_cast<T *>(this)->template getTrailingObjects<Expr *>() +
+        static_cast<T *>(this)
+                ->template getTrailingObjects<Expr *, /*Strict=*/false>() +
             OMPVarListClause<T>::varlist_size(),
         OMPVarListClause<T>::varlist_size());
   }
@@ -6095,7 +6106,8 @@ class OMPMappableExprListClause : public 
OMPVarListClause<T>,
     assert(SupportsMapper &&
            "Must be a clause that is possible to have user-defined mappers");
     return llvm::ArrayRef<Expr *>(
-        static_cast<const T *>(this)->template getTrailingObjects<Expr *>() +
+        static_cast<const T *>(this)
+                ->template getTrailingObjects<Expr *, /*Strict=*/false>() +
             OMPVarListClause<T>::varlist_size(),
         OMPVarListClause<T>::varlist_size());
   }
diff --git a/clang/lib/AST/Expr.cpp b/clang/lib/AST/Expr.cpp
index c3722c65abf6e..b93a31ca4ed36 100644
--- a/clang/lib/AST/Expr.cpp
+++ b/clang/lib/AST/Expr.cpp
@@ -2024,7 +2024,8 @@ CXXBaseSpecifier **CastExpr::path_buffer() {
 #define ABSTRACT_STMT(x)
 #define CASTEXPR(Type, Base)                                                   
\
   case Stmt::Type##Class:                                                      
\
-    return static_cast<Type *>(this)->getTrailingObjects<CXXBaseSpecifier *>();
+    return static_cast<Type *>(this)                                           
\
+        ->getTrailingObjects<CXXBaseSpecifier *, /*Strict=*/false>();
 #define STMT(Type, Base)
 #include "clang/AST/StmtNodes.inc"
   default:
diff --git a/llvm/include/llvm/Support/TrailingObjects.h 
b/llvm/include/llvm/Support/TrailingObjects.h
index f25f2311a81a4..3d701de93b4f1 100644
--- a/llvm/include/llvm/Support/TrailingObjects.h
+++ b/llvm/include/llvm/Support/TrailingObjects.h
@@ -282,7 +282,9 @@ class TrailingObjects : private 
trailing_objects_internal::TrailingObjectsImpl<
   /// Returns a pointer to the trailing object array of the given type
   /// (which must be one of those specified in the class template). The
   /// array may have zero or more elements in it.
-  template <typename T> const T *getTrailingObjects() const {
+  template <typename T, bool Strict = true>
+  const T *getTrailingObjects() const {
+    static_assert(!Strict || sizeof...(TrailingTys) > 1);
     verifyTrailingObjectsAssertions();
     // Forwards to an impl function with overloads, since member
     // function templates can't be specialized.
@@ -294,7 +296,8 @@ class TrailingObjects : private 
trailing_objects_internal::TrailingObjectsImpl<
   /// Returns a pointer to the trailing object array of the given type
   /// (which must be one of those specified in the class template). The
   /// array may have zero or more elements in it.
-  template <typename T> T *getTrailingObjects() {
+  template <typename T, bool Strict = true> T *getTrailingObjects() {
+    static_assert(!Strict || sizeof...(TrailingTys) > 1);
     verifyTrailingObjectsAssertions();
     // Forwards to an impl function with overloads, since member
     // function templates can't be specialized.
@@ -310,23 +313,25 @@ class TrailingObjects : private 
trailing_objects_internal::TrailingObjectsImpl<
     static_assert(sizeof...(TrailingTys) == 1,
                   "Can use non-templated getTrailingObjects() only when there "
                   "is a single trailing type");
-    return getTrailingObjects<FirstTrailingType>();
+    return getTrailingObjects<FirstTrailingType, /*Strict=*/false>();
   }
 
   FirstTrailingType *getTrailingObjects() {
     static_assert(sizeof...(TrailingTys) == 1,
                   "Can use non-templated getTrailingObjects() only when there "
                   "is a single trailing type");
-    return getTrailingObjects<FirstTrailingType>();
+    return getTrailingObjects<FirstTrailingType, /*Strict=*/false>();
   }
 
   // Functions that return the trailing objects as ArrayRefs.
-  template <typename T> MutableArrayRef<T> getTrailingObjects(size_t N) {
-    return MutableArrayRef(getTrailingObjects<T>(), N);
+  template <typename T, bool Strict = true>
+  MutableArrayRef<T> getTrailingObjects(size_t N) {
+    return MutableArrayRef(getTrailingObjects<T, Strict>(), N);
   }
 
-  template <typename T> ArrayRef<T> getTrailingObjects(size_t N) const {
-    return ArrayRef(getTrailingObjects<T>(), N);
+  template <typename T, bool Strict = true>
+  ArrayRef<T> getTrailingObjects(size_t N) const {
+    return ArrayRef(getTrailingObjects<T, Strict>(), N);
   }
 
   MutableArrayRef<FirstTrailingType> getTrailingObjects(size_t N) {
diff --git a/llvm/unittests/Support/TrailingObjectsTest.cpp 
b/llvm/unittests/Support/TrailingObjectsTest.cpp
index 2590f375b6598..83afb22f837aa 100644
--- a/llvm/unittests/Support/TrailingObjectsTest.cpp
+++ b/llvm/unittests/Support/TrailingObjectsTest.cpp
@@ -123,11 +123,12 @@ TEST(TrailingObjects, OneArg) {
   EXPECT_EQ(Class1::totalSizeToAlloc<short>(3),
             sizeof(Class1) + sizeof(short) * 3);
 
-  EXPECT_EQ(C->getTrailingObjects<short>(), reinterpret_cast<short *>(C + 1));
+  EXPECT_EQ(C->getTrailingObjects(), reinterpret_cast<short *>(C + 1));
   EXPECT_EQ(C->get(0), 1);
   EXPECT_EQ(C->get(2), 3);
 
-  EXPECT_EQ(C->getTrailingObjects(), C->getTrailingObjects<short>());
+  EXPECT_EQ(C->getTrailingObjects(),
+            (C->getTrailingObjects<short, /*Strict=*/false>()));
 
   delete C;
 }

_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to