rsmith created this revision.
rsmith added reviewers: aaron.ballman, rjmccall.
Herald added a project: clang.

We used to get this wrong in three ways:

1. During parsing, an expression-statement followed by the }) ending a

statement expression was always treated as producing the value of the
statement expression. That's wrong for ({ if (1) expr; })

2. During template instantiation, various kinds of statement (most

statements not appearing directly in a compound-statement) were not
treated as discarded-value expressions, resulting in missing volatile
loads (etc).

3. In all contexts, an expression-statement with attributes was not

treated as producing the value of the statement expression, eg
({ [[attr]] expr; }).


Repository:
  rC Clang

https://reviews.llvm.org/D57984

Files:
  include/clang/AST/Expr.h
  include/clang/AST/Stmt.h
  include/clang/Basic/StmtNodes.td
  include/clang/Parse/Parser.h
  include/clang/Sema/Sema.h
  lib/AST/Stmt.cpp
  lib/CodeGen/CGStmt.cpp
  lib/Parse/ParseObjc.cpp
  lib/Parse/ParseStmt.cpp
  lib/Sema/SemaExpr.cpp
  lib/Sema/SemaStmt.cpp
  lib/Sema/TreeTransform.h
  test/CodeGenCXX/stmtexpr.cpp
  test/CodeGenCXX/volatile.cpp

Index: test/CodeGenCXX/volatile.cpp
===================================================================
--- test/CodeGenCXX/volatile.cpp
+++ test/CodeGenCXX/volatile.cpp
@@ -1,4 +1,4 @@
-// RUN: %clang_cc1 %s -triple=x86_64-apple-darwin10 -emit-llvm -std=c++98 -o - | FileCheck %s
+// RUN: %clang_cc1 %s -triple=x86_64-apple-darwin10 -emit-llvm -std=c++98 -o - | FileCheck -check-prefix=CHECK -check-prefix=CHECK98 %s
 // RUN: %clang_cc1 %s -triple=x86_64-apple-darwin10 -emit-llvm -std=c++11 -o - | FileCheck -check-prefix=CHECK -check-prefix=CHECK11 %s
 
 // Check that IR gen doesn't try to do an lvalue-to-rvalue conversion
@@ -33,3 +33,19 @@
     *x;
   }
 }
+
+namespace PR40642 {
+  template <class T> struct S {
+    // CHECK-LABEL: define {{.*}} @_ZN7PR406421SIiE3fooEv(
+    void foo() {
+      // CHECK98-NOT: load volatile
+      // CHECK11: load volatile
+      if (true)
+        reinterpret_cast<const volatile unsigned char *>(m_ptr)[0];
+      // CHECK: }
+    }
+    int *m_ptr;
+  };
+
+  void f(S<int> *x) { x->foo(); }
+}
Index: test/CodeGenCXX/stmtexpr.cpp
===================================================================
--- test/CodeGenCXX/stmtexpr.cpp
+++ test/CodeGenCXX/stmtexpr.cpp
@@ -190,3 +190,79 @@
 // CHECK: %[[v2:[^ ]*]] = load float, float* %[[tmp2]]
 // CHECK: store float %[[v1]], float* %v.realp
 // CHECK: store float %[[v2]], float* %v.imagp
+
+extern "C" void then(int);
+
+// CHECK-LABEL: @{{.*}}volatile_load
+void volatile_load() {
+  volatile int n;
+
+  // CHECK-NOT: load volatile
+  // CHECK: load volatile
+  // CHECK-NOT: load volatile
+  ({n;});
+
+  // CHECK-LABEL: @then(i32 1)
+  then(1);
+
+  // CHECK-NOT: load volatile
+  // CHECK: load volatile
+  // CHECK-NOT: load volatile
+  ({goto lab; lab: n;});
+
+  // CHECK-LABEL: @then(i32 2)
+  then(2);
+
+  // CHECK-NOT: load volatile
+  // CHECK: load volatile
+  // CHECK-NOT: load volatile
+  ({[[gsl::suppress("foo")]] n;});
+
+  // CHECK-LABEL: @then(i32 3)
+  then(3);
+
+  // CHECK-NOT: load volatile
+  // CHECK: load volatile
+  // CHECK-NOT: load volatile
+  ({if (true) n;});
+
+  // CHECK: }
+}
+
+// CHECK-LABEL: @{{.*}}volatile_load_template
+template<typename T>
+void volatile_load_template() {
+  volatile T n;
+
+  // CHECK-NOT: load volatile
+  // CHECK: load volatile
+  // CHECK-NOT: load volatile
+  ({n;});
+
+  // CHECK-LABEL: @then(i32 1)
+  then(1);
+
+  // CHECK-NOT: load volatile
+  // CHECK: load volatile
+  // CHECK-NOT: load volatile
+  ({goto lab; lab: n;});
+
+  // CHECK-LABEL: @then(i32 2)
+  then(2);
+
+  // CHECK-NOT: load volatile
+  // CHECK: load volatile
+  // CHECK-NOT: load volatile
+  ({[[gsl::suppress("foo")]] n;});
+
+  // CHECK-LABEL: @then(i32 3)
+  then(3);
+
+  // CHECK-NOT: load volatile
+  // CHECK: load volatile
+  // CHECK-NOT: load volatile
+  ({if (true) n;});
+
+  // CHECK: }
+}
+template void volatile_load_template<int>();
Index: lib/Sema/TreeTransform.h
===================================================================
--- lib/Sema/TreeTransform.h
+++ lib/Sema/TreeTransform.h
@@ -318,6 +318,13 @@
   TypeSourceInfo *TransformTypeWithDeducedTST(TypeSourceInfo *DI);
   /// @}
 
