ymandel created this revision.
ymandel added a reviewer: gribozavr2.
ymandel requested review of this revision.
Herald added a project: clang.

This patch adds a `buildAccess` function, which constructs a string with the
proper operator to use based on the expression's form and type. It also adds two
predicates related to smart pointers, which are needed by `buildAccess` but are
also of general value.


Repository:
  rG LLVM Github Monorepo

https://reviews.llvm.org/D116377

Files:
  clang/include/clang/Tooling/Transformer/SourceCodeBuilders.h
  clang/lib/Tooling/Transformer/SourceCodeBuilders.cpp
  clang/lib/Tooling/Transformer/Stencil.cpp
  clang/unittests/Tooling/SourceCodeBuildersTest.cpp

Index: clang/unittests/Tooling/SourceCodeBuildersTest.cpp
===================================================================
--- clang/unittests/Tooling/SourceCodeBuildersTest.cpp
+++ clang/unittests/Tooling/SourceCodeBuildersTest.cpp
@@ -7,6 +7,7 @@
 //===----------------------------------------------------------------------===//
 
 #include "clang/Tooling/Transformer/SourceCodeBuilders.h"
+#include "clang/AST/Type.h"
 #include "clang/ASTMatchers/ASTMatchFinder.h"
 #include "clang/ASTMatchers/ASTMatchers.h"
 #include "clang/Tooling/Tooling.h"
@@ -24,8 +25,16 @@
 
 // Create a valid translation unit from a statement.
 static std::string wrapSnippet(StringRef StatementCode) {
-  return ("struct S { S(); S(int); int field; };\n"
+  return ("namespace std {\n"
+          "template <typename T> class unique_ptr {};\n"
+          "template <typename T> class shared_ptr {};\n"
+          "}\n"
+          "struct S { S(); S(int); int field; };\n"
           "S operator+(const S &a, const S &b);\n"
+          "struct Smart {\n"
+          "  S* operator->() const;\n"
+          "  S& operator*() const;\n"
+          "};\n"
           "auto test_snippet = []{" +
           StatementCode + "};")
       .str();
@@ -126,6 +135,69 @@
   testPredicateOnArg(mayEverNeedParens, "void f(S); f(3 + 5);", true);
 }
 
+TEST(SourceCodeBuildersTest, isSmartPointerTypeUniquePtr) {
+  std::string Snippet = "std::unique_ptr<int> P; P;";
+  auto StmtMatch = matchStmt(Snippet, expr(hasType(qualType().bind("ty"))));
+  ASSERT_TRUE(StmtMatch) << "Snippet: " << Snippet;
+  EXPECT_TRUE(
+      isSmartPointerType(*StmtMatch->Result.Nodes.getNodeAs<QualType>("ty"),
+                         *StmtMatch->Result.Context))
+      << "Snippet: " << Snippet;
+}
+
+TEST(SourceCodeBuildersTest, isSmartPointerTypeSharedPtr) {
+  std::string Snippet = "std::shared_ptr<int> P; P;";
+  auto StmtMatch = matchStmt(Snippet, expr(hasType(qualType().bind("ty"))));
+  ASSERT_TRUE(StmtMatch) << "Snippet: " << Snippet;
+  EXPECT_TRUE(
+      isSmartPointerType(*StmtMatch->Result.Nodes.getNodeAs<QualType>("ty"),
+                         *StmtMatch->Result.Context))
+      << "Snippet: " << Snippet;
+}
+
+TEST(SourceCodeBuildersTest, isSmartPointerTypeDuckType) {
+  std::string Snippet = "Smart P; P;";
+  auto StmtMatch = matchStmt(Snippet, expr(hasType(qualType().bind("ty"))));
+  ASSERT_TRUE(StmtMatch) << "Snippet: " << Snippet;
+  EXPECT_TRUE(
+      isSmartPointerType(*StmtMatch->Result.Nodes.getNodeAs<QualType>("ty"),
+                         *StmtMatch->Result.Context))
+      << "Snippet: " << Snippet;
+}
+
+TEST(SourceCodeBuildersTest, isSmartPointerTypeNormalTypeFalse) {
+  std::string Snippet = "int *P; P;";
+  auto StmtMatch = matchStmt(Snippet, expr(hasType(qualType().bind("ty"))));
+  ASSERT_TRUE(StmtMatch) << "Snippet: " << Snippet;
+  EXPECT_FALSE(
+      isSmartPointerType(*StmtMatch->Result.Nodes.getNodeAs<QualType>("ty"),
+                         *StmtMatch->Result.Context))
+      << "Snippet: " << Snippet;
+}
+
+TEST(SourceCodeBuildersTest, isSmartDereferenceTrue) {
+  std::string Snippet = "Smart P; *P;";
+  auto StmtMatch = matchStmt(
+      Snippet, expr(cxxOperatorCallExpr(hasUnaryOperand(expr().bind("arg"))))
+                   .bind("expr"));
+  ASSERT_TRUE(StmtMatch) << "Snippet: " << Snippet;
+  const auto *Arg = StmtMatch->Result.Nodes.getNodeAs<Expr>("arg");
+  EXPECT_EQ(Arg,
+            isSmartDereference(*StmtMatch->Result.Nodes.getNodeAs<Expr>("expr"),
+                               *StmtMatch->Result.Context))
+      << "Snippet: " << Snippet;
+}
+
+TEST(SourceCodeBuildersTest, isSmartDereferenceFalse) {
+  std::string Snippet = "int *P; *P;";
+  auto StmtMatch = matchStmt(Snippet, expr().bind("expr"));
+  ASSERT_TRUE(StmtMatch) << "Snippet: " << Snippet;
+  EXPECT_EQ(nullptr,
+            isSmartDereference(*StmtMatch->Result.Nodes.getNodeAs<Expr>("expr"),
+                               *StmtMatch->Result.Context))
+      << "Snippet: " << Snippet;
+}
+
 static void testBuilder(
     llvm::Optional<std::string> (*Builder)(const Expr &, const ASTContext &),
     StringRef Snippet, StringRef Expected) {
@@ -136,6 +208,16 @@
               ValueIs(std::string(Expected)));
 }
 
