Skip to content

Commit

Permalink
Merge pull request #4035 from Hccake/3.0
Browse files Browse the repository at this point in the history
🐛 fix gitee I4FP6E, right join bug
  • Loading branch information
qmdx committed Dec 24, 2021
2 parents 31ff5b5 + 5ba2470 commit 2884cd2
Show file tree
Hide file tree
Showing 2 changed files with 342 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,10 @@

import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.LinkedList;
import java.util.List;
Expand Down Expand Up @@ -234,27 +237,55 @@ protected void appendSelectItem(List<SelectItem> selectItems) {
* 处理 PlainSelect
*/
protected void processPlainSelect(PlainSelect plainSelect) {
FromItem fromItem = plainSelect.getFromItem();
Expression where = plainSelect.getWhere();
processWhereSubSelect(where);
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
if (!tenantLineHandler.ignoreTable(fromTable.getName())) {
//#1186 github
plainSelect.setWhere(builderExpression(where, fromTable));
}
} else {
processFromItem(fromItem);
}
//#3087 github
List<SelectItem> selectItems = plainSelect.getSelectItems();
if (CollectionUtils.isNotEmpty(selectItems)) {
selectItems.forEach(this::processSelectItem);
}

// 处理 where 中的子查询
Expression where = plainSelect.getWhere();
processWhereSubSelect(where);

// 处理 fromItem
FromItem fromItem = plainSelect.getFromItem();
List<Table> list = processFromItem(fromItem);
List<Table> mainTables = new ArrayList<>(list);

// 处理 join
List<Join> joins = plainSelect.getJoins();
if (CollectionUtils.isNotEmpty(joins)) {
processJoins(joins);
mainTables = processJoins(mainTables, joins);
}

// 当有 mainTable 时,进行 where 条件追加
if (CollectionUtils.isNotEmpty(mainTables)) {
plainSelect.setWhere(builderExpression(where, mainTables));
}
}

private List<Table> processFromItem(FromItem fromItem) {
// 处理括号括起来的表达式
while (fromItem instanceof ParenthesisFromItem) {
fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
}

List<Table> mainTables = new ArrayList<>();
// 无 join 时的处理逻辑
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
if (!tenantLineHandler.ignoreTable(fromTable.getName())) {
mainTables.add(fromTable);
}
} else if (fromItem instanceof SubJoin) {
// SubJoin 类型则还需要添加上 where 条件
List<Table> tables = processSubJoin((SubJoin) fromItem);
mainTables.addAll(tables);
} else {
// 处理下 fromItem
processOtherFromItem(fromItem);
}
return mainTables;
}

