================
@@ -14175,27 +14222,350 @@ bool SemaOpenMP::checkTransformableLoopNest(
         return false;
       },
       [&OriginalInits](OMPLoopBasedDirective *Transform) {
-        Stmt *DependentPreInits;
-        if (auto *Dir = dyn_cast<OMPTileDirective>(Transform))
-          DependentPreInits = Dir->getPreInits();
-        else if (auto *Dir = dyn_cast<OMPStripeDirective>(Transform))
-          DependentPreInits = Dir->getPreInits();
-        else if (auto *Dir = dyn_cast<OMPUnrollDirective>(Transform))
-          DependentPreInits = Dir->getPreInits();
-        else if (auto *Dir = dyn_cast<OMPReverseDirective>(Transform))
-          DependentPreInits = Dir->getPreInits();
-        else if (auto *Dir = dyn_cast<OMPInterchangeDirective>(Transform))
-          DependentPreInits = Dir->getPreInits();
-        else
-          llvm_unreachable("Unhandled loop transformation");
-
-        appendFlattenedStmtList(OriginalInits.back(), DependentPreInits);
+        updatePreInits(Transform, OriginalInits);
       });
   assert(OriginalInits.back().empty() && "No preinit after innermost loop");
   OriginalInits.pop_back();
   return Result;
 }
 
