This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new a1d1c3b1af Always pass the QueryContext into TransformFunctionFactory 
(#10320)
a1d1c3b1af is described below

commit a1d1c3b1afd2451c5d048b725eaa2547a75f3956
Author: Xiaotian (Jackie) Jiang <17555551+jackie-ji...@users.noreply.github.com>
AuthorDate: Thu Feb 23 13:31:03 2023 -0800

    Always pass the QueryContext into TransformFunctionFactory (#10320)
---
 .../operator/filter/ExpressionFilterOperator.java  | 10 +--
 .../filter/H3InclusionIndexFilterOperator.java     | 10 +--
 .../operator/filter/H3IndexFilterOperator.java     | 12 ++-
 .../operator/query/SelectionOrderByOperator.java   |  5 +-
 .../transform/PassThroughTransformOperator.java    | 17 ++--
 .../core/operator/transform/TransformOperator.java | 48 +++---------
 .../function/TransformFunctionFactory.java         | 90 +++++++++-------------
 .../org/apache/pinot/core/plan/FilterPlanNode.java |  6 +-
 .../apache/pinot/core/plan/TransformPlanNode.java  |  2 +-
 .../startree/plan/StarTreeTransformPlanNode.java   |  4 +-
 10 files changed, 81 insertions(+), 123 deletions(-)

diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/filter/ExpressionFilterOperator.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/filter/ExpressionFilterOperator.java
index 9342dc8047..5d3390fc46 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/filter/ExpressionFilterOperator.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/filter/ExpressionFilterOperator.java
@@ -33,6 +33,7 @@ import 
org.apache.pinot.core.operator.filter.predicate.PredicateEvaluator;
 import 
org.apache.pinot.core.operator.filter.predicate.PredicateEvaluatorProvider;
 import org.apache.pinot.core.operator.transform.function.TransformFunction;
 import 
org.apache.pinot.core.operator.transform.function.TransformFunctionFactory;
+import org.apache.pinot.core.query.request.context.QueryContext;
 import org.apache.pinot.segment.spi.IndexSegment;
 import org.apache.pinot.segment.spi.datasource.DataSource;
 
@@ -45,7 +46,7 @@ public class ExpressionFilterOperator extends 
BaseFilterOperator {
   private final TransformFunction _transformFunction;
   private final PredicateEvaluator _predicateEvaluator;
 
-  public ExpressionFilterOperator(IndexSegment segment, Predicate predicate, 
int numDocs) {
+  public ExpressionFilterOperator(IndexSegment segment, QueryContext 
queryContext, Predicate predicate, int numDocs) {
     _numDocs = numDocs;
 
     _dataSourceMap = new HashMap<>();
@@ -56,9 +57,9 @@ public class ExpressionFilterOperator extends 
BaseFilterOperator {
       _dataSourceMap.put(column, segment.getDataSource(column));
     }
 
-    _transformFunction = TransformFunctionFactory.get(lhs, _dataSourceMap);
-    _predicateEvaluator = PredicateEvaluatorProvider
-        .getPredicateEvaluator(predicate, _transformFunction.getDictionary(),
+    _transformFunction = TransformFunctionFactory.get(lhs, _dataSourceMap, 
queryContext);
+    _predicateEvaluator =
+        PredicateEvaluatorProvider.getPredicateEvaluator(predicate, 
_transformFunction.getDictionary(),
             _transformFunction.getResultMetadata().getDataType());
   }
 
@@ -68,7 +69,6 @@ public class ExpressionFilterOperator extends 
BaseFilterOperator {
         new ExpressionFilterDocIdSet(_transformFunction, _predicateEvaluator, 
_dataSourceMap, _numDocs));
   }
 
-
   @Override
   public List<Operator> getChildOperators() {
     return Collections.emptyList();
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/filter/H3InclusionIndexFilterOperator.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/filter/H3InclusionIndexFilterOperator.java
index 081afdf271..f3b0c7bde3 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/filter/H3InclusionIndexFilterOperator.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/filter/H3InclusionIndexFilterOperator.java
@@ -47,24 +47,23 @@ import org.roaringbitmap.buffer.MutableRoaringBitmap;
  * A filter operator that uses H3 index for geospatial data inclusion
  */
 public class H3InclusionIndexFilterOperator extends BaseFilterOperator {
-
   private static final String EXPLAIN_NAME = "INCLUSION_FILTER_H3_INDEX";
   private static final String LITERAL_H3_CELLS_CACHE_NAME = 
"st_contain_literal_h3_cells";
 
   private final IndexSegment _segment;
+  private final QueryContext _queryContext;
   private final Predicate _predicate;
   private final int _numDocs;
   private final H3IndexReader _h3IndexReader;
   private final Geometry _geometry;
   private final boolean _isPositiveCheck;
-  private final QueryContext _queryContext;
 
-  public H3InclusionIndexFilterOperator(IndexSegment segment, Predicate 
predicate, QueryContext queryContext,
+  public H3InclusionIndexFilterOperator(IndexSegment segment, QueryContext 
queryContext, Predicate predicate,
       int numDocs) {
     _segment = segment;
+    _queryContext = queryContext;
     _predicate = predicate;
     _numDocs = numDocs;
-    _queryContext = queryContext;
 
     List<ExpressionContext> arguments = 
predicate.getLhs().getFunction().getArguments();
     EqPredicate eqPredicate = (EqPredicate) predicate;
@@ -125,7 +124,8 @@ public class H3InclusionIndexFilterOperator extends 
BaseFilterOperator {
    * Returns the filter block based on the given the partial match doc ids.
    */
   private FilterBlock getFilterBlock(MutableRoaringBitmap fullMatchDocIds, 
MutableRoaringBitmap partialMatchDocIds) {
-    ExpressionFilterOperator expressionFilterOperator = new 
ExpressionFilterOperator(_segment, _predicate, _numDocs);
+    ExpressionFilterOperator expressionFilterOperator =
+        new ExpressionFilterOperator(_segment, _queryContext, _predicate, 
_numDocs);
     ScanBasedDocIdIterator docIdIterator =
         (ScanBasedDocIdIterator) 
expressionFilterOperator.getNextBlock().getBlockDocIdSet().iterator();
     MutableRoaringBitmap result = docIdIterator.applyAnd(partialMatchDocIds);
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/filter/H3IndexFilterOperator.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/filter/H3IndexFilterOperator.java
index 9f91127c12..b0dc6ca6ac 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/filter/H3IndexFilterOperator.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/filter/H3IndexFilterOperator.java
@@ -33,6 +33,7 @@ import 
org.apache.pinot.core.operator.dociditerators.ScanBasedDocIdIterator;
 import org.apache.pinot.core.operator.docidsets.BitmapDocIdSet;
 import org.apache.pinot.core.operator.docidsets.EmptyDocIdSet;
 import org.apache.pinot.core.operator.docidsets.MatchAllDocIdSet;
+import org.apache.pinot.core.query.request.context.QueryContext;
 import org.apache.pinot.segment.local.utils.GeometrySerializer;
 import org.apache.pinot.segment.local.utils.H3Utils;
 import org.apache.pinot.segment.spi.IndexSegment;
@@ -46,9 +47,10 @@ import org.roaringbitmap.buffer.MutableRoaringBitmap;
  * A filter operator that uses H3 index for geospatial data retrieval
  */
 public class H3IndexFilterOperator extends BaseFilterOperator {
-
   private static final String EXPLAIN_NAME = "FILTER_H3_INDEX";
+
   private final IndexSegment _segment;
+  private final QueryContext _queryContext;
   private final Predicate _predicate;
   private final int _numDocs;
   private final H3IndexReader _h3IndexReader;
@@ -57,8 +59,9 @@ public class H3IndexFilterOperator extends BaseFilterOperator 
{
   private final double _lowerBound;
   private final double _upperBound;
 
-  public H3IndexFilterOperator(IndexSegment segment, Predicate predicate, int 
numDocs) {
+  public H3IndexFilterOperator(IndexSegment segment, QueryContext 
queryContext, Predicate predicate, int numDocs) {
     _segment = segment;
+    _queryContext = queryContext;
     _predicate = predicate;
     _numDocs = numDocs;
 
@@ -182,7 +185,7 @@ public class H3IndexFilterOperator extends 
BaseFilterOperator {
       return getFilterBlock(fullMatchDocIds, partialMatchDocIds);
     } catch (Exception e) {
       // Fall back to ExpressionFilterOperator when the execution encounters 
exception (e.g. numRings is too large)
-      return new ExpressionFilterOperator(_segment, _predicate, 
_numDocs).getNextBlock();
+      return new ExpressionFilterOperator(_segment, _queryContext, _predicate, 
_numDocs).getNextBlock();
     }
   }
 
@@ -229,7 +232,8 @@ public class H3IndexFilterOperator extends 
BaseFilterOperator {
    * Returns the filter block based on the given full match doc ids and the 
partial match doc ids.
    */
   private FilterBlock getFilterBlock(MutableRoaringBitmap fullMatchDocIds, 
MutableRoaringBitmap partialMatchDocIds) {
-    ExpressionFilterOperator expressionFilterOperator = new 
ExpressionFilterOperator(_segment, _predicate, _numDocs);
+    ExpressionFilterOperator expressionFilterOperator =
+        new ExpressionFilterOperator(_segment, _queryContext, _predicate, 
_numDocs);
     ScanBasedDocIdIterator docIdIterator =
         (ScanBasedDocIdIterator) 
expressionFilterOperator.getNextBlock().getBlockDocIdSet().iterator();
     MutableRoaringBitmap result = docIdIterator.applyAnd(partialMatchDocIds);
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/SelectionOrderByOperator.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/SelectionOrderByOperator.java
index d45652bba8..107cbe0d7b 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/query/SelectionOrderByOperator.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/query/SelectionOrderByOperator.java
@@ -72,6 +72,7 @@ public class SelectionOrderByOperator extends 
BaseOperator<SelectionResultsBlock
   private static final String EXPLAIN_NAME = "SELECT_ORDERBY";
 
   private final IndexSegment _indexSegment;
+  private final QueryContext _queryContext;
   private final boolean _nullHandlingEnabled;
   // Deduped order-by expressions followed by output expressions from 
SelectionOperatorUtils.extractExpressions()
   private final List<ExpressionContext> _expressions;
@@ -87,6 +88,7 @@ public class SelectionOrderByOperator extends 
BaseOperator<SelectionResultsBlock
   public SelectionOrderByOperator(IndexSegment indexSegment, QueryContext 
queryContext,
       List<ExpressionContext> expressions, TransformOperator 
transformOperator) {
     _indexSegment = indexSegment;
+    _queryContext = queryContext;
     _nullHandlingEnabled = queryContext.isNullHandlingEnabled();
     _expressions = expressions;
     _transformOperator = transformOperator;
@@ -262,7 +264,8 @@ public class SelectionOrderByOperator extends 
BaseOperator<SelectionResultsBlock
     }
     ProjectionOperator projectionOperator =
         new ProjectionOperator(dataSourceMap, new 
BitmapDocIdSetOperator(docIds, numRows));
-    TransformOperator transformOperator = new 
TransformOperator(projectionOperator, nonOrderByExpressions);
+    TransformOperator transformOperator =
+        new TransformOperator(_queryContext, projectionOperator, 
nonOrderByExpressions);
 
     // Fill the non-order-by expression values
     int numNonOrderByExpressions = nonOrderByExpressions.size();
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/PassThroughTransformOperator.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/PassThroughTransformOperator.java
index 75991bf7ea..fa58606f20 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/PassThroughTransformOperator.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/PassThroughTransformOperator.java
@@ -23,24 +23,18 @@ import 
org.apache.pinot.common.request.context.ExpressionContext;
 import org.apache.pinot.core.operator.ProjectionOperator;
 import org.apache.pinot.core.operator.blocks.PassThroughTransformBlock;
 import org.apache.pinot.core.operator.blocks.ProjectionBlock;
+import org.apache.pinot.core.query.request.context.QueryContext;
 
 
 /**
  * Class for evaluating pass through transform expressions.
  */
 public class PassThroughTransformOperator extends TransformOperator {
-
   private static final String EXPLAIN_NAME = "TRANSFORM_PASSTHROUGH";
 
-  /**
-   * Constructor for the class
-   *
-   * @param projectionOperator Projection operator
-   * @param expressions Collection of expressions to evaluate
-   */
-  public PassThroughTransformOperator(ProjectionOperator projectionOperator,
+  public PassThroughTransformOperator(QueryContext queryContext, 
ProjectionOperator projectionOperator,
       Collection<ExpressionContext> expressions) {
-    super(projectionOperator, expressions);
+    super(queryContext, projectionOperator, expressions);
   }
 
   @Override
@@ -53,9 +47,8 @@ public class PassThroughTransformOperator extends 
TransformOperator {
     }
   }
 
-
   @Override
-  public String toExplainString() {
-    return toExplainString(EXPLAIN_NAME);
+  public String getExplainName() {
+    return EXPLAIN_NAME;
   }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/TransformOperator.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/TransformOperator.java
index 9f52fe839d..986947a1fe 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/TransformOperator.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/TransformOperator.java
@@ -18,16 +18,14 @@
  */
 package org.apache.pinot.core.operator.transform;
 
-import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
-import java.util.Comparator;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
-import javax.annotation.Nullable;
+import java.util.stream.Collectors;
+import org.apache.commons.lang3.StringUtils;
 import org.apache.pinot.common.request.context.ExpressionContext;
-import org.apache.pinot.core.common.Operator;
 import org.apache.pinot.core.operator.BaseOperator;
 import org.apache.pinot.core.operator.ExecutionStatistics;
 import org.apache.pinot.core.operator.ProjectionOperator;
@@ -51,31 +49,16 @@ public class TransformOperator extends 
BaseOperator<TransformBlock> {
   protected final Map<String, DataSource> _dataSourceMap;
   protected final Map<ExpressionContext, TransformFunction> 
_transformFunctionMap = new HashMap<>();
 
-  /**
-   *
-   * @param queryContext the query context
-   * @param projectionOperator Projection operator
-   * @param expressions Collection of expressions to evaluate
-   */
-  public TransformOperator(@Nullable QueryContext queryContext, 
ProjectionOperator projectionOperator,
+  public TransformOperator(QueryContext queryContext, ProjectionOperator 
projectionOperator,
       Collection<ExpressionContext> expressions) {
     _projectionOperator = projectionOperator;
     _dataSourceMap = projectionOperator.getDataSourceMap();
     for (ExpressionContext expression : expressions) {
-      TransformFunction transformFunction = 
TransformFunctionFactory.get(queryContext, expression, _dataSourceMap);
+      TransformFunction transformFunction = 
TransformFunctionFactory.get(expression, _dataSourceMap, queryContext);
       _transformFunctionMap.put(expression, transformFunction);
     }
   }
 
-  /**
-   *
-   * @param projectionOperator Projection operator
-   * @param expressions Collection of expressions to evaluate
-   */
-  public TransformOperator(ProjectionOperator projectionOperator, 
Collection<ExpressionContext> expressions) {
-    this(null, projectionOperator, expressions);
-  }
-
   /**
    * Returns the number of columns projected.
    *
@@ -116,30 +99,19 @@ public class TransformOperator extends 
BaseOperator<TransformBlock> {
     }
   }
 
-
   @Override
   public String toExplainString() {
-    return toExplainString(EXPLAIN_NAME);
+    List<String> expressions =
+        
_transformFunctionMap.keySet().stream().map(ExpressionContext::toString).sorted().collect(Collectors.toList());
+    return getExplainName() + "(" + StringUtils.join(expressions, ", ") + ")";
   }
-  public String toExplainString(String explainName) {
-    ExpressionContext[] functions = _transformFunctionMap.keySet().toArray(new 
ExpressionContext[0]);
-
-    // Sort to make the order, in which names appear within the operator, 
deterministic.
-    Arrays.sort(functions, Comparator.comparing(ExpressionContext::toString));
-
-    StringBuilder stringBuilder = new StringBuilder(explainName).append("(");
-    if (functions != null && functions.length > 0) {
-      stringBuilder.append(functions[0].toString());
-      for (int i = 1; i < functions.length; i++) {
-        stringBuilder.append(", ").append(functions[i].toString());
-      }
-    }
 
-    return stringBuilder.append(')').toString();
+  protected String getExplainName() {
+    return EXPLAIN_NAME;
   }
 
   @Override
-  public List<Operator> getChildOperators() {
+  public List<ProjectionOperator> getChildOperators() {
     return Collections.singletonList(_projectionOperator);
   }
 
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
index cec8847f1a..291ea3fc40 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java
@@ -18,6 +18,7 @@
  */
 package org.apache.pinot.core.operator.transform.function;
 
+import com.google.common.annotations.VisibleForTesting;
 import java.util.ArrayList;
 import java.util.EnumMap;
 import java.util.HashMap;
@@ -71,6 +72,9 @@ import 
org.apache.pinot.core.operator.transform.function.TrigonometricTransformF
 import org.apache.pinot.core.query.request.context.QueryContext;
 import org.apache.pinot.segment.spi.datasource.DataSource;
 import org.apache.pinot.spi.exception.BadQueryRequestException;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
 
 /**
  * Factory class for transformation functions.
@@ -79,6 +83,7 @@ public class TransformFunctionFactory {
   private TransformFunctionFactory() {
   }
 
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(TransformFunctionFactory.class);
   private static final Map<String, Class<? extends TransformFunction>> 
TRANSFORM_FUNCTION_MAP = createRegistry();
 
   private static Map<String, Class<? extends TransformFunction>> 
createRegistry() {
@@ -105,14 +110,10 @@ public class TransformFunctionFactory {
     typeToImplementation.put(TransformFunctionType.TRUNCATE, 
TruncateDecimalTransformFunction.class);
 
     typeToImplementation.put(TransformFunctionType.CAST, 
CastTransformFunction.class);
-    typeToImplementation.put(TransformFunctionType.JSONEXTRACTSCALAR,
-        JsonExtractScalarTransformFunction.class);
-    typeToImplementation.put(TransformFunctionType.JSONEXTRACTKEY,
-        JsonExtractKeyTransformFunction.class);
-    typeToImplementation.put(TransformFunctionType.TIMECONVERT,
-        TimeConversionTransformFunction.class);
-    typeToImplementation.put(TransformFunctionType.DATETIMECONVERT,
-        DateTimeConversionTransformFunction.class);
+    typeToImplementation.put(TransformFunctionType.JSONEXTRACTSCALAR, 
JsonExtractScalarTransformFunction.class);
+    typeToImplementation.put(TransformFunctionType.JSONEXTRACTKEY, 
JsonExtractKeyTransformFunction.class);
+    typeToImplementation.put(TransformFunctionType.TIMECONVERT, 
TimeConversionTransformFunction.class);
+    typeToImplementation.put(TransformFunctionType.DATETIMECONVERT, 
DateTimeConversionTransformFunction.class);
     typeToImplementation.put(TransformFunctionType.DATETRUNC, 
DateTruncTransformFunction.class);
     typeToImplementation.put(TransformFunctionType.YEAR, 
DateTimeTransformFunction.Year.class);
     typeToImplementation.put(TransformFunctionType.YEAR_OF_WEEK, 
DateTimeTransformFunction.YearOfWeek.class);
@@ -135,12 +136,10 @@ public class TransformFunctionFactory {
     typeToImplementation.put(TransformFunctionType.EXTRACT, 
ExtractTransformFunction.class);
 
     // Regexp functions
-    typeToImplementation.put(TransformFunctionType.REGEXP_EXTRACT,
-        RegexpExtractTransformFunction.class);
+    typeToImplementation.put(TransformFunctionType.REGEXP_EXTRACT, 
RegexpExtractTransformFunction.class);
 
     // Array functions
-    typeToImplementation.put(TransformFunctionType.ARRAYAVERAGE,
-        ArrayAverageTransformFunction.class);
+    typeToImplementation.put(TransformFunctionType.ARRAYAVERAGE, 
ArrayAverageTransformFunction.class);
     typeToImplementation.put(TransformFunctionType.ARRAYMAX, 
ArrayMaxTransformFunction.class);
     typeToImplementation.put(TransformFunctionType.ARRAYMIN, 
ArrayMinTransformFunction.class);
     typeToImplementation.put(TransformFunctionType.ARRAYSUM, 
ArraySumTransformFunction.class);
@@ -150,13 +149,10 @@ public class TransformFunctionFactory {
 
     typeToImplementation.put(TransformFunctionType.EQUALS, 
EqualsTransformFunction.class);
     typeToImplementation.put(TransformFunctionType.NOT_EQUALS, 
NotEqualsTransformFunction.class);
-    typeToImplementation.put(TransformFunctionType.GREATER_THAN,
-        GreaterThanTransformFunction.class);
-    typeToImplementation.put(TransformFunctionType.GREATER_THAN_OR_EQUAL,
-        GreaterThanOrEqualTransformFunction.class);
+    typeToImplementation.put(TransformFunctionType.GREATER_THAN, 
GreaterThanTransformFunction.class);
+    typeToImplementation.put(TransformFunctionType.GREATER_THAN_OR_EQUAL, 
GreaterThanOrEqualTransformFunction.class);
     typeToImplementation.put(TransformFunctionType.LESS_THAN, 
LessThanTransformFunction.class);
-    typeToImplementation.put(TransformFunctionType.LESS_THAN_OR_EQUAL,
-        LessThanOrEqualTransformFunction.class);
+    typeToImplementation.put(TransformFunctionType.LESS_THAN_OR_EQUAL, 
LessThanOrEqualTransformFunction.class);
     typeToImplementation.put(TransformFunctionType.IN, 
InTransformFunction.class);
     typeToImplementation.put(TransformFunctionType.NOT_IN, 
NotInTransformFunction.class);
 
@@ -167,22 +163,17 @@ public class TransformFunctionFactory {
 
     // geo functions
     // geo constructors
-    typeToImplementation.put(TransformFunctionType.ST_GEOG_FROM_TEXT,
-        StGeogFromTextFunction.class);
-    typeToImplementation.put(TransformFunctionType.ST_GEOG_FROM_WKB,
-        StGeogFromWKBFunction.class);
-    typeToImplementation.put(TransformFunctionType.ST_GEOM_FROM_TEXT,
-        StGeomFromTextFunction.class);
-    typeToImplementation.put(TransformFunctionType.ST_GEOM_FROM_WKB,
-        StGeomFromWKBFunction.class);
+    typeToImplementation.put(TransformFunctionType.ST_GEOG_FROM_TEXT, 
StGeogFromTextFunction.class);
+    typeToImplementation.put(TransformFunctionType.ST_GEOG_FROM_WKB, 
StGeogFromWKBFunction.class);
+    typeToImplementation.put(TransformFunctionType.ST_GEOM_FROM_TEXT, 
StGeomFromTextFunction.class);
+    typeToImplementation.put(TransformFunctionType.ST_GEOM_FROM_WKB, 
StGeomFromWKBFunction.class);
     typeToImplementation.put(TransformFunctionType.ST_POINT, 
StPointFunction.class);
     typeToImplementation.put(TransformFunctionType.ST_POLYGON, 
StPolygonFunction.class);
 
     // geo measurements
     typeToImplementation.put(TransformFunctionType.ST_AREA, 
StAreaFunction.class);
     typeToImplementation.put(TransformFunctionType.ST_DISTANCE, 
StDistanceFunction.class);
-    typeToImplementation.put(TransformFunctionType.ST_GEOMETRY_TYPE,
-        StGeometryTypeFunction.class);
+    typeToImplementation.put(TransformFunctionType.ST_GEOMETRY_TYPE, 
StGeometryTypeFunction.class);
 
     // geo outputs
     typeToImplementation.put(TransformFunctionType.ST_AS_BINARY, 
StAsBinaryFunction.class);
@@ -202,8 +193,7 @@ public class TransformFunctionFactory {
 
     // null handling
     typeToImplementation.put(TransformFunctionType.IS_NULL, 
IsNullTransformFunction.class);
-    typeToImplementation.put(TransformFunctionType.IS_NOT_NULL,
-        IsNotNullTransformFunction.class);
+    typeToImplementation.put(TransformFunctionType.IS_NOT_NULL, 
IsNotNullTransformFunction.class);
     typeToImplementation.put(TransformFunctionType.COALESCE, 
CoalesceTransformFunction.class);
     typeToImplementation.put(TransformFunctionType.IS_DISTINCT_FROM, 
IsDistinctFromTransformFunction.class);
     typeToImplementation.put(TransformFunctionType.IS_NOT_DISTINCT_FROM, 
IsNotDistinctFromTransformFunction.class);
@@ -234,7 +224,7 @@ public class TransformFunctionFactory {
 
   /**
    * Initializes the factory with a set of transform function classes.
-   * <p>Should be called only once before calling {@link 
#get(ExpressionContext, Map)}.
+   * <p>Should be called only once before using the factory.
    *
    * @param transformFunctionClasses Set of transform function classes
    */
@@ -242,17 +232,17 @@ public class TransformFunctionFactory {
     for (Class<TransformFunction> transformFunctionClass : 
transformFunctionClasses) {
       TransformFunction transformFunction;
       try {
-        transformFunction = transformFunctionClass.newInstance();
-      } catch (InstantiationException | IllegalAccessException e) {
+        transformFunction = 
transformFunctionClass.getDeclaredConstructor().newInstance();
+      } catch (Exception e) {
         throw new RuntimeException(
-            "Caught exception while instantiating transform function from 
class: " + transformFunctionClass.toString(),
-            e);
+            "Caught exception while instantiating transform function from 
class: " + transformFunctionClass, e);
       }
       String transformFunctionName = canonicalize(transformFunction.getName());
-      if (TRANSFORM_FUNCTION_MAP.containsKey(transformFunctionName)) {
-        throw new IllegalArgumentException("Transform function: " + 
transformFunctionName + " already exists");
+      if (TRANSFORM_FUNCTION_MAP.put(transformFunctionName, 
transformFunctionClass) == null) {
+        LOGGER.info("Registering function: {} with class: {}", 
transformFunctionName, transformFunctionClass);
+      } else {
+        LOGGER.info("Replacing function: {} with class: {}", 
transformFunctionName, transformFunctionClass);
       }
-      TRANSFORM_FUNCTION_MAP.put(transformFunctionName, 
transformFunctionClass);
     }
   }
 
@@ -261,22 +251,11 @@ public class TransformFunctionFactory {
    *
    * @param expression Transform expression
    * @param dataSourceMap Map from column name to column data source
-   * @return Transform function
-   */
-  public static TransformFunction get(ExpressionContext expression, 
Map<String, DataSource> dataSourceMap) {
-    return get(null, expression, dataSourceMap);
-  }
-
-  /**
-   * Returns an instance of transform function for the given expression.
-   *
    * @param queryContext the query context if available
-   * @param expression Transform expression
-   * @param dataSourceMap Map from column name to column data source
    * @return Transform function
    */
-  public static TransformFunction get(@Nullable QueryContext queryContext, 
ExpressionContext expression,
-      Map<String, DataSource> dataSourceMap) {
+  public static TransformFunction get(ExpressionContext expression, 
Map<String, DataSource> dataSourceMap,
+      @Nullable QueryContext queryContext) {
     switch (expression.getType()) {
       case FUNCTION:
         FunctionContext function = expression.getFunction();
@@ -289,7 +268,7 @@ public class TransformFunctionFactory {
         if (transformFunctionClass != null) {
           // Transform function
           try {
-            transformFunction = transformFunctionClass.newInstance();
+            transformFunction = 
transformFunctionClass.getDeclaredConstructor().newInstance();
           } catch (Exception e) {
             throw new RuntimeException("Caught exception while constructing 
transform function: " + functionName, e);
           }
@@ -309,7 +288,7 @@ public class TransformFunctionFactory {
 
         List<TransformFunction> transformFunctionArguments = new 
ArrayList<>(numArguments);
         for (ExpressionContext argument : arguments) {
-          
transformFunctionArguments.add(TransformFunctionFactory.get(queryContext, 
argument, dataSourceMap));
+          
transformFunctionArguments.add(TransformFunctionFactory.get(argument, 
dataSourceMap, queryContext));
         }
         try {
           transformFunction.init(transformFunctionArguments, dataSourceMap);
@@ -330,6 +309,11 @@ public class TransformFunctionFactory {
     }
   }
 
+  @VisibleForTesting
+  public static TransformFunction get(ExpressionContext expression, 
Map<String, DataSource> dataSourceMap) {
+    return get(expression, dataSourceMap, null);
+  }
+
   /**
    * Converts the transform function name into its canonical form
    *
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/plan/FilterPlanNode.java 
b/pinot-core/src/main/java/org/apache/pinot/core/plan/FilterPlanNode.java
index 75bd69f7ac..588929f2ea 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/FilterPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/FilterPlanNode.java
@@ -228,13 +228,13 @@ public class FilterPlanNode implements PlanNode {
         ExpressionContext lhs = predicate.getLhs();
         if (lhs.getType() == ExpressionContext.Type.FUNCTION) {
           if (canApplyH3IndexForDistanceCheck(predicate, lhs.getFunction())) {
-            return new H3IndexFilterOperator(_indexSegment, predicate, 
numDocs);
+            return new H3IndexFilterOperator(_indexSegment, _queryContext, 
predicate, numDocs);
           } else if (canApplyH3IndexForInclusionCheck(predicate, 
lhs.getFunction())) {
-            return new H3InclusionIndexFilterOperator(_indexSegment, 
predicate, _queryContext, numDocs);
+            return new H3InclusionIndexFilterOperator(_indexSegment, 
_queryContext, predicate, numDocs);
           } else {
             // TODO: ExpressionFilterOperator does not support predicate types 
without PredicateEvaluator (IS_NULL,
             //       IS_NOT_NULL, TEXT_MATCH)
-            return new ExpressionFilterOperator(_indexSegment, predicate, 
numDocs);
+            return new ExpressionFilterOperator(_indexSegment, _queryContext, 
predicate, numDocs);
           }
         } else {
           String column = lhs.getIdentifier();
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/plan/TransformPlanNode.java 
b/pinot-core/src/main/java/org/apache/pinot/core/plan/TransformPlanNode.java
index 592fc9246f..f50c83fc6b 100644
--- a/pinot-core/src/main/java/org/apache/pinot/core/plan/TransformPlanNode.java
+++ b/pinot-core/src/main/java/org/apache/pinot/core/plan/TransformPlanNode.java
@@ -70,7 +70,7 @@ public class TransformPlanNode implements PlanNode {
     if (hasNonIdentifierExpression) {
       return new TransformOperator(_queryContext, projectionOperator, 
_expressions);
     } else {
-      return new PassThroughTransformOperator(projectionOperator, 
_expressions);
+      return new PassThroughTransformOperator(_queryContext, 
projectionOperator, _expressions);
     }
   }
 }
diff --git 
a/pinot-core/src/main/java/org/apache/pinot/core/startree/plan/StarTreeTransformPlanNode.java
 
b/pinot-core/src/main/java/org/apache/pinot/core/startree/plan/StarTreeTransformPlanNode.java
index 283891dcbc..1059693980 100644
--- 
a/pinot-core/src/main/java/org/apache/pinot/core/startree/plan/StarTreeTransformPlanNode.java
+++ 
b/pinot-core/src/main/java/org/apache/pinot/core/startree/plan/StarTreeTransformPlanNode.java
@@ -35,12 +35,14 @@ import 
org.apache.pinot.segment.spi.index.startree.StarTreeV2;
 
 
 public class StarTreeTransformPlanNode implements PlanNode {
+  private final QueryContext _queryContext;
   private final List<ExpressionContext> _groupByExpressions;
   private final StarTreeProjectionPlanNode _starTreeProjectionPlanNode;
 
   public StarTreeTransformPlanNode(QueryContext queryContext, StarTreeV2 
starTreeV2,
       AggregationFunctionColumnPair[] aggregationFunctionColumnPairs, 
@Nullable ExpressionContext[] groupByExpressions,
       Map<String, List<CompositePredicateEvaluator>> predicateEvaluatorsMap) {
+    _queryContext = queryContext;
     Set<String> projectionColumns = new HashSet<>();
     for (AggregationFunctionColumnPair aggregationFunctionColumnPair : 
aggregationFunctionColumnPairs) {
       projectionColumns.add(aggregationFunctionColumnPair.toColumnName());
@@ -67,6 +69,6 @@ public class StarTreeTransformPlanNode implements PlanNode {
     // NOTE: Here we do not put aggregation expressions into TransformOperator 
based on the following assumptions:
     //       - They are all columns (not functions or constants), where no 
transform is required
     //       - We never call TransformOperator.getResultMetadata() or 
TransformOperator.getDictionary() on them
-    return new TransformOperator(_starTreeProjectionPlanNode.run(), 
_groupByExpressions);
+    return new TransformOperator(_queryContext, 
_starTreeProjectionPlanNode.run(), _groupByExpressions);
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@pinot.apache.org
For additional commands, e-mail: commits-h...@pinot.apache.org

Reply via email to