tianshilei1992 created this revision.
tianshilei1992 added reviewers: jdoerfert, ABataev.
Herald added subscribers: guansong, yaxunl.
tianshilei1992 requested review of this revision.
Herald added subscribers: cfe-commits, sstefan1.
Herald added a project: clang.

This patch adds the codegen support for `atomic compare capture` in clang.

Depends on D120007 <https://reviews.llvm.org/D120007> and D118632 
<https://reviews.llvm.org/D118632>


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D120290

Files:
  clang/include/clang/AST/StmtOpenMP.h
  clang/lib/AST/StmtOpenMP.cpp
  clang/lib/CodeGen/CGStmtOpenMP.cpp
  clang/lib/Sema/SemaOpenMP.cpp

Index: clang/lib/Sema/SemaOpenMP.cpp
===================================================================
--- clang/lib/Sema/SemaOpenMP.cpp
+++ clang/lib/Sema/SemaOpenMP.cpp
@@ -11845,8 +11845,10 @@
   Expr *UE = nullptr;
   Expr *D = nullptr;
   Expr *CE = nullptr;
+  Expr *R = nullptr;
   bool IsXLHSInRHSPart = false;
   bool IsPostfixUpdate = false;
+  bool IsFailOnly = false;
   // OpenMP [2.12.6, atomic Construct]
   // In the next expressions:
   // * x and v (as applicable) are both l-value expressions with scalar type.
@@ -12242,8 +12244,15 @@
             << ErrorInfo.Error << ErrorInfo.NoteRange;
         return StmtError();
       }
-      // TODO: We don't set X, D, E, etc. here because in code gen we will emit
-      // error directly.
+      X = Checker.getX();
+      E = Checker.getE();
+      D = Checker.getD();
+      CE = Checker.getCond();
+      V = Checker.getV();
+      R = Checker.getR();
+      // We reuse IsXLHSInRHSPart to tell if it is in the form 'x ordop expr'.
+      IsXLHSInRHSPart = Checker.isXBinopExpr();
+      IsFailOnly = Checker.isFailOnly();
     } else {
       OpenMPAtomicCompareChecker::ErrorInfoTy ErrorInfo;
       OpenMPAtomicCompareChecker Checker(*this);
@@ -12266,8 +12275,8 @@
   setFunctionHasBranchProtectedScope();
 
   return OMPAtomicDirective::Create(Context, StartLoc, EndLoc, Clauses, AStmt,
-                                    X, V, E, UE, D, CE, IsXLHSInRHSPart,
-                                    IsPostfixUpdate);
+                                    X, V, R, E, UE, D, CE, IsXLHSInRHSPart,
+                                    IsPostfixUpdate, IsFailOnly);
 }
 
 StmtResult Sema::ActOnOpenMPTargetDirective(ArrayRef<OMPClause *> Clauses,
Index: clang/lib/CodeGen/CGStmtOpenMP.cpp
===================================================================
--- clang/lib/CodeGen/CGStmtOpenMP.cpp
+++ clang/lib/CodeGen/CGStmtOpenMP.cpp
@@ -6019,8 +6019,10 @@
 
 static void emitOMPAtomicCompareExpr(CodeGenFunction &CGF,
                                      llvm::AtomicOrdering AO, const Expr *X,
+                                     const Expr *V, const Expr *R,
                                      const Expr *E, const Expr *D,
                                      const Expr *CE, bool IsXBinopExpr,
+                                     bool IsPostfixUpdate, bool IsFailOnly,
                                      SourceLocation Loc) {
   llvm::OpenMPIRBuilder &OMPBuilder =
       CGF.CGM.getOpenMPRuntime().getOMPBuilder();
@@ -6050,17 +6052,26 @@
       XPtr, XPtr->getType()->getPointerElementType(),
       X->getType().isVolatileQualified(),
       X->getType()->hasSignedIntegerRepresentation()};
+  llvm::OpenMPIRBuilder::AtomicOpValue VOpVal{
+      XPtr, XPtr->getType()->getPointerElementType(),
+      X->getType().isVolatileQualified(),
+      X->getType()->hasSignedIntegerRepresentation()};
+  llvm::OpenMPIRBuilder::AtomicOpValue ROpVal{
+      XPtr, XPtr->getType()->getPointerElementType(),
+      X->getType().isVolatileQualified(),
+      X->getType()->hasSignedIntegerRepresentation()};
 
   CGF.Builder.restoreIP(OMPBuilder.createAtomicCompare(
-      CGF.Builder, XOpVal, EVal, DVal, AO, Op, IsXBinopExpr));
+      CGF.Builder, XOpVal, VOpVal, ROpVal, EVal, DVal, AO, Op, IsXBinopExpr,
+      IsPostfixUpdate, IsFailOnly));
 }
 
 static void emitOMPAtomicExpr(CodeGenFunction &CGF, OpenMPClauseKind Kind,
                               llvm::AtomicOrdering AO, bool IsPostfixUpdate,
-                              const Expr *X, const Expr *V, const Expr *E,
-                              const Expr *UE, const Expr *D, const Expr *CE,
-                              bool IsXLHSInRHSPart, bool IsCompareCapture,
-                              SourceLocation Loc) {
+                              const Expr *X, const Expr *V, const Expr *R,
+                              const Expr *E, const Expr *UE, const Expr *D,
+                              const Expr *CE, bool IsXLHSInRHSPart,
+                              bool IsFailOnly, SourceLocation Loc) {
   switch (Kind) {
   case OMPC_read:
     emitOMPAtomicReadExpr(CGF, AO, X, V, Loc);
@@ -6077,15 +6088,8 @@
                              IsXLHSInRHSPart, Loc);
     break;
   case OMPC_compare: {
-    if (IsCompareCapture) {
-      // Emit an error here.
-      unsigned DiagID = CGF.CGM.getDiags().getCustomDiagID(
-          DiagnosticsEngine::Error,
-          "'atomic compare capture' is not supported for now");
-      CGF.CGM.getDiags().Report(DiagID);
-    } else {
-      emitOMPAtomicCompareExpr(CGF, AO, X, E, D, CE, IsXLHSInRHSPart, Loc);
-    }
+    emitOMPAtomicCompareExpr(CGF, AO, X, V, R, E, D, CE, IsXLHSInRHSPart,
+                             IsPostfixUpdate, IsFailOnly, Loc);
     break;
   }
   case OMPC_if:
@@ -6210,12 +6214,12 @@
     Kind = K;
     KindsEncountered.insert(K);
   }