+// Counts the total number of nested loops, including the outermost loop (the
+// original loop). PRECONDITION of this visitor is that it must be invoked from
+// the original loop to be analyzed. The traversal is stop for Decl's and
+// Expr's given that they may contain inner loops that must not be counted.
+//
+// Example AST structure for the code:
+//
+// int main() {
+//     #pragma omp fuse
+//     {
+//         for (int i = 0; i < 100; i++) {    <-- Outer loop
+//             []() {
+//                 for(int j = 0; j < 100; j++) {}  <-- NOT A LOOP
+//             };
+//             for(int j = 0; j < 5; ++j) {}    <-- Inner loop
+//         }
+//         for (int r = 0; i < 100; i++) {    <-- Outer loop
+//             struct LocalClass {
+//                 void bar() {
+//                     for(int j = 0; j < 100; j++) {}  <-- NOT A LOOP
+//                 }
+//             };
+//             for(int k = 0; k < 10; ++k) {}    <-- Inner loop
+//             {x = 5; for(k = 0; k < 10; ++k) x += k; x}; <-- NOT A LOOP
+//         }
+//     }
+// }
+// Result: Loop 'i' contains 2 loops, Loop 'r' also contains 2 loops
+class NestedLoopCounterVisitor : public DynamicRecursiveASTVisitor {
+private:
+  unsigned NestedLoopCount = 0;
+
+public:
+  explicit NestedLoopCounterVisitor() {}
+
+  unsigned getNestedLoopCount() const { return NestedLoopCount; }
+
+  bool VisitForStmt(ForStmt *FS) override {
+    ++NestedLoopCount;
+    return true;
+  }
+
+  bool VisitCXXForRangeStmt(CXXForRangeStmt *FRS) override {
+    ++NestedLoopCount;
+    return true;
+  }
+
+  bool TraverseStmt(Stmt *S) override {
+    if (!S)
+      return true;
+
+    // Skip traversal of all expressions, including special cases like
+    // LambdaExpr, StmtExpr, BlockExpr, and RequiresExpr. These expressions
+    // may contain inner statements (and even loops), but they are not part
+    // of the syntactic body of the surrounding loop structure.
+    //  Therefore must not be counted
+    if (isa<Expr>(S))
+      return true;
+
+    // Only recurse into CompoundStmt (block {}) and loop bodies
+    if (isa<CompoundStmt>(S) || isa<ForStmt>(S) || isa<CXXForRangeStmt>(S)) {
+      return DynamicRecursiveASTVisitor::TraverseStmt(S);
+    }
+
+    // Stop traversal of the rest of statements, that break perfect
+    // loop nesting, such as control flow (IfStmt, SwitchStmt...)
+    return true;
+  }
+
+  bool TraverseDecl(Decl *D) override {
+    // Stop in the case of finding a declaration, it is not important
+    // in order to find nested loops (Possible CXXRecordDecl, RecordDecl,
+    // FunctionDecl...)
+    return true;
+  }
+};
+
+bool SemaOpenMP::analyzeLoopSequence(
+    Stmt *LoopSeqStmt, unsigned &LoopSeqSize, unsigned &NumLoops,
+    SmallVectorImpl<OMPLoopBasedDirective::HelperExprs> &LoopHelpers,
+    SmallVectorImpl<Stmt *> &ForStmts,
+    SmallVectorImpl<SmallVector<Stmt *, 0>> &OriginalInits,
+    SmallVectorImpl<SmallVector<Stmt *, 0>> &TransformsPreInits,
+    SmallVectorImpl<SmallVector<Stmt *, 0>> &LoopSequencePreInits,
+    SmallVectorImpl<OMPLoopCategory> &LoopCategories, ASTContext &Context,
+    OpenMPDirectiveKind Kind) {
+
+  VarsWithInheritedDSAType TmpDSA;
+  QualType BaseInductionVarType;
+  // Helper Lambda to handle storing initialization and body statements for 
both
+  // ForStmt and CXXForRangeStmt and checks for any possible mismatch between
+  // induction variables types
+  auto storeLoopStatements = [&OriginalInits, &ForStmts, &BaseInductionVarType,
+                              this, &Context](Stmt *LoopStmt) {
+    if (auto *For = dyn_cast<ForStmt>(LoopStmt)) {
+      OriginalInits.back().push_back(For->getInit());
+      ForStmts.push_back(For);
+      // Extract induction variable
+      if (auto *InitStmt = dyn_cast_or_null<DeclStmt>(For->getInit())) {
+        if (auto *InitDecl = dyn_cast<VarDecl>(InitStmt->getSingleDecl())) {
+          QualType InductionVarType = InitDecl->getType().getCanonicalType();
+
+          // Compare with first loop type
+          if (BaseInductionVarType.isNull()) {
+            BaseInductionVarType = InductionVarType;
+          } else if (!Context.hasSameType(BaseInductionVarType,
+                                          InductionVarType)) {
+            Diag(InitDecl->getBeginLoc(),
+                 diag::warn_omp_different_loop_ind_var_types)
+                << getOpenMPDirectiveName(OMPD_fuse) << BaseInductionVarType
+                << InductionVarType;
+          }
+        }
+      }
+    } else {
+      auto *CXXFor = cast<CXXForRangeStmt>(LoopStmt);
+      OriginalInits.back().push_back(CXXFor->getBeginStmt());
+      ForStmts.push_back(CXXFor);
+    }
+  };
+
+  // Helper lambda functions to encapsulate the processing of different
+  // derivations of the canonical loop sequence grammar
+  //
+  // Modularized code for handling loop generation and transformations
+  auto analyzeLoopGeneration = [&storeLoopStatements, &LoopHelpers,
+                                &OriginalInits, &TransformsPreInits,
+                                &LoopCategories, &LoopSeqSize, &NumLoops, Kind,
+                                &TmpDSA, &ForStmts, &Context,
+                                &LoopSequencePreInits, this](Stmt *Child) {
+    auto LoopTransform = dyn_cast<OMPLoopTransformationDirective>(Child);
+    Stmt *TransformedStmt = LoopTransform->getTransformedStmt();
+    unsigned NumGeneratedLoopNests = LoopTransform->getNumGeneratedLoopNests();
+    unsigned NumGeneratedLoops = LoopTransform->getNumGeneratedLoops();
+    // Handle the case where transformed statement is not available due to
+    // dependent contexts
+    if (!TransformedStmt) {
+      if (NumGeneratedLoopNests > 0) {
+        LoopSeqSize += NumGeneratedLoopNests;
+        NumLoops += NumGeneratedLoops;
+        return true;
+      }
+      // Unroll full (0 loops produced)
+      else {
+        Diag(Child->getBeginLoc(), diag::err_omp_not_for)
+            << 0 << getOpenMPDirectiveName(Kind);
+        return false;
+      }
+    }
+    // Handle loop transformations with multiple loop nests
+    // Unroll full
+    if (NumGeneratedLoopNests <= 0) {
+      Diag(Child->getBeginLoc(), diag::err_omp_not_for)
+          << 0 << getOpenMPDirectiveName(Kind);
+      return false;
+    }
+    // Loop transformatons such as split or loopranged fuse
+    else if (NumGeneratedLoopNests > 1) {
+      // Get the preinits related to this loop sequence generating
+      // loop transformation (i.e loopranged fuse, split...)
+      LoopSequencePreInits.emplace_back();
+      // These preinits differ slightly from regular inits/pre-inits related
+      // to single loop generating loop transformations (interchange, unroll)
+      // given that they are not bounded to a particular loop nest
+      // so they need to be treated independently
+      updatePreInits(LoopTransform, LoopSequencePreInits);
+      return analyzeLoopSequence(TransformedStmt, LoopSeqSize, NumLoops,
+                                 LoopHelpers, ForStmts, OriginalInits,
+                                 TransformsPreInits, LoopSequencePreInits,
+                                 LoopCategories, Context, Kind);
+    }
+    // Vast majority: (Tile, Unroll, Stripe, Reverse, Interchange, Fuse all)
+    else {
+      // Process the transformed loop statement
+      OriginalInits.emplace_back();
+      TransformsPreInits.emplace_back();
+      LoopHelpers.emplace_back();
+      LoopCategories.push_back(OMPLoopCategory::TransformSingleLoop);
+
+      unsigned IsCanonical =
+          checkOpenMPLoop(Kind, nullptr, nullptr, TransformedStmt, SemaRef,
+                          *DSAStack, TmpDSA, LoopHelpers[LoopSeqSize]);
+
+      if (!IsCanonical) {
+        Diag(TransformedStmt->getBeginLoc(), diag::err_omp_not_canonical_loop)
+            << getOpenMPDirectiveName(Kind);
+        return false;
+      }
+      storeLoopStatements(TransformedStmt);
+      updatePreInits(LoopTransform, TransformsPreInits);
+
+      NumLoops += NumGeneratedLoops;
+      ++LoopSeqSize;
+      return true;
+    }
+  };
+
+  // Modularized code for handling regular canonical loops
+  auto analyzeRegularLoop = [&storeLoopStatements, &LoopHelpers, 
&OriginalInits,
+                             &LoopSeqSize, &NumLoops, Kind, &TmpDSA,
+                             &LoopCategories, this](Stmt *Child) {
+    OriginalInits.emplace_back();
+    LoopHelpers.emplace_back();
+    LoopCategories.push_back(OMPLoopCategory::RegularLoop);
+
+    unsigned IsCanonical =
+        checkOpenMPLoop(Kind, nullptr, nullptr, Child, SemaRef, *DSAStack,
+                        TmpDSA, LoopHelpers[LoopSeqSize]);
+
+    if (!IsCanonical) {
+      Diag(Child->getBeginLoc(), diag::err_omp_not_canonical_loop)
+          << getOpenMPDirectiveName(Kind);
+      return false;
+    }
+
+    storeLoopStatements(Child);
+    auto NLCV = NestedLoopCounterVisitor();
+    NLCV.TraverseStmt(Child);
+    NumLoops += NLCV.getNestedLoopCount();
+    return true;
+  };
+
+  // Helper functions to validate canonical loop sequence grammar is valid
+  auto isLoopSequenceDerivation = [](auto *Child) {
+    return isa<ForStmt>(Child) || isa<CXXForRangeStmt>(Child) ||
+           isa<OMPLoopTransformationDirective>(Child);
+  };
+  auto isLoopGeneratingStmt = [](auto *Child) {
+    return isa<OMPLoopTransformationDirective>(Child);
+  };
+
+  // High level grammar validation
+  for (auto *Child : LoopSeqStmt->children()) {
+
+    if (!Child)
+      continue;
+
+    // Skip over non-loop-sequence statements
+    if (!isLoopSequenceDerivation(Child)) {
+      Child = Child->IgnoreContainers();
+
+      // Ignore empty compound statement
+      if (!Child)
+        continue;
+
+      // In the case of a nested loop sequence ignoring containers would not
+      // be enough, a recurisve transversal of the loop sequence is required
+      if (isa<CompoundStmt>(Child)) {
+        if (!analyzeLoopSequence(Child, LoopSeqSize, NumLoops, LoopHelpers,
+                                 ForStmts, OriginalInits, TransformsPreInits,
+                                 LoopSequencePreInits, LoopCategories, Context,
+                                 Kind))
+          return false;
+        // Already been treated, skip this children
+        continue;
+      }
+    }
+    // Regular loop sequence handling
+    if (isLoopSequenceDerivation(Child)) {
+      if (isLoopGeneratingStmt(Child)) {
+        if (!analyzeLoopGeneration(Child)) {
+          return false;
+        }
+        // analyzeLoopGeneration updates Loop Sequence size accordingly
+
+      } else {
+        if (!analyzeRegularLoop(Child)) {
+          return false;
+        }
+        // Update the Loop Sequence size by one
+        ++LoopSeqSize;
+      }
+    } else {
+      // Report error for invalid statement inside canonical loop sequence
+      Diag(Child->getBeginLoc(), diag::err_omp_not_for)
+          << 0 << getOpenMPDirectiveName(Kind);
+      return false;
+    }
+  }
+  return true;
+}
+
+bool SemaOpenMP::checkTransformableLoopSequence(
+    OpenMPDirectiveKind Kind, Stmt *AStmt, unsigned &LoopSeqSize,
+    unsigned &NumLoops,
+    SmallVectorImpl<OMPLoopBasedDirective::HelperExprs> &LoopHelpers,
+    SmallVectorImpl<Stmt *> &ForStmts,
+    SmallVectorImpl<SmallVector<Stmt *, 0>> &OriginalInits,
+    SmallVectorImpl<SmallVector<Stmt *, 0>> &TransformsPreInits,
+    SmallVectorImpl<SmallVector<Stmt *, 0>> &LoopSequencePreInits,
+    SmallVectorImpl<OMPLoopCategory> &LoopCategories, ASTContext &Context) {
+
+  // Checks whether the given statement is a compound statement
+  if (!isa<CompoundStmt>(AStmt)) {
+    Diag(AStmt->getBeginLoc(), diag::err_omp_not_a_loop_sequence)
+        << getOpenMPDirectiveName(Kind);
+    return false;
+  }
+  // Number of top level canonical loop nests observed (And acts as index)
+  LoopSeqSize = 0;
+  // Number of total observed loops
+  NumLoops = 0;
+
+  // Following OpenMP 6.0 API Specification, a Canonical Loop Sequence follows
+  // the grammar:
+  //
+  // canonical-loop-sequence:
+  //  {
+  //    loop-sequence+
+  //  }
+  // where loop-sequence can be any of the following:
+  // 1. canonical-loop-sequence
+  // 2. loop-nest
+  // 3. loop-sequence-generating-construct (i.e OMPLoopTransformationDirective)
+  //
+  // To recognise and traverse this structure the following helper functions
+  // have been defined. analyzeLoopSequence serves as the recurisve entry point
+  // and tries to match the input AST to the canonical loop sequence grammar
+  // structure. This function will perform both a semantic and syntactical
+  // analysis of the given statement according to OpenMP 6.0 definition of
+  // the aforementioned canonical loop sequence
+
+  // Recursive entry point to process the main loop sequence
+  if (!analyzeLoopSequence(AStmt, LoopSeqSize, NumLoops, LoopHelpers, ForStmts,
+                           OriginalInits, TransformsPreInits,
+                           LoopSequencePreInits, LoopCategories, Context,
+                           Kind)) {
+    return false;
+  }
----------------
alexey-bataev wrote:

```suggestion
                           Kind))
    return false;
```

https://github.com/llvm/llvm-project/pull/139293
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to