This is an automated email from the ASF dual-hosted git repository.

xhsun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new db20a2e  [TE] Pushdown topk filter (#5495)
db20a2e is described below

commit db20a2e50660822096b962ba0120a35c082a8b3d
Author: Xiaohui Sun <xh...@linkedin.com>
AuthorDate: Wed Jun 10 14:18:11 2020 -0700

    [TE] Pushdown topk filter (#5495)
    
    * [TE] Push down top k filter to data provider
    
    * add top n support for scv and sql data source
    
    * [TE] Fix failed tests
    
    Co-authored-by: Xiaohui Sun <xh...@xhsun-mn1.linkedin.biz>
---
 .../pinot/thirdeye/datasource/sql/SqlUtils.java    |  22 ++--
 .../pinot/thirdeye/detection/DataProvider.java     |   4 +-
 .../thirdeye/detection/DefaultDataProvider.java    |  10 +-
 .../detection/DefaultInputDataFetcher.java         |   2 +-
 .../detection/StaticDetectionPipeline.java         |   2 +-
 .../algorithm/BaselineRuleFilterWrapper.java       |   4 +-
 .../detection/algorithm/DimensionWrapper.java      |   8 +-
 .../detection/algorithm/LegacyMergeWrapper.java    |   2 +-
 .../algorithm/ThresholdRuleFilterWrapper.java      |   2 +-
 .../thirdeye/datasource/sql/TestSqlUtils.java      | 124 +++++++++++++++++++++
 .../pinot/thirdeye/detection/DataProviderTest.java |  64 ++++++++++-
 .../pinot/thirdeye/detection/MockDataProvider.java |  12 +-
 12 files changed, 226 insertions(+), 30 deletions(-)

diff --git 
a/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/datasource/sql/SqlUtils.java
 
b/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/datasource/sql/SqlUtils.java
index 28f6272..f8b98a2 100644
--- 
a/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/datasource/sql/SqlUtils.java
+++ 
b/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/datasource/sql/SqlUtils.java
@@ -258,16 +258,17 @@ public class SqlUtils {
       sb.append(" AND ").append(dimensionWhereClause);
     }
 
-    if (limit <= 0) {
-      limit = DEFAULT_LIMIT;
-    }
-
     String groupByClause = getDimensionGroupByClause(groupBy, timeGranularity, 
dataTimeSpec);
     if (StringUtils.isNotBlank(groupByClause)) {
       sb.append(" ").append(groupByClause);
-      sb.append(" LIMIT " + limit);
     }
 
+    if (limit > 0 ){
+      sb.append(" ORDER BY " + getSelectMetricClause(metricConfig, 
metricFunction) + " DESC");
+    }
+
+    limit = limit > 0 ? limit : DEFAULT_LIMIT;
+    sb.append(" LIMIT " + limit);
     return sb.toString();
   }
 
@@ -290,12 +291,20 @@ public class SqlUtils {
       } else { //timeFormat case
         builder.append(dateTimeSpec.getColumnName()).append(", ");
       }
-  }
+    }
 
     for (String groupByKey: groupByKeys) {
       builder.append(groupByKey).append(", ");
     }
 
