This is an automated email from the ASF dual-hosted git repository. jakevin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 86fc19397b8 [refactor](Nereids): refactor PredicatePropagation & support to infer Equal Condition (#29644) 86fc19397b8 is described below commit 86fc19397b80613f179b412999cde0e4219bf7b0 Author: jakevin <jakevin...@gmail.com> AuthorDate: Mon Jan 8 19:42:12 2024 +0800 [refactor](Nereids): refactor PredicatePropagation & support to infer Equal Condition (#29644) --- .../nereids/rules/rewrite/EliminateJoinByFK.java | 6 +- .../rules/rewrite/PredicatePropagation.java | 176 ++++++++------------- .../nereids/trees/plans/logical/LogicalJoin.java | 8 +- ...eEquivalenceSet.java => ImmutableEqualSet.java} | 64 +++++--- .../rules/rewrite/PredicatePropagationTest.java | 16 ++ .../data/nereids_p0/hint/fix_leading.out | 2 +- 6 files changed, 134 insertions(+), 138 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java index 594dee50853..b4a6eac207b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateJoinByFK.java @@ -36,7 +36,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalRelation; import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; -import org.apache.doris.nereids.util.ImmutableEquivalenceSet; +import org.apache.doris.nereids.util.ImmutableEqualSet; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; @@ -145,7 +145,7 @@ public class EliminateJoinByFK extends DefaultPlanRewriter<JobContext> implement return project; } - private @Nullable Map<Slot, Slot> mapPrimaryToForeign(ImmutableEquivalenceSet<Slot> equivalenceSet, + private @Nullable Map<Slot, Slot> mapPrimaryToForeign(ImmutableEqualSet<Slot> equivalenceSet, Set<Slot> foreignKeys) { ImmutableMap.Builder<Slot, Slot> builder = new ImmutableMap.Builder<>(); for (Slot foreignSlot : foreignKeys) { @@ -164,7 +164,7 @@ public class EliminateJoinByFK extends DefaultPlanRewriter<JobContext> implement // 4. if foreign key is null, add a isNotNull predicate for null-reject join condition private Plan eliminateJoin(LogicalProject<LogicalJoin<?, ?>> project, ForeignKeyContext context) { LogicalJoin<?, ?> join = project.child(); - ImmutableEquivalenceSet<Slot> equalSet = join.getEqualSlots(); + ImmutableEqualSet<Slot> equalSet = join.getEqualSlots(); Set<Slot> leftSlots = Sets.intersection(join.left().getOutputSet(), equalSet.getAllItemSet()); Set<Slot> rightSlots = Sets.intersection(join.right().getOutputSet(), equalSet.getAllItemSet()); if (context.isForeignKey(leftSlots) && context.isPrimaryKey(rightSlots)) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java index 7788bbb7f06..5d11a1fa542 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java @@ -17,16 +17,14 @@ package org.apache.doris.nereids.rules.rewrite; -import org.apache.doris.nereids.parser.NereidsParser; -import org.apache.doris.nereids.rules.expression.rules.DateFunctionRewrite; -import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.DateTimeType; import org.apache.doris.nereids.types.DateTimeV2Type; @@ -35,11 +33,15 @@ import org.apache.doris.nereids.types.DateV2Type; import org.apache.doris.nereids.types.coercion.CharacterType; import org.apache.doris.nereids.types.coercion.DateLikeType; import org.apache.doris.nereids.types.coercion.IntegralType; +import org.apache.doris.nereids.util.ImmutableEqualSet; import org.apache.doris.nereids.util.TypeCoercionUtils; -import com.google.common.collect.Sets; - -import java.util.Objects; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -65,59 +67,62 @@ public class PredicatePropagation { } } - private static class EqualInferInfo { - - public final InferType inferType; - public final Expression left; - public final Expression right; - public final ComparisonPredicate comparisonPredicate; - - public EqualInferInfo(InferType inferType, - Expression left, Expression right, - ComparisonPredicate comparisonPredicate) { - this.inferType = inferType; - this.left = left; - this.right = right; - this.comparisonPredicate = comparisonPredicate; - } - } - /** * infer additional predicates. */ public static Set<Expression> infer(Set<Expression> predicates) { - Set<Expression> inferred = Sets.newHashSet(); + ImmutableEqualSet.Builder<Slot> equalSetBuilder = new ImmutableEqualSet.Builder<>(); + Map<Slot, List<Expression>> slotPredicates = new HashMap<>(); + Set<Pair<Slot, Slot>> equalPairs = new HashSet<>(); for (Expression predicate : predicates) { - // if we support more infer predicate expression type, we should impl withInferred() method. - // And should add inferred props in withChildren() method just like ComparisonPredicate, - // and it's subclass, to mark the predicate is from infer. - if (!(predicate instanceof ComparisonPredicate - || (predicate instanceof InPredicate && ((InPredicate) predicate).isLiteralChildren()))) { + Set<Slot> inputSlots = predicate.getInputSlots(); + if (inputSlots.size() == 1) { + if (predicate instanceof ComparisonPredicate + || (predicate instanceof InPredicate && ((InPredicate) predicate).isLiteralChildren())) { + slotPredicates.computeIfAbsent(inputSlots.iterator().next(), k -> new ArrayList<>()).add(predicate); + } continue; } - if (predicate instanceof InPredicate) { - continue; + + if (predicate instanceof EqualTo) { + getEqualSlot(equalSetBuilder, equalPairs, (EqualTo) predicate); } - EqualInferInfo equalInfo = getEqualInferInfo((ComparisonPredicate) predicate); - if (equalInfo.inferType == InferType.NONE) { - continue; + } + + ImmutableEqualSet<Slot> equalSet = equalSetBuilder.build(); + + Set<Expression> inferred = new HashSet<>(); + slotPredicates.forEach((left, exprs) -> { + for (Slot right : equalSet.calEqualSet(left)) { + for (Expression expr : exprs) { + inferred.add(doInferPredicate(left, right, expr)); + } + } + }); + + // infer equal to equal like a = b & b = c -> a = c + // a b c | e f g + // get (a b) (a c) (b c) | (e f) (e g) (f g) + List<Set<Slot>> equalSetList = equalSet.calEqualSetList(); + for (Set<Slot> es : equalSetList) { + List<Slot> el = es.stream().sorted(Comparator.comparingInt(s -> s.getExprId().asInt())) + .collect(Collectors.toList()); + for (int i = 0; i < el.size(); i++) { + Slot left = el.get(i); + for (int j = i + 1; j < el.size(); j++) { + Slot right = el.get(j); + if (!equalPairs.contains(Pair.of(left, right))) { + inferred.add(TypeCoercionUtils.processComparisonPredicate(new EqualTo(left, right)) + .withInferred(true)); + } + } } - Set<Expression> newInferred = predicates.stream() - .filter(p -> !p.equals(predicate)) - .filter(p -> p instanceof ComparisonPredicate || p instanceof InPredicate) - .map(predicateInfo -> doInferPredicate(equalInfo, predicateInfo)) - .filter(Objects::nonNull) - .collect(Collectors.toSet()); - inferred.addAll(newInferred); } - inferred.removeAll(predicates); + return inferred; } - private static Expression doInferPredicate(EqualInferInfo equalInfo, Expression predicate) { - Expression equalLeft = equalInfo.left; - Expression equalRight = equalInfo.right; - + private static Expression doInferPredicate(Expression equalLeft, Expression equalRight, Expression predicate) { DataType leftType = predicate.child(0).getDataType(); InferType inferType; if (leftType instanceof CharacterType) { @@ -160,47 +165,6 @@ public class PredicatePropagation { } } - /** - * Use the left or right child of `leftSlotEqualToRightSlot` to replace the left or right child of `expression` - * Now only support infer `ComparisonPredicate`. - * TODO: We should determine whether `expression` satisfies the condition for replacement - * eg: Satisfy `expression` is non-deterministic - */ - private static Expression doInfer(EqualInferInfo equalInfo, EqualInferInfo predicateInfo) { - Expression equalLeft = equalInfo.left; - Expression equalRight = equalInfo.right; - - Expression predicateLeft = predicateInfo.left; - Expression predicateRight = predicateInfo.right; - Expression newLeft = inferOneSide(predicateLeft, equalLeft, equalRight); - Expression newRight = inferOneSide(predicateRight, equalLeft, equalRight); - if (newLeft == null || newRight == null) { - return null; - } - ComparisonPredicate newPredicate = (ComparisonPredicate) predicateInfo - .comparisonPredicate.withChildren(newLeft, newRight); - Expression expr = SimplifyComparisonPredicate.INSTANCE - .rewrite(TypeCoercionUtils.processComparisonPredicate(newPredicate), null); - return DateFunctionRewrite.INSTANCE.rewrite(expr, null).withInferred(true); - } - - private static Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) { - if (predicateOneSide instanceof SlotReference) { - if (predicateOneSide.equals(equalLeft)) { - return equalRight; - } else if (predicateOneSide.equals(equalRight)) { - return equalLeft; - } - } else if (predicateOneSide.isConstant()) { - if (predicateOneSide instanceof IntegerLikeLiteral) { - return new NereidsParser().parseExpression(((IntegerLikeLiteral) predicateOneSide).toSql()); - } else { - return predicateOneSide; - } - } - return null; - } - private static Optional<Expression> validForInfer(Expression expression, InferType inferType) { if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) { return Optional.empty(); @@ -249,7 +213,7 @@ public class PredicatePropagation { return Optional.empty(); } - private static EqualInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) { + private static Optional<Pair<Expression, Expression>> inferInferInfo(ComparisonPredicate comparisonPredicate) { DataType leftType = comparisonPredicate.left().getDataType(); InferType inferType; if (leftType instanceof CharacterType) { @@ -264,29 +228,21 @@ public class PredicatePropagation { Optional<Expression> left = validForInfer(comparisonPredicate.left(), inferType); Optional<Expression> right = validForInfer(comparisonPredicate.right(), inferType); if (!left.isPresent() || !right.isPresent()) { - inferType = InferType.NONE; + return Optional.empty(); } - return new EqualInferInfo(inferType, left.orElse(comparisonPredicate.left()), - right.orElse(comparisonPredicate.right()), comparisonPredicate); + return Optional.of(Pair.of(left.get(), right.get())); } - /** - * Currently only equivalence derivation is supported - * and requires that the left and right sides of an expression must be slot - * <p> - * TODO: NullSafeEqual - */ - private static EqualInferInfo getEqualInferInfo(ComparisonPredicate predicate) { - if (!(predicate instanceof EqualTo)) { - return new EqualInferInfo(InferType.NONE, predicate.left(), predicate.right(), predicate); - } - EqualInferInfo info = inferInferInfo(predicate); - if (info.inferType == InferType.NONE) { - return info; - } - if (info.left instanceof SlotReference && info.right instanceof SlotReference) { - return info; - } - return new EqualInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate); + private static void getEqualSlot(ImmutableEqualSet.Builder<Slot> equalSlots, Set<Pair<Slot, Slot>> equalPairs, + EqualTo predicate) { + inferInferInfo(predicate) + .filter(info -> info.first instanceof Slot && info.second instanceof Slot) + .ifPresent(pair -> { + Slot left = (Slot) pair.first; + Slot right = (Slot) pair.second; + equalSlots.addEqualPair(left, right); + equalPairs.add(left.getExprId().asInt() <= right.getExprId().asInt() + ? Pair.of(left, right) : Pair.of(right, left)); + }); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java index 76fda759ef0..6c78193abf1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java @@ -38,7 +38,7 @@ import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.algebra.Join; import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; import org.apache.doris.nereids.util.ExpressionUtils; -import org.apache.doris.nereids.util.ImmutableEquivalenceSet; +import org.apache.doris.nereids.util.ImmutableEqualSet; import org.apache.doris.nereids.util.JoinUtils; import org.apache.doris.nereids.util.Utils; @@ -467,12 +467,12 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends /** * get Equal slot from join */ - public ImmutableEquivalenceSet<Slot> getEqualSlots() { + public ImmutableEqualSet<Slot> getEqualSlots() { // TODO: Use fd in the future if (!joinType.isInnerJoin() && !joinType.isSemiJoin()) { - return ImmutableEquivalenceSet.of(); + return ImmutableEqualSet.empty(); } - ImmutableEquivalenceSet.Builder<Slot> builder = new ImmutableEquivalenceSet.Builder<>(); + ImmutableEqualSet.Builder<Slot> builder = new ImmutableEqualSet.Builder<>(); hashJoinConjuncts.stream() .filter(e -> e instanceof EqualPredicate && e.child(0) instanceof Slot diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ImmutableEquivalenceSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ImmutableEqualSet.java similarity index 50% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/util/ImmutableEquivalenceSet.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/util/ImmutableEqualSet.java index 66d20597fc1..724414e2e19 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ImmutableEquivalenceSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ImmutableEqualSet.java @@ -17,64 +17,73 @@ package org.apache.doris.nereids.util; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Set; /** - * EquivalenceSet + * A class representing an immutable set of elements with equivalence relations. */ -public class ImmutableEquivalenceSet<T> { - final Map<T, T> root; +public class ImmutableEqualSet<T> { + private final Map<T, T> root; - ImmutableEquivalenceSet(Map<T, T> root) { + ImmutableEqualSet(Map<T, T> root) { this.root = ImmutableMap.copyOf(root); } - public static <T> ImmutableEquivalenceSet<T> of() { - return new ImmutableEquivalenceSet<>(ImmutableMap.of()); + public static <T> ImmutableEqualSet<T> empty() { + return new ImmutableEqualSet<>(ImmutableMap.of()); } /** - * Builder of ImmutableEquivalenceSet + * Builder for ImmutableEqualSet. */ public static class Builder<T> { - final Map<T, T> parent = new HashMap<>(); + private final Map<T, T> parent = new HashMap<>(); + private final Map<T, Integer> size = new HashMap<>(); + /** + * Add a equal pair + */ public void addEqualPair(T a, T b) { - parent.computeIfAbsent(b, v -> v); - parent.computeIfAbsent(a, v -> v); - union(a, b); - } - - private void union(T a, T b) { T root1 = findRoot(a); T root2 = findRoot(b); if (root1 != root2) { - parent.put(b, root1); - findRoot(b); + // merge by size + if (size.get(root1) < size.get(root2)) { + parent.put(root1, root2); + size.put(root2, size.get(root2) + size.get(root1)); + } else { + parent.put(root2, root1); + size.put(root1, size.get(root1) + size.get(root2)); + } } } private T findRoot(T a) { + parent.putIfAbsent(a, a); // Ensure that the element is added + size.putIfAbsent(a, 1); // Initialize size to 1 + if (!parent.get(a).equals(a)) { - parent.put(a, findRoot(parent.get(a))); + parent.put(a, findRoot(parent.get(a))); // Path compression } return parent.get(a); } - public ImmutableEquivalenceSet<T> build() { + public ImmutableEqualSet<T> build() { parent.keySet().forEach(this::findRoot); - return new ImmutableEquivalenceSet<>(parent); + return new ImmutableEqualSet<>(parent); } } /** - * cal equal set for a except self + * Calculate equal set for a except self */ public Set<T> calEqualSet(T a) { T ra = root.get(a); @@ -83,6 +92,21 @@ public class ImmutableEquivalenceSet<T> { .collect(ImmutableSet.toImmutableSet()); } + /** + * Calculate all equal set + */ + public List<Set<T>> calEqualSetList() { + return root.values() + .stream() + .distinct() + .map(a -> { + T ra = root.get(a); + return root.keySet().stream() + .filter(t -> root.get(t).equals(ra)) + .collect(ImmutableSet.toImmutableSet()); + }).collect(ImmutableList.toImmutableList()); + } + public Set<T> getAllItemSet() { return ImmutableSet.copyOf(root.keySet()); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java index b1aa25df1b1..1efa94451af 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagationTest.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.Literal; @@ -34,6 +35,7 @@ import java.util.Set; class PredicatePropagationTest { private final SlotReference a = new SlotReference("a", SmallIntType.INSTANCE); private final SlotReference b = new SlotReference("b", BigIntType.INSTANCE); + private final SlotReference c = new SlotReference("c", BigIntType.INSTANCE); @Test void equal() { @@ -48,4 +50,18 @@ class PredicatePropagationTest { Set<Expression> inferExprs = PredicatePropagation.infer(exprs); System.out.println(inferExprs); } + + @Test + void inferSlotEqual() { + Set<Expression> exprs = ImmutableSet.of(new EqualTo(a, b), new EqualTo(a, c)); + Set<Expression> inferExprs = PredicatePropagation.infer(exprs); + System.out.println(inferExprs); + } + + @Test + void inferComplex0() { + Set<Expression> exprs = ImmutableSet.of(new EqualTo(a, b), new EqualTo(a, c), new GreaterThan(a, Literal.of(1))); + Set<Expression> inferExprs = PredicatePropagation.infer(exprs); + System.out.println(inferExprs); + } } diff --git a/regression-test/data/nereids_p0/hint/fix_leading.out b/regression-test/data/nereids_p0/hint/fix_leading.out index 58122945bb6..898fe5882b9 100644 --- a/regression-test/data/nereids_p0/hint/fix_leading.out +++ b/regression-test/data/nereids_p0/hint/fix_leading.out @@ -9,7 +9,7 @@ PhysicalResultSink ----------PhysicalDistribute[DistributionSpecHash] ------------PhysicalOlapScan[t2] --------PhysicalDistribute[DistributionSpecHash] -----------NestedLoopJoin[CROSS_JOIN](t4.c4 = t3.c3)(t3.c3 = t4.c4) +----------NestedLoopJoin[CROSS_JOIN](t3.c3 = t4.c4) ------------PhysicalOlapScan[t3] ------------PhysicalDistribute[DistributionSpecReplicated] --------------PhysicalOlapScan[t4] --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org