This is an automated email from the ASF dual-hosted git repository.
xiangfu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new 07a8b895a2d Handle SUM rewrite under null handling (#17338)
07a8b895a2d is described below
commit 07a8b895a2d0b8091f77d260b30edbf5a92d77bd
Author: Xiang Fu <[email protected]>
AuthorDate: Wed Dec 10 18:47:32 2025 -0800
Handle SUM rewrite under null handling (#17338)
---
.../sql/parsers/rewriter/AggregationOptimizer.java | 47 +--
.../parsers/rewriter/AggregationOptimizerTest.java | 452 +++++++++++++++------
2 files changed, 337 insertions(+), 162 deletions(-)
diff --git
a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/AggregationOptimizer.java
b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/AggregationOptimizer.java
index 95cfcaad895..f63fb923b20 100644
---
a/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/AggregationOptimizer.java
+++
b/pinot-common/src/main/java/org/apache/pinot/sql/parsers/rewriter/AggregationOptimizer.java
@@ -30,10 +30,10 @@ import org.apache.pinot.common.utils.request.RequestUtils;
/**
* AggregationOptimizer optimizes aggregation functions by leveraging
mathematical properties.
* Currently supports:
- * - sum(column + constant) → sum(column) + constant * count(1)
- * - sum(column - constant) → sum(column) - constant * count(1)
- * - sum(constant + column) → sum(column) + constant * count(1)
- * - sum(constant - column) → constant * count(1) - sum(column)
+ * - sum(column + constant) → sum(column) + constant * count(column)
+ * - sum(column - constant) → sum(column) - constant * count(column)
+ * - sum(constant + column) → sum(column) + constant * count(column)
+ * - sum(constant - column) → constant * count(column) - sum(column)
* - sum/avg/min/max(column * constant) → aggregation(column) * constant
* (for min/max, negative constants flip the aggregation to max/min)
*/
@@ -163,7 +163,7 @@ public class AggregationOptimizer implements QueryRewriter {
} else if ("sub".equalsIgnoreCase(operator) ||
"minus".equalsIgnoreCase(operator)) {
return optimizeSubtractionForFunction(left, right,
aggregationFunction);
} else if ("mul".equalsIgnoreCase(operator) ||
"mult".equalsIgnoreCase(operator)
- || "multiply".equalsIgnoreCase(operator)) {
+ || "multiply".equalsIgnoreCase(operator) ||
"times".equalsIgnoreCase(operator)) {
return optimizeMultiplicationForFunction(left, right,
aggregationFunction);
}
}
@@ -190,11 +190,11 @@ public class AggregationOptimizer implements
QueryRewriter {
private Expression optimizeAdditionForFunction(Expression left, Expression
right, String aggregationFunction) {
if (isColumn(left) && isConstant(right)) {
// AGG(column + constant) → AGG(column) + constant (for avg/min/max)
- // or AGG(column) + constant * count(1) (for sum)
+ // or AGG(column) + constant * count(column) (for sum)
return createOptimizedAddition(left, right, aggregationFunction);
} else if (isConstant(left) && isColumn(right)) {
// AGG(constant + column) → AGG(column) + constant (for avg/min/max)
- // or AGG(column) + constant * count(1) (for sum)
+ // or AGG(column) + constant * count(column) (for sum)
return createOptimizedAddition(right, left, aggregationFunction);
}
return null;
@@ -206,7 +206,7 @@ public class AggregationOptimizer implements QueryRewriter {
private Expression optimizeSubtractionForFunction(Expression left,
Expression right, String aggregationFunction) {
if (isColumn(left) && isConstant(right)) {
// AGG(column - constant) → AGG(column) - constant (for avg/min/max)
- // or AGG(column) - constant * count(1) (for sum)
+ // or AGG(column) - constant * count(column) (for sum)
return createOptimizedSubtraction(left, right, aggregationFunction);
} else if (isConstant(left) && isColumn(right)) {
// Special cases: constant - AGG(column)
@@ -232,7 +232,7 @@ public class AggregationOptimizer implements QueryRewriter {
/**
* Creates the optimized expression for addition based on aggregation
function.
- * For sum: AGG(column) + constant * count(1)
+ * For sum: AGG(column) + constant * count(column)
* For avg/min/max: AGG(column) + constant
*/
private Expression createOptimizedAddition(Expression column, Expression
constant, String aggregationFunction) {
@@ -240,7 +240,7 @@ public class AggregationOptimizer implements QueryRewriter {
Expression rightOperand;
if ("sum".equals(aggregationFunction)) {
- rightOperand = createConstantTimesCount(constant);
+ rightOperand = createConstantTimesCount(constant, column);
} else {
rightOperand = constant;
}
@@ -250,7 +250,7 @@ public class AggregationOptimizer implements QueryRewriter {
/**
* Creates the optimized expression for subtraction based on aggregation
function.
- * For sum: AGG(column) - constant * count(1)
+ * For sum: AGG(column) - constant * count(column)
* For avg/min/max: AGG(column) - constant
*/
private Expression createOptimizedSubtraction(Expression column, Expression
constant, String aggregationFunction) {
@@ -258,7 +258,7 @@ public class AggregationOptimizer implements QueryRewriter {
Expression rightOperand;
if ("sum".equals(aggregationFunction)) {
- rightOperand = createConstantTimesCount(constant);
+ rightOperand = createConstantTimesCount(constant, column);
} else {
rightOperand = constant;
}
@@ -268,7 +268,7 @@ public class AggregationOptimizer implements QueryRewriter {
/**
* Creates the optimized expression for reversed subtraction based on
aggregation function.
- * For sum: constant * count(1) - sum(column)
+ * For sum: constant * count(column) - sum(column)
* For avg: constant - avg(column)
* For min: constant - max(column)
* For max: constant - min(column)
@@ -279,7 +279,7 @@ public class AggregationOptimizer implements QueryRewriter {
Expression aggColumn;
if ("sum".equals(aggregationFunction)) {
- leftOperand = createConstantTimesCount(constant);
+ leftOperand = createConstantTimesCount(constant, column);
aggColumn = createAggregationExpression(column, "sum");
} else if ("min".equals(aggregationFunction)) {
leftOperand = constant;
@@ -324,22 +324,11 @@ public class AggregationOptimizer implements
QueryRewriter {
}
/**
- * Creates constant * count(1) expression
+ * Creates constant * count(column) expression.
*/
- private Expression createConstantTimesCount(Expression constant) {
- Expression countOne = createCountOneExpression();
- return RequestUtils.getFunctionExpression("mult", constant, countOne);
- }
-
- /**
- * Creates count(1) expression
- */
- private Expression createCountOneExpression() {
- Literal oneLiteral = new Literal();
- oneLiteral.setIntValue(1);
- Expression oneExpression = new Expression(ExpressionType.LITERAL);
- oneExpression.setLiteral(oneLiteral);
- return RequestUtils.getFunctionExpression("count", oneExpression);
+ private Expression createConstantTimesCount(Expression constant, Expression
column) {
+ Expression countExpr = RequestUtils.getFunctionExpression("count", column);
+ return RequestUtils.getFunctionExpression("mult", constant, countExpr);
}
/**
diff --git
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/rewriter/AggregationOptimizerTest.java
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/rewriter/AggregationOptimizerTest.java
index 1157a3e6021..3f9c3d0f808 100644
---
a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/rewriter/AggregationOptimizerTest.java
+++
b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/rewriter/AggregationOptimizerTest.java
@@ -19,104 +19,189 @@
package org.apache.pinot.sql.parsers.rewriter;
import java.util.ArrayList;
-import java.util.Collections;
+import java.util.List;
import org.apache.pinot.common.request.Expression;
import org.apache.pinot.common.request.ExpressionType;
import org.apache.pinot.common.request.Function;
import org.apache.pinot.common.request.PinotQuery;
-import org.apache.pinot.common.utils.request.RequestUtils;
+import org.apache.pinot.common.utils.config.QueryOptionsUtils;
+import
org.apache.pinot.spi.utils.CommonConstants.Broker.Request.QueryOptionKey;
import org.apache.pinot.sql.parsers.CalciteSqlParser;
+import org.apache.pinot.sql.parsers.SqlNodeAndOptions;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.assertTrue;
public class AggregationOptimizerTest {
private final AggregationOptimizer _optimizer = new AggregationOptimizer();
+ @BeforeClass
+ public void setUp() {
+ List<QueryRewriter> queryRewriters = new ArrayList<>();
+ for (QueryRewriter queryRewriter :
QueryRewriterFactory.getQueryRewriters()) {
+ if (!(queryRewriter instanceof AggregationOptimizer)) {
+ queryRewriters.add(queryRewriter);
+ }
+ }
+ CalciteSqlParser.QUERY_REWRITERS.clear();
+ CalciteSqlParser.QUERY_REWRITERS.addAll(queryRewriters);
+ }
+
+ @AfterClass
+ public void tearDown() {
+ CalciteSqlParser.QUERY_REWRITERS.clear();
+
CalciteSqlParser.QUERY_REWRITERS.addAll(QueryRewriterFactory.getQueryRewriters());
+ }
+
@Test
public void testSumColumnPlusConstant() {
- // Test: SELECT sum(met + 2) → SELECT sum(met) + 2 * count(1)
+ // Test: SELECT sum(met + 2) → SELECT sum(met) + 2 * count(met)
String query = "SELECT sum(met + 2) FROM mytable";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
+ Expression original = pinotQuery.getSelectList().get(0);
+ assertTopLevelOperator(original, "sum");
+
// Apply optimizer
_optimizer.rewrite(pinotQuery);
// Verify optimization
Expression selectExpression = pinotQuery.getSelectList().get(0);
+ assertTopLevelOperator(selectExpression, "add");
+ verifyOptimizedAddition(selectExpression, "met", 2);
+ }
+
+ @Test
+ public void testSumColumnPlusConstantWithNullHandlingEnabled() {
+ // Test: Optimizer rewrites using count(column) (null handling on)
+ String query = "SET enableNullHandling=true; SELECT sum(met + 2) FROM
mytable";
+ SqlNodeAndOptions sqlNodeAndOptions =
CalciteSqlParser.compileToSqlNodeAndOptions(query);
+ PinotQuery pinotQuery =
CalciteSqlParser.compileToPinotQuery(sqlNodeAndOptions);
+
+ Expression original = pinotQuery.getSelectList().get(0);
+ assertTopLevelOperator(original, "sum");
+
+ // Apply optimizer
+ _optimizer.rewrite(pinotQuery);
+
+
assertTrue(QueryOptionsUtils.isNullHandlingEnabled(pinotQuery.getQueryOptions()));
+ Expression selectExpression = pinotQuery.getSelectList().get(0);
+ assertTopLevelOperator(selectExpression, "add");
verifyOptimizedAddition(selectExpression, "met", 2);
}
+ @Test
+ public void testSumRewriteUsesCountColumnWithNullHandling() {
+ // Ensure the rewrite uses count(column) when null handling is enabled
+ SqlNodeAndOptions sqlNodeAndOptions =
CalciteSqlParser.compileToSqlNodeAndOptions("SELECT sum(met + 2) FROM t");
+ sqlNodeAndOptions.getOptions().put(QueryOptionKey.ENABLE_NULL_HANDLING,
"true");
+ PinotQuery pinotQuery =
CalciteSqlParser.compileToPinotQuery(sqlNodeAndOptions);
+
+ Expression original = pinotQuery.getSelectList().get(0);
+ assertTopLevelOperator(original, "sum");
+
+ _optimizer.rewrite(pinotQuery);
+
+ Expression selectExpression = pinotQuery.getSelectList().get(0);
+ assertTopLevelOperator(selectExpression, "add");
+ Function multFunction =
selectExpression.getFunctionCall().getOperands().get(1).getFunctionCall();
+ Function countFunction =
multFunction.getOperands().get(1).getFunctionCall();
+ Expression countArg = countFunction.getOperands().get(0);
+
+ // Verify we use count(column) (identifier) instead of count(1) (literal)
to preserve semantics
+ assertEquals(countArg.getType(), ExpressionType.IDENTIFIER);
+ assertEquals(countArg.getIdentifier().getName(), "met");
+ }
+
@Test
public void testSumConstantPlusColumn() {
- // Test: SELECT sum(2 + met) → SELECT sum(met) + 2 * count(1)
+ // Test: SELECT sum(2 + met) → SELECT sum(met) + 2 * count(met)
String query = "SELECT sum(2 + met) FROM mytable";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
+ Expression original = pinotQuery.getSelectList().get(0);
+ assertTopLevelOperator(original, "sum");
+
// Apply optimizer
_optimizer.rewrite(pinotQuery);
// Verify optimization
Expression selectExpression = pinotQuery.getSelectList().get(0);
+ assertTopLevelOperator(selectExpression, "add");
verifyOptimizedAddition(selectExpression, "met", 2);
}
@Test
public void testSumColumnMinusConstant() {
- // Test: SELECT sum(met - 5) → SELECT sum(met) - 5 * count(1)
+ // Test: SELECT sum(met - 5) → SELECT sum(met) - 5 * count(met)
String query = "SELECT sum(met - 5) FROM mytable";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
+ Expression original = pinotQuery.getSelectList().get(0);
+ assertTopLevelOperator(original, "sum");
+
// Apply optimizer
_optimizer.rewrite(pinotQuery);
// Verify optimization
Expression selectExpression = pinotQuery.getSelectList().get(0);
+ assertTopLevelOperator(selectExpression, "sub");
verifyOptimizedSubtraction(selectExpression, "met", 5);
}
@Test
public void testSumConstantMinusColumn() {
- // Test: SELECT sum(10 - met) → SELECT 10 * count(1) - sum(met)
+ // Test: SELECT sum(10 - met) → SELECT 10 * count(met) - sum(met)
String query = "SELECT sum(10 - met) FROM mytable";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
+ Expression original = pinotQuery.getSelectList().get(0);
+ assertTopLevelOperator(original, "sum");
+
// Apply optimizer
_optimizer.rewrite(pinotQuery);
// Verify optimization
Expression selectExpression = pinotQuery.getSelectList().get(0);
+ assertTopLevelOperator(selectExpression, "sub");
verifyOptimizedSubtractionReversed(selectExpression, 10, "met");
}
@Test
public void testSumWithFloatConstant() {
- // Test: SELECT sum(price + 2.5) → SELECT sum(price) + 2.5 * count(1)
+ // Test: SELECT sum(price + 2.5) → SELECT sum(price) + 2.5 * count(price)
String query = "SELECT sum(price + 2.5) FROM mytable";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
+ Expression original = pinotQuery.getSelectList().get(0);
+ assertTopLevelOperator(original, "sum");
+
// Apply optimizer
_optimizer.rewrite(pinotQuery);
// Verify optimization
Expression selectExpression = pinotQuery.getSelectList().get(0);
+ assertTopLevelOperator(selectExpression, "add");
verifyOptimizedFloatAddition(selectExpression, "price", 2.5);
}
@Test
public void testSumMultiplicationOptimized() {
- // Build sum(met * 2) manually to avoid parser constant folding
- Expression multiplication = RequestUtils.getFunctionExpression("mult",
- RequestUtils.getIdentifierExpression("met"),
RequestUtils.getLiteralExpression(2));
- Expression sum = RequestUtils.getFunctionExpression("sum", multiplication);
- PinotQuery pinotQuery = buildQueryWithSelect(sum);
+ String query = "SELECT sum(met * 2) FROM mytable";
+ PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sum");
_optimizer.rewrite(pinotQuery);
Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "mult");
assertEquals(rewritten.getOperator(), "mult");
assertEquals(rewritten.getOperands().size(), 2);
@@ -128,15 +213,16 @@ public class AggregationOptimizerTest {
@Test
public void testMinMultiplicationWithNegativeConstant() {
- // Build min(score * -3.5) manually; negative constant should flip MIN to
MAX
- Expression multiplication = RequestUtils.getFunctionExpression("multiply",
- RequestUtils.getIdentifierExpression("score"),
RequestUtils.getLiteralExpression(-3.5));
- Expression min = RequestUtils.getFunctionExpression("min", multiplication);
- PinotQuery pinotQuery = buildQueryWithSelect(min);
+ // Parse min(score * -3.5); negative constant should flip MIN to MAX
+ String query = "SELECT min(score * -3.5) FROM mytable";
+ PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "min");
_optimizer.rewrite(pinotQuery);
Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "mult");
assertEquals(rewritten.getOperator(), "mult");
Function flippedAggregation =
rewritten.getOperands().get(0).getFunctionCall();
@@ -147,15 +233,16 @@ public class AggregationOptimizerTest {
@Test
public void testMaxMultiplicationWithNegativeConstant() {
- // Build max(score * -2) manually; negative constant should flip MAX to MIN
- Expression multiplication = RequestUtils.getFunctionExpression("mul",
- RequestUtils.getIdentifierExpression("score"),
RequestUtils.getLiteralExpression(-2));
- Expression max = RequestUtils.getFunctionExpression("max", multiplication);
- PinotQuery pinotQuery = buildQueryWithSelect(max);
+ // Parse max(score * -2); negative constant should flip MAX to MIN
+ String query = "SELECT max(score * -2) FROM mytable";
+ PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "max");
_optimizer.rewrite(pinotQuery);
Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "mult");
assertEquals(rewritten.getOperator(), "mult");
Function flippedAggregation =
rewritten.getOperands().get(0).getFunctionCall();
@@ -166,25 +253,20 @@ public class AggregationOptimizerTest {
@Test
public void testMultiplicationWithTwoColumnsNotOptimized() {
- // Build sum(a * b) manually; should remain unchanged because neither side
is a constant
- Expression multiplication = RequestUtils.getFunctionExpression("mult",
- RequestUtils.getIdentifierExpression("a"),
RequestUtils.getIdentifierExpression("b"));
- Expression sum = RequestUtils.getFunctionExpression("sum", multiplication);
- PinotQuery pinotQuery = buildQueryWithSelect(sum);
+ // Parse sum(a * b); should remain unchanged because neither side is a
constant
+ String query = "SELECT sum(a * b) FROM mytable";
+ PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sum");
_optimizer.rewrite(pinotQuery);
Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sum");
assertEquals(rewritten.getOperator(), "sum");
assertEquals(rewritten.getOperands().size(), 1);
Function multiplicationFunction =
rewritten.getOperands().get(0).getFunctionCall();
- assertEquals(multiplicationFunction.getOperator(), "mult");
- }
-
- private PinotQuery buildQueryWithSelect(Expression expression) {
- PinotQuery pinotQuery = new PinotQuery();
- pinotQuery.setSelectList(new
ArrayList<>(Collections.singletonList(expression)));
- return pinotQuery;
+ assertEquals(multiplicationFunction.getOperator().toLowerCase(), "times");
}
@Test
@@ -193,33 +275,39 @@ public class AggregationOptimizerTest {
String query = "SELECT sum(a + 1), sum(b - 2), avg(c) FROM mytable";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sum");
+ assertTopLevelOperator(pinotQuery.getSelectList().get(1), "sum");
+ assertTopLevelOperator(pinotQuery.getSelectList().get(2), "avg");
+
// Apply optimizer
_optimizer.rewrite(pinotQuery);
// Verify optimizations
assertEquals(pinotQuery.getSelectList().size(), 3);
- // First aggregation: sum(a + 1) → sum(a) + 1 * count(1)
+ // First aggregation: sum(a + 1) → sum(a) + 1 * count(a)
Expression firstExpression = pinotQuery.getSelectList().get(0);
+ assertTopLevelOperator(firstExpression, "add");
verifyOptimizedAddition(firstExpression, "a", 1);
- // Second aggregation: sum(b - 2) → sum(b) - 2 * count(1)
+ // Second aggregation: sum(b - 2) → sum(b) - 2 * count(b)
Expression secondExpression = pinotQuery.getSelectList().get(1);
+ assertTopLevelOperator(secondExpression, "sub");
verifyOptimizedSubtraction(secondExpression, "b", 2);
// Third aggregation: avg(c) should remain unchanged
Expression thirdExpression = pinotQuery.getSelectList().get(2);
+ assertTopLevelOperator(thirdExpression, "avg");
assertEquals(thirdExpression.getFunctionCall().getOperator(), "avg");
}
@Test
public void testNoOptimizationForUnsupportedPatterns() {
- // Test patterns that should NOT be optimized
+ // Test patterns that should NOT be optimized; top-level aggregation stays
unchanged
String[] queries = {
"SELECT sum(a / 2) FROM mytable", // division not supported
"SELECT sum(a + b) FROM mytable", // both operands are columns
"SELECT sum(1 + 2) FROM mytable", // both operands are
constants
- "SELECT avg(a + 2) FROM mytable", // not a sum function
"SELECT sum(a) FROM mytable", // no arithmetic expression
"SELECT sum(a + b + c) FROM mytable" // more than 2 operands
};
@@ -229,6 +317,7 @@ public class AggregationOptimizerTest {
// Store original function operator before optimization
String originalOperator =
pinotQuery.getSelectList().get(0).getFunctionCall().getOperator();
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0),
originalOperator);
// Apply optimizer
_optimizer.rewrite(pinotQuery);
@@ -236,6 +325,7 @@ public class AggregationOptimizerTest {
// Verify no optimization occurred - the outer function should remain
unchanged
Expression optimized = pinotQuery.getSelectList().get(0);
assertEquals(originalOperator,
optimized.getFunctionCall().getOperator());
+ assertTopLevelOperator(optimized, originalOperator);
// Additional verification: for queries that have inner arithmetic,
ensure they weren't rewritten
Function outerFunction = optimized.getFunctionCall();
@@ -255,11 +345,14 @@ public class AggregationOptimizerTest {
String query = "SELECT sum(value + 10) FROM mytable GROUP BY category";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sum");
+
// Apply optimizer
_optimizer.rewrite(pinotQuery);
// Verify optimization occurred
Expression selectExpression = pinotQuery.getSelectList().get(0);
+ assertTopLevelOperator(selectExpression, "add");
verifyOptimizedAddition(selectExpression, "value", 10);
// Verify GROUP BY is preserved
@@ -274,14 +367,17 @@ public class AggregationOptimizerTest {
String query = "SELECT sum(met + 2) FROM mytable";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sum");
+
// Apply optimizer
_optimizer.rewrite(pinotQuery);
// Verify the optimization worked
Expression selectExpression = pinotQuery.getSelectList().get(0);
+ assertTopLevelOperator(selectExpression, "add");
Function function = selectExpression.getFunctionCall();
- // Should be rewritten from sum(met + 2) to add(sum(met), mult(2,
count(1)))
+ // Should be rewritten from sum(met + 2) to add(sum(met), mult(2,
count(met)))
assertEquals(function.getOperator(), "add");
assertEquals(function.getOperands().size(), 2);
@@ -290,226 +386,278 @@ public class AggregationOptimizerTest {
assertEquals(sumExpr.getFunctionCall().getOperator(), "sum");
assertEquals(sumExpr.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
"met");
- // Second operand: mult(2, count(1))
+ // Second operand: mult(2, count(met))
Expression multExpr = function.getOperands().get(1);
assertEquals(multExpr.getFunctionCall().getOperator(), "mult");
- System.out.println("✓ Successfully optimized: sum(met + 2) → sum(met) + 2
* count(1)");
+ System.out.println("✓ Successfully optimized: sum(met + 2) → sum(met) + 2
* count(met)");
}
// ==================== AVG FUNCTION TESTS ====================
- // NOTE: AVG optimizations for column+constant are limited due to Pinot's
parser doing
- // constant folding before our optimizer runs. These tests verify current
behavior.
+ // AVG tests verify that aggregations remain the top-level operator before
rewriting and become arithmetic after.
@Test
public void testAvgColumnPlusConstant() {
- // Test: SELECT avg(value + 10) - Due to constant folding, this is NOT
optimized
- String query = "SELECT avg(value + 10) FROM mytable";
+ // Test: SELECT avg(value + 10)
+ String query = "SELECT avg(\"value\" + 10) FROM mytable";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
- PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "avg");
_optimizer.rewrite(pinotQuery);
- // Should remain unchanged due to constant folding in parser
- assertEquals(pinotQuery.getSelectList().get(0).toString(),
- originalQuery.getSelectList().get(0).toString());
+ Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "add");
+ assertEquals(rewritten.getOperator(), "add");
+
assertEquals(rewritten.getOperands().get(0).getFunctionCall().getOperator(),
"avg");
+ assertEquals(rewritten.getOperands().get(1).getLiteral().getIntValue(),
10);
}
@Test
public void testAvgConstantPlusColumn() {
- // Test: SELECT avg(5 + salary) - Due to constant folding, this is NOT
optimized
+ // Test: SELECT avg(5 + salary)
String query = "SELECT avg(5 + salary) FROM mytable";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
- PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "avg");
_optimizer.rewrite(pinotQuery);
- // Should remain unchanged due to constant folding in parser
- assertEquals(pinotQuery.getSelectList().get(0).toString(),
- originalQuery.getSelectList().get(0).toString());
+ Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "add");
+ assertEquals(rewritten.getOperator(), "add");
+
assertEquals(rewritten.getOperands().get(0).getFunctionCall().getOperator(),
"avg");
+ assertEquals(rewritten.getOperands().get(1).getLiteral().getIntValue(), 5);
}
@Test
public void testAvgColumnMinusConstant() {
- // Test: SELECT avg(price - 100) - Due to constant folding, this is NOT
optimized
+ // Test: SELECT avg(price - 100)
String query = "SELECT avg(price - 100) FROM mytable";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
- PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "avg");
_optimizer.rewrite(pinotQuery);
- // Should remain unchanged due to constant folding in parser
- assertEquals(pinotQuery.getSelectList().get(0).toString(),
- originalQuery.getSelectList().get(0).toString());
+ Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sub");
+ assertEquals(rewritten.getOperator(), "sub");
+
assertEquals(rewritten.getOperands().get(0).getFunctionCall().getOperator(),
"avg");
+ assertEquals(rewritten.getOperands().get(1).getLiteral().getIntValue(),
100);
}
@Test
public void testAvgConstantMinusColumn() {
- // Test: SELECT avg(1000 - cost) - Due to constant folding, this is NOT
optimized
+ // Test: SELECT avg(1000 - cost)
String query = "SELECT avg(1000 - cost) FROM mytable";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
- PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "avg");
_optimizer.rewrite(pinotQuery);
- // Should remain unchanged due to constant folding in parser
- assertEquals(pinotQuery.getSelectList().get(0).toString(),
- originalQuery.getSelectList().get(0).toString());
+ Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sub");
+ assertEquals(rewritten.getOperator(), "sub");
+ assertEquals(rewritten.getOperands().get(0).getLiteral().getIntValue(),
1000);
+
assertEquals(rewritten.getOperands().get(1).getFunctionCall().getOperator(),
"avg");
}
@Test
public void testAvgColumnTimesConstant() {
- // Test: SELECT avg(quantity * 2.5) - Due to constant folding, this is NOT
optimized
+ // Test: SELECT avg(quantity * 2.5)
String query = "SELECT avg(quantity * 2.5) FROM mytable";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
- PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "avg");
_optimizer.rewrite(pinotQuery);
- // Should remain unchanged due to constant folding in parser
- assertEquals(pinotQuery.getSelectList().get(0).toString(),
- originalQuery.getSelectList().get(0).toString());
+ Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "mult");
+ assertEquals(rewritten.getOperator(), "mult");
+
+ Function avgFunction = rewritten.getOperands().get(0).getFunctionCall();
+ assertEquals(avgFunction.getOperator(), "avg");
+ assertEquals(avgFunction.getOperands().get(0).getIdentifier().getName(),
"quantity");
+ assertEquals(rewritten.getOperands().get(1).getLiteral().getDoubleValue(),
2.5, 0.0001);
}
// ==================== MIN FUNCTION TESTS ====================
- // NOTE: MIN optimizations for column+constant are limited due to Pinot's
parser doing
- // constant folding before our optimizer runs. These tests verify current
behavior.
+ // MIN tests ensure the optimizer rewrites aggregation roots into arithmetic
(or flips on negative multipliers).
@Test
public void testMinColumnPlusConstant() {
- // Test: SELECT min(score + 50) - Due to constant folding, this is NOT
optimized
+ // Test: SELECT min(score + 50) -> min(score) + 50
String query = "SELECT min(score + 50) FROM mytable";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
- PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "min");
_optimizer.rewrite(pinotQuery);
- // Should remain unchanged due to constant folding in parser
- assertEquals(pinotQuery.getSelectList().get(0).toString(),
- originalQuery.getSelectList().get(0).toString());
+ Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "add");
+ assertEquals(rewritten.getOperator(), "add");
+
assertEquals(rewritten.getOperands().get(0).getFunctionCall().getOperator(),
"min");
+ assertEquals(rewritten.getOperands().get(1).getLiteral().getIntValue(),
50);
}
@Test
public void testMinConstantMinusColumn() {
- // Test: SELECT min(100 - temperature) - Due to constant folding, this is
NOT optimized
+ // Test: SELECT min(100 - temperature) -> 100 - max(temperature)
String query = "SELECT min(100 - temperature) FROM mytable";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
- PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "min");
_optimizer.rewrite(pinotQuery);
- // Should remain unchanged due to constant folding in parser
- assertEquals(pinotQuery.getSelectList().get(0).toString(),
- originalQuery.getSelectList().get(0).toString());
+ Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sub");
+ assertEquals(rewritten.getOperator(), "sub");
+ assertEquals(rewritten.getOperands().get(0).getLiteral().getIntValue(),
100);
+
assertEquals(rewritten.getOperands().get(1).getFunctionCall().getOperator(),
"max");
}
@Test
public void testMinColumnTimesPositiveConstant() {
- // Test: SELECT min(value * 3) - Due to constant folding, this is NOT
optimized
+ // Parse min(value * 3); positive constant keeps MIN
String query = "SELECT min(value * 3) FROM mytable";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
- PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "min");
_optimizer.rewrite(pinotQuery);
- // Should remain unchanged due to constant folding in parser
- assertEquals(pinotQuery.getSelectList().get(0).toString(),
- originalQuery.getSelectList().get(0).toString());
+ Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "mult");
+ assertEquals(rewritten.getOperator(), "mult");
+
+ Function aggregation = rewritten.getOperands().get(0).getFunctionCall();
+ assertEquals(aggregation.getOperator(), "min");
+ assertEquals(aggregation.getOperands().get(0).getIdentifier().getName(),
"value");
+ assertEquals(rewritten.getOperands().get(1).getLiteral().getIntValue(), 3);
}
@Test
public void testMinColumnTimesNegativeConstant() {
- // Test: SELECT min(value * -2) - Due to constant folding, this is NOT
optimized
+ // Parse min(value * -2); negative constant should flip MIN to MAX
String query = "SELECT min(value * -2) FROM mytable";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
- PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "min");
_optimizer.rewrite(pinotQuery);
- // Should remain unchanged due to constant folding in parser
- assertEquals(pinotQuery.getSelectList().get(0).toString(),
- originalQuery.getSelectList().get(0).toString());
+ Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "mult");
+ assertEquals(rewritten.getOperator(), "mult");
+
+ Function aggregation = rewritten.getOperands().get(0).getFunctionCall();
+ assertEquals(aggregation.getOperator(), "max");
+ assertEquals(aggregation.getOperands().get(0).getIdentifier().getName(),
"value");
+ assertEquals(rewritten.getOperands().get(1).getLiteral().getIntValue(),
-2);
}
// ==================== MAX FUNCTION TESTS ====================
- // NOTE: MAX optimizations for column+constant are limited due to Pinot's
parser doing
- // constant folding before our optimizer runs. These tests verify current
behavior.
+ // MAX tests follow the same pattern: aggregation before, arithmetic after
(with flips on negative multipliers).
@Test
public void testMaxColumnPlusConstant() {
- // Test: SELECT max(height + 10) - Due to constant folding, this is NOT
optimized
+ // Test: SELECT max(height + 10) -> max(height) + 10
String query = "SELECT max(height + 10) FROM mytable";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
- PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "max");
_optimizer.rewrite(pinotQuery);
- // Should remain unchanged due to constant folding in parser
- assertEquals(pinotQuery.getSelectList().get(0).toString(),
- originalQuery.getSelectList().get(0).toString());
+ Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "add");
+ assertEquals(rewritten.getOperator(), "add");
+
assertEquals(rewritten.getOperands().get(0).getFunctionCall().getOperator(),
"max");
+ assertEquals(rewritten.getOperands().get(1).getLiteral().getIntValue(),
10);
}
@Test
public void testMaxConstantMinusColumn() {
- // Test: SELECT max(200 - age) - Due to constant folding, this is NOT
optimized
+ // Test: SELECT max(200 - age) -> 200 - min(age)
String query = "SELECT max(200 - age) FROM mytable";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
- PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "max");
_optimizer.rewrite(pinotQuery);
- // Should remain unchanged due to constant folding in parser
- assertEquals(pinotQuery.getSelectList().get(0).toString(),
- originalQuery.getSelectList().get(0).toString());
+ Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sub");
+ assertEquals(rewritten.getOperator(), "sub");
+ assertEquals(rewritten.getOperands().get(0).getLiteral().getIntValue(),
200);
+
assertEquals(rewritten.getOperands().get(1).getFunctionCall().getOperator(),
"min");
}
@Test
public void testMaxColumnTimesNegativeConstant() {
- // Test: SELECT max(profit * -1) - Due to constant folding, this is NOT
optimized
+ // Parse max(profit * -1); negative constant should flip MAX to MIN
String query = "SELECT max(profit * -1) FROM mytable";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
- PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "max");
_optimizer.rewrite(pinotQuery);
- // Should remain unchanged due to constant folding in parser
- assertEquals(pinotQuery.getSelectList().get(0).toString(),
- originalQuery.getSelectList().get(0).toString());
+ Function rewritten = pinotQuery.getSelectList().get(0).getFunctionCall();
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "mult");
+ assertEquals(rewritten.getOperator(), "mult");
+
+ Function flippedAggregation =
rewritten.getOperands().get(0).getFunctionCall();
+ assertEquals(flippedAggregation.getOperator(), "min");
+
assertEquals(flippedAggregation.getOperands().get(0).getIdentifier().getName(),
"profit");
+ assertEquals(rewritten.getOperands().get(1).getLiteral().getIntValue(),
-1);
}
// ==================== COMPLEX MIXED TESTS ====================
@Test
public void testMixedAggregationOptimizations() {
- // Test multiple different aggregations in one query
- // Only SUM should be optimized due to parser constant folding limitations
+ // Test multiple different aggregations in one query; each should be
rewritten
String query = "SELECT sum(a + 1), avg(b - 2), min(c * 3), max(d + 4) FROM
mytable";
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "sum");
+ assertTopLevelOperator(pinotQuery.getSelectList().get(1), "avg");
+ assertTopLevelOperator(pinotQuery.getSelectList().get(2), "min");
+ assertTopLevelOperator(pinotQuery.getSelectList().get(3), "max");
+
_optimizer.rewrite(pinotQuery);
assertEquals(pinotQuery.getSelectList().size(), 4);
- // sum(a + 1) → sum(a) + 1 * count(1) - This SHOULD be optimized
+ // sum(a + 1) → sum(a) + 1 * count(a) - This SHOULD be optimized
verifyOptimizedAddition(pinotQuery.getSelectList().get(0), "a", 1);
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0), "add");
+
+ // avg(b - 2) -> avg(b) - 2
+ Function avgRewrite = pinotQuery.getSelectList().get(1).getFunctionCall();
+ assertTopLevelOperator(pinotQuery.getSelectList().get(1), "sub");
+ assertEquals(avgRewrite.getOperator(), "sub");
- // avg(b - 2), min(c * 3), max(d + 4) - These should NOT be optimized due
to constant folding
- // We'll verify they remain unchanged by comparing with original parsed
query
- String originalQuery = "SELECT sum(a + 1), avg(b - 2), min(c * 3), max(d +
4) FROM mytable";
- PinotQuery originalPinotQuery =
CalciteSqlParser.compileToPinotQuery(originalQuery);
+ // min(c * 3) -> min(c) * 3
+ Function minRewrite = pinotQuery.getSelectList().get(2).getFunctionCall();
+ assertTopLevelOperator(pinotQuery.getSelectList().get(2), "mult");
+ assertEquals(minRewrite.getOperator(), "mult");
- // The original avg, min, max should remain the same (only sum gets
optimized)
- assertEquals(pinotQuery.getSelectList().get(1).toString(),
- originalPinotQuery.getSelectList().get(1).toString());
- assertEquals(pinotQuery.getSelectList().get(2).toString(),
- originalPinotQuery.getSelectList().get(2).toString());
- assertEquals(pinotQuery.getSelectList().get(3).toString(),
- originalPinotQuery.getSelectList().get(3).toString());
+ // max(d + 4) -> max(d) + 4
+ Function maxRewrite = pinotQuery.getSelectList().get(3).getFunctionCall();
+ assertTopLevelOperator(pinotQuery.getSelectList().get(3), "add");
+ assertEquals(maxRewrite.getOperator(), "add");
}
@Test
public void testNonOptimizableQueries() {
- // Queries that should NOT be optimized
+ // Queries that should NOT be optimized; verify the aggregation root
remains in place
String[] queries = {
"SELECT sum(a * b) FROM mytable", // Both operands are columns
"SELECT avg(func(x)) FROM mytable", // Function call, not arithmetic
@@ -521,16 +669,30 @@ public class AggregationOptimizerTest {
PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query);
PinotQuery originalQuery = CalciteSqlParser.compileToPinotQuery(query);
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0),
+ pinotQuery.getSelectList().get(0).getFunctionCall().getOperator());
+
_optimizer.rewrite(pinotQuery);
// Should remain unchanged
assertEquals(pinotQuery.getSelectList().get(0).toString(),
originalQuery.getSelectList().get(0).toString());
+ assertTopLevelOperator(pinotQuery.getSelectList().get(0),
+
originalQuery.getSelectList().get(0).getFunctionCall().getOperator());
}
}
/**
- * Verifies that the expression is optimized to: sum(column) + constant *
count(1)
+ * Helper to assert the top-level function operator of an expression.
+ */
+ private void assertTopLevelOperator(Expression expression, String
expectedOperator) {
+ Function functionCall = expression.getFunctionCall();
+ assertNotNull(functionCall);
+ assertEquals(functionCall.getOperator(), expectedOperator);
+ }
+
+ /**
+ * Verifies that the expression is optimized to: sum(column) + constant *
count(column)
*/
private void verifyOptimizedAddition(Expression expression, String
columnName, int constantValue) {
Function function = expression.getFunctionCall();
@@ -543,7 +705,7 @@ public class AggregationOptimizerTest {
assertEquals(sumExpression.getFunctionCall().getOperator(), "sum");
assertEquals(sumExpression.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
columnName);
- // Second operand should be constant * count(1)
+ // Second operand should be constant * count(column)
Expression multExpression = function.getOperands().get(1);
assertEquals(multExpression.getFunctionCall().getOperator(), "mult");
@@ -551,14 +713,17 @@ public class AggregationOptimizerTest {
Expression constantExpr =
multExpression.getFunctionCall().getOperands().get(0);
assertEquals(constantExpr.getLiteral().getIntValue(), constantValue);
- // Verify count(1)
+ // Verify count(column)
Expression countExpr =
multExpression.getFunctionCall().getOperands().get(1);
assertEquals(countExpr.getFunctionCall().getOperator(), "count");
-
assertEquals(countExpr.getFunctionCall().getOperands().get(0).getLiteral().getIntValue(),
1);
+ Expression countOperand = countExpr.getFunctionCall().getOperands().get(0);
+ assertEquals(countOperand.getType(), ExpressionType.IDENTIFIER);
+ assertNotNull(countOperand.getIdentifier());
+ assertEquals(countOperand.getIdentifier().getName(), columnName);
}
/**
- * Verifies that the expression is optimized to: sum(column) + constant *
count(1) for float constants
+ * Verifies that the expression is optimized to: sum(column) + constant *
count(column) for float constants
*/
private void verifyOptimizedFloatAddition(Expression expression, String
columnName, double constantValue) {
Function function = expression.getFunctionCall();
@@ -571,17 +736,24 @@ public class AggregationOptimizerTest {
assertEquals(sumExpression.getFunctionCall().getOperator(), "sum");
assertEquals(sumExpression.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
columnName);
- // Second operand should be constant * count(1)
+ // Second operand should be constant * count(column)
Expression multExpression = function.getOperands().get(1);
assertEquals(multExpression.getFunctionCall().getOperator(), "mult");
// Verify constant value (for float, check double value)
Expression constantExpr =
multExpression.getFunctionCall().getOperands().get(0);
assertEquals(constantExpr.getLiteral().getDoubleValue(), constantValue,
0.0001);
+
+ // Verify count(column)
+ Expression countExpr =
multExpression.getFunctionCall().getOperands().get(1);
+ assertEquals(countExpr.getFunctionCall().getOperator(), "count");
+ Expression countOperand = countExpr.getFunctionCall().getOperands().get(0);
+ assertEquals(countOperand.getType(), ExpressionType.IDENTIFIER);
+ assertEquals(countOperand.getIdentifier().getName(), columnName);
}
/**
- * Verifies that the expression is optimized to: sum(column) - constant *
count(1)
+ * Verifies that the expression is optimized to: sum(column) - constant *
count(column)
*/
private void verifyOptimizedSubtraction(Expression expression, String
columnName, int constantValue) {
Function function = expression.getFunctionCall();
@@ -594,17 +766,24 @@ public class AggregationOptimizerTest {
assertEquals(sumExpression.getFunctionCall().getOperator(), "sum");
assertEquals(sumExpression.getFunctionCall().getOperands().get(0).getIdentifier().getName(),
columnName);
- // Second operand should be constant * count(1)
+ // Second operand should be constant * count(column)
Expression multExpression = function.getOperands().get(1);
assertEquals(multExpression.getFunctionCall().getOperator(), "mult");
// Verify constant value
Expression constantExpr =
multExpression.getFunctionCall().getOperands().get(0);
assertEquals(constantExpr.getLiteral().getIntValue(), constantValue);
+
+ // Verify count(column)
+ Expression countExpr =
multExpression.getFunctionCall().getOperands().get(1);
+ assertEquals(countExpr.getFunctionCall().getOperator(), "count");
+ Expression countOperand = countExpr.getFunctionCall().getOperands().get(0);
+ assertEquals(countOperand.getType(), ExpressionType.IDENTIFIER);
+ assertEquals(countOperand.getIdentifier().getName(), columnName);
}
/**
- * Verifies that the expression is optimized to: constant * count(1) -
sum(column)
+ * Verifies that the expression is optimized to: constant * count(column) -
sum(column)
*/
private void verifyOptimizedSubtractionReversed(Expression expression, int
constantValue, String columnName) {
Function function = expression.getFunctionCall();
@@ -612,7 +791,7 @@ public class AggregationOptimizerTest {
assertEquals(function.getOperator(), "sub");
assertEquals(function.getOperands().size(), 2);
- // First operand should be constant * count(1)
+ // First operand should be constant * count(column)
Expression multExpression = function.getOperands().get(0);
assertEquals(multExpression.getFunctionCall().getOperator(), "mult");
@@ -620,6 +799,13 @@ public class AggregationOptimizerTest {
Expression constantExpr =
multExpression.getFunctionCall().getOperands().get(0);
assertEquals(constantExpr.getLiteral().getIntValue(), constantValue);
+ // Verify count(column)
+ Expression countExpr =
multExpression.getFunctionCall().getOperands().get(1);
+ assertEquals(countExpr.getFunctionCall().getOperator(), "count");
+ Expression countOperand = countExpr.getFunctionCall().getOperands().get(0);
+ assertEquals(countOperand.getType(), ExpressionType.IDENTIFIER);
+ assertEquals(countOperand.getIdentifier().getName(), columnName);
+
// Second operand should be sum(column)
Expression sumExpression = function.getOperands().get(1);
assertEquals(sumExpression.getFunctionCall().getOperator(), "sum");
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]