morrySnow commented on code in PR #49096:
URL: https://github.com/apache/doris/pull/49096#discussion_r2194612710


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSkewExpr.java:
##########
@@ -0,0 +1,64 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.analysis;
+
+import org.apache.doris.nereids.CascadesContext;
+import org.apache.doris.nereids.hint.DistributeHint;
+import org.apache.doris.nereids.hint.JoinSkewInfo;
+import org.apache.doris.nereids.pattern.MatchingContext;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.util.TypeCoercionUtils;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**bind skew hint in DistributeHint*/
+public class BindSkewExpr extends BindExpression {
+    @Override
+    public List<Rule> buildRules() {
+        return ImmutableList.of(
+            RuleType.BINDING_SKEW_EXPR.build(
+                    logicalJoin().when(join -> 
join.getDistributeHint().getSkewInfo() != null)
+                    .thenApply(this::bindSkewExpr))
+        );
+    }
+
+    private LogicalJoin<Plan, Plan> 
bindSkewExpr(MatchingContext<LogicalJoin<Plan, Plan>> ctx) {

Review Comment:
   could this class merged into BindExpression?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Aggregate.java:
##########
@@ -106,4 +111,23 @@ default boolean isDistinct() {
         return getOutputExpressions().stream().allMatch(e -> e instanceof Slot)
                 && getGroupByExpressions().stream().allMatch(e -> e instanceof 
Slot);
     }
+
+    /**canSkewRewrite*/

Review Comment:
   add comment to explain could rewrite situations



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/properties/DistributionSpecHash.java:
##########
@@ -53,6 +53,8 @@ public class DistributionSpecHash extends DistributionSpec {
     private final long tableId;
     private final Set<Long> partitionIds;
     private final long selectedIndexId;
+    // used for window skew rewrite
+    private final boolean isSkew;

Review Comment:
   why process window need add a skew flag here?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownAliasThroughJoin.java:
##########
@@ -97,9 +98,13 @@ public Rule build() {
                 List<Expression> newHash = 
replaceJoinConjuncts(join.getHashJoinConjuncts(), replaceMap);
                 List<Expression> newOther = 
replaceJoinConjuncts(join.getOtherJoinConjuncts(), replaceMap);
                 List<Expression> newMark = 
replaceJoinConjuncts(join.getMarkJoinConjuncts(), replaceMap);
-
+                DistributeHint hint = join.getDistributeHint();

Review Comment:
   it is mutable? this is dangerous, could it be immutable?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalWindowToPhysicalWindow.java:
##########
@@ -95,10 +95,12 @@ private PhysicalWindow implement(LogicalWindow<GroupPlan> 
logicalWindow) {
 
         Plan newRoot = logicalWindow.child();
         for (PartitionKeyGroup partitionKeyGroup : partitionKeyGroupList) {
-            for (OrderKeyGroup orderKeyGroup : partitionKeyGroup.groups) {
+            boolean isSkew = partitionKeyGroup.isSkew();
+            for (int i = 0; i < partitionKeyGroup.groups.size(); ++i) {
+                OrderKeyGroup orderKeyGroup = partitionKeyGroup.groups.get(i);
                 // in OrderKeyGroup, create PhysicalWindow for each 
WindowFrameGroup;
                 // each PhysicalWindow contains the same windowExpressions as 
WindowFrameGroup.groups
-                newRoot = createPhysicalPlanNodeForWindowFrameGroup(newRoot, 
orderKeyGroup);
+                newRoot = createPhysicalPlanNodeForWindowFrameGroup(newRoot, 
orderKeyGroup, 0 == i && isSkew);

Review Comment:
   why need `0 == i`? could u add some comment to explain it?



##########
fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java:
##########
@@ -743,6 +743,10 @@ public class SessionVariable implements Serializable, 
Writable {
 
     public static final String PREFER_UDF_OVER_BUILTIN = 
"prefer_udf_over_builtin";
 
+    public static final String JOIN_SKEW_ADD_SALT_EXPLODE_FACTOR = 
"join_skew_add_salt_explode_factor";

Review Comment:
   better all skew related variable has a same prefix



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java:
##########
@@ -2088,4 +2107,130 @@ private boolean couldConvertToMulti(LogicalAggregate<? 
extends Plan> aggregate)
         }
         return true;
     }
+
+    /**
+     * LogicalAggregate(groupByExpr=[a], outputExpr=[a,count(distinct b)])
+     * ->
+     * +--PhysicalHashAggregate(groupByExpr=[a], outputExpr=[a, 
count(partial_count(m))]
+     *   +--PhysicalDistribute(shuffleColumn=[a])
+     *     +--PhysicalHashAggregate(groupByExpr=[a], outputExpr=[a, 
partial_count(m)]
+     *       +--PhysicalHashAggregate(groupByExpr=[a, saltExpr], 
outputExpr=[a, multi_distinct_count(b) as m])
+     *         +--PhysicalDistribute(shuffleColumn=[a, saltExpr])
+     *           +--PhysicalProject(projects=[a, b, xxhash_32(b)%512 as 
saltExpr])
+     *             +--PhysicalHashAggregate(groupByExpr=[a, b], outputExpr=[a, 
b])
+     * */
+    private PhysicalHashAggregate<Plan> 
countDistinctSkewRewrite(LogicalAggregate<GroupPlan> logicalAgg,
+            CascadesContext cascadesContext) {
+        if (!logicalAgg.canSkewRewrite()) {
+            return null;
+        }
+
+        // 1.local agg
+        ImmutableList.Builder<Expression> localAggGroupByBuilder = 
ImmutableList.builderWithExpectedSize(
+                logicalAgg.getGroupByExpressions().size() + 1);
+        localAggGroupByBuilder.addAll(logicalAgg.getGroupByExpressions());
+        Count count = (Count) 
logicalAgg.getAggregateFunctions().iterator().next();
+        if (!(count.child(0) instanceof Slot)) {
+            return null;
+        }
+        localAggGroupByBuilder.add(count.child(0));
+        List<Expression> localAggGroupBy = localAggGroupByBuilder.build();
+        List<NamedExpression> localAggOutput = 
Utils.fastToImmutableList((List) localAggGroupBy);
+        RequireProperties requireAny = 
RequireProperties.of(PhysicalProperties.ANY);
+        boolean maybeUsingStreamAgg = 
maybeUsingStreamAgg(cascadesContext.getConnectContext(),
+                localAggGroupBy);
+        boolean couldBanned = false;
+        AggregateParam localParam = new AggregateParam(AggPhase.LOCAL, 
AggMode.INPUT_TO_BUFFER, couldBanned);
+        PhysicalHashAggregate<Plan> localAgg = new 
PhysicalHashAggregate<>(localAggGroupBy, localAggOutput,
+                Optional.empty(), localParam, maybeUsingStreamAgg, 
Optional.empty(), null,
+                requireAny, logicalAgg.child());
+        // add shuffle expr in project
+        ImmutableList.Builder<NamedExpression> projections = 
ImmutableList.builderWithExpectedSize(
+                localAgg.getOutputs().size() + 1);
+        projections.addAll(localAgg.getOutputs());
+        Alias modAlias = getShuffleExpr(count, cascadesContext);
+        projections.add(modAlias);
+        PhysicalProject<Plan> physicalProject = new 
PhysicalProject<>(projections.build(), null, localAgg);
+
+        // 2.second phase agg: multi_distinct_count(b) group by a,h
+        ImmutableList.Builder<Expression> secondPhaseAggGroupByBuilder = 
ImmutableList.builderWithExpectedSize(
+                logicalAgg.getGroupByExpressions().size() + 1);
+        
secondPhaseAggGroupByBuilder.addAll(logicalAgg.getGroupByExpressions());
+        secondPhaseAggGroupByBuilder.add(modAlias.toSlot());
+        List<Expression> secondPhaseAggGroupBy = 
secondPhaseAggGroupByBuilder.build();
+        ImmutableList.Builder<NamedExpression> secondPhaseAggOutput = 
ImmutableList.builderWithExpectedSize(
+                secondPhaseAggGroupBy.size() + 1);
+        secondPhaseAggOutput.addAll((List) secondPhaseAggGroupBy);
+        Alias aliasTarget = new Alias(new TinyIntLiteral((byte) 0));
+        for (NamedExpression ne : logicalAgg.getOutputExpressions()) {
+            if (ne instanceof Alias) {
+                if (((Alias) ne).child().equals(count)) {
+                    aliasTarget = (Alias) ne;
+                }
+            }
+        }
+        AggregateParam secondParam = new AggregateParam(AggPhase.GLOBAL, 
AggMode.INPUT_TO_RESULT, couldBanned);
+        AggregateFunction multiDistinct = count.convertToMultiDistinct();
+        Alias multiDistinctAlias = new Alias(new 
AggregateExpression(multiDistinct, secondParam));
+        secondPhaseAggOutput.add(multiDistinctAlias);
+        List<ExprId> shuffleIds = new ArrayList<>();
+        for (Expression expr : secondPhaseAggGroupBy) {
+            if (expr instanceof Slot) {
+                shuffleIds.add(((Slot) expr).getExprId());
+            }
+        }
+        RequireProperties secondRequireProperties = RequireProperties.of(
+                PhysicalProperties.createHash(shuffleIds, 
ShuffleType.REQUIRE));
+        PhysicalHashAggregate<Plan> secondPhaseAgg = new 
PhysicalHashAggregate<>(
+                secondPhaseAggGroupBy, secondPhaseAggOutput.build(),
+                Optional.empty(), secondParam, false, Optional.empty(), null,
+                secondRequireProperties, physicalProject);
+
+        // 3. third phase agg
+        List<Expression> thirdPhaseAggGroupBy = 
Utils.fastToImmutableList(logicalAgg.getGroupByExpressions());
+        ImmutableList.Builder<NamedExpression> thirdPhaseAggOutput = 
ImmutableList.builderWithExpectedSize(
+                thirdPhaseAggGroupBy.size() + 1);
+        thirdPhaseAggOutput.addAll((List) thirdPhaseAggGroupBy);
+        AggregateParam thirdParam = new 
AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned);
+        Count thirdCount = new Count(multiDistinctAlias.toSlot());
+        Alias thirdCountAlias = new Alias(new AggregateExpression(thirdCount, 
thirdParam));
+        thirdPhaseAggOutput.add(thirdCountAlias);
+        PhysicalHashAggregate<Plan> thirdPhaseAgg = new 
PhysicalHashAggregate<>(
+                thirdPhaseAggGroupBy, thirdPhaseAggOutput.build(),
+                Optional.empty(), thirdParam, false, Optional.empty(), null,
+                secondRequireProperties, secondPhaseAgg);
+
+        // 4. fourth phase agg
+        ImmutableList.Builder<NamedExpression> fourthPhaseAggOutput = 
ImmutableList.builderWithExpectedSize(
+                thirdPhaseAggGroupBy.size() + 1);
+        fourthPhaseAggOutput.addAll((List) thirdPhaseAggGroupBy);
+        AggregateParam fourthParam = new 
AggregateParam(AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT,
+                couldBanned);
+        Alias sumAliasFour = new Alias(aliasTarget.getExprId(),
+                new AggregateExpression(thirdCount, fourthParam, 
thirdCountAlias.toSlot()),
+                aliasTarget.getName());
+        fourthPhaseAggOutput.add(sumAliasFour);
+        List<ExprId> shuffleIdsFour = new ArrayList<>();
+        for (Expression expr : logicalAgg.getExpressions()) {
+            if (expr instanceof Slot) {
+                shuffleIdsFour.add(((Slot) expr).getExprId());
+            }
+        }
+        RequireProperties fourthRequireProperties = RequireProperties.of(
+                PhysicalProperties.createHash(shuffleIdsFour, 
ShuffleType.REQUIRE));
+        return new PhysicalHashAggregate<>(thirdPhaseAggGroupBy,
+                fourthPhaseAggOutput.build(), Optional.empty(), fourthParam,
+                false, Optional.empty(), logicalAgg.getLogicalProperties(),
+                fourthRequireProperties, thirdPhaseAgg);
+    }
+
+    private Alias getShuffleExpr(Count count, CascadesContext cascadesContext) 
{
+        int bucketNum = 
cascadesContext.getConnectContext().getSessionVariable().aggDistinctSkewBucketNum;
+        DataType type = bucketNum <= 256 ? TinyIntType.INSTANCE : 
SmallIntType.INSTANCE;
+        int bucket = bucketNum / 2;

Review Comment:
   > because XxHash32 return negative and positive number
   
   add a comment



##########
fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java:
##########
@@ -2634,6 +2646,12 @@ public void setDetailShapePlanNodes(String 
detailShapePlanNodes) {
             })
     public boolean preferUdfOverBuiltin = false;
 
+    @VariableMgr.VarAttr(name = JOIN_SKEW_ADD_SALT_EXPLODE_FACTOR, description 
= {
+            "join 加盐优化的扩展因子",
+            "join skew add salt explode factor"

Review Comment:
   does it need a checker?



##########
fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java:
##########
@@ -4656,6 +4674,19 @@ public void checkBatchSize(String batchSize) {
         }
     }
 
+    public void checkAggDistinctSkewRewriteBucketNum(String bucketNumStr) {
+        try {
+            long bucketNum = Long.parseLong(bucketNumStr);
+            if (bucketNum <= 0 || bucketNum >= 65536) {
+                throw new InvalidParameterException(
+                        "agg_distinct_skew_rewrite_bucket_num should be 
between 1 and 65535");

Review Comment:
   should use static variable in error msg



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SaltJoin.java:
##########
@@ -0,0 +1,383 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.hint.DistributeHint;
+import org.apache.doris.nereids.hint.Hint.HintStatus;
+import org.apache.doris.nereids.pattern.MatchingContext;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Cast;
+import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.InPredicate;
+import org.apache.doris.nereids.trees.expressions.IsNull;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Not;
+import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
+import org.apache.doris.nereids.trees.expressions.Or;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.Function;
+import 
org.apache.doris.nereids.trees.expressions.functions.generator.ExplodeNumbers;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Random;
+import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.plans.DistributeType;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.types.TinyIntType;
+import org.apache.doris.nereids.util.TypeCoercionUtils;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import org.jetbrains.annotations.Nullable;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Current capabilities and limitations of SaltJoin rewrite handling:
+ * - Supports single-side skew in INNER JOIN, NOT support double-side (both 
tables) skew
+ * - Supports left table skew and NOT support right table skew in LEFT JOIN
+ * - Supports right table skew and Not support left table skew in RIGHT JOIN
+ *
+ * INNER JOIN and LEFT JOIN use case:
+ * Applicable when left table is skewed and right table is too large for 
broadcast
+ *
+ * Here are some examples in rewrite:
+ * case1:
+ * LogicalJoin(type:inner, t1.a=t2.a, hint:skew(t1.a(null,1,2)))
+ *   +--LogicalOlapScan(t1)
+ *   +--LogicalOlapScan(t2)
+ * ->
+ * LogicalJoin (type:inner, t1.a=t2.a and r1=r2)
+ *   |--LogicalProject (t1.a, if (t1.a IN (1, 2), random(0, 999), 
DEFAULT_SALT_VALUE) AS r1))
+ *   |  +--LogicalFilter(t1.a is not null)
+ *   |    +--LogicalOlapScan(t1)
+ *   +--LogicalProject (projections: t2.a, if(explodeNumber IS NULL, 
DEFAULT_SALT_VALUE, explodeNumber) as r2)
+ *     +--LogicalJoin (type=right_outer_join, t2.a = skewValue)
+ *       |--LogicalGenerate(generators=[explode_numbers(1000)], 
generatorOutput=[explodeNumber])
+ *       |  +--LogicalUnion(outputs=[skewValue], constantExprsList(1,2))
+ *       +--LogicalFilter(t2.a is not null)
+ *         +--LogicalOlapScan(t2)
+ *
+ * case2:
+ * LogicalJoin(type:inner, t1.a=t2.a, hint:skew(t1.a(1,2)))
+ *   +--LogicalOlapScan(t1)
+ *   +--LogicalOlapScan(t2)
+ * ->
+ * LogicalJoin (type:inner, t1.a=t2.a and r1=r2)
+ *   |--LogicalProject (t1.a, if (t1.a IN (1, 2), random(0, 999), 
DEFAULT_SALT_VALUE) AS r1))
+ *   | +--LogicalOlapScan(t1)
+ *   +--LogicalProject (projections: t2.a, if(explodeNumber IS NULL, 
DEFAULT_SALT_VALUE, explodeNumber) as r2)
+ *     +--LogicalJoin (type=right_outer_join, t2.a = skewValue)
+ *       |--LogicalGenerate(generators=[explode_numbers(1000)], 
generatorOutput=[explodeNumber])
+ *       |  +--LogicalUnion(outputs=[skewValue], constantExprsList(1,2))
+ *       +--LogicalOlapScan(t2)
+ *
+ * case3: not optimize, because rows will not be output in join when join key 
is null
+ * LogicalJoin(type:inner, t1.a=t2.a, hint:skew(t1.a(null)))
+ *   |--LogicalOlapScan(t1)
+ *   +--LogicalOlapScan(t2)
+ * ->
+ * LogicalJoin(type:inner, t1.a=t2.a)
+ *   |--LogicalFilter(t1.a is not null)
+ *   |  +--LogicalOlapScan(t1)
+ *   +--LogicalFilter(t2.a is not null)
+ *     +--LogicalOlapScan(t2)
+ * */
+public class SaltJoin extends OneRewriteRuleFactory {
+    private static final String RANDOM_COLUMN_NAME_LEFT = "r1";
+    private static final String RANDOM_COLUMN_NAME_RIGHT = "r2";
+    private static final String SKEW_VALUE_COLUMN_NAME = "skewValue";
+    private static final String EXPLODE_NUMBER_COLUMN_NAME = "explodeColumn";
+    private static final int SALT_FACTOR = 4;
+    private static final int DEFAULT_SALT_VALUE = 0;
+
+    @Override
+    public Rule build() {
+        return logicalJoin()
+                .when(join -> join.getJoinType().isOneSideOuterJoin() || 
join.getJoinType().isInnerJoin())
+                .when(join -> join.getDistributeHint() != null && 
join.getDistributeHint().getSkewInfo() != null)
+                .whenNot(LogicalJoin::isMarkJoin)
+                .whenNot(join -> 
join.getDistributeHint().isSuccessInSkewRewrite())
+                .thenApply(SaltJoin::transform).toRule(RuleType.SALT_JOIN);
+    }
+
+    private static Plan transform(MatchingContext<LogicalJoin<Plan, Plan>> 
ctx) {
+        LogicalJoin<Plan, Plan> join = ctx.root;
+        DistributeHint hint = join.getDistributeHint();
+        if (hint.distributeType != DistributeType.SHUFFLE_RIGHT) {
+            return null;
+        }
+        Expression skewExpr = hint.getSkewExpr();
+        if (!skewExpr.isSlot()) {
+            return null;
+        }
+        if ((join.getJoinType().isLeftOuterJoin() || 
join.getJoinType().isInnerJoin())
+                && !join.left().getOutput().contains((Slot) skewExpr)
+                || join.getJoinType().isRightOuterJoin() && 
!join.right().getOutput().contains((Slot) skewExpr)) {
+            return null;
+        }
+        int factor = getSaltFactor(ctx);
+        Optional<Expression> literalType = 
TypeCoercionUtils.characterLiteralTypeCoercion(String.valueOf(factor),
+                TinyIntType.INSTANCE);
+        if (!literalType.isPresent()) {
+            return null;
+        }
+        Expression leftSkewExpr = null;
+        Expression rightSkewExpr = null;
+        Expression skewConjunct = null;
+        for (Expression conjunct : join.getHashJoinConjuncts()) {
+            if (skewExpr.equals(conjunct.child(0)) || 
skewExpr.equals(conjunct.child(1))) {
+                if (join.left().getOutputSet().contains((Slot) 
conjunct.child(0))
+                        && join.right().getOutputSet().contains((Slot) 
conjunct.child(1))) {
+                    skewConjunct = conjunct;
+                } else if (join.left().getOutputSet().contains((Slot) 
conjunct.child(1))
+                        && join.right().getOutputSet().contains((Slot) 
conjunct.child(0))) {
+                    skewConjunct = ((ComparisonPredicate) conjunct).commute();
+                } else {
+                    return null;
+                }
+                leftSkewExpr = skewConjunct.child(0);
+                rightSkewExpr = skewConjunct.child(1);
+                break;
+            }
+        }
+        if (leftSkewExpr == null || rightSkewExpr == null) {
+            return null;
+        }
+        List<Expression> skewValues = join.getDistributeHint().getSkewValues();
+        Set<Expression> skewValuesSet = new HashSet<>(skewValues);
+        List<Expression> expandSideValues = 
getSaltedSkewValuesForExpandSide(skewConjunct, skewValuesSet);
+        List<Expression> skewSideValues = 
getSaltedSkewValuesForSkewSide(skewConjunct, skewValuesSet, join);
+        if (skewSideValues.isEmpty()) {
+            return null;
+        }
+        DataType type = literalType.get().getDataType();
+        LogicalProject<Plan> rightProject;
+        LogicalProject<Plan> leftProject;
+        if (join.getJoinType() == JoinType.INNER_JOIN || join.getJoinType() == 
JoinType.LEFT_OUTER_JOIN) {
+            leftProject = addRandomSlot(leftSkewExpr, skewSideValues, 
join.left(), factor, type);
+            rightProject = expandSkewValueRows(rightSkewExpr, 
expandSideValues, join.right(), factor, type);
+        } else {

Review Comment:
   process right outer join? better add if here and return null in other cases.
   maybe use switch case is better



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SaltJoin.java:
##########
@@ -0,0 +1,383 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.hint.DistributeHint;
+import org.apache.doris.nereids.hint.Hint.HintStatus;
+import org.apache.doris.nereids.pattern.MatchingContext;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Cast;
+import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.InPredicate;
+import org.apache.doris.nereids.trees.expressions.IsNull;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Not;
+import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
+import org.apache.doris.nereids.trees.expressions.Or;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.Function;
+import 
org.apache.doris.nereids.trees.expressions.functions.generator.ExplodeNumbers;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Random;
+import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.plans.DistributeType;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.types.TinyIntType;
+import org.apache.doris.nereids.util.TypeCoercionUtils;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import org.jetbrains.annotations.Nullable;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Current capabilities and limitations of SaltJoin rewrite handling:
+ * - Supports single-side skew in INNER JOIN, NOT support double-side (both 
tables) skew
+ * - Supports left table skew and NOT support right table skew in LEFT JOIN
+ * - Supports right table skew and Not support left table skew in RIGHT JOIN
+ *
+ * INNER JOIN and LEFT JOIN use case:
+ * Applicable when left table is skewed and right table is too large for 
broadcast
+ *
+ * Here are some examples in rewrite:
+ * case1:
+ * LogicalJoin(type:inner, t1.a=t2.a, hint:skew(t1.a(null,1,2)))
+ *   +--LogicalOlapScan(t1)
+ *   +--LogicalOlapScan(t2)
+ * ->
+ * LogicalJoin (type:inner, t1.a=t2.a and r1=r2)
+ *   |--LogicalProject (t1.a, if (t1.a IN (1, 2), random(0, 999), 
DEFAULT_SALT_VALUE) AS r1))
+ *   |  +--LogicalFilter(t1.a is not null)
+ *   |    +--LogicalOlapScan(t1)
+ *   +--LogicalProject (projections: t2.a, if(explodeNumber IS NULL, 
DEFAULT_SALT_VALUE, explodeNumber) as r2)
+ *     +--LogicalJoin (type=right_outer_join, t2.a = skewValue)
+ *       |--LogicalGenerate(generators=[explode_numbers(1000)], 
generatorOutput=[explodeNumber])
+ *       |  +--LogicalUnion(outputs=[skewValue], constantExprsList(1,2))
+ *       +--LogicalFilter(t2.a is not null)
+ *         +--LogicalOlapScan(t2)
+ *
+ * case2:
+ * LogicalJoin(type:inner, t1.a=t2.a, hint:skew(t1.a(1,2)))
+ *   +--LogicalOlapScan(t1)
+ *   +--LogicalOlapScan(t2)
+ * ->
+ * LogicalJoin (type:inner, t1.a=t2.a and r1=r2)
+ *   |--LogicalProject (t1.a, if (t1.a IN (1, 2), random(0, 999), 
DEFAULT_SALT_VALUE) AS r1))
+ *   | +--LogicalOlapScan(t1)
+ *   +--LogicalProject (projections: t2.a, if(explodeNumber IS NULL, 
DEFAULT_SALT_VALUE, explodeNumber) as r2)
+ *     +--LogicalJoin (type=right_outer_join, t2.a = skewValue)
+ *       |--LogicalGenerate(generators=[explode_numbers(1000)], 
generatorOutput=[explodeNumber])
+ *       |  +--LogicalUnion(outputs=[skewValue], constantExprsList(1,2))
+ *       +--LogicalOlapScan(t2)
+ *
+ * case3: not optimize, because rows will not be output in join when join key 
is null
+ * LogicalJoin(type:inner, t1.a=t2.a, hint:skew(t1.a(null)))
+ *   |--LogicalOlapScan(t1)
+ *   +--LogicalOlapScan(t2)
+ * ->
+ * LogicalJoin(type:inner, t1.a=t2.a)
+ *   |--LogicalFilter(t1.a is not null)
+ *   |  +--LogicalOlapScan(t1)
+ *   +--LogicalFilter(t2.a is not null)
+ *     +--LogicalOlapScan(t2)
+ * */
+public class SaltJoin extends OneRewriteRuleFactory {
+    private static final String RANDOM_COLUMN_NAME_LEFT = "r1";
+    private static final String RANDOM_COLUMN_NAME_RIGHT = "r2";
+    private static final String SKEW_VALUE_COLUMN_NAME = "skewValue";
+    private static final String EXPLODE_NUMBER_COLUMN_NAME = "explodeColumn";
+    private static final int SALT_FACTOR = 4;
+    private static final int DEFAULT_SALT_VALUE = 0;
+
+    @Override
+    public Rule build() {
+        return logicalJoin()
+                .when(join -> join.getJoinType().isOneSideOuterJoin() || 
join.getJoinType().isInnerJoin())
+                .when(join -> join.getDistributeHint() != null && 
join.getDistributeHint().getSkewInfo() != null)
+                .whenNot(LogicalJoin::isMarkJoin)
+                .whenNot(join -> 
join.getDistributeHint().isSuccessInSkewRewrite())
+                .thenApply(SaltJoin::transform).toRule(RuleType.SALT_JOIN);
+    }
+
+    private static Plan transform(MatchingContext<LogicalJoin<Plan, Plan>> 
ctx) {
+        LogicalJoin<Plan, Plan> join = ctx.root;
+        DistributeHint hint = join.getDistributeHint();
+        if (hint.distributeType != DistributeType.SHUFFLE_RIGHT) {
+            return null;
+        }
+        Expression skewExpr = hint.getSkewExpr();
+        if (!skewExpr.isSlot()) {
+            return null;
+        }
+        if ((join.getJoinType().isLeftOuterJoin() || 
join.getJoinType().isInnerJoin())
+                && !join.left().getOutput().contains((Slot) skewExpr)
+                || join.getJoinType().isRightOuterJoin() && 
!join.right().getOutput().contains((Slot) skewExpr)) {
+            return null;
+        }
+        int factor = getSaltFactor(ctx);
+        Optional<Expression> literalType = 
TypeCoercionUtils.characterLiteralTypeCoercion(String.valueOf(factor),
+                TinyIntType.INSTANCE);
+        if (!literalType.isPresent()) {
+            return null;
+        }
+        Expression leftSkewExpr = null;
+        Expression rightSkewExpr = null;
+        Expression skewConjunct = null;
+        for (Expression conjunct : join.getHashJoinConjuncts()) {
+            if (skewExpr.equals(conjunct.child(0)) || 
skewExpr.equals(conjunct.child(1))) {
+                if (join.left().getOutputSet().contains((Slot) 
conjunct.child(0))
+                        && join.right().getOutputSet().contains((Slot) 
conjunct.child(1))) {
+                    skewConjunct = conjunct;
+                } else if (join.left().getOutputSet().contains((Slot) 
conjunct.child(1))
+                        && join.right().getOutputSet().contains((Slot) 
conjunct.child(0))) {
+                    skewConjunct = ((ComparisonPredicate) conjunct).commute();
+                } else {
+                    return null;
+                }
+                leftSkewExpr = skewConjunct.child(0);
+                rightSkewExpr = skewConjunct.child(1);
+                break;
+            }
+        }
+        if (leftSkewExpr == null || rightSkewExpr == null) {
+            return null;
+        }
+        List<Expression> skewValues = join.getDistributeHint().getSkewValues();
+        Set<Expression> skewValuesSet = new HashSet<>(skewValues);
+        List<Expression> expandSideValues = 
getSaltedSkewValuesForExpandSide(skewConjunct, skewValuesSet);
+        List<Expression> skewSideValues = 
getSaltedSkewValuesForSkewSide(skewConjunct, skewValuesSet, join);
+        if (skewSideValues.isEmpty()) {

Review Comment:
   why not check expandSideValues ?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SaltJoin.java:
##########
@@ -0,0 +1,383 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.hint.DistributeHint;
+import org.apache.doris.nereids.hint.Hint.HintStatus;
+import org.apache.doris.nereids.pattern.MatchingContext;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Cast;
+import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.InPredicate;
+import org.apache.doris.nereids.trees.expressions.IsNull;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Not;
+import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
+import org.apache.doris.nereids.trees.expressions.Or;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.Function;
+import 
org.apache.doris.nereids.trees.expressions.functions.generator.ExplodeNumbers;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Random;
+import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.plans.DistributeType;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.types.TinyIntType;
+import org.apache.doris.nereids.util.TypeCoercionUtils;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import org.jetbrains.annotations.Nullable;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Current capabilities and limitations of SaltJoin rewrite handling:
+ * - Supports single-side skew in INNER JOIN, NOT support double-side (both 
tables) skew
+ * - Supports left table skew and NOT support right table skew in LEFT JOIN
+ * - Supports right table skew and Not support left table skew in RIGHT JOIN
+ *
+ * INNER JOIN and LEFT JOIN use case:
+ * Applicable when left table is skewed and right table is too large for 
broadcast
+ *
+ * Here are some examples in rewrite:
+ * case1:
+ * LogicalJoin(type:inner, t1.a=t2.a, hint:skew(t1.a(null,1,2)))
+ *   +--LogicalOlapScan(t1)
+ *   +--LogicalOlapScan(t2)
+ * ->
+ * LogicalJoin (type:inner, t1.a=t2.a and r1=r2)
+ *   |--LogicalProject (t1.a, if (t1.a IN (1, 2), random(0, 999), 
DEFAULT_SALT_VALUE) AS r1))
+ *   |  +--LogicalFilter(t1.a is not null)
+ *   |    +--LogicalOlapScan(t1)
+ *   +--LogicalProject (projections: t2.a, if(explodeNumber IS NULL, 
DEFAULT_SALT_VALUE, explodeNumber) as r2)
+ *     +--LogicalJoin (type=right_outer_join, t2.a = skewValue)
+ *       |--LogicalGenerate(generators=[explode_numbers(1000)], 
generatorOutput=[explodeNumber])
+ *       |  +--LogicalUnion(outputs=[skewValue], constantExprsList(1,2))
+ *       +--LogicalFilter(t2.a is not null)
+ *         +--LogicalOlapScan(t2)
+ *
+ * case2:
+ * LogicalJoin(type:inner, t1.a=t2.a, hint:skew(t1.a(1,2)))
+ *   +--LogicalOlapScan(t1)
+ *   +--LogicalOlapScan(t2)
+ * ->
+ * LogicalJoin (type:inner, t1.a=t2.a and r1=r2)
+ *   |--LogicalProject (t1.a, if (t1.a IN (1, 2), random(0, 999), 
DEFAULT_SALT_VALUE) AS r1))
+ *   | +--LogicalOlapScan(t1)
+ *   +--LogicalProject (projections: t2.a, if(explodeNumber IS NULL, 
DEFAULT_SALT_VALUE, explodeNumber) as r2)
+ *     +--LogicalJoin (type=right_outer_join, t2.a = skewValue)
+ *       |--LogicalGenerate(generators=[explode_numbers(1000)], 
generatorOutput=[explodeNumber])
+ *       |  +--LogicalUnion(outputs=[skewValue], constantExprsList(1,2))
+ *       +--LogicalOlapScan(t2)
+ *
+ * case3: not optimize, because rows will not be output in join when join key 
is null
+ * LogicalJoin(type:inner, t1.a=t2.a, hint:skew(t1.a(null)))
+ *   |--LogicalOlapScan(t1)
+ *   +--LogicalOlapScan(t2)
+ * ->
+ * LogicalJoin(type:inner, t1.a=t2.a)
+ *   |--LogicalFilter(t1.a is not null)
+ *   |  +--LogicalOlapScan(t1)
+ *   +--LogicalFilter(t2.a is not null)
+ *     +--LogicalOlapScan(t2)
+ * */
+public class SaltJoin extends OneRewriteRuleFactory {
+    private static final String RANDOM_COLUMN_NAME_LEFT = "r1";
+    private static final String RANDOM_COLUMN_NAME_RIGHT = "r2";
+    private static final String SKEW_VALUE_COLUMN_NAME = "skewValue";
+    private static final String EXPLODE_NUMBER_COLUMN_NAME = "explodeColumn";
+    private static final int SALT_FACTOR = 4;
+    private static final int DEFAULT_SALT_VALUE = 0;
+
+    @Override
+    public Rule build() {
+        return logicalJoin()
+                .when(join -> join.getJoinType().isOneSideOuterJoin() || 
join.getJoinType().isInnerJoin())
+                .when(join -> join.getDistributeHint() != null && 
join.getDistributeHint().getSkewInfo() != null)
+                .whenNot(LogicalJoin::isMarkJoin)
+                .whenNot(join -> 
join.getDistributeHint().isSuccessInSkewRewrite())
+                .thenApply(SaltJoin::transform).toRule(RuleType.SALT_JOIN);
+    }
+
+    private static Plan transform(MatchingContext<LogicalJoin<Plan, Plan>> 
ctx) {
+        LogicalJoin<Plan, Plan> join = ctx.root;
+        DistributeHint hint = join.getDistributeHint();
+        if (hint.distributeType != DistributeType.SHUFFLE_RIGHT) {
+            return null;
+        }
+        Expression skewExpr = hint.getSkewExpr();
+        if (!skewExpr.isSlot()) {
+            return null;
+        }
+        if ((join.getJoinType().isLeftOuterJoin() || 
join.getJoinType().isInnerJoin())
+                && !join.left().getOutput().contains((Slot) skewExpr)
+                || join.getJoinType().isRightOuterJoin() && 
!join.right().getOutput().contains((Slot) skewExpr)) {
+            return null;
+        }
+        int factor = getSaltFactor(ctx);
+        Optional<Expression> literalType = 
TypeCoercionUtils.characterLiteralTypeCoercion(String.valueOf(factor),
+                TinyIntType.INSTANCE);
+        if (!literalType.isPresent()) {
+            return null;
+        }
+        Expression leftSkewExpr = null;
+        Expression rightSkewExpr = null;
+        Expression skewConjunct = null;
+        for (Expression conjunct : join.getHashJoinConjuncts()) {
+            if (skewExpr.equals(conjunct.child(0)) || 
skewExpr.equals(conjunct.child(1))) {
+                if (join.left().getOutputSet().contains((Slot) 
conjunct.child(0))
+                        && join.right().getOutputSet().contains((Slot) 
conjunct.child(1))) {
+                    skewConjunct = conjunct;
+                } else if (join.left().getOutputSet().contains((Slot) 
conjunct.child(1))
+                        && join.right().getOutputSet().contains((Slot) 
conjunct.child(0))) {
+                    skewConjunct = ((ComparisonPredicate) conjunct).commute();
+                } else {
+                    return null;
+                }
+                leftSkewExpr = skewConjunct.child(0);
+                rightSkewExpr = skewConjunct.child(1);
+                break;
+            }
+        }
+        if (leftSkewExpr == null || rightSkewExpr == null) {
+            return null;
+        }
+        List<Expression> skewValues = join.getDistributeHint().getSkewValues();
+        Set<Expression> skewValuesSet = new HashSet<>(skewValues);
+        List<Expression> expandSideValues = 
getSaltedSkewValuesForExpandSide(skewConjunct, skewValuesSet);
+        List<Expression> skewSideValues = 
getSaltedSkewValuesForSkewSide(skewConjunct, skewValuesSet, join);
+        if (skewSideValues.isEmpty()) {
+            return null;
+        }
+        DataType type = literalType.get().getDataType();
+        LogicalProject<Plan> rightProject;
+        LogicalProject<Plan> leftProject;
+        if (join.getJoinType() == JoinType.INNER_JOIN || join.getJoinType() == 
JoinType.LEFT_OUTER_JOIN) {
+            leftProject = addRandomSlot(leftSkewExpr, skewSideValues, 
join.left(), factor, type);
+            rightProject = expandSkewValueRows(rightSkewExpr, 
expandSideValues, join.right(), factor, type);
+        } else {
+            leftProject = expandSkewValueRows(leftSkewExpr, expandSideValues, 
join.left(), factor, type);
+            rightProject = addRandomSlot(rightSkewExpr, skewSideValues, 
join.right(), factor, type);
+        }
+        EqualTo saltEqual = new 
EqualTo(leftProject.getProjects().get(leftProject.getProjects().size() - 
1).toSlot(),
+                
rightProject.getProjects().get(rightProject.getProjects().size() - 1).toSlot());
+        saltEqual = (EqualTo) 
TypeCoercionUtils.processComparisonPredicate(saltEqual);
+        ImmutableList.Builder<Expression> newHashJoinConjuncts = 
ImmutableList.builderWithExpectedSize(
+                join.getHashJoinConjuncts().size() + 1);
+        newHashJoinConjuncts.addAll(join.getHashJoinConjuncts());
+        newHashJoinConjuncts.add(saltEqual);
+        hint.setStatus(HintStatus.SUCCESS);
+        hint.setSkewInfo(hint.getSkewInfo().withSuccessInSaltJoin(true));
+        return new LogicalJoin<>(join.getJoinType(), 
newHashJoinConjuncts.build(), join.getOtherJoinConjuncts(),
+                hint, leftProject, rightProject, JoinReorderContext.EMPTY);
+    }
+
+    // Add a project on top of originPlan, which includes all the original 
columns plus a case when column.
+    private static LogicalProject<Plan> addRandomSlot(Expression skewExpr, 
List<Expression> skewValues,
+            Plan originPlan, int factor, DataType type) {
+        List<Expression> skewValuesExceptNull = 
skewValues.stream().filter(value -> !(value instanceof NullLiteral))
+                .collect(Collectors.toList());
+        Expression ifCondition = getIfCondition(skewExpr, skewValues, 
skewValuesExceptNull);
+        Random random = new Random(new BigIntLiteral(0), new 
BigIntLiteral(factor - 1));
+        Cast cast = new Cast(random, type);
+        If ifExpr = new If(ifCondition, cast, 
Literal.convertToTypedLiteral(DEFAULT_SALT_VALUE, type));
+        ImmutableList.Builder<NamedExpression> namedExpressionsBuilder = 
ImmutableList.builderWithExpectedSize(
+                originPlan.getOutput().size() + 1);
+        namedExpressionsBuilder.addAll(originPlan.getOutput());
+        namedExpressionsBuilder.add(new Alias(ifExpr, 
RANDOM_COLUMN_NAME_LEFT));

Review Comment:
   use a meaningful name as prefix and follow by 
`org.apache.doris.nereids.StatementContext#generateColumnName`



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/copier/LogicalPlanDeepCopier.java:
##########
@@ -286,8 +287,13 @@ public Plan visitLogicalJoin(LogicalJoin<? extends Plan, ? 
extends Plan> join, D
                     .deepCopy(join.getMarkJoinSlotReference().get(), context));
 
         }
+        DistributeHint hint = join.getDistributeHint();
+        if (hint.getSkewInfo() != null) {
+            Expression skewExpr = 
ExpressionDeepCopier.INSTANCE.deepCopy(hint.getSkewExpr(), context);
+            hint.setSkewInfo(hint.getSkewInfo().withSkewExpr(skewExpr));
+        }

Review Comment:
   duplicate code block, could do better abstraction?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SaltJoin.java:
##########
@@ -0,0 +1,383 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.hint.DistributeHint;
+import org.apache.doris.nereids.hint.Hint.HintStatus;
+import org.apache.doris.nereids.pattern.MatchingContext;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Cast;
+import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.InPredicate;
+import org.apache.doris.nereids.trees.expressions.IsNull;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Not;
+import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
+import org.apache.doris.nereids.trees.expressions.Or;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.Function;
+import 
org.apache.doris.nereids.trees.expressions.functions.generator.ExplodeNumbers;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Random;
+import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.plans.DistributeType;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.types.TinyIntType;
+import org.apache.doris.nereids.util.TypeCoercionUtils;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import org.jetbrains.annotations.Nullable;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Current capabilities and limitations of SaltJoin rewrite handling:
+ * - Supports single-side skew in INNER JOIN, NOT support double-side (both 
tables) skew
+ * - Supports left table skew and NOT support right table skew in LEFT JOIN
+ * - Supports right table skew and Not support left table skew in RIGHT JOIN
+ *
+ * INNER JOIN and LEFT JOIN use case:
+ * Applicable when left table is skewed and right table is too large for 
broadcast
+ *
+ * Here are some examples in rewrite:
+ * case1:
+ * LogicalJoin(type:inner, t1.a=t2.a, hint:skew(t1.a(null,1,2)))
+ *   +--LogicalOlapScan(t1)
+ *   +--LogicalOlapScan(t2)
+ * ->
+ * LogicalJoin (type:inner, t1.a=t2.a and r1=r2)
+ *   |--LogicalProject (t1.a, if (t1.a IN (1, 2), random(0, 999), 
DEFAULT_SALT_VALUE) AS r1))
+ *   |  +--LogicalFilter(t1.a is not null)
+ *   |    +--LogicalOlapScan(t1)
+ *   +--LogicalProject (projections: t2.a, if(explodeNumber IS NULL, 
DEFAULT_SALT_VALUE, explodeNumber) as r2)
+ *     +--LogicalJoin (type=right_outer_join, t2.a = skewValue)
+ *       |--LogicalGenerate(generators=[explode_numbers(1000)], 
generatorOutput=[explodeNumber])
+ *       |  +--LogicalUnion(outputs=[skewValue], constantExprsList(1,2))
+ *       +--LogicalFilter(t2.a is not null)
+ *         +--LogicalOlapScan(t2)
+ *
+ * case2:
+ * LogicalJoin(type:inner, t1.a=t2.a, hint:skew(t1.a(1,2)))
+ *   +--LogicalOlapScan(t1)
+ *   +--LogicalOlapScan(t2)
+ * ->
+ * LogicalJoin (type:inner, t1.a=t2.a and r1=r2)
+ *   |--LogicalProject (t1.a, if (t1.a IN (1, 2), random(0, 999), 
DEFAULT_SALT_VALUE) AS r1))
+ *   | +--LogicalOlapScan(t1)
+ *   +--LogicalProject (projections: t2.a, if(explodeNumber IS NULL, 
DEFAULT_SALT_VALUE, explodeNumber) as r2)
+ *     +--LogicalJoin (type=right_outer_join, t2.a = skewValue)
+ *       |--LogicalGenerate(generators=[explode_numbers(1000)], 
generatorOutput=[explodeNumber])
+ *       |  +--LogicalUnion(outputs=[skewValue], constantExprsList(1,2))
+ *       +--LogicalOlapScan(t2)
+ *
+ * case3: not optimize, because rows will not be output in join when join key 
is null
+ * LogicalJoin(type:inner, t1.a=t2.a, hint:skew(t1.a(null)))
+ *   |--LogicalOlapScan(t1)
+ *   +--LogicalOlapScan(t2)
+ * ->
+ * LogicalJoin(type:inner, t1.a=t2.a)
+ *   |--LogicalFilter(t1.a is not null)
+ *   |  +--LogicalOlapScan(t1)
+ *   +--LogicalFilter(t2.a is not null)
+ *     +--LogicalOlapScan(t2)
+ * */
+public class SaltJoin extends OneRewriteRuleFactory {
+    private static final String RANDOM_COLUMN_NAME_LEFT = "r1";
+    private static final String RANDOM_COLUMN_NAME_RIGHT = "r2";
+    private static final String SKEW_VALUE_COLUMN_NAME = "skewValue";
+    private static final String EXPLODE_NUMBER_COLUMN_NAME = "explodeColumn";
+    private static final int SALT_FACTOR = 4;
+    private static final int DEFAULT_SALT_VALUE = 0;
+
+    @Override
+    public Rule build() {
+        return logicalJoin()
+                .when(join -> join.getJoinType().isOneSideOuterJoin() || 
join.getJoinType().isInnerJoin())
+                .when(join -> join.getDistributeHint() != null && 
join.getDistributeHint().getSkewInfo() != null)
+                .whenNot(LogicalJoin::isMarkJoin)
+                .whenNot(join -> 
join.getDistributeHint().isSuccessInSkewRewrite())
+                .thenApply(SaltJoin::transform).toRule(RuleType.SALT_JOIN);
+    }
+
+    private static Plan transform(MatchingContext<LogicalJoin<Plan, Plan>> 
ctx) {
+        LogicalJoin<Plan, Plan> join = ctx.root;
+        DistributeHint hint = join.getDistributeHint();
+        if (hint.distributeType != DistributeType.SHUFFLE_RIGHT) {
+            return null;
+        }
+        Expression skewExpr = hint.getSkewExpr();
+        if (!skewExpr.isSlot()) {
+            return null;
+        }
+        if ((join.getJoinType().isLeftOuterJoin() || 
join.getJoinType().isInnerJoin())
+                && !join.left().getOutput().contains((Slot) skewExpr)
+                || join.getJoinType().isRightOuterJoin() && 
!join.right().getOutput().contains((Slot) skewExpr)) {
+            return null;
+        }
+        int factor = getSaltFactor(ctx);
+        Optional<Expression> literalType = 
TypeCoercionUtils.characterLiteralTypeCoercion(String.valueOf(factor),
+                TinyIntType.INSTANCE);
+        if (!literalType.isPresent()) {
+            return null;
+        }
+        Expression leftSkewExpr = null;
+        Expression rightSkewExpr = null;
+        Expression skewConjunct = null;
+        for (Expression conjunct : join.getHashJoinConjuncts()) {
+            if (skewExpr.equals(conjunct.child(0)) || 
skewExpr.equals(conjunct.child(1))) {
+                if (join.left().getOutputSet().contains((Slot) 
conjunct.child(0))
+                        && join.right().getOutputSet().contains((Slot) 
conjunct.child(1))) {
+                    skewConjunct = conjunct;
+                } else if (join.left().getOutputSet().contains((Slot) 
conjunct.child(1))
+                        && join.right().getOutputSet().contains((Slot) 
conjunct.child(0))) {
+                    skewConjunct = ((ComparisonPredicate) conjunct).commute();
+                } else {
+                    return null;
+                }
+                leftSkewExpr = skewConjunct.child(0);
+                rightSkewExpr = skewConjunct.child(1);
+                break;
+            }
+        }
+        if (leftSkewExpr == null || rightSkewExpr == null) {
+            return null;
+        }
+        List<Expression> skewValues = join.getDistributeHint().getSkewValues();
+        Set<Expression> skewValuesSet = new HashSet<>(skewValues);
+        List<Expression> expandSideValues = 
getSaltedSkewValuesForExpandSide(skewConjunct, skewValuesSet);
+        List<Expression> skewSideValues = 
getSaltedSkewValuesForSkewSide(skewConjunct, skewValuesSet, join);
+        if (skewSideValues.isEmpty()) {
+            return null;
+        }
+        DataType type = literalType.get().getDataType();
+        LogicalProject<Plan> rightProject;
+        LogicalProject<Plan> leftProject;
+        if (join.getJoinType() == JoinType.INNER_JOIN || join.getJoinType() == 
JoinType.LEFT_OUTER_JOIN) {
+            leftProject = addRandomSlot(leftSkewExpr, skewSideValues, 
join.left(), factor, type);
+            rightProject = expandSkewValueRows(rightSkewExpr, 
expandSideValues, join.right(), factor, type);
+        } else {
+            leftProject = expandSkewValueRows(leftSkewExpr, expandSideValues, 
join.left(), factor, type);
+            rightProject = addRandomSlot(rightSkewExpr, skewSideValues, 
join.right(), factor, type);
+        }
+        EqualTo saltEqual = new 
EqualTo(leftProject.getProjects().get(leftProject.getProjects().size() - 
1).toSlot(),
+                
rightProject.getProjects().get(rightProject.getProjects().size() - 1).toSlot());
+        saltEqual = (EqualTo) 
TypeCoercionUtils.processComparisonPredicate(saltEqual);
+        ImmutableList.Builder<Expression> newHashJoinConjuncts = 
ImmutableList.builderWithExpectedSize(
+                join.getHashJoinConjuncts().size() + 1);
+        newHashJoinConjuncts.addAll(join.getHashJoinConjuncts());
+        newHashJoinConjuncts.add(saltEqual);
+        hint.setStatus(HintStatus.SUCCESS);
+        hint.setSkewInfo(hint.getSkewInfo().withSuccessInSaltJoin(true));
+        return new LogicalJoin<>(join.getJoinType(), 
newHashJoinConjuncts.build(), join.getOtherJoinConjuncts(),
+                hint, leftProject, rightProject, JoinReorderContext.EMPTY);
+    }
+
+    // Add a project on top of originPlan, which includes all the original 
columns plus a case when column.
+    private static LogicalProject<Plan> addRandomSlot(Expression skewExpr, 
List<Expression> skewValues,
+            Plan originPlan, int factor, DataType type) {
+        List<Expression> skewValuesExceptNull = 
skewValues.stream().filter(value -> !(value instanceof NullLiteral))

Review Comment:
   collect as ImmutableList



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java:
##########
@@ -2086,4 +2104,140 @@ private boolean couldConvertToMulti(LogicalAggregate<? 
extends Plan> aggregate)
         }
         return true;
     }
+
+    /**
+     * LogicalAggregate(groupByExpr=[a], outputExpr=[a,count(distinct b)])
+     * ->
+     * +--PhysicalHashAggregate(groupByExpr=[a], outputExpr=[a, 
count(partial_count(m))]
+     *   +--PhysicalDistribute(shuffleColumn=[a])
+     *     +--PhysicalHashAggregate(groupByExpr=[a], outputExpr=[a, 
partial_count(m)]
+     *       +--PhysicalHashAggregate(groupByExpr=[a, saltExpr], 
outputExpr=[a, multi_distinct_count(b) as m])
+     *         +--PhysicalDistribute(shuffleColumn=[a, saltExpr])
+     *           +--PhysicalProject(projects=[a, b, xxhash_32(b)%512 as 
saltExpr])
+     *             +--PhysicalHashAggregate(groupByExpr=[a, b], outputExpr=[a, 
b])
+     * */
+    private PhysicalHashAggregate<Plan> 
countDistinctSkewRewrite(LogicalAggregate<GroupPlan> logicalAgg,
+            CascadesContext cascadesContext) {
+        if (!logicalAgg.canSkewRewrite()) {
+            return null;
+        }
+
+        // 1.local agg
+        ImmutableList.Builder<Expression> localAggGroupByBuilder = 
ImmutableList.builderWithExpectedSize(
+                logicalAgg.getGroupByExpressions().size() + 1);
+        localAggGroupByBuilder.addAll(logicalAgg.getGroupByExpressions());
+        AggregateFunction aggFunc = 
logicalAgg.getAggregateFunctions().iterator().next();
+        if (!(aggFunc.child(0) instanceof Slot)) {
+            return null;
+        }
+        localAggGroupByBuilder.add(aggFunc.child(0));
+        List<Expression> localAggGroupBy = localAggGroupByBuilder.build();
+        List<NamedExpression> localAggOutput = 
Utils.fastToImmutableList((List) localAggGroupBy);
+        RequireProperties requireAny = 
RequireProperties.of(PhysicalProperties.ANY);
+        boolean maybeUsingStreamAgg = 
maybeUsingStreamAgg(cascadesContext.getConnectContext(),
+                localAggGroupBy);
+        boolean couldBanned = false;
+        AggregateParam localParam = new AggregateParam(AggPhase.LOCAL, 
AggMode.INPUT_TO_BUFFER, couldBanned);
+        PhysicalHashAggregate<Plan> localAgg = new 
PhysicalHashAggregate<>(localAggGroupBy, localAggOutput,
+                Optional.empty(), localParam, maybeUsingStreamAgg, 
Optional.empty(), null,
+                requireAny, logicalAgg.child());
+        // add shuffle expr in project
+        ImmutableList.Builder<NamedExpression> projections = 
ImmutableList.builderWithExpectedSize(
+                localAgg.getOutputs().size() + 1);
+        projections.addAll(localAgg.getOutputs());
+        Alias modAlias = getShuffleExpr(aggFunc, cascadesContext);
+        projections.add(modAlias);
+        PhysicalProject<Plan> physicalProject = new 
PhysicalProject<>(projections.build(), null, localAgg);
+
+        // 2.second phase agg: multi_distinct_count(b) group by a,h
+        ImmutableList.Builder<Expression> secondPhaseAggGroupByBuilder = 
ImmutableList.builderWithExpectedSize(
+                logicalAgg.getGroupByExpressions().size() + 1);
+        
secondPhaseAggGroupByBuilder.addAll(logicalAgg.getGroupByExpressions());
+        secondPhaseAggGroupByBuilder.add(modAlias.toSlot());
+        List<Expression> secondPhaseAggGroupBy = 
secondPhaseAggGroupByBuilder.build();
+        ImmutableList.Builder<NamedExpression> secondPhaseAggOutput = 
ImmutableList.builderWithExpectedSize(
+                secondPhaseAggGroupBy.size() + 1);
+        secondPhaseAggOutput.addAll((List) secondPhaseAggGroupBy);
+        Alias aliasTarget = new Alias(new TinyIntLiteral((byte) 0));
+        for (NamedExpression ne : logicalAgg.getOutputExpressions()) {
+            if (ne instanceof Alias) {
+                if (((Alias) ne).child().equals(aggFunc)) {
+                    aliasTarget = (Alias) ne;
+                }
+            }
+        }
+        AggregateParam secondParam = new AggregateParam(AggPhase.GLOBAL, 
AggMode.INPUT_TO_RESULT, couldBanned);
+        AggregateFunction multiDistinct = ((SupportMultiDistinct) 
aggFunc).convertToMultiDistinct();
+        Alias multiDistinctAlias = new Alias(new 
AggregateExpression(multiDistinct, secondParam));
+        secondPhaseAggOutput.add(multiDistinctAlias);
+        List<ExprId> shuffleIds = new ArrayList<>();
+        for (Expression expr : secondPhaseAggGroupBy) {
+            if (expr instanceof Slot) {
+                shuffleIds.add(((Slot) expr).getExprId());
+            }
+        }
+        RequireProperties secondRequireProperties = RequireProperties.of(
+                PhysicalProperties.createHash(shuffleIds, 
ShuffleType.REQUIRE));
+        PhysicalHashAggregate<Plan> secondPhaseAgg = new 
PhysicalHashAggregate<>(
+                secondPhaseAggGroupBy, secondPhaseAggOutput.build(),
+                Optional.empty(), secondParam, false, Optional.empty(), null,
+                secondRequireProperties, physicalProject);
+
+        // 3. third phase agg
+        List<Expression> thirdPhaseAggGroupBy = 
Utils.fastToImmutableList(logicalAgg.getGroupByExpressions());
+        ImmutableList.Builder<NamedExpression> thirdPhaseAggOutput = 
ImmutableList.builderWithExpectedSize(
+                thirdPhaseAggGroupBy.size() + 1);
+        thirdPhaseAggOutput.addAll((List) thirdPhaseAggGroupBy);
+        AggregateParam thirdParam = new 
AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned);
+        AggregateFunction function = getAggregateFunction(aggFunc);
+        AggregateFunction thirdAggFunc = 
function.withDistinctAndChildren(false,
+                ImmutableList.of(multiDistinctAlias.toSlot()));
+        Alias thirdCountAlias = new Alias(new 
AggregateExpression(thirdAggFunc, thirdParam));
+        thirdPhaseAggOutput.add(thirdCountAlias);
+        PhysicalHashAggregate<Plan> thirdPhaseAgg = new 
PhysicalHashAggregate<>(
+                thirdPhaseAggGroupBy, thirdPhaseAggOutput.build(),
+                Optional.empty(), thirdParam, false, Optional.empty(), null,
+                secondRequireProperties, secondPhaseAgg);
+
+        // 4. fourth phase agg
+        ImmutableList.Builder<NamedExpression> fourthPhaseAggOutput = 
ImmutableList.builderWithExpectedSize(
+                thirdPhaseAggGroupBy.size() + 1);
+        fourthPhaseAggOutput.addAll((List) thirdPhaseAggGroupBy);
+        AggregateParam fourthParam = new 
AggregateParam(AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT,
+                couldBanned);
+        Alias sumAliasFour = new Alias(aliasTarget.getExprId(),
+                new AggregateExpression(thirdAggFunc, fourthParam, 
thirdCountAlias.toSlot()),
+                aliasTarget.getName());
+        fourthPhaseAggOutput.add(sumAliasFour);
+        List<ExprId> shuffleIdsFour = new ArrayList<>();
+        for (Expression expr : logicalAgg.getExpressions()) {
+            if (expr instanceof Slot) {
+                shuffleIdsFour.add(((Slot) expr).getExprId());
+            }
+        }
+        RequireProperties fourthRequireProperties = RequireProperties.of(
+                PhysicalProperties.createHash(shuffleIdsFour, 
ShuffleType.REQUIRE));
+        return new PhysicalHashAggregate<>(thirdPhaseAggGroupBy,
+                fourthPhaseAggOutput.build(), Optional.empty(), fourthParam,
+                false, Optional.empty(), logicalAgg.getLogicalProperties(),
+                fourthRequireProperties, thirdPhaseAgg);
+    }
+
+    private AggregateFunction getAggregateFunction(AggregateFunction aggFunc) {
+        if (aggFunc instanceof Count) {
+            return new Sum0(aggFunc.child(0));
+        } else {
+            return aggFunc;
+        }
+    }
+
+    private Alias getShuffleExpr(AggregateFunction aggFunc, CascadesContext 
cascadesContext) {
+        int bucketNum = 
cascadesContext.getConnectContext().getSessionVariable().aggDistinctSkewRewriteBucketNum;
+        int bucket = bucketNum / 2;
+        DataType type = bucket <= 128 ? TinyIntType.INSTANCE : 
SmallIntType.INSTANCE;
+        Mod mod = new Mod(new XxHash32(TypeCoercionUtils.castIfNotSameType(
+                aggFunc.child(0), StringType.INSTANCE)), new 
SmallIntLiteral((short) bucket));
+        Cast cast = new Cast(mod, type);
+        return new Alias(cast, SALT_EXPR);

Review Comment:
   why need a static name? if must, suggest use it as a prefix and follow by 
`org.apache.doris.nereids.StatementContext#generateColumnName`



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java:
##########
@@ -436,6 +451,9 @@ && couldConvertToMulti(agg))
                                 secondPhaseRequireGroupByAndDistinctHash, 
fourPhaseRequireGroupByHash
                         );
                     })
+            ),
+            RuleType.COUNT_DISTINCT_AGG_SKEW_REWRITE.build(
+                    basePattern.thenApply(ctx -> 
countDistinctSkewRewrite(ctx.root, ctx.cascadesContext))

Review Comment:
   when(aggregate::canSkewRewrite)



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java:
##########
@@ -436,6 +451,9 @@ && couldConvertToMulti(agg))
                                 secondPhaseRequireGroupByAndDistinctHash, 
fourPhaseRequireGroupByHash
                         );
                     })
+            ),
+            RuleType.COUNT_DISTINCT_AGG_SKEW_REWRITE.build(

Review Comment:
   only count_distinct?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SaltJoin.java:
##########
@@ -0,0 +1,383 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.hint.DistributeHint;
+import org.apache.doris.nereids.hint.Hint.HintStatus;
+import org.apache.doris.nereids.pattern.MatchingContext;
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.Cast;
+import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
+import org.apache.doris.nereids.trees.expressions.EqualPredicate;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.InPredicate;
+import org.apache.doris.nereids.trees.expressions.IsNull;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Not;
+import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
+import org.apache.doris.nereids.trees.expressions.Or;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.functions.Function;
+import 
org.apache.doris.nereids.trees.expressions.functions.generator.ExplodeNumbers;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Random;
+import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.plans.DistributeType;
+import org.apache.doris.nereids.trees.plans.JoinType;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
+import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
+import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.IntegerType;
+import org.apache.doris.nereids.types.TinyIntType;
+import org.apache.doris.nereids.util.TypeCoercionUtils;
+import org.apache.doris.nereids.util.Utils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import org.jetbrains.annotations.Nullable;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Current capabilities and limitations of SaltJoin rewrite handling:
+ * - Supports single-side skew in INNER JOIN, NOT support double-side (both 
tables) skew
+ * - Supports left table skew and NOT support right table skew in LEFT JOIN
+ * - Supports right table skew and Not support left table skew in RIGHT JOIN
+ *
+ * INNER JOIN and LEFT JOIN use case:
+ * Applicable when left table is skewed and right table is too large for 
broadcast
+ *
+ * Here are some examples in rewrite:
+ * case1:
+ * LogicalJoin(type:inner, t1.a=t2.a, hint:skew(t1.a(null,1,2)))
+ *   +--LogicalOlapScan(t1)
+ *   +--LogicalOlapScan(t2)
+ * ->
+ * LogicalJoin (type:inner, t1.a=t2.a and r1=r2)
+ *   |--LogicalProject (t1.a, if (t1.a IN (1, 2), random(0, 999), 
DEFAULT_SALT_VALUE) AS r1))
+ *   |  +--LogicalFilter(t1.a is not null)
+ *   |    +--LogicalOlapScan(t1)
+ *   +--LogicalProject (projections: t2.a, if(explodeNumber IS NULL, 
DEFAULT_SALT_VALUE, explodeNumber) as r2)
+ *     +--LogicalJoin (type=right_outer_join, t2.a = skewValue)
+ *       |--LogicalGenerate(generators=[explode_numbers(1000)], 
generatorOutput=[explodeNumber])
+ *       |  +--LogicalUnion(outputs=[skewValue], constantExprsList(1,2))
+ *       +--LogicalFilter(t2.a is not null)
+ *         +--LogicalOlapScan(t2)
+ *
+ * case2:
+ * LogicalJoin(type:inner, t1.a=t2.a, hint:skew(t1.a(1,2)))
+ *   +--LogicalOlapScan(t1)
+ *   +--LogicalOlapScan(t2)
+ * ->
+ * LogicalJoin (type:inner, t1.a=t2.a and r1=r2)
+ *   |--LogicalProject (t1.a, if (t1.a IN (1, 2), random(0, 999), 
DEFAULT_SALT_VALUE) AS r1))
+ *   | +--LogicalOlapScan(t1)
+ *   +--LogicalProject (projections: t2.a, if(explodeNumber IS NULL, 
DEFAULT_SALT_VALUE, explodeNumber) as r2)
+ *     +--LogicalJoin (type=right_outer_join, t2.a = skewValue)
+ *       |--LogicalGenerate(generators=[explode_numbers(1000)], 
generatorOutput=[explodeNumber])
+ *       |  +--LogicalUnion(outputs=[skewValue], constantExprsList(1,2))
+ *       +--LogicalOlapScan(t2)
+ *
+ * case3: not optimize, because rows will not be output in join when join key 
is null
+ * LogicalJoin(type:inner, t1.a=t2.a, hint:skew(t1.a(null)))
+ *   |--LogicalOlapScan(t1)
+ *   +--LogicalOlapScan(t2)
+ * ->
+ * LogicalJoin(type:inner, t1.a=t2.a)
+ *   |--LogicalFilter(t1.a is not null)
+ *   |  +--LogicalOlapScan(t1)
+ *   +--LogicalFilter(t2.a is not null)
+ *     +--LogicalOlapScan(t2)
+ * */
+public class SaltJoin extends OneRewriteRuleFactory {
+    private static final String RANDOM_COLUMN_NAME_LEFT = "r1";
+    private static final String RANDOM_COLUMN_NAME_RIGHT = "r2";
+    private static final String SKEW_VALUE_COLUMN_NAME = "skewValue";
+    private static final String EXPLODE_NUMBER_COLUMN_NAME = "explodeColumn";
+    private static final int SALT_FACTOR = 4;
+    private static final int DEFAULT_SALT_VALUE = 0;
+
+    @Override
+    public Rule build() {
+        return logicalJoin()
+                .when(join -> join.getJoinType().isOneSideOuterJoin() || 
join.getJoinType().isInnerJoin())
+                .when(join -> join.getDistributeHint() != null && 
join.getDistributeHint().getSkewInfo() != null)
+                .whenNot(LogicalJoin::isMarkJoin)
+                .whenNot(join -> 
join.getDistributeHint().isSuccessInSkewRewrite())
+                .thenApply(SaltJoin::transform).toRule(RuleType.SALT_JOIN);
+    }
+
+    private static Plan transform(MatchingContext<LogicalJoin<Plan, Plan>> 
ctx) {
+        LogicalJoin<Plan, Plan> join = ctx.root;
+        DistributeHint hint = join.getDistributeHint();
+        if (hint.distributeType != DistributeType.SHUFFLE_RIGHT) {
+            return null;
+        }
+        Expression skewExpr = hint.getSkewExpr();
+        if (!skewExpr.isSlot()) {
+            return null;
+        }
+        if ((join.getJoinType().isLeftOuterJoin() || 
join.getJoinType().isInnerJoin())
+                && !join.left().getOutput().contains((Slot) skewExpr)
+                || join.getJoinType().isRightOuterJoin() && 
!join.right().getOutput().contains((Slot) skewExpr)) {
+            return null;
+        }
+        int factor = getSaltFactor(ctx);
+        Optional<Expression> literalType = 
TypeCoercionUtils.characterLiteralTypeCoercion(String.valueOf(factor),
+                TinyIntType.INSTANCE);

Review Comment:
   must tinyint? so [-128, 127] ?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java:
##########
@@ -2086,4 +2104,140 @@ private boolean couldConvertToMulti(LogicalAggregate<? 
extends Plan> aggregate)
         }
         return true;
     }
