Skip to content

Commit

Permalink
[feat][dingo-executor] Support multiple hybridSearch in SQL
Browse files Browse the repository at this point in the history
  • Loading branch information
githubgxll authored and ketor committed Oct 9, 2024
1 parent e2f258a commit 01c3663
Show file tree
Hide file tree
Showing 3 changed files with 87 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
import org.apache.calcite.sql2rel.SqlVectorOperator;
import org.checkerframework.checker.nullness.qual.Nullable;

import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

public class DingoSqlValidator extends SqlValidatorImpl {

@Getter
Expand All @@ -46,6 +49,8 @@ public class DingoSqlValidator extends SqlValidatorImpl {
@Getter
@Setter
private String hybridSearchSql;
@Getter
private Map<SqlBasicCall, String> hybridSearchMap;

static Config CONFIG = Config.DEFAULT
.withConformance(DingoParser.PARSER_CONFIG.conformance());
Expand All @@ -66,6 +71,7 @@ public class DingoSqlValidator extends SqlValidatorImpl {
);
this.hybridSearch = false;
this.hybridSearchSql = "";
this.hybridSearchMap = new ConcurrentHashMap<>();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,9 +162,9 @@ protected RelDataType validateImpl(RelDataType targetRowType) {
);
this.rowType = rowType;

if (((DingoSqlValidator)validator).isHybridSearch()) {
throw new RuntimeException("Multiple hybridSearch in SQL is not supported");
}
// if (((DingoSqlValidator)validator).isHybridSearch()) {
// throw new RuntimeException("Multiple hybridSearch in SQL is not supported");
// }
String sql = HybridSearchSqlUtils.hybridSearchSqlReplace(
vectorWeight,
documentWeight,
Expand All @@ -179,6 +179,7 @@ protected RelDataType validateImpl(RelDataType targetRowType) {
);
((DingoSqlValidator)validator).setHybridSearch(true);
((DingoSqlValidator)validator).setHybridSearchSql(sql);
((DingoSqlValidator)validator).getHybridSearchMap().put(this.function, sql);
return rowType;
} else {
throw new RuntimeException("unsupported operator type.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@
import io.dingodb.common.ExecuteVariables;
import io.dingodb.common.Location;
import io.dingodb.common.ProcessInfo;
import io.dingodb.common.audit.DingoAudit;
import io.dingodb.common.config.DingoConfiguration;
import io.dingodb.common.environment.ExecutionEnvironment;
import io.dingodb.common.audit.DingoAudit;
import io.dingodb.common.log.LogUtils;
import io.dingodb.common.metrics.DingoMetrics;
import io.dingodb.common.profile.CommitProfile;
Expand Down Expand Up @@ -116,6 +116,7 @@
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import javax.annotation.Nonnull;
Expand Down Expand Up @@ -381,23 +382,40 @@ public Meta.Signature parseQuery(
}
if (statementType == Meta.StatementType.SELECT) {
if (((DingoSqlValidator)validator).isHybridSearch()) {
String hybridSearchSql = ((DingoSqlValidator)validator).getHybridSearchSql();
LogUtils.info(log, "HybridSearchSql: {}", hybridSearchSql);
SqlNode hybridSqlNode;
try {
hybridSqlNode = parse(hybridSearchSql);
} catch (SqlParseException e) {
throw ExceptionUtils.toRuntime(e);
}
syntacticSugar(hybridSqlNode);
SqlNode originalSqlNode;
try {
originalSqlNode = parse(sql);
} catch (SqlParseException e) {
throw ExceptionUtils.toRuntime(e);
}
syntacticSugar(originalSqlNode);
lockUpHybridSearchNode(originalSqlNode, hybridSqlNode);
if(((DingoSqlValidator) validator).getHybridSearchMap().size() == 1) {
String hybridSearchSql = ((DingoSqlValidator)validator).getHybridSearchSql();
LogUtils.info(log, "HybridSearchSql: {}", hybridSearchSql);
SqlNode hybridSqlNode;
try {
hybridSqlNode = parse(hybridSearchSql);
} catch (SqlParseException e) {
throw ExceptionUtils.toRuntime(e);
}
syntacticSugar(hybridSqlNode);
lockUpHybridSearchNode(originalSqlNode, hybridSqlNode);
} else {
ConcurrentHashMap<SqlBasicCall, SqlNode> sqlNodeHashMap = new ConcurrentHashMap<>();
for (Map.Entry<SqlBasicCall, String> entry : ((DingoSqlValidator) validator).getHybridSearchMap().entrySet()) {
SqlBasicCall key = entry.getKey();
String value = entry.getValue();
SqlNode hybridSqlNode;
try {
hybridSqlNode = parse(value);
} catch (SqlParseException e) {
throw ExceptionUtils.toRuntime(e);
}
syntacticSugar(hybridSqlNode);
sqlNodeHashMap.put(key, hybridSqlNode);
}
lockUpHybridSearchNode(originalSqlNode, sqlNodeHashMap);
}
LogUtils.info(log, "HybridSearch Rewrite Sql: {}", originalSqlNode.toString());
if (originalSqlNode.getKind().equals(SqlKind.EXPLAIN)) {
assert originalSqlNode instanceof SqlExplain;
Expand Down Expand Up @@ -920,6 +938,53 @@ private void lockUpHybridSearchNode(SqlNode sqlNode, SqlNode subSqlNode) {
}
}

private void lockUpHybridSearchNode(SqlNode sqlNode, ConcurrentHashMap<SqlBasicCall, SqlNode> subSqlNode) {
if (sqlNode instanceof SqlSelect) {
SqlNode from = ((SqlSelect) sqlNode).getFrom();
if (from instanceof SqlBasicCall && (((SqlBasicCall) from).getOperator() instanceof SqlHybridSearchOperator)) {
SqlBasicCall removeKey = null;
for (Map.Entry<SqlBasicCall, SqlNode> entry : subSqlNode.entrySet()) {
SqlBasicCall key = entry.getKey();
SqlNode value = entry.getValue();
if (from.toString().equals(key.toString())) {
((SqlSelect) sqlNode).setFrom(value);
removeKey = key;
break;
}
}
if (removeKey != null) {
subSqlNode.remove(removeKey);
}
} else {
if (from instanceof SqlJoin) {
lockUpHybridSearchNode(((SqlJoin) from).getLeft(), subSqlNode);
lockUpHybridSearchNode(((SqlJoin) from).getRight(), subSqlNode);
}
}
} else if (sqlNode instanceof SqlJoin) {
lockUpHybridSearchNode(((SqlJoin) sqlNode).getLeft(), subSqlNode);
lockUpHybridSearchNode(((SqlJoin) sqlNode).getRight(), subSqlNode);
} else if (sqlNode instanceof SqlBasicCall) {
if (((SqlBasicCall) sqlNode).getOperator() instanceof SqlAsOperator) {
deepLockUpChildren(((SqlBasicCall) sqlNode).getOperandList(), subSqlNode);
}
} else if (sqlNode instanceof SqlOrderBy) {
lockUpHybridSearchNode(((SqlOrderBy)sqlNode).query, subSqlNode);
} else if (sqlNode instanceof SqlExplain) {
lockUpHybridSearchNode(((SqlExplain) sqlNode).getExplicandum(), subSqlNode);
}

}

private void deepLockUpChildren(List<SqlNode> sqlNodes, ConcurrentHashMap<SqlBasicCall, SqlNode> subSqlNode) {
if (sqlNodes == null) {
return;
}
for (int i = 0; i < sqlNodes.size(); i ++) {
lockUpHybridSearchNode(sqlNodes.get(i), subSqlNode);
}
}

private void deepLockUpChildren(List<SqlNode> sqlNodes, SqlNode subSqlNode) {
if (sqlNodes == null) {
return;
Expand All @@ -928,6 +993,7 @@ private void deepLockUpChildren(List<SqlNode> sqlNodes, SqlNode subSqlNode) {
lockUpHybridSearchNode(sqlNodes.get(i), subSqlNode);
}
}

private void syntacticSugar(SqlNode sqlNode) {
if (sqlNode instanceof SqlSelect) {
SqlNodeList sqlNodes = ((SqlSelect) sqlNode).getSelectList();
Expand Down

0 comments on commit 01c3663

Please sign in to comment.