Skip to content

Commit

Permalink
Extended QPT to athena-dynamodb (#1819)
Browse files Browse the repository at this point in the history
Co-authored-by: AbdulRehman Faraj <[email protected]>
  • Loading branch information
AbdulR3hman and AbdulRehman Faraj committed Mar 22, 2024
1 parent 9965f05 commit 2007aaf
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 19 deletions.
1 change: 1 addition & 0 deletions athena-dynamodb/athena-dynamodb.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ Resources:
- dynamodb:ListTables
- dynamodb:Query
- dynamodb:Scan
- dynamodb:PartiQLSelect
- glue:GetTableVersions
- glue:GetPartitions
- glue:GetTables
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
import com.amazonaws.athena.connector.lambda.domain.predicate.ValueSet;
import com.amazonaws.athena.connector.lambda.domain.spill.SpillLocation;
import com.amazonaws.athena.connector.lambda.handlers.GlueMetadataHandler;
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetDataSourceCapabilitiesResponse;
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsRequest;
import com.amazonaws.athena.connector.lambda.metadata.GetSplitsResponse;
import com.amazonaws.athena.connector.lambda.metadata.GetTableLayoutRequest;
Expand All @@ -40,12 +42,14 @@
import com.amazonaws.athena.connector.lambda.metadata.ListTablesRequest;
import com.amazonaws.athena.connector.lambda.metadata.ListTablesResponse;
import com.amazonaws.athena.connector.lambda.metadata.glue.GlueFieldLexer;
import com.amazonaws.athena.connector.lambda.metadata.optimizations.OptimizationSubType;
import com.amazonaws.athena.connector.lambda.security.EncryptionKeyFactory;
import com.amazonaws.athena.connectors.dynamodb.constants.DynamoDBConstants;
import com.amazonaws.athena.connectors.dynamodb.credentials.CrossAccountCredentialsProviderV2;
import com.amazonaws.athena.connectors.dynamodb.model.DynamoDBIndex;
import com.amazonaws.athena.connectors.dynamodb.model.DynamoDBPaginatedTables;
import com.amazonaws.athena.connectors.dynamodb.model.DynamoDBTable;
import com.amazonaws.athena.connectors.dynamodb.qpt.DDBQueryPassthrough;
import com.amazonaws.athena.connectors.dynamodb.resolver.DynamoDBTableResolver;
import com.amazonaws.athena.connectors.dynamodb.util.DDBPredicateUtils;
import com.amazonaws.athena.connectors.dynamodb.util.DDBRecordMetadata;
Expand All @@ -59,6 +63,7 @@
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
import com.amazonaws.util.json.Jackson;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableMap;
import org.apache.arrow.vector.complex.reader.FieldReader;
import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.Field;
Expand All @@ -68,6 +73,8 @@
import software.amazon.awssdk.enhanced.dynamodb.document.EnhancedDocument;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.ExecuteStatementRequest;
import software.amazon.awssdk.services.dynamodb.model.ExecuteStatementResponse;

import java.util.ArrayList;
import java.util.Collections;
Expand Down Expand Up @@ -98,6 +105,7 @@
import static com.amazonaws.athena.connectors.dynamodb.constants.DynamoDBConstants.SEGMENT_ID_PROPERTY;
import static com.amazonaws.athena.connectors.dynamodb.constants.DynamoDBConstants.TABLE_METADATA;
import static com.amazonaws.athena.connectors.dynamodb.throttling.DynamoDBExceptionFilter.EXCEPTION_FILTER;
import static com.amazonaws.athena.connectors.dynamodb.util.DDBTableUtils.SCHEMA_INFERENCE_NUM_RECORDS;

/**
* Handles metadata requests for the Athena DynamoDB Connector.
Expand Down Expand Up @@ -134,6 +142,8 @@ public class DynamoDBMetadataHandler
private final AWSGlue glueClient;
private final DynamoDBTableResolver tableResolver;

private final DDBQueryPassthrough queryPassthrough;

public DynamoDBMetadataHandler(java.util.Map<String, String> configOptions)
{
super(SOURCE_TYPE, configOptions);
Expand All @@ -143,6 +153,7 @@ public DynamoDBMetadataHandler(java.util.Map<String, String> configOptions)
this.glueClient = getAwsGlue();
this.invoker = ThrottlingInvoker.newDefaultBuilder(EXCEPTION_FILTER, configOptions).build();
this.tableResolver = new DynamoDBTableResolver(invoker, ddbClient);
this.queryPassthrough = new DDBQueryPassthrough();
}

@VisibleForTesting
Expand All @@ -161,6 +172,16 @@ public DynamoDBMetadataHandler(java.util.Map<String, String> configOptions)
this.ddbClient = ddbClient;
this.invoker = ThrottlingInvoker.newDefaultBuilder(EXCEPTION_FILTER, configOptions).build();
this.tableResolver = new DynamoDBTableResolver(invoker, ddbClient);
this.queryPassthrough = new DDBQueryPassthrough();
}

@Override
public GetDataSourceCapabilitiesResponse doGetDataSourceCapabilities(BlockAllocator allocator, GetDataSourceCapabilitiesRequest request)
{
ImmutableMap.Builder<String, List<OptimizationSubType>> capabilities = ImmutableMap.builder();
this.queryPassthrough.addQueryPassthroughCapabilityIfEnabled(capabilities, this.configOptions);

return new GetDataSourceCapabilitiesResponse(request.getCatalogName(), capabilities.build());
}

/**
Expand Down Expand Up @@ -230,6 +251,27 @@ public ListTablesResponse doListTables(BlockAllocator allocator, ListTablesReque
return new ListTablesResponse(request.getCatalogName(), new ArrayList<>(combinedTables), token);
}

@Override
public GetTableResponse doGetQueryPassthroughSchema(BlockAllocator allocator, GetTableRequest request) throws Exception
{
if (!request.isQueryPassthrough()) {
throw new IllegalArgumentException("No Query passed through [{}]" + request);
}

queryPassthrough.verify(request.getQueryPassthroughArguments());
String partiQLStatement = request.getQueryPassthroughArguments().get(DDBQueryPassthrough.QUERY);
ExecuteStatementRequest executeStatementRequest =
ExecuteStatementRequest.builder()
.statement(partiQLStatement)
.limit(SCHEMA_INFERENCE_NUM_RECORDS)
.build();
//PartiQL on DynamoDB Doesn't allow a dry run; therefore, we look "Peek" over the first few records
ExecuteStatementResponse response = ddbClient.executeStatement(executeStatementRequest);
SchemaBuilder schemaBuilder = DDBTableUtils.buildSchemaFromItems(response.items());

return new GetTableResponse(request.getCatalogName(), request.getTableName(), schemaBuilder.build(), Collections.emptySet());
}

/**
* Fetches a table's schema from Glue DataCatalog if present and not disabled, otherwise falls
* back to doing a small table scan derives a schema from that.
Expand Down Expand Up @@ -268,6 +310,10 @@ public GetTableResponse doGetTable(BlockAllocator allocator, GetTableRequest req
@Override
public void enhancePartitionSchema(SchemaBuilder partitionSchemaBuilder, GetTableLayoutRequest request)
{
if (request.getTableName().getQualifiedTableName().equalsIgnoreCase(queryPassthrough.getFunctionSignature())) {
//Query passthrough does not support partition
return;
}
// use the source table name from the schema if available (in case Glue table name != actual table name)
String tableName = getSourceTableName(request.getSchema());
if (tableName == null) {
Expand Down Expand Up @@ -414,6 +460,11 @@ private void precomputeAdditionalMetadata(Set<String> columnsToIgnore, Map<Strin
@Override
public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest request)
{
if (request.getConstraints().isQueryPassThrough()) {
logger.info("QPT Split Requested");
return setupQueryPassthroughSplit(request);
}

int partitionContd = decodeContinuationToken(request);
Set<Split> splits = new HashSet<>();
Block partitions = request.getPartitions();
Expand Down Expand Up @@ -509,4 +560,21 @@ private String encodeContinuationToken(int partition)
{
return String.valueOf(partition);
}

/**
* Helper function that provides a single partition for Query Pass-Through
*
*/
private GetSplitsResponse setupQueryPassthroughSplit(GetSplitsRequest request)
{
//Every split must have a unique location if we wish to spill to avoid failures
SpillLocation spillLocation = makeSpillLocation(request);

//Since this is QPT query we return a fixed split.
Map<String, String> qptArguments = request.getConstraints().getQueryPassthroughArguments();
return new GetSplitsResponse(request.getCatalogName(),
Split.newBuilder(spillLocation, makeEncryptionKey())
.applyProperties(qptArguments)
.build());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import com.amazonaws.athena.connector.lambda.handlers.RecordHandler;
import com.amazonaws.athena.connector.lambda.records.ReadRecordsRequest;
import com.amazonaws.athena.connectors.dynamodb.credentials.CrossAccountCredentialsProviderV2;
import com.amazonaws.athena.connectors.dynamodb.qpt.DDBQueryPassthrough;
import com.amazonaws.athena.connectors.dynamodb.resolver.DynamoDBFieldResolver;
import com.amazonaws.athena.connectors.dynamodb.util.DDBPredicateUtils;
import com.amazonaws.athena.connectors.dynamodb.util.DDBRecordMetadata;
Expand All @@ -50,6 +51,8 @@
import software.amazon.awssdk.enhanced.dynamodb.document.EnhancedDocument;
import software.amazon.awssdk.services.dynamodb.DynamoDbClient;
import software.amazon.awssdk.services.dynamodb.model.AttributeValue;
import software.amazon.awssdk.services.dynamodb.model.ExecuteStatementRequest;
import software.amazon.awssdk.services.dynamodb.model.ExecuteStatementResponse;
import software.amazon.awssdk.services.dynamodb.model.QueryRequest;
import software.amazon.awssdk.services.dynamodb.model.QueryResponse;
import software.amazon.awssdk.services.dynamodb.model.ScanRequest;
Expand Down Expand Up @@ -104,6 +107,8 @@ public class DynamoDBRecordHandler
private final LoadingCache<String, ThrottlingInvoker> invokerCache;
private final DynamoDbClient ddbClient;

private final DDBQueryPassthrough queryPassthrough = new DDBQueryPassthrough();

public DynamoDBRecordHandler(java.util.Map<String, String> configOptions)
{
super(sourceType, configOptions);
Expand Down Expand Up @@ -149,6 +154,11 @@ public ThrottlingInvoker load(String tableName)
protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker)
throws ExecutionException
{
if (recordsRequest.getConstraints().isQueryPassThrough()) {
logger.info("readWithConstraint for QueryPassthrough PartiQL Query");
handleQueryPassthroughPartiQLQuery(spiller, recordsRequest, queryStatusChecker);
return;
}
Split split = recordsRequest.getSplit();
// use the property instead of the request table name because of case sensitivity
String tableName = split.getProperty(TABLE_METADATA);
Expand Down Expand Up @@ -190,8 +200,43 @@ protected void readWithConstraint(BlockSpiller spiller, ReadRecordsRequest recor
}

Iterator<Map<String, AttributeValue>> itemIterator = getIterator(split, tableName, recordsRequest.getSchema(), recordsRequest.getConstraints(), disableProjectionAndCasing);
writeItemsToBlock(spiller, recordsRequest, queryStatusChecker, recordMetadata, itemIterator, disableProjectionAndCasing);
}

private void handleQueryPassthroughPartiQLQuery(BlockSpiller spiller, ReadRecordsRequest recordsRequest, QueryStatusChecker queryStatusChecker)
{
if (!recordsRequest.getConstraints().isQueryPassThrough()) {
throw new RuntimeException("Attempting to readConstraints with Query Passthrough without PartiQL Query");
}
queryPassthrough.verify(recordsRequest.getConstraints().getQueryPassthroughArguments());

DDBRecordMetadata recordMetadata = new DDBRecordMetadata(recordsRequest.getSchema());

String partiQLStatement = recordsRequest.getConstraints().getQueryPassthroughArguments().get(DDBQueryPassthrough.QUERY);
ExecuteStatementRequest executeStatementRequest =
ExecuteStatementRequest.builder()
.statement(partiQLStatement)
.build();

ExecuteStatementResponse response = ddbClient.executeStatement(executeStatementRequest);

Iterator<Map<String, AttributeValue>> itemIterator = response.items().iterator();
writeItemsToBlock(spiller, recordsRequest, queryStatusChecker, recordMetadata, itemIterator, false);
}

private void writeItemsToBlock(
BlockSpiller spiller,
ReadRecordsRequest recordsRequest,
QueryStatusChecker queryStatusChecker,
DDBRecordMetadata recordMetadata,
Iterator<Map<String, AttributeValue>> itemIterator,
boolean disableProjectionAndCasing)
{
DynamoDBFieldResolver resolver = new DynamoDBFieldResolver(recordMetadata);

String disableProjectionAndCasingEnvValue = configOptions.getOrDefault(DISABLE_PROJECTION_AND_CASING_ENV, "auto").toLowerCase();
logger.info(DISABLE_PROJECTION_AND_CASING_ENV + " environment variable set to: " + disableProjectionAndCasingEnvValue);

GeneratedRowWriter.RowWriterBuilder rowWriterBuilder = GeneratedRowWriter.newBuilder(recordsRequest.getConstraints());
//register extract and field writer factory for each field.
for (Field next : recordsRequest.getSchema().getFields()) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*-
* #%L
* athena-jdbc
* %%
* Copyright (C) 2019 Amazon Web Services
* %%
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* #L%
*/
package com.amazonaws.athena.connectors.dynamodb.qpt;

import com.amazonaws.athena.connector.lambda.metadata.optimizations.querypassthrough.QueryPassthroughSignature;
import com.google.common.collect.ImmutableSet;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Arrays;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;

public class DDBQueryPassthrough implements QueryPassthroughSignature
{
// Constant value representing the name of the query.
public static final String NAME = "query";

// Constant value representing the domain of the query.
public static final String SCHEMA_NAME = "system";

// List of arguments for the query, statically initialized as it always contains the same value.
public static final String QUERY = "QUERY";

public static final List<String> ARGUMENTS = Arrays.asList(QUERY);

private static final Logger LOGGER = LoggerFactory.getLogger(DDBQueryPassthrough.class);

@Override
public String getFunctionSchema()
{
return SCHEMA_NAME;
}

@Override
public String getFunctionName()
{
return NAME;
}

@Override
public List<String> getFunctionArguments()
{
return ARGUMENTS;
}

@Override
public Logger getLogger()
{
return LOGGER;
}

@Override
public void customConnectorVerifications(Map<String, String> engineQptArguments)
{
String partiQLStatement = engineQptArguments.get(QUERY);
String upperCaseStatement = partiQLStatement.trim().toUpperCase(Locale.ENGLISH);

// Immediately check if the statement starts with "SELECT"
if (!upperCaseStatement.startsWith("SELECT")) {
throw new UnsupportedOperationException("Statement does not start with SELECT.");
}

// List of disallowed keywords
Set<String> disallowedKeywords = ImmutableSet.of("INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER");

// Check if the statement contains any disallowed keywords
for (String keyword : disallowedKeywords) {
if (upperCaseStatement.contains(keyword)) {
throw new UnsupportedOperationException("Unaccepted operation; only SELECT statements are allowed. Found: " + keyword);
}
}
}
}
Loading

0 comments on commit 2007aaf

Please sign in to comment.