mjsax commented on code in PR #21365:
URL: https://github.com/apache/kafka/pull/21365#discussion_r2744349397


##########
streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/RebalanceTaskClosureIntegrationTest.java:
##########
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.integration;
+
+import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.common.serialization.LongSerializer;
+import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.serialization.StringSerializer;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.streams.CloseOptions;
+import org.apache.kafka.streams.KafkaStreams;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.TopologyWrapper;
+import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
+import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
+import org.apache.kafka.streams.processor.StateStore;
+import org.apache.kafka.streams.processor.StateStoreContext;
+import org.apache.kafka.streams.processor.internals.StreamThread;
+import org.apache.kafka.streams.state.KeyValueStore;
+import org.apache.kafka.streams.state.StoreBuilder;
+import org.apache.kafka.streams.state.internals.AbstractStoreBuilder;
+import org.apache.kafka.streams.state.internals.CacheFlushListener;
+import org.apache.kafka.streams.state.internals.CachedStateStore;
+import org.apache.kafka.streams.state.internals.RocksDBStore;
+import org.apache.kafka.test.MockApiProcessorSupplier;
+import org.apache.kafka.test.TestUtils;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInfo;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.util.List;
+import java.util.Properties;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.apache.kafka.streams.utils.TestUtils.safeUniqueTestName;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class RebalanceTaskClosureIntegrationTest {
+
+    private static final int NUM_BROKERS = 1;
+    protected static final String INPUT_TOPIC_NAME = "input-topic";
+    private static final int NUM_PARTITIONS = 3;
+
+    private final EmbeddedKafkaCluster cluster = new 
EmbeddedKafkaCluster(NUM_BROKERS);
+
+    private KafkaStreamsWrapper streams1;
+    private KafkaStreamsWrapper streams2;
+    private String safeTestName;
+
+    @BeforeEach
+    public void before(final TestInfo testInfo) throws InterruptedException, 
IOException {
+        cluster.start();
+        cluster.createTopic(INPUT_TOPIC_NAME, NUM_PARTITIONS, 1);
+        safeTestName = safeUniqueTestName(testInfo);
+
+    }
+
+    @AfterEach
+    public void after() {
+        cluster.stop();
+        if (streams1 != null) {
+            streams1.close(Duration.ofSeconds(30));
+        }
+        if (streams2 != null) {
+            streams2.close(Duration.ofSeconds(30));
+        }
+    }
+
+    @Test
+    public void shouldClosePendingTasksToInitAfterRebalance() throws Exception 
{
+        final CountDownLatch recycleLatch = new CountDownLatch(1);
+        final CountDownLatch pendingShutdownLatch = new CountDownLatch(1);
+        // Count how many times we initialize and close stores
+        final AtomicInteger initCount = new AtomicInteger();
+        final AtomicInteger closeCount = new AtomicInteger();
+        final StoreBuilder<KeyValueStore<Bytes, byte[]>> storeBuilder = new 
AbstractStoreBuilder<>("testStateStore", Serdes.Integer(), Serdes.ByteArray(), 
new MockTime()) {
+
+            @Override
+            public KeyValueStore<Bytes, byte[]> build() {
+                return new TestRocksDBStore(name, recycleLatch, 
pendingShutdownLatch, initCount, closeCount);
+            }
+        };
+
+        final TopologyWrapper topology = new TopologyWrapper();
+        topology.addSource("ingest", INPUT_TOPIC_NAME);
+        topology.addProcessor("my-processor", new 
MockApiProcessorSupplier<>(), "ingest");
+        topology.addStateStore(storeBuilder, "my-processor");
+
+        streams1 = new KafkaStreamsWrapper(topology, props("1"));
+        streams1.setStreamThreadStateListener((t, newState, oldState) -> {
+            if (newState == StreamThread.State.PENDING_SHUTDOWN) {
+                pendingShutdownLatch.countDown();
+            }
+        });
+        streams1.start();
+
+        TestUtils.waitForCondition(() -> streams1.state() == 
KafkaStreams.State.RUNNING, "Streams never reached RUNNING state");
+
+        streams2 = new KafkaStreamsWrapper(topology, props("2"));
+        streams2.start();
+
+        TestUtils.waitForCondition(() -> streams2.state() == 
KafkaStreams.State.RUNNING, "Streams never reached RUNNING state");
+
+        // starting the second KS app triggered a rebalance. Which in turn 
will recycle active tasks that need to become standby.
+        // That's exactly what we are waiting for
+        recycleLatch.await();
+
+        // sending a message to disable retries in the consumer client. if 
there are no messages, it retries the whole sequence of actions,
+        // including the rebalance data. which we don't want, because we just 
staged the right condition
+        IntegrationTestUtils.produceKeyValuesSynchronously(INPUT_TOPIC_NAME, 
List.of(new KeyValue<>(1L, "key")),
+                TestUtils.producerConfig(cluster.bootstrapServers(), 
LongSerializer.class, StringSerializer.class, new Properties()), cluster.time);
+        // Now we can close both apps. The StreamThreadStateListener will 
unblock the clearCache call, letting the rebalance finish.
+        // We don't want it to happen any sooner, because we want the stream 
thread to stop before it gets to moving messages from task registry to state 
updater.
+        
streams1.close(CloseOptions.groupMembershipOperation(CloseOptions.GroupMembershipOperation.LEAVE_GROUP));
+        
streams2.close(CloseOptions.groupMembershipOperation(CloseOptions.GroupMembershipOperation.LEAVE_GROUP));
+
+        assertEquals(initCount.get(), closeCount.get());
+    }
+
+    private Properties props(final String storePathSuffix) {
+        final Properties streamsConfiguration = new Properties();
+
+        streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, 
safeTestName);
+        streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, 
cluster.bootstrapServers());
+        streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, 
"earliest");

Review Comment:
   This is the default IIRC -- no need to set it