-  bool IsCompareCapture = false;
+  // We just need to correct Kind here. No need to set a bool saying it is
+  // actually compare capture because we can tell from whether V and R are
+  // nullptr.
   if (KindsEncountered.contains(OMPC_compare) &&
-      KindsEncountered.contains(OMPC_capture)) {
-    IsCompareCapture = true;
+      KindsEncountered.contains(OMPC_capture))
     Kind = OMPC_compare;
-  }
   if (!MemOrderingSpecified) {
     llvm::AtomicOrdering DefaultOrder =
         CGM.getOpenMPRuntime().getDefaultMemoryOrdering();
@@ -6237,8 +6241,9 @@
   LexicalScope Scope(*this, S.getSourceRange());
   EmitStopPoint(S.getAssociatedStmt());
   emitOMPAtomicExpr(*this, Kind, AO, S.isPostfixUpdate(), S.getX(), S.getV(),
-                    S.getExpr(), S.getUpdateExpr(), S.getD(), S.getCondExpr(),
-                    S.isXLHSInRHSPart(), IsCompareCapture, S.getBeginLoc());
+                    S.getR(), S.getExpr(), S.getUpdateExpr(), S.getD(),
+                    S.getCondExpr(), S.isXLHSInRHSPart(), S.isFailOnly(),
+                    S.getBeginLoc());
 }
 
 static void emitCommonOMPTargetDirective(CodeGenFunction &CGF,
Index: clang/lib/AST/StmtOpenMP.cpp
===================================================================
--- clang/lib/AST/StmtOpenMP.cpp
+++ clang/lib/AST/StmtOpenMP.cpp
@@ -863,22 +863,23 @@
                                                    !IsStandalone);
 }
 
-OMPAtomicDirective *
-OMPAtomicDirective::Create(const ASTContext &C, SourceLocation StartLoc,
-                           SourceLocation EndLoc, ArrayRef<OMPClause *> Clauses,
-                           Stmt *AssociatedStmt, Expr *X, Expr *V, Expr *E,
-                           Expr *UE, Expr *D, Expr *Cond, bool IsXLHSInRHSPart,
-                           bool IsPostfixUpdate) {
+OMPAtomicDirective *OMPAtomicDirective::Create(
+    const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
+    ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt, Expr *X, Expr *V,
+    Expr *R, Expr *E, Expr *UE, Expr *D, Expr *Cond, bool IsXLHSInRHSPart,
+    bool IsPostfixUpdate, bool IsFailOnly) {
   auto *Dir = createDirective<OMPAtomicDirective>(
-      C, Clauses, AssociatedStmt, /*NumChildren=*/6, StartLoc, EndLoc);
+      C, Clauses, AssociatedStmt, /*NumChildren=*/7, StartLoc, EndLoc);
   Dir->setX(X);
   Dir->setV(V);
+  Dir->setR(R);
   Dir->setExpr(E);
   Dir->setUpdateExpr(UE);
   Dir->setD(D);
   Dir->setCond(Cond);
   Dir->IsXLHSInRHSPart = IsXLHSInRHSPart;
   Dir->IsPostfixUpdate = IsPostfixUpdate;
+  Dir->IsFailOnly = IsFailOnly;
   return Dir;
 }
 
