yashmayya commented on code in PR #16886:
URL: https://github.com/apache/pinot/pull/16886#discussion_r2392572433


##########
pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalancerTest.java:
##########
@@ -1086,6 +1089,368 @@ private void addTenantTagToInstances(String testTenant) 
{
     });
   }
 
+  @Test
+  public void testZkBasedTenantRebalanceObserverPoll()

Review Comment:
   Maybe one of these tests could try polling concurrently on multiple threads 
and verify that the ZK updates are properly serialized?



##########
pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalancerTest.java:
##########
@@ -1086,6 +1089,368 @@ private void addTenantTagToInstances(String testTenant) 
{
     });
   }
 
+  @Test
+  public void testZkBasedTenantRebalanceObserverPoll()
+      throws Exception {
+    int numServers = 3;
+    for (int i = 0; i < numServers; i++) {
+      addFakeServerInstanceToAutoJoinHelixCluster(SERVER_INSTANCE_ID_PREFIX + 
i, true);
+    }
+
+    TenantRebalancer tenantRebalancer =
+        new TenantRebalancer(_tableRebalanceManager, _helixResourceManager, 
_executorService);
+
+    // tag all servers and brokers to test tenant
+    addTenantTagToInstances(TENANT_NAME);
+
+    // create 2 schemas
+    addDummySchema(RAW_TABLE_NAME_A);
+    addDummySchema(RAW_TABLE_NAME_B);
+
+    // create 2 tables on test tenant
+    createTableWithSegments(RAW_TABLE_NAME_A, TENANT_NAME);
+    createTableWithSegments(RAW_TABLE_NAME_B, TENANT_NAME);
+
+    // Add 3 more servers which will be tagged to default tenant
+    int numServersToAdd = 3;
+    for (int i = 0; i < numServersToAdd; i++) {
+      addFakeServerInstanceToAutoJoinHelixCluster(SERVER_INSTANCE_ID_PREFIX + 
(numServers + i), true);
+    }
+    addTenantTagToInstances(TENANT_NAME);
+
+    String jobId = "test-poll-job-123";
+
+    // Create tenant rebalance context with tables in queues
+    TenantRebalanceConfig config = new TenantRebalanceConfig();
+    config.setTenantName(TENANT_NAME);
+    config.setVerboseResult(true);
+    config.setDryRun(true);
+
+    TenantRebalanceResult dryRunResult = tenantRebalancer.rebalance(config);
+    TenantRebalanceContext context = new TenantRebalanceContext(
+        jobId, config, 1,
+        dryRunResult.getRebalanceTableResults().keySet().stream()
+            .map(tableName -> new 
TenantRebalancer.TenantTableRebalanceJobContext(tableName, tableName + "_job", 
false))
+            .collect(Collectors.toCollection(ConcurrentLinkedDeque::new)),
+        new LinkedList<>(),
+        new ConcurrentLinkedQueue<>()
+    );
+
+    TenantRebalanceProgressStats progressStats =
+        new 
TenantRebalanceProgressStats(dryRunResult.getRebalanceTableResults().keySet());
+
+    // Test polling from parallel queue
+    ZkBasedTenantRebalanceObserver observer =
+        new ZkBasedTenantRebalanceObserver(jobId, TENANT_NAME, progressStats, 
context, _helixResourceManager);
+    TenantRebalancer.TenantTableRebalanceJobContext polledJob = 
observer.pollParallel();
+    assertNotNull(polledJob);
+    
assertTrue(dryRunResult.getRebalanceTableResults().containsKey(polledJob.getTableName()));
+
+    // Test polling from sequential queue (should be empty)
+    TenantRebalancer.TenantTableRebalanceJobContext sequentialJob = 
observer.pollSequential();
+    assertNull(sequentialJob);
+
+    // Verify the job was moved to ongoing queue and status was updated
+    Map<String, String> updatedMetadata =
+        _helixResourceManager.getControllerJobZKMetadata(jobId, 
ControllerJobTypes.TENANT_REBALANCE);
+    assertNotNull(updatedMetadata);
+    TenantRebalanceContext updatedContext = 
TenantRebalanceContext.fromTenantRebalanceJobMetadata(updatedMetadata);
+    TenantRebalanceProgressStats updatedStats =
+        
TenantRebalanceProgressStats.fromTenantRebalanceJobMetadata(updatedMetadata);
+
+    assertNotNull(updatedContext);
+    assertNotNull(updatedStats);
+    assertEquals(updatedContext.getOngoingJobsQueue().size(), 1);
+    
assertEquals(updatedStats.getTableStatusMap().get(polledJob.getTableName()),
+        TenantRebalanceProgressStats.TableStatus.REBALANCING.name());
+
+    _helixResourceManager.deleteOfflineTable(RAW_TABLE_NAME_A);
+    _helixResourceManager.deleteOfflineTable(RAW_TABLE_NAME_B);
+
+    for (int i = 0; i < numServers + numServersToAdd; i++) {
+      stopAndDropFakeInstance(SERVER_INSTANCE_ID_PREFIX + i);
+    }
+  }
+
+  @Test
+  public void testZkBasedTenantRebalanceObserverCancelJob()
+      throws Exception {
+    int numServers = 3;
+    for (int i = 0; i < numServers; i++) {
+      addFakeServerInstanceToAutoJoinHelixCluster(SERVER_INSTANCE_ID_PREFIX + 
i, true);
+    }
+
+    TenantRebalancer tenantRebalancer =
+        new TenantRebalancer(_tableRebalanceManager, _helixResourceManager, 
_executorService);
+
+    // tag all servers and brokers to test tenant
+    addTenantTagToInstances(TENANT_NAME);
+
+    // create 2 schemas
+    addDummySchema(RAW_TABLE_NAME_A);
+    addDummySchema(RAW_TABLE_NAME_B);
+
+    // create 2 tables on test tenant
+    createTableWithSegments(RAW_TABLE_NAME_A, TENANT_NAME);
+    createTableWithSegments(RAW_TABLE_NAME_B, TENANT_NAME);
+
+    // Add 3 more servers which will be tagged to default tenant
+    int numServersToAdd = 3;
+    for (int i = 0; i < numServersToAdd; i++) {
+      addFakeServerInstanceToAutoJoinHelixCluster(SERVER_INSTANCE_ID_PREFIX + 
(numServers + i), true);
+    }
+    addTenantTagToInstances(TENANT_NAME);
+
+    // Create observer and test cancellation
+    String jobId = "test-cancel-job-456";
+
+    // Create tenant rebalance context with tables in queues
+    TenantRebalanceConfig config = new TenantRebalanceConfig();
+    config.setTenantName(TENANT_NAME);
+    config.setVerboseResult(true);
+    config.setDryRun(true);
+
+    TenantRebalanceResult dryRunResult = tenantRebalancer.rebalance(config);
+    Set<String> tableNames = dryRunResult.getRebalanceTableResults().keySet();
+
+    TenantRebalanceContext context = new TenantRebalanceContext(
+        jobId, config, 1,
+        tableNames.stream()
+            .map(tableName -> new 
TenantRebalancer.TenantTableRebalanceJobContext(tableName, tableName + "_job", 
false))
+            .collect(Collectors.toCollection(ConcurrentLinkedDeque::new)),
+        new LinkedList<>(),
+        new ConcurrentLinkedQueue<>()
+    );
+
+    TenantRebalanceProgressStats progressStats = new 
TenantRebalanceProgressStats(tableNames);
+
+    // Test cancellation by user
+    ZkBasedTenantRebalanceObserver observer =
+        new ZkBasedTenantRebalanceObserver(jobId, TENANT_NAME, progressStats, 
context, _helixResourceManager);
+
+    // move one job to ongoing to test cancellation from that state
+    TenantRebalancer.TenantTableRebalanceJobContext polledJob = 
observer.pollParallel();
+    assertNotNull(polledJob);
+    
assertTrue(dryRunResult.getRebalanceTableResults().containsKey(polledJob.getTableName()));
+
+    Pair<List<String>, Boolean> cancelResult = observer.cancelJob(true);
+    assertTrue(cancelResult.getRight()); // cancellation was successful
+    assertTrue(cancelResult.getLeft().isEmpty()); // no jobs were cancelled 
(they were just in queue)

Review Comment:
   Shouldn't the polled job be cancelled?



##########
pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/TableRebalanceManager.java:
##########
@@ -225,10 +225,15 @@ RebalanceResult rebalanceTable(String tableNameWithType, 
TableConfig tableConfig
    * Cancels ongoing rebalance jobs (if any) for the given table.
    *
    * @param tableNameWithType name of the table for which to cancel any 
ongoing rebalance job
+   * @param resourceManager resource manager to use for updating the job 
metadata in ZK
+   * @param setToStatus status to set the cancelled jobs to. Must be either 
{@link RebalanceResult.Status#ABORTED}
+   *                    or {@link RebalanceResult.Status#CANCELLED}
    * @return the list of job IDs that were cancelled
    */
   public static List<String> cancelRebalance(String tableNameWithType, 
PinotHelixResourceManager resourceManager,
       RebalanceResult.Status setToStatus) {
+    
Preconditions.checkState(setToStatus.equals(RebalanceResult.Status.ABORTED) || 
setToStatus.equals(

Review Comment:
   nit: this should be `checkArgument` instead.



##########
pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/ZkBasedTenantRebalanceObserver.java:
##########
@@ -74,66 +82,211 @@ public ZkBasedTenantRebalanceObserver(String jobId, String 
tenantName, Set<Strin
 
   @Override
   public void onTrigger(Trigger trigger, String tableName, String description) 
{
-    switch (trigger) {
-      case START_TRIGGER:
-        _progressStats.setStartTimeMs(System.currentTimeMillis());
-        break;
-      case REBALANCE_STARTED_TRIGGER:
-        _progressStats.updateTableStatus(tableName, 
TenantRebalanceProgressStats.TableStatus.PROCESSING.name());
-        _progressStats.putTableRebalanceJobId(tableName, description);
-        break;
-      case REBALANCE_COMPLETED_TRIGGER:
-        _progressStats.updateTableStatus(tableName, 
TenantRebalanceProgressStats.TableStatus.PROCESSED.name());
-        _unprocessedTables.remove(tableName);
-        _progressStats.setRemainingTables(_unprocessedTables.size());
-        break;
-      case REBALANCE_ERRORED_TRIGGER:
-        _progressStats.updateTableStatus(tableName, description);
-        _unprocessedTables.remove(tableName);
-        _progressStats.setRemainingTables(_unprocessedTables.size());
-        break;
-      default:
+  }
+
+  public void onStart() {
+    try {
+      updateTenantRebalanceJobMetadataInZk(
+          (ctx, progressStats) -> 
progressStats.setStartTimeMs(System.currentTimeMillis()));
+    } catch (AttemptFailureException e) {
+      LOGGER.error("Error updating ZK for jobId: {} on starting tenant 
rebalance", _jobId, e);
+      throw new RuntimeException("Error updating ZK for jobId: " + _jobId + " 
on starting tenant rebalance", e);
     }
-    syncStatsAndContextInZk();
   }
 
   @Override
   public void onSuccess(String msg) {
-    _progressStats.setCompletionStatusMsg(msg);
-    _progressStats.setTimeToFinishInSeconds((System.currentTimeMillis() - 
_progressStats.getStartTimeMs()) / 1000);
-    syncStatsAndContextInZk();
-    _isDone = true;
+    onFinish(msg);
   }
 
   @Override
   public void onError(String errorMsg) {
-    _progressStats.setCompletionStatusMsg(errorMsg);
-    _progressStats.setTimeToFinishInSeconds(System.currentTimeMillis() - 
_progressStats.getStartTimeMs());
-    syncStatsAndContextInZk();
+    onFinish(errorMsg);
+  }
+
+  private void onFinish(String msg) {
+    try {
+      updateTenantRebalanceJobMetadataInZk((ctx, progressStats) -> {
+        if (StringUtils.isEmpty(progressStats.getCompletionStatusMsg())) {
+          progressStats.setCompletionStatusMsg(msg);
+          progressStats.setTimeToFinishInSeconds((System.currentTimeMillis() - 
progressStats.getStartTimeMs()) / 1000);
+        }
+      });
+    } catch (AttemptFailureException e) {
+      LOGGER.error("Error updating ZK for jobId: {} on successful completion 
of tenant rebalance", _jobId, e);
+      throw new RuntimeException(
+          "Error updating ZK for jobId: " + _jobId + " on successful 
completion of tenant rebalance", e);
+    }
     _isDone = true;
   }
 
-  private void syncStatsAndContextInZk() {
+  private Map<String, String> makeJobMetadata(TenantRebalanceContext 
tenantRebalanceContext,
+      TenantRebalanceProgressStats progressStats) {
     Map<String, String> jobMetadata = new HashMap<>();
     jobMetadata.put(CommonConstants.ControllerJob.TENANT_NAME, _tenantName);
     jobMetadata.put(CommonConstants.ControllerJob.JOB_ID, _jobId);
     jobMetadata.put(CommonConstants.ControllerJob.SUBMISSION_TIME_MS, 
Long.toString(System.currentTimeMillis()));
     jobMetadata.put(CommonConstants.ControllerJob.JOB_TYPE, 
ControllerJobTypes.TENANT_REBALANCE.name());
     try {
       
jobMetadata.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,
-          JsonUtils.objectToString(_progressStats));
+          JsonUtils.objectToString(progressStats));
     } catch (JsonProcessingException e) {
       LOGGER.error("Error serialising rebalance stats to JSON for persisting 
to ZK {}", _jobId, e);
     }
     try {
       jobMetadata.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,
-          JsonUtils.objectToString(_tenantRebalanceContext));
+          JsonUtils.objectToString(tenantRebalanceContext));
     } catch (JsonProcessingException e) {
       LOGGER.error("Error serialising rebalance context to JSON for persisting 
to ZK {}", _jobId, e);
     }
-    _pinotHelixResourceManager.addControllerJobToZK(_jobId, jobMetadata, 
ControllerJobTypes.TENANT_REBALANCE);
-    _numUpdatesToZk++;
-    LOGGER.debug("Number of updates to Zk: {} for rebalanceJob: {}  ", 
_numUpdatesToZk, _jobId);
+    return jobMetadata;
+  }
+
+  public TenantRebalancer.TenantTableRebalanceJobContext pollQueue(boolean 
isParallel) {
+    final TenantRebalancer.TenantTableRebalanceJobContext[] ret =
+        new TenantRebalancer.TenantTableRebalanceJobContext[1];
+    try {
+      updateTenantRebalanceJobMetadataInZk((ctx, progressStats) -> {
+        TenantRebalancer.TenantTableRebalanceJobContext polled =
+            isParallel ? ctx.getParallelQueue().poll() : 
ctx.getSequentialQueue().poll();
+        if (polled != null) {
+          ctx.getOngoingJobsQueue().add(polled);
+          String tableName = polled.getTableName();
+          String rebalanceJobId = polled.getJobId();
+          progressStats.updateTableStatus(tableName, 
TenantRebalanceProgressStats.TableStatus.REBALANCING.name());
+          progressStats.putTableRebalanceJobId(tableName, rebalanceJobId);
+        }
+        ret[0] = polled;
+      });
+    } catch (AttemptFailureException e) {
+      LOGGER.error("Error updating ZK for jobId: {} while polling from {} 
queue", _jobId,
+          isParallel ? "parallel" : "sequential", e);
+      throw new RuntimeException(
+          "Error updating ZK for jobId: " + _jobId + " while polling from " + 
(isParallel ? "parallel" : "sequential")
+              + " queue", e);
+    }
+    return ret[0];
+  }
+
+  public TenantRebalancer.TenantTableRebalanceJobContext pollParallel() {
+    return pollQueue(true);
+  }
+
+  public TenantRebalancer.TenantTableRebalanceJobContext pollSequential() {
+    return pollQueue(false);
+  }
+
+  public void onTableJobError(TenantRebalancer.TenantTableRebalanceJobContext 
jobContext, String errorMessage) {
+    onTableJobComplete(jobContext, errorMessage);
+  }
+
+  public void onTableJobDone(TenantRebalancer.TenantTableRebalanceJobContext 
jobContext) {
+    onTableJobComplete(jobContext, 
TenantRebalanceProgressStats.TableStatus.DONE.name());
+  }
+
+  private void 
onTableJobComplete(TenantRebalancer.TenantTableRebalanceJobContext jobContext, 
String message) {
+    try {
+      updateTenantRebalanceJobMetadataInZk((ctx, progressStats) -> {
+        if (ctx.getOngoingJobsQueue().remove(jobContext)) {
+          progressStats.updateTableStatus(jobContext.getTableName(), message);
+          progressStats.setRemainingTables(progressStats.getRemainingTables() 
- 1);
+        }
+      });
+    } catch (AttemptFailureException e) {
+      LOGGER.error("Error updating ZK for jobId: {} on completion of table 
rebalance job: {}", _jobId, jobContext, e);
+      throw new RuntimeException(
+          "Error updating ZK for jobId: " + _jobId + " on completion of table 
rebalance job: " + jobContext, e);
+    }
+  }
+
+  public Pair<List<String>, Boolean> cancelJob(boolean isCancelledByUser) {
+    List<String> cancelledJobs = new ArrayList<>();
+    try {
+      // Empty the queues first to prevent any new jobs from being picked up.
+      updateTenantRebalanceJobMetadataInZk((tenantRebalanceContext, 
progressStats) -> {
+        TenantRebalancer.TenantTableRebalanceJobContext ctx;
+        while ((ctx = tenantRebalanceContext.getParallelQueue().poll()) != 
null) {
+          progressStats.getTableStatusMap()
+              .put(ctx.getTableName(), 
TenantRebalanceProgressStats.TableStatus.NOT_SCHEDULED.name());
+        }
+        while ((ctx = tenantRebalanceContext.getSequentialQueue().poll()) != 
null) {
+          progressStats.getTableStatusMap()
+              .put(ctx.getTableName(), 
TenantRebalanceProgressStats.TableStatus.NOT_SCHEDULED.name());
+        }
+      });
+      // Try to cancel ongoing jobs with best efforts. There could be some 
ongoing jobs that are marked cancelled but
+      // was completed if table rebalance completed right after 
TableRebalanceManager marked it.
+      updateTenantRebalanceJobMetadataInZk((tenantRebalanceContext, 
progressStats) -> {
+        TenantRebalancer.TenantTableRebalanceJobContext ctx;
+        while ((ctx = tenantRebalanceContext.getOngoingJobsQueue().poll()) != 
null) {
+          
cancelledJobs.addAll(TableRebalanceManager.cancelRebalance(ctx.getTableName(), 
_pinotHelixResourceManager,
+              isCancelledByUser ? RebalanceResult.Status.CANCELLED : 
RebalanceResult.Status.ABORTED));
+          progressStats.getTableStatusMap()
+              .put(ctx.getTableName(), isCancelledByUser ? 
TenantRebalanceProgressStats.TableStatus.CANCELLED.name()
+                  : TenantRebalanceProgressStats.TableStatus.ABORTED.name());
+        }
+        progressStats.setRemainingTables(0);
+        progressStats.setCompletionStatusMsg(
+            "Tenant rebalance job has been " + (isCancelledByUser ? 
"cancelled." : "aborted."));
+        progressStats.setTimeToFinishInSeconds((System.currentTimeMillis() - 
progressStats.getStartTimeMs()) / 1000);
+      });
+      return Pair.of(cancelledJobs, true);
+    } catch (AttemptFailureException e) {
+      return Pair.of(cancelledJobs, false);
+    }
+  }
+
+  private void updateTenantRebalanceJobMetadataInZk(
+      BiConsumer<TenantRebalanceContext, TenantRebalanceProgressStats> updater)
+      throws AttemptFailureException {
+    RetryPolicy retry = 
RetryPolicies.fixedDelayRetryPolicy(ZK_UPDATE_MAX_RETRIES, 
ZK_UPDATE_RETRY_WAIT_MS);
+    retry.attempt(() -> {

Review Comment:
   Ah, that's a good point, and we should probably add a similar retry 
mechanism there as well.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to