This is an automated email from the ASF dual-hosted git repository.

mjsax pushed a commit to branch 4.3
in repository https://gitbox.apache.org/repos/asf/kafka.git


The following commit(s) were added to refs/heads/4.3 by this push:
     new cc2cb6a9b2b KAFKA-20173: Ensure Metered session-stores pass headers 
correctly (#21957)
cc2cb6a9b2b is described below

commit cc2cb6a9b2b9f2ae4f56358710397be576719dcc
Author: Matthias J. Sax <[email protected]>
AuthorDate: Tue Apr 7 16:58:10 2026 -0700

    KAFKA-20173: Ensure Metered session-stores pass headers correctly (#21957)
    
    Ensures that all Metered Session-stores (plain and headers) pass headers
    into de/serializers.
    
    Reviewers: Uladzislau Blok <[email protected]>, TengYao Chi
     <[email protected]>
---
 .../state/internals/MeteredSessionStore.java       | 230 ++++++++++---------
 .../internals/MeteredSessionStoreWithHeaders.java  | 244 ++++++++++++---------
 .../MeteredTimestampedWindowStoreWithHeaders.java  |  60 +----
 ...MeteredWindowedKeyValueWithHeadersIterator.java |  81 +++++++
 .../kafka/streams/state/internals/Utils.java       |  34 ---
 .../state/internals/MeteredSessionStoreTest.java   |  11 +-
 .../MeteredSessionStoreWithHeadersTest.java        |   2 +
 7 files changed, 363 insertions(+), 299 deletions(-)

diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java
index a838ceab680..fcc964cc255 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStore.java
@@ -17,7 +17,6 @@
 package org.apache.kafka.streams.state.internals;
 
 import org.apache.kafka.common.TopicPartition;
-import org.apache.kafka.common.header.internals.RecordHeaders;
 import org.apache.kafka.common.metrics.Sensor;
 import org.apache.kafka.common.serialization.Serde;
 import org.apache.kafka.common.utils.Bytes;
@@ -59,8 +58,6 @@ import static org.apache.kafka.common.utils.Utils.mkEntry;
 import static org.apache.kafka.common.utils.Utils.mkMap;
 import static 
org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency;
 
-// TODO: replace with new method in follow-up PR of KIP-1271
-@SuppressWarnings("deprecation")
 public class MeteredSessionStore<K, V>
     extends WrappedStateStore<SessionStore<Bytes, byte[]>, Windowed<K>, V>
     implements SessionStore<K, V>, MeteredStateStore {
@@ -94,11 +91,13 @@ public class MeteredSessionStore<K, V>
             );
 
 
-    MeteredSessionStore(final SessionStore<Bytes, byte[]> inner,
-                        final String metricsScope,
-                        final Serde<K> keySerde,
-                        final Serde<V> valueSerde,
-                        final Time time) {
+    MeteredSessionStore(
+        final SessionStore<Bytes, byte[]> inner,
+        final String metricsScope,
+        final Serde<K> keySerde,
+        final Serde<V> valueSerde,
+        final Time time
+    ) {
         super(inner);
         this.metricsScope = metricsScope;
         this.keySerde = keySerde;
@@ -107,8 +106,10 @@ public class MeteredSessionStore<K, V>
     }
 
     @Override
-    public void init(final StateStoreContext stateStoreContext,
-                     final StateStore root) {
+    public void init(
+        final StateStoreContext stateStoreContext,
+        final StateStore root
+    ) {
         internalContext = stateStoreContext instanceof 
InternalProcessorContext ? (InternalProcessorContext<?, ?>) stateStoreContext : 
null;
         taskId = stateStoreContext.taskId();
         initStoreSerde(stateStoreContext);
@@ -181,27 +182,35 @@ public class MeteredSessionStore<K, V>
 
     @SuppressWarnings("unchecked")
     @Override
-    public boolean setFlushListener(final CacheFlushListener<Windowed<K>, V> 
listener,
-                                    final boolean sendOldValues) {
+    public boolean setFlushListener(
+        final CacheFlushListener<Windowed<K>, V> listener,
+        final boolean sendOldValues
+    ) {
         final SessionStore<Bytes, byte[]> wrapped = wrapped();
         if (wrapped instanceof CachedStateStore) {
             return ((CachedStateStore<byte[], byte[]>) 
wrapped).setFlushListener(
-                record -> listener.apply(
-                    record.withKey(SessionKeySchema.from(record.key(), 
serdes.keyDeserializer(), record.headers(), serdes.topic()))
-                        .withValue(new Change<>(
-                            record.value().newValue != null ? 
serdes.valueFrom(record.value().newValue, record.headers()) : null,
-                            record.value().oldValue != null ? 
serdes.valueFrom(record.value().oldValue, record.headers()) : null,
-                            record.value().isLatest
-                        ))
-                ),
+                record -> {
+                    final Change<byte[]> change = record.value();
+                    listener.apply(
+                        record
+                            .withKey(SessionKeySchema.from(record.key(), 
serdes.keyDeserializer(), record.headers(), serdes.topic()))
+                            .withValue(new Change<>(
+                                change.newValue != null ? 
serdes.valueFrom(change.newValue, record.headers()) : null,
+                                change.oldValue != null ? 
serdes.valueFrom(change.oldValue, record.headers()) : null,
+                                change.isLatest
+                            ))
+                    );
+                },
                 sendOldValues);
         }
         return false;
     }
 
     @Override
-    public void put(final Windowed<K> sessionKey,
-                    final V aggregate) {
+    public void put(
+        final Windowed<K> sessionKey,
+        final V aggregate
+    ) {
         Objects.requireNonNull(sessionKey, "sessionKey can't be null");
         Objects.requireNonNull(sessionKey.key(), "sessionKey.key() can't be 
null");
         Objects.requireNonNull(sessionKey.window(), "sessionKey.window() can't 
be null");
@@ -209,8 +218,8 @@ public class MeteredSessionStore<K, V>
         try {
             maybeMeasureLatency(
                 () -> {
-                    final Bytes key = keyBytes(sessionKey.key());
-                    wrapped().put(new Windowed<>(key, sessionKey.window()), 
serdes.rawValue(aggregate));
+                    final Bytes key = serializeKey(sessionKey.key());
+                    wrapped().put(new Windowed<>(key, sessionKey.window()), 
serializeValue(aggregate));
                 },
                 time,
                 putSensor
@@ -231,7 +240,7 @@ public class MeteredSessionStore<K, V>
         try {
             maybeMeasureLatency(
                 () -> {
-                    final Bytes key = keyBytes(sessionKey.key());
+                    final Bytes key = serializeKey(sessionKey.key());
                     wrapped().remove(new Windowed<>(key, sessionKey.window()));
                 },
                 time,
@@ -247,18 +256,7 @@ public class MeteredSessionStore<K, V>
     public V fetchSession(final K key, final long earliestSessionEndTime, 
final long latestSessionStartTime) {
         Objects.requireNonNull(key, "key cannot be null");
         return maybeMeasureLatency(
-            () -> {
-                final Bytes bytesKey = keyBytes(key);
-                final byte[] result = wrapped().fetchSession(
-                    bytesKey,
-                    earliestSessionEndTime,
-                    latestSessionStartTime
-                );
-                if (result == null) {
-                    return null;
-                }
-                return serdes.valueFrom(result);
-            },
+            () -> deserializeValue(wrapped().fetchSession(serializeKey(key), 
earliestSessionEndTime, latestSessionStartTime)),
             time,
             fetchSensor
         );
@@ -268,25 +266,26 @@ public class MeteredSessionStore<K, V>
     public KeyValueIterator<Windowed<K>, V> fetch(final K key) {
         Objects.requireNonNull(key, "key cannot be null");
         return new MeteredWindowedKeyValueIterator<>(
-            wrapped().fetch(keyBytes(key)),
+            wrapped().fetch(serializeKey(key)),
             fetchSensor,
             iteratorDurationSensor,
-            bytes -> serdes.keyFrom(bytes, new RecordHeaders()),
-            serdes::valueFrom,
+            this::deserializeKey,
+            this::deserializeValue,
             time,
             numOpenIterators,
-            openIterators);
+            openIterators
+        );
     }
 
     @Override
     public KeyValueIterator<Windowed<K>, V> backwardFetch(final K key) {
         Objects.requireNonNull(key, "key cannot be null");
         return new MeteredWindowedKeyValueIterator<>(
-            wrapped().backwardFetch(keyBytes(key)),
+            wrapped().backwardFetch(serializeKey(key)),
             fetchSensor,
             iteratorDurationSensor,
-            bytes -> serdes.keyFrom(bytes, new RecordHeaders()),
-            serdes::valueFrom,
+            this::deserializeKey,
+            this::deserializeValue,
             time,
             numOpenIterators,
             openIterators
@@ -294,28 +293,33 @@ public class MeteredSessionStore<K, V>
     }
 
     @Override
-    public KeyValueIterator<Windowed<K>, V> fetch(final K keyFrom,
-                                                  final K keyTo) {
+    public KeyValueIterator<Windowed<K>, V> fetch(
+        final K keyFrom,
+        final K keyTo
+    ) {
         return new MeteredWindowedKeyValueIterator<>(
-            wrapped().fetch(keyBytes(keyFrom), keyBytes(keyTo)),
+            wrapped().fetch(serializeKey(keyFrom), serializeKey(keyTo)),
             fetchSensor,
             iteratorDurationSensor,
-            bytes -> serdes.keyFrom(bytes, new RecordHeaders()),
-            serdes::valueFrom,
+            this::deserializeKey,
+            this::deserializeValue,
             time,
             numOpenIterators,
-            openIterators);
+            openIterators
+        );
     }
 
     @Override
-    public KeyValueIterator<Windowed<K>, V> backwardFetch(final K keyFrom,
-                                                          final K keyTo) {
+    public KeyValueIterator<Windowed<K>, V> backwardFetch(
+        final K keyFrom,
+        final K keyTo
+    ) {
         return new MeteredWindowedKeyValueIterator<>(
-            wrapped().backwardFetch(keyBytes(keyFrom), keyBytes(keyTo)),
+            wrapped().backwardFetch(serializeKey(keyFrom), 
serializeKey(keyTo)),
             fetchSensor,
             iteratorDurationSensor,
-            bytes -> serdes.keyFrom(bytes, new RecordHeaders()),
-            serdes::valueFrom,
+            this::deserializeKey,
+            this::deserializeValue,
             time,
             numOpenIterators,
             openIterators
@@ -323,11 +327,13 @@ public class MeteredSessionStore<K, V>
     }
 
     @Override
-    public KeyValueIterator<Windowed<K>, V> findSessions(final K key,
-                                                         final long 
earliestSessionEndTime,
-                                                         final long 
latestSessionStartTime) {
+    public KeyValueIterator<Windowed<K>, V> findSessions(
+        final K key,
+        final long earliestSessionEndTime,
+        final long latestSessionStartTime
+    ) {
         Objects.requireNonNull(key, "key cannot be null");
-        final Bytes bytesKey = keyBytes(key);
+        final Bytes bytesKey = serializeKey(key);
         return new MeteredWindowedKeyValueIterator<>(
             wrapped().findSessions(
                 bytesKey,
@@ -335,19 +341,22 @@ public class MeteredSessionStore<K, V>
                 latestSessionStartTime),
             fetchSensor,
             iteratorDurationSensor,
-            bytes -> serdes.keyFrom(bytes, new RecordHeaders()),
-            serdes::valueFrom,
+            this::deserializeKey,
+            this::deserializeValue,
             time,
             numOpenIterators,
-            openIterators);
+            openIterators
+        );
     }
 
     @Override
-    public KeyValueIterator<Windowed<K>, V> backwardFindSessions(final K key,
-                                                                 final long 
earliestSessionEndTime,
-                                                                 final long 
latestSessionStartTime) {
+    public KeyValueIterator<Windowed<K>, V> backwardFindSessions(
+        final K key,
+        final long earliestSessionEndTime,
+        final long latestSessionStartTime
+    ) {
         Objects.requireNonNull(key, "key cannot be null");
-        final Bytes bytesKey = keyBytes(key);
+        final Bytes bytesKey = serializeKey(key);
         return new MeteredWindowedKeyValueIterator<>(
             wrapped().backwardFindSessions(
                 bytesKey,
@@ -356,8 +365,8 @@ public class MeteredSessionStore<K, V>
             ),
             fetchSensor,
             iteratorDurationSensor,
-            bytes -> serdes.keyFrom(bytes, new RecordHeaders()),
-            serdes::valueFrom,
+            this::deserializeKey,
+            this::deserializeValue,
             time,
             numOpenIterators,
             openIterators
@@ -365,12 +374,14 @@ public class MeteredSessionStore<K, V>
     }
 
     @Override
-    public KeyValueIterator<Windowed<K>, V> findSessions(final K keyFrom,
-                                                         final K keyTo,
-                                                         final long 
earliestSessionEndTime,
-                                                         final long 
latestSessionStartTime) {
-        final Bytes bytesKeyFrom = keyBytes(keyFrom);
-        final Bytes bytesKeyTo = keyBytes(keyTo);
+    public KeyValueIterator<Windowed<K>, V> findSessions(
+        final K keyFrom,
+        final K keyTo,
+        final long earliestSessionEndTime,
+        final long latestSessionStartTime
+    ) {
+        final Bytes bytesKeyFrom = serializeKey(keyFrom);
+        final Bytes bytesKeyTo = serializeKey(keyTo);
         return new MeteredWindowedKeyValueIterator<>(
             wrapped().findSessions(
                 bytesKeyFrom,
@@ -379,34 +390,40 @@ public class MeteredSessionStore<K, V>
                 latestSessionStartTime),
             fetchSensor,
             iteratorDurationSensor,
-            bytes -> serdes.keyFrom(bytes, new RecordHeaders()),
-            serdes::valueFrom,
+            this::deserializeKey,
+            this::deserializeValue,
             time,
             numOpenIterators,
-            openIterators);
+            openIterators
+        );
     }
 
     @Override
-    public KeyValueIterator<Windowed<K>, V> findSessions(final long 
earliestSessionEndTime,
-                                                         final long 
latestSessionEndTime) {
+    public KeyValueIterator<Windowed<K>, V> findSessions(
+        final long earliestSessionEndTime,
+        final long latestSessionEndTime
+    ) {
         return new MeteredWindowedKeyValueIterator<>(
-                wrapped().findSessions(earliestSessionEndTime, 
latestSessionEndTime),
-                fetchSensor,
-                iteratorDurationSensor,
-                bytes -> serdes.keyFrom(bytes, new RecordHeaders()),
-                serdes::valueFrom,
-                time,
-                numOpenIterators,
-                openIterators);
+            wrapped().findSessions(earliestSessionEndTime, 
latestSessionEndTime),
+            fetchSensor,
+            iteratorDurationSensor,
+            this::deserializeKey,
+            this::deserializeValue,
+            time,
+            numOpenIterators,
+            openIterators
+        );
     }
 
     @Override
-    public KeyValueIterator<Windowed<K>, V> backwardFindSessions(final K 
keyFrom,
-                                                                 final K keyTo,
-                                                                 final long 
earliestSessionEndTime,
-                                                                 final long 
latestSessionStartTime) {
-        final Bytes bytesKeyFrom = keyBytes(keyFrom);
-        final Bytes bytesKeyTo = keyBytes(keyTo);
+    public KeyValueIterator<Windowed<K>, V> backwardFindSessions(
+        final K keyFrom,
+        final K keyTo,
+        final long earliestSessionEndTime,
+        final long latestSessionStartTime
+    ) {
+        final Bytes bytesKeyFrom = serializeKey(keyFrom);
+        final Bytes bytesKeyTo = serializeKey(keyTo);
         return new MeteredWindowedKeyValueIterator<>(
             wrapped().backwardFindSessions(
                 bytesKeyFrom,
@@ -416,8 +433,8 @@ public class MeteredSessionStore<K, V>
             ),
             fetchSensor,
             iteratorDurationSensor,
-            bytes -> serdes.keyFrom(bytes, new RecordHeaders()),
-            serdes::valueFrom,
+            this::deserializeKey,
+            this::deserializeValue,
             time,
             numOpenIterators,
             openIterators
@@ -477,9 +494,7 @@ public class MeteredSessionStore<K, V>
         final WindowRangeQuery<K, V> typedQuery = (WindowRangeQuery<K, V>) 
query;
         if (typedQuery.getKey().isPresent()) {
             final WindowRangeQuery<Bytes, byte[]> rawKeyQuery =
-                WindowRangeQuery.withKey(
-                    Bytes.wrap(serdes.rawKey(typedQuery.getKey().get()))
-                );
+                
WindowRangeQuery.withKey(serializeKey(typedQuery.getKey().get()));
             final QueryResult<KeyValueIterator<Windowed<Bytes>, byte[]>> 
rawResult =
                 wrapped().query(rawKeyQuery, positionBound, config);
             if (rawResult.isSuccess()) {
@@ -488,7 +503,7 @@ public class MeteredSessionStore<K, V>
                         rawResult.getResult(),
                         fetchSensor,
                         iteratorDurationSensor,
-                        bytes -> serdes.keyFrom(bytes, new RecordHeaders()),
+                        this::deserializeKey,
                         StoreQueryUtils.deserializeValue(serdes, wrapped()),
                         time,
                         numOpenIterators,
@@ -502,7 +517,6 @@ public class MeteredSessionStore<K, V>
                 result = (QueryResult<R>) rawResult;
             }
         } else {
-
             result = QueryResult.forFailure(
                 FailureReason.UNKNOWN_QUERY_TYPE,
                 "This store (" + getClass() + ") doesn't know how to"
@@ -515,8 +529,20 @@ public class MeteredSessionStore<K, V>
         return result;
     }
 
-    private Bytes keyBytes(final K key) {
-        return key == null ? null : Bytes.wrap(serdes.rawKey(key, new 
RecordHeaders()));
+    private Bytes serializeKey(final K key) {
+        return Bytes.wrap(serdes.rawKey(key, internalContext.headers()));
+    }
+
+    private K deserializeKey(final byte[] rawKey) {
+        return serdes.keyFrom(rawKey, internalContext.headers());
+    }
+
+    protected byte[] serializeValue(final V value) {
+        return value != null ? serdes.rawValue(value, 
internalContext.headers()) : null;
+    }
+
+    protected V deserializeValue(final byte[] rawValue) {
+        return rawValue != null ? serdes.valueFrom(rawValue, 
internalContext.headers()) : null;
     }
 
     void maybeRecordE2ELatency() {
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeaders.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeaders.java
index 45a1b17ecc8..cbddac333cc 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeaders.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeaders.java
@@ -41,17 +41,18 @@ import 
org.apache.kafka.streams.state.SessionStoreWithHeaders;
 import java.util.Objects;
 
 import static 
org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency;
-import static org.apache.kafka.streams.state.internals.Utils.keyBytes;
 
 public class MeteredSessionStoreWithHeaders<K, AGG>
     extends MeteredSessionStore<K, AggregationWithHeaders<AGG>>
     implements SessionStoreWithHeaders<K, AGG> {
 
-    MeteredSessionStoreWithHeaders(final SessionStore<Bytes, byte[]> inner,
-                                   final String metricsScope,
-                                   final Serde<K> keySerde,
-                                   final Serde<AggregationWithHeaders<AGG>> 
aggSerde,
-                                   final Time time) {
+    MeteredSessionStoreWithHeaders(
+        final SessionStore<Bytes, byte[]> inner,
+        final String metricsScope,
+        final Serde<K> keySerde,
+        final Serde<AggregationWithHeaders<AGG>> aggSerde,
+        final Time time
+    ) {
         super(inner, metricsScope, keySerde, aggSerde, time);
     }
 
@@ -59,13 +60,22 @@ public class MeteredSessionStoreWithHeaders<K, AGG>
     @Override
     protected Serde<AggregationWithHeaders<AGG>> prepareValueSerdeForStore(
             final Serde<AggregationWithHeaders<AGG>> valueSerde,
-            final SerdeGetter getter) {
+            final SerdeGetter getter
+    ) {
         if (valueSerde == null) {
             return new AggregationWithHeadersSerde<>((Serde<AGG>) 
getter.valueSerde());
         }
         return super.prepareValueSerdeForStore(valueSerde, getter);
     }
 
+    private Bytes serializeKey(final K key, final Headers headers) {
+        return Bytes.wrap(serdes.rawKey(key, headers));
+    }
+
+    private K deserializeKey(final byte[] rawKey, final Headers headers) {
+        return serdes.keyFrom(rawKey, headers);
+    }
+
     @Override
     public void put(final Windowed<K> sessionKey, final 
AggregationWithHeaders<AGG> aggregate) {
         Objects.requireNonNull(sessionKey, "sessionKey can't be null");
@@ -89,16 +99,27 @@ public class MeteredSessionStoreWithHeaders<K, AGG>
 
                         try {
                             internalContext.setRecordContext(temporaryContext);
-                            final Bytes key = keyBytes(sessionKey, 
deleteHeaders, serdes);
-                            wrapped().put(new Windowed<>(key, 
sessionKey.window()), serdes.rawValue(null, deleteHeaders));
+                            wrapped().put(
+                                new Windowed<>(
+                                    serializeKey(sessionKey.key(), 
deleteHeaders),
+                                    sessionKey.window()
+                                ),
+                                null
+                            );
                         } finally {
                             // Restore original context
                             internalContext.setRecordContext(currentContext);
                         }
                     } else {
-                        final Headers headers = aggregate.headers();
-                        final Bytes key = keyBytes(sessionKey, headers, 
serdes);
-                        wrapped().put(new Windowed<>(key, 
sessionKey.window()), serdes.rawValue(aggregate, headers));
+                        // it's ok to only pass headers into `serializeKey`, 
because for the value case passed-in headers are
+                        // getting ignored anyway, because the value (of type 
`AggregationWithHeaders`) itself carries the headers
+                        wrapped().put(
+                            new Windowed<>(
+                                serializeKey(sessionKey.key(), 
aggregate.headers()),
+                                sessionKey.window()
+                            ),
+                            serializeValue(aggregate)
+                        );
                     }
                 },
                 time,
@@ -137,8 +158,9 @@ public class MeteredSessionStoreWithHeaders<K, AGG>
 
                     try {
                         internalContext.setRecordContext(temporaryContext);
-                        final Bytes key = keyBytes(sessionKey, deleteHeaders, 
serdes);
-                        wrapped().remove(new Windowed<>(key, 
sessionKey.window()));
+                        wrapped().remove(
+                            new Windowed<>(serializeKey(sessionKey.key(), 
deleteHeaders), sessionKey.window())
+                        );
                     } finally {
                         // Restore original context
                         internalContext.setRecordContext(currentContext);
@@ -154,30 +176,19 @@ public class MeteredSessionStoreWithHeaders<K, AGG>
 
     @SuppressWarnings("unchecked")
     @Override
-    public <R> QueryResult<R> query(final Query<R> query,
-                                    final PositionBound positionBound,
-                                    final QueryConfig config) {
-        final long start = time.nanoseconds();
+    public <R> QueryResult<R> query(
+        final Query<R> query,
+        final PositionBound positionBound,
+        final QueryConfig config
+    ) {
+        final long start = config.isCollectExecutionInfo() ? System.nanoTime() 
: -1L;
         final QueryResult<R> result;
 
         if (query instanceof WindowRangeQuery) {
-            final WindowRangeQuery<K, AGG> windowRangeQuery = 
(WindowRangeQuery<K, AGG>) query;
-            if (windowRangeQuery.getKey().isPresent()) {
-                result = runRangeQuery(query, positionBound, config);
-            } else {
-                result = QueryResult.forFailure(
-                    FailureReason.UNKNOWN_QUERY_TYPE,
-                    "This store (" + getClass() + ") doesn't know how to"
-                        + " execute the given query (" + query + ") because"
-                        + " SessionStores only support 
WindowRangeQuery.withKey."
-                        + " Contact the store maintainer if you need support"
-                        + " for a new query type."
-                );
-            }
+            result = runRangeQuery((WindowRangeQuery<K, AGG>) query, 
positionBound, config);
             if (config.isCollectExecutionInfo()) {
                 result.addExecutionInfo(
-                    "Handled in " + getClass() + " with serdes "
-                        + serdes + " in " + (time.nanoseconds() - start) + 
"ns");
+                    "Handled in " + getClass() + " with serdes " + serdes + " 
in " + (time.nanoseconds() - start) + "ns");
             }
         } else {
             result = wrapped().query(query, positionBound, config);
@@ -193,7 +204,7 @@ public class MeteredSessionStoreWithHeaders<K, AGG>
     public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>> 
fetch(final K key) {
         Objects.requireNonNull(key, "key cannot be null");
         return new MeteredSessionStoreWithHeadersIterator(
-            wrapped().fetch(keyBytes(key, new RecordHeaders(), serdes))
+            wrapped().fetch(serializeKey(key, internalContext.headers()))
         );
     }
 
@@ -201,123 +212,152 @@ public class MeteredSessionStoreWithHeaders<K, AGG>
     public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>> 
backwardFetch(final K key) {
         Objects.requireNonNull(key, "key cannot be null");
         return new MeteredSessionStoreWithHeadersIterator(
-            wrapped().backwardFetch(keyBytes(key, new RecordHeaders(), serdes))
+            wrapped().backwardFetch(serializeKey(key, 
internalContext.headers()))
         );
     }
 
     @Override
-    public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>> 
fetch(final K keyFrom,
-                                                                            
final K keyTo) {
+    public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>> fetch(
+        final K keyFrom,
+        final K keyTo
+    ) {
         return new MeteredSessionStoreWithHeadersIterator(
             wrapped().fetch(
-                keyBytes(keyFrom, new RecordHeaders(), serdes),
-                keyBytes(keyTo, new RecordHeaders(), serdes))
+                serializeKey(keyFrom, internalContext.headers()),
+                serializeKey(keyTo, internalContext.headers())
+            )
         );
     }
 
     @Override
-    public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>> 
backwardFetch(final K keyFrom,
-                                                                               
     final K keyTo) {
+    public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>> 
backwardFetch(
+        final K keyFrom,
+        final K keyTo
+    ) {
         return new MeteredSessionStoreWithHeadersIterator(
             wrapped().backwardFetch(
-                keyBytes(keyFrom, new RecordHeaders(), serdes),
-                keyBytes(keyTo, new RecordHeaders(), serdes))
+                serializeKey(keyFrom, internalContext.headers()),
+                serializeKey(keyTo, internalContext.headers())
+            )
         );
     }
 
     @Override
-    public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>> 
findSessions(final K key,
-                                                                               
    final long earliestSessionEndTime,
-                                                                               
    final long latestSessionStartTime) {
+    public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>> 
findSessions(
+        final K key,
+        final long earliestSessionEndTime,
+        final long latestSessionStartTime
+    ) {
         Objects.requireNonNull(key, "key cannot be null");
         return new MeteredSessionStoreWithHeadersIterator(
             wrapped().findSessions(
-                keyBytes(key, new RecordHeaders(), serdes),
+                serializeKey(key, internalContext.headers()),
                 earliestSessionEndTime,
-                latestSessionStartTime)
+                latestSessionStartTime
+            )
         );
     }
 
     @Override
-    public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>> 
backwardFindSessions(final K key,
-                                                                               
            final long earliestSessionEndTime,
-                                                                               
            final long latestSessionStartTime) {
+    public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>> 
backwardFindSessions(
+        final K key,
+        final long earliestSessionEndTime,
+        final long latestSessionStartTime
+    ) {
         Objects.requireNonNull(key, "key cannot be null");
         return new MeteredSessionStoreWithHeadersIterator(
             wrapped().backwardFindSessions(
-                keyBytes(key, new RecordHeaders(), serdes),
+                serializeKey(key, internalContext.headers()),
                 earliestSessionEndTime,
-                latestSessionStartTime)
+                latestSessionStartTime
+            )
         );
     }
 
     @Override
-    public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>> 
findSessions(final K keyFrom,
-                                                                               
    final K keyTo,
-                                                                               
    final long earliestSessionEndTime,
-                                                                               
    final long latestSessionStartTime) {
+    public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>> 
findSessions(
+        final K keyFrom,
+        final K keyTo,
+        final long earliestSessionEndTime,
+        final long latestSessionStartTime
+    ) {
         return new MeteredSessionStoreWithHeadersIterator(
             wrapped().findSessions(
-                keyBytes(keyFrom, new RecordHeaders(), serdes),
-                keyBytes(keyTo, new RecordHeaders(), serdes),
+                serializeKey(keyFrom, internalContext.headers()),
+                serializeKey(keyTo, internalContext.headers()),
                 earliestSessionEndTime,
-                latestSessionStartTime)
+                latestSessionStartTime
+            )
         );
     }
 
     @Override
-    public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>> 
backwardFindSessions(final K keyFrom,
-                                                                               
            final K keyTo,
-                                                                               
            final long earliestSessionEndTime,
-                                                                               
            final long latestSessionStartTime) {
+    public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>> 
backwardFindSessions(
+        final K keyFrom,
+        final K keyTo,
+        final long earliestSessionEndTime,
+        final long latestSessionStartTime
+    ) {
         return new MeteredSessionStoreWithHeadersIterator(
             wrapped().backwardFindSessions(
-                keyBytes(keyFrom, new RecordHeaders(), serdes),
-                keyBytes(keyTo, new RecordHeaders(), serdes),
+                serializeKey(keyFrom, internalContext.headers()),
+                serializeKey(keyTo, internalContext.headers()),
                 earliestSessionEndTime,
-                latestSessionStartTime)
+                latestSessionStartTime
+            )
         );
     }
 
     @Override
-    public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>> 
findSessions(final long earliestSessionEndTime,
-                                                                               
    final long latestSessionEndTime) {
-        return new MeteredSessionStoreWithHeadersIterator(
-            wrapped().findSessions(earliestSessionEndTime, 
latestSessionEndTime)
-        );
+    public KeyValueIterator<Windowed<K>, AggregationWithHeaders<AGG>> 
findSessions(
+        final long earliestSessionEndTime,
+        final long latestSessionEndTime
+    ) {
+        return new 
MeteredSessionStoreWithHeadersIterator(wrapped().findSessions(earliestSessionEndTime,
 latestSessionEndTime));
     }
 
     @SuppressWarnings("unchecked")
-    private <R> QueryResult<R> runRangeQuery(final Query<R> query,
-                                             final PositionBound positionBound,
-                                             final QueryConfig config) {
-        final WindowRangeQuery<K, AGG> typedQuery = (WindowRangeQuery<K, AGG>) 
query;
-        final WindowRangeQuery<Bytes, byte[]> rawKeyQuery =
-            WindowRangeQuery.withKey(
-                Bytes.wrap(serdes.rawKey(typedQuery.getKey().get(), new 
RecordHeaders()))
-            );
-        final QueryResult<KeyValueIterator<Windowed<Bytes>, byte[]>> rawResult 
=
-            wrapped().query(rawKeyQuery, positionBound, config);
-        if (rawResult.isSuccess()) {
-            final MeteredWindowedKeyValueIterator<K, AGG> typedResult =
-                new MeteredWindowedKeyValueIterator<>(
-                    rawResult.getResult(),
-                    fetchSensor,
-                    iteratorDurationSensor,
-                    bytes -> serdes.keyFrom(bytes, new RecordHeaders()),
-                    byteArray -> {
-                        final AggregationWithHeaders<AGG> awh =
-                            
serdes.valueDeserializer().deserialize(serdes.topic(), byteArray);
-                        return awh == null ? null : awh.aggregation();
-                    },
-                    time,
-                    numOpenIterators,
-                    openIterators
-                );
-            return (QueryResult<R>) 
InternalQueryResultUtil.copyAndSubstituteDeserializedResult(rawResult, 
typedResult);
+    private <R> QueryResult<R> runRangeQuery(
+        final WindowRangeQuery<K, AGG> query,
+        final PositionBound positionBound,
+        final QueryConfig config
+    ) {
+        final QueryResult<R> queryResult;
+
+        if (query.getKey().isPresent()) {
+            final WindowRangeQuery<Bytes, byte[]> rawKeyQuery =
+                WindowRangeQuery.withKey(serializeKey(query.getKey().get(), 
internalContext.headers()));
+            final QueryResult<KeyValueIterator<Windowed<Bytes>, byte[]>> 
rawResult =
+                wrapped().query(rawKeyQuery, positionBound, config);
+            if (rawResult.isSuccess()) {
+                final MeteredWindowedKeyValueIterator<K, AGG> typedResult =
+                    new MeteredWindowedKeyValueWithHeadersIterator<>(
+                        rawResult.getResult(),
+                        fetchSensor,
+                        iteratorDurationSensor,
+                        this::deserializeValue,
+                        this::deserializeKey,
+                        AggregationWithHeaders::headers,
+                        aggregationWithHeaders -> aggregationWithHeaders == 
null ? null : aggregationWithHeaders.aggregation(),
+                        time,
+                        numOpenIterators,
+                        openIterators
+                    );
+                queryResult = (QueryResult<R>) 
InternalQueryResultUtil.copyAndSubstituteDeserializedResult(rawResult, 
typedResult);
+            } else {
+                queryResult = (QueryResult<R>) rawResult;
+            }
         } else {
-            return (QueryResult<R>) rawResult;
+            queryResult = QueryResult.forFailure(
+                FailureReason.UNKNOWN_QUERY_TYPE,
+                "This store (" + getClass() + ") doesn't know how to"
+                    + " execute the given query (" + query + ") because"
+                    + " SessionStores only support WindowRangeQuery.withKey."
+                    + " Contact the store maintainer if you need support"
+                    + " for a new query type."
+            );
         }
+        return queryResult;
     }
 
     private class MeteredSessionStoreWithHeadersIterator
@@ -356,9 +396,9 @@ public class MeteredSessionStoreWithHeaders<K, AGG>
 
             final KeyValue<Windowed<Bytes>, byte[]> next = iter.next();
 
-            final AggregationWithHeaders<AGG> value = 
serdes.valueFrom(next.value, new RecordHeaders());
+            final AggregationWithHeaders<AGG> value = 
deserializeValue(next.value);
             final Headers headers = value != null ? value.headers() : new 
RecordHeaders();
-            final K key = serdes.keyFrom(next.key.key().get(), headers);
+            final K key = deserializeKey(next.key.key().get(), headers);
             final Windowed<K> windowedKey = new Windowed<>(key, 
next.key.window());
             return KeyValue.pair(windowedKey, value);
         }
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedWindowStoreWithHeaders.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedWindowStoreWithHeaders.java
index a3803316be6..5350628e290 100644
--- 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedWindowStoreWithHeaders.java
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredTimestampedWindowStoreWithHeaders.java
@@ -18,7 +18,6 @@ package org.apache.kafka.streams.state.internals;
 
 import org.apache.kafka.common.header.Headers;
 import org.apache.kafka.common.header.internals.RecordHeaders;
-import org.apache.kafka.common.metrics.Sensor;
 import org.apache.kafka.common.serialization.Serde;
 import org.apache.kafka.common.utils.Bytes;
 import org.apache.kafka.common.utils.Time;
@@ -45,9 +44,6 @@ import org.apache.kafka.streams.state.WindowStore;
 import org.apache.kafka.streams.state.WindowStoreIterator;
 
 import java.util.Objects;
-import java.util.Set;
-import java.util.concurrent.atomic.LongAdder;
-import java.util.function.BiFunction;
 import java.util.function.Function;
 
 import static 
org.apache.kafka.streams.processor.internals.metrics.StreamsMetricsImpl.maybeMeasureLatency;
@@ -184,9 +180,7 @@ public class MeteredTimestampedWindowStoreWithHeaders<K, V>
                         }
                     );
 
-                    final 
QueryResult<MeteredWindowStoreIterator<ValueAndTimestamp<V>>> typedQueryResult =
-                            
InternalQueryResultUtil.copyAndSubstituteDeserializedResult(rawResult, 
typedResult);
-                    queryResult = (QueryResult<R>) typedQueryResult;
+                    queryResult = (QueryResult<R>) 
InternalQueryResultUtil.copyAndSubstituteDeserializedResult(rawResult, 
typedResult);
                 } else {
                     // For non-timestamped stores, return plain V
                     final MeteredWindowStoreIterator<V> typedResult = 
meteredIterator(
@@ -196,10 +190,7 @@ public class MeteredTimestampedWindowStoreWithHeaders<K, V>
                             return vth == null ? null : vth.value();
                         }
                     );
-
-                    final QueryResult<MeteredWindowStoreIterator<V>> 
typedQueryResult =
-                            
InternalQueryResultUtil.copyAndSubstituteDeserializedResult(rawResult, 
typedResult);
-                    queryResult = (QueryResult<R>) typedQueryResult;
+                    queryResult = (QueryResult<R>) 
InternalQueryResultUtil.copyAndSubstituteDeserializedResult(rawResult, 
typedResult);
                 }
             } else {
                 queryResult = (QueryResult<R>) rawResult;
@@ -420,7 +411,9 @@ public class MeteredTimestampedWindowStoreWithHeaders<K, V>
             rawResult.getResult(),
             fetchSensor,
             iteratorDurationSensor,
+            this::deserializeValue,
             this::deserializeKey,
+            ValueTimestampHeaders::headers,
             valueConverter,
             time,
             numOpenIterators,
@@ -428,51 +421,6 @@ public class MeteredTimestampedWindowStoreWithHeaders<K, V>
         );
     }
 
-    private final class MeteredWindowedKeyValueWithHeadersIterator<ValueType> 
extends MeteredWindowedKeyValueIterator<K, ValueType> {
-        private final BiFunction<byte[], Headers, K> deserializeKey;
-        private final Function<ValueTimestampHeaders<V>, ValueType> 
valueConverter;
-
-        MeteredWindowedKeyValueWithHeadersIterator(
-            final KeyValueIterator<Windowed<Bytes>, byte[]> iter,
-            final Sensor operationSensor,
-            final Sensor iteratorSensor,
-            final BiFunction<byte[], Headers, K> deserializeKey,
-            final Function<ValueTimestampHeaders<V>, ValueType> valueConverter,
-            final Time time,
-            final LongAdder numOpenIterators,
-            final Set<MeteredIterator> openIterators
-        ) {
-            super(
-                iter,
-                operationSensor,
-                iteratorSensor,
-                null, // should not be used in super-class
-                null, // should not be used in super-class
-                time,
-                numOpenIterators,
-                openIterators
-            );
-
-            this.deserializeKey = deserializeKey;
-            this.valueConverter = valueConverter;
-        }
-
-        @Override
-        public KeyValue<Windowed<K>, ValueType> next() {
-            final KeyValue<Windowed<Bytes>, byte[]> next = iter.next();
-            final ValueTimestampHeaders<V> valueTimestampHeaders = 
deserializeValue(next.value);
-            return KeyValue.pair(
-                windowedKey(next.key, valueTimestampHeaders.headers()),
-                valueConverter.apply(valueTimestampHeaders)
-            );
-        }
-
-        private Windowed<K> windowedKey(final Windowed<Bytes> bytesKey, final 
Headers headers) {
-            final K key = deserializeKey.apply(bytesKey.key().get(), headers);
-            return new Windowed<>(key, bytesKey.window());
-        }
-    }
-
     private boolean isUnderlyingStoreTimestamped() {
         StateStore store = wrapped();
         do {
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowedKeyValueWithHeadersIterator.java
 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowedKeyValueWithHeadersIterator.java
new file mode 100644
index 00000000000..d83a6ea9952
--- /dev/null
+++ 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/MeteredWindowedKeyValueWithHeadersIterator.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.kafka.streams.state.internals;
+
+import org.apache.kafka.common.header.Headers;
+import org.apache.kafka.common.metrics.Sensor;
+import org.apache.kafka.common.utils.Bytes;
+import org.apache.kafka.common.utils.Time;
+import org.apache.kafka.streams.KeyValue;
+import org.apache.kafka.streams.kstream.Windowed;
+import org.apache.kafka.streams.state.KeyValueIterator;
+
+import java.util.Set;
+import java.util.concurrent.atomic.LongAdder;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+
+final class MeteredWindowedKeyValueWithHeadersIterator<K, VInner, VOuter> 
extends MeteredWindowedKeyValueIterator<K, VOuter> {
+    private final Function<byte[], VInner> deserializeValue;
+    private final BiFunction<byte[], Headers, K> deserializeKey;
+    private final Function<VInner, Headers> headersExtractor;
+    private final Function<VInner, VOuter> valueConverter;
+
+    MeteredWindowedKeyValueWithHeadersIterator(
+        final KeyValueIterator<Windowed<Bytes>, byte[]> iter,
+        final Sensor operationSensor,
+        final Sensor iteratorSensor,
+        final Function<byte[], VInner> deserializeValue,
+        final BiFunction<byte[], Headers, K> deserializeKey,
+        final Function<VInner, Headers> headersExtractor,
+        final Function<VInner, VOuter> valueConverter,
+        final Time time,
+        final LongAdder numOpenIterators,
+        final Set<MeteredIterator> openIterators
+    ) {
+        super(
+            iter,
+            operationSensor,
+            iteratorSensor,
+            null, // should not be used in super-class
+            null, // should not be used in super-class
+            time,
+            numOpenIterators,
+            openIterators
+        );
+
+        this.deserializeValue = deserializeValue;
+        this.deserializeKey = deserializeKey;
+        this.headersExtractor = headersExtractor;
+        this.valueConverter = valueConverter;
+    }
+
+    @Override
+    public KeyValue<Windowed<K>, VOuter> next() {
+        final KeyValue<Windowed<Bytes>, byte[]> next = iter.next();
+        final VInner valueTimestampHeaders = 
deserializeValue.apply(next.value);
+        return KeyValue.pair(
+            windowedKey(next.key, 
headersExtractor.apply(valueTimestampHeaders)),
+            valueConverter.apply(valueTimestampHeaders)
+        );
+    }
+
+    private Windowed<K> windowedKey(final Windowed<Bytes> bytesKey, final 
Headers headers) {
+        final K key = deserializeKey.apply(bytesKey.key().get(), headers);
+        return new Windowed<>(key, bytesKey.window());
+    }
+}
diff --git 
a/streams/src/main/java/org/apache/kafka/streams/state/internals/Utils.java 
b/streams/src/main/java/org/apache/kafka/streams/state/internals/Utils.java
index c67bfc8dfbe..e637c8eb2a2 100644
--- a/streams/src/main/java/org/apache/kafka/streams/state/internals/Utils.java
+++ b/streams/src/main/java/org/apache/kafka/streams/state/internals/Utils.java
@@ -21,8 +21,6 @@ import org.apache.kafka.common.header.Headers;
 import org.apache.kafka.common.header.internals.RecordHeaders;
 import org.apache.kafka.common.serialization.LongDeserializer;
 import org.apache.kafka.common.utils.ByteUtils;
-import org.apache.kafka.common.utils.Bytes;
-import org.apache.kafka.streams.kstream.Windowed;
 import org.apache.kafka.streams.state.StateSerdes;
 
 import java.nio.ByteBuffer;
@@ -54,38 +52,6 @@ public class Utils {
         return readHeaders(buffer);
     }
 
-    /**
-     * Serialize the key with headers into bytes
-     * @param key the key to serialize
-     * @param headers the Headers as context
-     * @param serdes the StateSerdes as serializer
-     * @return the Bytes of the key
-     */
-    public static <K> Bytes keyBytes(final K key, final Headers headers, final 
StateSerdes<K, ?> serdes) {
-        return Bytes.wrap(serdes.rawKey(key, headers));
-    }
-
-    /**
-     * Serialize the key into bytes
-     * @param key the key to serialize
-     * @param serdes the StateSerdes as serializer
-     * @return the Bytes of the key
-     */
-    static <K> Bytes keyBytes(final K key, final StateSerdes<K, ?> serdes) {
-        return keyBytes(key, new RecordHeaders(), serdes);
-    }
-
-    /**
-     * Serialize the session key with headers into bytes
-     * @param sessionKey the Windowed session key to serialize
-     * @param headers the Headers as context
-     * @param serdes the StateSerdes as serializer
-     * @return the Bytes of the key
-     */
-    static <K> Bytes keyBytes(final Windowed<K> sessionKey, final Headers 
headers, final StateSerdes<K, ?> serdes) {
-        return keyBytes(sessionKey.key(), headers, serdes);
-    }
-
     /**
      * Extract the raw aggregation bytes from serialized 
AggregationWithHeaders,
      * stripping the headers prefix.
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreTest.java
index 78b49786f75..21ae08edd89 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreTest.java
@@ -103,6 +103,7 @@ public class MeteredSessionStoreTest {
     private static final Windowed<Bytes> WINDOWED_KEY_BYTES = new 
Windowed<>(KEY_BYTES, new SessionWindow(0, 0));
     private static final String VALUE = "value";
     private static final byte[] VALUE_BYTES = VALUE.getBytes();
+    private static final Headers HEADERS = new RecordHeaders();
     private static final long START_TIMESTAMP = 24L;
     private static final long END_TIMESTAMP = 42L;
     private static final int RETENTION_PERIOD = 100;
@@ -139,8 +140,7 @@ public class MeteredSessionStoreTest {
         setUpWithoutContext();
         metrics.config().recordLevel(Sensor.RecordingLevel.DEBUG);
         when(context.applicationId()).thenReturn(APPLICATION_ID);
-        when(context.metrics())
-                .thenReturn(new StreamsMetricsImpl(metrics, "test", mockTime));
+        when(context.metrics()).thenReturn(new StreamsMetricsImpl(metrics, 
"test", mockTime));
         when(context.taskId()).thenReturn(taskId);
         when(context.changelogFor(STORE_NAME)).thenReturn(CHANGELOG_TOPIC);
         when(innerStore.name()).thenReturn(STORE_NAME);
@@ -187,12 +187,13 @@ public class MeteredSessionStoreTest {
         final Deserializer<String> valueDeserializer = 
mock(Deserializer.class);
         final Serializer<String> valueSerializer = mock(Serializer.class);
         when(keySerde.serializer()).thenReturn(keySerializer);
-        when(keySerializer.serialize(topic, new RecordHeaders(), 
KEY)).thenReturn(KEY.getBytes());
+        when(keySerializer.serialize(topic, HEADERS, 
KEY)).thenReturn(KEY.getBytes());
         when(valueSerde.deserializer()).thenReturn(valueDeserializer);
-        when(valueDeserializer.deserialize(topic, new RecordHeaders(), 
VALUE_BYTES)).thenReturn(VALUE);
+        when(valueDeserializer.deserialize(topic, HEADERS, 
VALUE_BYTES)).thenReturn(VALUE);
         when(valueSerde.serializer()).thenReturn(valueSerializer);
-        when(valueSerializer.serialize(topic, new RecordHeaders(), 
VALUE)).thenReturn(VALUE_BYTES);
+        when(valueSerializer.serialize(topic, HEADERS, 
VALUE)).thenReturn(VALUE_BYTES);
         when(innerStore.fetchSession(KEY_BYTES, START_TIMESTAMP, 
END_TIMESTAMP)).thenReturn(VALUE_BYTES);
+        when(context.headers()).thenReturn(HEADERS);
         store = new MeteredSessionStore<>(
             innerStore,
             STORE_TYPE,
diff --git 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeadersTest.java
 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeadersTest.java
index f95254237d6..aafa28d00b2 100644
--- 
a/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeadersTest.java
+++ 
b/streams/src/test/java/org/apache/kafka/streams/state/internals/MeteredSessionStoreWithHeadersTest.java
@@ -841,6 +841,8 @@ public class MeteredSessionStoreWithHeadersTest {
         lenient().when(keyDeserializer.deserialize(any(), eq(HEADERS), 
eq(KEY.getBytes())))
             .thenReturn(KEY);
 
+        when(context.headers()).thenReturn(new RecordHeaders());
+
         final MeteredSessionStoreWithHeaders<String, String> mockStore = new 
MeteredSessionStoreWithHeaders<>(
             innerStore,
             STORE_TYPE,

Reply via email to