Skip to content

Commit

Permalink
🐛 fix gitee I4FP6E, right join bug
Browse files Browse the repository at this point in the history
  • Loading branch information
Hccake committed Nov 9, 2021
1 parent 8ebc873 commit b350147
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@

import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Deque;
import java.util.LinkedList;
Expand Down Expand Up @@ -235,26 +236,34 @@ protected void appendSelectItem(List<SelectItem> selectItems) {
*/
protected void processPlainSelect(PlainSelect plainSelect) {
FromItem fromItem = plainSelect.getFromItem();

//#3087 github
List<SelectItem> selectItems = plainSelect.getSelectItems();
if (CollectionUtils.isNotEmpty(selectItems)) {
selectItems.forEach(this::processSelectItem);
}

// #I4FP6E gitee:右连接查询时,where 条件需要过滤
List<Table> rightJointTables;
List<Join> joins = plainSelect.getJoins();
if (CollectionUtils.isNotEmpty(joins)) {
rightJointTables = processJoins(joins);
}else {
rightJointTables = new ArrayList<>();
}

Expression where = plainSelect.getWhere();
processWhereSubSelect(where);
if (fromItem instanceof Table) {
Table fromTable = (Table) fromItem;
if (!tenantLineHandler.ignoreTable(fromTable.getName())) {
//#1186 github
plainSelect.setWhere(builderExpression(where, fromTable));
plainSelect.setWhere(builderExpression(where, fromTable, rightJointTables));
}
} else {
processFromItem(fromItem);
}
//#3087 github
List<SelectItem> selectItems = plainSelect.getSelectItems();
if (CollectionUtils.isNotEmpty(selectItems)) {
selectItems.forEach(this::processSelectItem);
}
List<Join> joins = plainSelect.getJoins();
if (CollectionUtils.isNotEmpty(joins)) {
processJoins(joins);
}

}

/**
Expand Down Expand Up @@ -379,8 +388,12 @@ protected void processFromItem(FromItem fromItem) {
* 处理 joins
*
* @param joins join 集合
* @return List<Table> 右连接查询的 Table 列表
*/
private void processJoins(List<Join> joins) {
private List<Table> processJoins(List<Join> joins) {

List<Table> rightJointTables = new ArrayList<>();

//对于 on 表达式写在最后的 join,需要记录下前面多个 on 的表名
Deque<Table> tables = new LinkedList<>();
for (Join join : joins) {
Expand All @@ -390,13 +403,20 @@ private void processJoins(List<Join> joins) {
Table fromTable = (Table) fromItem;
// 获取 join 尾缀的 on 表达式列表
Collection<Expression> originOnExpressions = join.getOnExpressions();

// 当前表是否忽略
boolean needIgnore = tenantLineHandler.ignoreTable(fromTable.getName());
// 如果不要忽略,且是右连接,则记录下当前表
if (!needIgnore && join.isRight()) {
rightJointTables.add(fromTable);
}

// 正常 join on 表达式只有一个,立刻处理
if (originOnExpressions.size() == 1) {
processJoin(join);
continue;
}
// 当前表是否忽略
boolean needIgnore = tenantLineHandler.ignoreTable(fromTable.getName());

// 表名压栈,忽略的表压入 null,以便后续不处理
tables.push(needIgnore ? null : fromTable);
// 尾缀多个 on 表达式的时候统一处理
Expand All @@ -417,6 +437,8 @@ private void processJoins(List<Join> joins) {
processFromItem(fromItem);
}
}

return rightJointTables;
}

/**
Expand All @@ -441,16 +463,30 @@ protected void processJoin(Join join) {
* 处理条件
*/
protected Expression builderExpression(Expression currentExpression, Table table) {
EqualsTo equalsTo = new EqualsTo();
equalsTo.setLeftExpression(this.getAliasColumn(table));
equalsTo.setRightExpression(tenantLineHandler.getTenantId());
return builderExpression(currentExpression, table, new ArrayList<>());
}

/**
* 处理条件
*/
protected Expression builderExpression(Expression currentExpression, Table table, List<Table> rightJointTables) {
// 租户
Expression tenantId = tenantLineHandler.getTenantId();
// 注入的表达式
Expression injectExpression = new EqualsTo(getAliasColumn(table), tenantId);
// 如果有右连接的主表,则添加对应主表的 where 条件
for (Table rightJointTable : rightJointTables) {
EqualsTo rightExpression = new EqualsTo(getAliasColumn(rightJointTable), tenantId);
injectExpression = new AndExpression(injectExpression, rightExpression);
}

if (currentExpression == null) {
return equalsTo;
return injectExpression;
}
if (currentExpression instanceof OrExpression) {
return new AndExpression(new Parenthesis(currentExpression), equalsTo);
return new AndExpression(new Parenthesis(currentExpression), injectExpression);
} else {
return new AndExpression(currentExpression, equalsTo);
return new AndExpression(currentExpression, injectExpression);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,14 +187,22 @@ void selectRightJoin() {
"right join entity1 e1 on e1.id = e.id",
"SELECT * FROM entity e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE e.tenant_id = 1");
"WHERE e.tenant_id = 1 AND e1.tenant_id = 1");

assertSql("SELECT * FROM entity e " +
"right join entity1 e1 on e1.id = e.id " +
"WHERE e.id = ? OR e.name = ?",
"SELECT * FROM entity e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1");
"WHERE (e.id = ? OR e.name = ?) AND e.tenant_id = 1 AND e1.tenant_id = 1");

assertSql("SELECT * FROM entity e " +
"right join entity1 e1 on e1.id = e.id " +
"right join entity2 e2 on e1.id = e2.id ",
"SELECT * FROM entity e " +
"RIGHT JOIN entity1 e1 ON e1.id = e.id AND e1.tenant_id = 1 " +
"RIGHT JOIN entity2 e2 ON e1.id = e2.id AND e2.tenant_id = 1 " +
"WHERE e.tenant_id = 1 AND e1.tenant_id = 1 AND e2.tenant_id = 1");
}

@Test
Expand Down

0 comments on commit b350147

Please sign in to comment.