+
+    /**
+     * LogicalAggregate(groupByExpr=[a], outputExpr=[a,count(distinct b)])
+     * ->
+     * +--PhysicalHashAggregate(groupByExpr=[a], outputExpr=[a, 
count(partial_count(m))]
+     *   +--PhysicalDistribute(shuffleColumn=[a])
+     *     +--PhysicalHashAggregate(groupByExpr=[a], outputExpr=[a, 
partial_count(m)]
+     *       +--PhysicalHashAggregate(groupByExpr=[a, saltExpr], 
outputExpr=[a, multi_distinct_count(b) as m])
+     *         +--PhysicalDistribute(shuffleColumn=[a, saltExpr])
+     *           +--PhysicalProject(projects=[a, b, xxhash_32(b)%512 as 
saltExpr])
+     *             +--PhysicalHashAggregate(groupByExpr=[a, b], outputExpr=[a, 
b])
+     * */
+    private PhysicalHashAggregate<Plan> 
countDistinctSkewRewrite(LogicalAggregate<GroupPlan> logicalAgg,
+            CascadesContext cascadesContext) {
+        if (!logicalAgg.canSkewRewrite()) {
+            return null;
+        }
+
+        // 1.local agg
+        ImmutableList.Builder<Expression> localAggGroupByBuilder = 
ImmutableList.builderWithExpectedSize(
+                logicalAgg.getGroupByExpressions().size() + 1);
+        localAggGroupByBuilder.addAll(logicalAgg.getGroupByExpressions());
+        AggregateFunction aggFunc = 
logicalAgg.getAggregateFunctions().iterator().next();
+        if (!(aggFunc.child(0) instanceof Slot)) {
+            return null;
+        }

Review Comment:
   should add into canSkewRewrite



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to