This is an automated email from the ASF dual-hosted git repository.

xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 07a8b895a2d Handle SUM rewrite under null handling (#17338)
07a8b895a2d is described below

commit 07a8b895a2d0b8091f77d260b30edbf5a92d77bd
Author: Xiang Fu <[email protected]>
AuthorDate: Wed Dec 10 18:47:32 2025 -0800

    Handle SUM rewrite under null handling (#17338)
---
 .../sql/parsers/rewriter/AggregationOptimizer.java |  47 +--
 .../parsers/rewriter/AggregationOptimizerTest.java | 452 +++++++++++++++------
 2 files changed, 337 insertions(+), 162 deletions(-)

diff --git 
a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/AggregationOptimizer.java
 
b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/AggregationOptimizer.java
index 95cfcaad895..f63fb923b20 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/AggregationOptimizer.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/AggregationOptimizer.java
@@ -30,10 +30,10 @@ import org.apache.pinot.common.utils.request.RequestUtils;
 /**
  * AggregationOptimizer optimizes aggregation functions by leveraging 
mathematical properties.
  * Currently supports:
- * - sum(column + constant) → sum(column) + constant * count(1)
- * - sum(column - constant) → sum(column) - constant * count(1)
- * - sum(constant + column) → sum(column) + constant * count(1)
- * - sum(constant - column) → constant * count(1) - sum(column)
+ * - sum(column + constant) → sum(column) + constant * count(column)
+ * - sum(column - constant) → sum(column) - constant * count(column)
+ * - sum(constant + column) → sum(column) + constant * count(column)
+ * - sum(constant - column) → constant * count(column) - sum(column)
  * - sum/avg/min/max(column * constant) → aggregation(column) * constant
  *   (for min/max, negative constants flip the aggregation to max/min)
  */
@@ -163,7 +163,7 @@ public class AggregationOptimizer implements QueryRewriter {
       } else if ("sub".equalsIgnoreCase(operator) || 
"minus".equalsIgnoreCase(operator)) {
         return optimizeSubtractionForFunction(left, right, 
aggregationFunction);
       } else if ("mul".equalsIgnoreCase(operator) || 
"mult".equalsIgnoreCase(operator)
-          || "multiply".equalsIgnoreCase(operator)) {
+          || "multiply".equalsIgnoreCase(operator) || 
"times".equalsIgnoreCase(operator)) {
         return optimizeMultiplicationForFunction(left, right, 
aggregationFunction);
       }
     }
@@ -190,11 +190,11 @@ public class AggregationOptimizer implements 
QueryRewriter {
   private Expression optimizeAdditionForFunction(Expression left, Expression 
right, String aggregationFunction) {
     if (isColumn(left) && isConstant(right)) {
       // AGG(column + constant) → AGG(column) + constant (for avg/min/max)
-      // or AGG(column) + constant * count(1) (for sum)
+      // or AGG(column) + constant * count(column) (for sum)
       return createOptimizedAddition(left, right, aggregationFunction);
     } else if (isConstant(left) && isColumn(right)) {
       // AGG(constant + column) → AGG(column) + constant (for avg/min/max)
-      // or AGG(column) + constant * count(1) (for sum)
+      // or AGG(column) + constant * count(column) (for sum)
       return createOptimizedAddition(right, left, aggregationFunction);
     }
     return null;
@@ -206,7 +206,7 @@ public class AggregationOptimizer implements QueryRewriter {
   private Expression optimizeSubtractionForFunction(Expression left, 
Expression right, String aggregationFunction) {
     if (isColumn(left) && isConstant(right)) {
       // AGG(column - constant) → AGG(column) - constant (for avg/min/max)
-      // or AGG(column) - constant * count(1) (for sum)
+      // or AGG(column) - constant * count(column) (for sum)
       return createOptimizedSubtraction(left, right, aggregationFunction);
     } else if (isConstant(left) && isColumn(right)) {
       // Special cases: constant - AGG(column)
@@ -232,7 +232,7 @@ public class AggregationOptimizer implements QueryRewriter {
 
   /**
    * Creates the optimized expression for addition based on aggregation 
function.
-   * For sum: AGG(column) + constant * count(1)
+   * For sum: AGG(column) + constant * count(column)
    * For avg/min/max: AGG(column) + constant
    */
   private Expression createOptimizedAddition(Expression column, Expression 
constant, String aggregationFunction) {
@@ -240,7 +240,7 @@ public class AggregationOptimizer implements QueryRewriter {
     Expression rightOperand;
 
     if ("sum".equals(aggregationFunction)) {
-      rightOperand = createConstantTimesCount(constant);
+      rightOperand = createConstantTimesCount(constant, column);
     } else {
       rightOperand = constant;
     }
@@ -250,7 +250,7 @@ public class AggregationOptimizer implements QueryRewriter {
 
   /**
    * Creates the optimized expression for subtraction based on aggregation 
function.
-   * For sum: AGG(column) - constant * count(1)
+   * For sum: AGG(column) - constant * count(column)
    * For avg/min/max: AGG(column) - constant
    */
   private Expression createOptimizedSubtraction(Expression column, Expression 
constant, String aggregationFunction) {
@@ -258,7 +258,7 @@ public class AggregationOptimizer implements QueryRewriter {
     Expression rightOperand;
 
     if ("sum".equals(aggregationFunction)) {
-      rightOperand = createConstantTimesCount(constant);
+      rightOperand = createConstantTimesCount(constant, column);
     } else {
       rightOperand = constant;
     }
@@ -268,7 +268,7 @@ public class AggregationOptimizer implements QueryRewriter {
 
   /**
    * Creates the optimized expression for reversed subtraction based on 
aggregation function.
-   * For sum: constant * count(1) - sum(column)
+   * For sum: constant * count(column) - sum(column)
    * For avg: constant - avg(column)
    * For min: constant - max(column)
    * For max: constant - min(column)
@@ -279,7 +279,7 @@ public class AggregationOptimizer implements QueryRewriter {
     Expression aggColumn;
 
     if ("sum".equals(aggregationFunction)) {
-      leftOperand = createConstantTimesCount(constant);
+      leftOperand = createConstantTimesCount(constant, column);
       aggColumn = createAggregationExpression(column, "sum");
     } else if ("min".equals(aggregationFunction)) {
       leftOperand = constant;
@@ -324,22 +324,11 @@ public class AggregationOptimizer implements 
QueryRewriter {
   }
 
   /**
-   * Creates constant * count(1) expression
+   * Creates constant * count(column) expression.
    */
-  private Expression createConstantTimesCount(Expression constant) {
-    Expression countOne = createCountOneExpression();
-    return RequestUtils.getFunctionExpression("mult", constant, countOne);
-  }
-
-  /**
-   * Creates count(1) expression
-   */
-  private Expression createCountOneExpression() {
-    Literal oneLiteral = new Literal();
-    oneLiteral.setIntValue(1);
-    Expression oneExpression = new Expression(ExpressionType.LITERAL);
-    oneExpression.setLiteral(oneLiteral);
-    return RequestUtils.getFunctionExpression("count", oneExpression);
+  private Expression createConstantTimesCount(Expression constant, Expression 
column) {
+    Expression countExpr = RequestUtils.getFunctionExpression("count", column);
+    return RequestUtils.getFunctionExpression("mult", constant, countExpr);
   }
 
   /**
diff --git 
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/rewriter/AggregationOptimizerTest.java
 
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/rewriter/AggregationOptimizerTest.java
index 1157a3e6021..3f9c3d0f808 100644
--- 
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/rewriter/AggregationOptimizerTest.java
+++ 
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/rewriter/AggregationOptimizerTest.java
@@ -19,104 +19,189 @@
 package org.apache.pinot.sql.parsers.rewriter;
 
 import java.util.ArrayList;
-import java.util.Collections;
+import java.util.List;
 import org.apache.pinot.common.request.Expression;
 import org.apache.pinot.common.request.ExpressionType;
 import org.apache.pinot.common.request.Function;
 import org.apache.pinot.common.request.PinotQuery;
-import org.apache.pinot.common.utils.request.RequestUtils;
+import org.apache.pinot.common.utils.config.QueryOptionsUtils;
+import 
org.apache.pinot.spi.utils.CommonConstants.Broker.Request.QueryOptionKey;
 import org.apache.pinot.sql.parsers.CalciteSqlParser;
+import org.apache.pinot.sql.parsers.SqlNodeAndOptions;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
 import org.testng.annotations.Test;
 
 import static org.testng.Assert.assertEquals;
 import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.assertTrue;
 
 
 public class AggregationOptimizerTest {
 
   private final AggregationOptimizer _optimizer = new AggregationOptimizer();
 
+  @BeforeClass
+  public void setUp() {
+    List<QueryRewriter> queryRewriters = new ArrayList<>();
+    for (QueryRewriter queryRewriter : 
QueryRewriterFactory.getQueryRewriters()) {
+      if (!(queryRewriter instanceof AggregationOptimizer)) {
+        queryRewriters.add(queryRewriter);
+      }
+    }
+    CalciteSqlParser.QUERY_REWRITERS.clear();
+    CalciteSqlParser.QUERY_REWRITERS.addAll(queryRewriters);
+  }
+
+  @AfterClass
+  public void tearDown() {
+    CalciteSqlParser.QUERY_REWRITERS.clear();
+    
CalciteSqlParser.QUERY_REWRITERS.addAll(QueryRewriterFactory.getQueryRewriters());
+  }
+
   @Test
   public void testSumColumnPlusConstant() {
-    // Test: SELECT sum(met + 2) → SELECT sum(met) + 2 * count(1)
+    // Test: SELECT sum(met + 2) → SELECT sum(met) + 2 * count(met)
     String query = "SELECT sum(met + 2) FROM mytable";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
 
+    Expression original = pinotQuery.getSelectList().get(0);
+    assertTopLevelOperator(original, "sum");
+
     // Apply optimizer
     _optimizer.rewrite(pinotQuery);
 
     // Verify optimization
     Expression selectExpression = pinotQuery.getSelectList().get(0);
+    assertTopLevelOperator(selectExpression, "add");
+    verifyOptimizedAddition(selectExpression, "met", 2);
+  }
+
+  @Test
+  public void testSumColumnPlusConstantWithNullHandlingEnabled() {
+    // Test: Optimizer rewrites using count(column) (null handling on)
+    String query = "SET enableNullHandling=true; SELECT sum(met + 2) FROM 
mytable";
+    SqlNodeAndOptions sqlNodeAndOptions = 
CalciteSqlParser.compileToSqlNodeAndOptions(query);
+    PinotQuery pinotQuery = 
CalciteSqlParser.compileToPinotQuery(sqlNodeAndOptions);
+
+    Expression original = pinotQuery.getSelectList().get(0);
+    assertTopLevelOperator(original, "sum");
+
+    // Apply optimizer
+    _optimizer.rewrite(pinotQuery);
+
+    
assertTrue(QueryOptionsUtils.isNullHandlingEnabled(pinotQuery.getQueryOptions()));
+    Expression selectExpression = pinotQuery.getSelectList().get(0);
+    assertTopLevelOperator(selectExpression, "add");
     verifyOptimizedAddition(selectExpression, "met", 2);
   }
 
+  @Test
+  public void testSumRewriteUsesCountColumnWithNullHandling() {
+    // Ensure the rewrite uses count(column) when null handling is enabled
+    SqlNodeAndOptions sqlNodeAndOptions = 
CalciteSqlParser.compileToSqlNodeAndOptions("SELECT sum(met + 2) FROM t");
+    sqlNodeAndOptions.getOptions().put(QueryOptionKey.ENABLE_NULL_HANDLING, 
"true");
+    PinotQuery pinotQuery = 
CalciteSqlParser.compileToPinotQuery(sqlNodeAndOptions);
+
+    Expression original = pinotQuery.getSelectList().get(0);
+    assertTopLevelOperator(original, "sum");
+
+    _optimizer.rewrite(pinotQuery);
+
+    Expression selectExpression = pinotQuery.getSelectList().get(0);
+    assertTopLevelOperator(selectExpression, "add");
+    Function multFunction = 
selectExpression.getFunctionCall().getOperands().get(1).getFunctionCall();
+    Function countFunction = 
multFunction.getOperands().get(1).getFunctionCall();
+    Expression countArg = countFunction.getOperands().get(0);
+
+    // Verify we use count(column) (identifier) instead of count(1) (literal) 
to preserve semantics
+    assertEquals(countArg.getType(), ExpressionType.IDENTIFIER);
+    assertEquals(countArg.getIdentifier().getName(), "met");
+  }
+
   @Test
   public void testSumConstantPlusColumn() {
-    // Test: SELECT sum(2 + met) → SELECT sum(met) + 2 * count(1)
+    // Test: SELECT sum(2 + met) → SELECT sum(met) + 2 * count(met)
     String query = "SELECT sum(2 + met) FROM mytable";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
 
+    Expression original = pinotQuery.getSelectList().get(0);
+    assertTopLevelOperator(original, "sum");
+
     // Apply optimizer
     _optimizer.rewrite(pinotQuery);
 
     // Verify optimization
     Expression selectExpression = pinotQuery.getSelectList().get(0);
+    assertTopLevelOperator(selectExpression, "add");
     verifyOptimizedAddition(selectExpression, "met", 2);
   }
 
   @Test
   public void testSumColumnMinusConstant() {
-    // Test: SELECT sum(met - 5) → SELECT sum(met) - 5 * count(1)
+    // Test: SELECT sum(met - 5) → SELECT sum(met) - 5 * count(met)
     String query = "SELECT sum(met - 5) FROM mytable";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
 
+    Expression original = pinotQuery.getSelectList().get(0);
+    assertTopLevelOperator(original, "sum");
+
     // Apply optimizer
     _optimizer.rewrite(pinotQuery);
 
     // Verify optimization
     Expression selectExpression = pinotQuery.getSelectList().get(0);
+    assertTopLevelOperator(selectExpression, "sub");
     verifyOptimizedSubtraction(selectExpression, "met", 5);
   }
 
   @Test
   public void testSumConstantMinusColumn() {
-    // Test: SELECT sum(10 - met) → SELECT 10 * count(1) - sum(met)
+    // Test: SELECT sum(10 - met) → SELECT 10 * count(met) - sum(met)
     String query = "SELECT sum(10 - met) FROM mytable";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
 
+    Expression original = pinotQuery.getSelectList().get(0);
+    assertTopLevelOperator(original, "sum");
+
     // Apply optimizer
     _optimizer.rewrite(pinotQuery);
 
     // Verify optimization
     Expression selectExpression = pinotQuery.getSelectList().get(0);
+    assertTopLevelOperator(selectExpression, "sub");
     verifyOptimizedSubtractionReversed(selectExpression, 10, "met");
   }
 
   @Test
   public void testSumWithFloatConstant() {
-    // Test: SELECT sum(price + 2.5) → SELECT sum(price) + 2.5 * count(1)
+    // Test: SELECT sum(price + 2.5) → SELECT sum(price) + 2.5 * count(price)
     String query = "SELECT sum(price + 2.5) FROM mytable";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
 
+    Expression original = pinotQuery.getSelectList().get(0);
+    assertTopLevelOperator(original, "sum");
+
     // Apply optimizer
     _optimizer.rewrite(pinotQuery);
 
     // Verify optimization
     Expression selectExpression = pinotQuery.getSelectList().get(0);
+    assertTopLevelOperator(selectExpression, "add");
     verifyOptimizedFloatAddition(selectExpression, "price", 2.5);
   }
 
   @Test
   public void testSumMultiplicationOptimized() {
-    // Build sum(met * 2) manually to avoid parser constant folding
-    Expression multiplication = RequestUtils.getFunctionExpression("mult",
-        RequestUtils.getIdentifierExpression("met"), 
RequestUtils.getLiteralExpression(2));
-    Expression sum = RequestUtils.getFunctionExpression("sum", multiplication);
-    PinotQuery pinotQuery = buildQueryWithSelect(sum);
+    String query = "SELECT sum(met * 2) FROM mytable";
+    PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sum");
 
     _optimizer.rewrite(pinotQuery);
 
     Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "mult");
     assertEquals(rewritten.getOperator(), "mult");
     assertEquals(rewritten.getOperands().size(), 2);
 
@@ -128,15 +213,16 @@ public class AggregationOptimizerTest {
 
   @Test
   public void testMinMultiplicationWithNegativeConstant() {
-    // Build min(score * -3.5) manually; negative constant should flip MIN to 
MAX
-    Expression multiplication = RequestUtils.getFunctionExpression("multiply",
-        RequestUtils.getIdentifierExpression("score"), 
RequestUtils.getLiteralExpression(-3.5));
-    Expression min = RequestUtils.getFunctionExpression("min", multiplication);
-    PinotQuery pinotQuery = buildQueryWithSelect(min);
+    // Parse min(score * -3.5); negative constant should flip MIN to MAX
+    String query = "SELECT min(score * -3.5) FROM mytable";
+    PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "min");
 
     _optimizer.rewrite(pinotQuery);
 
     Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "mult");
     assertEquals(rewritten.getOperator(), "mult");
 
     Function flippedAggregation = 
rewritten.getOperands().get(0).getFunctionCall();
@@ -147,15 +233,16 @@ public class AggregationOptimizerTest {
 
   @Test
   public void testMaxMultiplicationWithNegativeConstant() {
-    // Build max(score * -2) manually; negative constant should flip MAX to MIN
-    Expression multiplication = RequestUtils.getFunctionExpression("mul",
-        RequestUtils.getIdentifierExpression("score"), 
RequestUtils.getLiteralExpression(-2));
-    Expression max = RequestUtils.getFunctionExpression("max", multiplication);
-    PinotQuery pinotQuery = buildQueryWithSelect(max);
+    // Parse max(score * -2); negative constant should flip MAX to MIN
+    String query = "SELECT max(score * -2) FROM mytable";
+    PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "max");
 
     _optimizer.rewrite(pinotQuery);
 
     Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "mult");
     assertEquals(rewritten.getOperator(), "mult");
 
     Function flippedAggregation = 
rewritten.getOperands().get(0).getFunctionCall();
@@ -166,25 +253,20 @@ public class AggregationOptimizerTest {
 
   @Test
   public void testMultiplicationWithTwoColumnsNotOptimized() {
-    // Build sum(a * b) manually; should remain unchanged because neither side 
is a constant
-    Expression multiplication = RequestUtils.getFunctionExpression("mult",
-        RequestUtils.getIdentifierExpression("a"), 
RequestUtils.getIdentifierExpression("b"));
-    Expression sum = RequestUtils.getFunctionExpression("sum", multiplication);
-    PinotQuery pinotQuery = buildQueryWithSelect(sum);
+    // Parse sum(a * b); should remain unchanged because neither side is a 
constant
+    String query = "SELECT sum(a * b) FROM mytable";
+    PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sum");
 
     _optimizer.rewrite(pinotQuery);
 
     Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sum");
     assertEquals(rewritten.getOperator(), "sum");
     assertEquals(rewritten.getOperands().size(), 1);
     Function multiplicationFunction = 
rewritten.getOperands().get(0).getFunctionCall();
-    assertEquals(multiplicationFunction.getOperator(), "mult");
-  }
-
-  private PinotQuery buildQueryWithSelect(Expression expression) {
-    PinotQuery pinotQuery = new PinotQuery();
-    pinotQuery.setSelectList(new 
ArrayList<>(Collections.singletonList(expression)));
-    return pinotQuery;
+    assertEquals(multiplicationFunction.getOperator().toLowerCase(), "times");
   }
 
   @Test
@@ -193,33 +275,39 @@ public class AggregationOptimizerTest {
     String query = "SELECT sum(a + 1), sum(b - 2), avg(c) FROM mytable";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
 
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sum");
+    assertTopLevelOperator(pinotQuery.getSelectList().get(1), "sum");
+    assertTopLevelOperator(pinotQuery.getSelectList().get(2), "avg");
+
     // Apply optimizer
     _optimizer.rewrite(pinotQuery);
 
     // Verify optimizations
     assertEquals(pinotQuery.getSelectList().size(), 3);
 
-    // First aggregation: sum(a + 1) → sum(a) + 1 * count(1)
+    // First aggregation: sum(a + 1) → sum(a) + 1 * count(a)
     Expression firstExpression = pinotQuery.getSelectList().get(0);
+    assertTopLevelOperator(firstExpression, "add");
     verifyOptimizedAddition(firstExpression, "a", 1);
 
-    // Second aggregation: sum(b - 2) → sum(b) - 2 * count(1)
+    // Second aggregation: sum(b - 2) → sum(b) - 2 * count(b)
     Expression secondExpression = pinotQuery.getSelectList().get(1);
+    assertTopLevelOperator(secondExpression, "sub");
     verifyOptimizedSubtraction(secondExpression, "b", 2);
 
     // Third aggregation: avg(c) should remain unchanged
     Expression thirdExpression = pinotQuery.getSelectList().get(2);
+    assertTopLevelOperator(thirdExpression, "avg");
     assertEquals(thirdExpression.getFunctionCall().getOperator(), "avg");
   }
 
   @Test
   public void testNoOptimizationForUnsupportedPatterns() {
-    // Test patterns that should NOT be optimized
+    // Test patterns that should NOT be optimized; top-level aggregation stays 
unchanged
     String[] queries = {
         "SELECT sum(a / 2) FROM mytable",         // division not supported
         "SELECT sum(a + b) FROM mytable",         // both operands are columns
         "SELECT sum(1 + 2) FROM mytable",         // both operands are 
constants
-        "SELECT avg(a + 2) FROM mytable",         // not a sum function
         "SELECT sum(a) FROM mytable",             // no arithmetic expression
         "SELECT sum(a + b + c) FROM mytable"      // more than 2 operands
     };
@@ -229,6 +317,7 @@ public class AggregationOptimizerTest {
 
       // Store original function operator before optimization
       String originalOperator = 
pinotQuery.getSelectList().get(0).getFunctionCall().getOperator();
+      assertTopLevelOperator(pinotQuery.getSelectList().get(0), 
originalOperator);
 
       // Apply optimizer
       _optimizer.rewrite(pinotQuery);
@@ -236,6 +325,7 @@ public class AggregationOptimizerTest {
       // Verify no optimization occurred - the outer function should remain 
unchanged
       Expression optimized = pinotQuery.getSelectList().get(0);
       assertEquals(originalOperator, 
optimized.getFunctionCall().getOperator());
+      assertTopLevelOperator(optimized, originalOperator);
 
       // Additional verification: for queries that have inner arithmetic, 
ensure they weren't rewritten
       Function outerFunction = optimized.getFunctionCall();
@@ -255,11 +345,14 @@ public class AggregationOptimizerTest {
     String query = "SELECT sum(value + 10) FROM mytable GROUP BY category";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
 
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sum");
+
     // Apply optimizer
     _optimizer.rewrite(pinotQuery);
 
     // Verify optimization occurred
     Expression selectExpression = pinotQuery.getSelectList().get(0);
+    assertTopLevelOperator(selectExpression, "add");
     verifyOptimizedAddition(selectExpression, "value", 10);
 
     // Verify GROUP BY is preserved
@@ -274,14 +367,17 @@ public class AggregationOptimizerTest {
     String query = "SELECT sum(met + 2) FROM mytable";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
 
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sum");
+
     // Apply optimizer
     _optimizer.rewrite(pinotQuery);
 
     // Verify the optimization worked
     Expression selectExpression = pinotQuery.getSelectList().get(0);
+    assertTopLevelOperator(selectExpression, "add");
     Function function = selectExpression.getFunctionCall();
 
-    // Should be rewritten from sum(met + 2) to add(sum(met), mult(2, 
count(1)))
+    // Should be rewritten from sum(met + 2) to add(sum(met), mult(2, 
count(met)))
     assertEquals(function.getOperator(), "add");
     assertEquals(function.getOperands().size(), 2);
 
@@ -290,226 +386,278 @@ public class AggregationOptimizerTest {
     assertEquals(sumExpr.getFunctionCall().getOperator(), "sum");
     
assertEquals(sumExpr.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
 "met");
 
-    // Second operand: mult(2, count(1))
+    // Second operand: mult(2, count(met))
     Expression multExpr = function.getOperands().get(1);
     assertEquals(multExpr.getFunctionCall().getOperator(), "mult");
 
-    System.out.println("✓ Successfully optimized: sum(met + 2) → sum(met) + 2 
* count(1)");
+    System.out.println("✓ Successfully optimized: sum(met + 2) → sum(met) + 2 
* count(met)");
   }
 
   // ==================== AVG FUNCTION TESTS ====================
-  // NOTE: AVG optimizations for column+constant are limited due to Pinot's 
parser doing
-  // constant folding before our optimizer runs. These tests verify current 
behavior.
+  // AVG tests verify that aggregations remain the top-level operator before 
rewriting and become arithmetic after.
 
   @Test
   public void testAvgColumnPlusConstant() {
-    // Test: SELECT avg(value + 10) - Due to constant folding, this is NOT 
optimized
-    String query = "SELECT avg(value + 10) FROM mytable";
+    // Test: SELECT avg(value + 10)
+    String query = "SELECT avg(\"value\" + 10) FROM mytable";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
-    PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "avg");
 
     _optimizer.rewrite(pinotQuery);
 
-    // Should remain unchanged due to constant folding in parser
-    assertEquals(pinotQuery.getSelectList().get(0).toString(),
-        originalQuery.getSelectList().get(0).toString());
+    Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "add");
+    assertEquals(rewritten.getOperator(), "add");
+    
assertEquals(rewritten.getOperands().get(0).getFunctionCall().getOperator(), 
"avg");
+    assertEquals(rewritten.getOperands().get(1).getLiteral().getIntValue(), 
10);
   }
 
   @Test
   public void testAvgConstantPlusColumn() {
-    // Test: SELECT avg(5 + salary) - Due to constant folding, this is NOT 
optimized
+    // Test: SELECT avg(5 + salary)
     String query = "SELECT avg(5 + salary) FROM mytable";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
-    PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "avg");
 
     _optimizer.rewrite(pinotQuery);
 
-    // Should remain unchanged due to constant folding in parser
-    assertEquals(pinotQuery.getSelectList().get(0).toString(),
-        originalQuery.getSelectList().get(0).toString());
+    Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "add");
+    assertEquals(rewritten.getOperator(), "add");
+    
assertEquals(rewritten.getOperands().get(0).getFunctionCall().getOperator(), 
"avg");
+    assertEquals(rewritten.getOperands().get(1).getLiteral().getIntValue(), 5);
   }
 
   @Test
   public void testAvgColumnMinusConstant() {
-    // Test: SELECT avg(price - 100) - Due to constant folding, this is NOT 
optimized
+    // Test: SELECT avg(price - 100)
     String query = "SELECT avg(price - 100) FROM mytable";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
-    PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "avg");
 
     _optimizer.rewrite(pinotQuery);
 
-    // Should remain unchanged due to constant folding in parser
-    assertEquals(pinotQuery.getSelectList().get(0).toString(),
-        originalQuery.getSelectList().get(0).toString());
+    Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sub");
+    assertEquals(rewritten.getOperator(), "sub");
+    
assertEquals(rewritten.getOperands().get(0).getFunctionCall().getOperator(), 
"avg");
+    assertEquals(rewritten.getOperands().get(1).getLiteral().getIntValue(), 
100);
   }
 
   @Test
   public void testAvgConstantMinusColumn() {
-    // Test: SELECT avg(1000 - cost) - Due to constant folding, this is NOT 
optimized
+    // Test: SELECT avg(1000 - cost)
     String query = "SELECT avg(1000 - cost) FROM mytable";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
-    PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "avg");
 
     _optimizer.rewrite(pinotQuery);
 
-    // Should remain unchanged due to constant folding in parser
-    assertEquals(pinotQuery.getSelectList().get(0).toString(),
-        originalQuery.getSelectList().get(0).toString());
+    Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sub");
+    assertEquals(rewritten.getOperator(), "sub");
+    assertEquals(rewritten.getOperands().get(0).getLiteral().getIntValue(), 
1000);
+    
assertEquals(rewritten.getOperands().get(1).getFunctionCall().getOperator(), 
"avg");
   }
 
   @Test
   public void testAvgColumnTimesConstant() {
-    // Test: SELECT avg(quantity * 2.5) - Due to constant folding, this is NOT 
optimized
+    // Test: SELECT avg(quantity * 2.5)
     String query = "SELECT avg(quantity * 2.5) FROM mytable";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
-    PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "avg");
 
     _optimizer.rewrite(pinotQuery);
 
-    // Should remain unchanged due to constant folding in parser
-    assertEquals(pinotQuery.getSelectList().get(0).toString(),
-        originalQuery.getSelectList().get(0).toString());
+    Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "mult");
+    assertEquals(rewritten.getOperator(), "mult");
+
+    Function avgFunction = rewritten.getOperands().get(0).getFunctionCall();
+    assertEquals(avgFunction.getOperator(), "avg");
+    assertEquals(avgFunction.getOperands().get(0).getIdentifier().getName(), 
"quantity");
+    assertEquals(rewritten.getOperands().get(1).getLiteral().getDoubleValue(), 
2.5, 0.0001);
   }
 
   // ==================== MIN FUNCTION TESTS ====================
-  // NOTE: MIN optimizations for column+constant are limited due to Pinot's 
parser doing
-  // constant folding before our optimizer runs. These tests verify current 
behavior.
+  // MIN tests ensure the optimizer rewrites aggregation roots into arithmetic 
(or flips on negative multipliers).
 
   @Test
   public void testMinColumnPlusConstant() {
-    // Test: SELECT min(score + 50) - Due to constant folding, this is NOT 
optimized
+    // Test: SELECT min(score + 50) -> min(score) + 50
     String query = "SELECT min(score + 50) FROM mytable";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
-    PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "min");
 
     _optimizer.rewrite(pinotQuery);
 
-    // Should remain unchanged due to constant folding in parser
-    assertEquals(pinotQuery.getSelectList().get(0).toString(),
-        originalQuery.getSelectList().get(0).toString());
+    Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "add");
+    assertEquals(rewritten.getOperator(), "add");
+    
assertEquals(rewritten.getOperands().get(0).getFunctionCall().getOperator(), 
"min");
+    assertEquals(rewritten.getOperands().get(1).getLiteral().getIntValue(), 
50);
   }
 
   @Test
   public void testMinConstantMinusColumn() {
-    // Test: SELECT min(100 - temperature) - Due to constant folding, this is 
NOT optimized
+    // Test: SELECT min(100 - temperature) -> 100 - max(temperature)
     String query = "SELECT min(100 - temperature) FROM mytable";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
-    PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "min");
 
     _optimizer.rewrite(pinotQuery);
 
-    // Should remain unchanged due to constant folding in parser
-    assertEquals(pinotQuery.getSelectList().get(0).toString(),
-        originalQuery.getSelectList().get(0).toString());
+    Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sub");
+    assertEquals(rewritten.getOperator(), "sub");
+    assertEquals(rewritten.getOperands().get(0).getLiteral().getIntValue(), 
100);
+    
assertEquals(rewritten.getOperands().get(1).getFunctionCall().getOperator(), 
"max");
   }
 
   @Test
   public void testMinColumnTimesPositiveConstant() {
-    // Test: SELECT min(value * 3) - Due to constant folding, this is NOT 
optimized
+    // Parse min(value * 3); positive constant keeps MIN
     String query = "SELECT min(value * 3) FROM mytable";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
-    PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "min");
 
     _optimizer.rewrite(pinotQuery);
 
-    // Should remain unchanged due to constant folding in parser
-    assertEquals(pinotQuery.getSelectList().get(0).toString(),
-        originalQuery.getSelectList().get(0).toString());
+    Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "mult");
+    assertEquals(rewritten.getOperator(), "mult");
+
+    Function aggregation = rewritten.getOperands().get(0).getFunctionCall();
+    assertEquals(aggregation.getOperator(), "min");
+    assertEquals(aggregation.getOperands().get(0).getIdentifier().getName(), 
"value");
+    assertEquals(rewritten.getOperands().get(1).getLiteral().getIntValue(), 3);
   }
 
   @Test
   public void testMinColumnTimesNegativeConstant() {
-    // Test: SELECT min(value * -2) - Due to constant folding, this is NOT 
optimized
+    // Parse min(value * -2); negative constant should flip MIN to MAX
     String query = "SELECT min(value * -2) FROM mytable";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
-    PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "min");
 
     _optimizer.rewrite(pinotQuery);
 
-    // Should remain unchanged due to constant folding in parser
-    assertEquals(pinotQuery.getSelectList().get(0).toString(),
-        originalQuery.getSelectList().get(0).toString());
+    Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "mult");
+    assertEquals(rewritten.getOperator(), "mult");
+
+    Function aggregation = rewritten.getOperands().get(0).getFunctionCall();
+    assertEquals(aggregation.getOperator(), "max");
+    assertEquals(aggregation.getOperands().get(0).getIdentifier().getName(), 
"value");
+    assertEquals(rewritten.getOperands().get(1).getLiteral().getIntValue(), 
-2);
   }
 
   // ==================== MAX FUNCTION TESTS ====================
-  // NOTE: MAX optimizations for column+constant are limited due to Pinot's 
parser doing
-  // constant folding before our optimizer runs. These tests verify current 
behavior.
+  // MAX tests follow the same pattern: aggregation before, arithmetic after 
(with flips on negative multipliers).
 
   @Test
   public void testMaxColumnPlusConstant() {
-    // Test: SELECT max(height + 10) - Due to constant folding, this is NOT 
optimized
+    // Test: SELECT max(height + 10) -> max(height) + 10
     String query = "SELECT max(height + 10) FROM mytable";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
-    PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "max");
 
     _optimizer.rewrite(pinotQuery);
 
-    // Should remain unchanged due to constant folding in parser
-    assertEquals(pinotQuery.getSelectList().get(0).toString(),
-        originalQuery.getSelectList().get(0).toString());
+    Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "add");
+    assertEquals(rewritten.getOperator(), "add");
+    
assertEquals(rewritten.getOperands().get(0).getFunctionCall().getOperator(), 
"max");
+    assertEquals(rewritten.getOperands().get(1).getLiteral().getIntValue(), 
10);
   }
 
   @Test
   public void testMaxConstantMinusColumn() {
-    // Test: SELECT max(200 - age) - Due to constant folding, this is NOT 
optimized
+    // Test: SELECT max(200 - age) -> 200 - min(age)
     String query = "SELECT max(200 - age) FROM mytable";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
-    PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "max");
 
     _optimizer.rewrite(pinotQuery);
 
-    // Should remain unchanged due to constant folding in parser
-    assertEquals(pinotQuery.getSelectList().get(0).toString(),
-        originalQuery.getSelectList().get(0).toString());
+    Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sub");
+    assertEquals(rewritten.getOperator(), "sub");
+    assertEquals(rewritten.getOperands().get(0).getLiteral().getIntValue(), 
200);
+    
assertEquals(rewritten.getOperands().get(1).getFunctionCall().getOperator(), 
"min");
   }
 
   @Test
   public void testMaxColumnTimesNegativeConstant() {
-    // Test: SELECT max(profit * -1) - Due to constant folding, this is NOT 
optimized
+    // Parse max(profit * -1); negative constant should flip MAX to MIN
     String query = "SELECT max(profit * -1) FROM mytable";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
-    PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "max");
 
     _optimizer.rewrite(pinotQuery);
 
-    // Should remain unchanged due to constant folding in parser
-    assertEquals(pinotQuery.getSelectList().get(0).toString(),
-        originalQuery.getSelectList().get(0).toString());
+    Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "mult");
+    assertEquals(rewritten.getOperator(), "mult");
+
+    Function flippedAggregation = 
rewritten.getOperands().get(0).getFunctionCall();
+    assertEquals(flippedAggregation.getOperator(), "min");
+    
assertEquals(flippedAggregation.getOperands().get(0).getIdentifier().getName(), 
"profit");
+    assertEquals(rewritten.getOperands().get(1).getLiteral().getIntValue(), 
-1);
   }
 
   // ==================== COMPLEX MIXED TESTS ====================
 
   @Test
   public void testMixedAggregationOptimizations() {
-    // Test multiple different aggregations in one query
-    // Only SUM should be optimized due to parser constant folding limitations
+    // Test multiple different aggregations in one query; each should be 
rewritten
     String query = "SELECT sum(a + 1), avg(b - 2), min(c * 3), max(d + 4) FROM 
mytable";
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
 
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sum");
+    assertTopLevelOperator(pinotQuery.getSelectList().get(1), "avg");
+    assertTopLevelOperator(pinotQuery.getSelectList().get(2), "min");
+    assertTopLevelOperator(pinotQuery.getSelectList().get(3), "max");
+
     _optimizer.rewrite(pinotQuery);
 
     assertEquals(pinotQuery.getSelectList().size(), 4);
 
-    // sum(a + 1) → sum(a) + 1 * count(1) - This SHOULD be optimized
+    // sum(a + 1) → sum(a) + 1 * count(a) - This SHOULD be optimized
     verifyOptimizedAddition(pinotQuery.getSelectList().get(0), "a", 1);
+    assertTopLevelOperator(pinotQuery.getSelectList().get(0), "add");
+
+    // avg(b - 2) -> avg(b) - 2
+    Function avgRewrite = pinotQuery.getSelectList().get(1).getFunctionCall();
+    assertTopLevelOperator(pinotQuery.getSelectList().get(1), "sub");
+    assertEquals(avgRewrite.getOperator(), "sub");
 
-    // avg(b - 2), min(c * 3), max(d + 4) - These should NOT be optimized due 
to constant folding
-    // We'll verify they remain unchanged by comparing with original parsed 
query
-    String originalQuery = "SELECT sum(a + 1), avg(b - 2), min(c * 3), max(d + 
4) FROM mytable";
-    PinotQuery originalPinotQuery = 
CalciteSqlParser.compileToPinotQuery(originalQuery);
+    // min(c * 3) -> min(c) * 3
+    Function minRewrite = pinotQuery.getSelectList().get(2).getFunctionCall();
+    assertTopLevelOperator(pinotQuery.getSelectList().get(2), "mult");
+    assertEquals(minRewrite.getOperator(), "mult");
 
-    // The original avg, min, max should remain the same (only sum gets 
optimized)
-    assertEquals(pinotQuery.getSelectList().get(1).toString(),
-        originalPinotQuery.getSelectList().get(1).toString());
-    assertEquals(pinotQuery.getSelectList().get(2).toString(),
-        originalPinotQuery.getSelectList().get(2).toString());
-    assertEquals(pinotQuery.getSelectList().get(3).toString(),
-        originalPinotQuery.getSelectList().get(3).toString());
+    // max(d + 4) -> max(d) + 4
+    Function maxRewrite = pinotQuery.getSelectList().get(3).getFunctionCall();
+    assertTopLevelOperator(pinotQuery.getSelectList().get(3), "add");
+    assertEquals(maxRewrite.getOperator(), "add");
   }
 
   @Test
   public void testNonOptimizableQueries() {
-    // Queries that should NOT be optimized
+    // Queries that should NOT be optimized; verify the aggregation root 
remains in place
     String[] queries = {
         "SELECT sum(a * b) FROM mytable",  // Both operands are columns
         "SELECT avg(func(x)) FROM mytable",  // Function call, not arithmetic
@@ -521,16 +669,30 @@ public class AggregationOptimizerTest {
       PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
       PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
 
+      assertTopLevelOperator(pinotQuery.getSelectList().get(0),
+          pinotQuery.getSelectList().get(0).getFunctionCall().getOperator());
+
       _optimizer.rewrite(pinotQuery);
 
       // Should remain unchanged
       assertEquals(pinotQuery.getSelectList().get(0).toString(),
           originalQuery.getSelectList().get(0).toString());
+      assertTopLevelOperator(pinotQuery.getSelectList().get(0),
+          
originalQuery.getSelectList().get(0).getFunctionCall().getOperator());
     }
   }
 
   /**
-   * Verifies that the expression is optimized to: sum(column) + constant * 
count(1)
+   * Helper to assert the top-level function operator of an expression.
+   */
+  private void assertTopLevelOperator(Expression expression, String 
expectedOperator) {
+    Function functionCall = expression.getFunctionCall();
+    assertNotNull(functionCall);
+    assertEquals(functionCall.getOperator(), expectedOperator);
+  }
+
+  /**
+   * Verifies that the expression is optimized to: sum(column) + constant * 
count(column)
    */
   private void verifyOptimizedAddition(Expression expression, String 
columnName, int constantValue) {
     Function function = expression.getFunctionCall();
@@ -543,7 +705,7 @@ public class AggregationOptimizerTest {
     assertEquals(sumExpression.getFunctionCall().getOperator(), "sum");
     
assertEquals(sumExpression.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
 columnName);
 
-    // Second operand should be constant * count(1)
+    // Second operand should be constant * count(column)
     Expression multExpression = function.getOperands().get(1);
     assertEquals(multExpression.getFunctionCall().getOperator(), "mult");
 
@@ -551,14 +713,17 @@ public class AggregationOptimizerTest {
     Expression constantExpr = 
multExpression.getFunctionCall().getOperands().get(0);
     assertEquals(constantExpr.getLiteral().getIntValue(), constantValue);
 
-    // Verify count(1)
+    // Verify count(column)
     Expression countExpr = 
multExpression.getFunctionCall().getOperands().get(1);
     assertEquals(countExpr.getFunctionCall().getOperator(), "count");
-    
assertEquals(countExpr.getFunctionCall().getOperands().get(0).getLiteral().getIntValue(),
 1);
+    Expression countOperand = countExpr.getFunctionCall().getOperands().get(0);
+    assertEquals(countOperand.getType(), ExpressionType.IDENTIFIER);
+    assertNotNull(countOperand.getIdentifier());
+    assertEquals(countOperand.getIdentifier().getName(), columnName);
   }
 
   /**
-   * Verifies that the expression is optimized to: sum(column) + constant * 
count(1) for float constants
+   * Verifies that the expression is optimized to: sum(column) + constant * 
count(column) for float constants
    */
   private void verifyOptimizedFloatAddition(Expression expression, String 
columnName, double constantValue) {
     Function function = expression.getFunctionCall();
@@ -571,17 +736,24 @@ public class AggregationOptimizerTest {
     assertEquals(sumExpression.getFunctionCall().getOperator(), "sum");
     
assertEquals(sumExpression.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
 columnName);
 
-    // Second operand should be constant * count(1)
+    // Second operand should be constant * count(column)
     Expression multExpression = function.getOperands().get(1);
     assertEquals(multExpression.getFunctionCall().getOperator(), "mult");
 
     // Verify constant value (for float, check double value)
     Expression constantExpr = 
multExpression.getFunctionCall().getOperands().get(0);
     assertEquals(constantExpr.getLiteral().getDoubleValue(), constantValue, 
0.0001);
+
+    // Verify count(column)
+    Expression countExpr = 
multExpression.getFunctionCall().getOperands().get(1);
+    assertEquals(countExpr.getFunctionCall().getOperator(), "count");
+    Expression countOperand = countExpr.getFunctionCall().getOperands().get(0);
+    assertEquals(countOperand.getType(), ExpressionType.IDENTIFIER);
+    assertEquals(countOperand.getIdentifier().getName(), columnName);
   }
 
   /**
-   * Verifies that the expression is optimized to: sum(column) - constant * 
count(1)
+   * Verifies that the expression is optimized to: sum(column) - constant * 
count(column)
    */
   private void verifyOptimizedSubtraction(Expression expression, String 
columnName, int constantValue) {
     Function function = expression.getFunctionCall();
@@ -594,17 +766,24 @@ public class AggregationOptimizerTest {
     assertEquals(sumExpression.getFunctionCall().getOperator(), "sum");
     
assertEquals(sumExpression.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
 columnName);
 
-    // Second operand should be constant * count(1)
+    // Second operand should be constant * count(column)
     Expression multExpression = function.getOperands().get(1);
     assertEquals(multExpression.getFunctionCall().getOperator(), "mult");
 
     // Verify constant value
     Expression constantExpr = 
multExpression.getFunctionCall().getOperands().get(0);
     assertEquals(constantExpr.getLiteral().getIntValue(), constantValue);
+
+    // Verify count(column)
+    Expression countExpr = 
multExpression.getFunctionCall().getOperands().get(1);
+    assertEquals(countExpr.getFunctionCall().getOperator(), "count");
+    Expression countOperand = countExpr.getFunctionCall().getOperands().get(0);
+    assertEquals(countOperand.getType(), ExpressionType.IDENTIFIER);
+    assertEquals(countOperand.getIdentifier().getName(), columnName);
   }
 
   /**
-   * Verifies that the expression is optimized to: constant * count(1) - 
sum(column)
+   * Verifies that the expression is optimized to: constant * count(column) - 
sum(column)
    */
   private void verifyOptimizedSubtractionReversed(Expression expression, int 
constantValue, String columnName) {
     Function function = expression.getFunctionCall();
@@ -612,7 +791,7 @@ public class AggregationOptimizerTest {
     assertEquals(function.getOperator(), "sub");
     assertEquals(function.getOperands().size(), 2);
 
-    // First operand should be constant * count(1)
+    // First operand should be constant * count(column)
     Expression multExpression = function.getOperands().get(0);
     assertEquals(multExpression.getFunctionCall().getOperator(), "mult");
 
@@ -620,6 +799,13 @@ public class AggregationOptimizerTest {
     Expression constantExpr = 
multExpression.getFunctionCall().getOperands().get(0);
     assertEquals(constantExpr.getLiteral().getIntValue(), constantValue);
 
+    // Verify count(column)
+    Expression countExpr = 
multExpression.getFunctionCall().getOperands().get(1);
+    assertEquals(countExpr.getFunctionCall().getOperator(), "count");
+    Expression countOperand = countExpr.getFunctionCall().getOperands().get(0);
+    assertEquals(countOperand.getType(), ExpressionType.IDENTIFIER);
+    assertEquals(countOperand.getIdentifier().getName(), columnName);
+
     // Second operand should be sum(column)
     Expression sumExpression = function.getOperands().get(1);
     assertEquals(sumExpression.getFunctionCall().getOperator(), "sum");


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to