+  /// The reason why the value of a statement is not discarded, if any.
+  enum StmtDiscardKind {
+    SDK_Discarded,
+    SDK_NotDiscarded,
+    SDK_StmtExprResult,
+  };
+
   /// Transform the given statement.
   ///
   /// By default, this routine transforms a statement by delegating to the
@@ -327,7 +334,7 @@
   /// other mechanism.
   ///
   /// \returns the transformed statement.
-  StmtResult TransformStmt(Stmt *S, bool DiscardedValue = false);
+  StmtResult TransformStmt(Stmt *S, StmtDiscardKind SDK = SDK_Discarded);
 
   /// Transform the given statement.
   ///
@@ -672,6 +679,9 @@
 #define STMT(Node, Parent)                        \
   LLVM_ATTRIBUTE_NOINLINE \
   StmtResult Transform##Node(Node *S);
+#define VALUESTMT(Node, Parent)                   \
+  LLVM_ATTRIBUTE_NOINLINE \
+  StmtResult Transform##Node(Node *S, StmtDiscardKind SDK);
 #define EXPR(Node, Parent)                        \
   LLVM_ATTRIBUTE_NOINLINE \
   ExprResult Transform##Node(Node *E);
@@ -3270,7 +3280,7 @@
 };
 
 template <typename Derived>
