Skip to content

Commit

Permalink
[enhance](Nereids): handle project of OuterJoin in Reorder. (apache#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwener committed Apr 27, 2023
1 parent 0f89564 commit a35fc02
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 121 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.doris.nereids.rules.exploration.join.LogicalJoinSemiJoinTranspose;
import org.apache.doris.nereids.rules.exploration.join.LogicalJoinSemiJoinTransposeProject;
import org.apache.doris.nereids.rules.exploration.join.OuterJoinAssoc;
import org.apache.doris.nereids.rules.exploration.join.OuterJoinAssocProject;
import org.apache.doris.nereids.rules.exploration.join.OuterJoinLAsscom;
import org.apache.doris.nereids.rules.exploration.join.OuterJoinLAsscomProject;
import org.apache.doris.nereids.rules.exploration.join.PushdownProjectThroughInnerJoin;
Expand Down Expand Up @@ -164,6 +165,7 @@ public class RuleSet {
.add(JoinExchange.INSTANCE)
.add(JoinExchangeBothProject.INSTANCE)
.add(OuterJoinAssoc.INSTANCE)
.add(OuterJoinAssocProject.INSTANCE)
.build();

public List<Rule> getOtherReorderRules() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,6 @@
* Common
*/
public class CBOUtils {
/**
* Split project according to whether namedExpr contains by splitChildExprIds.
* Notice: projects must all be Slot.
*/
public static Map<Boolean, List<NamedExpression>> splitProject(List<NamedExpression> projects,
Set<ExprId> splitChildExprIds) {
return projects.stream()
.collect(Collectors.partitioningBy(expr -> {
Slot slot = (Slot) expr;
return splitChildExprIds.contains(slot.getExprId());
}));
}

/**
* If projects is empty or project output equal plan output, return the original plan.
*/
Expand All @@ -58,23 +45,6 @@ public static Plan projectOrSelf(List<NamedExpression> projects, Plan plan) {
return new LogicalProject<>(projects, plan);
}

/**
* When project not empty, we add all slots used by hashOnCondition into projects.
*/
public static void addSlotsUsedByOn(Set<Slot> usedSlots, List<NamedExpression> projects) {
if (projects.isEmpty()) {
return;
}
Set<ExprId> projectExprIdSet = projects.stream()
.map(NamedExpression::getExprId)
.collect(Collectors.toSet());
usedSlots.forEach(slot -> {
if (!projectExprIdSet.contains(slot.getExprId())) {
projects.add(slot);
}
});
}

public static Set<Slot> joinChildConditionSlots(LogicalJoin<? extends Plan, ? extends Plan> join, boolean left) {
Set<Slot> childSlots = left ? join.left().getOutputSet() : join.right().getOutputSet();
return join.getConditionSlot().stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,18 @@
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;

import com.google.common.collect.ImmutableSet;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* OuterJoinAssocProject.
Expand Down Expand Up @@ -68,12 +63,10 @@ public Rule build() {
.thenApply(ctx -> {
LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan> topJoin = ctx.root;
/* ********** init ********** */
List<NamedExpression> projects = topJoin.left().getProjects();
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left().child();
GroupPlan a = bottomJoin.left();
GroupPlan b = bottomJoin.right();
GroupPlan c = topJoin.right();
Set<ExprId> aOutputExprIds = a.getOutputExprIdSet();

/*
* Paper `On the Correct and Complete Enumeration of the Core Search Space`.
Expand All @@ -92,39 +85,15 @@ public Rule build() {
}
}

/* ********** Split projects ********** */
Map<Boolean, List<NamedExpression>> map = CBOUtils.splitProject(projects, aOutputExprIds);
List<NamedExpression> aProjects = map.get(true);
List<NamedExpression> bProjects = map.get(false);
if (bProjects.isEmpty()) {
return null;
}
Set<ExprId> aProjectsExprIds = aProjects.stream().map(NamedExpression::getExprId)
.collect(Collectors.toSet());

// topJoin condition can't contain aProject. just can (B C)
if (Stream.concat(topJoin.getHashJoinConjuncts().stream(), topJoin.getOtherJoinConjuncts().stream())
.anyMatch(expr -> Utils.isIntersecting(expr.getInputSlotExprIds(), aProjectsExprIds))) {
return null;
}

// Add all slots used by OnCondition when projects not empty.
Map<Boolean, Set<Slot>> abOnUsedSlots = Stream.concat(
bottomJoin.getHashJoinConjuncts().stream(),
bottomJoin.getHashJoinConjuncts().stream())
.flatMap(onExpr -> onExpr.getInputSlots().stream())
.collect(Collectors.partitioningBy(
slot -> aOutputExprIds.contains(slot.getExprId()), Collectors.toSet()));
CBOUtils.addSlotsUsedByOn(abOnUsedSlots.get(true), aProjects);
CBOUtils.addSlotsUsedByOn(abOnUsedSlots.get(false), bProjects);

bProjects.addAll(OuterJoinLAsscomProject.forceToNullable(c.getOutputSet()));
/* ********** new Plan ********** */
LogicalJoin newBottomJoin = topJoin.withChildrenNoContext(b, c);
newBottomJoin.getJoinReorderContext().copyFrom(bottomJoin.getJoinReorderContext());

Plan left = CBOUtils.projectOrSelf(aProjects, a);
Plan right = CBOUtils.projectOrSelf(bProjects, newBottomJoin);
Set<ExprId> topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet());
bottomJoin.getHashJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds()));
bottomJoin.getOtherJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds()));
Plan left = CBOUtils.newProject(topUsedExprIds, a);
Plan right = CBOUtils.newProject(topUsedExprIds, newBottomJoin);

LogicalJoin newTopJoin = bottomJoin.withChildrenNoContext(left, right);
newTopJoin.getJoinReorderContext().copyFrom(topJoin.getJoinReorderContext());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ public Rule build() {
* <p>
* Same with OtherJoinConjunct.
*/
private boolean checkCondition(LogicalJoin<? extends Plan, GroupPlan> topJoin, Set<ExprId> bOutputExprIdSet) {
public static boolean checkCondition(LogicalJoin<? extends Plan, GroupPlan> topJoin, Set<ExprId> bOutputExprIdSet) {
return Stream.concat(
topJoin.getHashJoinConjuncts().stream(),
topJoin.getOtherJoinConjuncts().stream())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,13 @@
import org.apache.doris.nereids.rules.exploration.CBOUtils;
import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.util.Utils;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;

/**
* Rule for change inner join LAsscom (associative and commutive).
Expand All @@ -61,52 +54,27 @@ public Rule build() {
.when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left().child()))
.whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint())
.whenNot(join -> join.isMarkJoin() || join.left().child().isMarkJoin())
.when(topJoin -> OuterJoinLAsscom.checkCondition(topJoin,
topJoin.left().child().right().getOutputExprIdSet()))
.when(join -> join.left().isAllSlots())
.then(topJoin -> {
/* ********** init ********** */
List<NamedExpression> projects = topJoin.left().getProjects();
LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left().child();
GroupPlan a = bottomJoin.left();
GroupPlan b = bottomJoin.right();
GroupPlan c = topJoin.right();
Set<ExprId> aOutputExprIds = a.getOutputExprIdSet();

/* ********** Split projects ********** */
Map<Boolean, List<NamedExpression>> map = CBOUtils.splitProject(projects, aOutputExprIds);
List<NamedExpression> aProjects = map.get(true);
if (aProjects.isEmpty()) {
return null;
}
List<NamedExpression> bProjects = map.get(false);
Set<ExprId> bProjectsExprIds = bProjects.stream().map(NamedExpression::getExprId)
.collect(Collectors.toSet());

// topJoin condition can't contain bProject output. just can (A C)
if (Stream.concat(topJoin.getHashJoinConjuncts().stream(), topJoin.getOtherJoinConjuncts().stream())
.anyMatch(expr -> Utils.isIntersecting(expr.getInputSlotExprIds(), bProjectsExprIds))) {
return null;
}

// Add all slots used by OnCondition when projects not empty.
Map<Boolean, Set<Slot>> abOnUsedSlots = Stream.concat(
bottomJoin.getHashJoinConjuncts().stream(),
bottomJoin.getHashJoinConjuncts().stream())
.flatMap(onExpr -> onExpr.getInputSlots().stream())
.collect(Collectors.partitioningBy(
slot -> aOutputExprIds.contains(slot.getExprId()), Collectors.toSet()));
CBOUtils.addSlotsUsedByOn(abOnUsedSlots.get(true), aProjects);
CBOUtils.addSlotsUsedByOn(abOnUsedSlots.get(false), bProjects);

aProjects.addAll(forceToNullable(c.getOutputSet()));

/* ********** new Plan ********** */
LogicalJoin newBottomJoin = topJoin.withChildrenNoContext(a, c);
newBottomJoin.getJoinReorderContext().copyFrom(bottomJoin.getJoinReorderContext());
newBottomJoin.getJoinReorderContext().setHasLAsscom(false);
newBottomJoin.getJoinReorderContext().setHasCommute(false);

Plan left = CBOUtils.projectOrSelf(aProjects, newBottomJoin);
Plan right = CBOUtils.projectOrSelf(bProjects, b);
Set<ExprId> topUsedExprIds = new HashSet<>(topJoin.getOutputExprIdSet());
bottomJoin.getHashJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds()));
bottomJoin.getOtherJoinConjuncts().forEach(e -> topUsedExprIds.addAll(e.getInputSlotExprIds()));
Plan left = CBOUtils.newProject(topUsedExprIds, newBottomJoin);
Plan right = CBOUtils.newProject(topUsedExprIds, b);

LogicalJoin newTopJoin = bottomJoin.withChildrenNoContext(left, right);
newTopJoin.getJoinReorderContext().copyFrom(topJoin.getJoinReorderContext());
Expand All @@ -115,17 +83,4 @@ public Rule build() {
return CBOUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin);
}).toRule(RuleType.LOGICAL_OUTER_JOIN_LASSCOM_PROJECT);
}

/**
* Force all slots in set to nullable.
*/
public static Set<Slot> forceToNullable(Set<Slot> slotSet) {
return slotSet.stream().map(s -> (Slot) s.rewriteUp(e -> {
if (e instanceof SlotReference) {
return ((SlotReference) e).withNullable(true);
} else {
return e;
}
})).collect(Collectors.toSet());
}
}

0 comments on commit a35fc02

Please sign in to comment.