/**
Expand Down Expand Up @@ -282,7 +313,7 @@ protected void processWhereSubSelect(Expression where) {
return;
}
if (where instanceof FromItem) {
processFromItem((FromItem) where);
processOtherFromItem((FromItem) where);
return;
}
if (where.toString().indexOf("SELECT") > 0) {
Expand Down Expand Up @@ -348,16 +379,13 @@ protected void processFunction(Function function) {
/**
* 处理子查询等
*/
protected void processFromItem(FromItem fromItem) {
if (fromItem instanceof SubJoin) {
SubJoin subJoin = (SubJoin) fromItem;
if (subJoin.getJoinList() != null) {
processJoins(subJoin.getJoinList());
}
if (subJoin.getLeft() != null) {
processFromItem(subJoin.getLeft());
}
} else if (fromItem instanceof SubSelect) {
protected void processOtherFromItem(FromItem fromItem) {
// 去除括号
while (fromItem instanceof ParenthesisFromItem) {
fromItem = ((ParenthesisFromItem) fromItem).getFromItem();
}

if (fromItem instanceof SubSelect) {
SubSelect subSelect = (SubSelect) fromItem;
if (subSelect.getSelectBody() != null) {
processSelectBody(subSelect.getSelectBody());
Expand All @@ -375,82 +403,160 @@ protected void processFromItem(FromItem fromItem) {
}
}

/**
* 处理 sub join
*
* @param subJoin subJoin
* @return Table subJoin 中的主表
*/
private List<Table> processSubJoin(SubJoin subJoin) {
List<Table> mainTables = new ArrayList<>();
if (subJoin.getJoinList() != null) {
List<Table> list = processFromItem(subJoin.getLeft());
mainTables.addAll(list);
mainTables = processJoins(mainTables, subJoin.getJoinList());
}
return mainTables;
}

/**
* 处理 joins
*
* @param joins join 集合
* @param mainTables 可以为 null
* @param joins join 集合
* @return List<Table> 右连接查询的 Table 列表
*/
private void processJoins(List<Join> joins) {
private List<Table> processJoins(List<Table> mainTables, List<Join> joins) {
if (mainTables == null) {
mainTables = new ArrayList<>();
}

// join 表达式中最终的主表
Table mainTable = null;
// 当前 join 的左表
Table leftTable = null;
if (mainTables.size() == 1) {
mainTable = mainTables.get(0);
leftTable = mainTable;
}

//对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名
Deque<Table> tables = new LinkedList<>();
Deque<List<Table>> onTableDeque = new LinkedList<>();
for (Join join : joins) {
// 处理 on 表达式
FromItem fromItem = join.getRightItem();
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
FromItem joinItem = join.getRightItem();

// 获取当前 join 的表,subJoint 可以看作是一张表
List<Table> joinTables = null;
if (joinItem instanceof Table) {
joinTables = new ArrayList<>();
joinTables.add((Table) joinItem);
} else if (joinItem instanceof SubJoin) {
joinTables = processSubJoin((SubJoin) joinItem);
}

if (joinTables != null) {

// 如果是隐式内连接
if (join.isSimple()) {
mainTables.addAll(joinTables);
continue;
}

// 当前表是否忽略
Table joinTable = joinTables.get(0);
boolean joinTableNeedIgnore = tenantLineHandler.ignoreTable(joinTable.getName());

List<Table> onTables = null;
// 如果不要忽略,且是右连接,则记录下当前表
if (join.isRight()) {
mainTable = joinTableNeedIgnore ? null : joinTable;
if (leftTable != null) {
onTables = Collections.singletonList(leftTable);
}
} else if (join.isLeft()) {
if (!joinTableNeedIgnore) {
onTables = Collections.singletonList(joinTable);
}
} else if (join.isInner()) {
if (mainTable == null) {
onTables = Collections.singletonList(joinTable);
} else {
onTables = Arrays.asList(mainTable, joinTable);
}
mainTable = null;
}
mainTables = new ArrayList<>();
if (mainTable != null) {
mainTables.add(mainTable);
}

// 获取 join 尾缀的 on 表达式列表
Collection<Expression> originOnExpressions = join.getOnExpressions();
// 正常 join on 表达式只有一个,立刻处理
if (originOnExpressions.size() == 1) {
processJoin(join);
if (originOnExpressions.size() == 1 && onTables != null) {
List<Expression> onExpressions = new LinkedList<>();
onExpressions.add(builderExpression(originOnExpressions.iterator().next(), onTables));
join.setOnExpressions(onExpressions);
leftTable = joinTable;
continue;
}
// 当前表是否忽略
boolean needIgnore = tenantLineHandler.ignoreTable(fromTable.getName());
// 表名压栈,忽略的表压入 null,以便后续不处理
tables.push(needIgnore ? null : fromTable);
onTableDeque.push(onTables);
// 尾缀多个 on 表达式的时候统一处理
if (originOnExpressions.size() > 1) {
Collection<Expression> onExpressions = new LinkedList<>();
for (Expression originOnExpression : originOnExpressions) {
Table currentTable = tables.poll();
if (currentTable == null) {
List<Table> currentTableList = onTableDeque.poll();
if (CollectionUtils.isEmpty(currentTableList)) {
onExpressions.add(originOnExpression);
} else {
onExpressions.add(builderExpression(originOnExpression, currentTable));
onExpressions.add(builderExpression(originOnExpression, currentTableList));
}
}
join.setOnExpressions(onExpressions);
}
leftTable = joinTable;
} else {
// 处理右边连接的子表达式
processFromItem(fromItem);
processOtherFromItem(joinItem);
leftTable = null;
}

}

return mainTables;
}

/**
* 处理联接语句
* 处理条件
*/
protected void processJoin(Join join) {
if (join.getRightItem() instanceof Table) {
Table fromTable = (Table) join.getRightItem();
if (tenantLineHandler.ignoreTable(fromTable.getName())) {
// 过滤退出执行
return;
protected Expression builderExpression(Expression currentExpression, List<Table> tables) {
// 没有表需要处理直接返回
if (CollectionUtils.isEmpty(tables)) {
return currentExpression;
}
// 租户
Expression tenantId = tenantLineHandler.getTenantId();
// 构造每张表的条件
List<EqualsTo> equalsTos = tables.stream()
.map(item -> new EqualsTo(getAliasColumn(item), tenantId))
.collect(Collectors.toList());
// 注入的表达式
Expression injectExpression = equalsTos.get(0);
// 如果有多表,则用 and 连接
if (equalsTos.size() > 1) {
for (int i = 1; i < equalsTos.size(); i++) {
injectExpression = new AndExpression(injectExpression, equalsTos.get(i));
}
// 走到这里说明 on 表达式肯定只有一个
Collection<Expression> originOnExpressions = join.getOnExpressions();
List<Expression> onExpressions = new LinkedList<>();
onExpressions.add(builderExpression(originOnExpressions.iterator().next(), fromTable));
join.setOnExpressions(onExpressions);
}
}

/**
* 处理条件
*/
protected Expression builderExpression(Expression currentExpression, Table table) {
EqualsTo equalsTo = new EqualsTo();
equalsTo.setLeftExpression(this.getAliasColumn(table));
equalsTo.setRightExpression(tenantLineHandler.getTenantId());
if (currentExpression == null) {
return equalsTo;
return injectExpression;
}
if (currentExpression instanceof OrExpression) {
return new AndExpression(new Parenthesis(currentExpression), equalsTo);
return new AndExpression(new Parenthesis(currentExpression), injectExpression);
} else {
return new AndExpression(currentExpression, equalsTo);
return new AndExpression(currentExpression, injectExpression);
}
}

Expand All @@ -463,10 +569,13 @@ protected Expression builderExpression(Expression currentExpression, Table table
*/
protected Column getAliasColumn(Table table) {
StringBuilder column = new StringBuilder();
// 为了兼容隐式内连接,没有别名时条件就需要加上表名
if (table.getAlias() != null) {
column.append(table.getAlias().getName()).append(StringPool.DOT);
column.append(table.getAlias().getName());
} else {
column.append(table.getName());
}
column.append(tenantLineHandler.getTenantIdColumn());
column.append(StringPool.DOT).append(tenantLineHandler.getTenantIdColumn());
return new Column(column.toString());
}

Expand Down
Loading

0 comments on commit 2884cd2

Please sign in to comment.