================ @@ -3068,12 +3124,240 @@ void MallocChecker::checkDeadSymbols(SymbolReaper &SymReaper, C.addTransition(state->set<RegionState>(RS), N); } +// Helper function to check if a name is a recognized smart pointer name +static bool isSmartPtrName(StringRef Name) { + return Name == "unique_ptr" || Name == "shared_ptr"; +} + +// Allowlist of owning smart pointers we want to recognize. +// Start with unique_ptr and shared_ptr. (intentionally exclude weak_ptr) +static bool isSmartOwningPtrType(QualType QT) { + QT = QT->getCanonicalTypeUnqualified(); + + // First try TemplateSpecializationType (for std smart pointers) + if (const auto *TST = QT->getAs<TemplateSpecializationType>()) { + const TemplateDecl *TD = TST->getTemplateName().getAsTemplateDecl(); + if (!TD) + return false; + + const auto *ND = dyn_cast_or_null<NamedDecl>(TD->getTemplatedDecl()); + if (!ND) + return false; + + // Check if it's in std namespace + if (!isWithinStdNamespace(ND)) + return false; + + return isSmartPtrName(ND->getName()); + } + + // Also try RecordType (for custom smart pointer implementations) + if (const auto *RD = QT->getAsCXXRecordDecl()) { + // Accept any custom unique_ptr or shared_ptr implementation + return isSmartPtrName(RD->getName()); + } + + return false; +} + +/// Check if a record type has smart pointer fields (directly or in base +/// classes). +static bool hasSmartPtrField(const CXXRecordDecl *CRD) { + // Check direct fields + if (llvm::any_of(CRD->fields(), [](const FieldDecl *FD) { + return isSmartOwningPtrType(FD->getType()); + })) + return true; + + // Check fields from base classes + for (const CXXBaseSpecifier &Base : CRD->bases()) { + if (const CXXRecordDecl *BaseDecl = Base.getType()->getAsCXXRecordDecl()) { + if (hasSmartPtrField(BaseDecl)) + return true; + } + } + return false; +} + +/// Check if an expression is an rvalue record type passed by value. +static bool isRvalueByValueRecord(const Expr *AE) { + if (AE->isGLValue()) + return false; + + QualType T = AE->getType(); + if (!T->isRecordType() || T->isReferenceType()) + return false; + + // Accept common temp/construct forms but don't overfit. + return isa<CXXTemporaryObjectExpr, MaterializeTemporaryExpr, CXXConstructExpr, + InitListExpr, ImplicitCastExpr, CXXBindTemporaryExpr>(AE); +} + +/// Check if an expression is an rvalue record with smart pointer fields passed +/// by value. +static bool isRvalueByValueRecordWithSmartPtr(const Expr *AE) { + if (!isRvalueByValueRecord(AE)) + return false; + + const auto *CRD = AE->getType()->getAsCXXRecordDecl(); + return CRD && hasSmartPtrField(CRD); +} + +/// Check if a CXXRecordDecl has a name matching recognized smart pointer names. +static bool isSmartOwningPtrRecord(const CXXRecordDecl *RD) { + if (!RD) + return false; + + // Check the record name directly + if (isSmartPtrName(RD->getName())) { + // Accept both std and custom smart pointer implementations + return true; + } + + return false; +} + +/// Check if a call is a constructor of a smart pointer class that accepts +/// pointer parameters. +static bool isSmartPtrCall(const CallEvent &Call) { + // Only check for smart pointer constructor calls + const auto *CD = dyn_cast_or_null<CXXConstructorDecl>(Call.getDecl()); + if (!CD) + return false; + + const auto *RD = CD->getParent(); + if (!isSmartOwningPtrRecord(RD)) + return false; + + // Check if constructor takes a pointer parameter + for (const auto *Param : CD->parameters()) { + QualType ParamType = Param->getType(); + if (ParamType->isPointerType() && !ParamType->isFunctionPointerType() && + !ParamType->isVoidPointerType()) { + return true; + } + } + + return false; +} + +static void collectDirectSmartOwningPtrFieldRegions( + const MemRegion *Base, QualType RecQT, CheckerContext &C, + SmallVectorImpl<const MemRegion *> &Out) { + if (!Base) + return; + const auto *CRD = RecQT->getAsCXXRecordDecl(); + if (!CRD) + return; + + // Collect direct fields + for (const FieldDecl *FD : CRD->fields()) { + if (!isSmartOwningPtrType(FD->getType())) + continue; + SVal L = C.getState()->getLValue(FD, loc::MemRegionVal(Base)); + if (const MemRegion *FR = L.getAsRegion()) + Out.push_back(FR); + } + + // Collect fields from base classes + for (const CXXBaseSpecifier &BaseSpec : CRD->bases()) { + if (const CXXRecordDecl *BaseDecl = + BaseSpec.getType()->getAsCXXRecordDecl()) { + // Get the base class region + SVal BaseL = C.getState()->getLValue(BaseDecl, Base->getAs<SubRegion>(), + BaseSpec.isVirtual()); + if (const MemRegion *BaseRegion = BaseL.getAsRegion()) { + // Recursively collect fields from this base class + collectDirectSmartOwningPtrFieldRegions(BaseRegion, BaseSpec.getType(), + C, Out); + } + } + } +} + +/// Handle smart pointer constructor calls by escaping allocated symbols +/// that are passed as pointer arguments to the constructor. +ProgramStateRef MallocChecker::handleSmartPointerConstructorArguments( + const CallEvent &Call, ProgramStateRef State) const { + const auto *CD = cast<CXXConstructorDecl>(Call.getDecl()); + for (unsigned I = 0, E = Call.getNumArgs(); I != E; ++I) { + const Expr *ArgExpr = Call.getArgExpr(I); + if (!ArgExpr) + continue; + + QualType ParamType = CD->getParamDecl(I)->getType(); + if (ParamType->isPointerType() && !ParamType->isFunctionPointerType() && + !ParamType->isVoidPointerType()) { + // This argument is a pointer being passed to smart pointer constructor + SVal ArgVal = Call.getArgSVal(I); + SymbolRef Sym = ArgVal.getAsSymbol(); + if (Sym && State->contains<RegionState>(Sym)) { + const RefState *RS = State->get<RegionState>(Sym); + if (RS && (RS->isAllocated() || RS->isAllocatedOfSizeZero())) { + State = State->set<RegionState>(Sym, RefState::getEscaped(RS)); + } + } + } + } + return State; +} + +/// Handle all smart pointer related processing in function calls. +/// This includes both direct smart pointer constructor calls and by-value +/// arguments containing smart pointer fields. +ProgramStateRef MallocChecker::handleSmartPointerRelatedCalls( + const CallEvent &Call, CheckerContext &C, ProgramStateRef State) const { + + // Handle direct smart pointer constructor calls first + if (isSmartPtrCall(Call)) { + return handleSmartPointerConstructorArguments(Call, State); + } + + // Handle smart pointer fields in by-value record arguments + SmallVector<const MemRegion *, 8> SmartPtrFieldRoots; + for (unsigned I = 0, E = Call.getNumArgs(); I != E; ++I) { + const Expr *AE = Call.getArgExpr(I); + if (!AE) + continue; + AE = AE->IgnoreParenImpCasts(); + + if (!isRvalueByValueRecordWithSmartPtr(AE)) + continue; + + // Find a region for the argument. + SVal ArgVal = Call.getArgSVal(I); + const MemRegion *ArgRegion = ArgVal.getAsRegion(); + if (!ArgRegion) { + // Skip this argument to prevent overly broad escaping that would + // suppress legitimate leak detection + continue; + } + + // Collect direct smart owning pointer field regions + collectDirectSmartOwningPtrFieldRegions(ArgRegion, AE->getType(), C, + SmartPtrFieldRoots); + } + + // Escape symbols reachable from smart pointer fields + if (!SmartPtrFieldRoots.empty()) { + State = EscapeTrackedCallback::EscapeTrackedRegionsReachableFrom( + SmartPtrFieldRoots, State); + } + + return State; +} + void MallocChecker::checkPostCall(const CallEvent &Call, CheckerContext &C) const { + // Handle existing post-call handlers first if (const auto *PostFN = PostFnMap.lookup(Call)) { (*PostFN)(this, C.getState(), Call, C); - return; + return; // Post-handler already called addTransition, we're done } + + // Handle smart pointer related processing only if no post-handler was called + ProgramStateRef State = handleSmartPointerRelatedCalls(Call, C, C.getState()); + C.addTransition(State); ---------------- steakhal wrote:
```suggestion C.addTransition(handleSmartPointerRelatedCalls(Call, C, C.getState())); ``` https://github.com/llvm/llvm-project/pull/152751 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits