This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-4.1
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-4.1 by this push:
     new a8ff3007e5b branch-4.1: [fix](fe) Prevent cast project pushdown 
through union distinct #64080 (#64557)
a8ff3007e5b is described below

commit a8ff3007e5b4bac1ecc57792a098a47a79c754d1
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Wed Jun 17 09:30:05 2026 +0800

    branch-4.1: [fix](fe) Prevent cast project pushdown through union distinct 
#64080 (#64557)
    
    Cherry-picked from #64080
    
    Co-authored-by: morrySnow <[email protected]>
---
 .../rules/rewrite/PushProjectThroughUnion.java     | 33 +++++++++--
 .../org/apache/doris/nereids/types/ArrayType.java  |  9 +++
 .../apache/doris/nereids/types/BooleanType.java    |  8 +++
 .../org/apache/doris/nereids/types/DataType.java   |  4 ++
 .../apache/doris/nereids/types/DateTimeType.java   |  9 +++
 .../apache/doris/nereids/types/DateTimeV2Type.java | 13 +++++
 .../org/apache/doris/nereids/types/DateType.java   |  6 ++
 .../apache/doris/nereids/types/DecimalV2Type.java  | 14 +++++
 .../apache/doris/nereids/types/DecimalV3Type.java  | 14 +++++
 .../org/apache/doris/nereids/types/MapType.java    | 10 ++++
 .../org/apache/doris/nereids/types/StructType.java | 18 ++++++
 .../doris/nereids/types/TimeStampTzType.java       | 10 ++++
 .../org/apache/doris/nereids/types/TimeV2Type.java | 10 ++++
 .../apache/doris/nereids/types/VariantType.java    |  5 ++
 .../nereids/types/coercion/CharacterType.java      |  5 ++
 .../doris/nereids/types/coercion/IntegralType.java | 15 +++++
 .../apache/doris/nereids/util/ExpressionUtils.java | 14 +++++
 .../rules/rewrite/PushProjectThroughUnionTest.java | 57 +++++++++++++++++++
 .../apache/doris/nereids/types/DataTypeTest.java   | 66 ++++++++++++++++++++++
 .../data/nereids_syntax_p0/set_operation.out       |  9 ++-
 .../suites/nereids_syntax_p0/set_operation.groovy  | 21 +++++++
 21 files changed, 343 insertions(+), 7 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnion.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnion.java
index dfaef55b56c..87b239556ec 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnion.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnion.java
@@ -20,12 +20,14 @@ package org.apache.doris.nereids.rules.rewrite;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
 import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation;
 import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
@@ -35,9 +37,11 @@ import org.apache.doris.nereids.util.ExpressionUtils;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableList.Builder;
 import com.google.common.collect.Maps;
+import com.google.common.collect.Sets;
 
 import java.util.List;
 import java.util.Map;
+import java.util.Set;
 
 /**
  * this rule push down the project through union to let MergeUnion could do 
better
@@ -56,14 +60,31 @@ public class PushProjectThroughUnion extends 
OneRewriteRuleFactory {
 
     /** canPushProject */
     public static boolean canPushProject(List<NamedExpression> projects, 
LogicalSetOperation logicalSetOperation) {
-        return projects.size() == logicalSetOperation.getOutput().size() && 
projects.stream().allMatch(e -> {
-            if (e instanceof SlotReference) {
-                return true;
+        if (projects.size() != logicalSetOperation.getOutput().size()) {
+            return false;
+        }
+        boolean isAll = 
logicalSetOperation.getQualifier().equals(Qualifier.ALL);
+        Set<ExprId> projectInputExprIds = 
Sets.newHashSetWithExpectedSize(projects.size());
+        for (NamedExpression project : projects) {
+            Expression input;
+            if (project instanceof SlotReference) {
+                input = project;
+            } else if (isAll) {
+                input = 
ExpressionUtils.getExpressionCoveredByCast(project.child(0));
             } else {
-                Expression expr = 
ExpressionUtils.getExpressionCoveredByCast(e.child(0));
-                return expr instanceof SlotReference;
+                input = 
ExpressionUtils.getExpressionCoveredBySafetyCast(project.child(0));
+            }
+            if (!(input instanceof SlotReference)) {
+                return false;
             }
-        });
+            projectInputExprIds.add(((SlotReference) input).getExprId());
+        }
+        if (isAll) {
+            return true;
+        }
+        Set<ExprId> outputExprIds = 
Sets.newHashSetWithExpectedSize(logicalSetOperation.getOutput().size());
+        logicalSetOperation.getOutput().forEach(output -> 
outputExprIds.add(output.getExprId()));
+        return projectInputExprIds.equals(outputExprIds);
     }
 
     /** doPushProject */
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/ArrayType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/ArrayType.java
index 6d4ec539ff3..c189154fb42 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/ArrayType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/ArrayType.java
@@ -18,6 +18,7 @@
 package org.apache.doris.nereids.types;
 
 import org.apache.doris.catalog.Type;
+import org.apache.doris.nereids.types.coercion.CharacterType;
 import org.apache.doris.nereids.types.coercion.ComplexDataType;
 
 import java.util.Objects;
@@ -63,6 +64,14 @@ public class ArrayType extends DataType implements 
ComplexDataType, NestedColumn
         return containsNull;
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        if (target instanceof ArrayType) {
+            return itemType.isInjectiveCastTo(((ArrayType) target).itemType);
+        }
+        return target instanceof CharacterType;
+    }
+
     @Override
     public Type toCatalogDataType() {
         return new 
org.apache.doris.catalog.ArrayType(itemType.toCatalogDataType(), containsNull);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BooleanType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BooleanType.java
index 49b2a6e72d7..708801a883f 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BooleanType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BooleanType.java
@@ -31,6 +31,14 @@ public class BooleanType extends PrimitiveType {
     private BooleanType() {
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        return target instanceof BooleanType || target.isIntegralType() || 
target.isFloatLikeType()
+                || (target instanceof DecimalV2Type && ((DecimalV2Type) 
target).getRange() >= 1)
+                || (target instanceof DecimalV3Type && ((DecimalV3Type) 
target).getRange() >= 1)
+                || target.isStringLikeType();
+    }
+
     @Override
     public Type toCatalogDataType() {
         return Type.BOOLEAN;
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java
index 8c30b835b05..4e48e5aa484 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java
@@ -812,6 +812,10 @@ public abstract class DataType {
 
     public abstract int width();
 
+    public boolean isInjectiveCastTo(DataType target) {
+        return this.equals(target);
+    }
+
     public static List<DataType> trivialTypes() {
         return Type.getTrivialTypes()
                 .stream()
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeType.java
index 8a0250d7b44..a93bfda0364 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeType.java
@@ -19,6 +19,7 @@ package org.apache.doris.nereids.types;
 
 import org.apache.doris.catalog.Type;
 import org.apache.doris.common.Config;
+import org.apache.doris.nereids.types.coercion.CharacterType;
 import org.apache.doris.nereids.types.coercion.DateLikeType;
 
 import java.time.DateTimeException;
@@ -45,6 +46,14 @@ public class DateTimeType extends DateLikeType {
         this.shouldConversion = shouldConversion;
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        if (target instanceof DateTimeType || target instanceof DateTimeV2Type 
|| target instanceof CharacterType) {
+            return true;
+        }
+        return false;
+    }
+
     @Override
     public DataType conversion() {
         if (Config.enable_date_conversion && shouldConversion) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeV2Type.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeV2Type.java
index f56b4662f8b..13097339554 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeV2Type.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeV2Type.java
@@ -23,6 +23,7 @@ import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral;
 import 
org.apache.doris.nereids.trees.expressions.literal.format.DateTimeChecker;
+import org.apache.doris.nereids.types.coercion.CharacterType;
 import org.apache.doris.nereids.types.coercion.DateLikeType;
 import org.apache.doris.nereids.types.coercion.IntegralType;
 import org.apache.doris.nereids.types.coercion.ScaleTimeType;
@@ -128,6 +129,18 @@ public class DateTimeV2Type extends DateLikeType 
implements ScaleTimeType {
         return super.toSql() + "(" + scale + ")";
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        if (target instanceof DateTimeV2Type) {
+            DateTimeV2Type t2 = (DateTimeV2Type) target;
+            return this.scale <= t2.scale;
+        }
+        if (target instanceof DateTimeType) {
+            return this.scale == 0;
+        }
+        return target instanceof CharacterType;
+    }
+
     @Override
     public Type toCatalogDataType() {
         return ScalarType.createDatetimeV2Type(scale);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateType.java
index d127ab16069..c6ce702ebe7 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateType.java
@@ -19,6 +19,7 @@ package org.apache.doris.nereids.types;
 
 import org.apache.doris.catalog.Type;
 import org.apache.doris.common.Config;
+import org.apache.doris.nereids.types.coercion.CharacterType;
 import org.apache.doris.nereids.types.coercion.DateLikeType;
 
 import java.time.DateTimeException;
@@ -45,6 +46,11 @@ public class DateType extends DateLikeType {
         this.shouldConversion = shouldConversion;
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        return target instanceof DateType || target instanceof DateV2Type || 
target instanceof CharacterType;
+    }
+
     @Override
     public DataType conversion() {
         if (Config.enable_date_conversion && shouldConversion) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV2Type.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV2Type.java
index b601aaa9f13..b055172f262 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV2Type.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV2Type.java
@@ -21,6 +21,7 @@ import org.apache.doris.catalog.PrimitiveType;
 import org.apache.doris.catalog.ScalarType;
 import org.apache.doris.catalog.Type;
 import org.apache.doris.common.Config;
+import org.apache.doris.nereids.types.coercion.CharacterType;
 import org.apache.doris.nereids.types.coercion.FractionalType;
 
 import com.google.common.base.Preconditions;
@@ -159,6 +160,19 @@ public class DecimalV2Type extends FractionalType {
         return DecimalV2Type.createDecimalV2Type(range + scale, scale);
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        if (target instanceof DecimalV2Type) {
+            DecimalV2Type decimalV2Type = (DecimalV2Type) target;
+            return decimalV2Type.getRange() >= this.getRange() && 
decimalV2Type.getScale() >= this.getScale();
+        }
+        if (target instanceof DecimalV3Type) {
+            DecimalV3Type decimalV3Type = (DecimalV3Type) target;
+            return decimalV3Type.getRange() >= this.getRange() && 
decimalV3Type.getScale() >= this.getScale();
+        }
+        return target instanceof CharacterType;
+    }
+
     @Override
     public Type toCatalogDataType() {
         return ScalarType.createDecimalType(PrimitiveType.DECIMALV2, 
precision, scale);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV3Type.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV3Type.java
index 3c0a83e95c4..b366568cb35 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV3Type.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV3Type.java
@@ -22,6 +22,7 @@ import org.apache.doris.catalog.Type;
 import org.apache.doris.nereids.annotation.Developing;
 import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.exceptions.NotSupportedException;
+import org.apache.doris.nereids.types.coercion.CharacterType;
 import org.apache.doris.nereids.types.coercion.FractionalType;
 import org.apache.doris.qe.ConnectContext;
 import org.apache.doris.qe.SessionVariable;
@@ -214,6 +215,19 @@ public class DecimalV3Type extends FractionalType {
         }
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        if (target instanceof DecimalV2Type) {
+            DecimalV2Type decimalV2Type = (DecimalV2Type) target;
+            return decimalV2Type.getRange() >= this.getRange() && 
decimalV2Type.getScale() >= this.getScale();
+        }
+        if (target instanceof DecimalV3Type) {
+            DecimalV3Type decimalV3Type = (DecimalV3Type) target;
+            return decimalV3Type.getRange() >= this.getRange() && 
decimalV3Type.getScale() >= this.getScale();
+        }
+        return target instanceof CharacterType;
+    }
+
     @Override
     public Type toCatalogDataType() {
         return ScalarType.createDecimalV3Type(precision, scale);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/MapType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/MapType.java
index 176c1db1d0d..fc6e9ba2f94 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/MapType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/MapType.java
@@ -19,6 +19,7 @@ package org.apache.doris.nereids.types;
 
 import org.apache.doris.catalog.Type;
 import org.apache.doris.nereids.annotation.Developing;
+import org.apache.doris.nereids.types.coercion.CharacterType;
 import org.apache.doris.nereids.types.coercion.ComplexDataType;
 
 import java.util.Objects;
@@ -63,6 +64,15 @@ public class MapType extends DataType implements 
ComplexDataType, NestedColumnPr
         return MapType.of(keyType.conversion(), valueType.conversion());
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        if (target instanceof MapType) {
+            MapType mapType = (MapType) target;
+            return keyType.isInjectiveCastTo(mapType.keyType) && 
valueType.isInjectiveCastTo(mapType.valueType);
+        }
+        return target instanceof CharacterType;
+    }
+
     @Override
     public Type toCatalogDataType() {
         return new 
org.apache.doris.catalog.MapType(keyType.toCatalogDataType(), 
valueType.toCatalogDataType());
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/StructType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/StructType.java
index 0c33a6d2dec..13f28c2e06e 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/StructType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/StructType.java
@@ -20,6 +20,7 @@ package org.apache.doris.nereids.types;
 import org.apache.doris.catalog.Type;
 import org.apache.doris.nereids.annotation.Developing;
 import org.apache.doris.nereids.exceptions.AnalysisException;
+import org.apache.doris.nereids.types.coercion.CharacterType;
 import org.apache.doris.nereids.types.coercion.ComplexDataType;
 
 import com.google.common.collect.ImmutableList;
@@ -84,6 +85,23 @@ public class StructType extends DataType implements 
ComplexDataType, NestedColum
         return new 
StructType(fields.stream().map(StructField::conversion).collect(Collectors.toList()));
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        if (target instanceof StructType) {
+            StructType structType = (StructType) target;
+            if (this.fields.size() != structType.fields.size()) {
+                return false;
+            }
+            for (int i = 0; i < fields.size(); i++) {
+                if 
(!this.fields.get(i).getDataType().isInjectiveCastTo(structType.fields.get(i).getDataType()))
 {
+                    return false;
+                }
+            }
+            return true;
+        }
+        return target instanceof CharacterType;
+    }
+
     @Override
     public Type toCatalogDataType() {
         return new org.apache.doris.catalog.StructType(fields.stream()
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeStampTzType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeStampTzType.java
index c3c99cad6fc..4f9c09b5f7f 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeStampTzType.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeStampTzType.java
@@ -21,6 +21,7 @@ import org.apache.doris.catalog.ScalarType;
 import org.apache.doris.catalog.Type;
 import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral;
+import org.apache.doris.nereids.types.coercion.CharacterType;
 import org.apache.doris.nereids.types.coercion.DateLikeType;
 import org.apache.doris.nereids.types.coercion.ScaleTimeType;
 
@@ -46,6 +47,15 @@ public class TimeStampTzType extends DateLikeType implements 
ScaleTimeType {
         this.scale = scale;
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        if (target instanceof TimeStampTzType) {
+            TimeStampTzType timeStampTzType = (TimeStampTzType) target;
+            return timeStampTzType.getScale() >= this.scale;
+        }
+        return target instanceof CharacterType;
+    }
+
     @Override
     public Type toCatalogDataType() {
         return ScalarType.createTimeStampTzType(scale);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeV2Type.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeV2Type.java
index 39f420e6931..af758c02cb5 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeV2Type.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeV2Type.java
@@ -22,6 +22,7 @@ import org.apache.doris.catalog.Type;
 import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.TimeV2Literal;
+import org.apache.doris.nereids.types.coercion.CharacterType;
 import org.apache.doris.nereids.types.coercion.IntegralType;
 import org.apache.doris.nereids.types.coercion.PrimitiveType;
 import org.apache.doris.nereids.types.coercion.RangeScalable;
@@ -48,6 +49,15 @@ public class TimeV2Type extends PrimitiveType implements 
RangeScalable, ScaleTim
         scale = 0;
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        if (target instanceof TimeV2Type) {
+            TimeV2Type timeV2Type = (TimeV2Type) target;
+            return timeV2Type.scale >= scale;
+        }
+        return target instanceof CharacterType;
+    }
+
     @Override
     public Type toCatalogDataType() {
         return ScalarType.createTimeV2Type(scale);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java
index 4b54aad5414..75e05a6dde9 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java
@@ -115,6 +115,11 @@ public class VariantType extends PrimitiveType {
         this.enableNestedGroup = enableNestedGroup;
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        return target.equals(this) || target instanceof VariantType;
+    }
+
     @Override
     public DataType conversion() {
         return new 
VariantType(predefinedFields.stream().map(VariantField::conversion)
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/CharacterType.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/CharacterType.java
index 781b1257028..3d8590534f5 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/CharacterType.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/CharacterType.java
@@ -42,6 +42,11 @@ public abstract class CharacterType extends PrimitiveType {
         return len;
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        return target instanceof CharacterType;
+    }
+
     @Override
     public Type toCatalogDataType() {
         throw new RuntimeException("CharacterType is only used for implicit 
cast.");
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java
index b1e58805388..fe625fa34bc 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java
@@ -19,6 +19,8 @@ package org.apache.doris.nereids.types.coercion;
 
 import org.apache.doris.nereids.types.BigIntType;
 import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.DecimalV3Type;
+import org.apache.doris.nereids.types.LargeIntType;
 
 import org.apache.commons.lang3.NotImplementedException;
 
@@ -44,6 +46,19 @@ public class IntegralType extends NumericType {
         return "integral";
     }
 
+    @Override
+    public boolean isInjectiveCastTo(DataType target) {
+        if (target instanceof IntegralType) {
+            return this.equals(target) || ((IntegralType) 
target).widerThan(this);
+        }
+        if (target instanceof DecimalV3Type && !(this instanceof 
LargeIntType)) {
+            DecimalV3Type other = (DecimalV3Type) target;
+            DecimalV3Type self = DecimalV3Type.forType(this);
+            return other.getRange() >= self.getRange();
+        }
+        return target instanceof CharacterType;
+    }
+
     public boolean widerThan(IntegralType other) {
         return this.width() > other.width();
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index faa8ff65f05..ec93e1da46f 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -1126,6 +1126,20 @@ public class ExpressionUtils {
         return expression;
     }
 
+    /**
+     * Strip only casts that preserve distinctness of the child expression.
+     */
+    public static Expression getExpressionCoveredBySafetyCast(Expression 
expression) {
+        while (expression instanceof Cast) {
+            if (((Cast) 
expression).child().getDataType().isInjectiveCastTo(expression.getDataType())) {
+                expression = ((Cast) expression).child();
+            } else {
+                break;
+            }
+        }
+        return expression;
+    }
+
     /**
      * the expressions can be used as runtime filter targets
      */
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnionTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnionTest.java
index 328c390d52f..a877348491c 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnionTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnionTest.java
@@ -28,6 +28,8 @@ import 
org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
 import org.apache.doris.nereids.types.BigIntType;
+import org.apache.doris.nereids.types.DateTimeType;
+import org.apache.doris.nereids.types.DateType;
 import org.apache.doris.nereids.types.IntegerType;
 import org.apache.doris.nereids.util.MemoTestUtils;
 import org.apache.doris.nereids.util.PlanChecker;
@@ -107,6 +109,61 @@ public class PushProjectThroughUnionTest {
         }
     }
 
+    @Test
+    public void testCastProjectPushThroughUnionByQualifierAndSafety() {
+        SlotReference unionOutput = new SlotReference(new ExprId(10), "s",
+                IntegerType.INSTANCE, true, ImmutableList.of());
+        Alias castProject = new Alias(new ExprId(100),
+                new Cast(unionOutput, BigIntType.INSTANCE), "n");
+        ImmutableList<NamedExpression> projects = 
ImmutableList.of(castProject);
+
+        LogicalUnion unionAll = new LogicalUnion(Qualifier.ALL,
+                ImmutableList.of(unionOutput), ImmutableList.of(), 
ImmutableList.of(), false, ImmutableList.of());
+        Assertions.assertTrue(PushProjectThroughUnion.canPushProject(projects, 
unionAll));
+
+        LogicalUnion unionDistinct = new LogicalUnion(Qualifier.DISTINCT,
+                ImmutableList.of(unionOutput), ImmutableList.of(), 
ImmutableList.of(), false, ImmutableList.of());
+        Assertions.assertTrue(PushProjectThroughUnion.canPushProject(projects, 
unionDistinct));
+
+        SlotReference dateTimeOutput = new SlotReference(new ExprId(11), "dt",
+                DateTimeType.INSTANCE, true, ImmutableList.of());
+        Alias unsafeCastProject = new Alias(new ExprId(101),
+                new Cast(dateTimeOutput, DateType.INSTANCE), "d");
+        ImmutableList<NamedExpression> unsafeProjects = 
ImmutableList.of(unsafeCastProject);
+
+        LogicalUnion unionAllWithUnsafeCast = new LogicalUnion(Qualifier.ALL,
+                ImmutableList.of(dateTimeOutput), ImmutableList.of(), 
ImmutableList.of(), false, ImmutableList.of());
+        
Assertions.assertTrue(PushProjectThroughUnion.canPushProject(unsafeProjects, 
unionAllWithUnsafeCast));
+
+        LogicalUnion unionDistinctWithUnsafeCast = new 
LogicalUnion(Qualifier.DISTINCT,
+                ImmutableList.of(dateTimeOutput), ImmutableList.of(), 
ImmutableList.of(), false, ImmutableList.of());
+        
Assertions.assertFalse(PushProjectThroughUnion.canPushProject(unsafeProjects, 
unionDistinctWithUnsafeCast));
+    }
+
+    @Test
+    public void testDistinctProjectRequiresAllOutputSlotsExactlyOnce() {
+        SlotReference firstOutput = new SlotReference(new ExprId(10), "a",
+                IntegerType.INSTANCE, true, ImmutableList.of());
+        SlotReference secondOutput = new SlotReference(new ExprId(11), "b",
+                IntegerType.INSTANCE, true, ImmutableList.of());
+        ImmutableList<NamedExpression> outputs = ImmutableList.of(firstOutput, 
secondOutput);
+        LogicalUnion unionAll = new LogicalUnion(Qualifier.ALL,
+                outputs, ImmutableList.of(), ImmutableList.of(), false, 
ImmutableList.of());
+        LogicalUnion unionDistinct = new LogicalUnion(Qualifier.DISTINCT,
+                outputs, ImmutableList.of(), ImmutableList.of(), false, 
ImmutableList.of());
+
+        ImmutableList<NamedExpression> duplicateProjects = ImmutableList.of(
+                new Alias(new ExprId(100), new Cast(firstOutput, 
BigIntType.INSTANCE), "a1"),
+                new Alias(new ExprId(101), new Cast(firstOutput, 
BigIntType.INSTANCE), "a2"));
+        
Assertions.assertTrue(PushProjectThroughUnion.canPushProject(duplicateProjects, 
unionAll));
+        
Assertions.assertFalse(PushProjectThroughUnion.canPushProject(duplicateProjects,
 unionDistinct));
+
+        ImmutableList<NamedExpression> permutationProjects = ImmutableList.of(
+                new Alias(new ExprId(102), new Cast(secondOutput, 
BigIntType.INSTANCE), "b"),
+                new Alias(new ExprId(103), new Cast(firstOutput, 
BigIntType.INSTANCE), "a"));
+        
Assertions.assertTrue(PushProjectThroughUnion.canPushProject(permutationProjects,
 unionDistinct));
+    }
+
     private LogicalUnion findUnion(Plan p) {
         if (p instanceof LogicalUnion) {
             return (LogicalUnion) p;
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/DataTypeTest.java 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/DataTypeTest.java
index 59509fce805..9720fa1c550 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/DataTypeTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/DataTypeTest.java
@@ -138,6 +138,62 @@ public class DataTypeTest {
 
     }
 
+    @Test
+    public void testIsInjectiveCastToForPrimitiveTypes() {
+        assertSafeCast(IntegerType.INSTANCE, IntegerType.INSTANCE);
+        assertSafeCast(IntegerType.INSTANCE, BigIntType.INSTANCE);
+        assertUnsafeCast(BigIntType.INSTANCE, IntegerType.INSTANCE);
+        assertSafeCast(IntegerType.INSTANCE, 
DecimalV3Type.createDecimalV3Type(10, 0));
+        assertUnsafeCast(IntegerType.INSTANCE, 
DecimalV3Type.createDecimalV3Type(9, 0));
+        assertUnsafeCast(LargeIntType.INSTANCE, 
DecimalV3Type.createDecimalV3Type(38, 0));
+
+        assertSafeCast(BooleanType.INSTANCE, 
DecimalV3Type.createDecimalV3Type(1, 0));
+        assertUnsafeCast(BooleanType.INSTANCE, 
DecimalV3Type.createDecimalV3Type(1, 1));
+
+        assertSafeCast(DecimalV3Type.createDecimalV3Type(6, 2), 
DecimalV3Type.createDecimalV3Type(8, 3));
+        assertUnsafeCast(DecimalV3Type.createDecimalV3Type(6, 2), 
DecimalV3Type.createDecimalV3Type(6, 1));
+        assertUnsafeCast(DecimalV3Type.createDecimalV3Type(6, 2), 
DecimalV3Type.createDecimalV3Type(5, 2));
+
+        assertSafeCast(DateTimeType.INSTANCE, DateTimeV2Type.of(0));
+        assertSafeCast(DateTimeV2Type.of(0), DateTimeType.INSTANCE);
+        assertSafeCast(DateTimeV2Type.of(3), DateTimeV2Type.of(6));
+        assertUnsafeCast(DateTimeV2Type.of(3), DateTimeType.INSTANCE);
+        assertUnsafeCast(DateTimeType.INSTANCE, DateType.INSTANCE);
+
+        assertSafeCast(VarcharType.createVarcharType(10), 
VarcharType.createVarcharType(20));
+        assertSafeCast(VarcharType.createVarcharType(10), StringType.INSTANCE);
+        assertSafeCast(VarcharType.createVarcharType(20), 
VarcharType.createVarcharType(10));
+        assertSafeCast(StringType.INSTANCE, VarcharType.createVarcharType(10));
+    }
+
+    @Test
+    public void testIsInjectiveCastToForComplexTypes() {
+        assertSafeCast(ArrayType.of(IntegerType.INSTANCE), 
ArrayType.of(BigIntType.INSTANCE));
+        assertUnsafeCast(ArrayType.of(BigIntType.INSTANCE), 
ArrayType.of(IntegerType.INSTANCE));
+
+        assertSafeCast(MapType.of(IntegerType.INSTANCE, 
VarcharType.createVarcharType(10)),
+                MapType.of(BigIntType.INSTANCE, StringType.INSTANCE));
+        assertUnsafeCast(MapType.of(BigIntType.INSTANCE, 
VarcharType.createVarcharType(10)),
+                MapType.of(IntegerType.INSTANCE, StringType.INSTANCE));
+
+        StructType intStringStruct = new StructType(ImmutableList.of(
+                new StructField("a", IntegerType.INSTANCE, true, ""),
+                new StructField("b", VarcharType.createVarcharType(10), true, 
"")));
+        StructType bigintStringStruct = new StructType(ImmutableList.of(
+                new StructField("a", BigIntType.INSTANCE, true, ""),
+                new StructField("b", StringType.INSTANCE, true, "")));
+        StructType intOnlyStruct = new StructType(ImmutableList.of(
+                new StructField("a", IntegerType.INSTANCE, true, "")));
+
+        assertSafeCast(intStringStruct, bigintStringStruct);
+        assertUnsafeCast(bigintStringStruct, intStringStruct);
+        assertUnsafeCast(intOnlyStruct, intStringStruct);
+
+        assertSafeCast(ArrayType.of(IntegerType.INSTANCE), 
StringType.INSTANCE);
+        assertSafeCast(MapType.of(IntegerType.INSTANCE, StringType.INSTANCE), 
StringType.INSTANCE);
+        assertSafeCast(intStringStruct, StringType.INSTANCE);
+    }
+
     @Test
     public void testAnyAccept() {
         AnyDataType dateType = AnyDataType.INSTANCE_WITHOUT_INDEX;
@@ -654,4 +710,14 @@ public class DataTypeTest {
         DataType type = ArrayType.of(MapType.of(VarcharType.SYSTEM_DEFAULT, 
IntegerType.INSTANCE));
         Assertions.assertDoesNotThrow(type::validateDataType);
     }
+
+    private void assertSafeCast(DataType source, DataType target) {
+        Assertions.assertTrue(source.isInjectiveCastTo(target), source.toSql() 
+ " should safely cast to "
+                + target.toSql());
+    }
+
+    private void assertUnsafeCast(DataType source, DataType target) {
+        Assertions.assertFalse(source.isInjectiveCastTo(target), 
source.toSql() + " should not safely cast to "
+                + target.toSql());
+    }
 }
diff --git a/regression-test/data/nereids_syntax_p0/set_operation.out 
b/regression-test/data/nereids_syntax_p0/set_operation.out
index 5afc4fac2ad..b21f59c0f65 100644
--- a/regression-test/data/nereids_syntax_p0/set_operation.out
+++ b/regression-test/data/nereids_syntax_p0/set_operation.out
@@ -592,6 +592,14 @@ hell0
 -- !union45 --
 2
 
+-- !union46 --
+2020-01-01
+2020-01-01
+
+-- !union47 --
+1      1
+1      1
+
 -- !check_child_col_order --
 205548764.21875        3601
 53950855.65625 3602
@@ -599,4 +607,3 @@ hell0
 -- !intersect_case --
 0
 1
-
diff --git a/regression-test/suites/nereids_syntax_p0/set_operation.groovy 
b/regression-test/suites/nereids_syntax_p0/set_operation.groovy
index c81c04aae5d..e3126212827 100644
--- a/regression-test/suites/nereids_syntax_p0/set_operation.groovy
+++ b/regression-test/suites/nereids_syntax_p0/set_operation.groovy
@@ -291,6 +291,27 @@ suite("test_nereids_set_operation") {
         select count(*) from (select 1, 2 union select 1,1 ) a;
     """
 
+    // do not push non-injective cast project below UNION DISTINCT.
+    // The two datetime values are distinct before the outer cast, but become
+    // equal after casting to date. The correct result keeps both rows.
+    order_qt_union46 """
+        select cast(dt as date) from (
+            select cast('2020-01-01 00:00:00' as datetime) dt
+            union
+            select cast('2020-01-01 01:00:00' as datetime) dt
+        ) t
+    """
+
+    // The project duplicates one UNION output and drops the other. Pushing it
+    // below UNION DISTINCT would collapse the two rows into one.
+    order_qt_union47 """
+        select cast(a as bigint), cast(a as bigint) from (
+            select 1 a, 2 b
+            union
+            select 1 a, 3 b
+        ) t
+    """
+
     def tables = [
             "dwd_daytable",
     ]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to