This is an automated email from the ASF dual-hosted git repository.
bbejeck pushed a commit to branch trunk
in repository https://gitbox.apache.org/repos/asf/kafka.git
The following commit(s) were added to refs/heads/trunk by this push:
new 9f15f269d5b KAFKA-20158: Fix header-based key deserialization in
MeteredSessionStoreWithHeaders iterator methods (#21734)
9f15f269d5b is described below
commit 9f15f269d5b67de861014e1e16b51a078274d427
Author: Bill Bejeck <[email protected]>
AuthorDate: Sat Mar 14 12:14:10 2026 -0400
KAFKA-20158: Fix header-based key deserialization in
MeteredSessionStoreWithHeaders iterator methods (#21734)
Fixes a bug where MeteredSessionStoreWithHeaders iterator methods fail
when key deserializers require headers.
The class inherits iterator-returning methods from
MeteredSessionStore: - fetch(K) / backwardFetch(K) - fetch(K, K) /
backwardFetch(K, K) - findSessions(K, long, long) /
backwardFindSessions(K, long, long) - findSessions(K, K, long, long)
/ backwardFindSessions(K, K, long, long) - findSessions(long, long)
These methods use serdes.keyFrom(bytes, new RecordHeaders()) which
provides empty headers, causing deserialization failures when headers
are required.
The fix overrides each method in MeteredSessionStoreWithHeaders with a
custom iterator that deserializes the value first to extract headers
from AggregationWithHeaders, then uses those headers to deserialize the
key.
This is the session store equivalent of #21705 which fixed the same
issue for MeteredTimestampedWindowStoreWithHeaders.
Reviewers: Matthias Sax <[email protected]>
---
.../internals/MeteredSessionStoreWithHeaders.java | 164 +++++++++++++
.../MeteredSessionStoreWithHeadersTest.java | 262 +++++++++++++++++++++
2 files changed, 426 insertions(+)
diff --git
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeaders.java
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeaders.java
index 444d29834bb..7f0b8d35cf7 100644
---
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeaders.java
+++
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeaders.java
@@ -21,6 +21,7 @@ import org.apache.kafka.common.header.internals.RecordHeaders;
import org.apache.kafka.common.serialization.Serde;
import org.apache.kafka.common.utils.Bytes;
import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.KeyValue;
import org.apache.kafka.streams.errors.ProcessorStateException;
import org.apache.kafka.streams.kstream.Windowed;
import org.apache.kafka.streams.query.FailureReason;
@@ -105,6 +106,104 @@ public class MeteredSessionStoreWithHeaders<K, AGG>
return result;
}
+ @Override
+ public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>>
fetch(final K key) {
+ Objects.requireNonNull(key, "key cannot be null");
+ return new MeteredSessionStoreWithHeadersIterator(
+ wrapped().fetch(keyBytes(key, new RecordHeaders(), serdes))
+ );
+ }
+
+ @Override
+ public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>>
backwardFetch(final K key) {
+ Objects.requireNonNull(key, "key cannot be null");
+ return new MeteredSessionStoreWithHeadersIterator(
+ wrapped().backwardFetch(keyBytes(key, new RecordHeaders(), serdes))
+ );
+ }
+
+ @Override
+ public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>>
fetch(final K keyFrom,
+
final K keyTo) {
+ return new MeteredSessionStoreWithHeadersIterator(
+ wrapped().fetch(
+ keyBytes(keyFrom, new RecordHeaders(), serdes),
+ keyBytes(keyTo, new RecordHeaders(), serdes))
+ );
+ }
+
+ @Override
+ public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>>
backwardFetch(final K keyFrom,
+
final K keyTo) {
+ return new MeteredSessionStoreWithHeadersIterator(
+ wrapped().backwardFetch(
+ keyBytes(keyFrom, new RecordHeaders(), serdes),
+ keyBytes(keyTo, new RecordHeaders(), serdes))
+ );
+ }
+
+ @Override
+ public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>>
findSessions(final K key,
+
final long earliestSessionEndTime,
+
final long latestSessionStartTime) {
+ Objects.requireNonNull(key, "key cannot be null");
+ return new MeteredSessionStoreWithHeadersIterator(
+ wrapped().findSessions(
+ keyBytes(key, new RecordHeaders(), serdes),
+ earliestSessionEndTime,
+ latestSessionStartTime)
+ );
+ }
+
+ @Override
+ public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>>
backwardFindSessions(final K key,
+
final long earliestSessionEndTime,
+
final long latestSessionStartTime) {
+ Objects.requireNonNull(key, "key cannot be null");
+ return new MeteredSessionStoreWithHeadersIterator(
+ wrapped().backwardFindSessions(
+ keyBytes(key, new RecordHeaders(), serdes),
+ earliestSessionEndTime,
+ latestSessionStartTime)
+ );
+ }
+
+ @Override
+ public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>>
findSessions(final K keyFrom,
+
final K keyTo,
+
final long earliestSessionEndTime,
+
final long latestSessionStartTime) {
+ return new MeteredSessionStoreWithHeadersIterator(
+ wrapped().findSessions(
+ keyBytes(keyFrom, new RecordHeaders(), serdes),
+ keyBytes(keyTo, new RecordHeaders(), serdes),
+ earliestSessionEndTime,
+ latestSessionStartTime)
+ );
+ }
+
+ @Override
+ public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>>
backwardFindSessions(final K keyFrom,
+
final K keyTo,
+
final long earliestSessionEndTime,
+
final long latestSessionStartTime) {
+ return new MeteredSessionStoreWithHeadersIterator(
+ wrapped().backwardFindSessions(
+ keyBytes(keyFrom, new RecordHeaders(), serdes),
+ keyBytes(keyTo, new RecordHeaders(), serdes),
+ earliestSessionEndTime,
+ latestSessionStartTime)
+ );
+ }
+
+ @Override
+ public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>>
findSessions(final long earliestSessionEndTime,
+
final long latestSessionEndTime) {
+ return new MeteredSessionStoreWithHeadersIterator(
+ wrapped().findSessions(earliestSessionEndTime,
latestSessionEndTime)
+ );
+ }
+
@SuppressWarnings("unchecked")
private <R> QueryResult<R> runRangeQuery(final Query<R> query,
final PositionBound positionBound,
@@ -138,4 +237,69 @@ public class MeteredSessionStoreWithHeaders<K, AGG>
return (QueryResult<R>) rawResult;
}
}
+
+ private class MeteredSessionStoreWithHeadersIterator
+ implements KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>>,
MeteredIterator {
+
+ private final KeyValueIterator<Windowed<Bytes>, byte[]> iter;
+ private final long startNs;
+ private final long startTimestampMs;
+ private KeyValue<Windowed<K>, AggregationWithHeaders<AGG>> cachedNext;
+
+ private MeteredSessionStoreWithHeadersIterator(final
KeyValueIterator<Windowed<Bytes>, byte[]> iter) {
+ this.iter = iter;
+ this.startNs = time.nanoseconds();
+ this.startTimestampMs = time.milliseconds();
+ numOpenIterators.increment();
+ openIterators.add(this);
+ }
+
+ @Override
+ public long startTimestamp() {
+ return startTimestampMs;
+ }
+
+ @Override
+ public boolean hasNext() {
+ return cachedNext != null || iter.hasNext();
+ }
+
+ @Override
+ public KeyValue<Windowed<K>, AggregationWithHeaders<AGG>> next() {
+ if (cachedNext != null) {
+ final KeyValue<Windowed<K>, AggregationWithHeaders<AGG>>
result = cachedNext;
+ cachedNext = null;
+ return result;
+ }
+
+ final KeyValue<Windowed<Bytes>, byte[]> next = iter.next();
+
+ final AggregationWithHeaders<AGG> value =
serdes.valueFrom(next.value, new RecordHeaders());
+ final Headers headers = value != null ? value.headers() : new
RecordHeaders();
+ final K key = serdes.keyFrom(next.key.key().get(), headers);
+ final Windowed<K> windowedKey = new Windowed<>(key,
next.key.window());
+ return KeyValue.pair(windowedKey, value);
+ }
+
+ @Override
+ public void close() {
+ try {
+ iter.close();
+ } finally {
+ final long duration = time.nanoseconds() - startNs;
+ fetchSensor.record(duration);
+ iteratorDurationSensor.record(duration);
+ numOpenIterators.decrement();
+ openIterators.remove(this);
+ }
+ }
+
+ @Override
+ public Windowed<K> peekNextKey() {
+ if (cachedNext == null) {
+ cachedNext = next();
+ }
+ return cachedNext.key;
+ }
+ }
}
diff --git
a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeadersTest.java
b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeadersTest.java
index 236f704e946..4d6aaa8e2d7 100644
---
a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeadersTest.java
+++
b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeadersTest.java
@@ -74,8 +74,10 @@ import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doNothing;
import static org.mockito.Mockito.doThrow;
+import static org.mockito.Mockito.lenient;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
@@ -806,4 +808,264 @@ public class MeteredSessionStoreWithHeadersTest {
assertTrue(result.isFailure());
}
+
+ // --- Tests verifying headers from value are used to deserialize keys ---
+
+ private static final Headers HEADERS = new RecordHeaders().add("key1",
"value1".getBytes());
+ private static final AggregationWithHeaders<String> AGG_WITH_HEADERS =
AggregationWithHeaders.make(VALUE, HEADERS);
+ private static final byte[] SERIALIZED_VALUE = new
AggregationWithHeadersSerializer<>(Serdes.String().serializer())
+ .serialize(CHANGELOG_TOPIC, AGG_WITH_HEADERS);
+
+ @SuppressWarnings("unchecked")
+ private MeteredSessionStoreWithHeaders<String, String>
createStoreWithMockSerdes(
+ final Serde<String> keySerde
+ ) {
+ final Deserializer<String> keyDeserializer = mock(Deserializer.class);
+ final Serializer<String> keySerializer = mock(Serializer.class);
+ final Deserializer<AggregationWithHeaders<String>> valueDeserializer =
mock(Deserializer.class);
+ final Serde<AggregationWithHeaders<String>> valueSerde =
mock(Serde.class);
+
+ lenient().when(keySerde.deserializer()).thenReturn(keyDeserializer);
+ lenient().when(keySerde.serializer()).thenReturn(keySerializer);
+
lenient().when(valueSerde.deserializer()).thenReturn(valueDeserializer);
+
+ lenient().when(keySerializer.serialize(any(),
any(RecordHeaders.class), any())).thenReturn(KEY.getBytes());
+
+ lenient().when(valueDeserializer.deserialize(any(),
any(RecordHeaders.class), eq(SERIALIZED_VALUE)))
+ .thenReturn(AGG_WITH_HEADERS);
+
+ lenient().when(keyDeserializer.deserialize(any(), eq(HEADERS),
eq(KEY.getBytes())))
+ .thenReturn(KEY);
+
+ final MeteredSessionStoreWithHeaders<String, String> mockStore = new
MeteredSessionStoreWithHeaders<>(
+ innerStore,
+ STORE_TYPE,
+ keySerde,
+ valueSerde,
+ new MockTime()
+ );
+ mockStore.init(context, mockStore);
+ return mockStore;
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void shouldUseHeadersFromValueToDeserializeKeyInFetch() {
+ setUp();
+ final Serde<String> keySerde = mock(Serde.class);
+ final MeteredSessionStoreWithHeaders<String, String> store =
createStoreWithMockSerdes(keySerde);
+
+ when(innerStore.fetch(any(Bytes.class)))
+ .thenReturn(new KeyValueIteratorStub<>(
+ List.of(KeyValue.pair(WINDOWED_KEY_BYTES,
SERIALIZED_VALUE)).iterator()));
+
+ final KeyValueIterator<Windowed<String>,
AggregationWithHeaders<String>> iterator = store.fetch(KEY);
+
+ assertTrue(iterator.hasNext());
+ assertEquals(KEY, iterator.peekNextKey().key());
+ final KeyValue<Windowed<String>, AggregationWithHeaders<String>>
result = iterator.next();
+ assertEquals(KEY, result.key.key());
+ assertEquals(AGG_WITH_HEADERS, result.value);
+ assertFalse(iterator.hasNext());
+ iterator.close();
+
+ verify(keySerde.deserializer()).deserialize(any(), eq(HEADERS),
eq(KEY.getBytes()));
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void shouldUseHeadersFromValueToDeserializeKeyInBackwardFetch() {
+ setUp();
+ final Serde<String> keySerde = mock(Serde.class);
+ final MeteredSessionStoreWithHeaders<String, String> store =
createStoreWithMockSerdes(keySerde);
+
+ when(innerStore.backwardFetch(any(Bytes.class)))
+ .thenReturn(new KeyValueIteratorStub<>(
+ List.of(KeyValue.pair(WINDOWED_KEY_BYTES,
SERIALIZED_VALUE)).iterator()));
+
+ final KeyValueIterator<Windowed<String>,
AggregationWithHeaders<String>> iterator = store.backwardFetch(KEY);
+
+ assertTrue(iterator.hasNext());
+ assertEquals(KEY, iterator.peekNextKey().key());
+ final KeyValue<Windowed<String>, AggregationWithHeaders<String>>
result = iterator.next();
+ assertEquals(KEY, result.key.key());
+ assertEquals(AGG_WITH_HEADERS, result.value);
+ assertFalse(iterator.hasNext());
+ iterator.close();
+
+ verify(keySerde.deserializer()).deserialize(any(), eq(HEADERS),
eq(KEY.getBytes()));
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void shouldUseHeadersFromValueToDeserializeKeyInFetchRange() {
+ setUp();
+ final Serde<String> keySerde = mock(Serde.class);
+ final MeteredSessionStoreWithHeaders<String, String> store =
createStoreWithMockSerdes(keySerde);
+
+ when(innerStore.fetch(any(Bytes.class), any(Bytes.class)))
+ .thenReturn(new KeyValueIteratorStub<>(
+ List.of(KeyValue.pair(WINDOWED_KEY_BYTES,
SERIALIZED_VALUE)).iterator()));
+
+ final KeyValueIterator<Windowed<String>,
AggregationWithHeaders<String>> iterator = store.fetch(KEY, KEY);
+
+ assertTrue(iterator.hasNext());
+ assertEquals(KEY, iterator.peekNextKey().key());
+ final KeyValue<Windowed<String>, AggregationWithHeaders<String>>
result = iterator.next();
+ assertEquals(KEY, result.key.key());
+ assertEquals(AGG_WITH_HEADERS, result.value);
+ assertFalse(iterator.hasNext());
+ iterator.close();
+
+ verify(keySerde.deserializer()).deserialize(any(), eq(HEADERS),
eq(KEY.getBytes()));
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void
shouldUseHeadersFromValueToDeserializeKeyInBackwardFetchRange() {
+ setUp();
+ final Serde<String> keySerde = mock(Serde.class);
+ final MeteredSessionStoreWithHeaders<String, String> store =
createStoreWithMockSerdes(keySerde);
+
+ when(innerStore.backwardFetch(any(Bytes.class), any(Bytes.class)))
+ .thenReturn(new KeyValueIteratorStub<>(
+ List.of(KeyValue.pair(WINDOWED_KEY_BYTES,
SERIALIZED_VALUE)).iterator()));
+
+ final KeyValueIterator<Windowed<String>,
AggregationWithHeaders<String>> iterator = store.backwardFetch(KEY, KEY);
+
+ assertTrue(iterator.hasNext());
+ assertEquals(KEY, iterator.peekNextKey().key());
+ final KeyValue<Windowed<String>, AggregationWithHeaders<String>>
result = iterator.next();
+ assertEquals(KEY, result.key.key());
+ assertEquals(AGG_WITH_HEADERS, result.value);
+ assertFalse(iterator.hasNext());
+ iterator.close();
+
+ verify(keySerde.deserializer()).deserialize(any(), eq(HEADERS),
eq(KEY.getBytes()));
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void shouldUseHeadersFromValueToDeserializeKeyInFindSessions() {
+ setUp();
+ final Serde<String> keySerde = mock(Serde.class);
+ final MeteredSessionStoreWithHeaders<String, String> store =
createStoreWithMockSerdes(keySerde);
+
+ when(innerStore.findSessions(any(Bytes.class), eq(0L), eq(100L)))
+ .thenReturn(new KeyValueIteratorStub<>(
+ List.of(KeyValue.pair(WINDOWED_KEY_BYTES,
SERIALIZED_VALUE)).iterator()));
+
+ final KeyValueIterator<Windowed<String>,
AggregationWithHeaders<String>> iterator =
+ store.findSessions(KEY, 0, 100);
+
+ assertTrue(iterator.hasNext());
+ assertEquals(KEY, iterator.peekNextKey().key());
+ final KeyValue<Windowed<String>, AggregationWithHeaders<String>>
result = iterator.next();
+ assertEquals(KEY, result.key.key());
+ assertEquals(AGG_WITH_HEADERS, result.value);
+ assertFalse(iterator.hasNext());
+ iterator.close();
+
+ verify(keySerde.deserializer()).deserialize(any(), eq(HEADERS),
eq(KEY.getBytes()));
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void
shouldUseHeadersFromValueToDeserializeKeyInBackwardFindSessions() {
+ setUp();
+ final Serde<String> keySerde = mock(Serde.class);
+ final MeteredSessionStoreWithHeaders<String, String> store =
createStoreWithMockSerdes(keySerde);
+
+ when(innerStore.backwardFindSessions(any(Bytes.class), eq(0L),
eq(100L)))
+ .thenReturn(new KeyValueIteratorStub<>(
+ List.of(KeyValue.pair(WINDOWED_KEY_BYTES,
SERIALIZED_VALUE)).iterator()));
+
+ final KeyValueIterator<Windowed<String>,
AggregationWithHeaders<String>> iterator =
+ store.backwardFindSessions(KEY, 0, 100);
+
+ assertTrue(iterator.hasNext());
+ assertEquals(KEY, iterator.peekNextKey().key());
+ final KeyValue<Windowed<String>, AggregationWithHeaders<String>>
result = iterator.next();
+ assertEquals(KEY, result.key.key());
+ assertEquals(AGG_WITH_HEADERS, result.value);
+ assertFalse(iterator.hasNext());
+ iterator.close();
+
+ verify(keySerde.deserializer()).deserialize(any(), eq(HEADERS),
eq(KEY.getBytes()));
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void shouldUseHeadersFromValueToDeserializeKeyInFindSessionsRange()
{
+ setUp();
+ final Serde<String> keySerde = mock(Serde.class);
+ final MeteredSessionStoreWithHeaders<String, String> store =
createStoreWithMockSerdes(keySerde);
+
+ when(innerStore.findSessions(any(Bytes.class), any(Bytes.class),
eq(0L), eq(100L)))
+ .thenReturn(new KeyValueIteratorStub<>(
+ List.of(KeyValue.pair(WINDOWED_KEY_BYTES,
SERIALIZED_VALUE)).iterator()));
+
+ final KeyValueIterator<Windowed<String>,
AggregationWithHeaders<String>> iterator =
+ store.findSessions(KEY, KEY, 0, 100);
+
+ assertTrue(iterator.hasNext());
+ assertEquals(KEY, iterator.peekNextKey().key());
+ final KeyValue<Windowed<String>, AggregationWithHeaders<String>>
result = iterator.next();
+ assertEquals(KEY, result.key.key());
+ assertEquals(AGG_WITH_HEADERS, result.value);
+ assertFalse(iterator.hasNext());
+ iterator.close();
+
+ verify(keySerde.deserializer()).deserialize(any(), eq(HEADERS),
eq(KEY.getBytes()));
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void
shouldUseHeadersFromValueToDeserializeKeyInBackwardFindSessionsRange() {
+ setUp();
+ final Serde<String> keySerde = mock(Serde.class);
+ final MeteredSessionStoreWithHeaders<String, String> store =
createStoreWithMockSerdes(keySerde);
+
+ when(innerStore.backwardFindSessions(any(Bytes.class),
any(Bytes.class), eq(0L), eq(100L)))
+ .thenReturn(new KeyValueIteratorStub<>(
+ List.of(KeyValue.pair(WINDOWED_KEY_BYTES,
SERIALIZED_VALUE)).iterator()));
+
+ final KeyValueIterator<Windowed<String>,
AggregationWithHeaders<String>> iterator =
+ store.backwardFindSessions(KEY, KEY, 0, 100);
+
+ assertTrue(iterator.hasNext());
+ assertEquals(KEY, iterator.peekNextKey().key());
+ final KeyValue<Windowed<String>, AggregationWithHeaders<String>>
result = iterator.next();
+ assertEquals(KEY, result.key.key());
+ assertEquals(AGG_WITH_HEADERS, result.value);
+ assertFalse(iterator.hasNext());
+ iterator.close();
+
+ verify(keySerde.deserializer()).deserialize(any(), eq(HEADERS),
eq(KEY.getBytes()));
+ }
+
+ @SuppressWarnings("unchecked")
+ @Test
+ public void
shouldUseHeadersFromValueToDeserializeKeyInFindSessionsByTime() {
+ setUp();
+ final Serde<String> keySerde = mock(Serde.class);
+ final MeteredSessionStoreWithHeaders<String, String> store =
createStoreWithMockSerdes(keySerde);
+
+ when(innerStore.findSessions(eq(0L), eq(100L)))
+ .thenReturn(new KeyValueIteratorStub<>(
+ List.of(KeyValue.pair(WINDOWED_KEY_BYTES,
SERIALIZED_VALUE)).iterator()));
+
+ final KeyValueIterator<Windowed<String>,
AggregationWithHeaders<String>> iterator =
+ store.findSessions(0, 100);
+
+ assertTrue(iterator.hasNext());
+ assertEquals(KEY, iterator.peekNextKey().key());
+ final KeyValue<Windowed<String>, AggregationWithHeaders<String>>
result = iterator.next();
+ assertEquals(KEY, result.key.key());
+ assertEquals(AGG_WITH_HEADERS, result.value);
+ assertFalse(iterator.hasNext());
+ iterator.close();
+
+ verify(keySerde.deserializer()).deserialize(any(), eq(HEADERS),
eq(KEY.getBytes()));
+ }
}