This is an automated email from the ASF dual-hosted git repository.
gujiaweijoe pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/bifromq.git
The following commit(s) were added to refs/heads/main by this push:
new c07c2c00 Switch to use WCH-balance mode with sticky enabled for dist
pub rpc (#170)
c07c2c00 is described below
commit c07c2c0099f39282ba5fd30d6af8c2e91b03348d
Author: Yonny(Yu) Hao <[email protected]>
AuthorDate: Sat Aug 16 20:53:30 2025 +0800
Switch to use WCH-balance mode with sticky enabled for dist pub rpc (#170)
1. support hashing routing with sticky behavior for pipelined unary method
which is more suitable for dist pub method;
2. Replace EnhancedMarshaller with a memory-efficient impl;
3. Make WRR deterministic.
---
.../baserpc/client/DummyServerSelector.java | 7 +-
.../bifromq/baserpc/client/ManagedBiDiStream.java | 27 +--
.../baserpc/client/ManagedMessageStream.java | 4 +-
.../baserpc/client/ManagedRequestPipeline.java | 4 +-
.../baserpc/client/loadbalancer/HRWRouter.java | 111 +++++++++++
.../client/loadbalancer/IServerGroupRouter.java | 4 +-
.../loadbalancer/TenantAwareServerSelector.java | 58 ++++++
.../loadbalancer/WeightedServerGroupRouter.java | 21 ++-
.../baserpc/client/loadbalancer/HRWRouterTest.java | 209 +++++++++++++++++++++
.../WeightedServerGroupRouterTest.java | 39 +++-
.../io/grpc/protobuf/lite/EnhancedMarshaller.java | 131 -------------
.../java/org/apache/bifromq/baserpc/BluePrint.java | 29 +--
.../baserpc/marshaller/HLCStampedInputStream.java | 125 ++++++++++++
.../baserpc/marshaller/HLCStampedMarshaller.java | 88 +++++++++
.../marshaller/HLCStampedInputStreamTest.java | 68 +++++++
.../marshaller/HLCStampedMarshallerTest.java | 84 +++++++++
.../scheduler/BatchPubCallBuilderFactory.java | 10 +-
.../java/org/apache/bifromq/dist/RPCBluePrint.java | 4 +-
18 files changed, 844 insertions(+), 179 deletions(-)
diff --git
a/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/DummyServerSelector.java
b/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/DummyServerSelector.java
index 0f5ed20d..3e37c853 100644
---
a/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/DummyServerSelector.java
+++
b/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/DummyServerSelector.java
@@ -19,9 +19,9 @@
package org.apache.bifromq.baserpc.client;
+import java.util.Optional;
import org.apache.bifromq.baserpc.client.loadbalancer.IServerGroupRouter;
import org.apache.bifromq.baserpc.client.loadbalancer.IServerSelector;
-import java.util.Optional;
class DummyServerSelector implements IServerSelector {
public static final IServerSelector INSTANCE = new DummyServerSelector();
@@ -58,6 +58,11 @@ class DummyServerSelector implements IServerSelector {
public Optional<String> hashing(String key) {
return Optional.empty();
}
+
+ @Override
+ public Optional<String> stickyHashing(String key) {
+ return Optional.empty();
+ }
};
}
diff --git
a/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/ManagedBiDiStream.java
b/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/ManagedBiDiStream.java
index 2003dc16..27f63fcf 100644
---
a/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/ManagedBiDiStream.java
+++
b/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/ManagedBiDiStream.java
@@ -14,7 +14,7 @@
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
- * under the License.
+ * under the License.
*/
package org.apache.bifromq.baserpc.client;
@@ -52,6 +52,7 @@ abstract class ManagedBiDiStream<InT, OutT> {
private final CompositeDisposable disposables = new CompositeDisposable();
private final String tenantId;
private final String wchKey;
+ private final boolean sticky;
private final String targetServerId;
private final Supplier<Map<String, String>> metadataSupplier;
private final Channel channel;
@@ -66,18 +67,20 @@ abstract class ManagedBiDiStream<InT, OutT> {
ManagedBiDiStream(String tenantId,
String wchKey,
String targetServerId,
- BluePrint.BalanceMode balanceMode,
+ BluePrint.MethodSemantic methodSemantic,
Supplier<Map<String, String>> metadataSupplier,
Channel channel,
CallOptions callOptions,
MethodDescriptor<InT, OutT> methodDescriptor) {
- checkArgument(balanceMode != BluePrint.BalanceMode.DDBalanced ||
targetServerId != null,
+ checkArgument(methodSemantic.mode() !=
BluePrint.BalanceMode.DDBalanced || targetServerId != null,
"targetServerId is required");
- checkArgument(balanceMode != BluePrint.BalanceMode.WCHBalanced |
wchKey != null, "wchKey is required");
+ checkArgument(methodSemantic.mode() !=
BluePrint.BalanceMode.WCHBalanced || wchKey != null,
+ "wchKey is required");
this.tenantId = tenantId;
this.wchKey = wchKey;
this.targetServerId = targetServerId;
- this.balanceMode = balanceMode;
+ this.balanceMode = methodSemantic.mode();
+ this.sticky = methodSemantic instanceof
BluePrint.HRWPipelineUnaryMethod;
this.metadataSupplier = metadataSupplier;
this.channel = channel;
this.callOptions = callOptions;
@@ -149,11 +152,11 @@ abstract class ManagedBiDiStream<InT, OutT> {
return;
}
Optional<String> currentServer = prevRouter.hashing(wchKey);
- Optional<String> newServer = router.hashing(wchKey);
+ Optional<String> newServer = sticky ?
router.stickyHashing(wchKey) : router.hashing(wchKey);
if (newServer.isEmpty()) {
// cancel current bidi-stream
synchronized (this) {
- bidiStream.get().bidiStream().cancel("no server
available");
+ bidiStream.get().bidiStream().cancel("No server
available");
}
} else if (!newServer.equals(currentServer)) {
switch (state.get()) {
@@ -179,7 +182,7 @@ abstract class ManagedBiDiStream<InT, OutT> {
if (newServer.isEmpty()) {
// cancel current bidi-stream
synchronized (this) {
- bidiStream.get().bidiStream().cancel("no server
available");
+ bidiStream.get().bidiStream().cancel("No server
available");
}
} else {
switch (state.get()) {
@@ -212,7 +215,7 @@ abstract class ManagedBiDiStream<InT, OutT> {
if (newServer.isEmpty()) {
// cancel current bidi-stream
synchronized (this) {
- bidiStream.get().bidiStream().cancel("no server
available");
+ bidiStream.get().bidiStream().cancel("No server
available");
}
} else {
switch (state.get()) {
@@ -294,7 +297,7 @@ abstract class ManagedBiDiStream<InT, OutT> {
}
case WCHBalanced -> {
IServerGroupRouter router = serverSelector.get(tenantId);
- Optional<String> selectedServer = router.hashing(wchKey);
+ Optional<String> selectedServer = sticky ?
router.stickyHashing(wchKey) : router.hashing(wchKey);
if (selectedServer.isEmpty()) {
state.set(State.NoServerAvailable);
reportServiceUnavailable();
@@ -423,13 +426,13 @@ abstract class ManagedBiDiStream<InT, OutT> {
@Override
public void cancel(String message) {
// do nothing
- managedBiDiStream.onStreamError(new
IllegalStateException("bidi-stream is not ready"));
+ managedBiDiStream.onStreamError(new IllegalStateException("Stream
is not ready"));
}
@Override
public void send(InT in) {
// do nothing
- managedBiDiStream.onStreamError(new
IllegalStateException("bidi-stream is not ready"));
+ managedBiDiStream.onStreamError(new IllegalStateException("Stream
is not ready"));
}
@Override
diff --git
a/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/ManagedMessageStream.java
b/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/ManagedMessageStream.java
index 58ea18db..2a44018c 100644
---
a/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/ManagedMessageStream.java
+++
b/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/ManagedMessageStream.java
@@ -14,7 +14,7 @@
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
- * under the License.
+ * under the License.
*/
package org.apache.bifromq.baserpc.client;
@@ -55,7 +55,7 @@ class ManagedMessageStream<MsgT, AckT> extends
ManagedBiDiStream<AckT, MsgT>
super(tenantId,
wchKey,
targetServerId,
- bluePrint.semantic(methodDescriptor.getFullMethodName()).mode(),
+ bluePrint.semantic(methodDescriptor.getFullMethodName()),
metadataSupplier,
channelHolder.channel(),
callOptions,
diff --git
a/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/ManagedRequestPipeline.java
b/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/ManagedRequestPipeline.java
index 852646c9..dd2d5b4c 100644
---
a/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/ManagedRequestPipeline.java
+++
b/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/ManagedRequestPipeline.java
@@ -14,7 +14,7 @@
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
- * under the License.
+ * under the License.
*/
package org.apache.bifromq.baserpc.client;
@@ -63,7 +63,7 @@ public class ManagedRequestPipeline<ReqT, RespT> extends
ManagedBiDiStream<ReqT,
super(tenantId,
wchKey,
targetServerId,
- bluePrint.semantic(methodDescriptor.getFullMethodName()).mode(),
+ bluePrint.semantic(methodDescriptor.getFullMethodName()),
metadataSupplier,
channelHolder.channel(),
callOptions,
diff --git
a/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/loadbalancer/HRWRouter.java
b/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/loadbalancer/HRWRouter.java
new file mode 100644
index 00000000..c405e190
--- /dev/null
+++
b/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/loadbalancer/HRWRouter.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.bifromq.baserpc.client.loadbalancer;
+
+import com.google.common.base.Charsets;
+import com.google.common.hash.HashCode;
+import com.google.common.hash.Hashing;
+import java.util.Collection;
+import java.util.Objects;
+
+/**
+ * Rendezvous Hashing Weighted Router (HRW).
+ */
+class HRWRouter<T> {
+ private static final long IEEE754_DOUBLE_1 = 0x3FF0000000000000L; // 1.0
in IEEE 754 double format
+ private final Collection<T> nodes;
+ private final KeyFunction<T> keyFunction;
+ private final WeightFunction<T> weightFunction;
+ private final HashFunction hashFunction;
+
+ HRWRouter(Collection<T> nodes, KeyFunction<T> keyFunction,
WeightFunction<T> weightFunction) {
+ this(nodes, keyFunction, weightFunction, (key) -> {
+ HashCode code = Hashing.murmur3_128().hashString(key,
Charsets.UTF_8);
+ return code.asLong();
+ });
+ }
+
+ HRWRouter(Collection<T> nodes,
+ KeyFunction<T> keyFunction,
+ WeightFunction<T> weightFunction,
+ HashFunction hashFunction) {
+ this.nodes = Objects.requireNonNull(nodes);
+ this.keyFunction = Objects.requireNonNull(keyFunction);
+ this.weightFunction = Objects.requireNonNull(weightFunction);
+ this.hashFunction = Objects.requireNonNull(hashFunction);
+ }
+
+ // map unsigned long to double in (0,1) uniformly
+ private static double hashToUnitInterval(long x) {
+ double u = Double.longBitsToDouble((x >>> 12) | IEEE754_DOUBLE_1) -
1.0;
+ final double eps = 1e-12;
+ if (u <= 0) {
+ u = eps;
+ }
+ if (u >= 1) {
+ u = 1 - eps;
+ }
+ return u;
+ }
+
+ /**
+ * Route to the best node based on the given object key.
+ *
+ * @param objectKey the key to route
+ * @return the best node, or null if no nodes are available
+ */
+ T routeNode(String objectKey) {
+ if (nodes.isEmpty()) {
+ return null;
+ }
+ T bestNode = null;
+ double bestScore = Double.POSITIVE_INFINITY;
+
+ for (T n : nodes) {
+ String key = keyFunction.getKey(n);
+ int w = weightFunction.getWeight(n);
+ if (w <= 0) {
+ continue;
+ }
+ long h = hashFunction.hash64(objectKey + key);
+ // Rendezvous/WRH:min(-ln(U)/w)
+ double u = hashToUnitInterval(h);
+ double score = -Math.log(u) / (double) w;
+
+ if (score < bestScore) {
+ bestScore = score;
+ bestNode = n;
+ }
+ }
+ return bestNode;
+ }
+
+ interface KeyFunction<T> {
+ String getKey(T node);
+ }
+
+ interface WeightFunction<T> {
+ int getWeight(T node);
+ }
+
+ interface HashFunction {
+ long hash64(String key);
+ }
+}
\ No newline at end of file
diff --git
a/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/loadbalancer/IServerGroupRouter.java
b/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/loadbalancer/IServerGroupRouter.java
index 0796eb2b..e5c0687b 100644
---
a/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/loadbalancer/IServerGroupRouter.java
+++
b/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/loadbalancer/IServerGroupRouter.java
@@ -14,7 +14,7 @@
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
- * under the License.
+ * under the License.
*/
package org.apache.bifromq.baserpc.client.loadbalancer;
@@ -31,4 +31,6 @@ public interface IServerGroupRouter {
Optional<String> tryRoundRobin();
Optional<String> hashing(String key);
+
+ Optional<String> stickyHashing(String key);
}
diff --git
a/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/loadbalancer/TenantAwareServerSelector.java
b/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/loadbalancer/TenantAwareServerSelector.java
new file mode 100644
index 00000000..8072f8f4
--- /dev/null
+++
b/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/loadbalancer/TenantAwareServerSelector.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.bifromq.baserpc.client.loadbalancer;
+
+import com.google.common.collect.Maps;
+import java.util.Map;
+import java.util.Set;
+import lombok.EqualsAndHashCode;
+
+@EqualsAndHashCode
+class TenantAwareServerSelector implements IServerSelector {
+ private final Map<String, Boolean> allServers;
+ private final Map<String, Set<String>> serverGroupTags;
+ private final Map<String, Map<String, Integer>> trafficDirective;
+ @EqualsAndHashCode.Exclude
+ private final ITenantRouter tenantRouter;
+
+ public TenantAwareServerSelector(Map<String, Boolean> allServers,
+ Map<String, Set<String>> serverGroupTags,
+ Map<String, Map<String, Integer>>
trafficDirective) {
+ this.allServers = Maps.newHashMap(allServers);
+ this.serverGroupTags = Maps.newHashMap(serverGroupTags);
+ this.trafficDirective = Maps.newHashMap(trafficDirective);
+ this.tenantRouter = new TenantRouter(this.allServers,
this.trafficDirective, this.serverGroupTags);
+ }
+
+ @Override
+ public boolean exists(String serverId) {
+ return allServers.containsKey(serverId);
+ }
+
+ @Override
+ public IServerGroupRouter get(String tenantId) {
+ return tenantRouter.get(tenantId);
+ }
+
+ @Override
+ public String toString() {
+ return allServers.toString();
+ }
+}
diff --git
a/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/loadbalancer/WeightedServerGroupRouter.java
b/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/loadbalancer/WeightedServerGroupRouter.java
index b6b2b8c8..c4c94c9f 100644
---
a/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/loadbalancer/WeightedServerGroupRouter.java
+++
b/base-rpc/base-rpc-client/src/main/java/org/apache/bifromq/baserpc/client/loadbalancer/WeightedServerGroupRouter.java
@@ -14,7 +14,7 @@
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
- * under the License.
+ * under the License.
*/
package org.apache.bifromq.baserpc.client.loadbalancer;
@@ -26,12 +26,14 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
+import java.util.SortedMap;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
class WeightedServerGroupRouter implements IServerGroupRouter {
- private final Map<String, Integer> weightedServers;
+ private final SortedMap<String, Integer> weightedServers;
private final List<String> weightedServerRRSequence;
+ private final HRWRouter<String> hrwRouter;
private final WCHRouter<String> chRouter;
private final AtomicInteger rrIndex = new AtomicInteger(0);
private final Set<String> inProcServers = new HashSet<>();
@@ -39,7 +41,7 @@ class WeightedServerGroupRouter implements IServerGroupRouter
{
WeightedServerGroupRouter(Map<String, Boolean> allServers,
Map<String, Integer> groupWeights,
Map<String, Set<String>> serverGroups) {
- weightedServers = Maps.newHashMap();
+ weightedServers = Maps.newTreeMap();
for (String group : groupWeights.keySet()) {
int weight = Math.abs(groupWeights.get(group)) % 11; // weight
range: 0-10
serverGroups.getOrDefault(group,
Collections.emptySet()).forEach(serverId ->
@@ -53,6 +55,8 @@ class WeightedServerGroupRouter implements IServerGroupRouter
{
}
weightedServerRRSequence =
LBUtils.toWeightedRRSequence(weightedServers);
chRouter = new WCHRouter<>(weightedServers.keySet(), serverId ->
serverId, weightedServers::get, 100);
+ hrwRouter = new HRWRouter<>(weightedServers.keySet(), serverId ->
serverId, weightedServers::get);
+
// if inproc server is not in the weightedServers, it will be ignored
for (String serverId : weightedServers.keySet()) {
if (allServers.getOrDefault(serverId, false)) {
@@ -92,7 +96,7 @@ class WeightedServerGroupRouter implements IServerGroupRouter
{
if (!inProcServers.isEmpty()) {
return inProcServers.stream().findFirst();
} else {
- int i = rrIndex.incrementAndGet();
+ int i = rrIndex.getAndIncrement();
if (i >= size) {
int oldi = i;
i %= size;
@@ -124,4 +128,13 @@ class WeightedServerGroupRouter implements
IServerGroupRouter {
public Optional<String> hashing(String key) {
return Optional.ofNullable(chRouter.routeNode(key));
}
+
+ @Override
+ public Optional<String> stickyHashing(String key) {
+ // prefer in-proc server
+ if (!inProcServers.isEmpty()) {
+ return inProcServers.stream().findFirst();
+ }
+ return Optional.ofNullable(hrwRouter.routeNode(key));
+ }
}
diff --git
a/base-rpc/base-rpc-client/src/test/java/org/apache/bifromq/baserpc/client/loadbalancer/HRWRouterTest.java
b/base-rpc/base-rpc-client/src/test/java/org/apache/bifromq/baserpc/client/loadbalancer/HRWRouterTest.java
new file mode 100644
index 00000000..d90285ae
--- /dev/null
+++
b/base-rpc/base-rpc-client/src/test/java/org/apache/bifromq/baserpc/client/loadbalancer/HRWRouterTest.java
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.bifromq.baserpc.client.loadbalancer;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNull;
+import static org.testng.Assert.assertTrue;
+
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import org.testng.annotations.Test;
+
+public class HRWRouterTest {
+
+ @Test
+ public void testEmptyNodesReturnsNull() {
+ HRWRouter<String> r = newRouter(Collections.emptyList(),
Collections.emptyMap());
+ assertNull(r.routeNode("any"));
+ }
+
+ @Test
+ public void testAllZeroWeightsReturnsNull() {
+ List<String> nodes = List.of("A", "B", "C");
+ Map<String, Integer> w = Map.of("A", 0, "B", 0, "C", 0);
+ HRWRouter<String> r = newRouter(nodes, w);
+ assertNull(r.routeNode("k1"));
+ assertNull(r.routeNode("k2"));
+ }
+
+ @Test
+ public void testNegativeAndZeroWeightsAreIgnored() {
+ List<String> nodes = List.of("A", "B", "C");
+ Map<String, Integer> w = new HashMap<>();
+ w.put("A", -1);
+ w.put("B", 0);
+ w.put("C", 3);
+ HRWRouter<String> r = newRouter(nodes, w);
+ for (int i = 0; i < 100; i++) {
+ assertEquals(r.routeNode("k" + i), "C");
+ }
+ }
+
+ @Test
+ public void testDeterministicForSameKeyAndNodes() {
+ List<String> nodes = List.of("A", "B", "C");
+ Map<String, Integer> w = Map.of("A", 1, "B", 1, "C", 1);
+ HRWRouter<String> r1 = newRouter(nodes, w);
+ HRWRouter<String> r2 = newRouter(nodes, w);
+
+ for (int i = 0; i < 1000; i++) {
+ String key = "k" + i;
+ assertEquals(r1.routeNode(key), r2.routeNode(key));
+ }
+ }
+
+ @Test
+ public void testNearUniformWhenEqualWeights() {
+ List<String> nodes = List.of("A", "B", "C", "D");
+ Map<String, Integer> w = Map.of("A", 1, "B", 1, "C", 1, "D", 1);
+ HRWRouter<String> r = newRouter(nodes, w);
+
+ int samples = 20_000;
+ Map<String, Integer> count = new HashMap<>();
+ nodes.forEach(n -> count.put(n, 0));
+
+ for (int i = 0; i < samples; i++) {
+ String key = "k" + i;
+ String owner = r.routeNode(key);
+ count.compute(owner, (k, v) -> v + 1);
+ }
+
+ double expect = samples / (double) nodes.size();
+ double tol = expect * 0.05; // 5% tolerant
+ for (String n : nodes) {
+ assertTrue(Math.abs(count.get(n) - expect) <= tol);
+ }
+ }
+
+ @Test
+ public void testWeightedProportion() {
+ // A:2, B:1, C:1 => probabilities 0.5, 0.25, 0.25
+ List<String> nodes = List.of("A", "B", "C");
+ Map<String, Integer> w = Map.of("A", 2, "B", 1, "C", 1);
+ HRWRouter<String> r = newRouter(nodes, w);
+
+ int samples = 20_000;
+ Map<String, Integer> count = new HashMap<>();
+ nodes.forEach(n -> count.put(n, 0));
+
+ for (int i = 0; i < samples; i++) {
+ String key = "k" + i;
+ String owner = r.routeNode(key);
+ count.compute(owner, (k, v) -> v + 1);
+ }
+
+ double pA = 2.0 / 4.0;
+ double pB = 1.0 / 4.0;
+ double pC = 1.0 / 4.0;
+ double tol = samples * 0.03; // 3% tolerant
+ assertTrue(Math.abs(count.get("A") - pA * samples) <= tol);
+ assertTrue(Math.abs(count.get("B") - pB * samples) <= tol);
+ assertTrue(Math.abs(count.get("C") - pC * samples) <= tol);
+ }
+
+ @Test
+ public void testIncreasingWeightIncreasesShare() {
+ List<String> nodes = List.of("A", "B");
+ Map<String, Integer> w1 = Map.of("A", 1, "B", 1);
+ Map<String, Integer> w2 = Map.of("A", 3, "B", 1);
+
+ HRWRouter<String> r1 = newRouter(nodes, w1);
+ HRWRouter<String> r2 = newRouter(nodes, w2);
+
+ int samples = 20_000;
+ int a1 = 0, a2 = 0;
+ for (int i = 0; i < samples; i++) {
+ String key = "k" + i;
+ if ("A".equals(r1.routeNode(key))) {
+ a1++;
+ }
+ if ("A".equals(r2.routeNode(key))) {
+ a2++;
+ }
+ }
+ assertTrue(a2 > a1);
+ }
+
+ @Test
+ public void testMinimalMovementOnAddNode() {
+ // N=3 -> N=4, theoretically about ~1/(N+1)=~25% of keys move to the
new node
+ List<String> nodes3 = List.of("A", "B", "C");
+ Map<String, Integer> w3 = Map.of("A", 1, "B", 1, "C", 1);
+ HRWRouter<String> r3 = newRouter(nodes3, w3);
+
+ List<String> nodes4 = List.of("A", "B", "C", "D");
+ Map<String, Integer> w4 = Map.of("A", 1, "B", 1, "C", 1, "D", 1);
+ HRWRouter<String> r4 = newRouter(nodes4, w4);
+
+ int samples = 20_000;
+ int movedToD = 0;
+ for (int i = 0; i < samples; i++) {
+ String key = "k" + i;
+ String o3 = r3.routeNode(key);
+ String o4 = r4.routeNode(key);
+ if (!Objects.equals(o3, o4) && "D".equals(o4)) {
+ movedToD++;
+ }
+ }
+ double ratio = movedToD / (double) samples;
+ // 0.25 ± 0.08
+ assertTrue(Math.abs(ratio - 0.25) <= 0.08,
+ "movement ratio to new node not close to 1/(N+1): " + ratio);
+ }
+
+ @Test
+ public void testMinimalMovementOnRemoveNode() {
+ // N=4 -> N=3, remove D, theoretically about ~1/(N)=~25% of keys move
to other nodes
+ List<String> nodes4 = List.of("A", "B", "C", "D");
+ Map<String, Integer> w4 = Map.of("A", 1, "B", 1, "C", 1, "D", 1);
+ HRWRouter<String> r4 = newRouter(nodes4, w4);
+
+ List<String> nodes3 = List.of("A", "B", "C");
+ Map<String, Integer> w3 = Map.of("A", 1, "B", 1, "C", 1);
+ HRWRouter<String> r3 = newRouter(nodes3, w3);
+
+ int samples = 20_000;
+ int moved = 0;
+ int nodesInD = 0;
+ for (int i = 0; i < samples; i++) {
+ String key = "k" + i;
+ String o4 = r4.routeNode(key);
+ String o3 = r3.routeNode(key);
+ if ("D".equals(o4)) {
+ nodesInD++;
+ if (!Objects.equals(o3, o4)) {
+ moved++;
+ }
+ }
+ }
+ // moved should be close to nodesInD (all keys owned by D should move)
+ double ratio = moved / (double) Math.max(1, nodesInD);
+ assertTrue(ratio > 0.95);
+ }
+
+ private HRWRouter<String> newRouter(Collection<String> nodes, Map<String,
Integer> weights) {
+ return new HRWRouter<>(nodes, n -> n, weights::get);
+ }
+}
\ No newline at end of file
diff --git
a/base-rpc/base-rpc-client/src/test/java/org/apache/bifromq/baserpc/client/loadbalancer/WeightedServerGroupRouterTest.java
b/base-rpc/base-rpc-client/src/test/java/org/apache/bifromq/baserpc/client/loadbalancer/WeightedServerGroupRouterTest.java
index bae437f2..73ac09ad 100644
---
a/base-rpc/base-rpc-client/src/test/java/org/apache/bifromq/baserpc/client/loadbalancer/WeightedServerGroupRouterTest.java
+++
b/base-rpc/base-rpc-client/src/test/java/org/apache/bifromq/baserpc/client/loadbalancer/WeightedServerGroupRouterTest.java
@@ -101,14 +101,13 @@ public class WeightedServerGroupRouterTest {
Optional<String> server = router.random();
assertTrue(server.isPresent());
- assertNotEquals("server4", server.get());
+ assertNotEquals(server.get(), "server4");
server = router.roundRobin();
assertTrue(server.isPresent());
- assertNotEquals("server4", server.get());
+ assertNotEquals(server.get(), "server4");
}
-
@Test
void roundRobinRouting() {
Map<String, Boolean> allServers = Map.of("server1", false, "server2",
false, "server3", false);
@@ -122,9 +121,9 @@ public class WeightedServerGroupRouterTest {
new WeightedServerGroupRouter(allServers, trafficAssignment,
groupAssignment);
// First round-robin call should return server1, then server2
- assertEquals("server1", router.roundRobin().get());
- assertEquals("server2", router.roundRobin().get());
- assertEquals("server1", router.roundRobin().get());
+ assertEquals(router.roundRobin().get(), "server1");
+ assertEquals(router.roundRobin().get(), "server2");
+ assertEquals(router.roundRobin().get(), "server1");
}
@Test
@@ -139,9 +138,9 @@ public class WeightedServerGroupRouterTest {
WeightedServerGroupRouter router =
new WeightedServerGroupRouter(allServers, trafficAssignment,
groupAssignment);
- assertEquals("server1", router.roundRobin().get());
- assertEquals("server1", router.roundRobin().get());
- assertEquals("server1", router.roundRobin().get());
+ assertEquals(router.roundRobin().get(), "server1");
+ assertEquals(router.roundRobin().get(), "server1");
+ assertEquals(router.roundRobin().get(), "server1");
}
@Test
@@ -161,4 +160,26 @@ public class WeightedServerGroupRouterTest {
assertTrue(server.isPresent());
assertTrue(Set.of("server1", "server2",
"server3").contains(server.get()));
}
+
+ @Test
+ void stickHashingRouting() {
+ Map<String, Boolean> allServers = Map.of("server1", false, "server2",
false, "server3", false);
+ Map<String, Boolean> allServersWithInProc = Map.of("server1", false,
"server2", true, "server3", false);
+ Map<String, Integer> trafficAssignment = Map.of("group1", 5);
+ Map<String, Set<String>> groupAssignment = Map.of(
+ "group1", Set.of("server1", "server2", "server3")
+ );
+ WeightedServerGroupRouter router =
+ new WeightedServerGroupRouter(allServers, trafficAssignment,
groupAssignment);
+ WeightedServerGroupRouter routerWithInProc =
+ new WeightedServerGroupRouter(allServersWithInProc,
trafficAssignment, groupAssignment);
+ String key = "myKey";
+ Optional<String> server = router.stickyHashing(key);
+ assertTrue(server.isPresent());
+ assertEquals(server.get(), "server1");
+
+ server = routerWithInProc.stickyHashing(key);
+ assertTrue(server.isPresent());
+ assertEquals(server.get(), "server2");
+ }
}
diff --git
a/base-rpc/base-rpc-common/src/main/java/io/grpc/protobuf/lite/EnhancedMarshaller.java
b/base-rpc/base-rpc-common/src/main/java/io/grpc/protobuf/lite/EnhancedMarshaller.java
deleted file mode 100644
index a75d6144..00000000
---
a/base-rpc/base-rpc-common/src/main/java/io/grpc/protobuf/lite/EnhancedMarshaller.java
+++ /dev/null
@@ -1,131 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-package io.grpc.protobuf.lite;
-
-import com.google.protobuf.CodedInputStream;
-import com.google.protobuf.ExtensionRegistryLite;
-import com.google.protobuf.Message;
-import com.google.protobuf.MessageLite;
-import com.google.protobuf.Parser;
-import com.google.protobuf.UnknownFieldSet;
-import io.grpc.KnownLength;
-import io.grpc.MethodDescriptor;
-import java.io.InputStream;
-import java.lang.ref.Reference;
-import java.lang.ref.WeakReference;
-import lombok.SneakyThrows;
-import org.apache.bifromq.basehlc.HLC;
-
-/**
- * Enhance marshaller with HLC piggybacking & aliasing enabled.
- * The reason why it's put in io.grpc.protobuf.lite package is because
ProtoInputStream is package-private.
- *
- * @param <T> the type of the message
- */
-public class EnhancedMarshaller<T> implements
MethodDescriptor.PrototypeMarshaller<T> {
- private static final int DEFAULT_MAX_MESSAGE_SIZE = 4 * 1024 * 1024;
- private static final int PIGGYBACK_FIELD_ID = Short.MAX_VALUE;
- private static final ThreadLocal<Reference<byte[]>> bufs = new
ThreadLocal<>();
- private final T defaultInstance;
- private final Parser<T> parser;
- private final ThreadLocal<UnknownFieldSet.Builder> localFieldSetBuilder =
- ThreadLocal.withInitial(UnknownFieldSet::newBuilder);
-
- private final ThreadLocal<UnknownFieldSet.Field.Builder> localFieldBuilder
=
- ThreadLocal.withInitial(UnknownFieldSet.Field::newBuilder);
-
- @SuppressWarnings("unchecked")
- public EnhancedMarshaller(T defaultInstance) {
- this.defaultInstance = defaultInstance;
- parser = (Parser<T>) ((MessageLite)
defaultInstance).getParserForType();
- }
-
- @Override
- public T getMessagePrototype() {
- return defaultInstance;
- }
-
- @SuppressWarnings("unchecked")
- @Override
- public Class<T> getMessageClass() {
- return (Class<T>) defaultInstance.getClass();
- }
-
- @SuppressWarnings("unchecked")
- @Override
- public InputStream stream(T value) {
- UnknownFieldSet.Field hlcField =
localFieldBuilder.get().clear().addFixed64(HLC.INST.get()).build();
- UnknownFieldSet fieldSet =
localFieldSetBuilder.get().addField(PIGGYBACK_FIELD_ID, hlcField).build();
- return new ProtoInputStream(((Message)
value).toBuilder().setUnknownFields(fieldSet).build(), parser);
- }
-
- @SneakyThrows
- @Override
- public T parse(InputStream stream) {
- if (stream instanceof ProtoInputStream protoStream) {
- if (protoStream.parser() == parser) {
- @SuppressWarnings("unchecked")
- T message = (T) protoStream.message();
- return message;
- }
- }
- CodedInputStream cis = null;
- if (stream instanceof KnownLength) {
- int size = stream.available();
- if (size > 0 && size <= DEFAULT_MAX_MESSAGE_SIZE) {
- Reference<byte[]> ref;
- byte[] buf;
- if ((ref = bufs.get()) == null || (buf = ref.get()) == null ||
buf.length < size) {
- buf = new byte[size];
- bufs.set(new WeakReference<>(buf));
- }
- int remaining = size;
- while (remaining > 0) {
- int position = size - remaining;
- int count = stream.read(buf, position, remaining);
- if (count == -1) {
- break;
- }
- remaining -= count;
- }
- if (remaining != 0) {
- int position = size - remaining;
- throw new IllegalStateException("Wrong size: " + size + "
!= " + position);
- }
- cis = CodedInputStream.newInstance(buf, 0, size);
- } else if (size == 0) {
- return getMessagePrototype();
- }
- }
- if (cis == null) {
- cis = CodedInputStream.newInstance(stream);
- }
- cis.setSizeLimit(Integer.MAX_VALUE);
-
- // we need aliasing to be enabled
- cis.enableAliasing(true);
-
- T message = parser.parseFrom(cis,
ExtensionRegistryLite.getEmptyRegistry());
- cis.checkLastTagWas(0);
- UnknownFieldSet.Field piggybackField = ((Message)
message).getUnknownFields().getField(PIGGYBACK_FIELD_ID);
- HLC.INST.update(piggybackField.getFixed64List().get(0));
- return message;
- }
-}
diff --git
a/base-rpc/base-rpc-common/src/main/java/org/apache/bifromq/baserpc/BluePrint.java
b/base-rpc/base-rpc-common/src/main/java/org/apache/bifromq/baserpc/BluePrint.java
index 7b8b4574..47db9504 100644
---
a/base-rpc/base-rpc-common/src/main/java/org/apache/bifromq/baserpc/BluePrint.java
+++
b/base-rpc/base-rpc-common/src/main/java/org/apache/bifromq/baserpc/BluePrint.java
@@ -14,7 +14,7 @@
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
- * under the License.
+ * under the License.
*/
package org.apache.bifromq.baserpc;
@@ -25,7 +25,6 @@ import static java.util.Collections.unmodifiableMap;
import io.grpc.MethodDescriptor;
import io.grpc.ServiceDescriptor;
-import io.grpc.protobuf.lite.EnhancedMarshaller;
import java.util.ArrayList;
import java.util.Map;
import java.util.Set;
@@ -33,6 +32,7 @@ import java.util.function.Function;
import lombok.AccessLevel;
import lombok.Builder;
import lombok.NoArgsConstructor;
+import org.apache.bifromq.baserpc.marshaller.HLCStampedMarshaller;
/**
* BluePrint is a configuration class for a service. It contains the service
descriptor, method semantics, and method
@@ -40,7 +40,6 @@ import lombok.NoArgsConstructor;
public final class BluePrint {
private final ServiceDescriptor serviceDescriptor;
private final Map<String, MethodSemantic> methodSemantics;
- private final Map<String, MethodDescriptor<?, ?>> methods;
private final Map<String, MethodDescriptor<?, ?>> wrappedMethods;
private BluePrint(
@@ -50,7 +49,6 @@ public final class BluePrint {
Map<String, MethodDescriptor<?, ?>> wrappedMethods) {
this.serviceDescriptor = serviceDescriptor;
this.methodSemantics = methodSemantics;
- this.methods = methods;
this.wrappedMethods = wrappedMethods;
if (!serviceDescriptor.getMethods().containsAll(methods.values())) {
throw new RuntimeException("Some method is not defined in the
supplied service descriptor");
@@ -245,6 +243,18 @@ public final class BluePrint {
}
}
+ @NoArgsConstructor(access = AccessLevel.PRIVATE)
+ public static final class HRWPipelineUnaryMethod implements PipelineUnary {
+ public static HRWPipelineUnaryMethod getInstance() {
+ return new HRWPipelineUnaryMethod();
+ }
+
+ @Override
+ public BalanceMode mode() {
+ return BalanceMode.WCHBalanced;
+ }
+ }
+
@NoArgsConstructor(access = AccessLevel.PRIVATE)
public static final class DDStreamingMethod implements Streaming {
public static DDStreamingMethod getInstance() {
@@ -317,10 +327,8 @@ public final class BluePrint {
this.methodSemantics.add(methodSemanticValue);
this.methods.add(methodSemanticKey);
this.wrappedMethods.add(methodSemanticKey.toBuilder()
-
.setRequestMarshaller(enhance((MethodDescriptor.PrototypeMarshaller<ReqT>)
- methodSemanticKey.getRequestMarshaller()))
-
.setResponseMarshaller(enhance((MethodDescriptor.PrototypeMarshaller<RespT>)
- methodSemanticKey.getResponseMarshaller()))
+
.setRequestMarshaller(enhance(methodSemanticKey.getRequestMarshaller()))
+
.setResponseMarshaller(enhance(methodSemanticKey.getResponseMarshaller()))
.build());
return this;
}
@@ -374,9 +382,8 @@ public final class BluePrint {
return new BluePrint(serviceDescriptor, methodSemanticMap,
methodsMap, wrappedMethods);
}
- private <T> MethodDescriptor.PrototypeMarshaller<T> enhance(
- MethodDescriptor.PrototypeMarshaller<T> marshaller) {
- return new EnhancedMarshaller<>(marshaller.getMessagePrototype());
+ private <T> MethodDescriptor.Marshaller<T>
enhance(MethodDescriptor.Marshaller<T> marshaller) {
+ return new HLCStampedMarshaller<>(marshaller);
}
}
}
diff --git
a/base-rpc/base-rpc-common/src/main/java/org/apache/bifromq/baserpc/marshaller/HLCStampedInputStream.java
b/base-rpc/base-rpc-common/src/main/java/org/apache/bifromq/baserpc/marshaller/HLCStampedInputStream.java
new file mode 100644
index 00000000..0e53443a
--- /dev/null
+++
b/base-rpc/base-rpc-common/src/main/java/org/apache/bifromq/baserpc/marshaller/HLCStampedInputStream.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.bifromq.baserpc.marshaller;
+
+import com.google.protobuf.CodedOutputStream;
+import com.google.protobuf.WireFormat;
+import io.grpc.Drainable;
+import io.grpc.KnownLength;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.OutputStream;
+import org.apache.bifromq.basehlc.HLC;
+
+class HLCStampedInputStream extends InputStream implements Drainable,
KnownLength {
+ public static final int HLC_FIELD_ID = Short.MAX_VALUE;
+ // varint encoded key: unknown fieldset_number=32767, wire_type=1 (fixed64)
+ public static final byte[] HLC_TAG = new byte[3];
+ public static final int HLC_FIELD_LENGTH = HLC_TAG.length + Long.BYTES; //
3 bytes tag + 8 bytes fixed64
+
+ static {
+ CodedOutputStream cos = CodedOutputStream.newInstance(HLC_TAG);
+ try {
+ cos.writeTag(HLC_FIELD_ID, WireFormat.WIRETYPE_FIXED64);
+ } catch (IOException e) {
+ // never happens
+ }
+ }
+
+ private final long hlc;
+ private final InputStream protoStream;
+ private int cursor = 0;
+
+ HLCStampedInputStream(InputStream protoStream) {
+ assert protoStream instanceof Drainable;
+ assert protoStream instanceof KnownLength;
+ this.hlc = HLC.INST.get();
+ this.protoStream = protoStream;
+ }
+
+ @Override
+ public int drainTo(OutputStream target) throws IOException {
+ while (cursor < HLC_FIELD_LENGTH) {
+ target.write(read());
+ }
+ return HLC_FIELD_LENGTH + ((Drainable) protoStream).drainTo(target);
+ }
+
+ @Override
+ public int available() throws IOException {
+ return (HLC_FIELD_LENGTH - cursor) + protoStream.available();
+ }
+
+ @Override
+ public int read() throws IOException {
+ if (cursor < HLC_TAG.length) {
+ return HLC_TAG[cursor++] & 0xFF;
+ }
+ if (cursor < HLC_FIELD_LENGTH) {
+ int shift = (cursor - HLC_TAG.length) * Long.BYTES;
+ cursor++;
+ return (int) ((hlc >>> shift) & 0xFF);
+ }
+ int read = protoStream.read();
+ if (read == -1) {
+ return -1; // End of stream
+ }
+ cursor++;
+ return read;
+ }
+
+ @Override
+ public int read(byte[] b, int off, int len) throws IOException {
+ int total = 0;
+ while (len > 0) {
+ int r = read();
+ if (r == -1) {
+ return total == 0 ? -1 : total;
+ }
+ b[off++] = (byte) r;
+ total++;
+ len--;
+ }
+ return total;
+ }
+
+ @Override
+ public long skip(long n) throws IOException {
+ long skipped = 0;
+ while (n > 0 && cursor < HLC_FIELD_LENGTH) {
+ cursor++;
+ n--;
+ skipped++;
+ }
+ if (n > 0) {
+ skipped += protoStream.skip(n);
+ }
+ return skipped;
+ }
+
+ @Override
+ public void close() throws IOException {
+ protoStream.close();
+ }
+
+ public InputStream protoStream() {
+ return protoStream;
+ }
+}
diff --git
a/base-rpc/base-rpc-common/src/main/java/org/apache/bifromq/baserpc/marshaller/HLCStampedMarshaller.java
b/base-rpc/base-rpc-common/src/main/java/org/apache/bifromq/baserpc/marshaller/HLCStampedMarshaller.java
new file mode 100644
index 00000000..63b4e535
--- /dev/null
+++
b/base-rpc/base-rpc-common/src/main/java/org/apache/bifromq/baserpc/marshaller/HLCStampedMarshaller.java
@@ -0,0 +1,88 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.bifromq.baserpc.marshaller;
+
+import static
org.apache.bifromq.baserpc.marshaller.HLCStampedInputStream.HLC_FIELD_ID;
+import static
org.apache.bifromq.baserpc.marshaller.HLCStampedInputStream.HLC_TAG;
+
+import com.google.protobuf.Message;
+import com.google.protobuf.UnknownFieldSet;
+import io.grpc.MethodDescriptor;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.PushbackInputStream;
+import lombok.SneakyThrows;
+import lombok.extern.slf4j.Slf4j;
+import org.apache.bifromq.basehlc.HLC;
+
+/**
+ * Marshaller that prepend HLC timestamp to encoded proto bytes.
+ */
+@Slf4j
+public class HLCStampedMarshaller<T> implements MethodDescriptor.Marshaller<T>
{
+ private final MethodDescriptor.Marshaller<T> delegate;
+
+ public HLCStampedMarshaller(MethodDescriptor.Marshaller<T> delegate) {
+ this.delegate = delegate;
+ }
+
+ @SuppressWarnings("unchecked")
+ public Class<T> getMessageClass() {
+ return (Class<T>) delegate.getClass();
+ }
+
+ @Override
+ public InputStream stream(T value) {
+ return new HLCStampedInputStream(delegate.stream(value));
+ }
+
+ @SneakyThrows
+ @Override
+ public T parse(InputStream stream) {
+ if (stream instanceof HLCStampedInputStream) {
+ // optimized for in-proc
+ return delegate.parse(((HLCStampedInputStream)
stream).protoStream());
+ }
+ PushbackInputStream pis = new PushbackInputStream(stream, 3);
+ byte b0 = (byte) pis.read();
+ byte b1 = (byte) pis.read();
+ byte b2 = (byte) pis.read();
+ if (HLC_TAG[0] != b0 || HLC_TAG[1] != b1 || HLC_TAG[2] != b2) {
+ // backward compatible with obsolete EnhancedMarshaller
+ pis.unread(b2);
+ pis.unread(b1);
+ pis.unread(b0);
+ T message = delegate.parse(pis);
+ UnknownFieldSet.Field piggybackField = ((Message)
message).getUnknownFields().getField(HLC_FIELD_ID);
+ HLC.INST.update(piggybackField.getFixed64List().get(0));
+ return message;
+ }
+ long hlc = 0;
+ for (int i = 0; i < Long.BYTES; i++) {
+ int b = pis.read();
+ if (b == -1) {
+ throw new IOException("Unexpected end of stream while reading
HLC");
+ }
+ hlc |= ((long) b & 0xFF) << (i * Long.BYTES);
+ }
+ HLC.INST.update(hlc);
+ return delegate.parse(pis);
+ }
+}
diff --git
a/base-rpc/base-rpc-common/src/test/java/org/apache/bifromq/baserpc/marshaller/HLCStampedInputStreamTest.java
b/base-rpc/base-rpc-common/src/test/java/org/apache/bifromq/baserpc/marshaller/HLCStampedInputStreamTest.java
new file mode 100644
index 00000000..6b8260b0
--- /dev/null
+++
b/base-rpc/base-rpc-common/src/test/java/org/apache/bifromq/baserpc/marshaller/HLCStampedInputStreamTest.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.bifromq.baserpc.marshaller;
+
+import static io.grpc.protobuf.lite.ProtoLiteUtils.marshaller;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertTrue;
+
+import com.google.protobuf.Struct;
+import com.google.protobuf.UnknownFieldSet;
+import com.google.protobuf.Value;
+import io.grpc.MethodDescriptor;
+import java.io.ByteArrayOutputStream;
+import java.io.InputStream;
+import lombok.SneakyThrows;
+import org.testng.annotations.Test;
+
+public class HLCStampedInputStreamTest {
+ @SneakyThrows
+ @Test
+ public void testDrainTo() {
+ Struct orig = Struct.newBuilder()
+ .putFields("key", Value.newBuilder().setNumberValue(123).build())
+ .build();
+ MethodDescriptor.Marshaller<Struct> baseMarshaller =
marshaller(Struct.getDefaultInstance());
+ InputStream stream = baseMarshaller.stream(orig);
+ HLCStampedInputStream stampedStream = new
HLCStampedInputStream(stream);
+ ByteArrayOutputStream outputStream = new ByteArrayOutputStream();
+ stampedStream.drainTo(outputStream);
+ Struct decoded = Struct.parseFrom(outputStream.toByteArray());
+ UnknownFieldSet.Field piggybackField =
decoded.getUnknownFields().getField(Short.MAX_VALUE);
+ assertTrue(piggybackField != null &&
piggybackField.getFixed64List().size() == 1);
+ assertEquals(orig,
decoded.toBuilder().setUnknownFields(UnknownFieldSet.getDefaultInstance()).build());
+ }
+
+ @SneakyThrows
+ @Test
+ public void testAvailable() {
+ Struct orig = Struct.newBuilder()
+ .putFields("key", Value.newBuilder().setNumberValue(123).build())
+ .build();
+ MethodDescriptor.Marshaller<Struct> baseMarshaller =
marshaller(Struct.getDefaultInstance());
+ InputStream stream = baseMarshaller.stream(orig);
+ HLCStampedInputStream stampedStream = new
HLCStampedInputStream(stream);
+ assertTrue(stampedStream.available() > 0);
+ stampedStream.drainTo(new ByteArrayOutputStream());
+ assertEquals(stampedStream.available(), 0);
+ assertEquals(stampedStream.read(), -1);
+ assertEquals(stampedStream.available(), 0);
+ }
+}
diff --git
a/base-rpc/base-rpc-common/src/test/java/org/apache/bifromq/baserpc/marshaller/HLCStampedMarshallerTest.java
b/base-rpc/base-rpc-common/src/test/java/org/apache/bifromq/baserpc/marshaller/HLCStampedMarshallerTest.java
new file mode 100644
index 00000000..57e85af9
--- /dev/null
+++
b/base-rpc/base-rpc-common/src/test/java/org/apache/bifromq/baserpc/marshaller/HLCStampedMarshallerTest.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.bifromq.baserpc.marshaller;
+
+import static io.grpc.protobuf.lite.ProtoLiteUtils.marshaller;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertTrue;
+
+import com.google.protobuf.Struct;
+import com.google.protobuf.UnknownFieldSet;
+import com.google.protobuf.Value;
+import io.grpc.MethodDescriptor;
+import java.io.InputStream;
+import lombok.SneakyThrows;
+import org.apache.bifromq.basehlc.HLC;
+import org.testng.annotations.Test;
+
+public class HLCStampedMarshallerTest {
+ @Test
+ public void testHLCStamping() {
+ Struct orig = Struct.newBuilder().putFields("key",
Value.newBuilder().setNumberValue(123).build()).build();
+
+ MethodDescriptor.Marshaller<Struct> baseMarshaller =
marshaller(Struct.getDefaultInstance());
+ HLCStampedMarshaller<Struct> stampedMarshaller = new
HLCStampedMarshaller<>(baseMarshaller);
+
+ long before = HLC.INST.get();
+ InputStream stream = stampedMarshaller.stream(orig);
+ Struct decoded = stampedMarshaller.parse(stream);
+ long after = HLC.INST.get();
+
+ assertEquals(orig, decoded);
+
+ assertTrue(after >= before);
+ }
+
+ @SneakyThrows
+ @Test
+ public void testForwardCompatibility() {
+ Struct orig = Struct.newBuilder()
+ .putFields("key", Value.newBuilder().setNumberValue(123).build())
+ .build();
+ MethodDescriptor.Marshaller<Struct> baseMarshaller =
marshaller(Struct.getDefaultInstance());
+ HLCStampedMarshaller<Struct> stampedMarshaller = new
HLCStampedMarshaller<>(baseMarshaller);
+ InputStream stream = stampedMarshaller.stream(orig);
+ Struct decoded = Struct.parseFrom(stream);
+ UnknownFieldSet.Field piggybackField =
decoded.getUnknownFields().getField(Short.MAX_VALUE);
+ assertTrue(piggybackField != null &&
piggybackField.getFixed64List().size() == 1);
+ assertEquals(orig,
decoded.toBuilder().setUnknownFields(UnknownFieldSet.getDefaultInstance()).build());
+ }
+
+ @Test
+ public void testBackwardCompatibility() {
+ long before = HLC.INST.get();
+ Struct orig = Struct.newBuilder()
+ .putFields("key", Value.newBuilder().setNumberValue(123).build())
+ .setUnknownFields(UnknownFieldSet.newBuilder()
+ .addField(Short.MAX_VALUE, UnknownFieldSet.Field.newBuilder()
+ .addFixed64(before).build())
+ .build())
+ .build();
+ MethodDescriptor.Marshaller<Struct> baseMarshaller =
marshaller(Struct.getDefaultInstance());
+ HLCStampedMarshaller<Struct> stampedMarshaller = new
HLCStampedMarshaller<>(baseMarshaller);
+ stampedMarshaller.parse(orig.toByteString().newInput());
+ long after = HLC.INST.get();
+ assertTrue(after >= before);
+ }
+}
diff --git
a/bifromq-dist/bifromq-dist-client/src/main/java/org/apache/bifromq/dist/client/scheduler/BatchPubCallBuilderFactory.java
b/bifromq-dist/bifromq-dist-client/src/main/java/org/apache/bifromq/dist/client/scheduler/BatchPubCallBuilderFactory.java
index c869f5ba..c4a14c01 100644
---
a/bifromq-dist/bifromq-dist-client/src/main/java/org/apache/bifromq/dist/client/scheduler/BatchPubCallBuilderFactory.java
+++
b/bifromq-dist/bifromq-dist-client/src/main/java/org/apache/bifromq/dist/client/scheduler/BatchPubCallBuilderFactory.java
@@ -14,13 +14,15 @@
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
- * under the License.
+ * under the License.
*/
package org.apache.bifromq.dist.client.scheduler;
import static java.util.Collections.emptyMap;
+import java.time.Duration;
+import java.util.UUID;
import org.apache.bifromq.baserpc.client.IRPCClient;
import org.apache.bifromq.basescheduler.IBatchCall;
import org.apache.bifromq.basescheduler.IBatchCallBuilder;
@@ -30,7 +32,6 @@ import org.apache.bifromq.dist.rpc.proto.DistReply;
import org.apache.bifromq.dist.rpc.proto.DistRequest;
import org.apache.bifromq.dist.rpc.proto.DistServiceGrpc;
import org.apache.bifromq.sysprops.props.DataPlaneMaxBurstLatencyMillis;
-import java.time.Duration;
public class BatchPubCallBuilderFactory implements
IBatchCallBuilderFactory<PubRequest, PubResult, PubCallBatcherKey> {
private final IRPCClient rpcClient;
@@ -45,8 +46,9 @@ public class BatchPubCallBuilderFactory implements
IBatchCallBuilderFactory<PubR
public IBatchCallBuilder<PubRequest, PubResult, PubCallBatcherKey>
newBuilder(String name,
PubCallBatcherKey batcherKey) {
IRPCClient.IRequestPipeline<DistRequest, DistReply> ppln =
- rpcClient.createRequestPipeline(batcherKey.tenantId(), null, null,
emptyMap(),
- DistServiceGrpc.getDistMethod());
+ rpcClient.createRequestPipeline(batcherKey.tenantId(), null,
+ // using random UUID for pipeline routing key to achieve
better load balancing
+ UUID.randomUUID().toString(), emptyMap(),
DistServiceGrpc.getDistMethod());
return new IBatchCallBuilder<>() {
@Override
public IBatchCall<PubRequest, PubResult, PubCallBatcherKey>
newBatchCall() {
diff --git
a/bifromq-dist/bifromq-dist-rpc-definition/src/main/java/org/apache/bifromq/dist/RPCBluePrint.java
b/bifromq-dist/bifromq-dist-rpc-definition/src/main/java/org/apache/bifromq/dist/RPCBluePrint.java
index 36d3ec23..8ac271fe 100644
---
a/bifromq-dist/bifromq-dist-rpc-definition/src/main/java/org/apache/bifromq/dist/RPCBluePrint.java
+++
b/bifromq-dist/bifromq-dist-rpc-definition/src/main/java/org/apache/bifromq/dist/RPCBluePrint.java
@@ -14,7 +14,7 @@
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
- * under the License.
+ * under the License.
*/
package org.apache.bifromq.dist;
@@ -27,6 +27,6 @@ public class RPCBluePrint {
.serviceDescriptor(DistServiceGrpc.getServiceDescriptor())
.methodSemantic(DistServiceGrpc.getMatchMethod(),
BluePrint.WRUnaryMethod.getInstance())
.methodSemantic(DistServiceGrpc.getUnmatchMethod(),
BluePrint.WRUnaryMethod.getInstance())
- .methodSemantic(DistServiceGrpc.getDistMethod(),
BluePrint.WRRPipelineUnaryMethod.getInstance())
+ .methodSemantic(DistServiceGrpc.getDistMethod(),
BluePrint.HRWPipelineUnaryMethod.getInstance())
.build();
}