This is an automated email from the ASF dual-hosted git repository.
yiguolei pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-4.1 by this push:
new a8ff3007e5b branch-4.1: [fix](fe) Prevent cast project pushdown
through union distinct #64080 (#64557)
a8ff3007e5b is described below
commit a8ff3007e5b4bac1ecc57792a098a47a79c754d1
Author: github-actions[bot]
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Wed Jun 17 09:30:05 2026 +0800
branch-4.1: [fix](fe) Prevent cast project pushdown through union distinct
#64080 (#64557)
Cherry-picked from #64080
Co-authored-by: morrySnow <[email protected]>
---
.../rules/rewrite/PushProjectThroughUnion.java | 33 +++++++++--
.../org/apache/doris/nereids/types/ArrayType.java | 9 +++
.../apache/doris/nereids/types/BooleanType.java | 8 +++
.../org/apache/doris/nereids/types/DataType.java | 4 ++
.../apache/doris/nereids/types/DateTimeType.java | 9 +++
.../apache/doris/nereids/types/DateTimeV2Type.java | 13 +++++
.../org/apache/doris/nereids/types/DateType.java | 6 ++
.../apache/doris/nereids/types/DecimalV2Type.java | 14 +++++
.../apache/doris/nereids/types/DecimalV3Type.java | 14 +++++
.../org/apache/doris/nereids/types/MapType.java | 10 ++++
.../org/apache/doris/nereids/types/StructType.java | 18 ++++++
.../doris/nereids/types/TimeStampTzType.java | 10 ++++
.../org/apache/doris/nereids/types/TimeV2Type.java | 10 ++++
.../apache/doris/nereids/types/VariantType.java | 5 ++
.../nereids/types/coercion/CharacterType.java | 5 ++
.../doris/nereids/types/coercion/IntegralType.java | 15 +++++
.../apache/doris/nereids/util/ExpressionUtils.java | 14 +++++
.../rules/rewrite/PushProjectThroughUnionTest.java | 57 +++++++++++++++++++
.../apache/doris/nereids/types/DataTypeTest.java | 66 ++++++++++++++++++++++
.../data/nereids_syntax_p0/set_operation.out | 9 ++-
.../suites/nereids_syntax_p0/set_operation.groovy | 21 +++++++
21 files changed, 343 insertions(+), 7 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnion.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnion.java
index dfaef55b56c..87b239556ec 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnion.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnion.java
@@ -20,12 +20,14 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
@@ -35,9 +37,11 @@ import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
import java.util.List;
import java.util.Map;
+import java.util.Set;
/**
* this rule push down the project through union to let MergeUnion could do
better
@@ -56,14 +60,31 @@ public class PushProjectThroughUnion extends
OneRewriteRuleFactory {
/** canPushProject */
public static boolean canPushProject(List<NamedExpression> projects,
LogicalSetOperation logicalSetOperation) {
- return projects.size() == logicalSetOperation.getOutput().size() &&
projects.stream().allMatch(e -> {
- if (e instanceof SlotReference) {
- return true;
+ if (projects.size() != logicalSetOperation.getOutput().size()) {
+ return false;
+ }
+ boolean isAll =
logicalSetOperation.getQualifier().equals(Qualifier.ALL);
+ Set<ExprId> projectInputExprIds =
Sets.newHashSetWithExpectedSize(projects.size());
+ for (NamedExpression project : projects) {
+ Expression input;
+ if (project instanceof SlotReference) {
+ input = project;
+ } else if (isAll) {
+ input =
ExpressionUtils.getExpressionCoveredByCast(project.child(0));
} else {
- Expression expr =
ExpressionUtils.getExpressionCoveredByCast(e.child(0));
- return expr instanceof SlotReference;
+ input =
ExpressionUtils.getExpressionCoveredBySafetyCast(project.child(0));
+ }
+ if (!(input instanceof SlotReference)) {
+ return false;
}
- });
+ projectInputExprIds.add(((SlotReference) input).getExprId());
+ }
+ if (isAll) {
+ return true;
+ }
+ Set<ExprId> outputExprIds =
Sets.newHashSetWithExpectedSize(logicalSetOperation.getOutput().size());
+ logicalSetOperation.getOutput().forEach(output ->
outputExprIds.add(output.getExprId()));
+ return projectInputExprIds.equals(outputExprIds);
}
/** doPushProject */
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/ArrayType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/ArrayType.java
index 6d4ec539ff3..c189154fb42 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/ArrayType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/ArrayType.java
@@ -18,6 +18,7 @@
package org.apache.doris.nereids.types;
import org.apache.doris.catalog.Type;
+import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.types.coercion.ComplexDataType;
import java.util.Objects;
@@ -63,6 +64,14 @@ public class ArrayType extends DataType implements
ComplexDataType, NestedColumn
return containsNull;
}
+ @Override
+ public boolean isInjectiveCastTo(DataType target) {
+ if (target instanceof ArrayType) {
+ return itemType.isInjectiveCastTo(((ArrayType) target).itemType);
+ }
+ return target instanceof CharacterType;
+ }
+
@Override
public Type toCatalogDataType() {
return new
org.apache.doris.catalog.ArrayType(itemType.toCatalogDataType(), containsNull);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BooleanType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BooleanType.java
index 49b2a6e72d7..708801a883f 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BooleanType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BooleanType.java
@@ -31,6 +31,14 @@ public class BooleanType extends PrimitiveType {
private BooleanType() {
}
+ @Override
+ public boolean isInjectiveCastTo(DataType target) {
+ return target instanceof BooleanType || target.isIntegralType() ||
target.isFloatLikeType()
+ || (target instanceof DecimalV2Type && ((DecimalV2Type)
target).getRange() >= 1)
+ || (target instanceof DecimalV3Type && ((DecimalV3Type)
target).getRange() >= 1)
+ || target.isStringLikeType();
+ }
+
@Override
public Type toCatalogDataType() {
return Type.BOOLEAN;
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java
index 8c30b835b05..4e48e5aa484 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java
@@ -812,6 +812,10 @@ public abstract class DataType {
public abstract int width();
+ public boolean isInjectiveCastTo(DataType target) {
+ return this.equals(target);
+ }
+
public static List<DataType> trivialTypes() {
return Type.getTrivialTypes()
.stream()
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeType.java
index 8a0250d7b44..a93bfda0364 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeType.java
@@ -19,6 +19,7 @@ package org.apache.doris.nereids.types;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Config;
+import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.types.coercion.DateLikeType;
import java.time.DateTimeException;
@@ -45,6 +46,14 @@ public class DateTimeType extends DateLikeType {
this.shouldConversion = shouldConversion;
}
+ @Override
+ public boolean isInjectiveCastTo(DataType target) {
+ if (target instanceof DateTimeType || target instanceof DateTimeV2Type
|| target instanceof CharacterType) {
+ return true;
+ }
+ return false;
+ }
+
@Override
public DataType conversion() {
if (Config.enable_date_conversion && shouldConversion) {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeV2Type.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeV2Type.java
index f56b4662f8b..13097339554 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeV2Type.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeV2Type.java
@@ -23,6 +23,7 @@ import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral;
import
org.apache.doris.nereids.trees.expressions.literal.format.DateTimeChecker;
+import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.types.coercion.DateLikeType;
import org.apache.doris.nereids.types.coercion.IntegralType;
import org.apache.doris.nereids.types.coercion.ScaleTimeType;
@@ -128,6 +129,18 @@ public class DateTimeV2Type extends DateLikeType
implements ScaleTimeType {
return super.toSql() + "(" + scale + ")";
}
+ @Override
+ public boolean isInjectiveCastTo(DataType target) {
+ if (target instanceof DateTimeV2Type) {
+ DateTimeV2Type t2 = (DateTimeV2Type) target;
+ return this.scale <= t2.scale;
+ }
+ if (target instanceof DateTimeType) {
+ return this.scale == 0;
+ }
+ return target instanceof CharacterType;
+ }
+
@Override
public Type toCatalogDataType() {
return ScalarType.createDatetimeV2Type(scale);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateType.java
index d127ab16069..c6ce702ebe7 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateType.java
@@ -19,6 +19,7 @@ package org.apache.doris.nereids.types;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Config;
+import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.types.coercion.DateLikeType;
import java.time.DateTimeException;
@@ -45,6 +46,11 @@ public class DateType extends DateLikeType {
this.shouldConversion = shouldConversion;
}
+ @Override
+ public boolean isInjectiveCastTo(DataType target) {
+ return target instanceof DateType || target instanceof DateV2Type ||
target instanceof CharacterType;
+ }
+
@Override
public DataType conversion() {
if (Config.enable_date_conversion && shouldConversion) {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV2Type.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV2Type.java
index b601aaa9f13..b055172f262 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV2Type.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV2Type.java
@@ -21,6 +21,7 @@ import org.apache.doris.catalog.PrimitiveType;
import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.Type;
import org.apache.doris.common.Config;
+import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.types.coercion.FractionalType;
import com.google.common.base.Preconditions;
@@ -159,6 +160,19 @@ public class DecimalV2Type extends FractionalType {
return DecimalV2Type.createDecimalV2Type(range + scale, scale);
}
+ @Override
+ public boolean isInjectiveCastTo(DataType target) {
+ if (target instanceof DecimalV2Type) {
+ DecimalV2Type decimalV2Type = (DecimalV2Type) target;
+ return decimalV2Type.getRange() >= this.getRange() &&
decimalV2Type.getScale() >= this.getScale();
+ }
+ if (target instanceof DecimalV3Type) {
+ DecimalV3Type decimalV3Type = (DecimalV3Type) target;
+ return decimalV3Type.getRange() >= this.getRange() &&
decimalV3Type.getScale() >= this.getScale();
+ }
+ return target instanceof CharacterType;
+ }
+
@Override
public Type toCatalogDataType() {
return ScalarType.createDecimalType(PrimitiveType.DECIMALV2,
precision, scale);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV3Type.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV3Type.java
index 3c0a83e95c4..b366568cb35 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV3Type.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV3Type.java
@@ -22,6 +22,7 @@ import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.exceptions.NotSupportedException;
+import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.types.coercion.FractionalType;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.qe.SessionVariable;
@@ -214,6 +215,19 @@ public class DecimalV3Type extends FractionalType {
}
}
+ @Override
+ public boolean isInjectiveCastTo(DataType target) {
+ if (target instanceof DecimalV2Type) {
+ DecimalV2Type decimalV2Type = (DecimalV2Type) target;
+ return decimalV2Type.getRange() >= this.getRange() &&
decimalV2Type.getScale() >= this.getScale();
+ }
+ if (target instanceof DecimalV3Type) {
+ DecimalV3Type decimalV3Type = (DecimalV3Type) target;
+ return decimalV3Type.getRange() >= this.getRange() &&
decimalV3Type.getScale() >= this.getScale();
+ }
+ return target instanceof CharacterType;
+ }
+
@Override
public Type toCatalogDataType() {
return ScalarType.createDecimalV3Type(precision, scale);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/MapType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/MapType.java
index 176c1db1d0d..fc6e9ba2f94 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/MapType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/MapType.java
@@ -19,6 +19,7 @@ package org.apache.doris.nereids.types;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.annotation.Developing;
+import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.types.coercion.ComplexDataType;
import java.util.Objects;
@@ -63,6 +64,15 @@ public class MapType extends DataType implements
ComplexDataType, NestedColumnPr
return MapType.of(keyType.conversion(), valueType.conversion());
}
+ @Override
+ public boolean isInjectiveCastTo(DataType target) {
+ if (target instanceof MapType) {
+ MapType mapType = (MapType) target;
+ return keyType.isInjectiveCastTo(mapType.keyType) &&
valueType.isInjectiveCastTo(mapType.valueType);
+ }
+ return target instanceof CharacterType;
+ }
+
@Override
public Type toCatalogDataType() {
return new
org.apache.doris.catalog.MapType(keyType.toCatalogDataType(),
valueType.toCatalogDataType());
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/StructType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/StructType.java
index 0c33a6d2dec..13f28c2e06e 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/StructType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/StructType.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.types;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.exceptions.AnalysisException;
+import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.types.coercion.ComplexDataType;
import com.google.common.collect.ImmutableList;
@@ -84,6 +85,23 @@ public class StructType extends DataType implements
ComplexDataType, NestedColum
return new
StructType(fields.stream().map(StructField::conversion).collect(Collectors.toList()));
}
+ @Override
+ public boolean isInjectiveCastTo(DataType target) {
+ if (target instanceof StructType) {
+ StructType structType = (StructType) target;
+ if (this.fields.size() != structType.fields.size()) {
+ return false;
+ }
+ for (int i = 0; i < fields.size(); i++) {
+ if
(!this.fields.get(i).getDataType().isInjectiveCastTo(structType.fields.get(i).getDataType()))
{
+ return false;
+ }
+ }
+ return true;
+ }
+ return target instanceof CharacterType;
+ }
+
@Override
public Type toCatalogDataType() {
return new org.apache.doris.catalog.StructType(fields.stream()
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeStampTzType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeStampTzType.java
index c3c99cad6fc..4f9c09b5f7f 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeStampTzType.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeStampTzType.java
@@ -21,6 +21,7 @@ import org.apache.doris.catalog.ScalarType;
import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral;
+import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.types.coercion.DateLikeType;
import org.apache.doris.nereids.types.coercion.ScaleTimeType;
@@ -46,6 +47,15 @@ public class TimeStampTzType extends DateLikeType implements
ScaleTimeType {
this.scale = scale;
}
+ @Override
+ public boolean isInjectiveCastTo(DataType target) {
+ if (target instanceof TimeStampTzType) {
+ TimeStampTzType timeStampTzType = (TimeStampTzType) target;
+ return timeStampTzType.getScale() >= this.scale;
+ }
+ return target instanceof CharacterType;
+ }
+
@Override
public Type toCatalogDataType() {
return ScalarType.createTimeStampTzType(scale);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeV2Type.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeV2Type.java
index 39f420e6931..af758c02cb5 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeV2Type.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeV2Type.java
@@ -22,6 +22,7 @@ import org.apache.doris.catalog.Type;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TimeV2Literal;
+import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.types.coercion.IntegralType;
import org.apache.doris.nereids.types.coercion.PrimitiveType;
import org.apache.doris.nereids.types.coercion.RangeScalable;
@@ -48,6 +49,15 @@ public class TimeV2Type extends PrimitiveType implements
RangeScalable, ScaleTim
scale = 0;
}
+ @Override
+ public boolean isInjectiveCastTo(DataType target) {
+ if (target instanceof TimeV2Type) {
+ TimeV2Type timeV2Type = (TimeV2Type) target;
+ return timeV2Type.scale >= scale;
+ }
+ return target instanceof CharacterType;
+ }
+
@Override
public Type toCatalogDataType() {
return ScalarType.createTimeV2Type(scale);
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java
index 4b54aad5414..75e05a6dde9 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java
@@ -115,6 +115,11 @@ public class VariantType extends PrimitiveType {
this.enableNestedGroup = enableNestedGroup;
}
+ @Override
+ public boolean isInjectiveCastTo(DataType target) {
+ return target.equals(this) || target instanceof VariantType;
+ }
+
@Override
public DataType conversion() {
return new
VariantType(predefinedFields.stream().map(VariantField::conversion)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/CharacterType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/CharacterType.java
index 781b1257028..3d8590534f5 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/CharacterType.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/CharacterType.java
@@ -42,6 +42,11 @@ public abstract class CharacterType extends PrimitiveType {
return len;
}
+ @Override
+ public boolean isInjectiveCastTo(DataType target) {
+ return target instanceof CharacterType;
+ }
+
@Override
public Type toCatalogDataType() {
throw new RuntimeException("CharacterType is only used for implicit
cast.");
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java
index b1e58805388..fe625fa34bc 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java
@@ -19,6 +19,8 @@ package org.apache.doris.nereids.types.coercion;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.DecimalV3Type;
+import org.apache.doris.nereids.types.LargeIntType;
import org.apache.commons.lang3.NotImplementedException;
@@ -44,6 +46,19 @@ public class IntegralType extends NumericType {
return "integral";
}
+ @Override
+ public boolean isInjectiveCastTo(DataType target) {
+ if (target instanceof IntegralType) {
+ return this.equals(target) || ((IntegralType)
target).widerThan(this);
+ }
+ if (target instanceof DecimalV3Type && !(this instanceof
LargeIntType)) {
+ DecimalV3Type other = (DecimalV3Type) target;
+ DecimalV3Type self = DecimalV3Type.forType(this);
+ return other.getRange() >= self.getRange();
+ }
+ return target instanceof CharacterType;
+ }
+
public boolean widerThan(IntegralType other) {
return this.width() > other.width();
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index faa8ff65f05..ec93e1da46f 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -1126,6 +1126,20 @@ public class ExpressionUtils {
return expression;
}
+ /**
+ * Strip only casts that preserve distinctness of the child expression.
+ */
+ public static Expression getExpressionCoveredBySafetyCast(Expression
expression) {
+ while (expression instanceof Cast) {
+ if (((Cast)
expression).child().getDataType().isInjectiveCastTo(expression.getDataType())) {
+ expression = ((Cast) expression).child();
+ } else {
+ break;
+ }
+ }
+ return expression;
+ }
+
/**
* the expressions can be used as runtime filter targets
*/
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnionTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnionTest.java
index 328c390d52f..a877348491c 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnionTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnionTest.java
@@ -28,6 +28,8 @@ import
org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.types.BigIntType;
+import org.apache.doris.nereids.types.DateTimeType;
+import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
@@ -107,6 +109,61 @@ public class PushProjectThroughUnionTest {
}
}
+ @Test
+ public void testCastProjectPushThroughUnionByQualifierAndSafety() {
+ SlotReference unionOutput = new SlotReference(new ExprId(10), "s",
+ IntegerType.INSTANCE, true, ImmutableList.of());
+ Alias castProject = new Alias(new ExprId(100),
+ new Cast(unionOutput, BigIntType.INSTANCE), "n");
+ ImmutableList<NamedExpression> projects =
ImmutableList.of(castProject);
+
+ LogicalUnion unionAll = new LogicalUnion(Qualifier.ALL,
+ ImmutableList.of(unionOutput), ImmutableList.of(),
ImmutableList.of(), false, ImmutableList.of());
+ Assertions.assertTrue(PushProjectThroughUnion.canPushProject(projects,
unionAll));
+
+ LogicalUnion unionDistinct = new LogicalUnion(Qualifier.DISTINCT,
+ ImmutableList.of(unionOutput), ImmutableList.of(),
ImmutableList.of(), false, ImmutableList.of());
+ Assertions.assertTrue(PushProjectThroughUnion.canPushProject(projects,
unionDistinct));
+
+ SlotReference dateTimeOutput = new SlotReference(new ExprId(11), "dt",
+ DateTimeType.INSTANCE, true, ImmutableList.of());
+ Alias unsafeCastProject = new Alias(new ExprId(101),
+ new Cast(dateTimeOutput, DateType.INSTANCE), "d");
+ ImmutableList<NamedExpression> unsafeProjects =
ImmutableList.of(unsafeCastProject);
+
+ LogicalUnion unionAllWithUnsafeCast = new LogicalUnion(Qualifier.ALL,
+ ImmutableList.of(dateTimeOutput), ImmutableList.of(),
ImmutableList.of(), false, ImmutableList.of());
+
Assertions.assertTrue(PushProjectThroughUnion.canPushProject(unsafeProjects,
unionAllWithUnsafeCast));
+
+ LogicalUnion unionDistinctWithUnsafeCast = new
LogicalUnion(Qualifier.DISTINCT,
+ ImmutableList.of(dateTimeOutput), ImmutableList.of(),
ImmutableList.of(), false, ImmutableList.of());
+
Assertions.assertFalse(PushProjectThroughUnion.canPushProject(unsafeProjects,
unionDistinctWithUnsafeCast));
+ }
+
+ @Test
+ public void testDistinctProjectRequiresAllOutputSlotsExactlyOnce() {
+ SlotReference firstOutput = new SlotReference(new ExprId(10), "a",
+ IntegerType.INSTANCE, true, ImmutableList.of());
+ SlotReference secondOutput = new SlotReference(new ExprId(11), "b",
+ IntegerType.INSTANCE, true, ImmutableList.of());
+ ImmutableList<NamedExpression> outputs = ImmutableList.of(firstOutput,
secondOutput);
+ LogicalUnion unionAll = new LogicalUnion(Qualifier.ALL,
+ outputs, ImmutableList.of(), ImmutableList.of(), false,
ImmutableList.of());
+ LogicalUnion unionDistinct = new LogicalUnion(Qualifier.DISTINCT,
+ outputs, ImmutableList.of(), ImmutableList.of(), false,
ImmutableList.of());
+
+ ImmutableList<NamedExpression> duplicateProjects = ImmutableList.of(
+ new Alias(new ExprId(100), new Cast(firstOutput,
BigIntType.INSTANCE), "a1"),
+ new Alias(new ExprId(101), new Cast(firstOutput,
BigIntType.INSTANCE), "a2"));
+
Assertions.assertTrue(PushProjectThroughUnion.canPushProject(duplicateProjects,
unionAll));
+
Assertions.assertFalse(PushProjectThroughUnion.canPushProject(duplicateProjects,
unionDistinct));
+
+ ImmutableList<NamedExpression> permutationProjects = ImmutableList.of(
+ new Alias(new ExprId(102), new Cast(secondOutput,
BigIntType.INSTANCE), "b"),
+ new Alias(new ExprId(103), new Cast(firstOutput,
BigIntType.INSTANCE), "a"));
+
Assertions.assertTrue(PushProjectThroughUnion.canPushProject(permutationProjects,
unionDistinct));
+ }
+
private LogicalUnion findUnion(Plan p) {
if (p instanceof LogicalUnion) {
return (LogicalUnion) p;
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/DataTypeTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/DataTypeTest.java
index 59509fce805..9720fa1c550 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/DataTypeTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/DataTypeTest.java
@@ -138,6 +138,62 @@ public class DataTypeTest {
}
+ @Test
+ public void testIsInjectiveCastToForPrimitiveTypes() {
+ assertSafeCast(IntegerType.INSTANCE, IntegerType.INSTANCE);
+ assertSafeCast(IntegerType.INSTANCE, BigIntType.INSTANCE);
+ assertUnsafeCast(BigIntType.INSTANCE, IntegerType.INSTANCE);
+ assertSafeCast(IntegerType.INSTANCE,
DecimalV3Type.createDecimalV3Type(10, 0));
+ assertUnsafeCast(IntegerType.INSTANCE,
DecimalV3Type.createDecimalV3Type(9, 0));
+ assertUnsafeCast(LargeIntType.INSTANCE,
DecimalV3Type.createDecimalV3Type(38, 0));
+
+ assertSafeCast(BooleanType.INSTANCE,
DecimalV3Type.createDecimalV3Type(1, 0));
+ assertUnsafeCast(BooleanType.INSTANCE,
DecimalV3Type.createDecimalV3Type(1, 1));
+
+ assertSafeCast(DecimalV3Type.createDecimalV3Type(6, 2),
DecimalV3Type.createDecimalV3Type(8, 3));
+ assertUnsafeCast(DecimalV3Type.createDecimalV3Type(6, 2),
DecimalV3Type.createDecimalV3Type(6, 1));
+ assertUnsafeCast(DecimalV3Type.createDecimalV3Type(6, 2),
DecimalV3Type.createDecimalV3Type(5, 2));
+
+ assertSafeCast(DateTimeType.INSTANCE, DateTimeV2Type.of(0));
+ assertSafeCast(DateTimeV2Type.of(0), DateTimeType.INSTANCE);
+ assertSafeCast(DateTimeV2Type.of(3), DateTimeV2Type.of(6));
+ assertUnsafeCast(DateTimeV2Type.of(3), DateTimeType.INSTANCE);
+ assertUnsafeCast(DateTimeType.INSTANCE, DateType.INSTANCE);
+
+ assertSafeCast(VarcharType.createVarcharType(10),
VarcharType.createVarcharType(20));
+ assertSafeCast(VarcharType.createVarcharType(10), StringType.INSTANCE);
+ assertSafeCast(VarcharType.createVarcharType(20),
VarcharType.createVarcharType(10));
+ assertSafeCast(StringType.INSTANCE, VarcharType.createVarcharType(10));
+ }
+
+ @Test
+ public void testIsInjectiveCastToForComplexTypes() {
+ assertSafeCast(ArrayType.of(IntegerType.INSTANCE),
ArrayType.of(BigIntType.INSTANCE));
+ assertUnsafeCast(ArrayType.of(BigIntType.INSTANCE),
ArrayType.of(IntegerType.INSTANCE));
+
+ assertSafeCast(MapType.of(IntegerType.INSTANCE,
VarcharType.createVarcharType(10)),
+ MapType.of(BigIntType.INSTANCE, StringType.INSTANCE));
+ assertUnsafeCast(MapType.of(BigIntType.INSTANCE,
VarcharType.createVarcharType(10)),
+ MapType.of(IntegerType.INSTANCE, StringType.INSTANCE));
+
+ StructType intStringStruct = new StructType(ImmutableList.of(
+ new StructField("a", IntegerType.INSTANCE, true, ""),
+ new StructField("b", VarcharType.createVarcharType(10), true,
"")));
+ StructType bigintStringStruct = new StructType(ImmutableList.of(
+ new StructField("a", BigIntType.INSTANCE, true, ""),
+ new StructField("b", StringType.INSTANCE, true, "")));
+ StructType intOnlyStruct = new StructType(ImmutableList.of(
+ new StructField("a", IntegerType.INSTANCE, true, "")));
+
+ assertSafeCast(intStringStruct, bigintStringStruct);
+ assertUnsafeCast(bigintStringStruct, intStringStruct);
+ assertUnsafeCast(intOnlyStruct, intStringStruct);
+
+ assertSafeCast(ArrayType.of(IntegerType.INSTANCE),
StringType.INSTANCE);
+ assertSafeCast(MapType.of(IntegerType.INSTANCE, StringType.INSTANCE),
StringType.INSTANCE);
+ assertSafeCast(intStringStruct, StringType.INSTANCE);
+ }
+
@Test
public void testAnyAccept() {
AnyDataType dateType = AnyDataType.INSTANCE_WITHOUT_INDEX;
@@ -654,4 +710,14 @@ public class DataTypeTest {
DataType type = ArrayType.of(MapType.of(VarcharType.SYSTEM_DEFAULT,
IntegerType.INSTANCE));
Assertions.assertDoesNotThrow(type::validateDataType);
}
+
+ private void assertSafeCast(DataType source, DataType target) {
+ Assertions.assertTrue(source.isInjectiveCastTo(target), source.toSql()
+ " should safely cast to "
+ + target.toSql());
+ }
+
+ private void assertUnsafeCast(DataType source, DataType target) {
+ Assertions.assertFalse(source.isInjectiveCastTo(target),
source.toSql() + " should not safely cast to "
+ + target.toSql());
+ }
}
diff --git a/regression-test/data/nereids_syntax_p0/set_operation.out
b/regression-test/data/nereids_syntax_p0/set_operation.out
index 5afc4fac2ad..b21f59c0f65 100644
--- a/regression-test/data/nereids_syntax_p0/set_operation.out
+++ b/regression-test/data/nereids_syntax_p0/set_operation.out
@@ -592,6 +592,14 @@ hell0
-- !union45 --
2
+-- !union46 --
+2020-01-01
+2020-01-01
+
+-- !union47 --
+1 1
+1 1
+
-- !check_child_col_order --
205548764.21875 3601
53950855.65625 3602
@@ -599,4 +607,3 @@ hell0
-- !intersect_case --
0
1
-
diff --git a/regression-test/suites/nereids_syntax_p0/set_operation.groovy
b/regression-test/suites/nereids_syntax_p0/set_operation.groovy
index c81c04aae5d..e3126212827 100644
--- a/regression-test/suites/nereids_syntax_p0/set_operation.groovy
+++ b/regression-test/suites/nereids_syntax_p0/set_operation.groovy
@@ -291,6 +291,27 @@ suite("test_nereids_set_operation") {
select count(*) from (select 1, 2 union select 1,1 ) a;
"""
+ // do not push non-injective cast project below UNION DISTINCT.
+ // The two datetime values are distinct before the outer cast, but become
+ // equal after casting to date. The correct result keeps both rows.
+ order_qt_union46 """
+ select cast(dt as date) from (
+ select cast('2020-01-01 00:00:00' as datetime) dt
+ union
+ select cast('2020-01-01 01:00:00' as datetime) dt
+ ) t
+ """
+
+ // The project duplicates one UNION output and drops the other. Pushing it
+ // below UNION DISTINCT would collapse the two rows into one.
+ order_qt_union47 """
+ select cast(a as bigint), cast(a as bigint) from (
+ select 1 a, 2 b
+ union
+ select 1 a, 3 b
+ ) t
+ """
+
def tables = [
"dwd_daytable",
]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]