danix800 updated this revision to Diff 553514.
danix800 edited the summary of this revision.
danix800 added a comment.

1. Move out complicated computation into separate function;
2. Only check `ConreteInt` request count.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D158813/new/

https://reviews.llvm.org/D158813

Files:
  clang/lib/StaticAnalyzer/Checkers/MPI-Checker/MPIChecker.cpp
  clang/test/Analysis/mpichecker.cpp

Index: clang/test/Analysis/mpichecker.cpp
===================================================================
--- clang/test/Analysis/mpichecker.cpp
+++ clang/test/Analysis/mpichecker.cpp
@@ -272,6 +272,55 @@
   MPI_Wait(&rs.req2, MPI_STATUS_IGNORE);
 } // no error
 
+void nestedRequestWithCount() {
+  typedef struct {
+    MPI_Request req[3];
+    MPI_Request req2;
+  } ReqStruct;
+
+  ReqStruct rs;
+  int rank = 0;
+  double buf = 0;
+  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
+
+  MPI_Ireduce(MPI_IN_PLACE, &buf, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD,
+              &rs.req[0]);
+  MPI_Ireduce(MPI_IN_PLACE, &buf, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD,
+              &rs.req[1]);
+  MPI_Ireduce(MPI_IN_PLACE, &buf, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD,
+              &rs.req[2]);
+  MPI_Ireduce(MPI_IN_PLACE, &buf, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD,
+              &rs.req2);
+  MPI_Waitall(2, rs.req, MPI_STATUSES_IGNORE);
+  MPI_Waitall(1, rs.req + 2, MPI_STATUSES_IGNORE);
+  MPI_Wait(&rs.req2, MPI_STATUS_IGNORE);
+} // no error
+
+void nestedRequestWithCountMissingNonBlockingWait() {
+  typedef struct {
+    MPI_Request req[3];
+    MPI_Request req2;
+  } ReqStruct;
+
+  ReqStruct rs;
+  int rank = 0;
+  double buf = 0;
+  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
+
+  MPI_Ireduce(MPI_IN_PLACE, &buf, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD,
+              &rs.req[0]);
+  MPI_Ireduce(MPI_IN_PLACE, &buf, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD,
+              &rs.req[1]);
+  MPI_Ireduce(MPI_IN_PLACE, &buf, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD,
+              &rs.req[2]);
+  MPI_Ireduce(MPI_IN_PLACE, &buf, 1, MPI_DOUBLE, MPI_SUM, 0, MPI_COMM_WORLD,
+              &rs.req2);
+  MPI_Waitall(1, rs.req, MPI_STATUSES_IGNORE);
+  // MPI_Waitall(1, rs.req + 1, MPI_STATUSES_IGNORE);
+  MPI_Waitall(1, rs.req + 2, MPI_STATUSES_IGNORE);
+  MPI_Wait(&rs.req2, MPI_STATUS_IGNORE);
+} // expected-warning{{Request 'rs.req[1]' has no matching wait.}}
+
 void singleRequestInWaitall() {
   MPI_Request r;
   int rank = 0;
Index: clang/lib/StaticAnalyzer/Checkers/MPI-Checker/MPIChecker.cpp
===================================================================
--- clang/lib/StaticAnalyzer/Checkers/MPI-Checker/MPIChecker.cpp
+++ clang/lib/StaticAnalyzer/Checkers/MPI-Checker/MPIChecker.cpp
@@ -143,6 +143,38 @@
   }
 }
 
+static std::optional<std::pair<NonLoc, llvm::APSInt>>
+getRequestRegionOffsetAndCount(const MemRegion *const MR, const CallEvent &CE) {
+  if (CE.getNumArgs() < 2)
+    return std::nullopt;
+
+  ProgramStateRef State = CE.getState();
+  SValBuilder &SVB = State->getStateManager().getSValBuilder();
+  ASTContext &ASTCtx = SVB.getContext();
+
+  QualType RequestTy = CE.getArgExpr(1)->getType()->getPointeeType();
+  auto RequestRegionCount =
+      getDynamicElementCountWithOffset(State, CE.getArgSVal(1), RequestTy)
+          .getAs<nonloc::ConcreteInt>();
+  if (!RequestRegionCount)
+    return std::nullopt;
+
+  CharUnits TypeSizeInChars = ASTCtx.getTypeSizeInChars(RequestTy);
+
+  // MPI_Request as a handle does not have to be of non-zero size.
+  int64_t TypeSizeInBits =
+      (TypeSizeInChars.isZero() ? 1 : TypeSizeInChars.getQuantity()) *
+      ASTCtx.getCharWidth();
+
+  RegionOffset RequestRegionOffset = MR->getAsOffset();
+  if (RequestRegionOffset.hasSymbolicOffset())
+    return std::nullopt;
+
+  return std::make_pair(
+      SVB.makeArrayIndex(RequestRegionOffset.getOffset() / TypeSizeInBits),
+      RequestRegionCount->getValue());
+}
+
 void MPIChecker::allRegionsUsedByWait(
     llvm::SmallVector<const MemRegion *, 2> &ReqRegions,
     const MemRegion *const MR, const CallEvent &CE, CheckerContext &Ctx) const {
@@ -161,20 +193,34 @@
       return;
     }
 
-    DefinedOrUnknownSVal ElementCount = getDynamicElementCount(
-        Ctx.getState(), SuperRegion, Ctx.getSValBuilder(),
-        CE.getArgExpr(1)->getType()->getPointeeType());
-    const llvm::APSInt &ArrSize =
-        ElementCount.castAs<nonloc::ConcreteInt>().getValue();
+    auto RequestRegionOffsetAndCount = getRequestRegionOffsetAndCount(MR, CE);
+    if (!RequestRegionOffsetAndCount)
+      return;
+
+    auto [RegionOffset, RegionCount] = *RequestRegionOffsetAndCount;
 
-    for (size_t i = 0; i < ArrSize; ++i) {
-      const NonLoc Idx = Ctx.getSValBuilder().makeArrayIndex(i);
+    QualType MPIReqTy = CE.getArgExpr(1)->getType()->getPointeeType();
+    SValBuilder &SVB = Ctx.getSValBuilder();
 
-      const ElementRegion *const ER = RegionManager.getElementRegion(
-          CE.getArgExpr(1)->getType()->getPointeeType(), Idx, SuperRegion,
-          Ctx.getASTContext());
+    auto RequestedCountSVal = CE.getArgSVal(0).getAs<nonloc::ConcreteInt>();
+    if (!RequestedCountSVal)
+      return;
 
-      ReqRegions.push_back(ER->getAs<MemRegion>());
+    const llvm::APSInt &RequestedCount = RequestedCountSVal->getValue();
+    // TODO: i >= RegionCount is an OOB UB, we could report it here but a better
+    // approach is adding this constraint as a summary into generic checker like
+    // StdCLibraryFunctions
+    for (size_t i = 0; i < RegionCount && i < RequestedCount; ++i) {
+      auto RegionIndex =
+          SVB.evalBinOp(Ctx.getState(), BO_Add, SVB.makeArrayIndex(i),
+                        RegionOffset, SVB.getArrayIndexType())
+              .getAs<NonLoc>();
+      if (RegionIndex) {
+        const ElementRegion *const RequestRegion =
+            RegionManager.getElementRegion(MPIReqTy, *RegionIndex, SuperRegion,
+                                           Ctx.getASTContext());
+        ReqRegions.push_back(RequestRegion);
+      }
     }
   } else if (FuncClassifier->isMPI_Wait(CE.getCalleeIdentifier())) {
     ReqRegions.push_back(MR);
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to