##########
streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/RebalanceTaskClosureIntegrationTest.java:
##########
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.integration;
+
+import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.common.serialization.LongSerializer;
+import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.serialization.StringSerializer;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.streams.CloseOptions;
+import org.apache.kafka.streams.KafkaStreams;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.TopologyWrapper;
+import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
+import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
+import org.apache.kafka.streams.processor.StateStore;
+import org.apache.kafka.streams.processor.StateStoreContext;
+import org.apache.kafka.streams.processor.internals.StreamThread;
+import org.apache.kafka.streams.state.KeyValueStore;
+import org.apache.kafka.streams.state.StoreBuilder;
+import org.apache.kafka.streams.state.internals.AbstractStoreBuilder;
+import org.apache.kafka.streams.state.internals.CacheFlushListener;
+import org.apache.kafka.streams.state.internals.CachedStateStore;
+import org.apache.kafka.streams.state.internals.RocksDBStore;
+import org.apache.kafka.test.MockApiProcessorSupplier;
+import org.apache.kafka.test.TestUtils;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInfo;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.util.List;
+import java.util.Properties;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.apache.kafka.streams.utils.TestUtils.safeUniqueTestName;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class RebalanceTaskClosureIntegrationTest {
+
+    private static final int NUM_BROKERS = 1;
+    protected static final String INPUT_TOPIC_NAME = "input-topic";
+    private static final int NUM_PARTITIONS = 3;
+
+    private final EmbeddedKafkaCluster cluster = new 
EmbeddedKafkaCluster(NUM_BROKERS);
+
+    private KafkaStreamsWrapper streams1;
+    private KafkaStreamsWrapper streams2;
+    private String safeTestName;
+
+    @BeforeEach
+    public void before(final TestInfo testInfo) throws InterruptedException, 
IOException {
+        cluster.start();
+        cluster.createTopic(INPUT_TOPIC_NAME, NUM_PARTITIONS, 1);
+        safeTestName = safeUniqueTestName(testInfo);
+

Review Comment:
   ```suggestion
   ```



##########
streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/RebalanceTaskClosureIntegrationTest.java:
##########
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.integration;
+
+import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.common.serialization.LongSerializer;
+import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.serialization.StringSerializer;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.streams.CloseOptions;
+import org.apache.kafka.streams.KafkaStreams;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.TopologyWrapper;
+import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
+import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
+import org.apache.kafka.streams.processor.StateStore;
+import org.apache.kafka.streams.processor.StateStoreContext;
+import org.apache.kafka.streams.processor.internals.StreamThread;
+import org.apache.kafka.streams.state.KeyValueStore;
+import org.apache.kafka.streams.state.StoreBuilder;
+import org.apache.kafka.streams.state.internals.AbstractStoreBuilder;
+import org.apache.kafka.streams.state.internals.CacheFlushListener;
+import org.apache.kafka.streams.state.internals.CachedStateStore;
+import org.apache.kafka.streams.state.internals.RocksDBStore;
+import org.apache.kafka.test.MockApiProcessorSupplier;
+import org.apache.kafka.test.TestUtils;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInfo;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.util.List;
+import java.util.Properties;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.apache.kafka.streams.utils.TestUtils.safeUniqueTestName;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class RebalanceTaskClosureIntegrationTest {
+
+    private static final int NUM_BROKERS = 1;
+    protected static final String INPUT_TOPIC_NAME = "input-topic";
+    private static final int NUM_PARTITIONS = 3;
+
+    private final EmbeddedKafkaCluster cluster = new 
EmbeddedKafkaCluster(NUM_BROKERS);
+
+    private KafkaStreamsWrapper streams1;
+    private KafkaStreamsWrapper streams2;
+    private String safeTestName;
+
+    @BeforeEach
+    public void before(final TestInfo testInfo) throws InterruptedException, 
IOException {
+        cluster.start();
+        cluster.createTopic(INPUT_TOPIC_NAME, NUM_PARTITIONS, 1);
+        safeTestName = safeUniqueTestName(testInfo);
+
+    }
+
+    @AfterEach
+    public void after() {
+        cluster.stop();
+        if (streams1 != null) {
+            streams1.close(Duration.ofSeconds(30));
+        }
+        if (streams2 != null) {
+            streams2.close(Duration.ofSeconds(30));
+        }
+    }
+
+    @Test
+    public void shouldClosePendingTasksToInitAfterRebalance() throws Exception 
{
+        final CountDownLatch recycleLatch = new CountDownLatch(1);
+        final CountDownLatch pendingShutdownLatch = new CountDownLatch(1);
+        // Count how many times we initialize and close stores
+        final AtomicInteger initCount = new AtomicInteger();
+        final AtomicInteger closeCount = new AtomicInteger();
+        final StoreBuilder<KeyValueStore<Bytes, byte[]>> storeBuilder = new 
AbstractStoreBuilder<>("testStateStore", Serdes.Integer(), Serdes.ByteArray(), 
new MockTime()) {
+
+            @Override
+            public KeyValueStore<Bytes, byte[]> build() {
+                return new TestRocksDBStore(name, recycleLatch, 
pendingShutdownLatch, initCount, closeCount);
+            }
+        };
+
+        final TopologyWrapper topology = new TopologyWrapper();
+        topology.addSource("ingest", INPUT_TOPIC_NAME);
+        topology.addProcessor("my-processor", new 
MockApiProcessorSupplier<>(), "ingest");
+        topology.addStateStore(storeBuilder, "my-processor");
+
+        streams1 = new KafkaStreamsWrapper(topology, props("1"));
+        streams1.setStreamThreadStateListener((t, newState, oldState) -> {
+            if (newState == StreamThread.State.PENDING_SHUTDOWN) {
+                pendingShutdownLatch.countDown();
+            }
+        });
+        streams1.start();
+
+        TestUtils.waitForCondition(() -> streams1.state() == 
KafkaStreams.State.RUNNING, "Streams never reached RUNNING state");
+
+        streams2 = new KafkaStreamsWrapper(topology, props("2"));
+        streams2.start();
+
+        TestUtils.waitForCondition(() -> streams2.state() == 
KafkaStreams.State.RUNNING, "Streams never reached RUNNING state");
+
+        // starting the second KS app triggered a rebalance. Which in turn 
will recycle active tasks that need to become standby.
+        // That's exactly what we are waiting for
+        recycleLatch.await();
+
+        // sending a message to disable retries in the consumer client. if 
there are no messages, it retries the whole sequence of actions,
+        // including the rebalance data. which we don't want, because we just 
staged the right condition
+        IntegrationTestUtils.produceKeyValuesSynchronously(INPUT_TOPIC_NAME, 
List.of(new KeyValue<>(1L, "key")),

Review Comment:
   Not sure if I understand. "disable retires on the consumer" -- what retires 
are you referring to, and what consumer exaclty?
   
   "it retires the whole sequence of actions, including the rebalance data" -- 
what sequence are you referring too? What do you mean by "rebalance data"?
   
   "which we don't want, because we just staged the right condition" -- Why do 
we not want this? Not sure what the "right condition" is, and why we would 
"lose it" w/o sending an input record.



##########
streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/RebalanceTaskClosureIntegrationTest.java:
##########
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.integration;
+
+import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.common.serialization.LongSerializer;
+import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.serialization.StringSerializer;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.streams.CloseOptions;
+import org.apache.kafka.streams.KafkaStreams;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.TopologyWrapper;
+import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
+import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
+import org.apache.kafka.streams.processor.StateStore;
+import org.apache.kafka.streams.processor.StateStoreContext;
+import org.apache.kafka.streams.processor.internals.StreamThread;
+import org.apache.kafka.streams.state.KeyValueStore;
+import org.apache.kafka.streams.state.StoreBuilder;
+import org.apache.kafka.streams.state.internals.AbstractStoreBuilder;
+import org.apache.kafka.streams.state.internals.CacheFlushListener;
+import org.apache.kafka.streams.state.internals.CachedStateStore;
+import org.apache.kafka.streams.state.internals.RocksDBStore;
+import org.apache.kafka.test.MockApiProcessorSupplier;
+import org.apache.kafka.test.TestUtils;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInfo;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.util.List;
+import java.util.Properties;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.apache.kafka.streams.utils.TestUtils.safeUniqueTestName;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class RebalanceTaskClosureIntegrationTest {
+
+    private static final int NUM_BROKERS = 1;
+    protected static final String INPUT_TOPIC_NAME = "input-topic";
+    private static final int NUM_PARTITIONS = 3;
+
+    private final EmbeddedKafkaCluster cluster = new 
EmbeddedKafkaCluster(NUM_BROKERS);
+
+    private KafkaStreamsWrapper streams1;
+    private KafkaStreamsWrapper streams2;
+    private String safeTestName;
+
+    @BeforeEach
+    public void before(final TestInfo testInfo) throws InterruptedException, 
IOException {
+        cluster.start();
+        cluster.createTopic(INPUT_TOPIC_NAME, NUM_PARTITIONS, 1);
+        safeTestName = safeUniqueTestName(testInfo);
+
+    }
+
+    @AfterEach
+    public void after() {
+        cluster.stop();
+        if (streams1 != null) {
+            streams1.close(Duration.ofSeconds(30));
+        }
+        if (streams2 != null) {
+            streams2.close(Duration.ofSeconds(30));
+        }
+    }
+
+    @Test
+    public void shouldClosePendingTasksToInitAfterRebalance() throws Exception 
{
+        final CountDownLatch recycleLatch = new CountDownLatch(1);
+        final CountDownLatch pendingShutdownLatch = new CountDownLatch(1);
+        // Count how many times we initialize and close stores
+        final AtomicInteger initCount = new AtomicInteger();
+        final AtomicInteger closeCount = new AtomicInteger();
+        final StoreBuilder<KeyValueStore<Bytes, byte[]>> storeBuilder = new 
AbstractStoreBuilder<>("testStateStore", Serdes.Integer(), Serdes.ByteArray(), 
new MockTime()) {
+
+            @Override
+            public KeyValueStore<Bytes, byte[]> build() {
+                return new TestRocksDBStore(name, recycleLatch, 
pendingShutdownLatch, initCount, closeCount);
+            }
+        };
+
+        final TopologyWrapper topology = new TopologyWrapper();
+        topology.addSource("ingest", INPUT_TOPIC_NAME);
+        topology.addProcessor("my-processor", new 
MockApiProcessorSupplier<>(), "ingest");
+        topology.addStateStore(storeBuilder, "my-processor");
+
+        streams1 = new KafkaStreamsWrapper(topology, props("1"));
+        streams1.setStreamThreadStateListener((t, newState, oldState) -> {
+            if (newState == StreamThread.State.PENDING_SHUTDOWN) {
+                pendingShutdownLatch.countDown();
+            }
+        });
+        streams1.start();
+
+        TestUtils.waitForCondition(() -> streams1.state() == 
KafkaStreams.State.RUNNING, "Streams never reached RUNNING state");
+
+        streams2 = new KafkaStreamsWrapper(topology, props("2"));
+        streams2.start();
+
+        TestUtils.waitForCondition(() -> streams2.state() == 
KafkaStreams.State.RUNNING, "Streams never reached RUNNING state");
+
+        // starting the second KS app triggered a rebalance. Which in turn 
will recycle active tasks that need to become standby.
+        // That's exactly what we are waiting for
+        recycleLatch.await();
+
+        // sending a message to disable retries in the consumer client. if 
there are no messages, it retries the whole sequence of actions,
+        // including the rebalance data. which we don't want, because we just 
staged the right condition
+        IntegrationTestUtils.produceKeyValuesSynchronously(INPUT_TOPIC_NAME, 
List.of(new KeyValue<>(1L, "key")),
+                TestUtils.producerConfig(cluster.bootstrapServers(), 
LongSerializer.class, StringSerializer.class, new Properties()), cluster.time);
+        // Now we can close both apps. The StreamThreadStateListener will 
unblock the clearCache call, letting the rebalance finish.
+        // We don't want it to happen any sooner, because we want the stream 
thread to stop before it gets to moving messages from task registry to state 
updater.
+        
streams1.close(CloseOptions.groupMembershipOperation(CloseOptions.GroupMembershipOperation.LEAVE_GROUP));
+        
streams2.close(CloseOptions.groupMembershipOperation(CloseOptions.GroupMembershipOperation.LEAVE_GROUP));
+
+        assertEquals(initCount.get(), closeCount.get());
+    }
+
+    private Properties props(final String storePathSuffix) {
+        final Properties streamsConfiguration = new Properties();
+
+        streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, 
safeTestName);
+        streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, 
cluster.bootstrapServers());
+        streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, 
"earliest");
+        streamsConfiguration.put(ConsumerConfig.DEFAULT_API_TIMEOUT_MS_CONFIG, 
1000);
+        streamsConfiguration.put(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG, 
1000);
+        streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, 
TestUtils.tempDirectory().getPath() + "/" + storePathSuffix);
+        
streamsConfiguration.put(StreamsConfig.STATESTORE_CACHE_MAX_BYTES_CONFIG, 0);
+        streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 
100L);
+        streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, 
Serdes.LongSerde.class);
+        
streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, 
Serdes.StringSerde.class);
+        streamsConfiguration.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 1);
+        streamsConfiguration.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1);
+
+        return streamsConfiguration;
+    }
+
+    private static class TestRocksDBStore extends RocksDBStore implements 
CachedStateStore<Bytes, byte[]> {
+
+        private final CountDownLatch recycleLatch;
+        private final CountDownLatch pendingShutdownLatch;
+        private final AtomicInteger initCount;
+        private final AtomicInteger closeCount;
+
+        public TestRocksDBStore(final String name,
+                                final CountDownLatch recycleLatch,
+                                final CountDownLatch pendingShutdownLatch,
+                                final AtomicInteger initCount,
+                                final AtomicInteger closeCount) {
+            super(name, "rocksdb");
+            this.recycleLatch = recycleLatch;
+            this.pendingShutdownLatch = pendingShutdownLatch;
+            this.initCount = initCount;
+            this.closeCount = closeCount;
+        }
+
+        @Override
+        public void init(final StateStoreContext stateStoreContext,
+                         final StateStore root) {
+            initCount.incrementAndGet();
+            super.init(stateStoreContext, root);
+        }
+
+        @Override
+        public boolean setFlushListener(final CacheFlushListener<Bytes, 
byte[]> listener,
+                                        final boolean sendOldValues) {
+            return false;

Review Comment:
   Why not `return super.setFlushListener(...)`?



##########
streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/RebalanceTaskClosureIntegrationTest.java:
##########
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.integration;
+
+import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.common.serialization.LongSerializer;
+import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.serialization.StringSerializer;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.streams.CloseOptions;
+import org.apache.kafka.streams.KafkaStreams;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.TopologyWrapper;
+import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
+import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
+import org.apache.kafka.streams.processor.StateStore;
+import org.apache.kafka.streams.processor.StateStoreContext;
+import org.apache.kafka.streams.processor.internals.StreamThread;
+import org.apache.kafka.streams.state.KeyValueStore;
+import org.apache.kafka.streams.state.StoreBuilder;
+import org.apache.kafka.streams.state.internals.AbstractStoreBuilder;
+import org.apache.kafka.streams.state.internals.CacheFlushListener;
+import org.apache.kafka.streams.state.internals.CachedStateStore;
+import org.apache.kafka.streams.state.internals.RocksDBStore;
+import org.apache.kafka.test.MockApiProcessorSupplier;
+import org.apache.kafka.test.TestUtils;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInfo;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.util.List;
+import java.util.Properties;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.apache.kafka.streams.utils.TestUtils.safeUniqueTestName;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class RebalanceTaskClosureIntegrationTest {
+
+    private static final int NUM_BROKERS = 1;
+    protected static final String INPUT_TOPIC_NAME = "input-topic";
+    private static final int NUM_PARTITIONS = 3;
+
+    private final EmbeddedKafkaCluster cluster = new 
EmbeddedKafkaCluster(NUM_BROKERS);
+
+    private KafkaStreamsWrapper streams1;
+    private KafkaStreamsWrapper streams2;
+    private String safeTestName;
+
+    @BeforeEach
+    public void before(final TestInfo testInfo) throws InterruptedException, 
IOException {
+        cluster.start();
+        cluster.createTopic(INPUT_TOPIC_NAME, NUM_PARTITIONS, 1);
+        safeTestName = safeUniqueTestName(testInfo);
+
+    }
+
+    @AfterEach
+    public void after() {
+        cluster.stop();
+        if (streams1 != null) {
+            streams1.close(Duration.ofSeconds(30));
+        }
+        if (streams2 != null) {
+            streams2.close(Duration.ofSeconds(30));
+        }
+    }
+
+    @Test
+    public void shouldClosePendingTasksToInitAfterRebalance() throws Exception 
{

Review Comment:
   Might be good to add a test description at the top, explaining what we are 
exactly testing.
   ```
   Starting two KS app
   Blocking the first KS app during the rebalance (which is started when the 
second apps starts) because <explanation>
   ```
   
   and so forth. I am not sure I understand the test yet, and I am sure if 
somebody come back to it at some point in the future they might struggle, too.



##########
streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/RebalanceTaskClosureIntegrationTest.java:
##########
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.integration;
+
+import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.common.serialization.LongSerializer;
+import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.serialization.StringSerializer;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.streams.CloseOptions;
+import org.apache.kafka.streams.KafkaStreams;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.TopologyWrapper;
+import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
+import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
+import org.apache.kafka.streams.processor.StateStore;
+import org.apache.kafka.streams.processor.StateStoreContext;
+import org.apache.kafka.streams.processor.internals.StreamThread;
+import org.apache.kafka.streams.state.KeyValueStore;
+import org.apache.kafka.streams.state.StoreBuilder;
+import org.apache.kafka.streams.state.internals.AbstractStoreBuilder;
+import org.apache.kafka.streams.state.internals.CacheFlushListener;
+import org.apache.kafka.streams.state.internals.CachedStateStore;
+import org.apache.kafka.streams.state.internals.RocksDBStore;
+import org.apache.kafka.test.MockApiProcessorSupplier;
+import org.apache.kafka.test.TestUtils;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInfo;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.util.List;
+import java.util.Properties;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.apache.kafka.streams.utils.TestUtils.safeUniqueTestName;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class RebalanceTaskClosureIntegrationTest {
+
+    private static final int NUM_BROKERS = 1;
+    protected static final String INPUT_TOPIC_NAME = "input-topic";
+    private static final int NUM_PARTITIONS = 3;
+
+    private final EmbeddedKafkaCluster cluster = new 
EmbeddedKafkaCluster(NUM_BROKERS);
+
+    private KafkaStreamsWrapper streams1;
+    private KafkaStreamsWrapper streams2;
+    private String safeTestName;
+
+    @BeforeEach
+    public void before(final TestInfo testInfo) throws InterruptedException, 
IOException {
+        cluster.start();
+        cluster.createTopic(INPUT_TOPIC_NAME, NUM_PARTITIONS, 1);
+        safeTestName = safeUniqueTestName(testInfo);
+
+    }
+
+    @AfterEach
+    public void after() {
+        cluster.stop();
+        if (streams1 != null) {
+            streams1.close(Duration.ofSeconds(30));
+        }
+        if (streams2 != null) {
+            streams2.close(Duration.ofSeconds(30));
+        }
+    }
+
+    @Test
+    public void shouldClosePendingTasksToInitAfterRebalance() throws Exception 
{
+        final CountDownLatch recycleLatch = new CountDownLatch(1);
+        final CountDownLatch pendingShutdownLatch = new CountDownLatch(1);
+        // Count how many times we initialize and close stores
+        final AtomicInteger initCount = new AtomicInteger();
+        final AtomicInteger closeCount = new AtomicInteger();
+        final StoreBuilder<KeyValueStore<Bytes, byte[]>> storeBuilder = new 
AbstractStoreBuilder<>("testStateStore", Serdes.Integer(), Serdes.ByteArray(), 
new MockTime()) {
+
+            @Override
+            public KeyValueStore<Bytes, byte[]> build() {
+                return new TestRocksDBStore(name, recycleLatch, 
pendingShutdownLatch, initCount, closeCount);
+            }
+        };
+
+        final TopologyWrapper topology = new TopologyWrapper();
+        topology.addSource("ingest", INPUT_TOPIC_NAME);
+        topology.addProcessor("my-processor", new 
MockApiProcessorSupplier<>(), "ingest");
+        topology.addStateStore(storeBuilder, "my-processor");
+
+        streams1 = new KafkaStreamsWrapper(topology, props("1"));
+        streams1.setStreamThreadStateListener((t, newState, oldState) -> {
+            if (newState == StreamThread.State.PENDING_SHUTDOWN) {
+                pendingShutdownLatch.countDown();
+            }
+        });
+        streams1.start();
+
+        TestUtils.waitForCondition(() -> streams1.state() == 
KafkaStreams.State.RUNNING, "Streams never reached RUNNING state");
+
+        streams2 = new KafkaStreamsWrapper(topology, props("2"));
+        streams2.start();
+
+        TestUtils.waitForCondition(() -> streams2.state() == 
KafkaStreams.State.RUNNING, "Streams never reached RUNNING state");
+
+        // starting the second KS app triggered a rebalance. Which in turn 
will recycle active tasks that need to become standby.
+        // That's exactly what we are waiting for
+        recycleLatch.await();
+
+        // sending a message to disable retries in the consumer client. if 
there are no messages, it retries the whole sequence of actions,
+        // including the rebalance data. which we don't want, because we just 
staged the right condition
+        IntegrationTestUtils.produceKeyValuesSynchronously(INPUT_TOPIC_NAME, 
List.of(new KeyValue<>(1L, "key")),
+                TestUtils.producerConfig(cluster.bootstrapServers(), 
LongSerializer.class, StringSerializer.class, new Properties()), cluster.time);
+        // Now we can close both apps. The StreamThreadStateListener will 
unblock the clearCache call, letting the rebalance finish.
+        // We don't want it to happen any sooner, because we want the stream 
thread to stop before it gets to moving messages from task registry to state 
updater.
+        
streams1.close(CloseOptions.groupMembershipOperation(CloseOptions.GroupMembershipOperation.LEAVE_GROUP));
+        
streams2.close(CloseOptions.groupMembershipOperation(CloseOptions.GroupMembershipOperation.LEAVE_GROUP));
+
+        assertEquals(initCount.get(), closeCount.get());
+    }
+
+    private Properties props(final String storePathSuffix) {
+        final Properties streamsConfiguration = new Properties();
+
+        streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, 
safeTestName);
+        streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, 
cluster.bootstrapServers());
+        streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, 
"earliest");
+        streamsConfiguration.put(ConsumerConfig.DEFAULT_API_TIMEOUT_MS_CONFIG, 
1000);
+        streamsConfiguration.put(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG, 
1000);
+        streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, 
TestUtils.tempDirectory().getPath() + "/" + storePathSuffix);
+        
streamsConfiguration.put(StreamsConfig.STATESTORE_CACHE_MAX_BYTES_CONFIG, 0);
+        streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 
100L);
+        streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, 
Serdes.LongSerde.class);
+        
streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, 
Serdes.StringSerde.class);
+        streamsConfiguration.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 1);