-StmtResult TreeTransform<Derived>::TransformStmt(Stmt *S, bool DiscardedValue) {
+StmtResult TreeTransform<Derived>::TransformStmt(Stmt *S, StmtDiscardKind SDK) {
   if (!S)
     return S;
 
@@ -3278,8 +3288,12 @@
   case Stmt::NoStmtClass: break;
 
   // Transform individual statement nodes
+  // Pass SDK into statements that can produce a value
 #define STMT(Node, Parent)                                              \
   case Stmt::Node##Class: return getDerived().Transform##Node(cast<Node>(S));
+#define VALUESTMT(Node, Parent)                                         \
+  case Stmt::Node##Class:                                               \
+    return getDerived().Transform##Node(cast<Node>(S), SDK);
 #define ABSTRACT_STMT(Node)
 #define EXPR(Node, Parent)
 #include "clang/AST/StmtNodes.inc"
@@ -3291,10 +3305,10 @@
 #include "clang/AST/StmtNodes.inc"
     {
       ExprResult E = getDerived().TransformExpr(cast<Expr>(S));
-      if (E.isInvalid())
-        return StmtError();
 
-      return getSema().ActOnExprStmt(E, DiscardedValue);
+      if (SDK == SDK_StmtExprResult)
+        E = getSema().ActOnStmtExprResult(E);
+      return getSema().ActOnExprStmt(E, SDK == SDK_Discarded);
     }
   }
 
@@ -6522,8 +6536,9 @@
   bool SubStmtChanged = false;
   SmallVector<Stmt*, 8> Statements;
   for (auto *B : S->body()) {
-    StmtResult Result =
-        getDerived().TransformStmt(B, !IsStmtExpr || B != S->body_back());
+    StmtResult Result = getDerived().TransformStmt(
+        B,
+        IsStmtExpr && B == S->body_back() ? SDK_StmtExprResult : SDK_Discarded);
 
     if (Result.isInvalid()) {
       // Immediately fail if this was a DeclStmt, since it's very
@@ -6586,7 +6601,8 @@
     return StmtError();
 
   // Transform the statement following the case
-  StmtResult SubStmt = getDerived().TransformStmt(S->getSubStmt());
+  StmtResult SubStmt =
+      getDerived().TransformStmt(S->getSubStmt());
   if (SubStmt.isInvalid())
     return StmtError();
 
@@ -6594,11 +6610,11 @@
   return getDerived().RebuildCaseStmtBody(Case.get(), SubStmt.get());
 }
 
-template<typename Derived>
-StmtResult
-TreeTransform<Derived>::TransformDefaultStmt(DefaultStmt *S) {
+template <typename Derived>
+StmtResult TreeTransform<Derived>::TransformDefaultStmt(DefaultStmt *S) {
   // Transform the statement following the default case
-  StmtResult SubStmt = getDerived().TransformStmt(S->getSubStmt());
+  StmtResult SubStmt =
+      getDerived().TransformStmt(S->getSubStmt());
   if (SubStmt.isInvalid())
     return StmtError();
 
@@ -6609,8 +6625,8 @@
 
 template<typename Derived>
 StmtResult
-TreeTransform<Derived>::TransformLabelStmt(LabelStmt *S) {
-  StmtResult SubStmt = getDerived().TransformStmt(S->getSubStmt());
+TreeTransform<Derived>::TransformLabelStmt(LabelStmt *S, StmtDiscardKind SDK) {
+  StmtResult SubStmt = getDerived().TransformStmt(S->getSubStmt(), SDK);
   if (SubStmt.isInvalid())
     return StmtError();
 
@@ -6619,6 +6635,11 @@
   if (!LD)
     return StmtError();
 
+  // If we're transforming "in-place" (we're not creating new local
+  // declarations), assume we're replacing the old label statement
+  // and clear out the reference to it.
+  if (LD == S->getDecl())
+    S->getDecl()->setStmt(nullptr);
 
   // FIXME: Pass the real colon location in.
   return getDerived().RebuildLabelStmt(S->getIdentLoc(),
@@ -6644,7 +6665,9 @@
 }
 
 template <typename Derived>
-StmtResult TreeTransform<Derived>::TransformAttributedStmt(AttributedStmt *S) {
+StmtResult
+TreeTransform<Derived>::TransformAttributedStmt(AttributedStmt *S,
+                                                StmtDiscardKind SDK) {
   bool AttrsChanged = false;
   SmallVector<const Attr *, 1> Attrs;
 
@@ -6655,7 +6678,7 @@
     Attrs.push_back(R);
   }
 
-  StmtResult SubStmt = getDerived().TransformStmt(S->getSubStmt());
+  StmtResult SubStmt = getDerived().TransformStmt(S->getSubStmt(), SDK);
   if (SubStmt.isInvalid())
     return StmtError();
 
@@ -7360,7 +7383,8 @@
 TreeTransform<Derived>::TransformObjCForCollectionStmt(
                                                   ObjCForCollectionStmt *S) {
   // Transform the element statement.
-  StmtResult Element = getDerived().TransformStmt(S->getElement());
+  StmtResult Element =
+      getDerived().TransformStmt(S->getElement(), SDK_NotDiscarded);
   if (Element.isInvalid())
     return StmtError();
 
Index: lib/Sema/SemaStmt.cpp
===================================================================
--- lib/Sema/SemaStmt.cpp
+++ lib/Sema/SemaStmt.cpp
@@ -346,10 +346,6 @@
   return getCurFunction()->CompoundScopes.back();
 }
 
-bool Sema::isCurCompoundStmtAStmtExpr() const {
-  return getCurCompoundScope().IsStmtExpr;
-}
-
 StmtResult Sema::ActOnCompoundStmt(SourceLocation L, SourceLocation R,
                                    ArrayRef<Stmt *> Elts, bool isStmtExpr) {
   const unsigned NumElts = Elts.size();
Index: lib/Sema/SemaExpr.cpp
===================================================================
--- lib/Sema/SemaExpr.cpp
+++ lib/Sema/SemaExpr.cpp
@@ -13283,29 +13283,6 @@
                                      Context.getPointerType(Context.VoidTy));
 }
 
-/// Given the last statement in a statement-expression, check whether
-/// the result is a producing expression (like a call to an
-/// ns_returns_retained function) and, if so, rebuild it to hoist the
-/// release out of the full-expression.  Otherwise, return null.
-/// Cannot fail.
-static Expr *maybeRebuildARCConsumingStmt(Stmt *Statement) {
-  // Should always be wrapped with one of these.
-  ExprWithCleanups *cleanups = dyn_cast<ExprWithCleanups>(Statement);
-  if (!cleanups) return nullptr;
-
-  ImplicitCastExpr *cast = dyn_cast<ImplicitCastExpr>(cleanups->getSubExpr());
-  if (!cast || cast->getCastKind() != CK_ARCConsumeObject)
-    return nullptr;
-
-  // Splice out the cast.  This shouldn't modify any interesting
-  // features of the statement.
-  Expr *producer = cast->getSubExpr();
-  assert(producer->getType() == cast->getType());
-  assert(producer->getValueKind() == cast->getValueKind());
-  cleanups->setSubExpr(producer);
-  return cleanups;
-}
-
 void Sema::ActOnStartStmtExpr() {
   PushExpressionEvaluationContext(ExprEvalContexts.back().Context);
 }
@@ -13339,47 +13316,10 @@
   QualType Ty = Context.VoidTy;
   bool StmtExprMayBindToTemp = false;
   if (!Compound->body_empty()) {
-    Stmt *LastStmt = Compound->body_back();
-    LabelStmt *LastLabelStmt = nullptr;
-    // If LastStmt is a label, skip down through into the body.
-    while (LabelStmt *Label = dyn_cast<LabelStmt>(LastStmt)) {
-      LastLabelStmt = Label;
-      LastStmt = Label->getSubStmt();
-    }
-
-    if (Expr *LastE = dyn_cast<Expr>(LastStmt)) {
-      // Do function/array conversion on the last expression, but not
-      // lvalue-to-rvalue.  However, initialize an unqualified type.
-      ExprResult LastExpr = DefaultFunctionArrayConversion(LastE);
-      if (LastExpr.isInvalid())
-        return ExprError();
-      Ty = LastExpr.get()->getType().getUnqualifiedType();
-
-      if (!Ty->isDependentType() && !LastExpr.get()->isTypeDependent()) {
-        // In ARC, if the final expression ends in a consume, splice
-        // the consume out and bind it later.  In the alternate case
-        // (when dealing with a retainable type), the result
-        // initialization will create a produce.  In both cases the
-        // result will be +1, and we'll need to balance that out with
-        // a bind.
-        if (Expr *rebuiltLastStmt
-              = maybeRebuildARCConsumingStmt(LastExpr.get())) {
-          LastExpr = rebuiltLastStmt;
-        } else {
-          LastExpr = PerformCopyInitialization(
-              InitializedEntity::InitializeStmtExprResult(LPLoc, Ty),
-              SourceLocation(), LastExpr);
-        }
-
-        if (LastExpr.isInvalid())
-          return ExprError();
-        if (LastExpr.get() != nullptr) {
-          if (!LastLabelStmt)
-            Compound->setLastStmt(LastExpr.get());
-          else
-            LastLabelStmt->setSubStmt(LastExpr.get());
-          StmtExprMayBindToTemp = true;
-        }
+    if (ValueStmt *LastStmt = dyn_cast<ValueStmt>(Compound->body_back())) {
+      if (Expr *Value = LastStmt->getExprStmt()) {
+        StmtExprMayBindToTemp = true;
+        Ty = Value->getType();
       }
     }
   }
@@ -13392,6 +13332,37 @@
   return ResStmtExpr;
 }
 
+ExprResult Sema::ActOnStmtExprResult(ExprResult ER) {
+  if (ER.isInvalid())
+    return ExprError();
+
+  // Do function/array conversion on the last expression, but not
+  // lvalue-to-rvalue.  However, initialize an unqualified type.
+  ER = DefaultFunctionArrayConversion(ER.get());
+  if (ER.isInvalid())
+    return ExprError();
+  Expr *E = ER.get();
+
+  if (E->isTypeDependent())
+    return E;
+
+  // In ARC, if the final expression ends in a consume, splice
+  // the consume out and bind it later.  In the alternate case
+  // (when dealing with a retainable type), the result
+  // initialization will create a produce.  In both cases the
+  // result will be +1, and we'll need to balance that out with
+  // a bind.
+  ImplicitCastExpr *Cast = dyn_cast<ImplicitCastExpr>(E);
+  if (Cast && Cast->getCastKind() == CK_ARCConsumeObject)
+    return Cast->getSubExpr();
+
+  // FIXME: Provide a better location for the initialization.
+  return PerformCopyInitialization(
+      InitializedEntity::InitializeStmtExprResult(
+          E->getBeginLoc(), E->getType().getUnqualifiedType()),
+      SourceLocation(), E);
+}
+
 ExprResult Sema::BuildBuiltinOffsetOf(SourceLocation BuiltinLoc,
                                       TypeSourceInfo *TInfo,
                                       ArrayRef<OffsetOfComponent> Components,
@@ -14472,14 +14443,6 @@
     // Make sure we redo semantic analysis
     bool AlwaysRebuild() { return true; }
 
-    // Make sure we handle LabelStmts correctly.
-    // FIXME: This does the right thing, but maybe we need a more general
-    // fix to TreeTransform?
-    StmtResult TransformLabelStmt(LabelStmt *S) {
-      S->getDecl()->setStmt(nullptr);
-      return BaseTransform::TransformLabelStmt(S);
-    }
-
     // We need to special-case DeclRefExprs referring to FieldDecls which
     // are not part of a member pointer formation; normal TreeTransforming
     // doesn't catch this case because of the way we represent them in the AST.
Index: lib/Parse/ParseStmt.cpp
===================================================================
--- lib/Parse/ParseStmt.cpp
+++ lib/Parse/ParseStmt.cpp
@@ -29,7 +29,8 @@
 /// Parse a standalone statement (for instance, as the body of an 'if',
 /// 'while', or 'for').
 StmtResult Parser::ParseStatement(SourceLocation *TrailingElseLoc,
-                                  bool AllowOpenMPStandalone) {
+                                  bool AllowOpenMPStandalone,
+                                  WithinStmtExpr IsInStmtExpr) {
   StmtResult Res;
 
   // We may get back a null statement if we found a #pragma. Keep going until
@@ -39,7 +40,7 @@
     Res = ParseStatementOrDeclaration(
         Stmts, AllowOpenMPStandalone ? ACK_StatementsOpenMPAnyExecutable
                                      : ACK_StatementsOpenMPNonStandalone,
-        TrailingElseLoc);
+        TrailingElseLoc, IsInStmtExpr);
   } while (!Res.isInvalid() && !Res.get());
 
   return Res;
@@ -97,7 +98,8 @@
 StmtResult
 Parser::ParseStatementOrDeclaration(StmtVector &Stmts,
                                     AllowedConstructsKind Allowed,
-                                    SourceLocation *TrailingElseLoc) {
+                                    SourceLocation *TrailingElseLoc,
+                                    WithinStmtExpr IsInStmtExpr) {
 
   ParenBraceBracketBalancer BalancerRAIIObj(*this);
 
@@ -107,7 +109,7 @@
     return StmtError();
 
   StmtResult Res = ParseStatementOrDeclarationAfterAttributes(
-      Stmts, Allowed, TrailingElseLoc, Attrs);
+      Stmts, Allowed, TrailingElseLoc, IsInStmtExpr, Attrs);
 
   assert((Attrs.empty() || Res.isInvalid() || Res.isUsable()) &&
          "attributes on empty statement");
@@ -147,10 +149,10 @@
 };
 }
 
-StmtResult
-Parser::ParseStatementOrDeclarationAfterAttributes(StmtVector &Stmts,
-          AllowedConstructsKind Allowed, SourceLocation *TrailingElseLoc,
-          ParsedAttributesWithRange &Attrs) {
+StmtResult Parser::ParseStatementOrDeclarationAfterAttributes(
+    StmtVector &Stmts, AllowedConstructsKind Allowed,
+    SourceLocation *TrailingElseLoc, WithinStmtExpr IsInStmtExpr,
+    ParsedAttributesWithRange &Attrs) {
   const char *SemiError = nullptr;
   StmtResult Res;
 
@@ -165,7 +167,7 @@
     {
       ProhibitAttributes(Attrs); // TODO: is it correct?
       AtLoc = ConsumeToken();  // consume @
-      return ParseObjCAtStatement(AtLoc);
+      return ParseObjCAtStatement(AtLoc, IsInStmtExpr);
     }
 
   case tok::code_completion:
@@ -177,7 +179,7 @@
     Token Next = NextToken();
     if (Next.is(tok::colon)) { // C99 6.8.1: labeled-statement
       // identifier ':' statement
-      return ParseLabeledStatement(Attrs);
+      return ParseLabeledStatement(Attrs, IsInStmtExpr);
     }
 
     // Look up the identifier, and typo-correct it to a keyword if it's not
@@ -220,7 +222,7 @@
       return StmtError();
     }
 
-    return ParseExprStatement();
+    return ParseExprStatement(IsInStmtExpr);
   }
 
   case tok::kw_case:                // C99 6.8.1: labeled-statement
@@ -382,7 +384,8 @@
 
   case tok::annot_pragma_loop_hint:
     ProhibitAttributes(Attrs);
-    return ParsePragmaLoopHint(Stmts, Allowed, TrailingElseLoc, Attrs);
+    return ParsePragmaLoopHint(Stmts, Allowed, TrailingElseLoc, IsInStmtExpr,
+                               Attrs);
 
   case tok::annot_pragma_dump:
     HandlePragmaDump();
@@ -407,7 +410,7 @@
 }
 
 /// Parse an expression statement.
-StmtResult Parser::ParseExprStatement() {
+StmtResult Parser::ParseExprStatement(WithinStmtExpr IsInStmtExpr) {
   // If a case keyword is missing, this is where it should be inserted.
   Token OldToken = Tok;
 
@@ -438,7 +441,7 @@
 
   // Otherwise, eat the semicolon.
   ExpectAndConsumeSemi(diag::err_expected_semi_after_expr);
-  return Actions.ActOnExprStmt(Expr, isExprValueDiscarded());
+  return handleExprStmt(Expr, IsInStmtExpr);
 }
 
 /// ParseSEHTryBlockCommon
@@ -577,7 +580,8 @@
 ///         identifier ':' statement
 /// [GNU]   identifier ':' attributes[opt] statement
 ///
-StmtResult Parser::ParseLabeledStatement(ParsedAttributesWithRange &attrs) {
+StmtResult Parser::ParseLabeledStatement(ParsedAttributesWithRange &attrs,
+                                         WithinStmtExpr IsInStmtExpr) {
   assert(Tok.is(tok::identifier) && Tok.getIdentifierInfo() &&
          "Not an identifier!");
 
@@ -612,7 +616,7 @@
       // GNU attributes are allowed.
       SubStmt = ParseStatementOrDeclarationAfterAttributes(
           Stmts, /*Allowed=*/ACK_StatementsOpenMPNonStandalone, nullptr,
-          TempAttrs);
+          IsInStmtExpr, TempAttrs);
       if (!TempAttrs.empty() && !SubStmt.isInvalid())
         SubStmt = Actions.ProcessStmtAttributes(SubStmt.get(), TempAttrs,
                                                 TempAttrs.Range);
@@ -623,7 +627,7 @@
 
   // If we've not parsed a statement yet, parse one now.
   if (!SubStmt.isInvalid() && !SubStmt.isUsable())
-    SubStmt = ParseStatement();
+    SubStmt = ParseStatement(nullptr, false, IsInStmtExpr);
 
   // Broken substmt shouldn't prevent the label from being added to the AST.
   if (SubStmt.isInvalid())
@@ -957,14 +961,18 @@
   return true;
 }
 
-bool Parser::isExprValueDiscarded() {
-  if (Actions.isCurCompoundStmtAStmtExpr()) {
-    // Look to see if the next two tokens close the statement expression;
+StmtResult Parser::handleExprStmt(ExprResult E, WithinStmtExpr IsInStmtExpr) {
+  bool IsStmtExprResult = false;
+  if (IsInStmtExpr == WithinStmtExpr::InStmtExpr) {
+    // Look ahead to see if the next two tokens close the statement expression;
     // if so, this expression statement is the last statement in a
     // statment expression.
-    return Tok.isNot(tok::r_brace) || NextToken().isNot(tok::r_paren);
+    IsStmtExprResult = Tok.is(tok::r_brace) && NextToken().is(tok::r_paren);
   }
-  return true;
+
+  if (IsStmtExprResult)
+    E = Actions.ActOnStmtExprResult(E);
+  return Actions.ActOnExprStmt(E, /*DiscardedValue=*/!IsStmtExprResult);
 }
 
 /// ParseCompoundStatementBody - Parse a sequence of statements and invoke the
@@ -1022,6 +1030,9 @@
       Stmts.push_back(R.get());
   }
 
+  WithinStmtExpr IsInStmtExpr =
+      isStmtExpr ? WithinStmtExpr::InStmtExpr : WithinStmtExpr::NotInStmtExpr;
+
   while (!tryParseMisplacedModuleImport() && Tok.isNot(tok::r_brace) &&
          Tok.isNot(tok::eof)) {
     if (Tok.is(tok::annot_pragma_unused)) {
@@ -1034,7 +1045,7 @@
 
     StmtResult R;
     if (Tok.isNot(tok::kw___extension__)) {
-      R = ParseStatementOrDeclaration(Stmts, ACK_Any);
+      R = ParseStatementOrDeclaration(Stmts, ACK_Any, nullptr, IsInStmtExpr);
     } else {
       // __extension__ can start declarations and it can also be a unary
       // operator for expressions.  Consume multiple __extension__ markers here
@@ -1067,11 +1078,12 @@
           continue;
         }
 
-        // FIXME: Use attributes?
         // Eat the semicolon at the end of stmt and convert the expr into a
         // statement.
         ExpectAndConsumeSemi(diag::err_expected_semi_after_expr);
-        R = Actions.ActOnExprStmt(Res, isExprValueDiscarded());
+        R = handleExprStmt(Res, IsInStmtExpr);
+        if (R.isUsable())
+          R = Actions.ProcessStmtAttributes(R.get(), attrs, attrs.Range);
       }
     }
 
@@ -2003,6 +2015,7 @@
 StmtResult Parser::ParsePragmaLoopHint(StmtVector &Stmts,
                                        AllowedConstructsKind Allowed,
                                        SourceLocation *TrailingElseLoc,
+                                       WithinStmtExpr IsInStmtExpr,
                                        ParsedAttributesWithRange &Attrs) {
   // Create temporary attribute list.
   ParsedAttributesWithRange TempAttrs(AttrFactory);
@@ -2024,7 +2037,7 @@
   MaybeParseCXX11Attributes(Attrs);
 
   StmtResult S = ParseStatementOrDeclarationAfterAttributes(
-      Stmts, Allowed, TrailingElseLoc, Attrs);
+      Stmts, Allowed, TrailingElseLoc, IsInStmtExpr, Attrs);
 
   Attrs.takeAllFrom(TempAttrs);
   return S;
Index: lib/Parse/ParseObjc.cpp
===================================================================
--- lib/Parse/ParseObjc.cpp
+++ lib/Parse/ParseObjc.cpp
@@ -2703,7 +2703,8 @@
   return MDecl;
 }
 
-StmtResult Parser::ParseObjCAtStatement(SourceLocation AtLoc) {
+StmtResult Parser::ParseObjCAtStatement(SourceLocation AtLoc,
+                                        WithinStmtExpr IsInStmtExpr) {
   if (Tok.is(tok::code_completion)) {
     Actions.CodeCompleteObjCAtStatement(getCurScope());
     cutOffParsing();
@@ -2740,7 +2741,7 @@
 
   // Otherwise, eat the semicolon.
   ExpectAndConsumeSemi(diag::err_expected_semi_after_expr);
-  return Actions.ActOnExprStmt(Res, isExprValueDiscarded());
+  return handleExprStmt(Res, IsInStmtExpr);
 }
 
 ExprResult Parser::ParseObjCAtExpression(SourceLocation AtLoc) {
Index: lib/CodeGen/CGStmt.cpp
===================================================================
--- lib/CodeGen/CGStmt.cpp
+++ lib/CodeGen/CGStmt.cpp
@@ -391,24 +391,38 @@
     // at the end of a statement expression, they yield the value of their
     // subexpression.  Handle this by walking through all labels we encounter,
     // emitting them before we evaluate the subexpr.
+    // Similar issues arise for attributed statements.
     const Stmt *LastStmt = S.body_back();
-    while (const LabelStmt *LS = dyn_cast<LabelStmt>(LastStmt)) {
-      EmitLabel(LS->getDecl());
-      LastStmt = LS->getSubStmt();
+    const Expr *E = nullptr;
+    while (true) {
+      if (isa<Expr>(LastStmt)) {
+        E = cast<Expr>(LastStmt);
+        break;
+      } else if (const LabelStmt *LS = dyn_cast<LabelStmt>(LastStmt)) {
+        EmitLabel(LS->getDecl());
+        LastStmt = LS->getSubStmt();
+      } else if (const AttributedStmt *AS =
+                     dyn_cast<AttributedStmt>(LastStmt)) {
+        // FIXME: Update this if we ever have attributes that affect the
+        // semantics of an expression.
+        LastStmt = AS->getSubStmt();
+      } else {
+        llvm_unreachable("unknown value statement");
+      }
     }
 
     EnsureInsertPoint();
 
-    QualType ExprTy = cast<Expr>(LastStmt)->getType();
+    QualType ExprTy = E->getType();
     if (hasAggregateEvaluationKind(ExprTy)) {
-      EmitAggExpr(cast<Expr>(LastStmt), AggSlot);
+      EmitAggExpr(E, AggSlot);
     } else {
       // We can't return an RValue here because there might be cleanups at
       // the end of the StmtExpr.  Because of that, we have to emit the result
       // here into a temporary alloca.
       RetAlloca = CreateMemTemp(ExprTy);
-      EmitAnyExprToMem(cast<Expr>(LastStmt), RetAlloca, Qualifiers(),
-                       /*IsInit*/false);
+      EmitAnyExprToMem(E, RetAlloca, Qualifiers(),
+                       /*IsInit*/ false);
     }
 
   }
Index: lib/AST/Stmt.cpp
===================================================================
--- lib/AST/Stmt.cpp
+++ lib/AST/Stmt.cpp
@@ -320,6 +320,23 @@
   return New;
 }
 
+const Expr *ValueStmt::getExprStmt() const {
+  const Stmt *S = this;
+  do {
+    if (const auto *E = dyn_cast<Expr>(S))
+      return E;
+
+    if (const auto *LS = dyn_cast<LabelStmt>(S))
+      S = LS->getSubStmt();
+    else if (const auto *AS = dyn_cast<AttributedStmt>(S))
+      S = AS->getSubStmt();
+    else
+      llvm_unreachable("unknown kind of ValueStmt");
+  } while (isa<ValueStmt>(S));
+
+  return nullptr;
+}
+
 const char *LabelStmt::getName() const {
   return getDecl()->getIdentifier()->getNameStart();
 }
Index: include/clang/Sema/Sema.h
===================================================================
--- include/clang/Sema/Sema.h
+++ include/clang/Sema/Sema.h
@@ -1403,7 +1403,6 @@
   void PopCompoundScope();
 
   sema::CompoundScopeInfo &getCurCompoundScope() const;
-  bool isCurCompoundStmtAStmtExpr() const;
 
   bool hasAnyUnrecoverableErrorsInThisFunction() const;
 
@@ -4533,6 +4532,8 @@
   void ActOnStartStmtExpr();
   ExprResult ActOnStmtExpr(SourceLocation LPLoc, Stmt *SubStmt,
                            SourceLocation RPLoc); // "({..})"
+  // Handle the final expression in a statement expression.
+  ExprResult ActOnStmtExprResult(ExprResult E);
   void ActOnStmtExprError();
 
   // __builtin_offsetof(type, identifier(.identifier|[expr])*)
Index: include/clang/Parse/Parser.h
===================================================================
--- include/clang/Parse/Parser.h
+++ include/clang/Parse/Parser.h
@@ -363,10 +363,15 @@
   /// just a regular sub-expression.
   SourceLocation ExprStatementTokLoc;
 
-  /// Tests whether an expression value is discarded based on token lookahead.
-  /// It will return true if the lexer is currently processing the })
-  /// terminating a GNU statement expression and false otherwise.
-  bool isExprValueDiscarded();
+  /// Whether we are parsing a "top-level" statement in a statement expression.
+  /// The statement within a labeled expression in a statement expression is
+  /// considered to be "top-level".
+  enum class WithinStmtExpr { NotInStmtExpr, InStmtExpr };
+
+  /// Act on an expression statement that might be the last statement in a
+  /// GNU statement expression. Checks whether we are actually at the end of
+  /// a statement expression and builds a suitable expression statement.
+  StmtResult handleExprStmt(ExprResult E, WithinStmtExpr IsInStmtExpr);
 
 public:
   Parser(Preprocessor &PP, Sema &Actions, bool SkipFunctionBodies);
@@ -1873,8 +1878,10 @@
   /// A SmallVector of types.
   typedef SmallVector<ParsedType, 12> TypeVector;
 
-  StmtResult ParseStatement(SourceLocation *TrailingElseLoc = nullptr,
-                            bool AllowOpenMPStandalone = false);
+  StmtResult
+  ParseStatement(SourceLocation *TrailingElseLoc = nullptr,
+                 bool AllowOpenMPStandalone = false,
+                 WithinStmtExpr IsInStmtExpr = WithinStmtExpr::NotInStmtExpr);
   enum AllowedConstructsKind {
     /// Allow any declarations, statements, OpenMP directives.
     ACK_Any,
@@ -1883,16 +1890,19 @@
     /// Allow statements and all executable OpenMP directives
     ACK_StatementsOpenMPAnyExecutable
   };
-  StmtResult
-  ParseStatementOrDeclaration(StmtVector &Stmts, AllowedConstructsKind Allowed,
-                              SourceLocation *TrailingElseLoc = nullptr);
+  StmtResult ParseStatementOrDeclaration(
+      StmtVector &Stmts, AllowedConstructsKind Allowed,
+      SourceLocation *TrailingElseLoc = nullptr,
+      WithinStmtExpr IsInStmtExpr = WithinStmtExpr::NotInStmtExpr);
   StmtResult ParseStatementOrDeclarationAfterAttributes(
                                          StmtVector &Stmts,
                                          AllowedConstructsKind Allowed,
                                          SourceLocation *TrailingElseLoc,
+                                         WithinStmtExpr IsInStmtExpr,
                                          ParsedAttributesWithRange &Attrs);
-  StmtResult ParseExprStatement();
-  StmtResult ParseLabeledStatement(ParsedAttributesWithRange &attrs);
+  StmtResult ParseExprStatement(WithinStmtExpr IsInStmtExpr);
+  StmtResult ParseLabeledStatement(ParsedAttributesWithRange &attrs,
+                                   WithinStmtExpr IsInStmtExpr);
   StmtResult ParseCaseStatement(bool MissingCase = false,
                                 ExprResult Expr = ExprResult());
   StmtResult ParseDefaultStatement();
@@ -1920,6 +1930,7 @@
   StmtResult ParsePragmaLoopHint(StmtVector &Stmts,
                                  AllowedConstructsKind Allowed,
                                  SourceLocation *TrailingElseLoc,
+                                 WithinStmtExpr IsInStmtExpr,
                                  ParsedAttributesWithRange &Attrs);
 
   /// Describes the behavior that should be taken for an __if_exists
@@ -1984,7 +1995,8 @@
   //===--------------------------------------------------------------------===//
   // Objective-C Statements
 
-  StmtResult ParseObjCAtStatement(SourceLocation atLoc);
+  StmtResult ParseObjCAtStatement(SourceLocation atLoc,
+                                  WithinStmtExpr IsInStmtExpr);
   StmtResult ParseObjCTryStmt(SourceLocation atLoc);
   StmtResult ParseObjCThrowStmt(SourceLocation atLoc);
   StmtResult ParseObjCSynchronizedStmt(SourceLocation atLoc);
Index: include/clang/Basic/StmtNodes.td
===================================================================
--- include/clang/Basic/StmtNodes.td
+++ include/clang/Basic/StmtNodes.td
@@ -11,8 +11,6 @@
 // Statements
 def NullStmt : Stmt;
 def CompoundStmt : Stmt;
-def LabelStmt : Stmt;
-def AttributedStmt : Stmt;
 def IfStmt : Stmt;
 def SwitchStmt : Stmt;
 def WhileStmt : Stmt;
@@ -29,6 +27,12 @@
 def DefaultStmt : DStmt<SwitchCase>;
 def CapturedStmt : Stmt;
 
+// Statements that might produce a value (for example, as the last non-null
+// statement in a GNU statement-expression).
+def ValueStmt : Stmt<1>;
+def LabelStmt : DStmt<ValueStmt>;
+def AttributedStmt : DStmt<ValueStmt>;
+
 // Asm statements
 def AsmStmt : Stmt<1>;
 def GCCAsmStmt : DStmt<AsmStmt>;
@@ -53,7 +57,7 @@
 def CoreturnStmt : Stmt;
 
 // Expressions
-def Expr : Stmt<1>;
+def Expr : DStmt<ValueStmt, 1>;
 def PredefinedExpr : DStmt<Expr>;
 def DeclRefExpr : DStmt<Expr>;
 def IntegerLiteral : DStmt<Expr>;
Index: include/clang/AST/Stmt.h
===================================================================
--- include/clang/AST/Stmt.h
+++ include/clang/AST/Stmt.h
@@ -1584,21 +1584,44 @@
   llvm_unreachable("SwitchCase is neither a CaseStmt nor a DefaultStmt!");
 }
 
+/// Represents a statement that could possibly have a value and type. This
+/// covers expression-statements, as well as labels and attributed statements.
+///
+/// Value statements have a special meaning when they are the last non-null
+/// statement in a GNU statement expression, where they determine the value
+/// of the statement expression.
+class ValueStmt : public Stmt {
+protected:
+  using Stmt::Stmt;
+
+public:
+  const Expr *getExprStmt() const;
+  Expr *getExprStmt() {
+    const ValueStmt *ConstThis = this;
+    return const_cast<Expr*>(ConstThis->getExprStmt());
+  }
+
+  static bool classof(const Stmt *T) {
+    return T->getStmtClass() >= firstValueStmtConstant &&
+           T->getStmtClass() <= lastValueStmtConstant;
+  }
+};
+
 /// LabelStmt - Represents a label, which has a substatement.  For example:
 ///    foo: return;
-class LabelStmt : public Stmt {
+class LabelStmt : public ValueStmt {
   LabelDecl *TheDecl;
   Stmt *SubStmt;
 
 public:
   /// Build a label statement.
   LabelStmt(SourceLocation IL, LabelDecl *D, Stmt *substmt)
-      : Stmt(LabelStmtClass), TheDecl(D), SubStmt(substmt) {
+      : ValueStmt(LabelStmtClass), TheDecl(D), SubStmt(substmt) {
     setIdentLoc(IL);
   }
 
   /// Build an empty label statement.
-  explicit LabelStmt(EmptyShell Empty) : Stmt(LabelStmtClass, Empty) {}
+  explicit LabelStmt(EmptyShell Empty) : ValueStmt(LabelStmtClass, Empty) {}
 
   SourceLocation getIdentLoc() const { return LabelStmtBits.IdentLoc; }
   void setIdentLoc(SourceLocation L) { LabelStmtBits.IdentLoc = L; }
@@ -1627,7 +1650,7 @@
 /// Represents an attribute applied to a statement. For example:
 ///   [[omp::for(...)]] for (...) { ... }
 class AttributedStmt final
-    : public Stmt,
+    : public ValueStmt,
       private llvm::TrailingObjects<AttributedStmt, const Attr *> {
   friend class ASTStmtReader;
   friend TrailingObjects;
@@ -1636,14 +1659,14 @@
 
   AttributedStmt(SourceLocation Loc, ArrayRef<const Attr *> Attrs,
                  Stmt *SubStmt)
-      : Stmt(AttributedStmtClass), SubStmt(SubStmt) {
+      : ValueStmt(AttributedStmtClass), SubStmt(SubStmt) {
     AttributedStmtBits.NumAttrs = Attrs.size();
     AttributedStmtBits.AttrLoc = Loc;
     std::copy(Attrs.begin(), Attrs.end(), getAttrArrayPtr());
   }
 
   explicit AttributedStmt(EmptyShell Empty, unsigned NumAttrs)
-      : Stmt(AttributedStmtClass, Empty) {
+      : ValueStmt(AttributedStmtClass, Empty) {
     AttributedStmtBits.NumAttrs = NumAttrs;
     AttributedStmtBits.AttrLoc = SourceLocation{};
     std::fill_n(getAttrArrayPtr(), NumAttrs, nullptr);
Index: include/clang/AST/Expr.h
===================================================================
--- include/clang/AST/Expr.h
+++ include/clang/AST/Expr.h
@@ -105,13 +105,13 @@
 /// This represents one expression.  Note that Expr's are subclasses of Stmt.
 /// This allows an expression to be transparently used any place a Stmt is
 /// required.
-class Expr : public Stmt {
+class Expr : public ValueStmt {
   QualType TR;
 
 protected:
   Expr(StmtClass SC, QualType T, ExprValueKind VK, ExprObjectKind OK,
        bool TD, bool VD, bool ID, bool ContainsUnexpandedParameterPack)
-    : Stmt(SC)
+    : ValueStmt(SC)
   {
     ExprBits.TypeDependent = TD;
     ExprBits.ValueDependent = VD;
@@ -124,7 +124,7 @@
   }
 
   /// Construct an empty expression.
-  explicit Expr(StmtClass SC, EmptyShell) : Stmt(SC) { }
+  explicit Expr(StmtClass SC, EmptyShell) : ValueStmt(SC) { }
 
 public:
   QualType getType() const { return TR; }
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to