llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang Author: Ryosuke Niwa (rniwa) <details> <summary>Changes</summary> This PR fixes the bug that alpha.webkit.UncountedLocalVarsChecker erroneously treats a trivial recursive function as non-trivial. This was caused by TrivialFunctionAnalysis::isTrivialImpl which takes a statement as an argument populating the cache with "false" while traversing the statement to determine its triviality within a recursive function in TrivialFunctionAnalysisVisitor's WithCachedResult. Because IsFunctionTrivial honors an entry in the cache, this resulted in the whole function to be treated as non-trivial. Thankfully, TrivialFunctionAnalysisVisitor::IsFunctionTrivial already handles recursive functions correctly so this PR applies the same logic to TrivialFunctionAnalysisVisitor::WithCachedResult by sharing code between the two functions. This avoids the cache to be pre-populated with "false" while traversing statements in a recurisve function. --- Full diff: https://github.com/llvm/llvm-project/pull/110973.diff 3 Files Affected: - (modified) clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp (+22-36) - (modified) clang/test/Analysis/Checkers/WebKit/uncounted-local-vars.cpp (+29) - (modified) clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp (+12) ``````````diff diff --git a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp index 4d145be808f6d8..9f192707e83f89 100644 --- a/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp +++ b/clang/lib/StaticAnalyzer/Checkers/WebKit/PtrTypesSemantics.cpp @@ -281,16 +281,29 @@ class TrivialFunctionAnalysisVisitor return true; } - template <typename CheckFunction> - bool WithCachedResult(const Stmt *S, CheckFunction Function) { - // If the statement isn't in the cache, conservatively assume that - // it's not trivial until analysis completes. Insert false to the cache - // first to avoid infinite recursion. - auto [It, IsNew] = Cache.insert(std::make_pair(S, false)); + template <typename StmtOrDecl, typename CheckFunction> + bool WithCachedResult(const StmtOrDecl *S, CheckFunction Function) { + auto CacheIt = Cache.find(S); + if (CacheIt != Cache.end()) + return CacheIt->second; + + // Treat a recursive statement to be trivial until proven otherwise. + auto [RecursiveIt, IsNew] = RecursiveFn.insert(std::make_pair(S, true)); if (!IsNew) - return It->second; + return RecursiveIt->second; + bool Result = Function(); + + if (!Result) { + for (auto &It : RecursiveFn) + It.second = false; + } + RecursiveIt = RecursiveFn.find(S); + assert(RecursiveIt != RecursiveFn.end()); + Result = RecursiveIt->second; + RecursiveFn.erase(RecursiveIt); Cache[S] = Result; + return Result; } @@ -300,16 +313,7 @@ class TrivialFunctionAnalysisVisitor TrivialFunctionAnalysisVisitor(CacheTy &Cache) : Cache(Cache) {} bool IsFunctionTrivial(const Decl *D) { - auto CacheIt = Cache.find(D); - if (CacheIt != Cache.end()) - return CacheIt->second; - - // Treat a recursive function call to be trivial until proven otherwise. - auto [RecursiveIt, IsNew] = RecursiveFn.insert(std::make_pair(D, true)); - if (!IsNew) - return RecursiveIt->second; - - bool Result = [&]() { + return WithCachedResult(D, [&]() { if (auto *CtorDecl = dyn_cast<CXXConstructorDecl>(D)) { for (auto *CtorInit : CtorDecl->inits()) { if (!Visit(CtorInit->getInit())) @@ -320,20 +324,7 @@ class TrivialFunctionAnalysisVisitor if (!Body) return false; return Visit(Body); - }(); - - if (!Result) { - // D and its mutually recursive callers are all non-trivial. - for (auto &It : RecursiveFn) - It.second = false; - } - RecursiveIt = RecursiveFn.find(D); - assert(RecursiveIt != RecursiveFn.end()); - Result = RecursiveIt->second; - RecursiveFn.erase(RecursiveIt); - Cache[D] = Result; - - return Result; + }); } bool VisitStmt(const Stmt *S) { @@ -590,11 +581,6 @@ bool TrivialFunctionAnalysis::isTrivialImpl( bool TrivialFunctionAnalysis::isTrivialImpl( const Stmt *S, TrivialFunctionAnalysis::CacheTy &Cache) { - // If the statement isn't in the cache, conservatively assume that - // it's not trivial until analysis completes. Unlike a function case, - // we don't insert an entry into the cache until Visit returns - // since Visit* functions themselves make use of the cache. - TrivialFunctionAnalysisVisitor V(Cache); bool Result = V.Visit(S); assert(Cache.contains(S) && "Top-level statement not properly cached!"); diff --git a/clang/test/Analysis/Checkers/WebKit/uncounted-local-vars.cpp b/clang/test/Analysis/Checkers/WebKit/uncounted-local-vars.cpp index 25776870dd3ae0..beb8f69512df63 100644 --- a/clang/test/Analysis/Checkers/WebKit/uncounted-local-vars.cpp +++ b/clang/test/Analysis/Checkers/WebKit/uncounted-local-vars.cpp @@ -289,3 +289,32 @@ void foo() { } } // namespace local_assignment_to_global + +namespace local_var_in_recursive_function { + +struct TreeNode { + Ref<TreeNode> create() { return Ref(*new TreeNode); } + + void ref() const { ++refCount; } + void deref() const { + if (!--refCount) + delete this; + } + + int recursiveCost(); + + int cost { 0 }; + mutable unsigned refCount { 0 }; + TreeNode* nextSibling { nullptr }; + TreeNode* firstChild { nullptr }; +}; + +int TreeNode::recursiveCost() { + // no warnings + unsigned totalCost = cost; + for (TreeNode* node = firstChild; node; node = node->nextSibling) + totalCost += recursiveCost(); + return totalCost; +} + +} // namespace local_var_in_recursive_function diff --git a/clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp b/clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp index 97efb354f0371d..9205254edb6b48 100644 --- a/clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp +++ b/clang/test/Analysis/Checkers/WebKit/uncounted-obj-arg.cpp @@ -245,6 +245,15 @@ class RefCounted { void mutuallyRecursive8() { mutuallyRecursive9(); someFunction(); } void mutuallyRecursive9() { mutuallyRecursive8(); } + int recursiveCost() { + unsigned totalCost = 0; + for (unsigned i = 0; i < sizeof(children)/sizeof(*children); ++i) { + if (auto* child = children[i]) + totalCost += child->recursiveCost(); + } + return totalCost; + } + int trivial1() { return 123; } float trivial2() { return 0.3; } float trivial3() { return (float)0.4; } @@ -431,6 +440,7 @@ class RefCounted { Number* number { nullptr }; ComplexNumber complex; Enum enumValue { Enum::Value1 }; + RefCounted* children[4]; }; unsigned RefCounted::s_v = 0; @@ -539,6 +549,8 @@ class UnrelatedClass { getFieldTrivial().mutuallyRecursive9(); // expected-warning@-1{{Call argument for 'this' parameter is uncounted and unsafe}} + getFieldTrivial().recursiveCost(); // no-warning + getFieldTrivial().someFunction(); // expected-warning@-1{{Call argument for 'this' parameter is uncounted and unsafe}} getFieldTrivial().nonTrivial1(); `````````` </details> https://github.com/llvm/llvm-project/pull/110973 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits