Skip to content

Commit

Permalink
Snowflake split issue (#1963)
Browse files Browse the repository at this point in the history
  • Loading branch information
ejeffrli authored May 17, 2024
1 parent ee6cba1 commit 71d3320
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public final class SnowflakeConstants
* This constant limits the number of partitions. The default set to 50. A large number may cause a timeout issue.
* We arrived at this number after performance testing with datasets of different size
*/
public static final int MAX_PARTITION_COUNT = 1;
public static final int MAX_PARTITION_COUNT = 50;
/**
* This constant limits the number of records to be returned in a single split.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import com.amazonaws.services.athena.AmazonAthena;
import com.amazonaws.services.secretsmanager.AWSSecretsManager;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.apache.arrow.vector.complex.reader.FieldReader;
Expand All @@ -73,11 +74,13 @@
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

Expand All @@ -103,6 +106,9 @@ public class SnowflakeMetadataHandler extends JdbcMetadataHandler
"WHERE table_type = 'BASE TABLE'\n" +
"AND table_schema= ?\n" +
"AND TABLE_NAME = ? ";
static final String SHOW_PRIMARY_KEYS_QUERY = "SHOW PRIMARY KEYS IN ";
static final String PRIMARY_KEY_COLUMN_NAME = "column_name";
static final String COUNTS_COLUMN_NAME = "COUNTS";
private static final String CASE_UPPER = "upper";
private static final String CASE_LOWER = "lower";
/**
Expand Down Expand Up @@ -178,6 +184,48 @@ public Schema getPartitionSchema(final String catalogName)
.addField(BLOCK_PARTITION_COLUMN_NAME, Types.MinorType.VARCHAR.getType());
return schemaBuilder.build();
}

private Optional<String> getPrimaryKey(TableName tableName) throws Exception
{
List<String> primaryKeys = new ArrayList<String>();
try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) {
try (PreparedStatement preparedStatement = connection.prepareStatement(SHOW_PRIMARY_KEYS_QUERY + tableName.getTableName());
ResultSet rs = preparedStatement.executeQuery()) {
while (rs.next()) {
// Concatenate multiple primary keys if they exist
primaryKeys.add(rs.getString(PRIMARY_KEY_COLUMN_NAME));
}
}
}
String primaryKey = String.join(", ", primaryKeys);
if (!Strings.isNullOrEmpty(primaryKey) && hasUniquePrimaryKey(tableName, primaryKey)) {
return Optional.of(primaryKey);
}
return Optional.empty();
}

/**
* Snowflake does not enforce primary key constraints, so we double-check user has unique primary key
* before partitioning.
*/
private boolean hasUniquePrimaryKey(TableName tableName, String primaryKey) throws Exception
{
try (Connection connection = getJdbcConnectionFactory().getConnection(getCredentialProvider())) {
try (PreparedStatement preparedStatement = connection.prepareStatement("SELECT " + primaryKey + ", count(*) as COUNTS FROM " + tableName.getTableName() + " GROUP BY " + primaryKey + " ORDER BY COUNTS DESC");
ResultSet rs = preparedStatement.executeQuery()) {
if (rs.next()) {
if (rs.getInt(COUNTS_COLUMN_NAME) == 1) {
// Since it is in descending order and 1 is this first count seen,
// this table has a unique primary key
return true;
}
}
}
}
LOGGER.warn("Primary key ,{}, is not unique. Falling back to single partition...", primaryKey);
return false;
}

/**
* Snowflake manual partition logic based upon number of records
* @param blockWriter
Expand Down Expand Up @@ -217,18 +265,19 @@ public void getPartitions(BlockWriter blockWriter, GetTableLayoutRequest getTabl
totalRecordCount = rs.getLong(1);
}
if (totalRecordCount > 0) {
long pageCount = (long) (Math.ceil(totalRecordCount / MAX_PARTITION_COUNT));
long partitionRecordCount = (totalRecordCount <= SINGLE_SPLIT_LIMIT_COUNT) ? (long) totalRecordCount : pageCount;
LOGGER.info(" Total Page Count" + partitionRecordCount);
double limit = (int) Math.ceil(totalRecordCount / partitionRecordCount);
Optional<String> primaryKey = getPrimaryKey(getTableLayoutRequest.getTableName());
long recordsInPartition = (long) (Math.ceil(totalRecordCount / MAX_PARTITION_COUNT));
long partitionRecordCount = (totalRecordCount <= SINGLE_SPLIT_LIMIT_COUNT || !primaryKey.isPresent()) ? (long) totalRecordCount : recordsInPartition;
LOGGER.info(" Total Page Count: " + partitionRecordCount);
double numberOfPartitions = (int) Math.ceil(totalRecordCount / partitionRecordCount);
long offset = 0;
/**
* Custom pagination based partition logic will be applied with limit and offset clauses.
* It will have maximum 50 partitions and number of records in each partition is decided by dividing total number of records by 50
* the partition values we are setting the limit and offset values like p-limit-3000-offset-0
*/
for (int i = 1; i <= limit; i++) {
final String partitionVal = BLOCK_PARTITION_COLUMN_NAME + "-limit-" + partitionRecordCount + "-offset-" + offset;
for (int i = 1; i <= numberOfPartitions; i++) {
final String partitionVal = BLOCK_PARTITION_COLUMN_NAME + "-primary-" + primaryKey.orElse("") + "-limit-" + partitionRecordCount + "-offset-" + offset;
LOGGER.info("partitionVal {} ", partitionVal);
blockWriter.writeRows((Block block, int rowNum) ->
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,24 @@ protected List<String> getPartitionWhereClauses(Split split)
@Override
protected String appendLimitOffset(Split split)
{
String primaryKey = "";
String xLimit = "";
String xOffset = "";
String partitionVal = split.getProperty(split.getProperties().keySet().iterator().next()); //p-limit-3000-offset-0
String partitionVal = split.getProperty(split.getProperties().keySet().iterator().next()); //p-primary-<PRIMARYKEY>-limit-3000-offset-0
if (!partitionVal.contains("-")) {
return EMPTY_STRING;
}
else {
String[] arr = partitionVal.split("-");
xLimit = arr[2];
xOffset = arr[4];
primaryKey = arr[2];
xLimit = arr[4];
xOffset = arr[6];
}
return " limit " + xLimit + " offset " + xOffset;

// if no primary key, single split only
if (primaryKey.equals("")) {
return "";
}
return "ORDER BY " + primaryKey + " limit " + xLimit + " offset " + xOffset;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,28 @@ public void doGetTableLayout()
Mockito.when(this.connection.prepareStatement(SnowflakeMetadataHandler.COUNT_RECORDS_QUERY)).thenReturn(preparedStatement);
String[] columns = {"partition"};
int[] types = {Types.VARCHAR};
Object[][] values = {{"partition : partition-limit-500000-offset-0"},{"partition : partition-limit-500000-offset-500000"}};
Object[][] values = {{"partition : partition-primary-pkey-limit-500000-offset-0"},{"partition : partition-primary-pkey-limit-500000-offset-500000"}};
ResultSet resultSet = mockResultSet(columns, types, values, new AtomicInteger(-1));
Mockito.when(preparedStatement.executeQuery()).thenReturn(resultSet);
Mockito.when(this.connection.getMetaData().getSearchStringEscape()).thenReturn(null);
double totalActualRecordCount = 10001;
Mockito.when(resultSet.getLong(1)).thenReturn((long) totalActualRecordCount);

PreparedStatement primaryKeyPreparedStatement = Mockito.mock(PreparedStatement.class);
String[] primaryKeyColumns = new String[] {SnowflakeMetadataHandler.PRIMARY_KEY_COLUMN_NAME};
String[][] primaryKeyValues = new String[][]{new String[] {"pkey"}};
ResultSet primaryKeyResultSet = mockResultSet(primaryKeyColumns, primaryKeyValues, new AtomicInteger(-1));
Mockito.when(this.connection.prepareStatement(SnowflakeMetadataHandler.SHOW_PRIMARY_KEYS_QUERY + "testTable")).thenReturn(primaryKeyPreparedStatement);
Mockito.when(primaryKeyPreparedStatement.executeQuery()).thenReturn(primaryKeyResultSet);

PreparedStatement countsPreparedStatement = Mockito.mock(PreparedStatement.class);
String GET_PKEY_COUNTS_QUERY = "SELECT pkey, count(*) as COUNTS FROM testTable GROUP BY pkey ORDER BY COUNTS DESC";
String[] countsColumns = new String[] {"pkey", SnowflakeMetadataHandler.COUNTS_COLUMN_NAME};
Object[][] countsValues = {{"a", 1}};
ResultSet countsResultSet = mockResultSet(countsColumns, countsValues, new AtomicInteger(-1));
Mockito.when(this.connection.prepareStatement(GET_PKEY_COUNTS_QUERY)).thenReturn(countsPreparedStatement);
Mockito.when(countsPreparedStatement.executeQuery()).thenReturn(countsResultSet);

GetTableLayoutResponse getTableLayoutResponse = this.snowflakeMetadataHandler.doGetTableLayout(blockAllocator, getTableLayoutRequest);
List<String> expectedValues = new ArrayList<>();
for (int i = 0; i < getTableLayoutResponse.getPartitions().getRowCount(); i++) {
Expand All @@ -120,7 +136,7 @@ public void doGetTableLayout()
if (i > 1) {
offset = offset + partitionActualRecordCount;
}
actualValues.add("[partition : partition-limit-" +partitionActualRecordCount + "-offset-" + offset + "]");
actualValues.add("[partition : partition-primary-pkey-limit-" +partitionActualRecordCount + "-offset-" + offset + "]");
}
Assert.assertEquals((int)limit, getTableLayoutResponse.getPartitions().getRowCount());
Assert.assertEquals(expectedValues, actualValues);
Expand Down Expand Up @@ -148,12 +164,28 @@ public void doGetTableLayoutSinglePartition()

String[] columns = {"partition"};
int[] types = {Types.VARCHAR};
Object[][] values = {{"partition : partition-limit-500000-offset-0"}};
Object[][] values = {{"partition : partition-primary-pkey-limit-500000-offset-0"}};
ResultSet resultSet = mockResultSet(columns, types, values, new AtomicInteger(-1));
Mockito.when(preparedStatement.executeQuery()).thenReturn(resultSet);

Mockito.when(this.connection.getMetaData().getSearchStringEscape()).thenReturn(null);
Mockito.when(resultSet.getLong(1)).thenReturn(1001L);

PreparedStatement primaryKeyPreparedStatement = Mockito.mock(PreparedStatement.class);
String[] primaryKeyColumns = new String[] {SnowflakeMetadataHandler.PRIMARY_KEY_COLUMN_NAME};
String[][] primaryKeyValues = new String[][]{new String[] {""}};
ResultSet primaryKeyResultSet = mockResultSet(primaryKeyColumns, primaryKeyValues, new AtomicInteger(-1));
Mockito.when(this.connection.prepareStatement(SnowflakeMetadataHandler.SHOW_PRIMARY_KEYS_QUERY + "testTable")).thenReturn(primaryKeyPreparedStatement);
Mockito.when(primaryKeyPreparedStatement.executeQuery()).thenReturn(primaryKeyResultSet);

PreparedStatement countsPreparedStatement = Mockito.mock(PreparedStatement.class);
String GET_PKEY_COUNTS_QUERY = "SELECT pkey, count(*) as COUNTS FROM testTable GROUP BY pkey ORDER BY COUNTS DESC";
String[] countsColumns = new String[] {"pkey", SnowflakeMetadataHandler.COUNTS_COLUMN_NAME};
Object[][] countsValues = {{"a", 1}};
ResultSet countsResultSet = mockResultSet(countsColumns, countsValues, new AtomicInteger(-1));
Mockito.when(this.connection.prepareStatement(GET_PKEY_COUNTS_QUERY)).thenReturn(countsPreparedStatement);
Mockito.when(countsPreparedStatement.executeQuery()).thenReturn(countsResultSet);

GetTableLayoutResponse getTableLayoutResponse = this.snowflakeMetadataHandler.doGetTableLayout(blockAllocator, getTableLayoutRequest);

Assert.assertEquals(values.length, getTableLayoutResponse.getPartitions().getRowCount());
Expand All @@ -162,7 +194,7 @@ public void doGetTableLayoutSinglePartition()
for (int i = 0; i < getTableLayoutResponse.getPartitions().getRowCount(); i++) {
expectedValues.add(BlockUtils.rowToString(getTableLayoutResponse.getPartitions(), i));
}
Assert.assertEquals(expectedValues, Arrays.asList("[partition : partition-limit-1001-offset-0]"));
Assert.assertEquals(expectedValues, Arrays.asList("[partition : partition-primary--limit-1001-offset-0]"));

SchemaBuilder expectedSchemaBuilder = SchemaBuilder.newBuilder();
expectedSchemaBuilder.addField(FieldBuilder.newBuilder("partition", org.apache.arrow.vector.types.Types.MinorType.VARCHAR.getType()).build());
Expand Down Expand Up @@ -194,11 +226,27 @@ public void doGetTableLayoutMaxPartition()
long offset = 0;
String[] columns = {"partition"};
int[] types = {Types.VARCHAR};
Object[][] values = {{"partition : partition-limit-500000-offset-0"}};
Object[][] values = {{"partition : partition-primary-pkey-limit-500000-offset-0"}};
ResultSet resultSet = mockResultSet(columns, types, values, new AtomicInteger(-1));
Mockito.when(preparedStatement.executeQuery()).thenReturn(resultSet);
Mockito.when(this.connection.getMetaData().getSearchStringEscape()).thenReturn(null);
Mockito.when(resultSet.getLong(1)).thenReturn((long)totalActualRecordCount);

PreparedStatement primaryKeyPreparedStatement = Mockito.mock(PreparedStatement.class);
String[] primaryKeyColumns = new String[] {SnowflakeMetadataHandler.PRIMARY_KEY_COLUMN_NAME};
String[][] primaryKeyValues = new String[][]{new String[] {"pkey"}};
ResultSet primaryKeyResultSet = mockResultSet(primaryKeyColumns, primaryKeyValues, new AtomicInteger(-1));
Mockito.when(this.connection.prepareStatement(SnowflakeMetadataHandler.SHOW_PRIMARY_KEYS_QUERY + "testTable")).thenReturn(primaryKeyPreparedStatement);
Mockito.when(primaryKeyPreparedStatement.executeQuery()).thenReturn(primaryKeyResultSet);

PreparedStatement countsPreparedStatement = Mockito.mock(PreparedStatement.class);
String GET_PKEY_COUNTS_QUERY = "SELECT pkey, count(*) as COUNTS FROM testTable GROUP BY pkey ORDER BY COUNTS DESC";
String[] countsColumns = new String[] {"pkey", SnowflakeMetadataHandler.COUNTS_COLUMN_NAME};
Object[][] countsValues = {{"a", 1}};
ResultSet countsResultSet = mockResultSet(countsColumns, countsValues, new AtomicInteger(-1));
Mockito.when(this.connection.prepareStatement(GET_PKEY_COUNTS_QUERY)).thenReturn(countsPreparedStatement);
Mockito.when(countsPreparedStatement.executeQuery()).thenReturn(countsResultSet);

GetTableLayoutResponse getTableLayoutResponse = this.snowflakeMetadataHandler.doGetTableLayout(blockAllocator, getTableLayoutRequest);
List<String> actualValues = new ArrayList<>();
List<String> expectedValues = new ArrayList<>();
Expand All @@ -209,7 +257,7 @@ public void doGetTableLayoutMaxPartition()
if (i > 1) {
offset = offset + partitionActualRecordCount;
}
actualValues.add("[partition : partition-limit-" +partitionActualRecordCount + "-offset-" + offset + "]");
actualValues.add("[partition : partition-primary-pkey-limit-" +partitionActualRecordCount + "-offset-" + offset + "]");
}
Assert.assertEquals(expectedValues,actualValues);
SchemaBuilder expectedSchemaBuilder = SchemaBuilder.newBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public void testQueryBuilderNew()
Split split = Mockito.mock(Split.class);
SnowflakeQueryStringBuilder builder = new SnowflakeQueryStringBuilder(SNOWFLAKE_QUOTE_CHARACTER, new SnowflakeFederationExpressionParser(SNOWFLAKE_QUOTE_CHARACTER));
Mockito.when(split.getProperties()).thenReturn(Collections.singletonMap("partition", "p0"));
Mockito.when(split.getProperty(Mockito.eq("partition"))).thenReturn("p1-p2-p3-p4-p5");
Mockito.when(split.getProperty(Mockito.eq("partition"))).thenReturn("p1-p2-p3-p4-p5-p6-p7");
builder.getFromClauseWithSplit("default", "", "table", split);
builder.appendLimitOffset(split);
}
Expand Down

0 comments on commit 71d3320

Please sign in to comment.