+static void testBuilder(llvm::Optional<std::string> (*Builder)(const Expr &,
+                                                               ASTContext &),
+                        StringRef Snippet, StringRef Expected) {
+  auto StmtMatch = matchStmt(Snippet, expr().bind("expr"));
+  ASSERT_TRUE(StmtMatch);
+  EXPECT_THAT(Builder(*StmtMatch->Result.Nodes.getNodeAs<Expr>("expr"),
+                      *StmtMatch->Result.Context),
+              ValueIs(std::string(Expected)));
+}
+
 TEST(SourceCodeBuildersTest, BuildParensUnaryOp) {
   testBuilder(buildParens, "-4;", "(-4)");
 }
@@ -245,4 +327,83 @@
 TEST(SourceCodeBuildersTest, BuildArrowValueAddressWithParens) {
   testBuilder(buildArrow, "S x; &(true ? x : x);", "(true ? x : x).");
 }
+
+TEST(SourceCodeBuildersTest, BuildAccessValue) {
+  testBuilder(buildAccess, "S x; x;", "x.");
+}
+
+TEST(SourceCodeBuildersTest, BuildAccessPointerDereference) {
+  testBuilder(buildAccess, "S *x; *x;", "x->");
+}
+
+TEST(SourceCodeBuildersTest, BuildAccessPointerDereferenceIgnoresParens) {
+  testBuilder(buildAccess, "S *x; *(x);", "x->");
+}
+
+TEST(SourceCodeBuildersTest, BuildAccessValueBinaryOperation) {
+  testBuilder(buildAccess, "S x; x + x;", "(x + x).");
+}
+
+TEST(SourceCodeBuildersTest, BuildAccessPointerDereferenceExprWithParens) {
+  testBuilder(buildAccess, "S *x; *(x + 1);", "(x + 1)->");
+}
+
+TEST(SourceCodeBuildersTest, BuildAccessPointer) {
+  testBuilder(buildAccess, "S *x; x;", "x->");
+}
+
+TEST(SourceCodeBuildersTest, BuildAccessValueAddress) {
+  testBuilder(buildAccess, "S x; &x;", "x.");
+}
+
+TEST(SourceCodeBuildersTest, BuildAccessValueAddressIgnoresParens) {
+  testBuilder(buildAccess, "S x; &(x);", "x.");
+}
+
+TEST(SourceCodeBuildersTest, BuildAccessPointerBinaryOperation) {
+  testBuilder(buildAccess, "S *x; x + 1;", "(x + 1)->");
+}
+
+TEST(SourceCodeBuildersTest, BuildAccessValueAddressWithParens) {
+  testBuilder(buildAccess, "S x; &(true ? x : x);", "(true ? x : x).");
+}
+
+TEST(SourceCodeBuildersTest, BuildAccessSmartPointer) {
+  testBuilder(buildAccess, "Smart x; x;", "x->");
+}
+
+TEST(SourceCodeBuildersTest, BuildAccessSmartPointerDeref) {
+  testBuilder(buildAccess, "Smart x; *x;", "x->");
+}
+
+TEST(SourceCodeBuildersTest, BuildAccessSmartPointerMemberCall) {
+  StringRef Snippet = R"cc(
+    Smart x;
+    x->Field;
+  )cc";
+  auto StmtMatch =
+      matchStmt(Snippet, memberExpr(hasObjectExpression(expr().bind("expr"))));
+  ASSERT_TRUE(StmtMatch);
+  EXPECT_THAT(buildAccess(*StmtMatch->Result.Nodes.getNodeAs<Expr>("expr"),
+                          *StmtMatch->Result.Context),
+              ValueIs(std::string("x->")));
+}
+
+TEST(SourceCodeBuildersTest, BuildAccessImplicitThis) {
+  StringRef Snippet = R"cc(
+    struct Struct {
+      void foo() {}
+      void bar() {
+        foo();
+      }
+    };
+  )cc";
+  auto StmtMatch = matchStmt(
+      Snippet,
+      cxxMemberCallExpr(onImplicitObjectArgument(cxxThisExpr().bind("expr"))));
+  ASSERT_TRUE(StmtMatch);
+  EXPECT_THAT(buildAccess(*StmtMatch->Result.Nodes.getNodeAs<Expr>("expr"),
+                          *StmtMatch->Result.Context),
+              ValueIs(std::string()));
+}
 } // namespace