@@ -886,7 +887,7 @@
                                                     unsigned NumClauses,
                                                     EmptyShell) {
   return createEmptyDirective<OMPAtomicDirective>(
-      C, NumClauses, /*HasAssociatedStmt=*/true, /*NumChildren=*/6);
+      C, NumClauses, /*HasAssociatedStmt=*/true, /*NumChildren=*/7);
 }
 
 OMPTargetDirective *OMPTargetDirective::Create(const ASTContext &C,
Index: clang/include/clang/AST/StmtOpenMP.h
===================================================================
--- clang/include/clang/AST/StmtOpenMP.h
+++ clang/include/clang/AST/StmtOpenMP.h
@@ -2842,6 +2842,9 @@
   /// This field is true for the first(postfix) form of the expression and false
   /// otherwise.
   bool IsPostfixUpdate = false;
+  /// True if 'v' is updated only when the condition is false (compare capture
+  /// only).
+  bool IsFailOnly = false;
 
   /// Build directive with the given start and end location.
   ///
@@ -2865,6 +2868,7 @@
     POS_UpdateExpr,
     POS_D,
     POS_Cond,
+    POS_R,
   };
 
   /// Set 'x' part of the associated expression/statement.
@@ -2877,6 +2881,8 @@
   }
   /// Set 'v' part of the associated expression/statement.
   void setV(Expr *V) { Data->getChildren()[DataPositionTy::POS_V] = V; }
+  /// Set 'r' part of the associated expression/statement.
+  void setR(Expr *R) { Data->getChildren()[DataPositionTy::POS_R] = R; }
   /// Set 'expr' part of the associated expression/statement.
   void setExpr(Expr *E) { Data->getChildren()[DataPositionTy::POS_E] = E; }
   /// Set 'd' part of the associated expression/statement.
@@ -2896,6 +2902,7 @@
   /// \param AssociatedStmt Statement, associated with the directive.
   /// \param X 'x' part of the associated expression/statement.
   /// \param V 'v' part of the associated expression/statement.
+  /// \param R 'r' part of the associated expression/statement.
   /// \param E 'expr' part of the associated expression/statement.
   /// \param UE Helper expression of the form
   /// 'OpaqueValueExpr(x) binop OpaqueValueExpr(expr)' or
@@ -2909,8 +2916,8 @@
   static OMPAtomicDirective *
   Create(const ASTContext &C, SourceLocation StartLoc, SourceLocation EndLoc,
          ArrayRef<OMPClause *> Clauses, Stmt *AssociatedStmt, Expr *X, Expr *V,
-         Expr *E, Expr *UE, Expr *D, Expr *Cond, bool IsXLHSInRHSPart,
-         bool IsPostfixUpdate);
+         Expr *R, Expr *E, Expr *UE, Expr *D, Expr *Cond, bool IsXLHSInRHSPart,
+         bool IsPostfixUpdate, bool IsFailOnly);
 
   /// Creates an empty directive with the place for \a NumClauses
   /// clauses.
@@ -2943,6 +2950,9 @@
   /// 'OpaqueValueExpr(x) binop OpaqueValueExpr(expr)' and false if it has form
   /// 'OpaqueValueExpr(expr) binop OpaqueValueExpr(x)'.
   bool isXLHSInRHSPart() const { return IsXLHSInRHSPart; }
+  /// Return true if 'v' is updated only when the condition is evaluated false
+  /// (compare capture only).
+  bool isFailOnly() const { return IsFailOnly; }
   /// Return true if 'v' expression must be updated to original value of
   /// 'x', false if 'v' must be updated to the new value of 'x'.
   bool isPostfixUpdate() const { return IsPostfixUpdate; }
@@ -2953,6 +2963,13 @@
   const Expr *getV() const {
     return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_V]);
   }
+  /// Get 'r' part of the associated expression/statement.
+  Expr *getR() {
+    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_R]);
+  }
+  const Expr *getR() const {
+    return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_R]);
+  }
   /// Get 'expr' part of the associated expression/statement.
   Expr *getExpr() {
     return cast_or_null<Expr>(Data->getChildren()[DataPositionTy::POS_E]);
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
  • [PATCH] D120290: [WIP][Clang][... Shilei Tian via Phabricator via cfe-commits

Reply via email to