Copilot commented on code in PR #17168:
URL: https://github.com/apache/pinot/pull/17168#discussion_r2586691031
##########
pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RelToPlanNodeConverter.java:
##########
@@ -159,6 +182,392 @@ public PlanNode toPlanNode(RelNode node) {
return result;
}
+ private UnnestNode convertLogicalUncollect(Uncollect node) {
+ // Extract array expressions (typically from a Project with one or more
expressions)
+ List<RexExpression> arrayExprs = new ArrayList<>();
+ RelNode input = node.getInput();
+ boolean withOrdinality = node.withOrdinality;
+
+ if (input instanceof Project) {
+ Project p = (Project) input;
+ List<RelDataTypeField> outputFields = node.getRowType().getFieldList();
+ List<RelDataTypeField> projectFields = p.getRowType().getFieldList();
+ int numProjects = p.getProjects().size();
+ int numOutputFields = outputFields.size();
+
+ // Check if WITH ORDINALITY is present: output fields = project
expressions + ordinality
+ if (numOutputFields > numProjects) {
+ withOrdinality = true;
+ }
+
+ // Extract all array expressions from the Project
+ for (int i = 0; i < numProjects; i++) {
+ arrayExprs.add(RexExpressionUtils.fromRexNode(p.getProjects().get(i)));
+ }
+ }
+
+ if (arrayExprs.isEmpty()) {
+ // Fallback: refer to first input ref
+ arrayExprs.add(new RexExpression.InputRef(0));
+ List<RelDataTypeField> fields = node.getRowType().getFieldList();
+ // Check for ordinality in fallback case
+ if (fields.size() > 1 && !withOrdinality) {
+ withOrdinality = true;
+ }
+ }
+
+ List<Integer> elementIndexes = new ArrayList<>(arrayExprs.size());
+ for (int i = 0; i < arrayExprs.size(); i++) {
+ elementIndexes.add(i);
+ }
+ int ordinalityIndex = withOrdinality ? arrayExprs.size() :
UnnestNode.UNSPECIFIED_INDEX;
+ UnnestNode.TableFunctionContext tableFunctionContext =
+ new UnnestNode.TableFunctionContext(withOrdinality, elementIndexes,
ordinalityIndex);
+ return new UnnestNode(DEFAULT_STAGE_ID, toDataSchema(node.getRowType()),
NodeHint.EMPTY,
+ convertInputs(node.getInputs()), arrayExprs, tableFunctionContext);
+ }
+
+ private BasePlanNode convertLogicalCorrelate(LogicalCorrelate node) {
+ // Pattern: Correlate(left, Uncollect(Project(correlatedFields...)))
+ RelNode right = node.getRight();
+ RelDataType leftRowType = node.getLeft().getRowType();
+ Project aliasProject = right instanceof Project ? (Project) right : null;
+ Project correlatedProject = findProjectUnderUncollect(right);
+ List<RexExpression> arrayExprs = new ArrayList<>();
+ if (correlatedProject != null) {
+ List<RelDataTypeField> outputFields = node.getRowType().getFieldList();
+ // Extract all array expressions from the Project
+ // The output fields include: left columns + array elements +
(ordinality if present)
+ // We need to extract only the array element columns (skip left columns,
skip ordinality if present)
+ int leftColumnCount = leftRowType.getFieldCount();
+ int numProjects = correlatedProject.getProjects().size();
+ List<RelDataTypeField> projectFields =
correlatedProject.getRowType().getFieldList();
+ for (int i = 0; i < numProjects; i++) {
+ RexNode rex = correlatedProject.getProjects().get(i);
+ RexExpression arrayExpr = deriveArrayExpression(rex,
correlatedProject, leftRowType);
+ if (arrayExpr == null) {
+ arrayExpr = RexExpressionUtils.fromRexNode(rex);
+ }
+ arrayExprs.add(arrayExpr);
+ }
+ }
+ if (arrayExprs.isEmpty()) {
+ // Fallback: refer to first input ref
+ arrayExprs.add(new RexExpression.InputRef(0));
+ }
+ LogicalFilter correlateFilter = findCorrelateFilter(right);
+ boolean wrapWithFilter = correlateFilter != null;
+ RexNode filterCondition = wrapWithFilter ? correlateFilter.getCondition()
: null;
+ // Use the entire correlate output schema
+ PlanNode inputNode = toPlanNode(node.getLeft());
+ // Ensure inputs list is mutable because downstream visitors (e.g.,
withInputs methods) may modify the inputs list
+ List<PlanNode> inputs = new ArrayList<>();
+ inputs.add(inputNode);
+ ElementOrdinalInfo ordinalInfo = deriveElementOrdinalInfo(right,
leftRowType, node.getRowType(), arrayExprs.size());
+ boolean withOrdinality = ordinalInfo.hasOrdinality();
+ List<Integer> elementIndexes = ordinalInfo.getElementIndexes();
+ int ordinalityIndex = ordinalInfo.getOrdinalityIndex();
+ UnnestNode.TableFunctionContext tableFunctionContext =
+ new UnnestNode.TableFunctionContext(withOrdinality, elementIndexes,
ordinalityIndex);
+ UnnestNode unnest = new UnnestNode(DEFAULT_STAGE_ID,
toDataSchema(node.getRowType()), NodeHint.EMPTY,
+ inputs, arrayExprs, tableFunctionContext);
+ if (wrapWithFilter) {
+ // Wrap Unnest with a FilterNode; rewrite filter InputRefs to absolute
output indexes
+ // For multiple arrays, we need to handle rewriting differently
+ RexExpression rewritten = rewriteInputRefsForMultipleArrays(
+ RexExpressionUtils.fromRexNode(filterCondition), elementIndexes,
ordinalityIndex);
+ return new FilterNode(DEFAULT_STAGE_ID, toDataSchema(node.getRowType()),
NodeHint.EMPTY,
+ new ArrayList<>(List.of(unnest)), rewritten);
+ }
+ return unnest;
+ }
+
+ @Nullable
+ private static Project findProjectUnderUncollect(RelNode node) {
+ RelNode current = node;
+ while (current != null) {
+ if (current instanceof Uncollect) {
+ RelNode input = ((Uncollect) current).getInput();
+ return input instanceof Project ? (Project) input : null;
+ }
+ if (current instanceof Project) {
+ current = ((Project) current).getInput();
+ } else if (current instanceof LogicalFilter) {
+ current = ((LogicalFilter) current).getInput();
+ } else {
+ return null;
+ }
+ }
+ return null;
+ }
+
+ @Nullable
+ private static Uncollect findUncollect(RelNode node) {
+ RelNode current = node;
+ while (current != null) {
+ if (current instanceof Uncollect) {
+ return (Uncollect) current;
+ }
+ if (current instanceof Project) {
+ current = ((Project) current).getInput();
+ } else if (current instanceof LogicalFilter) {
+ current = ((LogicalFilter) current).getInput();
+ } else {
+ return null;
+ }
+ }
+ return null;
+ }
+
+ @Nullable
+ private RexExpression deriveArrayExpression(RexNode rex, Project project,
RelDataType leftRowType) {
+ Integer idx = resolveInputRefFromCorrel(rex, leftRowType);
+ if (idx != null) {
+ return new RexExpression.InputRef(idx);
+ }
+ RexExpression candidate = RexExpressionUtils.fromRexNode(rex);
+ return candidate instanceof RexExpression.InputRef ? candidate : null;
+ }
+
+ @Nullable
+ private static LogicalFilter findCorrelateFilter(RelNode node) {
+ RelNode current = node;
+ while (current instanceof Project || current instanceof LogicalFilter) {
+ if (current instanceof LogicalFilter) {
+ return (LogicalFilter) current;
+ }
+ current = ((Project) current).getInput();
+ }
+ return null;
+ }
+
+ private static ElementOrdinalInfo deriveElementOrdinalInfo(RelNode right,
RelDataType leftRowType,
+ RelDataType correlateOutputRowType, int numArrays) {
+ Uncollect uncollect = findUncollect(right);
+ boolean hasOrdinality = uncollect != null && uncollect.withOrdinality;
+ ElementOrdinalAccumulator accumulator =
+ new ElementOrdinalAccumulator(leftRowType.getFieldCount(), numArrays,
hasOrdinality);
+ if (correlateOutputRowType != null) {
+ // Use the Correlate's output row type which includes left columns +
unnested elements + ordinality
+ accumulator.populateFromCorrelateOutput(correlateOutputRowType,
leftRowType.getFieldCount());
+ } else {
+ // Fallback to old logic for non-Correlate cases
+ if (right instanceof Uncollect) {
+ accumulator.populateFromRowType(right.getRowType());
+ } else if (right instanceof Project) {
+ accumulator.populateFromProject((Project) right);
+ } else if (right instanceof LogicalFilter) {
+ LogicalFilter filter = (LogicalFilter) right;
+ RelNode filterInput = filter.getInput();
+ if (filterInput instanceof Uncollect) {
+ accumulator.populateFromRowType(filter.getRowType());
+ } else if (filterInput instanceof Project) {
+ accumulator.populateFromProject((Project) filterInput);
+ }
+ }
+ }
+ if (uncollect != null) {
+ accumulator.ensureOrdinalityFromRowType(uncollect.getRowType(),
uncollect.withOrdinality);
+ }
+ return accumulator.toInfo();
+ }
+
+ private static final class ElementOrdinalAccumulator {
+ private final int _base;
+ private final int _numArrays;
+ private final boolean _hasOrdinality;
+ private final List<Integer> _elementIndexes = new ArrayList<>();
+ private int _ordinalityIndex = -1;
+
+ ElementOrdinalAccumulator(int base, int numArrays, boolean hasOrdinality) {
+ _base = base;
+ _numArrays = numArrays;
+ _hasOrdinality = hasOrdinality;
+ }
+
+ void populateFromRowType(RelDataType rowType) {
+ List<RelDataTypeField> fields = rowType.getFieldList();
+ // Extract element aliases and indexes for all arrays
+ for (int i = 0; i < _numArrays && i < fields.size(); i++) {
+ _elementIndexes.add(_base + i);
+ }
+ for (int i = fields.size(); i < _numArrays; i++) {
+ _elementIndexes.add(_base + i);
+ }
+ // Check if ordinality is present: fields.size() should be numArrays + 1
+ if (fields.size() > _numArrays && _ordinalityIndex < 0) {
+ _ordinalityIndex = _base + _numArrays;
+ }
+ }
+
+ void populateFromProject(Project project) {
+ List<RexNode> projects = project.getProjects();
+ List<RelDataTypeField> projFields = project.getRowType().getFieldList();
+ // Extract element aliases and indexes from project outputs
+ for (int j = 0; j < projects.size() && j < _numArrays; j++) {
+ _elementIndexes.add(_base + j);
+ }
+ for (int j = projects.size(); j < _numArrays; j++) {
+ _elementIndexes.add(_base + j);
+ }
+ // Check if ordinality is present: projFields.size() should be numArrays
+ 1
+ if (projFields.size() > _numArrays && _ordinalityIndex < 0) {
+ _ordinalityIndex = _base + _numArrays;
+ }
+ }
+
+ void populateFromCorrelateOutput(RelDataType correlateOutputRowType, int
leftColumnCount) {
+ List<RelDataTypeField> fields = correlateOutputRowType.getFieldList();
+ int rightFieldCount = Math.min(_numArrays + (_hasOrdinality ? 1 : 0),
fields.size());
+ int actualLeftColumns = Math.max(0, fields.size() - rightFieldCount);
+ int missingLeftColumns = Math.max(0, leftColumnCount -
actualLeftColumns);
+ int adjustedBase = Math.max(0, leftColumnCount - missingLeftColumns);
+
+ for (int i = 0; i < _numArrays; i++) {
+ int fieldIndex = adjustedBase + i;
+ if (fieldIndex < fields.size()) {
+ _elementIndexes.add(fieldIndex);
+ } else {
+ _elementIndexes.add(_base + i);
+ }
+ }
+ int ordinalityFieldIndex = adjustedBase + _numArrays;
+ if (_hasOrdinality && ordinalityFieldIndex < fields.size() &&
_ordinalityIndex < 0) {
+ _ordinalityIndex = ordinalityFieldIndex;
+ }
+ }
+
+ void ensureOrdinalityFromRowType(RelDataType rowType, boolean
hasOrdinality) {
+ if (!hasOrdinality) {
+ return;
+ }
+ List<RelDataTypeField> fields = rowType.getFieldList();
+ if (_ordinalityIndex < 0) {
+ _ordinalityIndex = _base + _numArrays;
+ }
+ }
+
+ ElementOrdinalInfo toInfo() {
+ // For backward compatibility, provide single element index if only one
array
+ int singleElementIndex = _elementIndexes.isEmpty() ? -1 :
_elementIndexes.get(0);
+ return new ElementOrdinalInfo(singleElementIndex, _ordinalityIndex,
_elementIndexes);
+ }
+ }
+
+ private static final class ElementOrdinalInfo {
+ private final int _elementIndex;
+ private final int _ordinalityIndex;
+ private final List<Integer> _elementIndexes;
+
+ ElementOrdinalInfo(int elementIndex, int ordinalityIndex) {
+ this(elementIndex, ordinalityIndex, elementIndex >= 0 ?
List.of(elementIndex) : List.of());
+ }
+
+ ElementOrdinalInfo(int elementIndex, int ordinalityIndex, List<Integer>
elementIndexes) {
+ _elementIndex = elementIndex;
+ _ordinalityIndex = ordinalityIndex;
+ _elementIndexes = elementIndexes;
+ }
+
+ int getElementIndex() {
+ return _elementIndex;
+ }
+
+ List<Integer> getElementIndexes() {
+ return _elementIndexes;
+ }
+
+ int getOrdinalityIndex() {
+ return _ordinalityIndex;
+ }
+
+ boolean hasOrdinality() {
+ return _ordinalityIndex >= 0;
+ }
+ }
+
+ private static RexExpression rewriteInputRefs(RexExpression expr, int
elemOutIdx, int ordOutIdx) {
+ if (expr instanceof RexExpression.InputRef) {
+ int idx = ((RexExpression.InputRef) expr).getIndex();
+ if (idx == 0 && elemOutIdx >= 0) {
+ return new RexExpression.InputRef(elemOutIdx);
+ } else if (idx == 1 && ordOutIdx >= 0) {
+ return new RexExpression.InputRef(ordOutIdx);
+ } else {
+ return expr;
+ }
+ } else if (expr instanceof RexExpression.FunctionCall) {
+ RexExpression.FunctionCall fc = (RexExpression.FunctionCall) expr;
+ List<RexExpression> ops = fc.getFunctionOperands();
+ List<RexExpression> rewritten = new ArrayList<>(ops.size());
+ for (RexExpression op : ops) {
+ rewritten.add(rewriteInputRefs(op, elemOutIdx, ordOutIdx));
+ }
+ return new RexExpression.FunctionCall(fc.getDataType(),
fc.getFunctionName(), rewritten);
+ } else {
+ return expr;
+ }
+ }
+
+ private static RexExpression rewriteInputRefsForMultipleArrays(RexExpression
expr, List<Integer> elemOutIdxs,
+ int ordOutIdx) {
+ if (expr instanceof RexExpression.InputRef) {
+ int idx = ((RexExpression.InputRef) expr).getIndex();
+ // Map element indexes: 0 -> first element index, 1 -> second element
index, etc.
+ if (idx >= 0 && idx < elemOutIdxs.size() && elemOutIdxs.get(idx) >= 0) {
+ return new RexExpression.InputRef(elemOutIdxs.get(idx));
+ } else if (idx == elemOutIdxs.size() && ordOutIdx >= 0) {
+ // Ordinality index comes after all element indexes
+ return new RexExpression.InputRef(ordOutIdx);
+ } else {
+ return expr;
+ }
+ } else if (expr instanceof RexExpression.FunctionCall) {
+ RexExpression.FunctionCall fc = (RexExpression.FunctionCall) expr;
+ List<RexExpression> ops = fc.getFunctionOperands();
+ List<RexExpression> rewritten = new ArrayList<>(ops.size());
+ for (RexExpression op : ops) {
+ rewritten.add(rewriteInputRefsForMultipleArrays(op, elemOutIdxs,
ordOutIdx));
+ }
+ return new RexExpression.FunctionCall(fc.getDataType(),
fc.getFunctionName(), rewritten);
+ } else {
+ return expr;
+ }
+ }
+
+ private Integer resolveInputRefFromCorrel(RexNode expr, RelDataType
leftRowType) {
+ if (expr instanceof RexFieldAccess) {
+ RexFieldAccess fieldAccess = (RexFieldAccess) expr;
+ if (fieldAccess.getReferenceExpr() instanceof RexCorrelVariable) {
+ String fieldName = fieldAccess.getField().getName();
+ List<RelDataTypeField> fields = leftRowType.getFieldList();
+ // SQL field names are case-insensitive by default in Calcite, so we
use equalsIgnoreCase for matching.
+ // NOTE: This assumes that the schema is configured with Calcite's
default case-insensitivity.
+ // If the schema is case-sensitive, this approach may produce
incorrect results. Update logic if needed.
+ for (int i = 0; i < fields.size(); i++) {
+ String candidateName = fields.get(i).getName();
+ if (candidateName == null) {
+ continue;
+ }
+ if (candidateName.equals(fieldName)) {
+ return i;
+ }
+ if (candidateName.equalsIgnoreCase(fieldName)) {
+ if (_caseSensitive) {
+ LOGGER.warn("Skipping correlated reference '{}' vs '{}' because
schema is case-sensitive. "
+ + "Adjust query or enable case-insensitive matching.",
fieldName, candidateName);
+ } else {
+ LOGGER.warn("Case-insensitive field match for correlated
reference '{}' vs '{}'. "
+ + "Ensure schema case-sensitivity is configured as
expected.", fieldName, candidateName);
+ return i;
+ }
+ }
Review Comment:
[nitpick] Consider extracting the field name matching logic into a separate
helper method. The nested if-else structure makes it harder to follow the
matching strategy and would benefit from separation. A method like
`findFieldIndex(String fieldName, List<RelDataTypeField> fields)` would improve
readability.
```suggestion
return findFieldIndex(fieldName, fields);
}
}
return null;
}
/**
* Finds the index of the field in the given list matching the specified
field name.
* Handles both case-sensitive and case-insensitive matching, with logging
for ambiguous cases.
*/
private Integer findFieldIndex(String fieldName, List<RelDataTypeField>
fields) {
for (int i = 0; i < fields.size(); i++) {
String candidateName = fields.get(i).getName();
if (candidateName == null) {
continue;
}
if (candidateName.equals(fieldName)) {
return i;
}
if (candidateName.equalsIgnoreCase(fieldName)) {
if (_caseSensitive) {
LOGGER.warn("Skipping correlated reference '{}' vs '{}' because
schema is case-sensitive. "
+ "Adjust query or enable case-insensitive matching.",
fieldName, candidateName);
} else {
LOGGER.warn("Case-insensitive field match for correlated reference
'{}' vs '{}'. "
+ "Ensure schema case-sensitivity is configured as expected.",
fieldName, candidateName);
return i;
```
##########
pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/custom/UnnestIntegrationTest.java:
##########
@@ -0,0 +1,397 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.integration.tests.custom;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.google.common.cache.Cache;
+import com.google.common.cache.CacheBuilder;
+import java.io.File;
+import java.util.List;
+import org.apache.avro.file.DataFileWriter;
+import org.apache.avro.generic.GenericData;
+import org.apache.avro.generic.GenericDatumWriter;
+import org.apache.commons.lang3.RandomStringUtils;
+import org.apache.pinot.spi.data.FieldSpec;
+import org.apache.pinot.spi.data.Schema;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.assertTrue;
+
+
+@Test(suiteName = "CustomClusterIntegrationTest")
+public class UnnestIntegrationTest extends
CustomDataQueryClusterIntegrationTest {
+
+ private static final String DEFAULT_TABLE_NAME = "UnnestIntegrationTest";
+ private static final String INT_COLUMN = "intCol";
+ private static final String LONG_COLUMN = "longCol";
+ private static final String FLOAT_COLUMN = "floatCol";
+ private static final String DOUBLE_COLUMN = "doubleCol";
+ private static final String STRING_COLUMN = "stringCol";
+ private static final String TIMESTAMP_COLUMN = "timestampCol";
+ private static final String GROUP_BY_COLUMN = "groupKey";
+ private static final String LONG_ARRAY_COLUMN = "longArrayCol";
+ private static final String DOUBLE_ARRAY_COLUMN = "doubleArrayCol";
+ private static final String STRING_ARRAY_COLUMN = "stringArrayCol";
+
+ @Test(dataProvider = "useV2QueryEngine")
+ public void testCountWithCrossJoinUnnest(boolean useMultiStageQueryEngine)
+ throws Exception {
+ setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+ String query = String.format("SELECT COUNT(*) FROM %s CROSS JOIN
UNNEST(longArrayCol) AS u(elem)", getTableName());
+ JsonNode json = postQuery(query);
+ JsonNode rows = json.get("resultTable").get("rows");
+ assertNotNull(rows);
+ long count = rows.get(0).get(0).asLong();
+ assertEquals(count, 4 * getCountStarResult());
+ }
+
+ @Test(dataProvider = "useV2QueryEngine")
+ public void testSelectWithCrossJoinUnnest(boolean useMultiStageQueryEngine)
+ throws Exception {
+ setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+ String query = String.format("SELECT intCol, u.elem FROM %s CROSS JOIN
UNNEST(stringArrayCol) AS u(elem)"
+ + " ORDER BY intCol", getTableName());
+ JsonNode json = postQuery(query);
+ JsonNode rows = json.get("resultTable").get("rows");
+ assertNotNull(rows);
+ assertEquals(rows.size(), 3 * getCountStarResult());
+ for (int i = 0; i < rows.size(); i++) {
+ JsonNode row = rows.get(i);
+ assertEquals(row.get(0).asInt(), i / 3);
+ switch (i % 3) {
+ case 0:
+ assertEquals(row.get(1).asText(), "a");
+ break;
+ case 1:
+ assertEquals(row.get(1).asText(), "b");
+ break;
+ case 2:
+ assertEquals(row.get(1).asText(), "c");
+ break;
+ default:
+ break;
+ }
+ }
+ }
+
+ @Test(dataProvider = "useV2QueryEngine")
+ public void testSelectWithCrossJoinUnnestOnMultiColumn(boolean
useMultiStageQueryEngine)
+ throws Exception {
+ setUseMultiStageQueryEngine(useMultiStageQueryEngine);
+ String query = String.format(
+ "SELECT intCol, u.longValue, u.stringValue FROM %s CROSS JOIN
UNNEST(longArrayCol, stringArrayCol) AS u"
+ + "(longValue, stringValue)"
+ + " ORDER BY intCol", getTableName());
+ JsonNode json = postQuery(query);
+ JsonNode rows = json.get("resultTable").get("rows");
+ assertNotNull(rows);
+ assertEquals(rows.size(), 4 * getCountStarResult());
+ for (int i = 0; i < rows.size(); i++) {
+ JsonNode row = rows.get(i);
+ assertEquals(row.get(0).asInt(), i / 4);
+ switch (i % 4) {
+ case 0:
+ assertEquals(row.get(1).asLong(), 0L);
+ assertEquals(row.get(2).asText(), "a");
+ break;
+ case 1:
+ assertEquals(row.get(1).asLong(), 1L);
+ assertEquals(row.get(2).asText(), "b");
+ break;
+ case 2:
+ assertEquals(row.get(1).asLong(), 2L);
+ assertEquals(row.get(2).asText(), "c");
+ break;
+ case 3:
+ assertEquals(row.get(1).asLong(), 3L);
+ assertEquals(row.get(2).asText(), "null");
Review Comment:
Using `asText()` to check for null values returns the string "null" rather
than checking if the node itself is null. Use `row.get(2).isNull()` instead to
properly verify null values, as shown correctly in other parts of the code
(e.g., line 236).
```suggestion
assertTrue(row.get(2).isNull());
```
##########
pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/UnnestOperator.java:
##########
@@ -0,0 +1,301 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.query.runtime.operator;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import org.apache.pinot.common.datatable.StatMap;
+import org.apache.pinot.common.utils.DataSchema;
+import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.planner.plannode.UnnestNode;
+import org.apache.pinot.query.runtime.blocks.MseBlock;
+import org.apache.pinot.query.runtime.blocks.RowHeapDataBlock;
+import org.apache.pinot.query.runtime.operator.operands.TransformOperand;
+import
org.apache.pinot.query.runtime.operator.operands.TransformOperandFactory;
+import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+
+/**
+ * UnnestOperator expands array/collection values per input row into zero or
more output rows.
+ * Supports multiple arrays, aligning them by index (like a zip operation).
+ * If arrays have different lengths, shorter arrays are padded with null
values.
+ * The output schema is provided by the associated UnnestNode's data schema.
+ */
+public class UnnestOperator extends MultiStageOperator {
+ private static final Logger LOGGER =
LoggerFactory.getLogger(UnnestOperator.class);
+ private static final String EXPLAIN_NAME = "UNNEST";
+
+ private final MultiStageOperator _input;
+ private final List<TransformOperand> _arrayExprOperands;
+ private final DataSchema _resultSchema;
+ private final boolean _withOrdinality;
+ private final List<Integer> _elementIndexes;
+ private final int _ordinalityIndex;
+ private final StatMap<StatKey> _statMap = new StatMap<>(StatKey.class);
+ private boolean _loggedElementOverflow;
+
+ public UnnestOperator(OpChainExecutionContext context, MultiStageOperator
input, DataSchema inputSchema,
+ UnnestNode node) {
+ super(context);
+ _input = input;
+ List<RexExpression> arrayExprs = node.getArrayExprs();
+ _arrayExprOperands = new ArrayList<>(arrayExprs.size());
+ for (RexExpression arrayExpr : arrayExprs) {
+
_arrayExprOperands.add(TransformOperandFactory.getTransformOperand(arrayExpr,
inputSchema));
+ }
+ _resultSchema = node.getDataSchema();
+ _withOrdinality = node.isWithOrdinality();
+ _elementIndexes = node.getElementIndexes();
+ _ordinalityIndex = node.getOrdinalityIndex();
+ }
+
+ @Override
+ public void registerExecution(long time, int numRows, long memoryUsedBytes,
long gcTimeMs) {
+ _statMap.merge(StatKey.EXECUTION_TIME_MS, time);
+ _statMap.merge(StatKey.EMITTED_ROWS, numRows);
+ _statMap.merge(StatKey.ALLOCATED_MEMORY_BYTES, memoryUsedBytes);
+ _statMap.merge(StatKey.GC_TIME_MS, gcTimeMs);
+ }
+
+ @Override
+ protected Logger logger() {
+ return LOGGER;
+ }
+
+ @Override
+ public List<MultiStageOperator> getChildOperators() {
+ return List.of(_input);
+ }
+
+ @Override
+ public Type getOperatorType() {
+ return Type.UNNEST;
+ }
+
+ @Override
+ public String toExplainString() {
+ return EXPLAIN_NAME;
+ }
+
+ /**
+ * Produces zipped rows across the configured array expressions for each
input row.
+ * If an expression evaluates to a scalar instead of an array/list, we
intentionally treat it as a single-element
+ * array so the row still participates in UNNEST output instead of being
dropped.
+ */
+ @Override
+ protected MseBlock getNextBlock() {
+ MseBlock block = _input.nextBlock();
+ if (block.isEos()) {
+ return block;
+ }
+ MseBlock.Data dataBlock = (MseBlock.Data) block;
+ List<Object[]> inputRows = dataBlock.asRowHeap().getRows();
+ List<Object[]> outRows = new ArrayList<>();
+
+ for (Object[] row : inputRows) {
+ // Extract all arrays from the input row
+ List<List<Object>> arrays = new ArrayList<>();
+ for (TransformOperand operand : _arrayExprOperands) {
+ Object value = operand.apply(row);
+ List<Object> elements = extractArrayElements(value);
+ arrays.add(elements);
+ }
+ // Align arrays by index (zip operation)
+ alignArraysByIndex(row, arrays, outRows);
+ }
+
+ return new RowHeapDataBlock(outRows, _resultSchema);
+ }
+
+ private List<Object> extractArrayElements(Object value) {
+ List<Object> elements = new ArrayList<>();
+ if (value == null) {
+ return elements;
+ }
+ if (value instanceof List) {
+ elements.addAll((List<?>) value);
+ } else if (value.getClass().isArray()) {
+ if (value instanceof int[]) {
+ int[] arr = (int[]) value;
+ for (int v : arr) {
+ elements.add(v);
+ }
+ } else if (value instanceof long[]) {
+ long[] arr = (long[]) value;
+ for (long v : arr) {
+ elements.add(v);
+ }
+ } else if (value instanceof double[]) {
+ double[] arr = (double[]) value;
+ for (double v : arr) {
+ elements.add(v);
+ }
+ } else if (value instanceof float[]) {
+ float[] arr = (float[]) value;
+ for (float v : arr) {
+ elements.add(v);
+ }
+ } else if (value instanceof boolean[]) {
+ boolean[] arr = (boolean[]) value;
+ for (boolean v : arr) {
+ elements.add(v);
+ }
+ } else if (value instanceof char[]) {
+ char[] arr = (char[]) value;
+ for (char v : arr) {
+ elements.add(v);
+ }
+ } else if (value instanceof short[]) {
+ short[] arr = (short[]) value;
+ for (short v : arr) {
+ elements.add(v);
+ }
+ } else if (value instanceof byte[]) {
+ byte[] arr = (byte[]) value;
+ for (byte v : arr) {
+ elements.add(v);
+ }
+ } else if (value instanceof String[]) {
+ String[] arr = (String[]) value;
+ Collections.addAll(elements, arr);
+ } else if (value instanceof Object[]) {
+ Object[] arr = (Object[]) value;
+ Collections.addAll(elements, arr);
+ } else {
+ // Last-resort fallback for uncommon array types; use reflection only
in this slow path.
+ int length = java.lang.reflect.Array.getLength(value);
+ for (int i = 0; i < length; i++) {
+ elements.add(java.lang.reflect.Array.get(value, i));
+ }
+ }
+ } else {
+ // If not array-like, treat as a single element
+ elements.add(value);
+ }
+ return elements;
+ }
Review Comment:
[nitpick] The `extractArrayElements` method has high cyclomatic complexity
with 11 consecutive if-else branches for different primitive array types.
Consider extracting each array type handling into separate methods (e.g.,
`extractIntArray`, `extractLongArray`) or using a strategy pattern to reduce
the method's complexity and improve testability.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]