From 01c36637f3015693fbb08380f3d741b67026ef70 Mon Sep 17 00:00:00 2001 From: githubgxll <1094462054@qq.com> Date: Wed, 9 Oct 2024 17:25:34 +0800 Subject: [PATCH] [feat][dingo-executor] Support multiple hybridSearch in SQL --- .../io/dingodb/calcite/DingoSqlValidator.java | 6 ++ .../TableHybridFunctionNamespace.java | 7 +- .../io/dingodb/driver/DingoDriverParser.java | 88 ++++++++++++++++--- 3 files changed, 87 insertions(+), 14 deletions(-) diff --git a/dingo-calcite/src/main/java/io/dingodb/calcite/DingoSqlValidator.java b/dingo-calcite/src/main/java/io/dingodb/calcite/DingoSqlValidator.java index 15d420a42..b76b1ade8 100644 --- a/dingo-calcite/src/main/java/io/dingodb/calcite/DingoSqlValidator.java +++ b/dingo-calcite/src/main/java/io/dingodb/calcite/DingoSqlValidator.java @@ -38,6 +38,9 @@ import org.apache.calcite.sql2rel.SqlVectorOperator; import org.checkerframework.checker.nullness.qual.Nullable; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + public class DingoSqlValidator extends SqlValidatorImpl { @Getter @@ -46,6 +49,8 @@ public class DingoSqlValidator extends SqlValidatorImpl { @Getter @Setter private String hybridSearchSql; + @Getter + private Map hybridSearchMap; static Config CONFIG = Config.DEFAULT .withConformance(DingoParser.PARSER_CONFIG.conformance()); @@ -66,6 +71,7 @@ public class DingoSqlValidator extends SqlValidatorImpl { ); this.hybridSearch = false; this.hybridSearchSql = ""; + this.hybridSearchMap = new ConcurrentHashMap<>(); } @Override diff --git a/dingo-calcite/src/main/java/org/apache/calcite/sql/validate/TableHybridFunctionNamespace.java b/dingo-calcite/src/main/java/org/apache/calcite/sql/validate/TableHybridFunctionNamespace.java index ff2430348..39ce5935c 100644 --- a/dingo-calcite/src/main/java/org/apache/calcite/sql/validate/TableHybridFunctionNamespace.java +++ b/dingo-calcite/src/main/java/org/apache/calcite/sql/validate/TableHybridFunctionNamespace.java @@ -162,9 +162,9 @@ protected RelDataType validateImpl(RelDataType targetRowType) { ); this.rowType = rowType; - if (((DingoSqlValidator)validator).isHybridSearch()) { - throw new RuntimeException("Multiple hybridSearch in SQL is not supported"); - } +// if (((DingoSqlValidator)validator).isHybridSearch()) { +// throw new RuntimeException("Multiple hybridSearch in SQL is not supported"); +// } String sql = HybridSearchSqlUtils.hybridSearchSqlReplace( vectorWeight, documentWeight, @@ -179,6 +179,7 @@ protected RelDataType validateImpl(RelDataType targetRowType) { ); ((DingoSqlValidator)validator).setHybridSearch(true); ((DingoSqlValidator)validator).setHybridSearchSql(sql); + ((DingoSqlValidator)validator).getHybridSearchMap().put(this.function, sql); return rowType; } else { throw new RuntimeException("unsupported operator type."); diff --git a/dingo-driver/host/src/main/java/io/dingodb/driver/DingoDriverParser.java b/dingo-driver/host/src/main/java/io/dingodb/driver/DingoDriverParser.java index f58c2ae40..cd2098945 100644 --- a/dingo-driver/host/src/main/java/io/dingodb/driver/DingoDriverParser.java +++ b/dingo-driver/host/src/main/java/io/dingodb/driver/DingoDriverParser.java @@ -44,9 +44,9 @@ import io.dingodb.common.ExecuteVariables; import io.dingodb.common.Location; import io.dingodb.common.ProcessInfo; +import io.dingodb.common.audit.DingoAudit; import io.dingodb.common.config.DingoConfiguration; import io.dingodb.common.environment.ExecutionEnvironment; -import io.dingodb.common.audit.DingoAudit; import io.dingodb.common.log.LogUtils; import io.dingodb.common.metrics.DingoMetrics; import io.dingodb.common.profile.CommitProfile; @@ -116,6 +116,7 @@ import java.util.Map; import java.util.Objects; import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import javax.annotation.Nonnull; @@ -381,15 +382,6 @@ public Meta.Signature parseQuery( } if (statementType == Meta.StatementType.SELECT) { if (((DingoSqlValidator)validator).isHybridSearch()) { - String hybridSearchSql = ((DingoSqlValidator)validator).getHybridSearchSql(); - LogUtils.info(log, "HybridSearchSql: {}", hybridSearchSql); - SqlNode hybridSqlNode; - try { - hybridSqlNode = parse(hybridSearchSql); - } catch (SqlParseException e) { - throw ExceptionUtils.toRuntime(e); - } - syntacticSugar(hybridSqlNode); SqlNode originalSqlNode; try { originalSqlNode = parse(sql); @@ -397,7 +389,33 @@ public Meta.Signature parseQuery( throw ExceptionUtils.toRuntime(e); } syntacticSugar(originalSqlNode); - lockUpHybridSearchNode(originalSqlNode, hybridSqlNode); + if(((DingoSqlValidator) validator).getHybridSearchMap().size() == 1) { + String hybridSearchSql = ((DingoSqlValidator)validator).getHybridSearchSql(); + LogUtils.info(log, "HybridSearchSql: {}", hybridSearchSql); + SqlNode hybridSqlNode; + try { + hybridSqlNode = parse(hybridSearchSql); + } catch (SqlParseException e) { + throw ExceptionUtils.toRuntime(e); + } + syntacticSugar(hybridSqlNode); + lockUpHybridSearchNode(originalSqlNode, hybridSqlNode); + } else { + ConcurrentHashMap sqlNodeHashMap = new ConcurrentHashMap<>(); + for (Map.Entry entry : ((DingoSqlValidator) validator).getHybridSearchMap().entrySet()) { + SqlBasicCall key = entry.getKey(); + String value = entry.getValue(); + SqlNode hybridSqlNode; + try { + hybridSqlNode = parse(value); + } catch (SqlParseException e) { + throw ExceptionUtils.toRuntime(e); + } + syntacticSugar(hybridSqlNode); + sqlNodeHashMap.put(key, hybridSqlNode); + } + lockUpHybridSearchNode(originalSqlNode, sqlNodeHashMap); + } LogUtils.info(log, "HybridSearch Rewrite Sql: {}", originalSqlNode.toString()); if (originalSqlNode.getKind().equals(SqlKind.EXPLAIN)) { assert originalSqlNode instanceof SqlExplain; @@ -920,6 +938,53 @@ private void lockUpHybridSearchNode(SqlNode sqlNode, SqlNode subSqlNode) { } } + private void lockUpHybridSearchNode(SqlNode sqlNode, ConcurrentHashMap subSqlNode) { + if (sqlNode instanceof SqlSelect) { + SqlNode from = ((SqlSelect) sqlNode).getFrom(); + if (from instanceof SqlBasicCall && (((SqlBasicCall) from).getOperator() instanceof SqlHybridSearchOperator)) { + SqlBasicCall removeKey = null; + for (Map.Entry entry : subSqlNode.entrySet()) { + SqlBasicCall key = entry.getKey(); + SqlNode value = entry.getValue(); + if (from.toString().equals(key.toString())) { + ((SqlSelect) sqlNode).setFrom(value); + removeKey = key; + break; + } + } + if (removeKey != null) { + subSqlNode.remove(removeKey); + } + } else { + if (from instanceof SqlJoin) { + lockUpHybridSearchNode(((SqlJoin) from).getLeft(), subSqlNode); + lockUpHybridSearchNode(((SqlJoin) from).getRight(), subSqlNode); + } + } + } else if (sqlNode instanceof SqlJoin) { + lockUpHybridSearchNode(((SqlJoin) sqlNode).getLeft(), subSqlNode); + lockUpHybridSearchNode(((SqlJoin) sqlNode).getRight(), subSqlNode); + } else if (sqlNode instanceof SqlBasicCall) { + if (((SqlBasicCall) sqlNode).getOperator() instanceof SqlAsOperator) { + deepLockUpChildren(((SqlBasicCall) sqlNode).getOperandList(), subSqlNode); + } + } else if (sqlNode instanceof SqlOrderBy) { + lockUpHybridSearchNode(((SqlOrderBy)sqlNode).query, subSqlNode); + } else if (sqlNode instanceof SqlExplain) { + lockUpHybridSearchNode(((SqlExplain) sqlNode).getExplicandum(), subSqlNode); + } + + } + + private void deepLockUpChildren(List sqlNodes, ConcurrentHashMap subSqlNode) { + if (sqlNodes == null) { + return; + } + for (int i = 0; i < sqlNodes.size(); i ++) { + lockUpHybridSearchNode(sqlNodes.get(i), subSqlNode); + } + } + private void deepLockUpChildren(List sqlNodes, SqlNode subSqlNode) { if (sqlNodes == null) { return; @@ -928,6 +993,7 @@ private void deepLockUpChildren(List sqlNodes, SqlNode subSqlNode) { lockUpHybridSearchNode(sqlNodes.get(i), subSqlNode); } } + private void syntacticSugar(SqlNode sqlNode) { if (sqlNode instanceof SqlSelect) { SqlNodeList sqlNodes = ((SqlSelect) sqlNode).getSelectList();