This is an automated email from the ASF dual-hosted git repository.

schofielaj pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/trunk by this push:
     new ef4b04b5485 KAFKA-19844: ShareFetch related changes to support RENEW. 
(#20826)
ef4b04b5485 is described below

commit ef4b04b5485e11bc5821fd1b8e07c5deaedec79f
Author: Sushant Mahajan <[email protected]>
AuthorDate: Wed Nov 5 23:25:53 2025 +0530

    KAFKA-19844: ShareFetch related changes to support RENEW. (#20826)
    
    * When a ShareFetch request contains RENEW acks, the fetchMessages
    sub-routine is skipped in the KafkaApis handler for the request.
    * Additionally, new validations for ShareFetch version and ackType have
    been added along with validations on fields maxBytes, minBytes,
    maxRecords and maxWaitMs which should be set to 0 for version >= 2 and
    isRenewAck set to true.
    * Unit tests have been added to verify the behavior.
    
    Reviewers: Abhinav Dixit <[email protected]>, poorv Mittal
     <[email protected]>, Andrew Schofield <[email protected]>
---
 core/src/main/scala/kafka/server/KafkaApis.scala   |  68 ++++-
 .../scala/unit/kafka/server/KafkaApisTest.scala    | 281 ++++++++++++++++++++-
 2 files changed, 331 insertions(+), 18 deletions(-)

diff --git a/core/src/main/scala/kafka/server/KafkaApis.scala 
b/core/src/main/scala/kafka/server/KafkaApis.scala
index 52f1d5bfccd..4ea3c0bf627 100644
--- a/core/src/main/scala/kafka/server/KafkaApis.scala
+++ b/core/src/main/scala/kafka/server/KafkaApis.scala
@@ -3182,6 +3182,35 @@ class KafkaApis(val requestChannel: RequestChannel,
     val newReqMetadata: ShareRequestMetadata = new 
ShareRequestMetadata(Uuid.fromString(memberId), shareSessionEpoch)
     var shareFetchContext: ShareFetchContext = null
 
+    // KIP-1222 enforces setting the maxBytes, minBytes, maxRecords, maxWaitMs
+    // values to 0, in case isRenewAck is true.
+    if (shareFetchRequest.version >= 2 && shareFetchRequest.data.isRenewAck) {
+      val reqData = shareFetchRequest.data
+      var errorMsg: String = ""
+      if (reqData.maxBytes != 0) {
+        errorMsg += "maxBytes must be set to 0, "
+      }
+
+      if (reqData.minBytes != 0) {
+        errorMsg += "minBytes must be set to 0, "
+      }
+
+      if (reqData.maxRecords != 0) {
+        errorMsg += "maxRecords must be set to 0, "
+      }
+
+      if (reqData.maxWaitMs != 0) {
+        errorMsg += "maxWaitMs must be set to 0, "
+      }
+
+      if (errorMsg != "") {
+        errorMsg += "if isRenewAck is true."
+        error(errorMsg)
+        requestHelper.sendMaybeThrottle(request, 
shareFetchRequest.getErrorResponse(AbstractResponse.DEFAULT_THROTTLE_TIME, 
Errors.INVALID_REQUEST.exception(errorMsg)))
+        return CompletableFuture.completedFuture[Unit](())
+      }
+    }
+
     try {
       // Creating the shareFetchContext for Share Session Handling. if context 
creation fails, the request is failed directly here.
       shareFetchContext = sharePartitionManager.newContext(groupId, 
shareFetchData, forgottenTopics, newReqMetadata, isAcknowledgeDataPresent, 
request.context.connectionId)
@@ -3234,18 +3263,24 @@ class KafkaApis(val requestChannel: RequestChannel,
         authorizedTopics,
         groupId,
         memberId,
+        shareFetchRequest.version == 2,
+        shareFetchRequest.data.isRenewAck
       )
     }
 
     // Handling the Fetch from the ShareFetchRequest.
     // Variable to store the topic partition wise result of fetching.
-    val fetchResult: CompletableFuture[Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData]] = handleFetchFromShareFetchRequest(
-      request,
-      shareSessionEpoch,
-      erroneousAndValidPartitionData,
-      sharePartitionManager,
-      authorizedTopics
-    )
+
+    // Here we are populating fetchResult conditionally because per the design 
of
+    // KIP-1222, if a ShareFetch request contains a RENEW ack type piggybacked 
then
+    // we must forego the record fetching as the amount of time spent in 
fetching
+    // might be more that the acquisition lock timeout which got RENEWed and 
as a result
+    // it'll timeout again, before the response reaches the share consumer.
+    val fetchResult: CompletableFuture[Map[TopicIdPartition, 
ShareFetchResponseData.PartitionData]] =
+      if (shareFetchRequest.version() >= 2 && 
shareFetchRequest.data.isRenewAck)
+        CompletableFuture.completedFuture(mutable.Map.empty[TopicIdPartition, 
ShareFetchResponseData.PartitionData])
+      else
+        handleFetchFromShareFetchRequest(request, shareSessionEpoch, 
erroneousAndValidPartitionData, sharePartitionManager, authorizedTopics)
 
     def combineShareFetchAndShareAcknowledgeResponses(fetchResult: 
CompletableFuture[Map[TopicIdPartition, ShareFetchResponseData.PartitionData]],
                                                       acknowledgeResult: 
CompletableFuture[Map[TopicIdPartition, 
ShareAcknowledgeResponseData.PartitionData]],
@@ -3422,9 +3457,11 @@ class KafkaApis(val requestChannel: RequestChannel,
                              sharePartitionManagerInstance: 
SharePartitionManager,
                              authorizedTopics: Set[String],
                              groupId: String,
-                             memberId: String): 
CompletableFuture[Map[TopicIdPartition, 
ShareAcknowledgeResponseData.PartitionData]] = {
+                             memberId: String,
+                             supportsRenewAcknowledgements: Boolean,
+                             isRenewAck: Boolean): 
CompletableFuture[Map[TopicIdPartition, 
ShareAcknowledgeResponseData.PartitionData]] = {
 
-    val erroneousTopicIdPartitions = 
validateAcknowledgementBatches(acknowledgementData, erroneous)
+    val erroneousTopicIdPartitions = 
validateAcknowledgementBatches(acknowledgementData, erroneous, 
supportsRenewAcknowledgements, isRenewAck)
     erroneousTopicIdPartitions.foreach(tp => acknowledgementData.remove(tp))
 
     val interested = mutable.Map[TopicIdPartition, 
util.List[ShareAcknowledgementBatch]]()
@@ -3523,7 +3560,7 @@ class KafkaApis(val requestChannel: RequestChannel,
 
     val erroneous = mutable.Map[TopicIdPartition, 
ShareAcknowledgeResponseData.PartitionData]()
     val acknowledgementDataFromRequest = 
getAcknowledgeBatchesFromShareAcknowledgeRequest(shareAcknowledgeRequest, 
topicIdNames, erroneous)
-    handleAcknowledgements(acknowledgementDataFromRequest, erroneous, 
sharePartitionManager, authorizedTopics, groupId, memberId)
+    handleAcknowledgements(acknowledgementDataFromRequest, erroneous, 
sharePartitionManager, authorizedTopics, groupId, memberId, 
shareAcknowledgeRequest.version == 2, shareAcknowledgeRequest.data.isRenewAck)
       .handle[Unit] {(result, exception) =>
         if (exception != null) {
           requestHelper.sendMaybeThrottle(request, 
shareAcknowledgeRequest.getErrorResponse(AbstractResponse.DEFAULT_THROTTLE_TIME,
 exception))
@@ -3987,13 +4024,16 @@ class KafkaApis(val requestChannel: RequestChannel,
 
   // Visible for Testing
   def validateAcknowledgementBatches(acknowledgementDataFromRequest: 
mutable.Map[TopicIdPartition, util.List[ShareAcknowledgementBatch]],
-                                     erroneous: mutable.Map[TopicIdPartition, 
ShareAcknowledgeResponseData.PartitionData]
+                                     erroneous: mutable.Map[TopicIdPartition, 
ShareAcknowledgeResponseData.PartitionData],
+                                     supportsRenewAcknowledgements: Boolean,
+                                     isRenewAck: Boolean
                                     ): mutable.Set[TopicIdPartition] = {
     val erroneousTopicIdPartitions: mutable.Set[TopicIdPartition] = 
mutable.Set.empty[TopicIdPartition]
 
     acknowledgementDataFromRequest.foreach { case (tp: TopicIdPartition, 
acknowledgeBatches: util.List[ShareAcknowledgementBatch]) =>
       var prevEndOffset = -1L
       var isErroneous = false
+      val maxAcknowledgeType = if (supportsRenewAcknowledgements) 4 else 3
       acknowledgeBatches.forEach { batch =>
         if (!isErroneous) {
           if (batch.firstOffset > batch.lastOffset) {
@@ -4012,7 +4052,11 @@ class KafkaApis(val requestChannel: RequestChannel,
             erroneous += tp -> ShareAcknowledgeResponse.partitionResponse(tp, 
Errors.INVALID_REQUEST)
             erroneousTopicIdPartitions.add(tp)
             isErroneous = true
-          } else if (batch.acknowledgeTypes.stream().anyMatch(ackType => 
ackType < 0 || ackType > 3)) {
+          } else if (batch.acknowledgeTypes.stream().anyMatch(ackType => 
ackType < 0 || ackType > maxAcknowledgeType)) {
+            erroneous += tp -> ShareAcknowledgeResponse.partitionResponse(tp, 
Errors.INVALID_REQUEST)
+            erroneousTopicIdPartitions.add(tp)
+            isErroneous = true
+          } else if (batch.acknowledgeTypes.stream().anyMatch(ackType => 
ackType == 4) && !isRenewAck) {
             erroneous += tp -> ShareAcknowledgeResponse.partitionResponse(tp, 
Errors.INVALID_REQUEST)
             erroneousTopicIdPartitions.add(tp)
             isErroneous = true
diff --git a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala 
b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
index b6bc5a81a98..b66c2a14d82 100644
--- a/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
+++ b/core/src/test/scala/unit/kafka/server/KafkaApisTest.scala
@@ -26,6 +26,7 @@ import kafka.server.share.SharePartitionManager
 import kafka.utils.{CoreUtils, Logging, TestUtils}
 import org.apache.kafka.clients.admin.AlterConfigOp.OpType
 import org.apache.kafka.clients.admin.{AlterConfigOp, ConfigEntry}
+import org.apache.kafka.clients.consumer.AcknowledgeType
 import org.apache.kafka.common._
 import org.apache.kafka.common.acl.AclOperation
 import org.apache.kafka.common.compress.Compression
@@ -6582,6 +6583,199 @@ class KafkaApisTest extends Logging {
     assertArrayEquals(expectedAcquiredRecords(10, 19, 1).toArray(), 
topicResponse.partitions.get(0).acquiredRecords.toArray())
   }
 
+  @Test
+  def testHandleShareFetchRequestSuccessWithRenewAcknowledgements(): Unit = {
+    val topicName = "foo"
+    val topicId = Uuid.randomUuid()
+    val partitionIndex = 0
+    metadataCache = initializeMetadataCacheWithShareGroupsEnabled()
+    addTopicToMetadataCache(topicName, 1, topicId = topicId)
+    val memberId: Uuid = Uuid.randomUuid()
+
+    val records1 = memoryRecords(10, 0)
+
+    val groupId = "group"
+
+    when(sharePartitionManager.fetchMessages(any(), any(), any(), any(), 
anyInt(), anyInt(), anyInt(), any())).thenReturn(
+      CompletableFuture.completedFuture(util.Map.of[TopicIdPartition, 
ShareFetchResponseData.PartitionData](
+        new TopicIdPartition(topicId, new TopicPartition(topicName, 
partitionIndex)),
+        new ShareFetchResponseData.PartitionData()
+          .setErrorCode(Errors.NONE.code)
+          .setAcknowledgeErrorCode(Errors.NONE.code)
+          .setRecords(records1)
+          .setAcquiredRecords(new util.ArrayList(util.List.of(
+            new ShareFetchResponseData.AcquiredRecords()
+              .setFirstOffset(0)
+              .setLastOffset(9)
+              .setDeliveryCount(1)
+          )))
+      ))
+    )
+
+    val cachedSharePartitions = new 
ImplicitLinkedHashCollection[CachedSharePartition]
+    cachedSharePartitions.mustAdd(new CachedSharePartition(
+      new TopicIdPartition(topicId, 0, topicName), false
+    ))
+
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any())).thenReturn(
+      new ShareSessionContext(new ShareRequestMetadata(memberId, 0), 
util.List.of(
+        new TopicIdPartition(topicId, partitionIndex, topicName)
+      ))
+    ).thenReturn(new ShareSessionContext(new ShareRequestMetadata(memberId, 
1), new ShareSession(
+      new ShareSessionKey(groupId, memberId), cachedSharePartitions, 2))
+    )
+
+    when(sharePartitionManager.acknowledge(any(), any(), any())).thenReturn(
+      CompletableFuture.completedFuture(util.Map.of[TopicIdPartition, 
ShareAcknowledgeResponseData.PartitionData](
+        new TopicIdPartition(topicId, new TopicPartition(topicName, 0)),
+        new ShareAcknowledgeResponseData.PartitionData()
+          .setPartitionIndex(0)
+          .setErrorCode(Errors.NONE.code),
+      ))
+    )
+
+    // First request to get some records.
+    var shareFetchRequestData = new ShareFetchRequestData().
+      setGroupId(groupId).
+      setMemberId(memberId.toString).
+      setShareSessionEpoch(0).
+      setTopics(new 
ShareFetchRequestData.FetchTopicCollection(util.List.of(new 
ShareFetchRequestData.FetchTopic().
+        setTopicId(topicId).
+        setPartitions(new 
ShareFetchRequestData.FetchPartitionCollection(util.List.of(
+          new ShareFetchRequestData.FetchPartition()
+            .setPartitionIndex(0)).iterator))).iterator))
+
+    var shareFetchRequest = new 
ShareFetchRequest.Builder(shareFetchRequestData).build(ApiKeys.SHARE_FETCH.latestVersion)
+    var request = buildRequest(shareFetchRequest)
+    kafkaApis = createKafkaApis()
+    kafkaApis.handleShareFetchRequest(request)
+    var response = verifyNoThrottling[ShareFetchResponse](request)
+    var responseData = response.data()
+    var topicResponses = responseData.responses()
+
+    assertEquals(Errors.NONE.code, responseData.errorCode)
+    assertEquals(1, topicResponses.size())
+    var topicResponse = topicResponses.stream.findFirst.get
+    assertEquals(topicId, topicResponse.topicId)
+    assertEquals(1, topicResponse.partitions.size())
+    assertEquals(partitionIndex, 
topicResponse.partitions.get(0).partitionIndex)
+    assertEquals(Errors.NONE.code, topicResponse.partitions.get(0).errorCode)
+    assertEquals(records1, topicResponse.partitions.get(0).records)
+    assertArrayEquals(expectedAcquiredRecords(0, 9, 1).toArray(), 
topicResponse.partitions.get(0).acquiredRecords.toArray())
+
+    // Second request with RENEW ack.
+    shareFetchRequestData = new ShareFetchRequestData().
+      setGroupId("group").
+      setMemberId(memberId.toString).
+      setShareSessionEpoch(1).
+      setIsRenewAck(true).
+      setMaxBytes(0).
+      setMinBytes(0).
+      setMaxWaitMs(0).
+      setMaxRecords(0).
+      setTopics(new 
ShareFetchRequestData.FetchTopicCollection(util.List.of(new 
ShareFetchRequestData.FetchTopic().
+        setTopicId(topicId).
+        setPartitions(new 
ShareFetchRequestData.FetchPartitionCollection(util.List.of(
+          new ShareFetchRequestData.FetchPartition()
+            .setAcknowledgementBatches(util.List.of(new AcknowledgementBatch()
+              .setFirstOffset(0)
+              .setLastOffset(8)
+              .setAcknowledgeTypes(util.List.of(AcknowledgeType.ACCEPT.id)),
+              new AcknowledgementBatch()
+                .setFirstOffset(9)
+                .setLastOffset(9)
+                .setAcknowledgeTypes(util.List.of(AcknowledgeType.RENEW.id))))
+            .setPartitionIndex(0)).iterator))).iterator))
+
+    shareFetchRequest = new 
ShareFetchRequest.Builder(shareFetchRequestData).build(ApiKeys.SHARE_FETCH.latestVersion)
+    request = buildRequest(shareFetchRequest)
+    kafkaApis.handleShareFetchRequest(request)
+    response = verifyNoThrottling[ShareFetchResponse](request)
+    responseData = response.data()
+    topicResponses = responseData.responses()
+    val expectedRecords = memoryRecords(0, 0)
+
+    assertEquals(Errors.NONE.code, responseData.errorCode)
+    assertEquals(1, topicResponses.size())
+    topicResponse = topicResponses.stream.findFirst.get
+    assertEquals(topicId, topicResponse.topicId)
+    assertEquals(1, topicResponse.partitions.size())
+    assertEquals(partitionIndex, 
topicResponse.partitions.get(0).partitionIndex)
+    assertEquals(Errors.NONE.code, topicResponse.partitions.get(0).errorCode)
+    assertEquals(Errors.NONE.code, 
topicResponse.partitions.get(0).acknowledgeErrorCode)
+    assertEquals(expectedRecords, topicResponse.partitions.get(0).records)
+    assertEquals(0, topicResponse.partitions.get(0).acquiredRecords.size())
+    // fetchMessages only called once for 1st ShareFetch.
+    verify(sharePartitionManager, times(1)).fetchMessages(any(), any(), any(), 
any(), anyInt(), anyInt(), anyInt(), any())
+  }
+
+  @ParameterizedTest
+  @CsvSource(value=Array("true,false", "false,true"))
+  def testHandleShareAcknowledgeRequestWithRenewAcknowledgements(isRenewAck: 
Boolean, shouldFail: Boolean): Unit = {
+    val topicName = "foo"
+    val topicId = Uuid.randomUuid()
+    val partitionIndex = 0
+    metadataCache = initializeMetadataCacheWithShareGroupsEnabled()
+    addTopicToMetadataCache(topicName, 1, topicId = topicId)
+    val memberId: Uuid = Uuid.randomUuid()
+
+    val groupId = "group"
+
+    when(sharePartitionManager.acknowledge(any(), any(), any())).thenReturn(
+      CompletableFuture.completedFuture(util.Map.of[TopicIdPartition, 
ShareAcknowledgeResponseData.PartitionData](
+        new TopicIdPartition(topicId, new TopicPartition(topicName, 0)),
+        new ShareAcknowledgeResponseData.PartitionData()
+          .setPartitionIndex(0)
+          .setErrorCode(Errors.NONE.code)
+      ))
+    )
+
+    doNothing().when(sharePartitionManager).acknowledgeSessionUpdate(any(), 
any())
+
+    val shareAcknowledgeRequestData = new ShareAcknowledgeRequestData().
+      setGroupId(groupId).
+      setMemberId(memberId.toString).
+      setShareSessionEpoch(1).
+      setIsRenewAck(isRenewAck).
+      setTopics(new 
ShareAcknowledgeRequestData.AcknowledgeTopicCollection(util.List.of(new 
ShareAcknowledgeRequestData.AcknowledgeTopic().
+        setTopicId(topicId).
+        setPartitions(new 
ShareAcknowledgeRequestData.AcknowledgePartitionCollection(util.List.of(
+          new ShareAcknowledgeRequestData.AcknowledgePartition()
+            .setPartitionIndex(partitionIndex)
+            .setAcknowledgementBatches(util.List.of(
+              new ShareAcknowledgeRequestData.AcknowledgementBatch()
+                .setFirstOffset(0)
+                .setLastOffset(8)
+                .setAcknowledgeTypes(util.List.of(AcknowledgeType.ACCEPT.id)),
+              new ShareAcknowledgeRequestData.AcknowledgementBatch()
+                .setFirstOffset(9)
+                .setLastOffset(9)
+                .setAcknowledgeTypes(util.List.of(AcknowledgeType.RENEW.id))
+            ))
+        ).iterator))).iterator))
+
+    val shareAcknowledgeRequest = new 
ShareAcknowledgeRequest.Builder(shareAcknowledgeRequestData)
+      .build(ApiKeys.SHARE_ACKNOWLEDGE.latestVersion)
+    val request = buildRequest(shareAcknowledgeRequest)
+    kafkaApis = createKafkaApis()
+    kafkaApis.handleShareAcknowledgeRequest(request)
+    val response = verifyNoThrottling[ShareAcknowledgeResponse](request)
+    val responseData = response.data()
+    val topicResponses = responseData.responses()
+
+    assertEquals(Errors.NONE.code, responseData.errorCode)
+    assertEquals(1, topicResponses.size())
+    val topicResponse = topicResponses.stream.findFirst.get
+    assertEquals(topicId, topicResponse.topicId)
+    assertEquals(1, topicResponse.partitions.size())
+    assertEquals(partitionIndex, 
topicResponse.partitions.get(0).partitionIndex)
+    if(shouldFail) {
+      assertEquals(Errors.INVALID_REQUEST.code, 
topicResponse.partitions.get(0).errorCode)
+    } else {
+      assertEquals(Errors.NONE.code, topicResponse.partitions.get(0).errorCode)
+    }
+  }
+
   @Test
   def testHandleShareFetchShareGroupDisabled(): Unit = {
     val topicId = Uuid.randomUuid()
@@ -7431,7 +7625,7 @@ class KafkaApisTest extends Logging {
 
     kafkaApis = createKafkaApis()
     val acknowledgeBatches = 
kafkaApis.getAcknowledgeBatchesFromShareFetchRequest(shareFetchRequest, 
topicIdNames, erroneous)
-    val erroneousTopicIdPartitions = 
kafkaApis.validateAcknowledgementBatches(acknowledgeBatches, erroneous)
+    val erroneousTopicIdPartitions = 
kafkaApis.validateAcknowledgementBatches(acknowledgeBatches, erroneous, 
supportsRenewAcknowledgements = true, isRenewAck = false)
 
     assertEquals(3, erroneous.size)
     assertEquals(2, erroneousTopicIdPartitions.size)
@@ -7560,7 +7754,7 @@ class KafkaApisTest extends Logging {
 
     kafkaApis = createKafkaApis()
     val acknowledgeBatches = 
kafkaApis.getAcknowledgeBatchesFromShareAcknowledgeRequest(shareAcknowledgeRequest,
 topicIdNames, erroneous)
-    val erroneousTopicIdPartitions = 
kafkaApis.validateAcknowledgementBatches(acknowledgeBatches, erroneous)
+    val erroneousTopicIdPartitions = 
kafkaApis.validateAcknowledgementBatches(acknowledgeBatches, erroneous, 
supportsRenewAcknowledgements = true, isRenewAck = false)
 
     assertEquals(3, erroneous.size)
     assertEquals(2, erroneousTopicIdPartitions.size)
@@ -7633,7 +7827,9 @@ class KafkaApisTest extends Logging {
       sharePartitionManager,
       authorizedTopics,
       groupId,
-      memberId.toString
+      memberId.toString,
+      supportsRenewAcknowledgements = true,
+      isRenewAck = false
     ).get()
 
     assertEquals(3, ackResult.size)
@@ -7708,7 +7904,9 @@ class KafkaApisTest extends Logging {
       sharePartitionManager,
       authorizedTopics,
       groupId,
-      memberId.toString
+      memberId.toString,
+      supportsRenewAcknowledgements = true,
+      isRenewAck = false
     ).get()
 
     assertEquals(3, ackResult.size)
@@ -7784,7 +7982,9 @@ class KafkaApisTest extends Logging {
       sharePartitionManager,
       authorizedTopics,
       groupId,
-      memberId.toString
+      memberId.toString,
+      supportsRenewAcknowledgements = true,
+      isRenewAck = false
     ).get()
 
     assertEquals(3, ackResult.size)
@@ -7854,7 +8054,9 @@ class KafkaApisTest extends Logging {
       sharePartitionManager,
       authorizedTopics,
       groupId,
-      memberId.toString
+      memberId.toString,
+      supportsRenewAcknowledgements = true,
+      isRenewAck = false
     ).get()
 
     assertEquals(3, ackResult.size)
@@ -13765,6 +13967,73 @@ class KafkaApisTest extends Logging {
     assertEquals(alterShareGroupOffsetsResponseData, response.data)
   }
 
+  @ParameterizedTest
+  @CsvSource(value = Array("1,true,true", "1,false,true", "2,true,false", 
"2,false,true"))
+  def testValidateAcknowledgementBatchesForRenew(version: Short, isRenew: 
Boolean, shouldFail: Boolean): Unit = {
+    kafkaApis = createKafkaApis()
+    val tp = new TopicIdPartition(Uuid.randomUuid(), new 
TopicPartition("topic", 0))
+    val ackMap = mutable.Map(tp -> util.List.of(new 
ShareAcknowledgementBatch(0, 0, util.List.of(AcknowledgeType.RENEW.id))))
+    val erroneous:mutable.Map[TopicIdPartition, 
ShareAcknowledgeResponseData.PartitionData] = mutable.Map()
+    val errorSet = kafkaApis.validateAcknowledgementBatches(ackMap, erroneous, 
supportsRenewAcknowledgements = version == 2, isRenewAck = isRenew)
+    if (shouldFail) {
+      assertEquals(1, errorSet.size, s"expected error topic partition, 
version=${version}, isRenew=${isRenew}")
+      assertTrue(errorSet.contains(tp), s"error topic partition mismatch, 
version=${version}, isRenew=${isRenew}")
+    } else {
+      assertEquals(0, errorSet.size, s"unexpected error topic partition, 
version=${version}, isRenew=${isRenew}")
+    }
+  }
+
+  @Test
+  def testHandleShareFetchRenewInvalidRequest(): Unit = {
+    val topicId = Uuid.randomUuid()
+    val partitionIndex = 0
+    val groupId = "group"
+    val memberId = Uuid.randomUuid()
+    val testPrincipal = new KafkaPrincipal(KafkaPrincipal.USER_TYPE, 
"test-user")
+    val testClientAddress = InetAddress.getByName("192.168.1.100")
+    val testClientId = "test-client-id"
+    metadataCache = initializeMetadataCacheWithShareGroupsEnabled()
+
+    when(sharePartitionManager.newContext(any(), any(), any(), any(), any(), 
any())).thenReturn(
+      new FinalContext()
+    )
+
+    val shareFetchRequestData = new ShareFetchRequestData()
+      .setGroupId(groupId)
+      .setMemberId(memberId.toString)
+      .setShareSessionEpoch(0)
+      .setIsRenewAck(true)
+      .setMinBytes(10)
+      .setMaxBytes(20)
+      .setMaxRecords(30)
+      .setMaxWaitMs(40)
+      .setTopics(new 
ShareFetchRequestData.FetchTopicCollection(util.List.of(new 
ShareFetchRequestData.FetchTopic()
+        .setTopicId(topicId)
+        .setPartitions(new 
ShareFetchRequestData.FetchPartitionCollection(util.List.of(
+          new ShareFetchRequestData.FetchPartition()
+            .setAcknowledgementBatches(util.List.of(new AcknowledgementBatch()
+              .setFirstOffset(0)
+              .setLastOffset(0)
+              .setAcknowledgeTypes(util.List.of(AcknowledgeType.RENEW.id))))
+            .setPartitionIndex(partitionIndex)
+        ).iterator))
+      ).iterator))
+
+    val shareFetchRequest = new 
ShareFetchRequest.Builder(shareFetchRequestData).build(ApiKeys.SHARE_FETCH.latestVersion)
+
+    // Create request with custom principal and client address to test quota 
tags
+    val requestHeader = new RequestHeader(shareFetchRequest.apiKey, 
shareFetchRequest.version, testClientId, 0)
+    val request = buildRequest(shareFetchRequest, testPrincipal, 
testClientAddress,
+      ListenerName.forSecurityProtocol(SecurityProtocol.SSL), 
fromPrivilegedListener = false, Some(requestHeader), requestChannelMetrics)
+
+    val kafkaApis = createKafkaApis()
+    kafkaApis.handleShareFetchRequest(request)
+    val response = verifyNoThrottling[ShareFetchResponse](request)
+    val responseData = response.data()
+
+    assertEquals(Errors.INVALID_REQUEST.code, responseData.errorCode)
+  }
+
   def getShareGroupDescribeResponse(groupIds: util.List[String], 
enableShareGroups: Boolean = true,
                                     verifyNoErr: Boolean = true, authorizer: 
Authorizer = null,
                                     describedGroups: 
util.List[ShareGroupDescribeResponseData.DescribedGroup]): 
ShareGroupDescribeResponse = {

Reply via email to