This is an automated email from the ASF dual-hosted git repository.

xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new fb02f04516 Fix Direct Memory OOM on Server (#15335)
fb02f04516 is described below

commit fb02f045169eb34aee0587c040b078fabf803dbb
Author: Chaitanya Deepthi <45308220+deepthi...@users.noreply.github.com>
AuthorDate: Sat Apr 12 11:35:38 2025 -0400

    Fix Direct Memory OOM on Server (#15335)
    
    * Direct Memory on Server OOM issue fix
    
    * Add a method in the factory class
    
    * update license and checkstyle fixes
    
    * change it to lower case
    
    * Remove channels removal from QueryServer using listener
    
    * Review comments fix
    
    * Delete DirectOOMServerHandler
    
    * Add license header
    
    * Fix checkstyle changes
    
    * checkstyle fixes
    
    * spotless fix
    
    * change the log message
    
    * remove stack trace
    
    * fixes in the arguments
    
    * Edit comment
    
    * Make variables final
    
    * checkstyle changes
    
    * checkstyle fixes
    
    * Add Server and Broker metrics and logs to track RESERVED_MEMORY
    
    * Add a common getReservedMemory() method
    
    * Remove unused imports
    
    * Remove unused imports
---
 .../apache/pinot/common/metrics/ServerMeter.java   |  2 +
 .../core/transport/ChannelHandlerFactory.java      |  7 +-
 .../pinot/core/transport/DirectOOMHandler.java     | 87 +++++++++++++++++-----
 .../PooledByteBufAllocatorWithLimits.java          | 72 ++++++++++++++++++
 .../apache/pinot/core/transport/QueryServer.java   | 13 +++-
 .../pinot/core/transport/ServerChannels.java       | 20 ++---
 6 files changed, 170 insertions(+), 31 deletions(-)

diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/metrics/ServerMeter.java 
b/pinot-common/src/main/java/org/apache/pinot/common/metrics/ServerMeter.java
index e3c94bda96..61270c4060 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/metrics/ServerMeter.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/metrics/ServerMeter.java
@@ -133,6 +133,8 @@ public enum ServerMeter implements AbstractMetrics.Meter {
   TOTAL_THREAD_CPU_TIME_MILLIS("millis", false),
   LARGE_QUERY_RESPONSE_SIZE_EXCEPTIONS("exceptions", false),
 
+  DIRECT_MEMORY_OOM("directMemoryOOMCount", true),
+
   // Multi-stage
   /**
    * Number of times the max number of rows in the hash table has been reached.
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/transport/ChannelHandlerFactory.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/transport/ChannelHandlerFactory.java
index 00545f2607..9f46cf3fae 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/transport/ChannelHandlerFactory.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/transport/ChannelHandlerFactory.java
@@ -19,6 +19,7 @@
 package org.apache.pinot.core.transport;
 
 import io.netty.channel.ChannelHandler;
+import io.netty.channel.socket.ServerSocketChannel;
 import io.netty.channel.socket.SocketChannel;
 import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
 import io.netty.handler.codec.LengthFieldPrepender;
@@ -103,7 +104,9 @@ public class ChannelHandlerFactory {
   }
 
   public static ChannelHandler getDirectOOMHandler(QueryRouter queryRouter, 
ServerRoutingInstance serverRoutingInstance,
-      ConcurrentHashMap<ServerRoutingInstance, ServerChannels.ServerChannel> 
serverToChannelMap) {
-    return new DirectOOMHandler(queryRouter, serverRoutingInstance, 
serverToChannelMap);
+      ConcurrentHashMap<ServerRoutingInstance, ServerChannels.ServerChannel> 
serverToChannelMap,
+      ConcurrentHashMap<SocketChannel, Boolean> allChannels, 
ServerSocketChannel serverSocketChannel) {
+    return new DirectOOMHandler(queryRouter, serverRoutingInstance, 
serverToChannelMap, allChannels,
+        serverSocketChannel);
   }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/transport/DirectOOMHandler.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/transport/DirectOOMHandler.java
index 0ba00327f6..13b89edd0b 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/transport/DirectOOMHandler.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/transport/DirectOOMHandler.java
@@ -20,20 +20,25 @@ package org.apache.pinot.core.transport;
 
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.ChannelInboundHandlerAdapter;
+import io.netty.channel.socket.ServerSocketChannel;
+import io.netty.channel.socket.SocketChannel;
 import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.atomic.AtomicBoolean;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.pinot.common.metrics.BrokerMeter;
 import org.apache.pinot.common.metrics.BrokerMetrics;
+import org.apache.pinot.common.metrics.ServerMeter;
+import org.apache.pinot.common.metrics.ServerMetrics;
 import org.apache.pinot.spi.exception.QueryCancelledException;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
+
 /**
- * Handling netty direct memory OOM. In this case there is a great chance that 
multiple channels are receiving
- * large data tables from servers concurrently. We want to close all channels 
to servers to proactively release
- * the direct memory, because the execution of netty threads can all block in 
allocating direct memory, in which case
- * no one will reach channelRead0.
+ * Handling netty direct memory OOM on broker and server. In this case there 
is a great chance that multiple channels
+ * are receiving large data tables from servers concurrently. We want to close 
all channels to servers to
+ * proactively release the direct memory, because the execution of netty 
threads can all block in allocating direct
+ * memory, in which case no one will reach channelRead0.
  */
 public class DirectOOMHandler extends ChannelInboundHandlerAdapter {
   private static final Logger LOGGER = 
LoggerFactory.getLogger(DirectOOMHandler.class);
@@ -42,12 +47,17 @@ public class DirectOOMHandler extends 
ChannelInboundHandlerAdapter {
   private final ServerRoutingInstance _serverRoutingInstance;
   private final ConcurrentHashMap<ServerRoutingInstance, 
ServerChannels.ServerChannel> _serverToChannelMap;
   private volatile boolean _silentShutDown = false;
+  private final ConcurrentHashMap<SocketChannel, Boolean> _allChannels;
+  private final ServerSocketChannel _serverSocketChannel;
 
   public DirectOOMHandler(QueryRouter queryRouter, ServerRoutingInstance 
serverRoutingInstance,
-      ConcurrentHashMap<ServerRoutingInstance, ServerChannels.ServerChannel> 
serverToChannelMap) {
+      ConcurrentHashMap<ServerRoutingInstance, ServerChannels.ServerChannel> 
serverToChannelMap,
+      ConcurrentHashMap<SocketChannel, Boolean> allChannels, 
ServerSocketChannel serverSocketChannel) {
     _queryRouter = queryRouter;
     _serverRoutingInstance = serverRoutingInstance;
     _serverToChannelMap = serverToChannelMap;
+    _allChannels = allChannels;
+    _serverSocketChannel = serverSocketChannel;
   }
 
   public void setSilentShutDown() {
@@ -63,25 +73,64 @@ public class DirectOOMHandler extends 
ChannelInboundHandlerAdapter {
     ctx.fireChannelInactive();
   }
 
+  /**
+   * Closes and removes all active channels from the map to release direct 
memory.
+   */
+  private void closeAllChannels() {
+    LOGGER.warn("OOM detected: Closing all channels to server to release 
direct memory");
+    for (SocketChannel channel : _allChannels.keySet()) {
+      try {
+        if (channel != null) {
+          LOGGER.info("Closing channel: {}", channel);
+          setSilentShutdown(channel);
+          channel.close();
+        }
+      } catch (Exception e) {
+        LOGGER.error("Error while closing channel: {}", channel, e);
+      } finally {
+        if (channel != null) {
+          _allChannels.remove(channel);
+        }
+      }
+    }
+  }
+
+  // silent shutdown for the channels without firing channelInactive
+  private void setSilentShutdown(SocketChannel socketChannel) {
+    if (socketChannel != null) {
+      DirectOOMHandler directOOMHandler = 
socketChannel.pipeline().get(DirectOOMHandler.class);
+      if (directOOMHandler != null) {
+        directOOMHandler.setSilentShutDown();
+      }
+    }
+  }
+
   @Override
-  public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) 
throws Exception {
+  public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
     // catch direct memory oom here
-    if (cause instanceof OutOfMemoryError
-        && StringUtils.containsIgnoreCase(cause.getMessage(), "direct 
buffer")) {
-      BrokerMetrics.get().addMeteredGlobalValue(BrokerMeter.DIRECT_MEMORY_OOM, 
1L);
+    if (cause instanceof OutOfMemoryError && 
StringUtils.containsIgnoreCase(cause.getMessage(), "direct buffer")) {
       // only one thread can get here and do the shutdown
       if (DIRECT_OOM_SHUTTING_DOWN.compareAndSet(false, true)) {
         try {
-          LOGGER.error("Closing ALL channels to servers, as we are running out 
of direct memory "
-              + "while receiving response from {}", _serverRoutingInstance, 
cause);
-          // close all channels to servers
-          _serverToChannelMap.keySet().forEach(serverRoutingInstance -> {
-            ServerChannels.ServerChannel removed = 
_serverToChannelMap.remove(serverRoutingInstance);
-            removed.closeChannel();
-            removed.setSilentShutdown();
-          });
-          _queryRouter.markServerDown(_serverRoutingInstance,
-              new QueryCancelledException("Query cancelled as broker is out of 
direct memory"));
+          if (_serverToChannelMap != null && !_serverToChannelMap.isEmpty()) {
+            LOGGER.error("Closing ALL channels to servers, as we are running 
out of direct memory "
+                + "while receiving response from {}", _serverRoutingInstance, 
cause); // broker side direct OOM handler
+            
BrokerMetrics.get().addMeteredGlobalValue(BrokerMeter.DIRECT_MEMORY_OOM, 1L);
+
+            // close all channels to servers
+            _serverToChannelMap.keySet().forEach(serverRoutingInstance -> {
+              ServerChannels.ServerChannel removed = 
_serverToChannelMap.remove(serverRoutingInstance);
+              removed.closeChannel();
+              removed.setSilentShutdown();
+            });
+            _queryRouter.markServerDown(_serverRoutingInstance,
+                new QueryCancelledException("Query cancelled as broker is out 
of direct memory"));
+          } else if (_allChannels != null && !_allChannels.isEmpty()) { // 
server side direct OOM handler
+            LOGGER.error("Closing channel from broker, as we are running out 
of direct memory "
+                + "while initiating request to server channel {}", 
_serverSocketChannel, cause);
+            
ServerMetrics.get().addMeteredGlobalValue(ServerMeter.DIRECT_MEMORY_OOM, 1L);
+            closeAllChannels();
+          }
         } catch (Exception e) {
           LOGGER.error("Caught exception while handling direct memory OOM", e);
         } finally {
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/transport/PooledByteBufAllocatorWithLimits.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/transport/PooledByteBufAllocatorWithLimits.java
new file mode 100644
index 0000000000..54ab85dbeb
--- /dev/null
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/transport/PooledByteBufAllocatorWithLimits.java
@@ -0,0 +1,72 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.transport;
+
+import io.grpc.netty.shaded.io.netty.util.internal.SystemPropertyUtil;
+import io.netty.buffer.PooledByteBufAllocator;
+import io.netty.buffer.PooledByteBufAllocatorMetric;
+import io.netty.util.NettyRuntime;
+import io.netty.util.internal.PlatformDependent;
+import java.lang.reflect.Field;
+import java.util.concurrent.atomic.AtomicLong;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * Utility class for setting limits in the PooledByteBufAllocator.
+ */
+public class PooledByteBufAllocatorWithLimits {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(PooledByteBufAllocatorWithLimits.class);
+
+  private PooledByteBufAllocatorWithLimits() {
+  }
+
+  // Reduce the number of direct arenas when using netty channels on broker 
and server side to limit the direct
+  // memory usage
+  public static PooledByteBufAllocator 
getBufferAllocatorWithLimits(PooledByteBufAllocatorMetric metric) {
+    int defaultPageSize = 
SystemPropertyUtil.getInt("io.netty.allocator.pageSize", 8192);
+    final int defaultMinNumArena = NettyRuntime.availableProcessors() * 2;
+    int defaultMaxOrder = 
SystemPropertyUtil.getInt("io.netty.allocator.maxOrder", 9);
+    final int defaultChunkSize = defaultPageSize << defaultMaxOrder;
+    long maxDirectMemory = PlatformDependent.maxDirectMemory();
+    long remainingDirectMemory = maxDirectMemory - getReservedMemory();
+
+    int numDirectArenas = Math.max(0, 
SystemPropertyUtil.getInt("io.netty.allocator.numDirectArenas",
+        (int) Math.min(defaultMinNumArena, remainingDirectMemory / 
defaultChunkSize / 5)));
+    boolean useCacheForAllThreads = 
SystemPropertyUtil.getBoolean("io.netty.allocator.useCacheForAllThreads", 
false);
+
+    return new PooledByteBufAllocator(true, metric.numHeapArenas(), 
numDirectArenas, defaultPageSize, defaultMaxOrder,
+        metric.smallCacheSize(), metric.normalCacheSize(), 
useCacheForAllThreads);
+  }
+
+  //Get reserved direct memory allocated so far
+  private static long getReservedMemory() {
+    try {
+      Class<?> bitsClass = Class.forName("java.nio.Bits");
+      Field reservedMemoryField = 
bitsClass.getDeclaredField("RESERVED_MEMORY");
+      reservedMemoryField.setAccessible(true);
+      AtomicLong reserved = (AtomicLong) reservedMemoryField.get(null);
+      return reserved.get();
+    } catch (Exception e) {
+      LOGGER.error("Failed to get the direct reserved memory");
+      return 0;
+    }
+  }
+}
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/transport/QueryServer.java 
b/pinot-core/src/main/java/org/apache/pinot/core/transport/QueryServer.java
index 0e1c7d1aa0..f7e51c2540 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/transport/QueryServer.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/transport/QueryServer.java
@@ -36,6 +36,7 @@ import io.netty.channel.nio.NioEventLoopGroup;
 import io.netty.channel.socket.ServerSocketChannel;
 import io.netty.channel.socket.SocketChannel;
 import io.netty.channel.socket.nio.NioServerSocketChannel;
+import java.util.concurrent.ConcurrentHashMap;
 import java.util.concurrent.TimeUnit;
 import org.apache.pinot.common.config.NettyConfig;
 import org.apache.pinot.common.config.TlsConfig;
@@ -60,6 +61,8 @@ public class QueryServer {
   private final Class<? extends ServerSocketChannel> _channelClass;
   private final ChannelHandler _instanceRequestHandler;
   private ServerSocketChannel _channel;
+  private final ConcurrentHashMap<SocketChannel, Boolean> _allChannels = new 
ConcurrentHashMap<>();
+
 
   /**
    * Create an unsecured server instance
@@ -116,6 +119,9 @@ public class QueryServer {
       PooledByteBufAllocator bufAllocator = PooledByteBufAllocator.DEFAULT;
       PooledByteBufAllocatorMetric metric = bufAllocator.metric();
       ServerMetrics metrics = ServerMetrics.get();
+      PooledByteBufAllocator bufAllocatorWithLimits =
+          
PooledByteBufAllocatorWithLimits.getBufferAllocatorWithLimits(metric);
+      metric = bufAllocatorWithLimits.metric();
       
metrics.setOrUpdateGlobalGauge(ServerGauge.NETTY_POOLED_USED_DIRECT_MEMORY, 
metric::usedDirectMemory);
       
metrics.setOrUpdateGlobalGauge(ServerGauge.NETTY_POOLED_USED_HEAP_MEMORY, 
metric::usedHeapMemory);
       metrics.setOrUpdateGlobalGauge(ServerGauge.NETTY_POOLED_ARENAS_DIRECT, 
metric::numDirectArenas);
@@ -126,9 +132,14 @@ public class QueryServer {
       metrics.setOrUpdateGlobalGauge(ServerGauge.NETTY_POOLED_CHUNK_SIZE, 
metric::chunkSize);
       _channel = (ServerSocketChannel) serverBootstrap.group(_bossGroup, 
_workerGroup).channel(_channelClass)
           .option(ChannelOption.SO_BACKLOG, 
128).childOption(ChannelOption.SO_KEEPALIVE, true)
-          .option(ChannelOption.ALLOCATOR, bufAllocator).childHandler(new 
ChannelInitializer<SocketChannel>() {
+          .option(ChannelOption.ALLOCATOR, bufAllocatorWithLimits)
+          .childHandler(new ChannelInitializer<SocketChannel>() {
             @Override
             protected void initChannel(SocketChannel ch) {
+              _allChannels.put(ch, true);
+
+              ch.pipeline()
+                  .addLast(ChannelHandlerFactory.getDirectOOMHandler(null, 
null, null, _allChannels, _channel));
               if (_tlsConfig != null) {
                 // Add SSL handler first to encrypt and decrypt everything.
                 ch.pipeline()
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/transport/ServerChannels.java 
b/pinot-core/src/main/java/org/apache/pinot/core/transport/ServerChannels.java
index b1e0d8c60b..451f4f887b 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/transport/ServerChannels.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/transport/ServerChannels.java
@@ -159,6 +159,9 @@ public class ServerChannels {
       _serverRoutingInstance = serverRoutingInstance;
       PooledByteBufAllocator bufAllocator = PooledByteBufAllocator.DEFAULT;
       PooledByteBufAllocatorMetric metric = bufAllocator.metric();
+      PooledByteBufAllocator bufAllocatorWithLimits =
+          
PooledByteBufAllocatorWithLimits.getBufferAllocatorWithLimits(metric);
+      metric = bufAllocatorWithLimits.metric();
       
_brokerMetrics.setOrUpdateGlobalGauge(BrokerGauge.NETTY_POOLED_USED_DIRECT_MEMORY,
 metric::usedDirectMemory);
       
_brokerMetrics.setOrUpdateGlobalGauge(BrokerGauge.NETTY_POOLED_USED_HEAP_MEMORY,
 metric::usedHeapMemory);
       
_brokerMetrics.setOrUpdateGlobalGauge(BrokerGauge.NETTY_POOLED_ARENAS_DIRECT, 
metric::numDirectArenas);
@@ -169,26 +172,25 @@ public class ServerChannels {
       
_brokerMetrics.setOrUpdateGlobalGauge(BrokerGauge.NETTY_POOLED_CHUNK_SIZE, 
metric::chunkSize);
 
       _bootstrap = new 
Bootstrap().remoteAddress(serverRoutingInstance.getHostname(), 
serverRoutingInstance.getPort())
-          .option(ChannelOption.ALLOCATOR, bufAllocator)
-          
.group(_eventLoopGroup).channel(_channelClass).option(ChannelOption.SO_KEEPALIVE,
 true)
-          .handler(new ChannelInitializer<SocketChannel>() {
+          .option(ChannelOption.ALLOCATOR, 
bufAllocatorWithLimits).group(_eventLoopGroup).channel(_channelClass)
+          .option(ChannelOption.SO_KEEPALIVE, true).handler(new 
ChannelInitializer<SocketChannel>() {
             @Override
             protected void initChannel(SocketChannel ch) {
               if (_tlsConfig != null) {
                 // Add SSL handler first to encrypt and decrypt everything.
-                ch.pipeline().addLast(
-                    ChannelHandlerFactory.SSL, 
ChannelHandlerFactory.getClientTlsHandler(_tlsConfig, ch));
+                ch.pipeline()
+                    .addLast(ChannelHandlerFactory.SSL, 
ChannelHandlerFactory.getClientTlsHandler(_tlsConfig, ch));
               }
 
               
ch.pipeline().addLast(ChannelHandlerFactory.getLengthFieldBasedFrameDecoder());
               
ch.pipeline().addLast(ChannelHandlerFactory.getLengthFieldPrepender());
               ch.pipeline().addLast(
-                  ChannelHandlerFactory.getDirectOOMHandler(_queryRouter, 
_serverRoutingInstance, _serverToChannelMap)
-              );
+                  ChannelHandlerFactory.getDirectOOMHandler(_queryRouter, 
_serverRoutingInstance, _serverToChannelMap,
+                      null, null));
               // NOTE: data table de-serialization happens inside this handler
               // Revisit if this becomes a bottleneck
-              ch.pipeline().addLast(ChannelHandlerFactory
-                      .getDataTableHandler(_queryRouter, 
_serverRoutingInstance, _brokerMetrics));
+              ch.pipeline().addLast(
+                  ChannelHandlerFactory.getDataTableHandler(_queryRouter, 
_serverRoutingInstance, _brokerMetrics));
             }
           });
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to