vvivekiyer commented on code in PR #15798:
URL: https://github.com/apache/pinot/pull/15798#discussion_r2114790390


##########
pinot-core/src/main/java/org/apache/pinot/core/accounting/ResourceUsageAccountantFactory.java:
##########
@@ -0,0 +1,383 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.core.accounting;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import javax.annotation.Nullable;
+import org.apache.pinot.spi.accounting.QueryResourceTracker;
+import org.apache.pinot.spi.accounting.ThreadAccountantFactory;
+import org.apache.pinot.spi.accounting.ThreadExecutionContext;
+import org.apache.pinot.spi.accounting.ThreadResourceTracker;
+import org.apache.pinot.spi.accounting.ThreadResourceUsageProvider;
+import org.apache.pinot.spi.accounting.TrackingScope;
+import org.apache.pinot.spi.config.instance.InstanceType;
+import org.apache.pinot.spi.env.PinotConfiguration;
+import org.apache.pinot.spi.trace.Tracing;
+import org.apache.pinot.spi.utils.CommonConstants;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+public class ResourceUsageAccountantFactory implements ThreadAccountantFactory 
{
+
+  @Override
+  public ResourceUsageAccountant init(PinotConfiguration config, String 
instanceId, InstanceType instanceType) {
+    return new ResourceUsageAccountant(config, instanceId, instanceType);
+  }
+
+  public static class ResourceUsageAccountant extends 
Tracing.DefaultThreadResourceUsageAccountant {
+    private static final Logger LOGGER = 
LoggerFactory.getLogger(ResourceUsageAccountant.class);
+    private static final String ACCOUNTANT_TASK_NAME = 
"ResourceUsageAccountant";
+    private static final int ACCOUNTANT_PRIORITY = 4;
+
+    private static final ExecutorService EXECUTOR_SERVICE = 
Executors.newFixedThreadPool(1, r -> {
+      Thread thread = new Thread(r);
+      thread.setPriority(ACCOUNTANT_PRIORITY);
+      thread.setDaemon(true);
+      thread.setName(ACCOUNTANT_TASK_NAME);
+      return thread;
+    });
+
+    private final PinotConfiguration _config;
+
+    // the map to track stats entry for each thread, the entry will 
automatically be added when one calls
+    // setThreadResourceUsageProvider on the thread, including but not limited 
to
+    // server worker thread, runner thread, broker jetty thread, or broker 
netty thread
+    private final ConcurrentHashMap<Thread, 
CPUMemThreadLevelAccountingObjects.ThreadEntry> _threadEntriesMap =
+        new ConcurrentHashMap<>();
+
+    private final ThreadLocal<CPUMemThreadLevelAccountingObjects.ThreadEntry> 
_threadLocalEntry =
+        ThreadLocal.withInitial(() -> {
+          CPUMemThreadLevelAccountingObjects.ThreadEntry ret = new 
CPUMemThreadLevelAccountingObjects.ThreadEntry();
+          _threadEntriesMap.put(Thread.currentThread(), ret);
+          LOGGER.info("Adding thread to _threadLocalEntry: {}", 
Thread.currentThread().getName());
+          return ret;
+        });
+
+    // ThreadResourceUsageProvider(ThreadMXBean wrapper) per runner/worker 
thread
+    private final ThreadLocal<ThreadResourceUsageProvider> 
_threadResourceUsageProvider;
+
+    // track thread cpu time
+    private final boolean _isThreadCPUSamplingEnabled;
+
+    // track memory usage
+    private final boolean _isThreadMemorySamplingEnabled;
+
+    // is sampling allowed for MSE queries
+    private final boolean _isThreadSamplingEnabledForMSE;
+
+    // instance id of the current instance, for logging purpose
+    private final String _instanceId;
+
+    private final WatcherTask _watcherTask;
+
+    private final Map<TrackingScope, ResourceAggregator> _resourceAggregators;
+
+    private final InstanceType _instanceType;
+
+    public ResourceUsageAccountant(PinotConfiguration config, String 
instanceId, InstanceType instanceType) {
+      LOGGER.info("Initializing ResourceUsageAccountant");
+      _config = config;
+      _instanceId = instanceId;
+
+      boolean threadCpuTimeMeasurementEnabled = 
ThreadResourceUsageProvider.isThreadCpuTimeMeasurementEnabled();
+      boolean threadMemoryMeasurementEnabled = 
ThreadResourceUsageProvider.isThreadMemoryMeasurementEnabled();
+      LOGGER.info("threadCpuTimeMeasurementEnabled: {}, 
threadMemoryMeasurementEnabled: {}",
+          threadCpuTimeMeasurementEnabled, threadMemoryMeasurementEnabled);
+
+      boolean cpuSamplingConfig = 
config.getProperty(CommonConstants.Accounting.CONFIG_OF_ENABLE_THREAD_CPU_SAMPLING,
+          CommonConstants.Accounting.DEFAULT_ENABLE_THREAD_CPU_SAMPLING);
+      boolean memorySamplingConfig =
+          
config.getProperty(CommonConstants.Accounting.CONFIG_OF_ENABLE_THREAD_MEMORY_SAMPLING,
+              
CommonConstants.Accounting.DEFAULT_ENABLE_THREAD_MEMORY_SAMPLING);
+      LOGGER.info("cpuSamplingConfig: {}, memorySamplingConfig: {}", 
cpuSamplingConfig, memorySamplingConfig);
+
+      _instanceType = instanceType;
+      _isThreadCPUSamplingEnabled = cpuSamplingConfig && 
threadCpuTimeMeasurementEnabled;
+      _isThreadMemorySamplingEnabled = memorySamplingConfig && 
threadMemoryMeasurementEnabled;
+      LOGGER.info("_isThreadCPUSamplingEnabled: {}, 
_isThreadMemorySamplingEnabled: {}", _isThreadCPUSamplingEnabled,
+          _isThreadMemorySamplingEnabled);
+
+      _isThreadSamplingEnabledForMSE =
+          
config.getProperty(CommonConstants.Accounting.CONFIG_OF_ENABLE_THREAD_SAMPLING_MSE,
+              CommonConstants.Accounting.DEFAULT_ENABLE_THREAD_SAMPLING_MSE);
+      LOGGER.info("_isThreadSamplingEnabledForMSE: {}", 
_isThreadSamplingEnabledForMSE);
+
+      // ThreadMXBean wrapper
+      _threadResourceUsageProvider = new ThreadLocal<>();
+      _watcherTask = new WatcherTask();
+
+      _resourceAggregators = new HashMap<>();
+
+      // Add all aggregators. Configs of enabling/disabling cost 
collection/enforcement are handled in the aggregators.
+      _resourceAggregators.put(TrackingScope.WORKLOAD,
+          new WorkloadAggregator(_isThreadCPUSamplingEnabled, 
_isThreadMemorySamplingEnabled, _config, _instanceType,
+              _instanceId));
+      _resourceAggregators.put(TrackingScope.QUERY,
+          new QueryAggregator(_isThreadCPUSamplingEnabled, 
_isThreadMemorySamplingEnabled, _config, _instanceType,
+              _instanceId));
+    }
+
+    @Override
+    public Collection<? extends ThreadResourceTracker> getThreadResources() {
+      return _threadEntriesMap.values();
+    }
+
+    @Override
+    public void sampleUsage() {
+      sampleThreadBytesAllocated();
+      sampleThreadCPUTime();
+    }
+
+    @Override
+    public void sampleUsageMSE() {
+      if (_isThreadSamplingEnabledForMSE) {
+        sampleThreadBytesAllocated();
+        sampleThreadCPUTime();
+      }
+    }
+
+    /**
+     * for testing only
+     */
+    public int getEntryCount() {
+      return _threadEntriesMap.size();
+    }
+
+    /**
+     * This function aggregates resource usage from all active threads and 
groups by queryId.
+     * @return A map of query id, QueryResourceTracker.
+     */
+    @Override
+    public Map<String, ? extends QueryResourceTracker> getQueryResources() {
+      if (!_resourceAggregators.containsKey(TrackingScope.QUERY)) {
+        return new HashMap<>();
+      }
+
+      QueryAggregator queryAggregator = (QueryAggregator) 
_resourceAggregators.get(TrackingScope.QUERY);
+      Map<String, ? extends QueryResourceTracker> queryResources = 
queryAggregator.getQueryResources(_threadEntriesMap);
+      return queryResources;
+    }
+
+    @Override
+    public void updateResourceUsageConcurrently(String identifier, 
TrackingScope trackingScope) {
+      ResourceAggregator resourceAggregator = 
_resourceAggregators.get(trackingScope);
+      if (resourceAggregator == null) {
+        return;
+      }
+
+      if (_isThreadCPUSamplingEnabled) {
+        long cpuUsageNS = getThreadResourceUsageProvider().getThreadTimeNs();
+        resourceAggregator.updateConcurrentCpuUsage(identifier, cpuUsageNS);
+      }
+      if (_isThreadMemorySamplingEnabled) {
+        long memoryAllocatedBytes = 
getThreadResourceUsageProvider().getThreadAllocatedBytes();
+        resourceAggregator.updateConcurrentMemUsage(identifier, 
memoryAllocatedBytes);
+      }
+    }
+
+    /**
+     * The thread would need to do {@code setThreadResourceUsageProvider} 
first upon it is scheduled.
+     * This is to be called from a worker or a runner thread to update its 
corresponding cpu usage entry
+     */
+    @SuppressWarnings("ConstantConditions")
+    public void sampleThreadCPUTime() {
+      ThreadResourceUsageProvider provider = getThreadResourceUsageProvider();
+      if (_isThreadCPUSamplingEnabled && provider != null) {
+        _threadLocalEntry.get()._currentThreadCPUTimeSampleMS = 
provider.getThreadTimeNs();
+      }
+    }
+
+    /**
+     * The thread would need to do {@code setThreadResourceUsageProvider} 
first upon it is scheduled.
+     * This is to be called from a worker or a runner thread to update its 
corresponding memory usage entry
+     */
+    @SuppressWarnings("ConstantConditions")
+    public void sampleThreadBytesAllocated() {
+      ThreadResourceUsageProvider provider = getThreadResourceUsageProvider();
+      if (_isThreadMemorySamplingEnabled && provider != null) {
+        _threadLocalEntry.get()._currentThreadMemoryAllocationSampleBytes = 
provider.getThreadAllocatedBytes();
+      }
+    }
+
+    private ThreadResourceUsageProvider getThreadResourceUsageProvider() {
+      return _threadResourceUsageProvider.get();
+    }
+
+    @Override
+    public void setThreadResourceUsageProvider(ThreadResourceUsageProvider 
threadResourceUsageProvider) {
+      _threadResourceUsageProvider.set(threadResourceUsageProvider);
+    }
+
+    @Override
+    public void createExecutionContextInner(@Nullable String queryId, int 
taskId,
+        ThreadExecutionContext.TaskType taskType, @Nullable 
ThreadExecutionContext parentContext,
+        @Nullable String workloadName) {
+      _threadLocalEntry.get()._errorStatus.set(null);
+      if (parentContext == null) {
+        // is anchor thread
+        assert queryId != null;
+        _threadLocalEntry.get()
+            .setThreadTaskStatus(queryId, 
CommonConstants.Accounting.ANCHOR_TASK_ID, taskType, Thread.currentThread(),
+                workloadName);
+      } else {
+        // not anchor thread
+        _threadLocalEntry.get()
+            .setThreadTaskStatus(queryId, taskId, parentContext.getTaskType(), 
parentContext.getAnchorThread(),
+                workloadName);
+      }
+    }
+
+    @Override
+    public ThreadExecutionContext getThreadExecutionContext() {
+      return _threadLocalEntry.get().getCurrentThreadTaskStatus();
+    }
+
+    public CPUMemThreadLevelAccountingObjects.ThreadEntry getThreadEntry() {
+      return _threadLocalEntry.get();
+    }
+
+    /**
+     * clears thread accounting info once a runner/worker thread has finished 
a particular run
+     */
+    @SuppressWarnings("ConstantConditions")
+    @Override
+    public void clear() {
+      CPUMemThreadLevelAccountingObjects.ThreadEntry threadEntry = 
_threadLocalEntry.get();
+      // clear task info + stats
+      threadEntry.setToIdle();
+      // clear threadResourceUsageProvider
+      _threadResourceUsageProvider.remove();
+      // clear _anchorThread
+      super.clear();
+    }
+
+    @Override
+    public void startWatcherTask() {
+      EXECUTOR_SERVICE.submit(_watcherTask);
+    }
+
+    @Override
+    public Exception getErrorStatus() {
+      return _threadLocalEntry.get()._errorStatus.getAndSet(null);
+    }
+
+    public List<CPUMemThreadLevelAccountingObjects.ThreadEntry> 
getAnchorThreadEntries() {
+      List<CPUMemThreadLevelAccountingObjects.ThreadEntry> anchorThreadEntries 
= new ArrayList<>();
+
+      for (Map.Entry<Thread, CPUMemThreadLevelAccountingObjects.ThreadEntry> 
entry : _threadEntriesMap.entrySet()) {
+        CPUMemThreadLevelAccountingObjects.ThreadEntry threadEntry = 
entry.getValue();
+        CPUMemThreadLevelAccountingObjects.TaskEntry taskEntry = 
threadEntry.getCurrentThreadTaskStatus();
+        if (taskEntry != null && taskEntry.isAnchorThread()) {
+          anchorThreadEntries.add(threadEntry);
+        }
+      }
+      return anchorThreadEntries;
+    }
+
+
+    class WatcherTask implements Runnable {
+      WatcherTask() {
+      }
+
+      @Override
+      public void run() {
+        LOGGER.debug("Running timed task for {}", this.getClass().getName());
+        while (true) {
+          try {
+            // Preaggregation.
+            runPreAggregation();
+
+            // Aggregation
+            runAggregation();
+
+            // Postaggregation
+            runPostAggregation();
+          } catch (Exception e) {
+            LOGGER.error("Error in WatcherTask", e);
+            // TODO: Add a metric to track the number of watcher task errors.
+          } finally {
+            try {
+              LOGGER.debug("_threadEntriesMap size: {}", 
_threadEntriesMap.size());
+
+              for (ResourceAggregator resourceAggregator : 
_resourceAggregators.values()) {
+                resourceAggregator.cleanUpPostAggregation();
+              }
+              // Get sleeptime from both resourceAggregators. Pick the 
minimum. PerQuery Accountant modifies the sleep
+              // time when condition is critical.
+              int sleepTime = Integer.MAX_VALUE;
+              for (ResourceAggregator resourceAggregator : 
_resourceAggregators.values()) {
+                sleepTime = Math.min(sleepTime, 
resourceAggregator.getAggregationSleepTimeMs());
+              }

Review Comment:
   Yes, plan to create a followup PR for this as stated in the description. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to