+    String selectMetricClause = getSelectMetricClause(metricConfig, 
metricFunction);
+    builder.append(selectMetricClause);
+
+    return builder.toString();
+  }
+
+  private static String getSelectMetricClause(MetricConfigDTO metricConfig, 
MetricFunction metricFunction) {
+    StringBuilder builder = new StringBuilder();
     String metricName = null;
     if (metricFunction.getMetricName().equals("*")) {
       metricName = "*";
@@ -303,7 +312,6 @@ public class SqlUtils {
       metricName = metricConfig.getName();
     }
     
builder.append(convertAggFunction(metricFunction.getFunctionName())).append("(").append(metricName).append(")");
-
     return builder.toString();
   }
 
diff --git 
a/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/DataProvider.java
 
b/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/DataProvider.java
index 5d1361d..3ad473c 100644
--- 
a/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/DataProvider.java
+++ 
b/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/DataProvider.java
@@ -69,9 +69,11 @@ public interface DataProvider {
    *
    * @param slices metric slices
    * @param dimensions dimensions to group by
+   * @param limit max number of records to return ordered by metric value
+   *                 no limitation if it is a non-positive number
    * @return map of aggregation values (keyed by slice)
    */
-  Map<MetricSlice, DataFrame> fetchAggregates(Collection<MetricSlice> slices, 
List<String> dimensions);
+  Map<MetricSlice, DataFrame> fetchAggregates(Collection<MetricSlice> slices, 
List<String> dimensions, int limit);
 
   /**
    * Returns a multimap of anomalies (keyed by slice) for a given set of 
slices.
diff --git 
a/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/DefaultDataProvider.java
 
b/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/DefaultDataProvider.java
index d24fbe3..4d82342 100644
--- 
a/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/DefaultDataProvider.java
+++ 
b/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/DefaultDataProvider.java
@@ -118,16 +118,12 @@ public class DefaultDataProvider implements DataProvider {
   }
 
   @Override
-  public Map<MetricSlice, DataFrame> fetchAggregates(Collection<MetricSlice> 
slices, final List<String> dimensions) {
+  public Map<MetricSlice, DataFrame> fetchAggregates(Collection<MetricSlice> 
slices, final List<String> dimensions, int limit) {
     try {
       Map<MetricSlice, Future<DataFrame>> futures = new HashMap<>();
       for (final MetricSlice slice : slices) {
-        futures.put(slice, this.executor.submit(new Callable<DataFrame>() {
-          @Override
-          public DataFrame call() throws Exception {
-            return 
DefaultDataProvider.this.aggregationLoader.loadAggregate(slice, dimensions, -1);
-          }
-        }));
+        futures.put(slice, this.executor.submit(
+            () -> 
DefaultDataProvider.this.aggregationLoader.loadAggregate(slice, dimensions, 
limit)));
       }
 
       final long deadline = System.currentTimeMillis() + TIMEOUT;
diff --git 
a/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/DefaultInputDataFetcher.java
 
b/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/DefaultInputDataFetcher.java
index 2dd1efc..c5072a9 100644
--- 
a/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/DefaultInputDataFetcher.java
+++ 
b/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/DefaultInputDataFetcher.java
@@ -57,7 +57,7 @@ public class DefaultInputDataFetcher implements 
InputDataFetcher {
    */
   public InputData fetchData(InputDataSpec inputDataSpec) {
     Map<MetricSlice, DataFrame> timeseries = 
provider.fetchTimeseries(inputDataSpec.getTimeseriesSlices());
-    Map<MetricSlice, DataFrame> aggregates = 
provider.fetchAggregates(inputDataSpec.getAggregateSlices(), 
Collections.<String>emptyList());
+    Map<MetricSlice, DataFrame> aggregates = 
provider.fetchAggregates(inputDataSpec.getAggregateSlices(), 
Collections.<String>emptyList(), -1);
 
     Collection<AnomalySlice> slicesWithConfigId = new HashSet<>();
     for (AnomalySlice slice : inputDataSpec.getAnomalySlices()) {
diff --git 
a/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/StaticDetectionPipeline.java
 
b/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/StaticDetectionPipeline.java
index d068685..42dea97 100644
--- 
a/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/StaticDetectionPipeline.java
+++ 
b/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/StaticDetectionPipeline.java
@@ -81,7 +81,7 @@ public abstract class StaticDetectionPipeline extends 
DetectionPipeline {
   public final DetectionPipelineResult run() throws Exception {
     InputDataSpec dataSpec = this.getInputDataSpec();
     Map<MetricSlice, DataFrame> timeseries = 
this.provider.fetchTimeseries(dataSpec.getTimeseriesSlices());
-    Map<MetricSlice, DataFrame> aggregates = 
this.provider.fetchAggregates(dataSpec.getAggregateSlices(), 
Collections.<String>emptyList());
+    Map<MetricSlice, DataFrame> aggregates = 
this.provider.fetchAggregates(dataSpec.getAggregateSlices(), 
Collections.<String>emptyList(), -1);
 
     Collection<AnomalySlice> updatedSlices = new HashSet<>();
     for (AnomalySlice slice : dataSpec.getAnomalySlices()) {
diff --git 
a/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/algorithm/BaselineRuleFilterWrapper.java
 
b/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/algorithm/BaselineRuleFilterWrapper.java
index 686bd30..ea43d58 100644
--- 
a/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/algorithm/BaselineRuleFilterWrapper.java
+++ 
b/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/algorithm/BaselineRuleFilterWrapper.java
@@ -87,7 +87,7 @@ public class BaselineRuleFilterWrapper extends 
RuleBasedFilterWrapper {
         MetricSlice.from(me.getId(), anomaly.getStartTime(), 
anomaly.getEndTime(), me.getFilters());
     MetricSlice baselineSlice = this.baseline.scatter(currentSlice).get(0);
 
-    Map<MetricSlice, DataFrame> aggregates = 
this.provider.fetchAggregates(Arrays.asList(currentSlice, baselineSlice), 
Collections.<String>emptyList());
+    Map<MetricSlice, DataFrame> aggregates = 
this.provider.fetchAggregates(Arrays.asList(currentSlice, baselineSlice), 
Collections.<String>emptyList(), -1);
     double currentValue = getValueFromAggregates(currentSlice, aggregates);
     double baselineValue = getValueFromAggregates(baselineSlice, aggregates);
     if (!Double.isNaN(this.difference) && Math.abs(currentValue - 
baselineValue) < this.difference) {
@@ -102,7 +102,7 @@ public class BaselineRuleFilterWrapper extends 
RuleBasedFilterWrapper {
       MetricSlice siteWideSlice = this.baseline.scatter(
           MetricSlice.from(siteWideEntity.getId(), anomaly.getStartTime(), 
anomaly.getEndTime(), me.getFilters())).get(0);
       double siteWideBaselineValue = getValueFromAggregates(siteWideSlice,
-          this.provider.fetchAggregates(Collections.singleton(siteWideSlice), 
Collections.<String>emptyList()));
+          this.provider.fetchAggregates(Collections.singleton(siteWideSlice), 
Collections.<String>emptyList(), -1));
 
       if (siteWideBaselineValue != 0 && (Math.abs(currentValue - 
baselineValue) / siteWideBaselineValue) < this.siteWideImpactThreshold) {
         return false;
diff --git 
a/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/algorithm/DimensionWrapper.java
 
b/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/algorithm/DimensionWrapper.java
index 7507e48..d9d034d 100644
--- 
a/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/algorithm/DimensionWrapper.java
+++ 
b/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/algorithm/DimensionWrapper.java
@@ -170,7 +170,13 @@ public class DimensionWrapper extends DetectionPipeline {
       MetricEntity metric = MetricEntity.fromURN(this.metricUrn);
       MetricSlice slice = MetricSlice.from(metric.getId(), 
this.start.getMillis(), this.end.getMillis(), metric.getFilters());
 
-      DataFrame aggregates = 
this.provider.fetchAggregates(Collections.singletonList(slice), 
this.dimensions).get(slice);
+      // We can push down the top k filter if min contribution is not defined.
+      // Otherwise it is not accurate to calculate the contribution.
+      int limit = -1;
+      if (Double.isNaN(this.minContribution) && this.k > 0) {
+        limit = this.k;
+      }
+      DataFrame aggregates = 
this.provider.fetchAggregates(Collections.singletonList(slice), 
this.dimensions, limit).get(slice);
 
       if (aggregates.isEmpty()) {
         return nestedMetrics;
diff --git 
a/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/algorithm/LegacyMergeWrapper.java
 
b/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/algorithm/LegacyMergeWrapper.java
index 48473b7..7d3435e 100644
--- 
a/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/algorithm/LegacyMergeWrapper.java
+++ 
b/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/algorithm/LegacyMergeWrapper.java
@@ -256,7 +256,7 @@ public class LegacyMergeWrapper extends DetectionPipeline {
           if (!StringUtils.isBlank(anomalyFunctionSpec.getGlobalMetric())) {
             MetricSlice slice = makeGlobalSlice(anomalyFunctionSpec, 
mergedAnomalyResult);
 
-            double valGlobal = 
this.provider.fetchAggregates(Collections.singleton(slice), 
Collections.<String>emptyList()).get(slice).getDouble(COL_VALUE, 0);
+            double valGlobal = 
this.provider.fetchAggregates(Collections.singleton(slice), 
Collections.<String>emptyList(), -1).get(slice).getDouble(COL_VALUE, 0);
             double diffLocal = mergedAnomalyResult.getAvgCurrentVal() - 
mergedAnomalyResult.getAvgBaselineVal();
 
             mergedAnomalyResult.setImpactToGlobal(diffLocal / valGlobal);
diff --git 
a/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/algorithm/ThresholdRuleFilterWrapper.java
 
b/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/algorithm/ThresholdRuleFilterWrapper.java
index dbc907a..71676de 100644
--- 
a/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/algorithm/ThresholdRuleFilterWrapper.java
+++ 
b/thirdeye/thirdeye-pinot/src/main/java/org/apache/pinot/thirdeye/detection/algorithm/ThresholdRuleFilterWrapper.java
@@ -56,7 +56,7 @@ public class ThresholdRuleFilterWrapper extends 
RuleBasedFilterWrapper {
     MetricEntity me = MetricEntity.fromURN(anomaly.getMetricUrn());
     MetricSlice currentSlice = MetricSlice.from(me.getId(), 
anomaly.getStartTime(), anomaly.getEndTime(), me.getFilters());
 
-    Map<MetricSlice, DataFrame> aggregates = 
this.provider.fetchAggregates(Collections.singleton(currentSlice), 
Collections.<String>emptyList());
+    Map<MetricSlice, DataFrame> aggregates = 
this.provider.fetchAggregates(Collections.singleton(currentSlice), 
Collections.<String>emptyList(), -1);
     double currentValue = getValueFromAggregates(currentSlice, aggregates);
     if (!Double.isNaN(this.min) && currentValue < this.min) {
       return false;
diff --git 
a/thirdeye/thirdeye-pinot/src/test/java/org/apache/pinot/thirdeye/datasource/sql/TestSqlUtils.java
 
b/thirdeye/thirdeye-pinot/src/test/java/org/apache/pinot/thirdeye/datasource/sql/TestSqlUtils.java
new file mode 100644
index 0000000..fad8db9
--- /dev/null
+++ 
b/thirdeye/thirdeye-pinot/src/test/java/org/apache/pinot/thirdeye/datasource/sql/TestSqlUtils.java
@@ -0,0 +1,124 @@
+/**
+ * Copyright (C) 2014-2018 LinkedIn Corp. (pinot-c...@linkedin.com)
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *         http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.pinot.thirdeye.datasource.sql;
+
+import com.google.common.cache.LoadingCache;
+import com.google.common.collect.HashMultimap;
+import java.util.concurrent.TimeUnit;
+import org.apache.pinot.thirdeye.common.time.TimeGranularity;
+import org.apache.pinot.thirdeye.common.time.TimeSpec;
+import org.apache.pinot.thirdeye.constant.MetricAggFunction;
+import org.apache.pinot.thirdeye.datalayer.bao.DAOTestBase;
+import org.apache.pinot.thirdeye.datalayer.dto.DatasetConfigDTO;
+import org.apache.pinot.thirdeye.datalayer.dto.MetricConfigDTO;
+import org.apache.pinot.thirdeye.datasource.DAORegistry;
+import org.apache.pinot.thirdeye.datasource.MetricFunction;
+import org.apache.pinot.thirdeye.datasource.ThirdEyeCacheRegistry;
+import org.apache.pinot.thirdeye.datasource.ThirdEyeRequest;
+import org.apache.pinot.thirdeye.datasource.cache.MetricDataset;
+import org.joda.time.DateTime;
+import org.joda.time.DateTimeZone;
+import org.joda.time.format.DateTimeFormat;
+import org.joda.time.format.DateTimeFormatter;
+import org.mockito.Mockito;
+import org.testng.Assert;
+import org.testng.annotations.AfterMethod;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.Test;
+
+
+public class TestSqlUtils {
+
+  private final String dataset = "mysql.db.table";
+  private final String metric = "metric";
+
+  private MetricDataset metricDataset;
+  private MetricFunction metricFunction;
+  private DAOTestBase daoTestBase;
+
+  @BeforeMethod
+  public void beforeMethod() throws Exception {
+    this.daoTestBase = DAOTestBase.getInstance();
+    this.metricDataset = new MetricDataset(metric, dataset);
+
+    LoadingCache<String, DatasetConfigDTO> mockDatasetConfigCache = 
Mockito.mock(LoadingCache.class);
+    Mockito.when(mockDatasetConfigCache.get(this.dataset)).thenReturn(new 
DatasetConfigDTO());
+
+    LoadingCache<MetricDataset, MetricConfigDTO> mockMetricConfigCache = 
Mockito.mock(LoadingCache.class);
+    Mockito.when(mockMetricConfigCache.get(this.metricDataset)).thenReturn(new 
MetricConfigDTO());
+
+    
ThirdEyeCacheRegistry.getInstance().registerDatasetConfigCache(mockDatasetConfigCache);
+    
ThirdEyeCacheRegistry.getInstance().registerMetricConfigCache(mockMetricConfigCache);
+
+    MetricConfigDTO metricConfigDTO = new MetricConfigDTO();
+    metricConfigDTO.setDataset(this.dataset);
+    metricConfigDTO.setName(this.metricDataset.getMetricName());
+    metricConfigDTO.setAlias(this.metricDataset.getDataset() + "::" + 
this.metricDataset.getMetricName());
+
+    metricFunction = new MetricFunction();
+    metricFunction.setDataset(dataset);
+    metricFunction.setMetricId(1L);
+    metricFunction.setMetricName(metric);
+    metricFunction.setFunctionName(MetricAggFunction.SUM);
+
+    DAORegistry.getInstance().getMetricConfigDAO().save(metricConfigDTO);
+  }
+
+  @AfterMethod
+  public void afterMethod() {
+    try { this.daoTestBase.cleanup(); } catch (Exception ignore) {}
+  }
+
+  @Test
+  public void testSqlWithExplicitLimit() {
+    TimeGranularity timeGranularity = new TimeGranularity(1, TimeUnit.DAYS);
+    DateTimeFormatter formatter = 
DateTimeFormat.forPattern("yyyy-MM-dd").withZone(DateTimeZone.UTC);
+    ThirdEyeRequest request = ThirdEyeRequest.newBuilder()
+        .setDataSource(this.dataset)
+        .setLimit(100)
+        .setGroupBy("country")
+        .setStartTimeInclusive(DateTime.parse("2020-05-01", formatter))
+        .setEndTimeExclusive(DateTime.parse("2020-05-01", formatter))
+        .setGroupByTimeGranularity(timeGranularity)
+        .build("");
+
+    String timeFormat = TimeSpec.SINCE_EPOCH_FORMAT;
+    TimeSpec timeSpec = new TimeSpec("date", timeGranularity, timeFormat);
+    String actualSql = SqlUtils.getSql(request, this.metricFunction, 
HashMultimap.create(), timeSpec, this.dataset);
+    String expected = "SELECT date, country, SUM(metric) FROM table WHERE  
date = 18383 GROUP BY date, country ORDER BY SUM(metric) DESC LIMIT 100";
+    Assert.assertEquals(actualSql, expected);
+  }
+
+  @Test
+  public void testSqlWithoutExplicitLimit() {
+    TimeGranularity timeGranularity = new TimeGranularity(1, TimeUnit.DAYS);
+    DateTimeFormatter formatter = 
DateTimeFormat.forPattern("yyyy-MM-dd").withZone(DateTimeZone.UTC);
+    ThirdEyeRequest request = ThirdEyeRequest.newBuilder()
+        .setDataSource(this.dataset)
+        .setGroupBy("country")
+        .setStartTimeInclusive(DateTime.parse("2020-05-01", formatter))
+        .setEndTimeExclusive(DateTime.parse("2020-05-01", formatter))
+        .setGroupByTimeGranularity(timeGranularity)
+        .build("");
+
+    String timeFormat = TimeSpec.SINCE_EPOCH_FORMAT;
+    TimeSpec timeSpec = new TimeSpec("date", timeGranularity, timeFormat);
+    String actual = SqlUtils.getSql(request, this.metricFunction, 
HashMultimap.create(), timeSpec, this.dataset);
+    String expected = "SELECT date, country, SUM(metric) FROM table WHERE  
date = 18383 GROUP BY date, country LIMIT 100000";
+    Assert.assertEquals(actual, expected);
+  }
+}
diff --git 
a/thirdeye/thirdeye-pinot/src/test/java/org/apache/pinot/thirdeye/detection/DataProviderTest.java
 
b/thirdeye/thirdeye-pinot/src/test/java/org/apache/pinot/thirdeye/detection/DataProviderTest.java
index b261c95..51533f7 100644
--- 
a/thirdeye/thirdeye-pinot/src/test/java/org/apache/pinot/thirdeye/detection/DataProviderTest.java
+++ 
b/thirdeye/thirdeye-pinot/src/test/java/org/apache/pinot/thirdeye/detection/DataProviderTest.java
@@ -17,6 +17,8 @@
 
 package org.apache.pinot.thirdeye.detection;
 
+import com.google.common.cache.LoadingCache;
+import com.google.common.collect.ArrayListMultimap;
 import com.google.common.collect.HashMultimap;
 import com.google.common.collect.SetMultimap;
 import java.io.InputStreamReader;
@@ -29,9 +31,11 @@ import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.Executors;
 import java.util.concurrent.TimeUnit;
 import org.apache.pinot.thirdeye.anomaly.AnomalyType;
 import org.apache.pinot.thirdeye.dataframe.DataFrame;
+import org.apache.pinot.thirdeye.dataframe.util.MetricSlice;
 import org.apache.pinot.thirdeye.datalayer.bao.DAOTestBase;
 import org.apache.pinot.thirdeye.datalayer.bao.DatasetConfigManager;
 import org.apache.pinot.thirdeye.datalayer.bao.DetectionConfigManager;
@@ -45,13 +49,20 @@ import org.apache.pinot.thirdeye.datalayer.dto.EventDTO;
 import org.apache.pinot.thirdeye.datalayer.dto.MergedAnomalyResultDTO;
 import org.apache.pinot.thirdeye.datalayer.dto.MetricConfigDTO;
 import org.apache.pinot.thirdeye.datasource.DAORegistry;
+import org.apache.pinot.thirdeye.datasource.ThirdEyeCacheRegistry;
+import org.apache.pinot.thirdeye.datasource.ThirdEyeDataSource;
+import org.apache.pinot.thirdeye.datasource.cache.MetricDataset;
 import org.apache.pinot.thirdeye.datasource.cache.QueryCache;
+import org.apache.pinot.thirdeye.datasource.csv.CSVThirdEyeDataSource;
+import org.apache.pinot.thirdeye.datasource.loader.AggregationLoader;
+import org.apache.pinot.thirdeye.datasource.loader.DefaultAggregationLoader;
 import org.apache.pinot.thirdeye.datasource.loader.DefaultTimeSeriesLoader;
 import org.apache.pinot.thirdeye.datasource.loader.TimeSeriesLoader;
 import org.apache.pinot.thirdeye.detection.cache.builder.AnomaliesCacheBuilder;
 import 
org.apache.pinot.thirdeye.detection.cache.builder.TimeSeriesCacheBuilder;
 import org.apache.pinot.thirdeye.detection.spi.model.AnomalySlice;
 import org.apache.pinot.thirdeye.detection.spi.model.EventSlice;
+import org.mockito.Mockito;
 import org.testng.Assert;
 import org.testng.annotations.AfterClass;
 import org.testng.annotations.BeforeMethod;
@@ -69,8 +80,9 @@ public class DataProviderTest {
   private DatasetConfigManager datasetDAO;
   private EvaluationManager evaluationDAO;
   private DetectionConfigManager detectionDAO;
-  private QueryCache cache;
+  private QueryCache queryCache;
   private TimeSeriesLoader timeseriesLoader;
+  private AggregationLoader aggregationLoader;
 
   private DataFrame data;
 
@@ -82,6 +94,8 @@ public class DataProviderTest {
   private List<Long> datasetIds;
   private List<Long> detectionIds;
 
+  private static final MetricDataset METRIC = new MetricDataset("metric", 
"collection1");
+
   @BeforeMethod
   public void beforeMethod() throws Exception {
     this.testBase = DAOTestBase.getInstance();
@@ -139,12 +153,47 @@ public class DataProviderTest {
       this.data.addSeries(COL_TIME, 
this.data.getLongs(COL_TIME).multiply(1000));
     }
 
-    // loaders
-    this.timeseriesLoader = new DefaultTimeSeriesLoader(this.metricDAO, 
this.datasetDAO, this.cache, null);
+    // register caches
+
+    LoadingCache<String, DatasetConfigDTO> mockDatasetConfigCache = 
Mockito.mock(LoadingCache.class);
+    DatasetConfigDTO datasetConfig = 
this.datasetDAO.findByDataset("myDataset2");
+    
Mockito.when(mockDatasetConfigCache.get("myDataset2")).thenReturn(datasetConfig);
+
+
+    LoadingCache<String, Long> mockDatasetMaxDataTimeCache = 
Mockito.mock(LoadingCache.class);
+    Mockito.when(mockDatasetMaxDataTimeCache.get("myDataset2"))
+        .thenReturn(Long.MAX_VALUE);
+
+    MetricDataset metricDataset = new MetricDataset("myMetric2", "myDataset2");
+    LoadingCache<MetricDataset, MetricConfigDTO> mockMetricConfigCache = 
Mockito.mock(LoadingCache.class);
+    MetricConfigDTO metricConfig = 
this.metricDAO.findByMetricAndDataset("myMetric2", "myDataset2");
+    
Mockito.when(mockMetricConfigCache.get(metricDataset)).thenReturn(metricConfig);
+
+    Map<String, DataFrame> datasets = new HashMap<>();
+    datasets.put("myDataset1", data);
+    datasets.put("myDataset2", data);
+
+    Map<Long, String> id2name = new HashMap<>();
+    id2name.put(this.metricIds.get(1), "value");
+    Map<String, ThirdEyeDataSource> dataSourceMap = new HashMap<>();
+    dataSourceMap.put("myDataSource", 
CSVThirdEyeDataSource.fromDataFrame(datasets, id2name));
+    this.queryCache = new QueryCache(dataSourceMap, 
Executors.newSingleThreadExecutor());
+
+    ThirdEyeCacheRegistry cacheRegistry = ThirdEyeCacheRegistry.getInstance();
+    cacheRegistry.registerMetricConfigCache(mockMetricConfigCache);
+    cacheRegistry.registerDatasetConfigCache(mockDatasetConfigCache);
+    cacheRegistry.registerQueryCache(this.queryCache);
+    cacheRegistry.registerDatasetMaxDataTimeCache(mockDatasetMaxDataTimeCache);
+
+    // time series loader
+    this.timeseriesLoader = new DefaultTimeSeriesLoader(this.metricDAO, 
this.datasetDAO, this.queryCache, null);
+
+    // aggregation loader
+    this.aggregationLoader = new DefaultAggregationLoader(this.metricDAO, 
this.datasetDAO, this.queryCache, mockDatasetMaxDataTimeCache);
 
     // provider
     this.provider = new DefaultDataProvider(this.metricDAO, this.datasetDAO, 
this.eventDAO, this.anomalyDAO,
-        this.evaluationDAO, this.timeseriesLoader, null, null,
+        this.evaluationDAO, this.timeseriesLoader, aggregationLoader, null,
         TimeSeriesCacheBuilder.getInstance(), 
AnomaliesCacheBuilder.getInstance());
   }
 
@@ -179,6 +228,13 @@ public class DataProviderTest {
     Assert.assertTrue(metrics.contains(makeMetric(this.metricIds.get(2), 
"myMetric3", "myDataset1")));
   }
 
+  @Test
+  public void testFetchAggregation() {
+    MetricSlice metricSlice = MetricSlice.from(this.metricIds.get(1), 0L, 
32400000L, ArrayListMultimap.create());
+    Map<MetricSlice, DataFrame> aggregates = 
this.provider.fetchAggregates(Collections.singletonList(metricSlice), 
Collections.emptyList(), 1);
+    Assert.assertEquals(aggregates.keySet().size(), 1);
+  }
+
   //
   // datasets
   //
diff --git 
a/thirdeye/thirdeye-pinot/src/test/java/org/apache/pinot/thirdeye/detection/MockDataProvider.java
 
b/thirdeye/thirdeye-pinot/src/test/java/org/apache/pinot/thirdeye/detection/MockDataProvider.java
index 965d6fd..7bd0628 100644
--- 
a/thirdeye/thirdeye-pinot/src/test/java/org/apache/pinot/thirdeye/detection/MockDataProvider.java
+++ 
b/thirdeye/thirdeye-pinot/src/test/java/org/apache/pinot/thirdeye/detection/MockDataProvider.java
@@ -112,7 +112,7 @@ public class MockDataProvider implements DataProvider {
   }
 
   @Override
-  public Map<MetricSlice, DataFrame> fetchAggregates(Collection<MetricSlice> 
slices, final List<String> dimensions) {
+  public Map<MetricSlice, DataFrame> fetchAggregates(Collection<MetricSlice> 
slices, final List<String> dimensions, int limit) {
     Map<MetricSlice, DataFrame> result = new HashMap<>();
     for (MetricSlice slice : slices) {
       List<String> expr = new ArrayList<>();
@@ -125,9 +125,13 @@ public class MockDataProvider implements DataProvider {
         result.put(slice, this.aggregates.get(slice.withFilters(NO_FILTERS)));
 
       } else {
-        result.put(slice, this.aggregates.get(slice.withFilters(NO_FILTERS))
-            .groupByValue(new ArrayList<>(dimensions)).aggregate(expr)
-            .dropSeries(COL_KEY).setIndex(dimensions));
+        DataFrame aggResult = 
this.aggregates.get(slice.withFilters(NO_FILTERS))
+            .groupByValue(new ArrayList<>(dimensions)).aggregate(expr);
+
+        if (limit > 0) {
+          aggResult = aggResult.sortedBy(COL_VALUE).reverse().head(limit);
+        }
+        result.put(slice, aggResult.dropSeries(COL_KEY).setIndex(dimensions));
       }
     }
     return result;


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to