https://github.com/ro-i created https://github.com/llvm/llvm-project/pull/146404

OpenMP 6.0 12.1.2 specifies the behavior of the strict modifier for the 
num_threads clause on parallel directives, along with the message and severity 
clauses. This commit implements necessary device runtime changes.

>From cf566c60db9eef81c39a45082645c9d44992bec5 Mon Sep 17 00:00:00 2001
From: Robert Imschweiler <robert.imschwei...@amd.com>
Date: Fri, 27 Jun 2025 07:54:07 -0500
Subject: [PATCH] [OpenMP][clang] 6.0: num_threads strict (part 2: device
 runtime)

OpenMP 6.0 12.1.2 specifies the behavior of the strict modifier for the
num_threads clause on parallel directives, along with the message and
severity clauses. This commit implements necessary device runtime
changes.
---
 offload/DeviceRTL/include/DeviceTypes.h |  6 ++
 offload/DeviceRTL/src/Parallelism.cpp   | 78 +++++++++++++++++++------
 openmp/runtime/src/kmp.h                |  1 +
 3 files changed, 67 insertions(+), 18 deletions(-)

diff --git a/offload/DeviceRTL/include/DeviceTypes.h 
b/offload/DeviceRTL/include/DeviceTypes.h
index 2e5d92380f040..111143a5578f1 100644
--- a/offload/DeviceRTL/include/DeviceTypes.h
+++ b/offload/DeviceRTL/include/DeviceTypes.h
@@ -136,6 +136,12 @@ struct omp_lock_t {
   void *Lock;
 };
 
+// see definition in openmp/runtime kmp.h
+typedef enum omp_severity_t {
+  severity_warning = 1,
+  severity_fatal = 2
+} omp_severity_t;
+
 using InterWarpCopyFnTy = void (*)(void *src, int32_t warp_num);
 using ShuffleReductFnTy = void (*)(void *rhsData, int16_t lane_id,
                                    int16_t lane_offset, int16_t shortCircuit);