Review Comment:
   Same. Already the default.



##########
streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/RebalanceTaskClosureIntegrationTest.java:
##########
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.integration;
+
+import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.common.serialization.LongSerializer;
+import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.serialization.StringSerializer;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.streams.CloseOptions;
+import org.apache.kafka.streams.KafkaStreams;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.TopologyWrapper;
+import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
+import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
+import org.apache.kafka.streams.processor.StateStore;
+import org.apache.kafka.streams.processor.StateStoreContext;
+import org.apache.kafka.streams.processor.internals.StreamThread;
+import org.apache.kafka.streams.state.KeyValueStore;
+import org.apache.kafka.streams.state.StoreBuilder;
+import org.apache.kafka.streams.state.internals.AbstractStoreBuilder;
+import org.apache.kafka.streams.state.internals.CacheFlushListener;
+import org.apache.kafka.streams.state.internals.CachedStateStore;
+import org.apache.kafka.streams.state.internals.RocksDBStore;
+import org.apache.kafka.test.MockApiProcessorSupplier;
+import org.apache.kafka.test.TestUtils;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInfo;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.util.List;
+import java.util.Properties;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.apache.kafka.streams.utils.TestUtils.safeUniqueTestName;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class RebalanceTaskClosureIntegrationTest {
+
+    private static final int NUM_BROKERS = 1;
+    protected static final String INPUT_TOPIC_NAME = "input-topic";
+    private static final int NUM_PARTITIONS = 3;
+
+    private final EmbeddedKafkaCluster cluster = new 
EmbeddedKafkaCluster(NUM_BROKERS);
+
+    private KafkaStreamsWrapper streams1;
+    private KafkaStreamsWrapper streams2;
+    private String safeTestName;
+
+    @BeforeEach
+    public void before(final TestInfo testInfo) throws InterruptedException, 
IOException {
+        cluster.start();
+        cluster.createTopic(INPUT_TOPIC_NAME, NUM_PARTITIONS, 1);
+        safeTestName = safeUniqueTestName(testInfo);
+
+    }
+
+    @AfterEach
+    public void after() {
+        cluster.stop();
+        if (streams1 != null) {
+            streams1.close(Duration.ofSeconds(30));
+        }
+        if (streams2 != null) {
+            streams2.close(Duration.ofSeconds(30));
+        }
+    }
+
+    @Test
+    public void shouldClosePendingTasksToInitAfterRebalance() throws Exception 
{
+        final CountDownLatch recycleLatch = new CountDownLatch(1);
+        final CountDownLatch pendingShutdownLatch = new CountDownLatch(1);
+        // Count how many times we initialize and close stores
+        final AtomicInteger initCount = new AtomicInteger();
+        final AtomicInteger closeCount = new AtomicInteger();
+        final StoreBuilder<KeyValueStore<Bytes, byte[]>> storeBuilder = new 
AbstractStoreBuilder<>("testStateStore", Serdes.Integer(), Serdes.ByteArray(), 
new MockTime()) {
+
+            @Override
+            public KeyValueStore<Bytes, byte[]> build() {
+                return new TestRocksDBStore(name, recycleLatch, 
pendingShutdownLatch, initCount, closeCount);
+            }
+        };
+
+        final TopologyWrapper topology = new TopologyWrapper();
+        topology.addSource("ingest", INPUT_TOPIC_NAME);
+        topology.addProcessor("my-processor", new 
MockApiProcessorSupplier<>(), "ingest");
+        topology.addStateStore(storeBuilder, "my-processor");
+
+        streams1 = new KafkaStreamsWrapper(topology, props("1"));
+        streams1.setStreamThreadStateListener((t, newState, oldState) -> {
+            if (newState == StreamThread.State.PENDING_SHUTDOWN) {
+                pendingShutdownLatch.countDown();
+            }
+        });
+        streams1.start();
+
+        TestUtils.waitForCondition(() -> streams1.state() == 
KafkaStreams.State.RUNNING, "Streams never reached RUNNING state");
+
+        streams2 = new KafkaStreamsWrapper(topology, props("2"));
+        streams2.start();
+
+        TestUtils.waitForCondition(() -> streams2.state() == 
KafkaStreams.State.RUNNING, "Streams never reached RUNNING state");
+
+        // starting the second KS app triggered a rebalance. Which in turn 
will recycle active tasks that need to become standby.
+        // That's exactly what we are waiting for
+        recycleLatch.await();
+
+        // sending a message to disable retries in the consumer client. if 
there are no messages, it retries the whole sequence of actions,
+        // including the rebalance data. which we don't want, because we just 
staged the right condition
+        IntegrationTestUtils.produceKeyValuesSynchronously(INPUT_TOPIC_NAME, 
List.of(new KeyValue<>(1L, "key")),
+                TestUtils.producerConfig(cluster.bootstrapServers(), 
LongSerializer.class, StringSerializer.class, new Properties()), cluster.time);
+        // Now we can close both apps. The StreamThreadStateListener will 
unblock the clearCache call, letting the rebalance finish.
+        // We don't want it to happen any sooner, because we want the stream 
thread to stop before it gets to moving messages from task registry to state 
updater.

Review Comment:
   "before it gets to moving messages from task registry to state updater" -- 
do you mean "task" instead of "messaged"?



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java:
##########
@@ -193,7 +193,9 @@ public synchronized void removeTask(final Task 
taskToRemove) {
             throw new IllegalStateException("Attempted to remove a task that 
is not closed or suspended: " + taskId);
         }
 
-        if (taskToRemove.isActive()) {
+        if (pendingTasksToInit.contains(taskToRemove)) {

Review Comment:
   We throw above if a task is not CLOSED or SUSPENDED, but `pendingTaskToInit` 
should be in state `CREATED` ? -- Do we need to update the above condition?
   
   Are we testing this code path properly? If no test fails, it seems we never 
call `removeTask()` with a pending task?



##########
streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/RebalanceTaskClosureIntegrationTest.java:
##########
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.integration;
+
+import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.common.serialization.LongSerializer;
+import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.serialization.StringSerializer;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.streams.CloseOptions;
+import org.apache.kafka.streams.KafkaStreams;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.TopologyWrapper;
+import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
+import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
+import org.apache.kafka.streams.processor.StateStore;
+import org.apache.kafka.streams.processor.StateStoreContext;
+import org.apache.kafka.streams.processor.internals.StreamThread;
+import org.apache.kafka.streams.state.KeyValueStore;
+import org.apache.kafka.streams.state.StoreBuilder;
+import org.apache.kafka.streams.state.internals.AbstractStoreBuilder;
+import org.apache.kafka.streams.state.internals.CacheFlushListener;
+import org.apache.kafka.streams.state.internals.CachedStateStore;
+import org.apache.kafka.streams.state.internals.RocksDBStore;
+import org.apache.kafka.test.MockApiProcessorSupplier;
+import org.apache.kafka.test.TestUtils;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInfo;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.util.List;
+import java.util.Properties;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.apache.kafka.streams.utils.TestUtils.safeUniqueTestName;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class RebalanceTaskClosureIntegrationTest {
+
+    private static final int NUM_BROKERS = 1;
+    protected static final String INPUT_TOPIC_NAME = "input-topic";
+    private static final int NUM_PARTITIONS = 3;
+
+    private final EmbeddedKafkaCluster cluster = new 
EmbeddedKafkaCluster(NUM_BROKERS);
+
+    private KafkaStreamsWrapper streams1;
+    private KafkaStreamsWrapper streams2;
+    private String safeTestName;
+
+    @BeforeEach
+    public void before(final TestInfo testInfo) throws InterruptedException, 
IOException {
+        cluster.start();
+        cluster.createTopic(INPUT_TOPIC_NAME, NUM_PARTITIONS, 1);
+        safeTestName = safeUniqueTestName(testInfo);
+
+    }
+
+    @AfterEach
+    public void after() {
+        cluster.stop();
+        if (streams1 != null) {
+            streams1.close(Duration.ofSeconds(30));
+        }
+        if (streams2 != null) {
+            streams2.close(Duration.ofSeconds(30));
+        }
+    }
+
+    @Test
+    public void shouldClosePendingTasksToInitAfterRebalance() throws Exception 
{
+        final CountDownLatch recycleLatch = new CountDownLatch(1);
+        final CountDownLatch pendingShutdownLatch = new CountDownLatch(1);
+        // Count how many times we initialize and close stores
+        final AtomicInteger initCount = new AtomicInteger();
+        final AtomicInteger closeCount = new AtomicInteger();
+        final StoreBuilder<KeyValueStore<Bytes, byte[]>> storeBuilder = new 
AbstractStoreBuilder<>("testStateStore", Serdes.Integer(), Serdes.ByteArray(), 
new MockTime()) {
+
+            @Override
+            public KeyValueStore<Bytes, byte[]> build() {
+                return new TestRocksDBStore(name, recycleLatch, 
pendingShutdownLatch, initCount, closeCount);
+            }
+        };
+
+        final TopologyWrapper topology = new TopologyWrapper();
+        topology.addSource("ingest", INPUT_TOPIC_NAME);
+        topology.addProcessor("my-processor", new 
MockApiProcessorSupplier<>(), "ingest");
+        topology.addStateStore(storeBuilder, "my-processor");
+
+        streams1 = new KafkaStreamsWrapper(topology, props("1"));
+        streams1.setStreamThreadStateListener((t, newState, oldState) -> {
+            if (newState == StreamThread.State.PENDING_SHUTDOWN) {
+                pendingShutdownLatch.countDown();
+            }
+        });
+        streams1.start();
+
+        TestUtils.waitForCondition(() -> streams1.state() == 
KafkaStreams.State.RUNNING, "Streams never reached RUNNING state");
+
+        streams2 = new KafkaStreamsWrapper(topology, props("2"));
+        streams2.start();
+
+        TestUtils.waitForCondition(() -> streams2.state() == 
KafkaStreams.State.RUNNING, "Streams never reached RUNNING state");
+
+        // starting the second KS app triggered a rebalance. Which in turn 
will recycle active tasks that need to become standby.
+        // That's exactly what we are waiting for
+        recycleLatch.await();
+
+        // sending a message to disable retries in the consumer client. if 
there are no messages, it retries the whole sequence of actions,
+        // including the rebalance data. which we don't want, because we just 
staged the right condition
+        IntegrationTestUtils.produceKeyValuesSynchronously(INPUT_TOPIC_NAME, 
List.of(new KeyValue<>(1L, "key")),
+                TestUtils.producerConfig(cluster.bootstrapServers(), 
LongSerializer.class, StringSerializer.class, new Properties()), cluster.time);
+        // Now we can close both apps. The StreamThreadStateListener will 
unblock the clearCache call, letting the rebalance finish.
+        // We don't want it to happen any sooner, because we want the stream 
thread to stop before it gets to moving messages from task registry to state 
updater.
+        
streams1.close(CloseOptions.groupMembershipOperation(CloseOptions.GroupMembershipOperation.LEAVE_GROUP));
+        
streams2.close(CloseOptions.groupMembershipOperation(CloseOptions.GroupMembershipOperation.LEAVE_GROUP));
+
+        assertEquals(initCount.get(), closeCount.get());
+    }
+
+    private Properties props(final String storePathSuffix) {
+        final Properties streamsConfiguration = new Properties();
+
+        streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, 
safeTestName);
+        streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, 
cluster.bootstrapServers());
+        streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, 
"earliest");
+        streamsConfiguration.put(ConsumerConfig.DEFAULT_API_TIMEOUT_MS_CONFIG, 
1000);
+        streamsConfiguration.put(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG, 
1000);
+        streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, 
TestUtils.tempDirectory().getPath() + "/" + storePathSuffix);
+        
streamsConfiguration.put(StreamsConfig.STATESTORE_CACHE_MAX_BYTES_CONFIG, 0);
+        streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 
100L);
+        streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, 
Serdes.LongSerde.class);
+        
streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, 
Serdes.StringSerde.class);
+        streamsConfiguration.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 1);
+        streamsConfiguration.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1);
+
+        return streamsConfiguration;
+    }
+
+    private static class TestRocksDBStore extends RocksDBStore implements 
CachedStateStore<Bytes, byte[]> {
+
+        private final CountDownLatch recycleLatch;
+        private final CountDownLatch pendingShutdownLatch;
+        private final AtomicInteger initCount;
+        private final AtomicInteger closeCount;
+
+        public TestRocksDBStore(final String name,
+                                final CountDownLatch recycleLatch,
+                                final CountDownLatch pendingShutdownLatch,
+                                final AtomicInteger initCount,
+                                final AtomicInteger closeCount) {
+            super(name, "rocksdb");
+            this.recycleLatch = recycleLatch;
+            this.pendingShutdownLatch = pendingShutdownLatch;
+            this.initCount = initCount;
+            this.closeCount = closeCount;
+        }
+
+        @Override
+        public void init(final StateStoreContext stateStoreContext,
+                         final StateStore root) {
+            initCount.incrementAndGet();
+            super.init(stateStoreContext, root);
+        }
+
+        @Override
+        public boolean setFlushListener(final CacheFlushListener<Bytes, 
byte[]> listener,
+                                        final boolean sendOldValues) {
+            return false;
+        }
+
+        @Override
+        public void flushCache() {
+        }

Review Comment:
   Should we call `super.flushCache()`?



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/TaskManager.java:
##########
@@ -1523,7 +1532,7 @@ private Collection<Task> tryCloseCleanActiveTasks(final 
Collection<Task> activeT
                                                       final boolean clean,
                                                       final 
AtomicReference<RuntimeException> firstException) {
         if (!clean) {
-            return activeTaskIterable();
+            return activeTasksToClose;

Review Comment:
   For my own understanding. The original code here was not wrong in the sense 
that before this PR the `activeTasksToClose` that got passed in was the same as 
`activeTaskIterable()`? But with this PR, we also include pending task in 
`activeTasksToClose`, and that's why we need to update the code here?



##########
streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/RebalanceTaskClosureIntegrationTest.java:
##########
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.integration;
+
+import org.apache.kafka.clients.consumer.ConsumerConfig;
+import org.apache.kafka.common.serialization.LongSerializer;
+import org.apache.kafka.common.serialization.Serdes;
+import org.apache.kafka.common.serialization.StringSerializer;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.common.utils.MockTime;
+import org.apache.kafka.streams.CloseOptions;
+import org.apache.kafka.streams.KafkaStreams;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.StreamsConfig;
+import org.apache.kafka.streams.TopologyWrapper;
+import org.apache.kafka.streams.integration.utils.EmbeddedKafkaCluster;
+import org.apache.kafka.streams.integration.utils.IntegrationTestUtils;
+import org.apache.kafka.streams.processor.StateStore;
+import org.apache.kafka.streams.processor.StateStoreContext;
+import org.apache.kafka.streams.processor.internals.StreamThread;
+import org.apache.kafka.streams.state.KeyValueStore;
+import org.apache.kafka.streams.state.StoreBuilder;
+import org.apache.kafka.streams.state.internals.AbstractStoreBuilder;
+import org.apache.kafka.streams.state.internals.CacheFlushListener;
+import org.apache.kafka.streams.state.internals.CachedStateStore;
+import org.apache.kafka.streams.state.internals.RocksDBStore;
+import org.apache.kafka.test.MockApiProcessorSupplier;
+import org.apache.kafka.test.TestUtils;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.TestInfo;
+
+import java.io.IOException;
+import java.time.Duration;
+import java.util.List;
+import java.util.Properties;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import static org.apache.kafka.streams.utils.TestUtils.safeUniqueTestName;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class RebalanceTaskClosureIntegrationTest {
+
+    private static final int NUM_BROKERS = 1;
+    protected static final String INPUT_TOPIC_NAME = "input-topic";
+    private static final int NUM_PARTITIONS = 3;
+
+    private final EmbeddedKafkaCluster cluster = new 
EmbeddedKafkaCluster(NUM_BROKERS);
+
+    private KafkaStreamsWrapper streams1;
+    private KafkaStreamsWrapper streams2;
+    private String safeTestName;
+
+    @BeforeEach
+    public void before(final TestInfo testInfo) throws InterruptedException, 
IOException {
+        cluster.start();
+        cluster.createTopic(INPUT_TOPIC_NAME, NUM_PARTITIONS, 1);
+        safeTestName = safeUniqueTestName(testInfo);
+
+    }
+
+    @AfterEach
+    public void after() {
+        cluster.stop();
+        if (streams1 != null) {
+            streams1.close(Duration.ofSeconds(30));
+        }
+        if (streams2 != null) {
+            streams2.close(Duration.ofSeconds(30));
+        }
+    }
+
+    @Test
+    public void shouldClosePendingTasksToInitAfterRebalance() throws Exception 
{
+        final CountDownLatch recycleLatch = new CountDownLatch(1);
+        final CountDownLatch pendingShutdownLatch = new CountDownLatch(1);
+        // Count how many times we initialize and close stores
+        final AtomicInteger initCount = new AtomicInteger();
+        final AtomicInteger closeCount = new AtomicInteger();
+        final StoreBuilder<KeyValueStore<Bytes, byte[]>> storeBuilder = new 
AbstractStoreBuilder<>("testStateStore", Serdes.Integer(), Serdes.ByteArray(), 
new MockTime()) {
+
+            @Override
+            public KeyValueStore<Bytes, byte[]> build() {
+                return new TestRocksDBStore(name, recycleLatch, 
pendingShutdownLatch, initCount, closeCount);
+            }
+        };
+
+        final TopologyWrapper topology = new TopologyWrapper();
+        topology.addSource("ingest", INPUT_TOPIC_NAME);
+        topology.addProcessor("my-processor", new 
MockApiProcessorSupplier<>(), "ingest");
+        topology.addStateStore(storeBuilder, "my-processor");
+
+        streams1 = new KafkaStreamsWrapper(topology, props("1"));
+        streams1.setStreamThreadStateListener((t, newState, oldState) -> {
+            if (newState == StreamThread.State.PENDING_SHUTDOWN) {
+                pendingShutdownLatch.countDown();
+            }
+        });
+        streams1.start();
+
+        TestUtils.waitForCondition(() -> streams1.state() == 
KafkaStreams.State.RUNNING, "Streams never reached RUNNING state");
+
+        streams2 = new KafkaStreamsWrapper(topology, props("2"));
+        streams2.start();
+
+        TestUtils.waitForCondition(() -> streams2.state() == 
KafkaStreams.State.RUNNING, "Streams never reached RUNNING state");
+
+        // starting the second KS app triggered a rebalance. Which in turn 
will recycle active tasks that need to become standby.
+        // That's exactly what we are waiting for
+        recycleLatch.await();
+
+        // sending a message to disable retries in the consumer client. if 
there are no messages, it retries the whole sequence of actions,
+        // including the rebalance data. which we don't want, because we just 
staged the right condition
+        IntegrationTestUtils.produceKeyValuesSynchronously(INPUT_TOPIC_NAME, 
List.of(new KeyValue<>(1L, "key")),
+                TestUtils.producerConfig(cluster.bootstrapServers(), 
LongSerializer.class, StringSerializer.class, new Properties()), cluster.time);
+        // Now we can close both apps. The StreamThreadStateListener will 
unblock the clearCache call, letting the rebalance finish.
+        // We don't want it to happen any sooner, because we want the stream 
thread to stop before it gets to moving messages from task registry to state 
updater.
+        
streams1.close(CloseOptions.groupMembershipOperation(CloseOptions.GroupMembershipOperation.LEAVE_GROUP));
+        
streams2.close(CloseOptions.groupMembershipOperation(CloseOptions.GroupMembershipOperation.LEAVE_GROUP));
+
+        assertEquals(initCount.get(), closeCount.get());
+    }
+
+    private Properties props(final String storePathSuffix) {
+        final Properties streamsConfiguration = new Properties();
+
+        streamsConfiguration.put(StreamsConfig.APPLICATION_ID_CONFIG, 
safeTestName);
+        streamsConfiguration.put(StreamsConfig.BOOTSTRAP_SERVERS_CONFIG, 
cluster.bootstrapServers());
+        streamsConfiguration.put(ConsumerConfig.AUTO_OFFSET_RESET_CONFIG, 
"earliest");
+        streamsConfiguration.put(ConsumerConfig.DEFAULT_API_TIMEOUT_MS_CONFIG, 
1000);
+        streamsConfiguration.put(ConsumerConfig.REQUEST_TIMEOUT_MS_CONFIG, 
1000);
+        streamsConfiguration.put(StreamsConfig.STATE_DIR_CONFIG, 
TestUtils.tempDirectory().getPath() + "/" + storePathSuffix);
+        
streamsConfiguration.put(StreamsConfig.STATESTORE_CACHE_MAX_BYTES_CONFIG, 0);
+        streamsConfiguration.put(StreamsConfig.COMMIT_INTERVAL_MS_CONFIG, 
100L);
+        streamsConfiguration.put(StreamsConfig.DEFAULT_KEY_SERDE_CLASS_CONFIG, 
Serdes.LongSerde.class);
+        
streamsConfiguration.put(StreamsConfig.DEFAULT_VALUE_SERDE_CLASS_CONFIG, 
Serdes.StringSerde.class);
+        streamsConfiguration.put(StreamsConfig.NUM_STREAM_THREADS_CONFIG, 1);
+        streamsConfiguration.put(StreamsConfig.NUM_STANDBY_REPLICAS_CONFIG, 1);
+
+        return streamsConfiguration;
+    }
+
+    private static class TestRocksDBStore extends RocksDBStore implements 
CachedStateStore<Bytes, byte[]> {
+
+        private final CountDownLatch recycleLatch;
+        private final CountDownLatch pendingShutdownLatch;
+        private final AtomicInteger initCount;
+        private final AtomicInteger closeCount;
+
+        public TestRocksDBStore(final String name,
+                                final CountDownLatch recycleLatch,
+                                final CountDownLatch pendingShutdownLatch,
+                                final AtomicInteger initCount,
+                                final AtomicInteger closeCount) {
+            super(name, "rocksdb");
+            this.recycleLatch = recycleLatch;
+            this.pendingShutdownLatch = pendingShutdownLatch;
+            this.initCount = initCount;
+            this.closeCount = closeCount;
+        }
+
+        @Override
+        public void init(final StateStoreContext stateStoreContext,
+                         final StateStore root) {
+            initCount.incrementAndGet();
+            super.init(stateStoreContext, root);
+        }
+
+        @Override
+        public boolean setFlushListener(final CacheFlushListener<Bytes, 
byte[]> listener,
+                                        final boolean sendOldValues) {
+            return false;
+        }
+
+        @Override
+        public void flushCache() {
+        }
+
+        @Override
+        public void clearCache() {

Review Comment:
   Should we call `super.clearCache()` at some point (not sure if in the 
beginning or at the end of the method)?



##########
streams/src/main/java/org/apache/kafka/streams/processor/internals/Tasks.java:
##########


Review Comment:
   Can you explain why we also need a fix for this method in more details? It 
make sense to me, that we need to fix `shutdown()`, but not totally sure about 
this one.
   
   Not saying we don't need to fix it, I just don't see the full picture of the 
bug yet.



##########
streams/integration-tests/src/test/java/org/apache/kafka/streams/integration/KafkaStreamsWrapper.java:
##########
@@ -48,7 +48,11 @@ public List<StreamThread> streamThreads() {
     public void setStreamThreadStateListener(final StreamThread.StateListener 
listener) {
         if (state == State.CREATED) {
             for (final StreamThread thread : threads) {
-                thread.setStateListener(listener);
+                StreamThread.StateListener originalListener = 
thread.getStateListener();
+                thread.setStateListener((t, newState, oldState) -> {
+                    originalListener.onChange(t, newState, oldState);
+                    listener.onChange(t, newState, oldState);
+                });

Review Comment:
   Why do we need this change? It seems you try to allow to register a second 
state listener? In general, we allow to only register one -- why do we need 
two? (Also, if we need multiple, why limit to two?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to