This is an automated email from the ASF dual-hosted git repository.

yashmayya pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 8075abf0673 Tenant Rebalance Cancellation (#16886)
8075abf0673 is described below

commit 8075abf06732873cfd888d7bea7e7bc25f000219
Author: Jhow <[email protected]>
AuthorDate: Wed Oct 1 14:53:18 2025 -0700

    Tenant Rebalance Cancellation (#16886)
---
 .../api/resources/PinotTableRestletResource.java   |   3 +-
 .../api/resources/PinotTenantRestletResource.java  |  38 ++
 .../helix/core/rebalance/RebalanceChecker.java     |  26 +-
 .../helix/core/rebalance/RebalanceConfig.java      |  35 ++
 .../core/rebalance/TableRebalanceManager.java      |  19 +-
 .../rebalance/tenant/TenantRebalanceChecker.java   | 121 ++---
 .../rebalance/tenant/TenantRebalanceConfig.java    |  20 +
 .../rebalance/tenant/TenantRebalanceContext.java   |  47 ++
 .../rebalance/tenant/TenantRebalanceObserver.java  |  38 --
 .../tenant/TenantRebalanceProgressStats.java       |  44 +-
 .../core/rebalance/tenant/TenantRebalancer.java    |  93 ++--
 .../tenant/ZkBasedTenantRebalanceObserver.java     | 290 +++++++++---
 .../tenant/TenantRebalanceCheckerTest.java         |  55 ++-
 .../rebalance/tenant/TenantRebalancerTest.java     | 510 ++++++++++++++++++++-
 14 files changed, 1050 insertions(+), 289 deletions(-)

diff --git 
a/pinot-controller/src/main/java/org/apache/pinot/controller/api/resources/PinotTableRestletResource.java
 
b/pinot-controller/src/main/java/org/apache/pinot/controller/api/resources/PinotTableRestletResource.java
index 7dbb2097eb2..7f52166d24c 100644
--- 
a/pinot-controller/src/main/java/org/apache/pinot/controller/api/resources/PinotTableRestletResource.java
+++ 
b/pinot-controller/src/main/java/org/apache/pinot/controller/api/resources/PinotTableRestletResource.java
@@ -919,7 +919,8 @@ public class PinotTableRestletResource {
       @Context HttpHeaders headers) {
     tableName = DatabaseUtils.translateTableName(tableName, headers);
     String tableNameWithType = constructTableNameWithType(tableName, 
tableTypeStr);
-    return _tableRebalanceManager.cancelRebalance(tableNameWithType);
+    return TableRebalanceManager.cancelRebalance(tableNameWithType, 
_pinotHelixResourceManager,
+        RebalanceResult.Status.CANCELLED);
   }
 
   @GET
diff --git 
a/pinot-controller/src/main/java/org/apache/pinot/controller/api/resources/PinotTenantRestletResource.java
 
b/pinot-controller/src/main/java/org/apache/pinot/controller/api/resources/PinotTenantRestletResource.java
index 817062d8a19..defdf0c3cf7 100644
--- 
a/pinot-controller/src/main/java/org/apache/pinot/controller/api/resources/PinotTenantRestletResource.java
+++ 
b/pinot-controller/src/main/java/org/apache/pinot/controller/api/resources/PinotTenantRestletResource.java
@@ -54,6 +54,7 @@ import javax.ws.rs.core.Context;
 import javax.ws.rs.core.HttpHeaders;
 import javax.ws.rs.core.MediaType;
 import javax.ws.rs.core.Response;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.helix.model.IdealState;
 import org.apache.helix.model.InstanceConfig;
 import org.apache.pinot.common.assignment.InstancePartitions;
@@ -72,6 +73,7 @@ import 
org.apache.pinot.controller.helix.core.rebalance.tenant.TenantRebalancePr
 import 
org.apache.pinot.controller.helix.core.rebalance.tenant.TenantRebalanceResult;
 import 
org.apache.pinot.controller.helix.core.rebalance.tenant.TenantRebalancer;
 import 
org.apache.pinot.controller.helix.core.rebalance.tenant.TenantTableWithProperties;
+import 
org.apache.pinot.controller.helix.core.rebalance.tenant.ZkBasedTenantRebalanceObserver;
 import org.apache.pinot.controller.util.TableSizeReader;
 import org.apache.pinot.core.auth.Actions;
 import org.apache.pinot.core.auth.Authorize;
@@ -701,6 +703,42 @@ public class PinotTenantRestletResource {
         Response.Status.INTERNAL_SERVER_ERROR);
   }
 
+  @DELETE
+  @Produces(MediaType.APPLICATION_JSON)
+  @Authenticate(AccessType.DELETE)
+  @Authorize(targetType = TargetType.CLUSTER, action = 
Actions.Cluster.REBALANCE_TENANT_TABLES)
+  @Path("/tenants/rebalance/{jobId}")
+  @ApiOperation(value = "Cancels a running tenant rebalance job")
+  @ApiResponses(value = {
+      @ApiResponse(code = 200, message = "Success", response = 
SuccessResponse.class),
+      @ApiResponse(code = 404, message = "Tenant rebalance job not found"),
+      @ApiResponse(code = 500, message = "Internal server error while 
cancelling the rebalance job")
+  })
+  public SuccessResponse cancelRebalance(
+      @ApiParam(value = "Tenant rebalance job id", required = true) 
@PathParam("jobId") String jobId) {
+    Map<String, String> jobMetadata =
+        _pinotHelixResourceManager.getControllerJobZKMetadata(jobId, 
ControllerJobTypes.TENANT_REBALANCE);
+    if (jobMetadata == null) {
+      throw new ControllerApplicationException(LOGGER, "Tenant rebalance job: 
" + jobId + " not found",
+          Response.Status.NOT_FOUND);
+    }
+    ZkBasedTenantRebalanceObserver observer =
+        new ZkBasedTenantRebalanceObserver(jobId, 
jobMetadata.get(CommonConstants.ControllerJob.TENANT_NAME),
+            _pinotHelixResourceManager);
+    Pair<List<String>, Boolean> result = observer.cancelJob(true);
+    if (result.getRight()) {
+      return new SuccessResponse(
+          "Successfully cancelled tenant rebalance job: " + jobId + ". Number 
of table rebalance jobs cancelled: "
+              + result.getLeft().size() + ": " + result.getLeft());
+    } else {
+      throw new ControllerApplicationException(LOGGER,
+          "Failed to cancel tenant rebalance job: " + jobId
+              + " due to update failure to ZK. Number of table rebalance jobs 
already cancelled: "
+              + result.getLeft().size() + ": " + result.getLeft(),
+          Response.Status.INTERNAL_SERVER_ERROR);
+    }
+  }
+
   @POST
   @Produces(MediaType.APPLICATION_JSON)
   @Authenticate(AccessType.UPDATE)
diff --git 
a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/RebalanceChecker.java
 
b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/RebalanceChecker.java
index 5c73cce4c5b..314f24652a6 100644
--- 
a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/RebalanceChecker.java
+++ 
b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/RebalanceChecker.java
@@ -137,7 +137,8 @@ public class RebalanceChecker extends 
ControllerPeriodicTask<Void> {
     // 3) If configured, we can abort the other rebalance jobs for the table 
by setting their status to FAILED.
 
     if (hasStuckInProgressJobs(tableNameWithType, allJobMetadata)) {
-      abortExistingJobs(tableNameWithType, _pinotHelixResourceManager);
+      TableRebalanceManager.cancelRebalance(tableNameWithType, 
_pinotHelixResourceManager,
+          RebalanceResult.Status.ABORTED);
     }
 
     Map<String/*original jobId*/, Set<Pair<TableRebalanceContext/*job 
attempts*/, Long
@@ -214,29 +215,6 @@ public class RebalanceChecker extends 
ControllerPeriodicTask<Void> {
     return (long) minDelayMs;
   }
 
-  private static void abortExistingJobs(String tableNameWithType, 
PinotHelixResourceManager pinotHelixResourceManager) {
-    boolean updated =
-        pinotHelixResourceManager.updateJobsForTable(tableNameWithType, 
ControllerJobTypes.TABLE_REBALANCE,
-        jobMetadata -> {
-          String jobId = jobMetadata.get(CommonConstants.ControllerJob.JOB_ID);
-          try {
-            String jobStatsInStr = 
jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS);
-            TableRebalanceProgressStats jobStats =
-                JsonUtils.stringToObject(jobStatsInStr, 
TableRebalanceProgressStats.class);
-            if (jobStats.getStatus() != RebalanceResult.Status.IN_PROGRESS) {
-              return;
-            }
-            LOGGER.info("Abort rebalance job: {} for table: {}", jobId, 
tableNameWithType);
-            jobStats.setStatus(RebalanceResult.Status.ABORTED);
-            
jobMetadata.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,
-                JsonUtils.objectToString(jobStats));
-          } catch (Exception e) {
-            LOGGER.error("Failed to abort rebalance job: {} for table: {}", 
jobId, tableNameWithType, e);
-          }
-        });
-    LOGGER.info("Tried to abort existing jobs at best effort and done: {}", 
updated);
-  }
-
   @VisibleForTesting
   static Pair<TableRebalanceContext, Long> getLatestJob(
       Map<String, Set<Pair<TableRebalanceContext, Long>>> candidateJobs) {
diff --git 
a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/RebalanceConfig.java
 
b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/RebalanceConfig.java
index bfb7865d48e..f9d00785dc5 100644
--- 
a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/RebalanceConfig.java
+++ 
b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/RebalanceConfig.java
@@ -22,6 +22,7 @@ import com.fasterxml.jackson.annotation.JsonProperty;
 import com.google.common.base.Preconditions;
 import io.swagger.annotations.ApiModel;
 import io.swagger.annotations.ApiModelProperty;
+import java.util.Objects;
 import org.apache.pinot.controller.api.resources.ForceCommitBatchConfig;
 import org.apache.pinot.spi.utils.Enablement;
 
@@ -377,6 +378,40 @@ public class RebalanceConfig {
     _diskUtilizationThreshold = diskUtilizationThreshold;
   }
 
+  @Override
+  public boolean equals(Object o) {
+    if (!(o instanceof RebalanceConfig)) {
+      return false;
+    }
+    RebalanceConfig that = (RebalanceConfig) o;
+    return _dryRun == that._dryRun && _preChecks == that._preChecks && 
_disableSummary == that._disableSummary
+        && _reassignInstances == that._reassignInstances && _includeConsuming 
== that._includeConsuming
+        && _bootstrap == that._bootstrap && _downtime == that._downtime
+        && _allowPeerDownloadDataLoss == that._allowPeerDownloadDataLoss
+        && _minAvailableReplicas == that._minAvailableReplicas && _lowDiskMode 
== that._lowDiskMode
+        && _bestEfforts == that._bestEfforts && _batchSizePerServer == 
that._batchSizePerServer
+        && _externalViewCheckIntervalInMs == 
that._externalViewCheckIntervalInMs
+        && _externalViewStabilizationTimeoutInMs == 
that._externalViewStabilizationTimeoutInMs
+        && _updateTargetTier == that._updateTargetTier && 
_heartbeatIntervalInMs == that._heartbeatIntervalInMs
+        && _heartbeatTimeoutInMs == that._heartbeatTimeoutInMs && _maxAttempts 
== that._maxAttempts
+        && _retryInitialDelayInMs == that._retryInitialDelayInMs
+        && Double.compare(_diskUtilizationThreshold, 
that._diskUtilizationThreshold) == 0
+        && _forceCommit == that._forceCommit && _forceCommitBatchSize == 
that._forceCommitBatchSize
+        && _forceCommitBatchStatusCheckIntervalMs == 
that._forceCommitBatchStatusCheckIntervalMs
+        && _forceCommitBatchStatusCheckTimeoutMs == 
that._forceCommitBatchStatusCheckTimeoutMs
+        && _minimizeDataMovement == that._minimizeDataMovement;
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(_dryRun, _preChecks, _disableSummary, 
_reassignInstances, _includeConsuming, _bootstrap,
+        _downtime, _allowPeerDownloadDataLoss, _minAvailableReplicas, 
_lowDiskMode, _bestEfforts, _minimizeDataMovement,
+        _batchSizePerServer, _externalViewCheckIntervalInMs, 
_externalViewStabilizationTimeoutInMs, _updateTargetTier,
+        _heartbeatIntervalInMs, _heartbeatTimeoutInMs, _maxAttempts, 
_retryInitialDelayInMs, _diskUtilizationThreshold,
+        _forceCommit, _forceCommitBatchSize, 
_forceCommitBatchStatusCheckIntervalMs,
+        _forceCommitBatchStatusCheckTimeoutMs);
+  }
+
   @Override
   public String toString() {
     return "RebalanceConfig{" + "_dryRun=" + _dryRun + ", preChecks=" + 
_preChecks + ", _disableSummary="
diff --git 
a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/TableRebalanceManager.java
 
b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/TableRebalanceManager.java
index 305391a88d6..d216ff551e4 100644
--- 
a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/TableRebalanceManager.java
+++ 
b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/TableRebalanceManager.java
@@ -225,11 +225,17 @@ public class TableRebalanceManager {
    * Cancels ongoing rebalance jobs (if any) for the given table.
    *
    * @param tableNameWithType name of the table for which to cancel any 
ongoing rebalance job
+   * @param resourceManager resource manager to use for updating the job 
metadata in ZK
+   * @param setToStatus status to set the cancelled jobs to. Must be either 
{@link RebalanceResult.Status#ABORTED}
+   *                    or {@link RebalanceResult.Status#CANCELLED}
    * @return the list of job IDs that were cancelled
    */
-  public List<String> cancelRebalance(String tableNameWithType) {
+  public static List<String> cancelRebalance(String tableNameWithType, 
PinotHelixResourceManager resourceManager,
+      RebalanceResult.Status setToStatus) {
+    
Preconditions.checkArgument(setToStatus.equals(RebalanceResult.Status.ABORTED) 
|| setToStatus.equals(
+        RebalanceResult.Status.CANCELLED));
     List<String> cancelledJobIds = new ArrayList<>();
-    boolean updated = _resourceManager.updateJobsForTable(tableNameWithType, 
ControllerJobTypes.TABLE_REBALANCE,
+    boolean updated = resourceManager.updateJobsForTable(tableNameWithType, 
ControllerJobTypes.TABLE_REBALANCE,
         jobMetadata -> {
           String jobId = jobMetadata.get(CommonConstants.ControllerJob.JOB_ID);
           try {
@@ -240,8 +246,10 @@ public class TableRebalanceManager {
               return;
             }
 
-            LOGGER.info("Cancelling rebalance job: {} for table: {}", jobId, 
tableNameWithType);
-            jobStats.setStatus(RebalanceResult.Status.CANCELLED);
+            LOGGER.info("{} rebalance job: {} for table: {}",
+                setToStatus.equals(RebalanceResult.Status.ABORTED) ? 
"Aborting" : "Cancelling", jobId,
+                tableNameWithType);
+            jobStats.setStatus(setToStatus);
             
jobMetadata.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,
                 JsonUtils.objectToString(jobStats));
             cancelledJobIds.add(jobId);
@@ -249,7 +257,8 @@ public class TableRebalanceManager {
             LOGGER.error("Failed to cancel rebalance job: {} for table: {}", 
jobId, tableNameWithType, e);
           }
         });
-    LOGGER.info("Tried to cancel existing rebalance jobs for table: {} at best 
effort and done: {}", tableNameWithType,
+    LOGGER.info("Tried to {} existing rebalance jobs for table: {} at best 
effort and done: {}.",
+        setToStatus.equals(RebalanceResult.Status.ABORTED) ? "abort" : 
"cancel", tableNameWithType,
         updated);
     return cancelledJobIds;
   }
diff --git 
a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceChecker.java
 
b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceChecker.java
index 954a68f5425..0c1eaf88c10 100644
--- 
a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceChecker.java
+++ 
b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceChecker.java
@@ -19,19 +19,19 @@
 package org.apache.pinot.controller.helix.core.rebalance.tenant;
 
 import com.fasterxml.jackson.core.JsonProcessingException;
+import com.google.common.annotations.VisibleForTesting;
 import java.util.ArrayList;
+import java.util.List;
 import java.util.Map;
 import java.util.Objects;
 import java.util.Properties;
 import java.util.Set;
 import java.util.UUID;
-import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.controller.ControllerConf;
 import org.apache.pinot.controller.helix.core.PinotHelixResourceManager;
 import org.apache.pinot.controller.helix.core.controllerjob.ControllerJobTypes;
 import org.apache.pinot.controller.helix.core.rebalance.RebalanceJobConstants;
-import org.apache.pinot.controller.helix.core.rebalance.RebalanceResult;
-import 
org.apache.pinot.controller.helix.core.rebalance.TableRebalanceProgressStats;
 import org.apache.pinot.core.periodictask.BasePeriodicTask;
 import org.apache.pinot.spi.utils.CommonConstants;
 import org.apache.pinot.spi.utils.JsonUtils;
@@ -87,38 +87,37 @@ public class TenantRebalanceChecker extends 
BasePeriodicTask {
 
       try {
         // Check if the tenant rebalance job is stuck
-        String tenantRebalanceContextStr =
-            
jobZKMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);
-        String tenantRebalanceProgressStatsStr =
-            
jobZKMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS);
-        if (StringUtils.isEmpty(tenantRebalanceContextStr) || 
StringUtils.isEmpty(tenantRebalanceProgressStatsStr)) {
+        TenantRebalanceContext tenantRebalanceContext =
+            
TenantRebalanceContext.fromTenantRebalanceJobMetadata(jobZKMetadata);
+        TenantRebalanceProgressStats progressStats =
+            
TenantRebalanceProgressStats.fromTenantRebalanceJobMetadata(jobZKMetadata);
+        if (tenantRebalanceContext == null || progressStats == null) {
           // Skip rebalance job: {} as it has no job context or progress stats
           LOGGER.info("Skip checking tenant rebalance job: {} as it has no job 
context or progress stats", jobId);
           continue;
         }
-        TenantRebalanceContext tenantRebalanceContext =
-            JsonUtils.stringToObject(tenantRebalanceContextStr, 
TenantRebalanceContext.class);
-        TenantRebalanceProgressStats progressStats =
-            JsonUtils.stringToObject(tenantRebalanceProgressStatsStr, 
TenantRebalanceProgressStats.class);
         long statsUpdatedAt = 
Long.parseLong(jobZKMetadata.get(CommonConstants.ControllerJob.SUBMISSION_TIME_MS));
 
         TenantRebalanceContext retryTenantRebalanceContext =
             prepareRetryIfTenantRebalanceJobStuck(jobZKMetadata, 
tenantRebalanceContext, statsUpdatedAt);
         if (retryTenantRebalanceContext != null) {
-          TenantRebalancer.TenantTableRebalanceJobContext ctx;
-          while ((ctx = 
retryTenantRebalanceContext.getOngoingJobsQueue().poll()) != null) {
-            abortTableRebalanceJob(ctx.getTableName());
-            // the existing table rebalance job is aborted, we need to run the 
rebalance job with a new job ID.
-            TenantRebalancer.TenantTableRebalanceJobContext newCtx =
-                new TenantRebalancer.TenantTableRebalanceJobContext(
-                    ctx.getTableName(), UUID.randomUUID().toString(), 
ctx.shouldRebalanceWithDowntime());
-            retryTenantRebalanceContext.getParallelQueue().addFirst(newCtx);
+          // abort the existing job, then retry with the new job context
+          ZkBasedTenantRebalanceObserver observer =
+              new ZkBasedTenantRebalanceObserver(jobId, 
jobZKMetadata.get(CommonConstants.ControllerJob.TENANT_NAME),
+                  _pinotHelixResourceManager);
+          Pair<List<String>, Boolean> result = observer.cancelJob(false);
+          if (result.getRight()) {
+            TenantRebalancer.TenantTableRebalanceJobContext ctx;
+            while ((ctx = 
retryTenantRebalanceContext.getOngoingJobsQueue().poll()) != null) {
+              TenantRebalancer.TenantTableRebalanceJobContext newCtx =
+                  new TenantRebalancer.TenantTableRebalanceJobContext(
+                      ctx.getTableName(), UUID.randomUUID().toString(), 
ctx.shouldRebalanceWithDowntime());
+              retryTenantRebalanceContext.getParallelQueue().addFirst(newCtx);
+            }
+            retryTenantRebalanceJob(retryTenantRebalanceContext, 
progressStats);
+          } else {
+            LOGGER.warn("Failed to abort the stuck tenant rebalance job: {}, 
will not retry", jobId);
           }
-          // the retry tenant rebalance job id has been created in ZK, we can 
safely mark the original job as
-          // aborted, so that this original job will not be picked up again in 
the future.
-          markTenantRebalanceJobAsAborted(jobId, jobZKMetadata, 
tenantRebalanceContext, progressStats);
-          retryTenantRebalanceJob(retryTenantRebalanceContext, progressStats);
-          // We only retry one stuck tenant rebalance job at a time to avoid 
multiple retries of the same job
           return;
         } else {
           LOGGER.info("Tenant rebalance job: {} is not stuck", jobId);
@@ -130,14 +129,15 @@ public class TenantRebalanceChecker extends 
BasePeriodicTask {
     }
   }
 
-  private void retryTenantRebalanceJob(TenantRebalanceContext 
tenantRebalanceContextForRetry,
+  @VisibleForTesting
+  void retryTenantRebalanceJob(TenantRebalanceContext 
tenantRebalanceContextForRetry,
       TenantRebalanceProgressStats progressStats) {
     ZkBasedTenantRebalanceObserver observer =
         new 
ZkBasedTenantRebalanceObserver(tenantRebalanceContextForRetry.getJobId(),
             tenantRebalanceContextForRetry.getConfig().getTenantName(),
             progressStats, tenantRebalanceContextForRetry, 
_pinotHelixResourceManager);
     _ongoingJobObserver = observer;
-    _tenantRebalancer.rebalanceWithContext(tenantRebalanceContextForRetry, 
observer);
+    _tenantRebalancer.rebalanceWithObserver(observer, 
tenantRebalanceContextForRetry.getConfig());
   }
 
   /**
@@ -260,71 +260,4 @@ public class TenantRebalanceChecker extends 
BasePeriodicTask {
     }
     return false;
   }
-
-  private void abortTableRebalanceJob(String tableNameWithType) {
-    // TODO: This is a duplicate of a private method in RebalanceChecker, we 
should refactor it to a common place.
-    boolean updated =
-        _pinotHelixResourceManager.updateJobsForTable(tableNameWithType, 
ControllerJobTypes.TABLE_REBALANCE,
-            jobMetadata -> {
-              String jobId = 
jobMetadata.get(CommonConstants.ControllerJob.JOB_ID);
-              try {
-                String jobStatsInStr = 
jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS);
-                TableRebalanceProgressStats jobStats =
-                    JsonUtils.stringToObject(jobStatsInStr, 
TableRebalanceProgressStats.class);
-                if (jobStats.getStatus() != 
RebalanceResult.Status.IN_PROGRESS) {
-                  return;
-                }
-                LOGGER.info("Abort rebalance job: {} for table: {}", jobId, 
tableNameWithType);
-                jobStats.setStatus(RebalanceResult.Status.ABORTED);
-                
jobMetadata.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,
-                    JsonUtils.objectToString(jobStats));
-              } catch (Exception e) {
-                LOGGER.error("Failed to abort rebalance job: {} for table: 
{}", jobId, tableNameWithType, e);
-              }
-            });
-    LOGGER.info("Tried to abort existing jobs at best effort and done: {}", 
updated);
-  }
-
-  /**
-   * Mark the tenant rebalance job as aborted by updating the progress stats 
and clearing the queues in the context,
-   * then update the updated job metadata to ZK. The tables that are 
unprocessed will be marked as CANCELLED, and the
-   * tables that are processing will be marked as ABORTED in progress stats.
-   *
-   * @param jobId The ID of the tenant rebalance job.
-   * @param jobMetadata The metadata of the tenant rebalance job.
-   * @param tenantRebalanceContext The context of the tenant rebalance job.
-   * @param progressStats The progress stats of the tenant rebalance job.
-   */
-  private void markTenantRebalanceJobAsAborted(String jobId, Map<String, 
String> jobMetadata,
-      TenantRebalanceContext tenantRebalanceContext,
-      TenantRebalanceProgressStats progressStats) {
-    LOGGER.info("Marking tenant rebalance job: {} as aborted", jobId);
-    TenantRebalanceProgressStats abortedProgressStats = new 
TenantRebalanceProgressStats(progressStats);
-    for (Map.Entry<String, String> entry : 
abortedProgressStats.getTableStatusMap().entrySet()) {
-      if (Objects.equals(entry.getValue(), 
TenantRebalanceProgressStats.TableStatus.UNPROCESSED.name())) {
-        
entry.setValue(TenantRebalanceProgressStats.TableStatus.CANCELLED.name());
-      } else if (Objects.equals(entry.getValue(),
-          TenantRebalanceProgressStats.TableStatus.PROCESSING.name())) {
-        
entry.setValue(TenantRebalanceProgressStats.TableStatus.ABORTED.name());
-      }
-    }
-    tenantRebalanceContext.getSequentialQueue().clear();
-    tenantRebalanceContext.getParallelQueue().clear();
-    tenantRebalanceContext.getOngoingJobsQueue().clear();
-    try {
-      
jobMetadata.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,
-          JsonUtils.objectToString(abortedProgressStats));
-    } catch (JsonProcessingException e) {
-      LOGGER.error("Error serialising rebalance stats to JSON for marking 
tenant rebalance job as aborted {}", jobId,
-          e);
-    }
-    try {
-      jobMetadata.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,
-          JsonUtils.objectToString(tenantRebalanceContext));
-    } catch (JsonProcessingException e) {
-      LOGGER.error("Error serialising rebalance context to JSON for marking 
tenant rebalance job as aborted {}", jobId,
-          e);
-    }
-    _pinotHelixResourceManager.addControllerJobToZK(jobId, jobMetadata, 
ControllerJobTypes.TENANT_REBALANCE);
-  }
 }
diff --git 
a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceConfig.java
 
b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceConfig.java
index 05d8ba6c482..cb25959db65 100644
--- 
a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceConfig.java
+++ 
b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceConfig.java
@@ -22,6 +22,7 @@ import com.fasterxml.jackson.annotation.JsonIgnore;
 import com.fasterxml.jackson.annotation.JsonProperty;
 import io.swagger.annotations.ApiModelProperty;
 import java.util.HashSet;
+import java.util.Objects;
 import java.util.Set;
 import org.apache.pinot.controller.helix.core.rebalance.RebalanceConfig;
 
@@ -114,4 +115,23 @@ public class TenantRebalanceConfig extends RebalanceConfig 
{
   public void setVerboseResult(boolean verboseResult) {
     _verboseResult = verboseResult;
   }
+
+  @Override
+  public boolean equals(Object o) {
+    if (!(o instanceof TenantRebalanceConfig)) {
+      return false;
+    }
+    TenantRebalanceConfig that = (TenantRebalanceConfig) o;
+    return super.equals(o) && _degreeOfParallelism == 
that._degreeOfParallelism && _verboseResult == that._verboseResult
+        && Objects.equals(_tenantName, that._tenantName) && 
Objects.equals(_parallelBlacklist,
+        that._parallelBlacklist) && Objects.equals(_parallelWhitelist, 
that._parallelWhitelist)
+        && Objects.equals(_includeTables, that._includeTables) && 
Objects.equals(_excludeTables,
+        that._excludeTables);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(_tenantName, _degreeOfParallelism, _parallelBlacklist, 
_parallelWhitelist, _includeTables,
+        _excludeTables, _verboseResult);
+  }
 }
diff --git 
a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceContext.java
 
b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceContext.java
index a5ef73a74a1..561355c203e 100644
--- 
a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceContext.java
+++ 
b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceContext.java
@@ -19,10 +19,18 @@
 package org.apache.pinot.controller.helix.core.rebalance.tenant;
 
 import com.fasterxml.jackson.annotation.JsonProperty;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import java.util.Arrays;
 import java.util.LinkedList;
+import java.util.Map;
+import java.util.Objects;
 import java.util.Queue;
 import java.util.concurrent.ConcurrentLinkedDeque;
 import java.util.concurrent.ConcurrentLinkedQueue;
+import javax.annotation.Nullable;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.pinot.controller.helix.core.rebalance.RebalanceJobConstants;
+import org.apache.pinot.spi.utils.JsonUtils;
 
 
 /**
@@ -58,6 +66,16 @@ public class TenantRebalanceContext {
     _ongoingJobsQueue = new ConcurrentLinkedQueue<>();
   }
 
+  public TenantRebalanceContext(TenantRebalanceContext context) {
+    _jobId = context._jobId;
+    _originalJobId = context._originalJobId;
+    _config = context._config;
+    _attemptId = context._attemptId;
+    _parallelQueue = new ConcurrentLinkedDeque<>(context._parallelQueue);
+    _sequentialQueue = new LinkedList<>(context._sequentialQueue);
+    _ongoingJobsQueue = new ConcurrentLinkedQueue<>(context._ongoingJobsQueue);
+  }
+
   public TenantRebalanceContext(String originalJobId, TenantRebalanceConfig 
config, int attemptId,
       ConcurrentLinkedDeque<TenantRebalancer.TenantTableRebalanceJobContext> 
parallelQueue,
       Queue<TenantRebalancer.TenantTableRebalanceJobContext> sequentialQueue,
@@ -114,6 +132,25 @@ public class TenantRebalanceContext {
     return _config;
   }
 
+  @Override
+  public boolean equals(Object o) {
+    if (!(o instanceof TenantRebalanceContext)) {
+      return false;
+    }
+    TenantRebalanceContext that = (TenantRebalanceContext) o;
+    return _attemptId == that._attemptId && Objects.equals(_jobId, 
that._jobId) && Objects.equals(
+        _originalJobId, that._originalJobId) && Objects.equals(_config, 
that._config)
+        && Arrays.equals(_ongoingJobsQueue.toArray(), 
that._ongoingJobsQueue.toArray()) && Arrays.equals(
+        _sequentialQueue.toArray(), that._sequentialQueue.toArray()) && 
Arrays.equals(_parallelQueue.toArray(),
+        that._parallelQueue.toArray());
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(_jobId, _originalJobId, _config, _attemptId, 
_ongoingJobsQueue, _parallelQueue,
+        _sequentialQueue);
+  }
+
   public String toString() {
     return "TenantRebalanceContext{" + "jobId='" + getJobId() + '\'' + ", 
originalJobId='" + getOriginalJobId()
         + '\'' + ", attemptId=" + getAttemptId() + ", parallelQueueSize="
@@ -127,4 +164,14 @@ public class TenantRebalanceContext {
     }
     return originalJobId + "_" + attemptId;
   }
+
+  @Nullable
+  public static TenantRebalanceContext 
fromTenantRebalanceJobMetadata(Map<String, String> jobMetadata)
+      throws JsonProcessingException {
+    String tenantRebalanceContextStr = 
jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT);
+    if (StringUtils.isEmpty(tenantRebalanceContextStr)) {
+      return null;
+    }
+    return JsonUtils.stringToObject(tenantRebalanceContextStr, 
TenantRebalanceContext.class);
+  }
 }
diff --git 
a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceObserver.java
 
b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceObserver.java
deleted file mode 100644
index 82dd086362b..00000000000
--- 
a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceObserver.java
+++ /dev/null
@@ -1,38 +0,0 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.pinot.controller.helix.core.rebalance.tenant;
-
-public interface TenantRebalanceObserver {
-  enum Trigger {
-    // Start of tenant rebalance Trigger
-    START_TRIGGER,
-    // rebalance of a table is started
-    REBALANCE_STARTED_TRIGGER,
-    // rebalance of a table is completed
-    REBALANCE_COMPLETED_TRIGGER,
-    // rebalance of a table is failed
-    REBALANCE_ERRORED_TRIGGER
-  }
-
-  void onTrigger(TenantRebalanceObserver.Trigger trigger, String tableName, 
String description);
-
-  void onSuccess(String msg);
-
-  void onError(String errorMsg);
-}
diff --git 
a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceProgressStats.java
 
b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceProgressStats.java
index 7bad045c0b7..81db89b971b 100644
--- 
a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceProgressStats.java
+++ 
b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceProgressStats.java
@@ -18,12 +18,18 @@
  */
 package org.apache.pinot.controller.helix.core.rebalance.tenant;
 
+import com.fasterxml.jackson.core.JsonProcessingException;
 import com.google.common.base.Preconditions;
 import java.util.HashMap;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Set;
 import java.util.function.Function;
 import java.util.stream.Collectors;
+import javax.annotation.Nullable;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.pinot.controller.helix.core.rebalance.RebalanceJobConstants;
+import org.apache.pinot.spi.utils.JsonUtils;
 
 
 public class TenantRebalanceProgressStats {
@@ -45,7 +51,7 @@ public class TenantRebalanceProgressStats {
   public TenantRebalanceProgressStats(Set<String> tables) {
     Preconditions.checkState(tables != null && !tables.isEmpty(), "List of 
tables to observe is empty.");
     _tableStatusMap = tables.stream()
-        .collect(Collectors.toMap(Function.identity(), k -> 
TableStatus.UNPROCESSED.name()));
+        .collect(Collectors.toMap(Function.identity(), k -> 
TableStatus.IN_QUEUE.name()));
     _totalTables = tables.size();
     _remainingTables = _totalTables;
   }
@@ -60,6 +66,16 @@ public class TenantRebalanceProgressStats {
     _completionStatusMsg = other._completionStatusMsg;
   }
 
+  @Nullable
+  public static TenantRebalanceProgressStats 
fromTenantRebalanceJobMetadata(Map<String, String> jobMetadata)
+      throws JsonProcessingException {
+    String tenantRebalanceContextStr = 
jobMetadata.get(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS);
+    if (StringUtils.isEmpty(tenantRebalanceContextStr)) {
+      return null;
+    }
+    return JsonUtils.stringToObject(tenantRebalanceContextStr, 
TenantRebalanceProgressStats.class);
+  }
+
   public Map<String, String> getTableStatusMap() {
     return _tableStatusMap;
   }
@@ -121,6 +137,30 @@ public class TenantRebalanceProgressStats {
   }
 
   public enum TableStatus {
-    UNPROCESSED, PROCESSING, PROCESSED, ABORTED, CANCELLED
+    IN_QUEUE,
+    REBALANCING,
+    DONE,
+    CANCELLED, // cancelled by user
+    ABORTED, // cancelled by TenantRebalanceChecker
+    NOT_SCHEDULED // tables IN_QUEUE will be marked as NOT_SCHEDULED once the 
rebalance job is cancelled/aborted
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (!(o instanceof TenantRebalanceProgressStats)) {
+      return false;
+    }
+    TenantRebalanceProgressStats that = (TenantRebalanceProgressStats) o;
+    return _totalTables == that._totalTables && _remainingTables == 
that._remainingTables
+        && _startTimeMs == that._startTimeMs && _timeToFinishInSeconds == 
that._timeToFinishInSeconds
+        && Objects.equals(_tableStatusMap, that._tableStatusMap) && 
Objects.equals(
+        _tableRebalanceJobIdMap, that._tableRebalanceJobIdMap) && 
Objects.equals(_completionStatusMsg,
+        that._completionStatusMsg);
+  }
+
+  @Override
+  public int hashCode() {
+    return Objects.hash(_tableStatusMap, _tableRebalanceJobIdMap, 
_totalTables, _remainingTables, _startTimeMs,
+        _timeToFinishInSeconds, _completionStatusMsg);
   }
 }
diff --git 
a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalancer.java
 
b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalancer.java
index 1f2cb74f47a..8f30c45d775 100644
--- 
a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalancer.java
+++ 
b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalancer.java
@@ -25,11 +25,11 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedList;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Queue;
 import java.util.Set;
 import java.util.UUID;
 import java.util.concurrent.ConcurrentLinkedDeque;
-import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.atomic.AtomicInteger;
 import javax.annotation.Nullable;
@@ -93,6 +93,21 @@ public class TenantRebalancer {
     public boolean shouldRebalanceWithDowntime() {
       return _withDowntime;
     }
+
+    @Override
+    public boolean equals(Object o) {
+      if (!(o instanceof TenantTableRebalanceJobContext)) {
+        return false;
+      }
+      TenantTableRebalanceJobContext that = (TenantTableRebalanceJobContext) o;
+      return _withDowntime == that._withDowntime && Objects.equals(_tableName, 
that._tableName)
+          && Objects.equals(_jobId, that._jobId);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hash(_tableName, _jobId, _withDowntime);
+    }
   }
 
   public TenantRebalanceResult rebalance(TenantRebalanceConfig config) {
@@ -141,11 +156,15 @@ public class TenantRebalancer {
         TenantRebalanceContext.forInitialRebalance(tenantRebalanceJobId, 
config, parallelQueue,
             sequentialQueue);
 
+    // ZK observer would likely to fail to update if the allowed retries is 
lower than the degree of parallelism,
+    // because all threads would poll when the tenant rebalance job starts at 
the same time.
+    int observerUpdaterMaxRetries =
+        Math.max(config.getDegreeOfParallelism(), 
ZkBasedTenantRebalanceObserver.DEFAULT_ZK_UPDATE_MAX_RETRIES);
     ZkBasedTenantRebalanceObserver observer =
         new ZkBasedTenantRebalanceObserver(tenantRebalanceContext.getJobId(), 
config.getTenantName(),
-            tables, tenantRebalanceContext, _pinotHelixResourceManager);
+            tables, tenantRebalanceContext, _pinotHelixResourceManager, 
observerUpdaterMaxRetries);
     // Step 4: Spin up threads to consume the parallel queue and sequential 
queue.
-    rebalanceWithContext(tenantRebalanceContext, observer);
+    rebalanceWithObserver(observer, config);
 
     // Step 5: Prepare the rebalance results to be returned to the user. The 
rebalance jobs are running in the
     // background asynchronously.
@@ -166,46 +185,37 @@ public class TenantRebalancer {
   }
 
   /**
-   * Spins up threads to rebalance the tenant with the given context and 
observer.
-   * The rebalance operation is performed in parallel for the tables in the 
parallel queue, then, sequentially for the
-   * tables in the sequential queue.
-   * The observer should be initiated with the tenantRebalanceContext in order 
to track the progress properly.
+   * Spins up threads to rebalance the tenant with the given context and 
observer. The rebalance operation is performed
+   * in parallel for the tables in the parallel queue, then, sequentially for 
the tables in the sequential queue. The
+   * observer should be initiated with the tenantRebalanceContext in order to 
track the progress properly.
    *
-   * @param tenantRebalanceContext The context containing the configuration 
and queues for the rebalance operation.
    * @param observer The observer to notify about the rebalance progress and 
results.
    */
-  public void rebalanceWithContext(TenantRebalanceContext 
tenantRebalanceContext,
-      ZkBasedTenantRebalanceObserver observer) {
-    LOGGER.info("Starting tenant rebalance with context: {}", 
tenantRebalanceContext);
-    TenantRebalanceConfig config = tenantRebalanceContext.getConfig();
-    ConcurrentLinkedDeque<TenantTableRebalanceJobContext> parallelQueue = 
tenantRebalanceContext.getParallelQueue();
-    Queue<TenantTableRebalanceJobContext> sequentialQueue = 
tenantRebalanceContext.getSequentialQueue();
-    ConcurrentLinkedQueue<TenantTableRebalanceJobContext> ongoingJobs = 
tenantRebalanceContext.getOngoingJobsQueue();
-
-    observer.onTrigger(TenantRebalanceObserver.Trigger.START_TRIGGER, null, 
null);
+  public void rebalanceWithObserver(ZkBasedTenantRebalanceObserver observer, 
TenantRebalanceConfig config) {
+    observer.onStart();
 
     // ensure atleast 1 thread is created to run the sequential table 
rebalance operations
     int parallelism = Math.max(config.getDegreeOfParallelism(), 1);
-    LOGGER.info("Spinning up {} threads for tenant rebalance job: {}", 
parallelism, tenantRebalanceContext.getJobId());
+    LOGGER.info("Spinning up {} threads for tenant rebalance job: {}", 
parallelism, observer.getJobId());
     AtomicInteger activeThreads = new AtomicInteger(parallelism);
     try {
       for (int i = 0; i < parallelism; i++) {
         _executorService.submit(() -> {
-          doConsumeTablesFromQueueAndRebalance(parallelQueue, ongoingJobs, 
config, observer);
+          doConsumeTablesFromQueueAndRebalance(config, observer, true);
           // If this is the last thread to finish, start consuming the 
sequential queue
           if (activeThreads.decrementAndGet() == 0) {
             LOGGER.info("All parallel threads completed, starting sequential 
rebalance for job: {}",
-                tenantRebalanceContext.getJobId());
-            doConsumeTablesFromQueueAndRebalance(sequentialQueue, ongoingJobs, 
config, observer);
+                observer.getJobId());
+            doConsumeTablesFromQueueAndRebalance(config, observer, false);
             observer.onSuccess(String.format("Successfully rebalanced tenant 
%s.", config.getTenantName()));
-            LOGGER.info("Completed tenant rebalance job: {}", 
tenantRebalanceContext.getJobId());
+            LOGGER.info("Completed tenant rebalance job: {}", 
observer.getJobId());
           }
         });
       }
     } catch (Exception exception) {
       observer.onError(String.format("Failed to rebalance the tenant %s. 
Cause: %s", config.getTenantName(),
           exception.getMessage()));
-      LOGGER.error("Caught exception in tenant rebalance job: {}, Cause: {}", 
tenantRebalanceContext.getJobId(),
+      LOGGER.error("Caught exception in tenant rebalance job: {}, Cause: {}", 
observer.getJobId(),
           exception.getMessage(), exception);
     }
   }
@@ -216,21 +226,24 @@ public class TenantRebalancer {
    * The ongoing jobs are tracked in the ongoingJobs queue, which is also from 
the monitored
    * DefaultTenantRebalanceContext.
    *
-   * @param queue The queue of TenantTableRebalanceJobContext to consume 
tables from.
-   * @param ongoingJobs The queue to track ongoing rebalance jobs.
    * @param config The rebalance configuration to use for the rebalancing.
    * @param observer The observer to notify about the rebalance progress and 
results, should be initiated with the
    *                 DefaultTenantRebalanceContext that contains `queue` and 
`ongoingJobs`.
    */
-  private void 
doConsumeTablesFromQueueAndRebalance(Queue<TenantTableRebalanceJobContext> 
queue,
-      Queue<TenantTableRebalanceJobContext> ongoingJobs, RebalanceConfig 
config,
-      ZkBasedTenantRebalanceObserver observer) {
+  private void doConsumeTablesFromQueueAndRebalance(RebalanceConfig config,
+      ZkBasedTenantRebalanceObserver observer, boolean isParallel) {
     while (true) {
-      TenantTableRebalanceJobContext jobContext = queue.poll();
+      TenantTableRebalanceJobContext jobContext;
+      try {
+        jobContext = isParallel ? observer.pollParallel() : 
observer.pollSequential();
+      } catch (Exception e) {
+        LOGGER.error("Caught exception while polling from the queue in tenant 
rebalance job: {}",
+            observer.getJobId(), e);
+        break;
+      }
       if (jobContext == null) {
         break;
       }
-      ongoingJobs.add(jobContext);
       String tableName = jobContext.getTableName();
       String rebalanceJobId = jobContext.getJobId();
       RebalanceConfig rebalanceConfig = RebalanceConfig.copy(config);
@@ -241,7 +254,6 @@ public class TenantRebalancer {
       try {
         LOGGER.info("Starting rebalance for table: {} with table rebalance job 
ID: {} in tenant rebalance job: {}",
             tableName, rebalanceJobId, observer.getJobId());
-        
observer.onTrigger(TenantRebalanceObserver.Trigger.REBALANCE_STARTED_TRIGGER, 
tableName, rebalanceJobId);
         // Disallow TABLE rebalance checker to retry the rebalance job here, 
since we want TENANT rebalance checker
         // to do so
         RebalanceResult result =
@@ -251,20 +263,19 @@ public class TenantRebalancer {
         if (result.getStatus().equals(RebalanceResult.Status.DONE)) {
           LOGGER.info("Completed rebalance for table: {} with table rebalance 
job ID: {} in tenant rebalance job: {}",
               tableName, rebalanceJobId, observer.getJobId());
-          ongoingJobs.remove(jobContext);
-          
observer.onTrigger(TenantRebalanceObserver.Trigger.REBALANCE_COMPLETED_TRIGGER, 
tableName, null);
+          observer.onTableJobDone(jobContext);
         } else {
-          LOGGER.warn("Rebalance for table: {} with table rebalance job ID: {} 
in tenant rebalance job: {} is not done."
+          LOGGER.warn(
+              "Rebalance for table: {} with table rebalance job ID: {} in 
tenant rebalance job: {} is not done."
                   + "Status: {}, Description: {}", tableName, rebalanceJobId, 
observer.getJobId(), result.getStatus(),
               result.getDescription());
-          ongoingJobs.remove(jobContext);
-          
observer.onTrigger(TenantRebalanceObserver.Trigger.REBALANCE_ERRORED_TRIGGER, 
tableName,
-              result.getDescription());
+          observer.onTableJobError(jobContext, result.getDescription());
         }
-      } catch (Throwable t) {
-        ongoingJobs.remove(jobContext);
-        
observer.onTrigger(TenantRebalanceObserver.Trigger.REBALANCE_ERRORED_TRIGGER, 
tableName,
-            String.format("Caught exception/error while rebalancing table: 
%s", tableName));
+      } catch (Exception e) {
+        LOGGER.error("Caught exception while rebalancing table: {} with table 
rebalance job ID: {} in tenant "
+            + "rebalance job: {}", tableName, rebalanceJobId, 
observer.getJobId(), e);
+        observer.onTableJobError(jobContext,
+            String.format("Caught exception/error while rebalancing table: %s. 
%s", tableName, e.getMessage()));
       }
     }
   }
diff --git 
a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/ZkBasedTenantRebalanceObserver.java
 
b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/ZkBasedTenantRebalanceObserver.java
index 9606db14b4b..e1237238065 100644
--- 
a/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/ZkBasedTenantRebalanceObserver.java
+++ 
b/pinot-controller/src/main/java/org/apache/pinot/controller/helix/core/rebalance/tenant/ZkBasedTenantRebalanceObserver.java
@@ -20,100 +20,117 @@ package 
org.apache.pinot.controller.helix.core.rebalance.tenant;
 
 import com.fasterxml.jackson.core.JsonProcessingException;
 import com.google.common.annotations.VisibleForTesting;
+import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
-import java.util.stream.Collectors;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.atomic.AtomicReference;
+import java.util.function.BiConsumer;
+import org.apache.commons.lang3.StringUtils;
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.controller.helix.core.PinotHelixResourceManager;
 import org.apache.pinot.controller.helix.core.controllerjob.ControllerJobTypes;
 import org.apache.pinot.controller.helix.core.rebalance.RebalanceJobConstants;
+import org.apache.pinot.controller.helix.core.rebalance.RebalanceResult;
+import org.apache.pinot.controller.helix.core.rebalance.TableRebalanceManager;
 import org.apache.pinot.spi.utils.CommonConstants;
 import org.apache.pinot.spi.utils.JsonUtils;
+import org.apache.pinot.spi.utils.retry.AttemptFailureException;
+import org.apache.pinot.spi.utils.retry.RetryPolicies;
+import org.apache.pinot.spi.utils.retry.RetryPolicy;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 
-public class ZkBasedTenantRebalanceObserver implements TenantRebalanceObserver 
{
+public class ZkBasedTenantRebalanceObserver {
   private static final Logger LOGGER = 
LoggerFactory.getLogger(ZkBasedTenantRebalanceObserver.class);
+  public static final int DEFAULT_ZK_UPDATE_MAX_RETRIES = 3;
+  private static final int MIN_ZK_UPDATE_RETRY_DELAY_MS = 100;
+  private static final int MAX_ZK_UPDATE_RETRY_DELAY_MS = 200;
 
   private final PinotHelixResourceManager _pinotHelixResourceManager;
   private final String _jobId;
   private final String _tenantName;
-  private final List<String> _unprocessedTables;
-  private final TenantRebalanceProgressStats _progressStats;
-  private final TenantRebalanceContext _tenantRebalanceContext;
   // Keep track of number of updates. Useful during debugging.
-  private int _numUpdatesToZk;
+  private final AtomicInteger _numUpdatesToZk;
+  private final int _zkUpdateMaxRetries;
   private boolean _isDone;
 
-  public ZkBasedTenantRebalanceObserver(String jobId, String tenantName, 
TenantRebalanceProgressStats progressStats,
-      TenantRebalanceContext tenantRebalanceContext,
-      PinotHelixResourceManager pinotHelixResourceManager) {
+  private ZkBasedTenantRebalanceObserver(String jobId, String tenantName,
+      PinotHelixResourceManager pinotHelixResourceManager, int 
zkUpdateMaxRetries) {
     _isDone = false;
     _jobId = jobId;
     _tenantName = tenantName;
-    _unprocessedTables = progressStats.getTableStatusMap()
-        .entrySet()
-        .stream()
-        .filter(entry -> 
entry.getValue().equals(TenantRebalanceProgressStats.TableStatus.UNPROCESSED.name()))
-        .map(Map.Entry::getKey)
-        .collect(Collectors.toList());
-    _tenantRebalanceContext = tenantRebalanceContext;
     _pinotHelixResourceManager = pinotHelixResourceManager;
-    _progressStats = progressStats;
-    _numUpdatesToZk = 0;
+    _zkUpdateMaxRetries = zkUpdateMaxRetries;
+    _numUpdatesToZk = new AtomicInteger(0);
+  }
+
+  private ZkBasedTenantRebalanceObserver(String jobId, String tenantName, 
TenantRebalanceProgressStats progressStats,
+      TenantRebalanceContext tenantRebalanceContext, PinotHelixResourceManager 
pinotHelixResourceManager,
+      int zkUpdateMaxRetries) {
+    this(jobId, tenantName, pinotHelixResourceManager, zkUpdateMaxRetries);
+    _pinotHelixResourceManager.addControllerJobToZK(_jobId, 
makeJobMetadata(tenantRebalanceContext, progressStats),
+        ControllerJobTypes.TENANT_REBALANCE);
+    _numUpdatesToZk.incrementAndGet();
   }
 
   public ZkBasedTenantRebalanceObserver(String jobId, String tenantName, 
Set<String> tables,
-      TenantRebalanceContext tenantRebalanceContext,
-      PinotHelixResourceManager pinotHelixResourceManager) {
+      TenantRebalanceContext tenantRebalanceContext, PinotHelixResourceManager 
pinotHelixResourceManager,
+      int zkUpdateMaxRetries) {
     this(jobId, tenantName, new TenantRebalanceProgressStats(tables), 
tenantRebalanceContext,
-        pinotHelixResourceManager);
-  }
-
-  @Override
-  public void onTrigger(Trigger trigger, String tableName, String description) 
{
-    switch (trigger) {
-      case START_TRIGGER:
-        _progressStats.setStartTimeMs(System.currentTimeMillis());
-        break;
-      case REBALANCE_STARTED_TRIGGER:
-        _progressStats.updateTableStatus(tableName, 
TenantRebalanceProgressStats.TableStatus.PROCESSING.name());
-        _progressStats.putTableRebalanceJobId(tableName, description);
-        break;
-      case REBALANCE_COMPLETED_TRIGGER:
-        _progressStats.updateTableStatus(tableName, 
TenantRebalanceProgressStats.TableStatus.PROCESSED.name());
-        _unprocessedTables.remove(tableName);
-        _progressStats.setRemainingTables(_unprocessedTables.size());
-        break;
-      case REBALANCE_ERRORED_TRIGGER:
-        _progressStats.updateTableStatus(tableName, description);
-        _unprocessedTables.remove(tableName);
-        _progressStats.setRemainingTables(_unprocessedTables.size());
-        break;
-      default:
+        pinotHelixResourceManager, zkUpdateMaxRetries);
+  }
+
+  public ZkBasedTenantRebalanceObserver(String jobId, String tenantName,
+      PinotHelixResourceManager pinotHelixResourceManager) {
+    this(jobId, tenantName, pinotHelixResourceManager, 
DEFAULT_ZK_UPDATE_MAX_RETRIES);
+  }
+
+  public ZkBasedTenantRebalanceObserver(String jobId, String tenantName, 
TenantRebalanceProgressStats progressStats,
+      TenantRebalanceContext tenantRebalanceContext, PinotHelixResourceManager 
pinotHelixResourceManager) {
+    this(jobId, tenantName, progressStats, tenantRebalanceContext, 
pinotHelixResourceManager,
+        DEFAULT_ZK_UPDATE_MAX_RETRIES);
+  }
+
+  public void onStart() {
+    try {
+      updateTenantRebalanceJobMetadataInZk(
+          (ctx, progressStats) -> 
progressStats.setStartTimeMs(System.currentTimeMillis()));
+    } catch (AttemptFailureException e) {
+      LOGGER.error("Error updating ZK for jobId: {} on starting tenant 
rebalance", _jobId, e);
+      throw new RuntimeException("Error updating ZK for jobId: " + _jobId + " 
on starting tenant rebalance", e);
     }
-    syncStatsAndContextInZk();
   }
 
-  @Override
   public void onSuccess(String msg) {
-    _progressStats.setCompletionStatusMsg(msg);
-    _progressStats.setTimeToFinishInSeconds((System.currentTimeMillis() - 
_progressStats.getStartTimeMs()) / 1000);
-    syncStatsAndContextInZk();
-    _isDone = true;
+    onFinish(msg);
   }
 
-  @Override
   public void onError(String errorMsg) {
-    _progressStats.setCompletionStatusMsg(errorMsg);
-    _progressStats.setTimeToFinishInSeconds(System.currentTimeMillis() - 
_progressStats.getStartTimeMs());
-    syncStatsAndContextInZk();
+    onFinish(errorMsg);
+  }
+
+  private void onFinish(String msg) {
+    try {
+      updateTenantRebalanceJobMetadataInZk((ctx, progressStats) -> {
+        if (StringUtils.isEmpty(progressStats.getCompletionStatusMsg())) {
+          progressStats.setCompletionStatusMsg(msg);
+          progressStats.setTimeToFinishInSeconds((System.currentTimeMillis() - 
progressStats.getStartTimeMs()) / 1000);
+        }
+      });
+    } catch (AttemptFailureException e) {
+      LOGGER.error("Error updating ZK for jobId: {} on successful completion 
of tenant rebalance", _jobId, e);
+      throw new RuntimeException(
+          "Error updating ZK for jobId: " + _jobId + " on successful 
completion of tenant rebalance", e);
+    }
     _isDone = true;
   }
 
-  private void syncStatsAndContextInZk() {
+  private Map<String, String> makeJobMetadata(TenantRebalanceContext 
tenantRebalanceContext,
+      TenantRebalanceProgressStats progressStats) {
     Map<String, String> jobMetadata = new HashMap<>();
     jobMetadata.put(CommonConstants.ControllerJob.TENANT_NAME, _tenantName);
     jobMetadata.put(CommonConstants.ControllerJob.JOB_ID, _jobId);
@@ -121,19 +138,172 @@ public class ZkBasedTenantRebalanceObserver implements 
TenantRebalanceObserver {
     jobMetadata.put(CommonConstants.ControllerJob.JOB_TYPE, 
ControllerJobTypes.TENANT_REBALANCE.name());
     try {
       
jobMetadata.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_PROGRESS_STATS,
-          JsonUtils.objectToString(_progressStats));
+          JsonUtils.objectToString(progressStats));
     } catch (JsonProcessingException e) {
       LOGGER.error("Error serialising rebalance stats to JSON for persisting 
to ZK {}", _jobId, e);
     }
     try {
       jobMetadata.put(RebalanceJobConstants.JOB_METADATA_KEY_REBALANCE_CONTEXT,
-          JsonUtils.objectToString(_tenantRebalanceContext));
+          JsonUtils.objectToString(tenantRebalanceContext));
     } catch (JsonProcessingException e) {
       LOGGER.error("Error serialising rebalance context to JSON for persisting 
to ZK {}", _jobId, e);
     }
-    _pinotHelixResourceManager.addControllerJobToZK(_jobId, jobMetadata, 
ControllerJobTypes.TENANT_REBALANCE);
-    _numUpdatesToZk++;
-    LOGGER.debug("Number of updates to Zk: {} for rebalanceJob: {}  ", 
_numUpdatesToZk, _jobId);
+    return jobMetadata;
+  }
+
+  public TenantRebalancer.TenantTableRebalanceJobContext pollQueue(boolean 
isParallel) {
+    AtomicReference<TenantRebalancer.TenantTableRebalanceJobContext> ret = new 
AtomicReference<>();
+    try {
+      updateTenantRebalanceJobMetadataInZk((ctx, progressStats) -> {
+        TenantRebalancer.TenantTableRebalanceJobContext polled =
+            isParallel ? ctx.getParallelQueue().poll() : 
ctx.getSequentialQueue().poll();
+        if (polled != null) {
+          ctx.getOngoingJobsQueue().add(polled);
+          String tableName = polled.getTableName();
+          String rebalanceJobId = polled.getJobId();
+          progressStats.updateTableStatus(tableName, 
TenantRebalanceProgressStats.TableStatus.REBALANCING.name());
+          progressStats.putTableRebalanceJobId(tableName, rebalanceJobId);
+        }
+        ret.set(polled);
+      });
+    } catch (AttemptFailureException e) {
+      LOGGER.error("Error updating ZK for jobId: {} while polling from {} 
queue", _jobId,
+          isParallel ? "parallel" : "sequential", e);
+      throw new RuntimeException(
+          "Error updating ZK for jobId: " + _jobId + " while polling from " + 
(isParallel ? "parallel" : "sequential")
+              + " queue", e);
+    }
+    return ret.get();
+  }
+
+  public TenantRebalancer.TenantTableRebalanceJobContext pollParallel() {
+    return pollQueue(true);
+  }
+
+  public TenantRebalancer.TenantTableRebalanceJobContext pollSequential() {
+    return pollQueue(false);
+  }
+
+  public void onTableJobError(TenantRebalancer.TenantTableRebalanceJobContext 
jobContext, String errorMessage) {
+    onTableJobComplete(jobContext, errorMessage);
+  }
+
+  public void onTableJobDone(TenantRebalancer.TenantTableRebalanceJobContext 
jobContext) {
+    onTableJobComplete(jobContext, 
TenantRebalanceProgressStats.TableStatus.DONE.name());
+  }
+
+  private void 
onTableJobComplete(TenantRebalancer.TenantTableRebalanceJobContext jobContext, 
String message) {
+    try {
+      updateTenantRebalanceJobMetadataInZk((ctx, progressStats) -> {
+        if (ctx.getOngoingJobsQueue().remove(jobContext)) {
+          progressStats.updateTableStatus(jobContext.getTableName(), message);
+          progressStats.setRemainingTables(progressStats.getRemainingTables() 
- 1);
+        }
+      });
+    } catch (AttemptFailureException e) {
+      LOGGER.error("Error updating ZK for jobId: {} on completion of table 
rebalance job: {}", _jobId, jobContext, e);
+      throw new RuntimeException(
+          "Error updating ZK for jobId: " + _jobId + " on completion of table 
rebalance job: " + jobContext, e);
+    }
+  }
+
+  /**
+   * Cancel the tenant rebalance job.
+   * @param isCancelledByUser true if the cancellation is triggered by user, 
false if it is triggered by system
+   *                           (e.g. tenant rebalance checker retrying a job)
+   * @return a pair of "list of TABLE rebalance job IDs that are successfully 
cancelled" and "whether the TENANT
+   * rebalance
+   * job cancellation is successful"
+   */
+  public Pair<List<String>, Boolean> cancelJob(boolean isCancelledByUser) {
+    List<String> cancelledJobs = new ArrayList<>();
+    try {
+      // Empty the queues first to prevent any new jobs from being picked up.
+      updateTenantRebalanceJobMetadataInZk((tenantRebalanceContext, 
progressStats) -> {
+        TenantRebalancer.TenantTableRebalanceJobContext ctx;
+        while ((ctx = tenantRebalanceContext.getParallelQueue().poll()) != 
null) {
+          progressStats.getTableStatusMap()
+              .put(ctx.getTableName(), 
TenantRebalanceProgressStats.TableStatus.NOT_SCHEDULED.name());
+        }
+        while ((ctx = tenantRebalanceContext.getSequentialQueue().poll()) != 
null) {
+          progressStats.getTableStatusMap()
+              .put(ctx.getTableName(), 
TenantRebalanceProgressStats.TableStatus.NOT_SCHEDULED.name());
+        }
+      });
+      // Try to cancel ongoing jobs with best efforts. There could be some 
ongoing jobs that are marked cancelled but
+      // was completed if table rebalance completed right after 
TableRebalanceManager marked it.
+      updateTenantRebalanceJobMetadataInZk((tenantRebalanceContext, 
progressStats) -> {
+        TenantRebalancer.TenantTableRebalanceJobContext ctx;
+        while ((ctx = tenantRebalanceContext.getOngoingJobsQueue().poll()) != 
null) {
+          
cancelledJobs.addAll(TableRebalanceManager.cancelRebalance(ctx.getTableName(), 
_pinotHelixResourceManager,
+              isCancelledByUser ? RebalanceResult.Status.CANCELLED : 
RebalanceResult.Status.ABORTED));
+          progressStats.getTableStatusMap()
+              .put(ctx.getTableName(), isCancelledByUser ? 
TenantRebalanceProgressStats.TableStatus.CANCELLED.name()
+                  : TenantRebalanceProgressStats.TableStatus.ABORTED.name());
+        }
+        progressStats.setRemainingTables(0);
+        progressStats.setCompletionStatusMsg(
+            "Tenant rebalance job has been " + (isCancelledByUser ? 
"cancelled." : "aborted."));
+        progressStats.setTimeToFinishInSeconds((System.currentTimeMillis() - 
progressStats.getStartTimeMs()) / 1000);
+      });
+      return Pair.of(cancelledJobs, true);
+    } catch (AttemptFailureException e) {
+      return Pair.of(cancelledJobs, false);
+    }
+  }
+
+  private void updateTenantRebalanceJobMetadataInZk(
+      BiConsumer<TenantRebalanceContext, TenantRebalanceProgressStats> updater)
+      throws AttemptFailureException {
+    RetryPolicy retry = 
RetryPolicies.randomDelayRetryPolicy(_zkUpdateMaxRetries, 
MIN_ZK_UPDATE_RETRY_DELAY_MS,
+        MAX_ZK_UPDATE_RETRY_DELAY_MS);
+    retry.attempt(() -> {
+      Map<String, String> jobMetadata =
+          _pinotHelixResourceManager.getControllerJobZKMetadata(_jobId, 
ControllerJobTypes.TENANT_REBALANCE);
+      if (jobMetadata == null) {
+        return false;
+      }
+      TenantRebalanceContext originalContext = 
TenantRebalanceContext.fromTenantRebalanceJobMetadata(jobMetadata);
+      TenantRebalanceProgressStats originalStats =
+          
TenantRebalanceProgressStats.fromTenantRebalanceJobMetadata(jobMetadata);
+      if (originalContext == null || originalStats == null) {
+        LOGGER.warn("Skip updating ZK since rebalance context or progress 
stats is not present in ZK for jobId: {}",
+            _jobId);
+        return false;
+      }
+      TenantRebalanceContext updatedContext = new 
TenantRebalanceContext(originalContext);
+      TenantRebalanceProgressStats updatedStats = new 
TenantRebalanceProgressStats(originalStats);
+      updater.accept(updatedContext, updatedStats);
+      boolean updateSuccessful =
+          _pinotHelixResourceManager.addControllerJobToZK(_jobId, 
makeJobMetadata(updatedContext, updatedStats),
+              ControllerJobTypes.TENANT_REBALANCE, prevJobMetadata -> {
+                try {
+                  TenantRebalanceContext prevContext =
+                      
TenantRebalanceContext.fromTenantRebalanceJobMetadata(prevJobMetadata);
+                  TenantRebalanceProgressStats prevStats =
+                      
TenantRebalanceProgressStats.fromTenantRebalanceJobMetadata(prevJobMetadata);
+                  if (prevContext == null || prevStats == null) {
+                    LOGGER.warn(
+                        "Failed to update ZK since rebalance context or 
progress stats was removed in ZK for "
+                            + "jobId: {}", _jobId);
+                    return false;
+                  }
+                  return prevContext.equals(originalContext) && 
prevStats.equals(originalStats);
+                } catch (JsonProcessingException e) {
+                  LOGGER.error("Error deserializing rebalance context from ZK 
for jobId: {}", _jobId, e);
+                  return false;
+                }
+              });
+      if (updateSuccessful) {
+        return true;
+      } else {
+        LOGGER.info(
+            "Tenant rebalance context or progress stats is out of sync with ZK 
while polling, fetching the latest "
+                + "context and progress stats from ZK and retry. jobId: {}", 
_jobId);
+        return false;
+      }
+    });
+    LOGGER.debug("Number of updates to Zk: {} for rebalanceJob: {}  ", 
_numUpdatesToZk.incrementAndGet(), _jobId);
   }
 
   public boolean isDone() {
diff --git 
a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceCheckerTest.java
 
b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceCheckerTest.java
index 15c69e63673..85ffcbdfb0b 100644
--- 
a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceCheckerTest.java
+++ 
b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalanceCheckerTest.java
@@ -41,6 +41,7 @@ import org.apache.pinot.spi.utils.CommonConstants;
 import org.apache.pinot.spi.utils.JsonUtils;
 import org.mockito.ArgumentCaptor;
 import org.mockito.Mock;
+import org.mockito.Mockito;
 import org.mockito.MockitoAnnotations;
 import org.testng.annotations.AfterMethod;
 import org.testng.annotations.BeforeMethod;
@@ -122,6 +123,8 @@ public class TenantRebalanceCheckerTest extends 
ControllerTest {
     // Setup mocks
     doReturn(allJobMetadata).when(_mockPinotHelixResourceManager)
         .getAllJobs(eq(Set.of(ControllerJobTypes.TENANT_REBALANCE)), any());
+    doReturn(allJobMetadata.get(JOB_ID)).when(_mockPinotHelixResourceManager)
+        .getControllerJobZKMetadata(eq(JOB_ID), 
eq(ControllerJobTypes.TENANT_REBALANCE));
     doReturn(stuckTableJobMetadata).when(_mockPinotHelixResourceManager)
         .getControllerJobZKMetadata(eq(STUCK_TABLE_JOB_ID), 
eq(ControllerJobTypes.TABLE_REBALANCE));
 
@@ -132,11 +135,13 @@ public class TenantRebalanceCheckerTest extends 
ControllerTest {
         ArgumentCaptor.forClass(ZkBasedTenantRebalanceObserver.class);
 
     // Execute the checker
-    _tenantRebalanceChecker.runTask(new Properties());
+    TenantRebalanceChecker checkerSpy = Mockito.spy(_tenantRebalanceChecker);
+    checkerSpy.runTask(new Properties());
 
     // Verify that the tenant rebalancer was called to resume the job
-    verify(_mockTenantRebalancer, times(1)).rebalanceWithContext(
-        contextCaptor.capture(), observerCaptor.capture());
+    verify(checkerSpy, times(1)).retryTenantRebalanceJob(
+        contextCaptor.capture(), any());
+    verify(_mockTenantRebalancer, 
times(1)).rebalanceWithObserver(observerCaptor.capture(), any());
 
     // Verify the resumed context
     TenantRebalanceContext resumedContext = contextCaptor.getValue();
@@ -180,19 +185,22 @@ public class TenantRebalanceCheckerTest extends 
ControllerTest {
     // Setup mocks
     doReturn(allJobMetadata).when(_mockPinotHelixResourceManager)
         .getAllJobs(eq(Set.of(ControllerJobTypes.TENANT_REBALANCE)), any());
+    doReturn(allJobMetadata.get(JOB_ID)).when(_mockPinotHelixResourceManager)
+        .getControllerJobZKMetadata(eq(JOB_ID), 
eq(ControllerJobTypes.TENANT_REBALANCE));
     doReturn(stuckTableJobMetadata1).when(_mockPinotHelixResourceManager)
         .getControllerJobZKMetadata(eq(STUCK_TABLE_JOB_ID), 
eq(ControllerJobTypes.TABLE_REBALANCE));
     doReturn(stuckTableJobMetadata2).when(_mockPinotHelixResourceManager)
         .getControllerJobZKMetadata(eq(STUCK_TABLE_JOB_ID_2), 
eq(ControllerJobTypes.TABLE_REBALANCE));
 
     // Execute the checker
-    _tenantRebalanceChecker.runTask(new Properties());
+    TenantRebalanceChecker checkerSpy = Mockito.spy(_tenantRebalanceChecker);
+    checkerSpy.runTask(new Properties());
 
     // Verify that the tenant rebalancer was called
     ArgumentCaptor<TenantRebalanceContext> contextCaptor =
         ArgumentCaptor.forClass(TenantRebalanceContext.class);
-    verify(_mockTenantRebalancer, times(1)).rebalanceWithContext(
-        contextCaptor.capture(), any(ZkBasedTenantRebalanceObserver.class));
+    verify(checkerSpy, times(1)).retryTenantRebalanceJob(
+        contextCaptor.capture(), any());
 
     // Verify that both stuck table jobs were moved back to parallel queue
     TenantRebalanceContext resumedContext = contextCaptor.getValue();
@@ -222,7 +230,7 @@ public class TenantRebalanceCheckerTest extends 
ControllerTest {
     _tenantRebalanceChecker.runTask(new Properties());
 
     // Verify that the tenant rebalancer was NOT called
-    verify(_mockTenantRebalancer, never()).rebalanceWithContext(any(), any());
+    verify(_mockTenantRebalancer, never()).rebalanceWithObserver(any(), any());
   }
 
   @Test
@@ -255,7 +263,7 @@ public class TenantRebalanceCheckerTest extends 
ControllerTest {
     _tenantRebalanceChecker.runTask(new Properties());
 
     // Verify that the tenant rebalancer was NOT called
-    verify(_mockTenantRebalancer, never()).rebalanceWithContext(any(), any());
+    verify(_mockTenantRebalancer, never()).rebalanceWithObserver(any(), any());
   }
 
   @Test
@@ -292,8 +300,7 @@ public class TenantRebalanceCheckerTest extends 
ControllerTest {
     _tenantRebalanceChecker.runTask(new Properties());
 
     // Verify that the tenant rebalancer was NOT called because ZK update 
failed
-    verify(_mockTenantRebalancer, times(0)).rebalanceWithContext(
-        contextCaptor.capture(), observerCaptor.capture());
+    verify(_mockTenantRebalancer, never()).rebalanceWithObserver(any(), any());
   }
 
   @Test
@@ -320,7 +327,7 @@ public class TenantRebalanceCheckerTest extends 
ControllerTest {
     _tenantRebalanceChecker.runTask(new Properties());
 
     // Verify that the tenant rebalancer was NOT called
-    verify(_mockTenantRebalancer, never()).rebalanceWithContext(any(), any());
+    verify(_mockTenantRebalancer, never()).rebalanceWithObserver(any(), any());
   }
 
   @Test
@@ -342,12 +349,14 @@ public class TenantRebalanceCheckerTest extends 
ControllerTest {
     // Setup mocks
     doReturn(allJobMetadata).when(_mockPinotHelixResourceManager)
         .getAllJobs(eq(Set.of(ControllerJobTypes.TENANT_REBALANCE)), any());
+    doReturn(allJobMetadata.get(JOB_ID)).when(_mockPinotHelixResourceManager)
+        .getControllerJobZKMetadata(eq(JOB_ID), 
eq(ControllerJobTypes.TENANT_REBALANCE));
+    doReturn(allJobMetadata.get(JOB_ID_2)).when(_mockPinotHelixResourceManager)
+        .getControllerJobZKMetadata(eq(JOB_ID_2), 
eq(ControllerJobTypes.TENANT_REBALANCE));
     doReturn(stuckTableJobMetadata).when(_mockPinotHelixResourceManager)
         .getControllerJobZKMetadata(eq(STUCK_TABLE_JOB_ID), 
eq(ControllerJobTypes.TABLE_REBALANCE));
 
     // Mock the tenant rebalancer to capture the resumed context
-    ArgumentCaptor<TenantRebalanceContext> contextCaptor =
-        ArgumentCaptor.forClass(TenantRebalanceContext.class);
     ArgumentCaptor<ZkBasedTenantRebalanceObserver> observerCaptor =
         ArgumentCaptor.forClass(ZkBasedTenantRebalanceObserver.class);
 
@@ -355,23 +364,23 @@ public class TenantRebalanceCheckerTest extends 
ControllerTest {
     _tenantRebalanceChecker.runTask(new Properties());
 
     // Verify that the tenant rebalancer was called to resume the job
-    verify(_mockTenantRebalancer, times(1)).rebalanceWithContext(
-        contextCaptor.capture(), observerCaptor.capture());
+    verify(_mockTenantRebalancer, times(1)).rebalanceWithObserver(
+        observerCaptor.capture(), any());
     // The mockTenantRebalance never let the job done
     assertFalse(observerCaptor.getValue().isDone());
 
     _tenantRebalanceChecker.runTask(new Properties());
     // Since the previous job is not done, the rebalanceWithContext should not 
be called again as we have set the limit
     // to one tenant rebalance retry at a time
-    verify(_mockTenantRebalancer, times(1)).rebalanceWithContext(
-        contextCaptor.capture(), observerCaptor.capture());
+    verify(_mockTenantRebalancer, times(1)).rebalanceWithObserver(
+        observerCaptor.capture(), any());
 
     // Mark the job as done and run the checker again - should pick up another 
job now
     observerCaptor.getValue().setDone(true);
     _tenantRebalanceChecker.runTask(new Properties());
 
-    verify(_mockTenantRebalancer, times(2)).rebalanceWithContext(
-        contextCaptor.capture(), observerCaptor.capture());
+    verify(_mockTenantRebalancer, times(2)).rebalanceWithObserver(
+        observerCaptor.capture(), any());
   }
 
   // Helper methods to create test data
@@ -467,8 +476,8 @@ public class TenantRebalanceCheckerTest extends 
ControllerTest {
 
     TenantRebalanceProgressStats stats = new 
TenantRebalanceProgressStats(tables);
     stats.setStartTimeMs(System.currentTimeMillis() - 60000); // 1 minute ago
-    stats.updateTableStatus(TABLE_NAME_1, 
TenantRebalanceProgressStats.TableStatus.PROCESSING.name());
-    stats.updateTableStatus(TABLE_NAME_2, 
TenantRebalanceProgressStats.TableStatus.UNPROCESSED.name());
+    stats.updateTableStatus(TABLE_NAME_1, 
TenantRebalanceProgressStats.TableStatus.REBALANCING.name());
+    stats.updateTableStatus(TABLE_NAME_2, 
TenantRebalanceProgressStats.TableStatus.IN_QUEUE.name());
 
     return stats;
   }
@@ -480,8 +489,8 @@ public class TenantRebalanceCheckerTest extends 
ControllerTest {
 
     TenantRebalanceProgressStats stats = new 
TenantRebalanceProgressStats(tables);
     stats.setStartTimeMs(System.currentTimeMillis() - 60000);
-    stats.updateTableStatus(TABLE_NAME_1, 
TenantRebalanceProgressStats.TableStatus.PROCESSING.name());
-    stats.updateTableStatus(TABLE_NAME_2, 
TenantRebalanceProgressStats.TableStatus.PROCESSING.name());
+    stats.updateTableStatus(TABLE_NAME_1, 
TenantRebalanceProgressStats.TableStatus.REBALANCING.name());
+    stats.updateTableStatus(TABLE_NAME_2, 
TenantRebalanceProgressStats.TableStatus.REBALANCING.name());
 
     return stats;
   }
diff --git 
a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalancerTest.java
 
b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalancerTest.java
index d033b9c6a4f..92f8dc729bb 100644
--- 
a/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalancerTest.java
+++ 
b/pinot-controller/src/test/java/org/apache/pinot/controller/helix/core/rebalance/tenant/TenantRebalancerTest.java
@@ -21,16 +21,24 @@ package 
org.apache.pinot.controller.helix.core.rebalance.tenant;
 
 import com.fasterxml.jackson.core.JsonProcessingException;
 import java.io.IOException;
+import java.util.ArrayList;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.HashSet;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Queue;
 import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.ConcurrentLinkedDeque;
+import java.util.concurrent.ConcurrentLinkedQueue;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
+import java.util.concurrent.Future;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.pinot.common.assignment.InstancePartitions;
 import org.apache.pinot.common.tier.TierFactory;
@@ -139,7 +147,7 @@ public class TenantRebalancerTest extends ControllerTest {
     TenantRebalanceProgressStats progressStats = 
getProgress(result.getJobId());
     
assertTrue(progressStats.getTableRebalanceJobIdMap().containsKey(OFFLINE_TABLE_NAME_A));
     assertEquals(progressStats.getTableStatusMap().get(OFFLINE_TABLE_NAME_A),
-        TenantRebalanceProgressStats.TableStatus.PROCESSED.name());
+        TenantRebalanceProgressStats.TableStatus.DONE.name());
     Map<String, Map<String, String>> idealState =
         
_helixResourceManager.getTableIdealState(OFFLINE_TABLE_NAME_A).getRecord().getMapFields();
     Map<String, Map<String, String>> externalView =
@@ -1086,6 +1094,506 @@ public class TenantRebalancerTest extends 
ControllerTest {
     });
   }
 
+  @Test
+  public void testZkBasedTenantRebalanceObserverPoll()
+      throws Exception {
+    int numServers = 3;
+    for (int i = 0; i < numServers; i++) {
+      addFakeServerInstanceToAutoJoinHelixCluster(SERVER_INSTANCE_ID_PREFIX + 
i, true);
+    }
+
+    TenantRebalancer tenantRebalancer =
+        new TenantRebalancer(_tableRebalanceManager, _helixResourceManager, 
_executorService);
+
+    // tag all servers and brokers to test tenant
+    addTenantTagToInstances(TENANT_NAME);
+
+    // create 2 schemas
+    addDummySchema(RAW_TABLE_NAME_A);
+    addDummySchema(RAW_TABLE_NAME_B);
+
+    // create 2 tables on test tenant
+    createTableWithSegments(RAW_TABLE_NAME_A, TENANT_NAME);
+    createTableWithSegments(RAW_TABLE_NAME_B, TENANT_NAME);
+
+    // Add 3 more servers which will be tagged to default tenant
+    int numServersToAdd = 3;
+    for (int i = 0; i < numServersToAdd; i++) {
+      addFakeServerInstanceToAutoJoinHelixCluster(SERVER_INSTANCE_ID_PREFIX + 
(numServers + i), true);
+    }
+    addTenantTagToInstances(TENANT_NAME);
+
+    String jobId = "test-poll-job-123";
+
+    // Create tenant rebalance context with tables in queues
+    TenantRebalanceConfig config = new TenantRebalanceConfig();
+    config.setTenantName(TENANT_NAME);
+    config.setVerboseResult(true);
+    config.setDryRun(true);
+
+    TenantRebalanceResult dryRunResult = tenantRebalancer.rebalance(config);
+    TenantRebalanceContext context = new TenantRebalanceContext(
+        jobId, config, 1,
+        dryRunResult.getRebalanceTableResults().keySet().stream()
+            .map(tableName -> new 
TenantRebalancer.TenantTableRebalanceJobContext(tableName, tableName + "_job", 
false))
+            .collect(Collectors.toCollection(ConcurrentLinkedDeque::new)),
+        new LinkedList<>(),
+        new ConcurrentLinkedQueue<>()
+    );
+
+    TenantRebalanceProgressStats progressStats =
+        new 
TenantRebalanceProgressStats(dryRunResult.getRebalanceTableResults().keySet());
+
+    // Test polling from parallel queue
+    ZkBasedTenantRebalanceObserver observer =
+        new ZkBasedTenantRebalanceObserver(jobId, TENANT_NAME, progressStats, 
context, _helixResourceManager);
+    TenantRebalancer.TenantTableRebalanceJobContext polledJob = 
observer.pollParallel();
+    assertNotNull(polledJob);
+    
assertTrue(dryRunResult.getRebalanceTableResults().containsKey(polledJob.getTableName()));
+
+    // Test polling from sequential queue (should be empty)
+    TenantRebalancer.TenantTableRebalanceJobContext sequentialJob = 
observer.pollSequential();
+    assertNull(sequentialJob);
+
+    // Verify the job was moved to ongoing queue and status was updated
+    Map<String, String> updatedMetadata =
+        _helixResourceManager.getControllerJobZKMetadata(jobId, 
ControllerJobTypes.TENANT_REBALANCE);
+    assertNotNull(updatedMetadata);
+    TenantRebalanceContext updatedContext = 
TenantRebalanceContext.fromTenantRebalanceJobMetadata(updatedMetadata);
+    TenantRebalanceProgressStats updatedStats =
+        
TenantRebalanceProgressStats.fromTenantRebalanceJobMetadata(updatedMetadata);
+
+    assertNotNull(updatedContext);
+    assertNotNull(updatedStats);
+    assertEquals(updatedContext.getOngoingJobsQueue().size(), 1);
+    
assertEquals(updatedStats.getTableStatusMap().get(polledJob.getTableName()),
+        TenantRebalanceProgressStats.TableStatus.REBALANCING.name());
+
+    _helixResourceManager.deleteOfflineTable(RAW_TABLE_NAME_A);
+    _helixResourceManager.deleteOfflineTable(RAW_TABLE_NAME_B);
+
+    for (int i = 0; i < numServers + numServersToAdd; i++) {
+      stopAndDropFakeInstance(SERVER_INSTANCE_ID_PREFIX + i);
+    }
+  }
+
+  @Test
+  public void testZkBasedTenantRebalanceObserverCancelJob()
+      throws Exception {
+    int numServers = 3;
+    for (int i = 0; i < numServers; i++) {
+      addFakeServerInstanceToAutoJoinHelixCluster(SERVER_INSTANCE_ID_PREFIX + 
i, true);
+    }
+
+    TenantRebalancer tenantRebalancer =
+        new TenantRebalancer(_tableRebalanceManager, _helixResourceManager, 
_executorService);
+
+    // tag all servers and brokers to test tenant
+    addTenantTagToInstances(TENANT_NAME);
+
+    // create 2 schemas
+    addDummySchema(RAW_TABLE_NAME_A);
+    addDummySchema(RAW_TABLE_NAME_B);
+
+    // create 2 tables on test tenant
+    createTableWithSegments(RAW_TABLE_NAME_A, TENANT_NAME);
+    createTableWithSegments(RAW_TABLE_NAME_B, TENANT_NAME);
+
+    // Add 3 more servers which will be tagged to default tenant
+    int numServersToAdd = 3;
+    for (int i = 0; i < numServersToAdd; i++) {
+      addFakeServerInstanceToAutoJoinHelixCluster(SERVER_INSTANCE_ID_PREFIX + 
(numServers + i), true);
+    }
+    addTenantTagToInstances(TENANT_NAME);
+
+    // Create observer and test cancellation
+    String jobId = "test-cancel-job-456";
+
+    // Create tenant rebalance context with tables in queues
+    TenantRebalanceConfig config = new TenantRebalanceConfig();
+    config.setTenantName(TENANT_NAME);
+    config.setVerboseResult(true);
+    config.setDryRun(true);
+
+    TenantRebalanceResult dryRunResult = tenantRebalancer.rebalance(config);
+    Set<String> tableNames = dryRunResult.getRebalanceTableResults().keySet();
+
+    TenantRebalanceContext context = new TenantRebalanceContext(
+        jobId, config, 1,
+        tableNames.stream()
+            .map(tableName -> new 
TenantRebalancer.TenantTableRebalanceJobContext(tableName, tableName + "_job", 
false))
+            .collect(Collectors.toCollection(ConcurrentLinkedDeque::new)),
+        new LinkedList<>(),
+        new ConcurrentLinkedQueue<>()
+    );
+
+    TenantRebalanceProgressStats progressStats = new 
TenantRebalanceProgressStats(tableNames);
+
+    // Test cancellation by user
+    ZkBasedTenantRebalanceObserver observer =
+        new ZkBasedTenantRebalanceObserver(jobId, TENANT_NAME, progressStats, 
context, _helixResourceManager);
+
+    // move one job to ongoing to test cancellation from that state
+    TenantRebalancer.TenantTableRebalanceJobContext polledJob = 
observer.pollParallel();
+    assertNotNull(polledJob);
+    
assertTrue(dryRunResult.getRebalanceTableResults().containsKey(polledJob.getTableName()));
+
+    Pair<List<String>, Boolean> cancelResult = observer.cancelJob(true);
+    assertTrue(cancelResult.getRight()); // cancellation was successful
+    assertTrue(cancelResult.getLeft()
+        .isEmpty()); // no jobs were cancelled (the polled job hasn't started 
its table rebalance job yet thus won't
+    // show in the cancelled list)
+
+    // Verify that queues are emptied and status is updated
+    Map<String, String> updatedMetadata =
+        _helixResourceManager.getControllerJobZKMetadata(jobId, 
ControllerJobTypes.TENANT_REBALANCE);
+    assertNotNull(updatedMetadata);
+    TenantRebalanceContext updatedContext = 
TenantRebalanceContext.fromTenantRebalanceJobMetadata(updatedMetadata);
+    TenantRebalanceProgressStats updatedStats =
+        
TenantRebalanceProgressStats.fromTenantRebalanceJobMetadata(updatedMetadata);
+
+    assertNotNull(updatedContext);
+    assertNotNull(updatedStats);
+    assertTrue(updatedContext.getParallelQueue().isEmpty());
+    assertTrue(updatedContext.getSequentialQueue().isEmpty());
+    assertTrue(updatedContext.getOngoingJobsQueue().isEmpty());
+    assertEquals(updatedStats.getRemainingTables(), 0);
+    assertEquals(updatedStats.getCompletionStatusMsg(), "Tenant rebalance job 
has been cancelled.");
+    assertTrue(updatedStats.getTimeToFinishInSeconds() >= 0);
+
+    // Verify all tables are marked as not scheduled
+    for (String tableName : tableNames) {
+      if (tableName.equals(polledJob.getTableName())) {
+        // the polled job was in ongoing, so should be marked as CANCELLED
+        assertEquals(updatedStats.getTableStatusMap().get(tableName),
+            TenantRebalanceProgressStats.TableStatus.CANCELLED.name());
+      } else {
+        assertEquals(updatedStats.getTableStatusMap().get(tableName),
+            TenantRebalanceProgressStats.TableStatus.NOT_SCHEDULED.name());
+      }
+    }
+
+    _helixResourceManager.deleteOfflineTable(RAW_TABLE_NAME_A);
+    _helixResourceManager.deleteOfflineTable(RAW_TABLE_NAME_B);
+
+    for (int i = 0; i < numServers + numServersToAdd; i++) {
+      stopAndDropFakeInstance(SERVER_INSTANCE_ID_PREFIX + i);
+    }
+  }
+
+  @Test
+  public void testZkBasedTenantRebalanceObserverOnTableJobDoneAndError()
+      throws Exception {
+    int numServers = 3;
+    for (int i = 0; i < numServers; i++) {
+      addFakeServerInstanceToAutoJoinHelixCluster(SERVER_INSTANCE_ID_PREFIX + 
i, true);
+    }
+
+    TenantRebalancer tenantRebalancer =
+        new TenantRebalancer(_tableRebalanceManager, _helixResourceManager, 
_executorService);
+
+    // tag all servers and brokers to test tenant
+    addTenantTagToInstances(TENANT_NAME);
+
+    // create 2 schemas
+    addDummySchema(RAW_TABLE_NAME_A);
+    addDummySchema(RAW_TABLE_NAME_B);
+
+    // create 2 tables on test tenant
+    createTableWithSegments(RAW_TABLE_NAME_A, TENANT_NAME);
+    createTableWithSegments(RAW_TABLE_NAME_B, TENANT_NAME);
+
+    // Add 3 more servers which will be tagged to default tenant
+    int numServersToAdd = 3;
+    for (int i = 0; i < numServersToAdd; i++) {
+      addFakeServerInstanceToAutoJoinHelixCluster(SERVER_INSTANCE_ID_PREFIX + 
(numServers + i), true);
+    }
+    addTenantTagToInstances(TENANT_NAME);
+
+    // Create observer and test table job completion
+    String jobId = "test-table-done-job-789";
+
+    // Create tenant rebalance context with tables in ongoing queue
+    TenantRebalanceConfig config = new TenantRebalanceConfig();
+    config.setTenantName(TENANT_NAME);
+    config.setVerboseResult(true);
+    config.setDryRun(true);
+
+    TenantRebalanceResult dryRunResult = tenantRebalancer.rebalance(config);
+    Set<String> tableNames = dryRunResult.getRebalanceTableResults().keySet();
+
+    TenantRebalanceContext context = new TenantRebalanceContext(
+        jobId, config, 1,
+        new ConcurrentLinkedDeque<>(),
+        new LinkedList<>(),
+        tableNames.stream()
+            .map(tableName -> new 
TenantRebalancer.TenantTableRebalanceJobContext(tableName, tableName + "_job", 
false))
+            .collect(Collectors.toCollection(ConcurrentLinkedQueue::new))
+    );
+
+    TenantRebalanceProgressStats progressStats = new 
TenantRebalanceProgressStats(tableNames);
+    // Set initial status to REBALANCING for the tables
+    for (String tableName : tableNames) {
+      progressStats.updateTableStatus(tableName, 
TenantRebalanceProgressStats.TableStatus.REBALANCING.name());
+    }
+
+    // Test onTableJobDone
+    ZkBasedTenantRebalanceObserver observer =
+        new ZkBasedTenantRebalanceObserver(jobId, TENANT_NAME, progressStats, 
context, _helixResourceManager);
+    TenantRebalancer.TenantTableRebalanceJobContext jobContextA =
+        new 
TenantRebalancer.TenantTableRebalanceJobContext(OFFLINE_TABLE_NAME_A, 
OFFLINE_TABLE_NAME_A + "_job", false);
+    observer.onTableJobDone(jobContextA);
+    TenantRebalancer.TenantTableRebalanceJobContext jobContextB =
+        new 
TenantRebalancer.TenantTableRebalanceJobContext(OFFLINE_TABLE_NAME_B, 
OFFLINE_TABLE_NAME_B + "_job", false);
+    String errorMessage = "Test error message";
+    observer.onTableJobError(jobContextB, errorMessage);
+
+    // Verify that the job was removed from ongoing queue and status was 
updated
+    Map<String, String> updatedMetadata =
+        _helixResourceManager.getControllerJobZKMetadata(jobId, 
ControllerJobTypes.TENANT_REBALANCE);
+    assertNotNull(updatedMetadata);
+    TenantRebalanceContext updatedContext = 
TenantRebalanceContext.fromTenantRebalanceJobMetadata(updatedMetadata);
+    TenantRebalanceProgressStats updatedStats =
+        
TenantRebalanceProgressStats.fromTenantRebalanceJobMetadata(updatedMetadata);
+
+    assertNotNull(updatedContext);
+    assertNotNull(updatedStats);
+    assertFalse(updatedContext.getOngoingJobsQueue().contains(jobContextA));
+    assertFalse(updatedContext.getOngoingJobsQueue().contains(jobContextB));
+    assertEquals(updatedStats.getTableStatusMap().get(OFFLINE_TABLE_NAME_A),
+        TenantRebalanceProgressStats.TableStatus.DONE.name());
+    assertEquals(updatedStats.getTableStatusMap().get(OFFLINE_TABLE_NAME_B), 
errorMessage);
+    assertEquals(updatedStats.getRemainingTables(), tableNames.size() - 2);
+
+    _helixResourceManager.deleteOfflineTable(RAW_TABLE_NAME_A);
+    _helixResourceManager.deleteOfflineTable(RAW_TABLE_NAME_B);
+
+    for (int i = 0; i < numServers + numServersToAdd; i++) {
+      stopAndDropFakeInstance(SERVER_INSTANCE_ID_PREFIX + i);
+    }
+  }
+
+  @Test
+  public void testZkBasedTenantRebalanceObserverLifecycle()
+      throws Exception {
+    int numServers = 3;
+    for (int i = 0; i < numServers; i++) {
+      addFakeServerInstanceToAutoJoinHelixCluster(SERVER_INSTANCE_ID_PREFIX + 
i, true);
+    }
+
+    TenantRebalancer tenantRebalancer =
+        new TenantRebalancer(_tableRebalanceManager, _helixResourceManager, 
_executorService);
+
+    // tag all servers and brokers to test tenant
+    addTenantTagToInstances(TENANT_NAME);
+
+    // create 2 schemas
+    addDummySchema(RAW_TABLE_NAME_A);
+    addDummySchema(RAW_TABLE_NAME_B);
+
+    // create 2 tables on test tenant
+    createTableWithSegments(RAW_TABLE_NAME_A, TENANT_NAME);
+    createTableWithSegments(RAW_TABLE_NAME_B, TENANT_NAME);
+
+    // Add 3 more servers which will be tagged to default tenant
+    int numServersToAdd = 3;
+    for (int i = 0; i < numServersToAdd; i++) {
+      addFakeServerInstanceToAutoJoinHelixCluster(SERVER_INSTANCE_ID_PREFIX + 
(numServers + i), true);
+    }
+    addTenantTagToInstances(TENANT_NAME);
+
+    // Create observer and test lifecycle methods
+    String jobId = "test-lifecycle-job-202";
+
+    // Create tenant rebalance context
+    TenantRebalanceConfig config = new TenantRebalanceConfig();
+    config.setTenantName(TENANT_NAME);
+    config.setVerboseResult(true);
+    config.setDryRun(true);
+
+    TenantRebalanceResult dryRunResult = tenantRebalancer.rebalance(config);
+    Set<String> tableNames = dryRunResult.getRebalanceTableResults().keySet();
+
+    TenantRebalanceContext context = new TenantRebalanceContext(
+        jobId, config, 1,
+        tableNames.stream()
+            .map(tableName -> new 
TenantRebalancer.TenantTableRebalanceJobContext(tableName, tableName + "_job", 
false))
+            .collect(Collectors.toCollection(ConcurrentLinkedDeque::new)),
+        new LinkedList<>(),
+        new ConcurrentLinkedQueue<>()
+    );
+
+    TenantRebalanceProgressStats progressStats = new 
TenantRebalanceProgressStats(tableNames);
+
+    // Test onStart
+    ZkBasedTenantRebalanceObserver observer =
+        new ZkBasedTenantRebalanceObserver(jobId, TENANT_NAME, progressStats, 
context, _helixResourceManager);
+    observer.onStart();
+    Map<String, String> startMetadata =
+        _helixResourceManager.getControllerJobZKMetadata(jobId, 
ControllerJobTypes.TENANT_REBALANCE);
+    assertNotNull(startMetadata);
+    TenantRebalanceProgressStats startStats =
+        
TenantRebalanceProgressStats.fromTenantRebalanceJobMetadata(startMetadata);
+    assertNotNull(startStats);
+    assertTrue(startStats.getStartTimeMs() > 0);
+
+    // Test onSuccess
+    String successMessage = "Tenant rebalance completed successfully";
+    observer.onSuccess(successMessage);
+    Map<String, String> successMetadata =
+        _helixResourceManager.getControllerJobZKMetadata(jobId, 
ControllerJobTypes.TENANT_REBALANCE);
+    assertNotNull(successMetadata);
+    TenantRebalanceProgressStats successStats =
+        
TenantRebalanceProgressStats.fromTenantRebalanceJobMetadata(successMetadata);
+    assertNotNull(successStats);
+    assertEquals(successStats.getCompletionStatusMsg(), successMessage);
+    assertTrue(successStats.getTimeToFinishInSeconds() >= 0);
+    assertTrue(observer.isDone());
+
+    _helixResourceManager.deleteOfflineTable(RAW_TABLE_NAME_A);
+    _helixResourceManager.deleteOfflineTable(RAW_TABLE_NAME_B);
+
+    for (int i = 0; i < numServers + numServersToAdd; i++) {
+      stopAndDropFakeInstance(SERVER_INSTANCE_ID_PREFIX + i);
+    }
+  }
+
+  @Test
+  public void testZkBasedTenantRebalanceObserverConcurrentPoll()
+      throws Exception {
+
+    int numServers = 3;
+    for (int i = 0; i < numServers; i++) {
+      addFakeServerInstanceToAutoJoinHelixCluster(SERVER_INSTANCE_ID_PREFIX + 
i, true);
+    }
+
+    // tag all servers and brokers to test tenant
+    addTenantTagToInstances(TENANT_NAME);
+
+    // create multiple schemas and tables for concurrent testing
+    String[] tableNames = {"table1", "table2", "table3", "table4", "table5"};
+    for (String tableName : tableNames) {
+      addDummySchema(tableName);
+      createTableWithSegments(tableName, TENANT_NAME);
+    }
+
+    // Add more servers
+    int numServersToAdd = 3;
+    for (int i = 0; i < numServersToAdd; i++) {
+      addFakeServerInstanceToAutoJoinHelixCluster(SERVER_INSTANCE_ID_PREFIX + 
(numServers + i), true);
+    }
+    addTenantTagToInstances(TENANT_NAME);
+
+    for (int i = 0; i < 3; i++) {
+      runZkBasedTenantRebalanceObserverConcurrentPoll();
+    }
+
+    // Clean up tables
+    for (String tableName : tableNames) {
+      _helixResourceManager.deleteOfflineTable(tableName);
+    }
+
+    for (int i = 0; i < numServers + numServersToAdd; i++) {
+      stopAndDropFakeInstance(SERVER_INSTANCE_ID_PREFIX + i);
+    }
+  }
+
+  private void runZkBasedTenantRebalanceObserverConcurrentPoll()
+      throws Exception {
+    TenantRebalancer tenantRebalancer =
+        new TenantRebalancer(_tableRebalanceManager, _helixResourceManager, 
_executorService);
+
+    String jobId = "test-concurrent-poll-job-303";
+
+    // Create tenant rebalance context with multiple tables
+    TenantRebalanceConfig config = new TenantRebalanceConfig();
+    config.setTenantName(TENANT_NAME);
+    config.setVerboseResult(true);
+    config.setDryRun(true);
+
+    TenantRebalanceResult dryRunResult = tenantRebalancer.rebalance(config);
+    Set<String> offlineTableNames = 
dryRunResult.getRebalanceTableResults().keySet();
+
+    TenantRebalanceContext context = new TenantRebalanceContext(
+        jobId, config, 1,
+        offlineTableNames.stream()
+            .map(tableName -> new 
TenantRebalancer.TenantTableRebalanceJobContext(tableName, tableName + "_job", 
false))
+            .collect(Collectors.toCollection(ConcurrentLinkedDeque::new)),
+        new LinkedList<>(),
+        new ConcurrentLinkedQueue<>()
+    );
+
+    TenantRebalanceProgressStats progressStats = new 
TenantRebalanceProgressStats(offlineTableNames);
+
+    // Create observer
+    ZkBasedTenantRebalanceObserver observer =
+        new ZkBasedTenantRebalanceObserver(jobId, TENANT_NAME, progressStats, 
context, _helixResourceManager);
+
+    // Test concurrent polling with multiple threads
+    int numThreads = offlineTableNames.size();
+    ExecutorService concurrentExecutor = 
Executors.newFixedThreadPool(numThreads);
+    List<Future<TenantRebalancer.TenantTableRebalanceJobContext>> futures = 
new ArrayList<>();
+    Set<String> polledTables = ConcurrentHashMap.newKeySet();
+    AtomicInteger pollCount = new AtomicInteger(0);
+
+    // Submit concurrent polling tasks
+    for (int i = 0; i < numThreads; i++) {
+      futures.add(concurrentExecutor.submit(() -> {
+        TenantRebalancer.TenantTableRebalanceJobContext job = 
observer.pollParallel();
+        if (job != null) {
+          polledTables.add(job.getTableName());
+          pollCount.incrementAndGet();
+        }
+        return job;
+      }));
+    }
+
+    // Wait for all polling tasks to complete
+    List<TenantRebalancer.TenantTableRebalanceJobContext> polledJobs = new 
ArrayList<>();
+    for (Future<TenantRebalancer.TenantTableRebalanceJobContext> future : 
futures) {
+      TenantRebalancer.TenantTableRebalanceJobContext job = future.get(5, 
TimeUnit.SECONDS);
+      if (job != null) {
+        polledJobs.add(job);
+      }
+    }
+
+    // Verify thread safety: no duplicate jobs should be polled
+    assertEquals(polledJobs.size(), pollCount.get());
+    assertEquals(polledTables.size(), polledJobs.size());
+    assertTrue(polledJobs.size() <= offlineTableNames.size());
+
+    // Verify that all polled jobs are valid table names
+    for (TenantRebalancer.TenantTableRebalanceJobContext job : polledJobs) {
+      assertTrue(offlineTableNames.contains(job.getTableName()));
+    }
+
+    // Verify ZK state consistency after concurrent operations
+    Map<String, String> updatedMetadata =
+        _helixResourceManager.getControllerJobZKMetadata(jobId, 
ControllerJobTypes.TENANT_REBALANCE);
+    assertNotNull(updatedMetadata);
+    TenantRebalanceContext updatedContext = 
TenantRebalanceContext.fromTenantRebalanceJobMetadata(updatedMetadata);
+    TenantRebalanceProgressStats updatedStats =
+        
TenantRebalanceProgressStats.fromTenantRebalanceJobMetadata(updatedMetadata);
+
+    assertNotNull(updatedContext);
+    assertNotNull(updatedStats);
+
+    // Verify that polled jobs are in ongoing queue
+    assertEquals(updatedContext.getOngoingJobsQueue().size(), 
polledJobs.size());
+    for (TenantRebalancer.TenantTableRebalanceJobContext job : polledJobs) {
+      assertTrue(updatedContext.getOngoingJobsQueue().contains(job));
+      assertEquals(updatedStats.getTableStatusMap().get(job.getTableName()),
+          TenantRebalanceProgressStats.TableStatus.REBALANCING.name());
+    }
+
+    // Verify remaining tables in parallel queue
+    assertEquals(updatedContext.getParallelQueue().size(), 0);
+
+    // Cleanup
+    concurrentExecutor.shutdown();
+    concurrentExecutor.awaitTermination(5, TimeUnit.SECONDS);
+  }
+
   @AfterClass
   public void tearDown() {
     stopFakeInstances();


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to