Index: clang/lib/Tooling/Transformer/Stencil.cpp
===================================================================
--- clang/lib/Tooling/Transformer/Stencil.cpp
+++ clang/lib/Tooling/Transformer/Stencil.cpp
@@ -11,7 +11,6 @@
 #include "clang/AST/ASTTypeTraits.h"
 #include "clang/AST/Expr.h"
 #include "clang/ASTMatchers/ASTMatchFinder.h"
-#include "clang/ASTMatchers/ASTMatchers.h"
 #include "clang/Basic/SourceLocation.h"
 #include "clang/Lex/Lexer.h"
 #include "clang/Tooling/Transformer/SourceCode.h"
@@ -56,39 +55,6 @@
   return Error::success();
 }
 
-// FIXME: Consider memoizing this function using the `ASTContext`.
-static bool isSmartPointerType(QualType Ty, ASTContext &Context) {
-  using namespace ::clang::ast_matchers;
-
-  // Optimization: hard-code common smart-pointer types. This can/should be
-  // removed if we start caching the results of this function.
-  auto KnownSmartPointer =
-      cxxRecordDecl(hasAnyName("::std::unique_ptr", "::std::shared_ptr"));
-  const auto QuacksLikeASmartPointer = cxxRecordDecl(
-      hasMethod(cxxMethodDecl(hasOverloadedOperatorName("->"),
-                              returns(qualType(pointsTo(type()))))),
-      hasMethod(cxxMethodDecl(hasOverloadedOperatorName("*"),
-                              returns(qualType(references(type()))))));
-  const auto SmartPointer = qualType(hasDeclaration(
-      cxxRecordDecl(anyOf(KnownSmartPointer, QuacksLikeASmartPointer))));
-  return match(SmartPointer, Ty, Context).size() > 0;
-}
-
-// Identifies use of `operator*` on smart pointers, and returns the underlying
-// smart-pointer expression; otherwise, returns null.
-static const Expr *isSmartDereference(const Expr &E, ASTContext &Context) {
-  using namespace ::clang::ast_matchers;
-
-  const auto HasOverloadedArrow = cxxRecordDecl(hasMethod(cxxMethodDecl(
-      hasOverloadedOperatorName("->"), returns(qualType(pointsTo(type()))))));
-  // Verify it is a smart pointer by finding `operator->` in the class
-  // declaration.
-  auto Deref = cxxOperatorCallExpr(
-      hasOverloadedOperatorName("*"), hasUnaryOperand(expr().bind("arg")),
-      callee(cxxMethodDecl(ofClass(HasOverloadedArrow))));
-  return selectFirst<Expr>("arg", match(Deref, E, Context));
-}
-
 namespace {
 // An arbitrary fragment of code within a stencil.
 class RawTextStencil : public StencilInterface {
@@ -196,7 +162,7 @@
       break;
     case UnaryNodeOperator::MaybeDeref:
       if (E->getType()->isAnyPointerType() ||
-          isSmartPointerType(E->getType(), *Match.Context)) {
+          tooling::isSmartPointerType(E->getType(), *Match.Context)) {
         // Strip off any operator->. This can only occur inside an actual arrow
         // member access, so we treat it as equivalent to an actual object
         // expression.
@@ -216,7 +182,7 @@
       break;
     case UnaryNodeOperator::MaybeAddressOf:
       if (E->getType()->isAnyPointerType() ||
-          isSmartPointerType(E->getType(), *Match.Context)) {
+          tooling::isSmartPointerType(E->getType(), *Match.Context)) {
         // Strip off any operator->. This can only occur inside an actual arrow
         // member access, so we treat it as equivalent to an actual object
         // expression.
@@ -311,34 +277,12 @@
     if (E == nullptr)
       return llvm::make_error<StringError>(errc::invalid_argument,
                                            "Id not bound: " + BaseId);
-    if (!E->isImplicitCXXThis()) {
-      llvm::Optional<std::string> S;
-      if (E->getType()->isAnyPointerType() ||
-          isSmartPointerType(E->getType(), *Match.Context)) {
-        // Strip off any operator->. This can only occur inside an actual arrow
-        // member access, so we treat it as equivalent to an actual object
-        // expression.
-        if (const auto *OpCall = dyn_cast<clang::CXXOperatorCallExpr>(E)) {
-          if (OpCall->getOperator() == clang::OO_Arrow &&
-              OpCall->getNumArgs() == 1) {
-            E = OpCall->getArg(0);
-          }
-        }
-        S = tooling::buildArrow(*E, *Match.Context);
-      } else if (const auto *Operand = isSmartDereference(*E, *Match.Context)) {
-        // `buildDot` already handles the built-in dereference operator, so we
-        // only need to catch overloaded `operator*`.
-        S = tooling::buildArrow(*Operand, *Match.Context);
-      } else {
-        S = tooling::buildDot(*E, *Match.Context);
-      }
-      if (S.hasValue())
-        *Result += *S;
-      else
-        return llvm::make_error<StringError>(
-            errc::invalid_argument,
-            "Could not construct object text from ID: " + BaseId);
-    }
+    llvm::Optional<std::string> S = tooling::buildAccess(*E, *Match.Context);
+    if (!S.hasValue())
+      return llvm::make_error<StringError>(
+          errc::invalid_argument,
+          "Could not construct object text from ID: " + BaseId);
+    *Result += *S;
     return Member->eval(Match, Result);
   }
 };
Index: clang/lib/Tooling/Transformer/SourceCodeBuilders.cpp
===================================================================
--- clang/lib/Tooling/Transformer/SourceCodeBuilders.cpp
+++ clang/lib/Tooling/Transformer/SourceCodeBuilders.cpp
@@ -10,6 +10,8 @@
 #include "clang/AST/ASTContext.h"
 #include "clang/AST/Expr.h"
 #include "clang/AST/ExprCXX.h"
+#include "clang/ASTMatchers/ASTMatchFinder.h"
+#include "clang/ASTMatchers/ASTMatchers.h"
 #include "clang/Tooling/Transformer/SourceCode.h"
 #include "llvm/ADT/Twine.h"
 #include <string>
@@ -60,6 +62,37 @@
   return false;
 }
 
+// FIXME: Consider memoizing this function using the `ASTContext`.
+bool tooling::isSmartPointerType(QualType Ty, ASTContext &Context) {
+  using namespace ast_matchers;
+
+  // Optimization: hard-code common smart-pointer types. This can/should be
+  // removed if we start caching the results of this function.
+  auto KnownSmartPointer =
+      cxxRecordDecl(hasAnyName("::std::unique_ptr", "::std::shared_ptr"));
+  const auto QuacksLikeASmartPointer = cxxRecordDecl(
+      hasMethod(cxxMethodDecl(hasOverloadedOperatorName("->"),
+                              returns(qualType(pointsTo(type()))))),
+      hasMethod(cxxMethodDecl(hasOverloadedOperatorName("*"),
+                              returns(qualType(references(type()))))));
+  const auto SmartPointer = qualType(hasDeclaration(
+      cxxRecordDecl(anyOf(KnownSmartPointer, QuacksLikeASmartPointer))));
+  return match(SmartPointer, Ty, Context).size() > 0;
+}
+
+const Expr *tooling::isSmartDereference(const Expr &E, ASTContext &Context) {
+  using namespace ::clang::ast_matchers;
+
+  const auto HasOverloadedArrow = cxxRecordDecl(hasMethod(cxxMethodDecl(
+      hasOverloadedOperatorName("->"), returns(qualType(pointsTo(type()))))));
+  // Verify it is a smart pointer by finding `operator->` in the class
+  // declaration.
+  auto Deref = cxxOperatorCallExpr(
+      hasOverloadedOperatorName("*"), hasUnaryOperand(expr().bind("arg")),
+      callee(cxxMethodDecl(ofClass(HasOverloadedArrow))));
+  return selectFirst<Expr>("arg", match(Deref, E, Context));
+}
+
 llvm::Optional<std::string> tooling::buildParens(const Expr &E,
                                                  const ASTContext &Context) {
   StringRef Text = getText(E, Context);
@@ -160,3 +193,33 @@
     return ("(" + Text + ")->").str();
   return (Text + "->").str();
 }
+
+llvm::Optional<std::string> tooling::buildAccess(const Expr &E,
+                                                 ASTContext &Context) {
+  // We return the empty string, because `None` signifies some sort of failure.
+  if (E.isImplicitCXXThis())
+    return std::string();
+
+  if (E.getType()->isAnyPointerType() ||
+      isSmartPointerType(E.getType(), Context)) {
+    // Strip off any operator->. This can only occur inside an actual arrow
+    // member access, so we treat it as equivalent to an actual object
+    // expression.
+    const Expr *ENorm = &E;
+    if (const auto *OpCall = dyn_cast<clang::CXXOperatorCallExpr>(ENorm)) {
+      if (OpCall->getOperator() == clang::OO_Arrow &&
+          OpCall->getNumArgs() == 1) {
+        ENorm = OpCall->getArg(0);
+      }
+    }
+    return tooling::buildArrow(*ENorm, Context);
+  }
+
+  if (const auto *Operand = isSmartDereference(E, Context)) {
+    // `buildDot` already handles the built-in dereference operator, so we
+    // only need to catch overloaded `operator*`.
+    return tooling::buildArrow(*Operand, Context);
+  }
+
+  return tooling::buildDot(E, Context);
+}
Index: clang/include/clang/Tooling/Transformer/SourceCodeBuilders.h
===================================================================
--- clang/include/clang/Tooling/Transformer/SourceCodeBuilders.h
+++ clang/include/clang/Tooling/Transformer/SourceCodeBuilders.h
@@ -43,6 +43,14 @@
 /// Determines whether printing this expression to the right of a unary operator
 /// requires a parentheses to preserve its meaning.
 bool needParensAfterUnaryOperator(const Expr &E);
+
+// Heuristic that guesses whether `Ty` is a "smart-pointer" type based on its
+// name or overloaded operators.
+bool isSmartPointerType(QualType Ty, ASTContext &Context);
+
+// Identifies use of `operator*` on smart pointers, and returns the underlying
+// smart-pointer expression; otherwise, returns null.
+const Expr *isSmartDereference(const Expr &E, ASTContext &Context);
 /// @}
 
 /// \name Basic code-string generation utilities.
@@ -79,6 +87,18 @@
 ///  `a+b` becomes `(a+b)->`
 llvm::Optional<std::string> buildArrow(const Expr &E,
                                        const ASTContext &Context);
+
+/// Adds an appropriate access operator (`.`, `->` or nothing, in the case of
+/// implicit `this`) to the end of the given expression, but adds parentheses
+/// when needed by the syntax, strips any `operator->` class and simplifies when
+/// possible. For example:
+///
+///  `x` becomes `x->` or `x.`, depending on `E`'s type
+///  `x.operator->()` becomes `x->`
+///  `a+b` becomes `(a+b)->` or `(a+b).`, depending on `E`'s type
+///  `&a` becomes `a.`
+///  `*a` becomes `a->`
+llvm::Optional<std::string> buildAccess(const Expr &E, ASTContext &Context);
 /// @}
 
 } // namespace tooling
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to