Skip to content

Commit

Permalink
AE unicode data corruption guard
Browse files Browse the repository at this point in the history
  • Loading branch information
tkyc committed Jan 6, 2024
1 parent c309aae commit 3af81f3
Show file tree
Hide file tree
Showing 8 changed files with 323 additions and 19 deletions.
10 changes: 6 additions & 4 deletions src/main/java/com/microsoft/sqlserver/jdbc/Parameter.java
Original file line number Diff line number Diff line change
Expand Up @@ -415,19 +415,21 @@ void setValue(JDBCType jdbcType, Object value, JavaType javaType, StreamSetterAr
// the value with the appropriate corresponding Unicode type.
// JavaType.OBJECT == javaType when calling setNull()
if (con.sendStringParametersAsUnicode() && (JavaType.STRING == javaType || JavaType.READER == javaType
|| JavaType.CLOB == javaType || JavaType.OBJECT == javaType) && jdbcType != JDBCType.VARCHAR) {
|| JavaType.CLOB == javaType || JavaType.OBJECT == javaType) && !((jdbcType == JDBCType.VARCHAR || jdbcType == JDBCType.CHAR) && con.isColumnEncryptionSettingEnabled())) {
jdbcType = getSSPAUJDBCType(jdbcType);
}

DTV newDTV = new DTV();
newDTV.setValue(con.getDatabaseCollation(), jdbcType, value, javaType, streamSetterArgs, calendar, scale, con,
forceEncrypt);

if (!con.sendStringParametersAsUnicode() || (con.sendStringParametersAsUnicode() && jdbcType == JDBCType.VARCHAR)) {
if (!con.sendStringParametersAsUnicode() || (con.sendStringParametersAsUnicode()
&& con.isColumnEncryptionSettingEnabled() && (jdbcType == JDBCType.VARCHAR || jdbcType == JDBCType.CHAR))) {
newDTV.sendStringParametersAsUnicode = false;
}

if (con.sendStringParametersAsUnicode() && jdbcType == JDBCType.VARCHAR && (!con.getDatabaseCollation().isUtf8Encoding() || con.getServerMajorVersion() < 15)) {
if (con.sendStringParametersAsUnicode() && (jdbcType == JDBCType.VARCHAR || jdbcType == JDBCType.CHAR) && con.isColumnEncryptionSettingEnabled()
&& (!con.getDatabaseCollation().isUtf8Encoding() || con.getServerMajorVersion() < 15)) {
throw new SQLServerException(SQLServerException.getErrString("R_possibleColumnDataCorruption"), null);
}

Expand Down Expand Up @@ -812,7 +814,7 @@ private void setTypeDefinition(DTV dtv) {
} else {
param.typeDefinition = SSType.VARCHAR.toString() + "(" + valueLength + ")";

if (DataTypes.SHORT_VARTYPE_MAX_BYTES <= valueLength) {
if (DataTypes.SHORT_VARTYPE_MAX_BYTES < valueLength) {
param.typeDefinition = VARCHAR_MAX;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import java.io.FileReader;
import java.io.IOException;
import java.math.BigDecimal;
import java.sql.Connection;
import java.sql.Date;
import java.sql.JDBCType;
import java.sql.SQLException;
Expand Down Expand Up @@ -62,7 +63,6 @@ public class AESetup extends AbstractTest {
static String cekWin = Constants.CEK_NAME + "_WIN";
static String cekAkv = Constants.CEK_NAME + "_AKV";
static SQLServerStatementColumnEncryptionSetting stmtColEncSetting = null;

static String AETestConnectionString;
static String enclaveProperties = "";

Expand All @@ -73,6 +73,8 @@ public class AESetup extends AbstractTest {
.escapeSingleQuotes(AbstractSQLGenerator.escapeIdentifier(RandomUtil.getIdentifier("AETest_")));
public static final String CHAR_TABLE_AE = TestUtils
.escapeSingleQuotes(AbstractSQLGenerator.escapeIdentifier(RandomUtil.getIdentifier("JDBCEncryptedChar")));
public static final String CHAR_TABLE_AE_NON_UNICODE = TestUtils
.escapeSingleQuotes(AbstractSQLGenerator.escapeIdentifier(RandomUtil.getIdentifier("JDBCEncryptedCharNonUnicode")));
public static final String BINARY_TABLE_AE = TestUtils
.escapeSingleQuotes(AbstractSQLGenerator.escapeIdentifier(RandomUtil.getIdentifier("JDBCEncryptedBinary")));
public static final String DATE_TABLE_AE = TestUtils
Expand Down Expand Up @@ -107,6 +109,16 @@ enum ColumnType {
{"Varchar8000", "varchar(8000) COLLATE Latin1_General_BIN2", "CHAR"},
{"Nvarchar4000", "nvarchar(4000) COLLATE Latin1_General_BIN2", "NCHAR"},};

static String charTableNonUnicode[][] = {{"Char", "char(20) COLLATE Latin1_General_BIN2", "CHAR"},
{"Varchar", "varchar(50) COLLATE Latin1_General_BIN2", "CHAR"},
{"VarcharMax", "varchar(max) COLLATE Latin1_General_BIN2", "LONGVARCHAR"},
{"Varchar8000", "varchar(8000) COLLATE Latin1_General_BIN2", "CHAR"},};

static String charTableUTF8Collate[][] = {{"Char", "char(20) COLLATE Latin1_General_100_BIN2_UTF8", "CHAR"},
{"Varchar", "varchar(50) COLLATE Latin1_General_100_BIN2_UTF8", "CHAR"},
{"VarcharMax", "varchar(max) COLLATE Latin1_General_100_BIN2_UTF8", "LONGVARCHAR"},
{"Varchar8000", "varchar(8000) COLLATE Latin1_General_100_BIN2_UTF8", "CHAR"},};

static String dateTable[][] = {{"Date", "date", "DATE"}, {"Datetime2Default", "datetime2", "TIMESTAMP"},
{"DatetimeoffsetDefault", "datetimeoffset", "DATETIMEOFFSET"}, {"TimeDefault", "time", "TIME"},
{"Datetime", "datetime", "DATETIME"}, {"Smalldatetime", "smalldatetime", "SMALLDATETIME"}};
Expand Down Expand Up @@ -338,6 +350,25 @@ protected static void createTable(String tableName, String cekName, String table
}
}

protected static void createTable(String tableName, String cekName, String table[][], SQLServerStatement stmt) {
try {
String sql = "";
for (int i = 0; i < table.length; i++) {
sql += ColumnType.PLAIN.name() + table[i][0] + " " + table[i][1] + " NULL,";
sql += ColumnType.DETERMINISTIC.name() + table[i][0] + " " + table[i][1]
+ String.format(encryptSql, ColumnType.DETERMINISTIC.name(), cekName) + ") NULL,";
sql += ColumnType.RANDOMIZED.name() + table[i][0] + " " + table[i][1]
+ String.format(encryptSql, ColumnType.RANDOMIZED.name(), cekName) + ") NULL,";
}
TestUtils.dropTableIfExists(tableName, stmt);
sql = String.format(createSql, tableName, sql);
stmt.execute(sql);
stmt.execute("DBCC FREEPROCCACHE");
} catch (SQLException e) {
fail(e.getMessage());
}
}

protected static void createPrecisionTable(String tableName, String table[][], String cekName, int floatPrecision,
int precision, int scale) throws SQLException {
try (SQLServerConnection con = (SQLServerConnection) PrepUtil.getConnection(AETestConnectionString, AEInfo);
Expand Down Expand Up @@ -400,6 +431,22 @@ protected static void createScaleTable(String tableName, String table[][], Strin
}
}

protected static void createDatabaseWithUtf8Collation(Connection conn, String dbName) throws SQLException {
try (SQLServerStatement stmt = (SQLServerStatement) conn.createStatement()) {
String dropDB = "IF EXISTS (SELECT name FROM sys.databases WHERE name = N'"+ dbName + "') DROP DATABASE " + dbName + ";";
String createDB = "CREATE DATABASE " + dbName + " COLLATE Latin1_General_100_CS_AS_WS_SC_UTF8";
stmt.execute(dropDB);
stmt.execute(createDB);
}
}

protected static void dropDatabaseWithUtf8Collation(Connection conn, String dbName) throws SQLException {
try (SQLServerStatement stmt = (SQLServerStatement) conn.createStatement()) {
String dropDB = "IF EXISTS (SELECT name FROM sys.databases WHERE name = N'"+ dbName + "') DROP DATABASE " + dbName + ";";
stmt.execute(dropDB);
}
}

/**
* Create a list of binary values
*
Expand Down Expand Up @@ -449,6 +496,24 @@ protected static String[] createCharValues(boolean nullable) {
return values;
}

/**
* Create a list of char values for non-unicode data types
*
* @param nullable
*/
protected static String[] createCharValuesNonUnicode(boolean nullable) {

boolean encrypted = true;
String char20 = RandomData.generateCharTypes("20", nullable, encrypted);
String varchar50 = RandomData.generateCharTypes("50", nullable, encrypted);
String varcharmax = RandomData.generateCharTypes("max", nullable, encrypted);
String varchar8000 = RandomData.generateCharTypes("8000", nullable, encrypted);

String[] values = {char20.trim(), varchar50, varcharmax, varchar8000};

return values;
}

/**
* Create a list of numeric values
*
Expand Down Expand Up @@ -805,11 +870,12 @@ protected static void populateBinaryNullCase() throws SQLException {
* @param charValues
* @throws SQLException
*/
protected static void populateCharNormalCase(String[] charValues) throws SQLException {
protected static void populateCharNormalCase(String[] charValues, boolean sendStringParametersAsUnicode) throws SQLException {
String sql = "insert into " + CHAR_TABLE_AE + " values( " + "?,?,?," + "?,?,?," + "?,?,?," + "?,?,?," + "?,?,?,"
+ "?,?,?," + "?,?,?," + "?,?,?," + "?,?,?" + ")";

try (SQLServerConnection con = (SQLServerConnection) PrepUtil.getConnection(AETestConnectionString, AEInfo);
String connectionString = TestUtils.addOrOverrideProperty(AETestConnectionString, "sendStringParametersAsUnicode", Boolean.toString(sendStringParametersAsUnicode));
try (SQLServerConnection con = (SQLServerConnection) PrepUtil.getConnection(connectionString, AEInfo);
SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) TestUtils.getPreparedStmt(con, sql,
stmtColEncSetting)) {

Expand Down Expand Up @@ -866,6 +932,82 @@ protected static void populateCharNormalCase(String[] charValues) throws SQLExce
}
}

/**
* Populate char data non-unicode.
*
* @param charValues
* @throws SQLException
*/
protected static void populateCharNormalCaseNonUnicode(String connectionString, String[] charValues, boolean sendStringParametersAsUnicode) throws SQLException {
String sql = "insert into " + CHAR_TABLE_AE_NON_UNICODE + " values( " + "?,?,?," + "?,?,?," + "?,?,?," + "?,?,?)";

String cs = TestUtils.addOrOverrideProperty(connectionString, "sendStringParametersAsUnicode", Boolean.toString(sendStringParametersAsUnicode));
try (SQLServerConnection con = (SQLServerConnection) PrepUtil.getConnection(cs, AEInfo);
SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) TestUtils.getPreparedStmt(con, sql,
stmtColEncSetting)) {

// char
for (int i = 1; i <= 3; i++) {
pstmt.setString(i, charValues[0]);
}

// varchar
for (int i = 4; i <= 6; i++) {
pstmt.setString(i, charValues[1]);
}

// varchar(max)
for (int i = 7; i <= 9; i++) {
pstmt.setString(i, charValues[2]);
}

// varchar8000
for (int i = 10; i <= 12; i++) {
pstmt.setString(i, charValues[3]);
}

pstmt.execute();
}
}

/**
* Populate char data using set object.
*
* @param charValues
* @throws SQLException
*/
protected static void populateCharSetObjectNonUnicode(String connectionString, String[] charValues, boolean sendStringParametersAsUnicode) throws SQLException {
String sql = "insert into " + CHAR_TABLE_AE_NON_UNICODE + " values( " + "?,?,?," + "?,?,?," + "?,?,?," + "?,?,?)";

String cs = TestUtils.addOrOverrideProperty(connectionString, "sendStringParametersAsUnicode", Boolean.toString(sendStringParametersAsUnicode));
try (SQLServerConnection con = (SQLServerConnection) PrepUtil.getConnection(cs, AEInfo);
SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) TestUtils.getPreparedStmt(con, sql,
stmtColEncSetting)) {

// char
for (int i = 1; i <= 3; i++) {
pstmt.setObject(i, charValues[0]);
}

// varchar
for (int i = 4; i <= 6; i++) {
pstmt.setObject(i, charValues[1]);
}

// varchar(max)
for (int i = 7; i <= 9; i++) {
pstmt.setObject(i, charValues[2], java.sql.Types.LONGVARCHAR);
}

// varchar8000
for (int i = 10; i <= 12; i++) {
pstmt.setObject(i, charValues[3]);
}

pstmt.execute();
}
}

/**
* Populate char data using set object.
*
Expand All @@ -876,7 +1018,8 @@ protected static void populateCharSetObject(String[] charValues) throws SQLExcep
String sql = "insert into " + CHAR_TABLE_AE + " values( " + "?,?,?," + "?,?,?," + "?,?,?," + "?,?,?," + "?,?,?,"
+ "?,?,?," + "?,?,?," + "?,?,?," + "?,?,?" + ")";

try (SQLServerConnection con = (SQLServerConnection) PrepUtil.getConnection(AETestConnectionString, AEInfo);
String connectionString = TestUtils.addOrOverrideProperty(AETestConnectionString, "sendStringParametersAsUnicode", "false");
try (SQLServerConnection con = (SQLServerConnection) PrepUtil.getConnection(connectionString, AEInfo);
SQLServerPreparedStatement pstmt = (SQLServerPreparedStatement) TestUtils.getPreparedStmt(con, sql,
stmtColEncSetting)) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ public static void initValues() throws Exception {
createTable(BINARY_TABLE_AE, cekJks, binaryTable);

createDateTableCallableStatement(cekJks);
populateCharNormalCase(charValues);
populateCharNormalCase(charValues, false);
populateNumericSetObject(numericValues);
populateBinaryNormalCase(byteValues);
populateDateNormalCase();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ public void testAEv2Disabled(String serverName, String url, String protocol) thr
String[] values = createCharValues(false);
TestUtils.dropTableIfExists(CHAR_TABLE_AE, stmt);
createTable(CHAR_TABLE_AE, cekJks, charTable);
populateCharNormalCase(values);
populateCharNormalCase(values, false);
testAlterColumnEncryption(stmt, CHAR_TABLE_AE, charTable, cekJks);
fail(TestResource.getResource("R_expectedExceptionNotThrown"));
} catch (Throwable e) {
Expand Down Expand Up @@ -346,7 +346,7 @@ public void testChar(String serverName, String url, String protocol) throws Exce
SQLServerStatement stmt = (SQLServerStatement) con.createStatement()) {
TestUtils.dropTableIfExists(CHAR_TABLE_AE, stmt);
createTable(CHAR_TABLE_AE, cekJks, charTable);
populateCharNormalCase(createCharValues(false));
populateCharNormalCase(createCharValues(false), false);
testAlterColumnEncryption(stmt, CHAR_TABLE_AE, charTable, cekJks);
}
}
Expand All @@ -363,7 +363,7 @@ public void testCharAkv(String serverName, String url, String protocol) throws E
SQLServerStatement stmt = (SQLServerStatement) con.createStatement()) {
TestUtils.dropTableIfExists(CHAR_TABLE_AE, stmt);
createTable(CHAR_TABLE_AE, cekAkv, charTable);
populateCharNormalCase(createCharValues(false));
populateCharNormalCase(createCharValues(false), false);
testAlterColumnEncryption(stmt, CHAR_TABLE_AE, charTable, cekAkv);
}
}
Expand Down
Loading

0 comments on commit 3af81f3

Please sign in to comment.