yashmayya commented on code in PR #15135: URL: https://github.com/apache/pinot/pull/15135#discussion_r1973219056
########## pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchablePlanContext.java: ########## @@ -130,14 +134,26 @@ public List<DispatchablePlanFragment> constructDispatchablePlanFragmentList(Plan dispatchablePlanFragment.setTimeBoundaryInfo(dispatchablePlanMetadata.getTimeBoundaryInfo()); } } - return Arrays.asList(dispatchablePlanFragmentArray); + return dispatchablePlanFragmentMap; } - private void createDispatchablePlanFragmentList(DispatchablePlanFragment[] dispatchablePlanFragmentArray, - PlanFragment planFragmentRoot) { - dispatchablePlanFragmentArray[planFragmentRoot.getFragmentId()] = new DispatchablePlanFragment(planFragmentRoot); - for (PlanFragment childPlanFragment : planFragmentRoot.getChildren()) { - createDispatchablePlanFragmentList(dispatchablePlanFragmentArray, childPlanFragment); + private Map<Integer, DispatchablePlanFragment> createDispatchablePlanFragmentMap(PlanFragment planFragmentRoot) { + HashMap<Integer, DispatchablePlanFragment> result = + Maps.newHashMapWithExpectedSize(_dispatchablePlanMetadataMap.size()); + Queue<PlanFragment> pendingPlanFragmentIds = new ArrayDeque<>(); + pendingPlanFragmentIds.add(planFragmentRoot); + while (!pendingPlanFragmentIds.isEmpty()) { + PlanFragment planFragment = pendingPlanFragmentIds.poll(); + int planFragmentId = planFragment.getFragmentId(); + + if (result.containsKey(planFragmentId)) { + LOGGER.info("plan fragment {} found twice", planFragmentId); + continue; + } Review Comment: This shouldn't happen right? ########## pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/SpoolIntegrationTest.java: ########## @@ -117,13 +121,146 @@ public void intermediateSpool() JsonNode stats = jsonNode.get("stageStats"); assertNoError(jsonNode); DocumentContext parsed = JsonPath.parse(stats.toString()); - List<Map<String, Object>> stage4On3 = parsed.read("$..[?(@.stage == 3)]..[?(@.stage == 4)]"); - Assert.assertEquals(stage4On3.size(), 1, "Stage 4 should be descended from stage 3 exactly once"); - List<Map<String, Object>> stage4On7 = parsed.read("$..[?(@.stage == 7)]..[?(@.stage == 4)]"); - Assert.assertEquals(stage4On3.size(), 1, "Stage 4 should be descended from stage 7 exactly once"); + checkSpoolTimes(parsed, 4, 3, 1); + checkSpoolTimes(parsed, 4, 7, 1); + checkSpoolSame(parsed, 4, 3, 7); + } + + /** + * Test a complex with nested spools. Don't try to understand it, just check that the spools are correct. Review Comment: 😆 I unfortunately only noticed this comment after I tried to decipher this crazy query and gave up :P ########## pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchableSubPlan.java: ########## @@ -56,8 +60,38 @@ public DispatchableSubPlan(PairList<Integer, String> fields, List<DispatchablePl * Get the list of stage plan root node. * @return stage plan map. */ - public List<DispatchablePlanFragment> getQueryStageList() { - return _queryStageList; + public Map<Integer, DispatchablePlanFragment> getQueryStageMap() { Review Comment: nit: let's update the Javadoc to state that the keys are the stage IDs with the values being the stage's root node fragment? ########## pinot-query-planner/src/main/java/org/apache/pinot/query/planner/physical/DispatchableSubPlan.java: ########## @@ -90,7 +124,11 @@ public Map<String, Set<String>> getTableToUnavailableSegmentsMap() { public int getEstimatedNumQueryThreads() { int estimatedNumQueryThreads = 0; // Skip broker reduce root stage - for (DispatchablePlanFragment stage : _queryStageList.subList(1, _queryStageList.size())) { + for (Map.Entry<Integer, DispatchablePlanFragment> entry : _queryStageMap.entrySet()) { + if (entry.getKey() == 0) { + continue; + } + DispatchablePlanFragment stage = entry.getValue(); Review Comment: We could simply use `getQueryStagesWithoutRoot` instead here right? ########## pinot-query-runtime/src/main/java/org/apache/pinot/query/service/dispatch/QueryDispatcher.java: ########## @@ -244,15 +245,15 @@ private boolean isQueryCancellationEnabled() { return _serversByQuery != null; } - private <E> void execute(long requestId, List<DispatchablePlanFragment> stagePlans, + private <E> void execute(long requestId, Set<DispatchablePlanFragment> stagePlans, long timeoutMs, Map<String, String> queryOptions, SendRequest<E> sendRequest, Set<QueryServerInstance> serverInstancesOut, BiConsumer<E, QueryServerInstance> resultConsumer) throws ExecutionException, InterruptedException, TimeoutException { Deadline deadline = Deadline.after(timeoutMs, TimeUnit.MILLISECONDS); - List<StageInfo> stageInfos = serializePlanFragments(stagePlans, serverInstancesOut, deadline); + Map<DispatchablePlanFragment, StageInfo> stageInfos = serializePlanFragments(stagePlans, serverInstancesOut, deadline); Review Comment: This is causing checkstyle to fail due to line length. ########## pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/SpoolIntegrationTest.java: ########## @@ -117,13 +121,146 @@ public void intermediateSpool() JsonNode stats = jsonNode.get("stageStats"); assertNoError(jsonNode); DocumentContext parsed = JsonPath.parse(stats.toString()); - List<Map<String, Object>> stage4On3 = parsed.read("$..[?(@.stage == 3)]..[?(@.stage == 4)]"); - Assert.assertEquals(stage4On3.size(), 1, "Stage 4 should be descended from stage 3 exactly once"); - List<Map<String, Object>> stage4On7 = parsed.read("$..[?(@.stage == 7)]..[?(@.stage == 4)]"); - Assert.assertEquals(stage4On3.size(), 1, "Stage 4 should be descended from stage 7 exactly once"); + checkSpoolTimes(parsed, 4, 3, 1); + checkSpoolTimes(parsed, 4, 7, 1); + checkSpoolSame(parsed, 4, 3, 7); + } + + /** + * Test a complex with nested spools. Don't try to understand it, just check that the spools are correct. + * The test name corresponds to the PR that fixed the issue. + */ + @Test + public void test15135() + throws Exception { + JsonNode jsonNode = postQuery("SET useSpools = true;\n" + + "\n" + + "WITH\n" + + " q1 AS (\n" + + " SELECT ArrTimeBlk as userUUID,\n" + + " Dest as deviceOS,\n" + + " SUM(ArrTime) AS totalTrips\n" + + " FROM mytable\n" + + " GROUP BY ArrTimeBlk, Dest\n" + + " ),\n" + + " q2 AS (\n" + + " SELECT userUUID,\n" + + " deviceOS,\n" + + " SUM(totalTrips) AS totalTrips,\n" + + " COUNT(DISTINCT userUUID) AS reach\n" + + " FROM q1\n" + + " GROUP BY userUUID,\n" + + " deviceOS\n" + + " ),\n" + + " q3 AS (\n" + + " SELECT userUUID,\n" + + " (totalTrips / reach) AS frequency\n" + + " FROM q2\n" + + " ),\n" + + " q4 AS (\n" + + " SELECT rd.userUUID,\n" + + " rd.deviceOS,\n" + + " rd.totalTrips as totalTrips,\n" + + " rd.reach AS reach\n" + + " FROM q2 rd\n" + + " ),\n" + + " q5 AS (\n" + + " SELECT userUUID,\n" + + " SUM(totalTrips) AS totalTrips\n" + + " FROM q4\n" + + " GROUP BY userUUID\n" + + " ),\n" + + " q6 AS (\n" + + " SELECT s.userUUID,\n" + + " s.totalTrips,\n" + + " (s.totalTrips / o.frequency) AS reach,\n" + + " 'Traditional TV + OTT' AS deviceOS\n" + + " FROM q5 s\n" + + " JOIN q3 o ON s.userUUID = o.userUUID\n" + + " ),\n" + + " q7 AS (\n" + + " SELECT rd.userUUID,\n" + + " rd.totalTrips,\n" + + " rd.reach,\n" + + " rd.deviceOS\n" + + " FROM q4 rd\n" + + " UNION ALL\n" + + " SELECT f.userUUID,\n" + + " f.totalTrips,\n" + + " f.reach,\n" + + " f.deviceOS\n" + + " FROM q6 f\n" + + " ),\n" + + " q8 AS (\n" + + " SELECT sd.*\n" + + " FROM q7 sd\n" + + " JOIN (\n" + + " SELECT deviceOS,\n" + + " PERCENTILETDigest(totalTrips, 20) AS p20\n" + + " FROM q7\n" + + " GROUP BY deviceOS\n" + + " ) q ON sd.deviceOS = q.deviceOS\n" + + " )\n" + + "SELECT *\n" + + "FROM q8"); + JsonNode stats = jsonNode.get("stageStats"); + assertNoError(jsonNode); + DocumentContext parsed = JsonPath.parse(stats.toString()); + + checkSpoolTimes(parsed, 6, 5, 1); + checkSpoolTimes(parsed, 6, 14, 1); + checkSpoolSame(parsed, 6, 5, 14); + + checkSpoolTimes(parsed, 7, 6, 2); + + checkSpoolTimes(parsed, 4, 3, 1); + checkSpoolTimes(parsed, 4, 7, 2); // because there are 2 copies of 7 as well + checkSpoolTimes(parsed, 4, 9, 1); + checkSpoolTimes(parsed, 4, 12, 1); + checkSpoolTimes(parsed, 4, 18, 1); + checkSpoolSame(parsed, 4, 3, 7, 9, 12, 18); + } + + private List<Map<String, Object>> findDescendantById(DocumentContext stats, int parent, int descendant) { + return stats.read(parentDescendantJsonPathExpression(parent, descendant)); + } + + private void checkSpoolTimes(DocumentContext stats, int spoolStageId, int parent, int times) { + List<Map<String, Object>> descendants = findDescendantById(stats, parent, spoolStageId); + Assert.assertEquals(descendants.size(), times, "Stage " + spoolStageId + " should be descended from stage " + + parent + " exactly " + times + " times"); + Map<String, Object> firstSpool = descendants.get(0); + for (int i = 1; i < descendants.size(); i++) { + Assert.assertEquals(descendants.get(i), firstSpool, "Stage " + spoolStageId + " should be the same in " + + "all " + times + " descendants"); + } + } + + private void checkSpoolSame(DocumentContext stats, int spoolStageId, int... parents) { + List<Pair<Integer, List<Map<String, Object>>>> spools = Arrays.stream(parents) + .mapToObj(parent -> Pair.of(parent, findDescendantById(stats, parent, spoolStageId))) + .collect(Collectors.toList()); + Pair<Integer, List<Map<String, Object>>> notEmpty = spools.stream() + .filter(s -> !s.getValue().isEmpty()) + .findFirst() + .orElse(null); + if (notEmpty == null) { + Assert.fail("None of the parent nodes " + Arrays.toString(parents) + " have a descendant with id " + + spoolStageId); + } + List<Pair<Integer, List<Map<String, Object>>>> allNotEqual = spools.stream() + .filter(s -> !s.getValue().get(0).equals(notEmpty.getValue().get(0))) + .collect(Collectors.toList()); + if (!allNotEqual.isEmpty()) { + Assert.fail("The descendant with id " + spoolStageId + " is not the same in all parent nodes " + + spools); + } + } - Assert.assertEquals(stage4On3, stage4On7, "Stage 4 should be the same in both stage 3 and stage 7"); + @Language("jsonpath") + private String parentDescendantJsonPathExpression(int parent, int child) { Review Comment: I'm not sure I follow what this method is doing and how / why? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org For additional commands, e-mail: commits-h...@pinot.apache.org