diff --git a/offload/DeviceRTL/src/Parallelism.cpp 
b/offload/DeviceRTL/src/Parallelism.cpp
index 08ce616aee1c4..78438a60454b8 100644
--- a/offload/DeviceRTL/src/Parallelism.cpp
+++ b/offload/DeviceRTL/src/Parallelism.cpp
@@ -45,7 +45,24 @@ using namespace ompx;
 
 namespace {
 
-uint32_t determineNumberOfThreads(int32_t NumThreadsClause) {
+void num_threads_strict_error(int32_t nt_strict, int32_t nt_severity,
+                              const char *nt_message, int32_t requested,
+                              int32_t actual) {
+  if (nt_message)
+    printf("%s\n", nt_message);
+  else
+    printf("The computed number of threads (%u) does not match the requested "
+           "number of threads (%d). Consider that it might not be supported "
+           "to select exactly %d threads on this target device.\n",
+           actual, requested, requested);
+  if (nt_severity == severity_fatal)
+    __builtin_trap();
+}
+
+uint32_t determineNumberOfThreads(int32_t NumThreadsClause,
+                                  int32_t nt_strict = false,
+                                  int32_t nt_severity = severity_fatal,
+                                  const char *nt_message = nullptr) {
   uint32_t NThreadsICV =
       NumThreadsClause != -1 ? NumThreadsClause : icv::NThreads;
   uint32_t NumThreads = mapping::getMaxTeamThreads();
@@ -55,13 +72,17 @@ uint32_t determineNumberOfThreads(int32_t NumThreadsClause) 
{
 
   // SPMD mode allows any number of threads, for generic mode we round down to 
a
   // multiple of WARPSIZE since it is legal to do so in OpenMP.
-  if (mapping::isSPMDMode())
-    return NumThreads;
+  if (!mapping::isSPMDMode()) {
+    if (NumThreads < mapping::getWarpSize())
+      NumThreads = 1;
+    else
+      NumThreads = (NumThreads & ~((uint32_t)mapping::getWarpSize() - 1));
+  }
 
-  if (NumThreads < mapping::getWarpSize())
-    NumThreads = 1;
-  else
-    NumThreads = (NumThreads & ~((uint32_t)mapping::getWarpSize() - 1));
+  if (NumThreadsClause != -1 && nt_strict &&
+      NumThreads != static_cast<uint32_t>(NumThreadsClause))
+    num_threads_strict_error(nt_strict, nt_severity, nt_message,
+                             NumThreadsClause, NumThreads);
 
   return NumThreads;
 }
@@ -82,12 +103,14 @@ uint32_t determineNumberOfThreads(int32_t 
NumThreadsClause) {
 
 extern "C" {
 
-[[clang::always_inline]] void __kmpc_parallel_spmd(IdentTy *ident,
-                                                   int32_t num_threads,
-                                                   void *fn, void **args,
-                                                   const int64_t nargs) {
+[[clang::always_inline]] void
+__kmpc_parallel_spmd(IdentTy *ident, int32_t num_threads, void *fn, void 
**args,
+                     const int64_t nargs, int32_t nt_strict = false,
+                     int32_t nt_severity = severity_fatal,
+                     const char *nt_message = nullptr) {
   uint32_t TId = mapping::getThreadIdInBlock();
-  uint32_t NumThreads = determineNumberOfThreads(num_threads);
+  uint32_t NumThreads =
+      determineNumberOfThreads(num_threads, nt_strict, nt_severity, 
nt_message);
   uint32_t PTeamSize =
       NumThreads == mapping::getMaxTeamThreads() ? 0 : NumThreads;
   // Avoid the race between the read of the `icv::Level` above and the write
@@ -140,10 +163,11 @@ extern "C" {
   return;
 }
 
-[[clang::always_inline]] void
-__kmpc_parallel_51(IdentTy *ident, int32_t, int32_t if_expr,
-                   int32_t num_threads, int proc_bind, void *fn,
-                   void *wrapper_fn, void **args, int64_t nargs) {
+[[clang::always_inline]] void __kmpc_parallel_51(
+    IdentTy *ident, int32_t, int32_t if_expr, int32_t num_threads,
+    int proc_bind, void *fn, void *wrapper_fn, void **args, int64_t nargs,
+    int32_t nt_strict = false, int32_t nt_severity = severity_fatal,
+    const char *nt_message = nullptr) {
   uint32_t TId = mapping::getThreadIdInBlock();
 
   // Assert the parallelism level is zero if disabled by the user.
@@ -156,6 +180,12 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t 
if_expr,
   // 3) nested parallel regions
   if (OMP_UNLIKELY(!if_expr || state::HasThreadState ||
                    (config::mayUseNestedParallelism() && icv::Level))) {
+    // OpenMP 6.0 12.1.2 requires the num_threads 'strict' modifier to also 
have
+    // effect when parallel execution is disabled by a corresponding if clause
+    // attached to the parallel directive.
+    if (nt_strict && num_threads > 1)
+      num_threads_strict_error(nt_strict, nt_severity, nt_message, num_threads,
+                               1);
     state::DateEnvironmentRAII DERAII(ident);
     ++icv::Level;
     invokeMicrotask(TId, 0, fn, args, nargs);
@@ -169,12 +199,14 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t 
if_expr,
     // This was moved to its own routine so it could be called directly
     // in certain situations to avoid resource consumption of unused
     // logic in parallel_51.
-    __kmpc_parallel_spmd(ident, num_threads, fn, args, nargs);
+    __kmpc_parallel_spmd(ident, num_threads, fn, args, nargs, nt_strict,
+                         nt_severity, nt_message);
 
     return;
   }
 
-  uint32_t NumThreads = determineNumberOfThreads(num_threads);
+  uint32_t NumThreads =
+      determineNumberOfThreads(num_threads, nt_strict, nt_severity, 
nt_message);
   uint32_t MaxTeamThreads = mapping::getMaxTeamThreads();
   uint32_t PTeamSize = NumThreads == MaxTeamThreads ? 0 : NumThreads;
 
@@ -277,6 +309,16 @@ __kmpc_parallel_51(IdentTy *ident, int32_t, int32_t 
if_expr,
     __kmpc_end_sharing_variables();
 }
 
+[[clang::always_inline]] void __kmpc_parallel_60(
+    IdentTy *ident, int32_t id, int32_t if_expr, int32_t num_threads,
+    int proc_bind, void *fn, void *wrapper_fn, void **args, int64_t nargs,
+    int32_t nt_strict = false, int32_t nt_severity = severity_fatal,
+    const char *nt_message = nullptr) {
+  return __kmpc_parallel_51(ident, id, if_expr, num_threads, proc_bind, fn,
+                            wrapper_fn, args, nargs, nt_strict, nt_severity,
+                            nt_message);
+}
+
 [[clang::noinline]] bool __kmpc_kernel_parallel(ParallelRegionFnTy *WorkFn) {
   // Work function and arguments for L1 parallel region.
   *WorkFn = state::ParallelRegionFn;
diff --git a/openmp/runtime/src/kmp.h b/openmp/runtime/src/kmp.h
index a2cacc8792b15..983e1c34f76b8 100644
--- a/openmp/runtime/src/kmp.h
+++ b/openmp/runtime/src/kmp.h
@@ -4666,6 +4666,7 @@ static inline int 
__kmp_adjust_gtid_for_hidden_helpers(int gtid) {
 }
 
 // Support for error directive
+// See definition in offload/DeviceRTL DeviceTypes.h
 typedef enum kmp_severity_t {
   severity_warning = 1,
   severity_fatal = 2

_______________________________________________
llvm-branch-commits mailing list
llvm-branch-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/llvm-branch-commits

Reply via email to