diff --git a/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseMetadataHandler.java b/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseMetadataHandler.java index 4a980d5830..6be56ee456 100644 --- a/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseMetadataHandler.java +++ b/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseMetadataHandler.java @@ -236,9 +236,11 @@ public GetTableResponse doGetTable(BlockAllocator blockAllocator, GetTableReques private GetTableResponse getTableResponse(GetTableRequest request, Schema origSchema, com.amazonaws.athena.connector.lambda.domain.TableName tableName) + throws IOException { + TableName hbaseName = HbaseTableNameUtils.getHbaseTableName(configOptions, getOrCreateConn(request), tableName); if (origSchema == null) { - origSchema = HbaseSchemaUtils.inferSchema(getOrCreateConn(request), tableName, NUM_ROWS_TO_SCAN); + origSchema = HbaseSchemaUtils.inferSchema(getOrCreateConn(request), hbaseName, NUM_ROWS_TO_SCAN); } SchemaBuilder schemaBuilder = SchemaBuilder.newBuilder(); @@ -253,7 +255,10 @@ private GetTableResponse getTableResponse(GetTableRequest request, Schema origSc Schema schema = schemaBuilder.build(); logger.info("doGetTable: return {}", schema); - return new GetTableResponse(request.getCatalogName(), request.getTableName(), schema); + return new GetTableResponse( + request.getCatalogName(), + new com.amazonaws.athena.connector.lambda.domain.TableName(hbaseName.getNamespaceAsString(), hbaseName.getNameAsString()), + schema); } /** @@ -287,7 +292,7 @@ public GetSplitsResponse doGetSplits(BlockAllocator blockAllocator, GetSplitsReq Set splits = new HashSet<>(); //We can read each region in parallel - for (HRegionInfo info : getOrCreateConn(request).getTableRegions(HbaseSchemaUtils.getQualifiedTable(request.getTableName()))) { + for (HRegionInfo info : getOrCreateConn(request).getTableRegions(HbaseTableNameUtils.getQualifiedTable(request.getTableName()))) { Split.Builder splitBuilder = Split.newBuilder(makeSpillLocation(request), makeEncryptionKey()) .add(HBASE_CONN_STR, getConnStr(request)) .add(START_KEY_FIELD, new String(info.getStartKey())) diff --git a/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseRecordHandler.java b/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseRecordHandler.java index 9a1ee6b3cf..9e10b503df 100644 --- a/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseRecordHandler.java +++ b/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseRecordHandler.java @@ -162,7 +162,7 @@ protected void readWithConstraint(BlockSpiller blockSpiller, ReadRecordsRequest addToProjection(scan, next); } - getOrCreateConn(conStr).scanTable(HbaseSchemaUtils.getQualifiedTable(tableNameObj), + getOrCreateConn(conStr).scanTable(HbaseTableNameUtils.getQualifiedTable(tableNameObj), scan, (ResultScanner scanner) -> scanFilterProject(scanner, request, blockSpiller, queryStatusChecker)); } diff --git a/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseSchemaUtils.java b/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseSchemaUtils.java index 36458deb14..d0974d0efe 100644 --- a/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseSchemaUtils.java +++ b/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseSchemaUtils.java @@ -20,7 +20,6 @@ package com.amazonaws.athena.connectors.hbase; import com.amazonaws.athena.connector.lambda.data.SchemaBuilder; -import com.amazonaws.athena.connector.lambda.domain.TableName; import com.amazonaws.athena.connectors.hbase.connection.HBaseConnection; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; @@ -35,6 +34,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.IOException; import java.nio.ByteBuffer; import java.util.HashMap; import java.util.Map; @@ -46,8 +46,6 @@ public class HbaseSchemaUtils { //Field name for the special 'row' column which represets the HBase key used to store a given row. protected static final String ROW_COLUMN_NAME = "row"; - //The HBase namespce qualifier character which commonly separates namespaces and column families from tables and columns. - protected static final String NAMESPACE_QUALIFIER = ":"; private static final Logger logger = LoggerFactory.getLogger(HbaseSchemaUtils.class); private HbaseSchemaUtils() {} @@ -62,11 +60,11 @@ private HbaseSchemaUtils() {} * @param numToScan The number of records to scan as part of producing the Schema. * @return An Apache Arrow Schema representing the schema of the HBase table. */ - public static Schema inferSchema(HBaseConnection client, TableName tableName, int numToScan) + public static Schema inferSchema(HBaseConnection client, org.apache.hadoop.hbase.TableName tableName, int numToScan) + throws IOException { Scan scan = new Scan().setMaxResultSize(numToScan).setFilter(new PageFilter(numToScan)); - org.apache.hadoop.hbase.TableName hbaseTableName = org.apache.hadoop.hbase.TableName.valueOf(getQualifiedTableName(tableName)); - return client.scanTable(hbaseTableName, scan, (ResultScanner scanner) -> { + return client.scanTable(tableName, scan, (ResultScanner scanner) -> { try { return scanAndInferSchema(scanner); } @@ -132,7 +130,7 @@ private static Schema scanAndInferSchema(ResultScanner scanner) throws java.io.U for (Map.Entry> nextFamily : schemaInference.entrySet()) { String family = nextFamily.getKey(); for (Map.Entry nextCol : nextFamily.getValue().entrySet()) { - schemaBuilder.addField(family + NAMESPACE_QUALIFIER + nextCol.getKey(), nextCol.getValue()); + schemaBuilder.addField(family + HbaseTableNameUtils.NAMESPACE_QUALIFIER + nextCol.getKey(), nextCol.getValue()); } } @@ -144,28 +142,6 @@ private static Schema scanAndInferSchema(ResultScanner scanner) throws java.io.U return schema; } - /** - * Helper which goes from an Athena Federation SDK TableName to an HBase table name string. - * - * @param tableName An Athena Federation SDK TableName. - * @return The corresponding HBase table name string. - */ - public static String getQualifiedTableName(TableName tableName) - { - return tableName.getSchemaName() + NAMESPACE_QUALIFIER + tableName.getTableName(); - } - - /** - * Helper which goes from an Athena Federation SDK TableName to an HBase TableName. - * - * @param tableName An Athena Federation SDK TableName. - * @return The corresponding HBase TableName. - */ - public static org.apache.hadoop.hbase.TableName getQualifiedTable(TableName tableName) - { - return org.apache.hadoop.hbase.TableName.valueOf(tableName.getSchemaName() + NAMESPACE_QUALIFIER + tableName.getTableName()); - } - /** * Given a value from HBase attempt to infer it's type. * @@ -243,7 +219,7 @@ public static Object coerceType(boolean isNative, ArrowType type, byte[] value) */ public static String[] extractColumnParts(String glueColumnName) { - return glueColumnName.split(NAMESPACE_QUALIFIER); + return glueColumnName.split(HbaseTableNameUtils.NAMESPACE_QUALIFIER); } /** diff --git a/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseTableNameUtils.java b/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseTableNameUtils.java new file mode 100644 index 0000000000..d8ec2f3155 --- /dev/null +++ b/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/HbaseTableNameUtils.java @@ -0,0 +1,154 @@ +/*- + * #%L + * athena-hbase + * %% + * Copyright (C) 2019 Amazon Web Services + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package com.amazonaws.athena.connectors.hbase; + +import com.amazonaws.athena.connector.lambda.domain.TableName; +import com.amazonaws.athena.connectors.hbase.connection.HBaseConnection; +import com.google.common.collect.ArrayListMultimap; +import com.google.common.collect.Multimap; +import org.apache.arrow.util.VisibleForTesting; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Collection; +import java.util.Locale; +import java.util.Map; + +/** + * This class helps with resolving the differences in casing between HBase and Presto. Presto expects all + * databases, tables, and columns to be lower case. This class allows us to resolve HBase tables + * which may have captial letters in them without issue. It does so by fetching all table names and doing + * a case insensitive search over them. It will first try to do a targeted get to reduce the penalty for + * tables which don't have capitalization. + * + * Modeled off of DynamoDBTableResolver.java + * + * TODO add caching + */ +public final class HbaseTableNameUtils +{ + //The HBase namespce qualifier character which commonly separates namespaces and column families from tables and columns. + protected static final String NAMESPACE_QUALIFIER = ":"; + protected static final String ENABLE_CASE_INSENSITIVE_MATCH = "enable_case_insensitive_match"; + private static final Logger logger = LoggerFactory.getLogger(HbaseTableNameUtils.class); + + private HbaseTableNameUtils() {} + + /** + * Helper which goes from a schema and table name to an HBase table name string + * @param schema a schema name + * @param table the name of the table + * @return + */ + public static String getQualifiedTableName(String schema, String table) + { + return schema + NAMESPACE_QUALIFIER + table; + } + + /** + * Helper which goes from an Athena Federation SDK TableName to an HBase table name string. + * + * @param tableName An Athena Federation SDK TableName. + * @return The corresponding HBase table name string. + */ + public static String getQualifiedTableName(TableName tableName) + { + return getQualifiedTableName(tableName.getSchemaName(), tableName.getTableName()); + } + + /** + * Helper which goes from a schema and table name to an HBase TableName + * @param schema the schema name + * @param table the name of the table + * @return The corresponding HBase TableName + */ + public static org.apache.hadoop.hbase.TableName getQualifiedTable(String schema, String table) + { + return org.apache.hadoop.hbase.TableName.valueOf(getQualifiedTableName(schema, table)); + } + + /** + * Helper which goes from an Athena Federation SDK TableName to an HBase TableName. + * + * @param tableName An Athena Federation SDK TableName. + * @return The corresponding HBase TableName. + */ + public static org.apache.hadoop.hbase.TableName getQualifiedTable(TableName tableName) + { + return org.apache.hadoop.hbase.TableName.valueOf(getQualifiedTableName(tableName)); + } + + /** + * Gets the hbase table name from Athena table name. This is to allow athena to query uppercase table names + * (since athena does not support them). If an hbase table name is found with the athena table name, it is returned. + * Otherwise, tryCaseInsensitiveSearch is used to find the corresponding hbase table. + * + * @param tableName the case insensitive table name + * @return the hbase table name + */ + public static org.apache.hadoop.hbase.TableName getHbaseTableName(Map configOptions, HBaseConnection conn, TableName athTableName) + throws IOException + { + if (!isCaseInsensitiveMatchEnable(configOptions) || !athTableName.getTableName().equals(athTableName.getTableName().toLowerCase())) { + return getQualifiedTable(athTableName); + } + return tryCaseInsensitiveSearch(conn, athTableName); + } + + /** + * Performs a case insensitive table search by listing all table names in the schema (namespace), mapping them + * to their lowercase transformation, and then mapping the given tableName back to a unique table. To prevent ambiguity, + * an IllegalStateException is thrown if multiple tables map to the given tableName. + * @param conn the HBaseConnection used to retrieve the tables + * @param tableName The Athena TableName to find the mapping to + * @return The HBase TableName containing the found HBase table and the Athena Schema (namespace) + * @throws IOException + */ + @VisibleForTesting + protected static org.apache.hadoop.hbase.TableName tryCaseInsensitiveSearch(HBaseConnection conn, TableName tableName) + throws IOException + { + logger.info("Case Insensitive Match enabled. Searching for Table {}.", tableName.getTableName()); + Multimap lowerCaseNameMapping = ArrayListMultimap.create(); + org.apache.hadoop.hbase.TableName[] tableNames = conn.listTableNamesByNamespace(tableName.getSchemaName()); + for (org.apache.hadoop.hbase.TableName nextTableName : tableNames) { + lowerCaseNameMapping.put(nextTableName.getQualifierAsString().toLowerCase(Locale.ENGLISH), nextTableName.getNameAsString()); + } + Collection mappedNames = lowerCaseNameMapping.get(tableName.getTableName()); + if (mappedNames.size() != 1) { + throw new IllegalStateException(String.format("Either no tables or multiple tables resolved from case insensitive name %s: %s", tableName.getTableName(), mappedNames)); + } + org.apache.hadoop.hbase.TableName result = org.apache.hadoop.hbase.TableName.valueOf(mappedNames.iterator().next()); + logger.info("CaseInsensitiveMatch, TableName resolved to: {}", result.getNameAsString()); + return result; + } + + private static boolean isCaseInsensitiveMatchEnable(Map configOptions) + { + String enableCaseInsensitiveMatchEnvValue = configOptions.getOrDefault(ENABLE_CASE_INSENSITIVE_MATCH, "false").toLowerCase(); + boolean enableCaseInsensitiveMatch = enableCaseInsensitiveMatchEnvValue.equals("true"); + logger.info("{} environment variable set to: {}. Resolved to: {}", + ENABLE_CASE_INSENSITIVE_MATCH, enableCaseInsensitiveMatchEnvValue, enableCaseInsensitiveMatch); + + return enableCaseInsensitiveMatch; + } + +} diff --git a/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/connection/HBaseConnection.java b/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/connection/HBaseConnection.java index f70ccf674c..bdbe461cd2 100644 --- a/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/connection/HBaseConnection.java +++ b/athena-hbase/src/main/java/com/amazonaws/athena/connectors/hbase/connection/HBaseConnection.java @@ -156,6 +156,20 @@ public T scanTable(TableName tableName, Scan scan, ResultProcessor result }); } + /** + * Retrieves whether the table exists + * + * @param tableName The fully qualified HBase TableName for which to check existence. + * @return Whether the table exists or not. + */ + public boolean tableExists(TableName tableName) + { + return callWithReconnectAndRetry(() -> { + Admin admin = getConnection().getAdmin(); + return admin.tableExists(tableName); + }); + } + /** * Used to close this connection by closing the underlying HBase Connection. */ diff --git a/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/HbaseSchemaUtilsTest.java b/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/HbaseSchemaUtilsTest.java index 3f3e3ac350..9c51424c99 100644 --- a/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/HbaseSchemaUtilsTest.java +++ b/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/HbaseSchemaUtilsTest.java @@ -71,8 +71,9 @@ public void inferSchema() ResultProcessor processor = (ResultProcessor) invocationOnMock.getArguments()[2]; return processor.scan(mockScanner); }); + when(mockConnection.tableExists(any())).thenReturn(true); - Schema schema = HbaseSchemaUtils.inferSchema(mockConnection, tableName, numToScan); + Schema schema = HbaseSchemaUtils.inferSchema(mockConnection, HbaseTableNameUtils.getQualifiedTable(tableName), numToScan); Map actualFields = new HashMap<>(); schema.getFields().stream().forEach(next -> actualFields.put(next.getName(), Types.getMinorTypeForArrowType(next.getType()))); @@ -91,26 +92,6 @@ public void inferSchema() verify(mockScanner, times(1)).iterator(); } - @Test - public void getQualifiedTableName() - { - String table = "table"; - String schema = "schema"; - String expected = "schema:table"; - String actual = HbaseSchemaUtils.getQualifiedTableName(new TableName(schema, table)); - assertEquals(expected, actual); - } - - @Test - public void getQualifiedTable() - { - String table = "table"; - String schema = "schema"; - org.apache.hadoop.hbase.TableName expected = org.apache.hadoop.hbase.TableName.valueOf(schema + ":" + table); - org.apache.hadoop.hbase.TableName actual = HbaseSchemaUtils.getQualifiedTable(new TableName(schema, table)); - assertEquals(expected, actual); - } - @Test public void inferType() { diff --git a/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/HbaseTableNameUtilsTest.java b/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/HbaseTableNameUtilsTest.java new file mode 100644 index 0000000000..97912c549f --- /dev/null +++ b/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/HbaseTableNameUtilsTest.java @@ -0,0 +1,167 @@ +/*- + * #%L + * athena-hbase + * %% + * Copyright (C) 2019 Amazon Web Services + * %% + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * #L% + */ +package com.amazonaws.athena.connectors.hbase; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.io.IOException; +import java.util.Map; + +import org.junit.Test; + +import com.amazonaws.athena.connector.lambda.domain.TableName; +import com.amazonaws.athena.connectors.hbase.connection.HBaseConnection; + +public class HbaseTableNameUtilsTest +{ + private final Map config = com.google.common.collect.ImmutableMap.of(HbaseTableNameUtils.ENABLE_CASE_INSENSITIVE_MATCH, "true"); + + @Test + public void getQualifiedTableName() + { + String table = "table"; + String schema = "schema"; + String expected = "schema:table"; + String actualWithTable = HbaseTableNameUtils.getQualifiedTableName(new TableName(schema, table)); + String actualWithStrings = HbaseTableNameUtils.getQualifiedTableName(schema, table); + assertEquals(expected, actualWithTable); + assertEquals(expected, actualWithStrings); + } + + @Test + public void getQualifiedTable() + { + String table = "table"; + String schema = "schema"; + org.apache.hadoop.hbase.TableName expected = org.apache.hadoop.hbase.TableName.valueOf(schema + ":" + table); + org.apache.hadoop.hbase.TableName actualWithTable = HbaseTableNameUtils.getQualifiedTable(new TableName(schema, table)); + org.apache.hadoop.hbase.TableName actualWithStrings = HbaseTableNameUtils.getQualifiedTable(schema, table); + assertEquals(expected, actualWithTable); + assertEquals(expected, actualWithStrings); + } + + @Test + public void getHbaseTableName() + throws IOException + { + org.apache.hadoop.hbase.TableName[] tableNames = { + org.apache.hadoop.hbase.TableName.valueOf("schema:Test") + }; + HBaseConnection mockConnection = mock(HBaseConnection.class); + when(mockConnection.listTableNamesByNamespace(any())).thenReturn(tableNames); + when(mockConnection.tableExists(any())).thenReturn(false); + + TableName input = new TableName("schema", "test"); + org.apache.hadoop.hbase.TableName expected = HbaseTableNameUtils.getQualifiedTable("schema", "Test"); + org.apache.hadoop.hbase.TableName result = HbaseTableNameUtils.getHbaseTableName(config, mockConnection, input); + assertEquals(expected, result); + } + + @Test + public void getHbaseTableNameFlagFalse() + throws IOException + { + HBaseConnection mockConnection = mock(HBaseConnection.class); + + TableName input = new TableName("schema", "Test"); + org.apache.hadoop.hbase.TableName expected = HbaseTableNameUtils.getQualifiedTable("schema", "Test"); + org.apache.hadoop.hbase.TableName result = HbaseTableNameUtils.getHbaseTableName(com.google.common.collect.ImmutableMap.of(), mockConnection, input); + assertEquals(expected, result); + verify(mockConnection, times(0)).tableExists(any()); + verify(mockConnection, times(0)).listTableNamesByNamespace(any()); + } + + @Test(expected = IllegalStateException.class) + public void getHbaseTableNameDNE() + throws IOException + { + org.apache.hadoop.hbase.TableName[] tableNames = { + org.apache.hadoop.hbase.TableName.valueOf("schema:test") + }; + HBaseConnection mockConnection = mock(HBaseConnection.class); + when(mockConnection.listTableNamesByNamespace(any())).thenReturn(tableNames); + when(mockConnection.tableExists(any())).thenReturn(false); + + TableName input = new TableName("schema", "table"); + HbaseTableNameUtils.getHbaseTableName(config, mockConnection, input); + } + + @Test + public void tryCaseInsensitiveSearch() + throws IOException + { + org.apache.hadoop.hbase.TableName[] tableNames = { + org.apache.hadoop.hbase.TableName.valueOf("schema:test") + }; + TableName input = new TableName("schema", "test"); + HBaseConnection mockConnection = mock(HBaseConnection.class); + when(mockConnection.listTableNamesByNamespace(any())).thenReturn(tableNames); + org.apache.hadoop.hbase.TableName result = HbaseTableNameUtils.tryCaseInsensitiveSearch(mockConnection, input); + org.apache.hadoop.hbase.TableName expected = HbaseTableNameUtils.getQualifiedTable("schema", "test"); + assertEquals(expected, result); + } + + @Test + public void tryCaseInsensitiveSearchSingle() + throws IOException + { + org.apache.hadoop.hbase.TableName[] tableNames = { + org.apache.hadoop.hbase.TableName.valueOf("schema:Test") + }; + TableName input = new TableName("schema", "test"); + HBaseConnection mockConnection = mock(HBaseConnection.class); + when(mockConnection.listTableNamesByNamespace(any())).thenReturn(tableNames); + org.apache.hadoop.hbase.TableName result = HbaseTableNameUtils.tryCaseInsensitiveSearch(mockConnection, input); + org.apache.hadoop.hbase.TableName expected = HbaseTableNameUtils.getQualifiedTable("schema", "Test"); + assertEquals(expected, result); + } + + @Test(expected = IllegalStateException.class) + public void tryCaseInsensitiveSearchMultiple() + throws IOException + { + org.apache.hadoop.hbase.TableName[] tableNames = { + org.apache.hadoop.hbase.TableName.valueOf("schema:Test"), + org.apache.hadoop.hbase.TableName.valueOf("schema:tEst") + }; + TableName input = new TableName("schema", "test"); + HBaseConnection mockConnection = mock(HBaseConnection.class); + when(mockConnection.listTableNamesByNamespace(any())).thenReturn(tableNames); + HbaseTableNameUtils.tryCaseInsensitiveSearch(mockConnection, input); + } + + @Test(expected = IllegalStateException.class) + public void tryCaseInsensitiveSearchNone() + throws IOException + { + org.apache.hadoop.hbase.TableName[] tableNames = { + org.apache.hadoop.hbase.TableName.valueOf("schema:other") + }; + TableName input = new TableName("schema", "test"); + HBaseConnection mockConnection = mock(HBaseConnection.class); + when(mockConnection.listTableNamesByNamespace(any())).thenReturn(tableNames); + HbaseTableNameUtils.tryCaseInsensitiveSearch(mockConnection, input); + } +} diff --git a/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/connection/HBaseConnectionTest.java b/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/connection/HBaseConnectionTest.java index 8ec3dcea5a..76eaa8a3cf 100644 --- a/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/connection/HBaseConnectionTest.java +++ b/athena-hbase/src/test/java/com/amazonaws/athena/connectors/hbase/connection/HBaseConnectionTest.java @@ -43,6 +43,7 @@ import static org.junit.Assert.*; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.booleanThat; import static org.mockito.ArgumentMatchers.nullable; import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; @@ -367,6 +368,71 @@ public void scanTableWithCallerException() logger.info("scanTable: exit"); } + @Test + public void tableExists() + throws IOException + { + logger.info("tableExists: enter"); + when(mockConnection.getAdmin()).thenReturn(mockAdmin); + when(mockAdmin.tableExists(any())).thenReturn(true); + + boolean result = connection.tableExists(null); + assertNotNull(result); + assertTrue(connection.isHealthy()); + assertEquals(0, connection.getRetries()); + verify(mockConnection, atLeastOnce()).getAdmin(); + verify(mockAdmin, atLeastOnce()).tableExists(any()); + logger.info("tableExists: exit"); + } + + @Test + public void tableExistsWithRetry() + throws IOException + { + logger.info("tableExistsWithRetry: enter"); + when(mockConnection.getAdmin()).thenAnswer(new Answer() + { + private int count = 0; + + public Object answer(InvocationOnMock invocation) + { + if (++count == 1) { + //first invocation should throw + return new RuntimeException("Retryable"); + } + + return mockAdmin; + } + }); + when(mockAdmin.tableExists(any())).thenReturn(false); + + boolean result = connection.tableExists(null); + assertNotNull(result); + assertTrue(connection.isHealthy()); + assertEquals(1, connection.getRetries()); + verify(mockConnection, atLeastOnce()).getAdmin(); + verify(mockAdmin, atLeastOnce()).tableExists(any()); + logger.info("tableExistsWithRetry: exit"); + } + + @Test + public void tableExistsRetryExhausted() + throws IOException + { + logger.info("tableExistsRetryExhausted: enter"); + when(mockConnection.getAdmin()).thenThrow(new RuntimeException("Retryable")); + try { + connection.tableExists(null); + fail("Should not reach this line because retries should be exhausted."); + } + catch (RuntimeException ex) { + logger.info("tableExistsRetryExhausted: Encountered expected exception.", ex); + } + assertFalse(connection.isHealthy()); + assertEquals(3, connection.getRetries()); + logger.info("tableExistsRetryExhausted: exit"); + } + @Test public void close() throws IOException