From 2a3d3728a61a2b9c27f11f6dc8bf59089d5ca79d Mon Sep 17 00:00:00 2001 From: Mahendra Chavan Date: Wed, 15 Jan 2025 09:54:14 +0530 Subject: [PATCH 1/5] Introduced timeouts for MSAL calls. (#2562) * Introduced timeouts for MSAL calls. * Fixed indentation issues. * Added unit tests * Added a max wait duration of 20 seconds to MSAL calls - Added more tests - Improved test to check for specific error message * Added Timeout Exception catch clause for one of the auth methods * Replaced lock with tryLock. - Replaced lock with tryLock to avoid potential long waiting for other threads while one thread is taking long to complete. * Replaced lock with semaphore for beter readablility. - Added detailed comment for the usage of semaphore. * Renamed semAcquired to isSemAcquired * Fixed indentation for an existing code line * Change to use Mono::timeout method * Updated TOKEN_WAIT_DURATION_MS to correct value. * Improved error messages --- .../sqlserver/jdbc/SQLServerConnection.java | 23 ++-- .../sqlserver/jdbc/SQLServerMSAL4JUtils.java | 126 ++++++++++++++---- .../jdbc/SQLServerSecurityUtility.java | 12 +- .../jdbc/SQLServerConnectionTest.java | 49 +++++++ 4 files changed, 166 insertions(+), 44 deletions(-) diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java index 42ae6647d..5f076f9af 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java @@ -6110,10 +6110,11 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw } while (true) { + int millisecondsRemaining = timerRemaining(timerExpire); if (authenticationString.equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString())) { fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, user, activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()), - authenticationString); + authenticationString, millisecondsRemaining); // Break out of the retry loop in successful case. break; @@ -6125,12 +6126,12 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) { fedAuthToken = SQLServerSecurityUtility.getManagedIdentityCredAuthToken(fedAuthInfo.spn, - managedIdentityClientId); + managedIdentityClientId, millisecondsRemaining); break; } fedAuthToken = SQLServerSecurityUtility.getManagedIdentityCredAuthToken(fedAuthInfo.spn, - activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString())); + activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()), millisecondsRemaining); // Break out of the retry loop in successful case. break; @@ -6141,12 +6142,12 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw if (aadPrincipalID != null && !aadPrincipalID.isEmpty() && aadPrincipalSecret != null && !aadPrincipalSecret.isEmpty()) { fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo, aadPrincipalID, - aadPrincipalSecret, authenticationString); + aadPrincipalSecret, authenticationString, millisecondsRemaining); } else { fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo, activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString()), activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()), - authenticationString); + authenticationString, millisecondsRemaining); } // Break out of the retry loop in successful case. @@ -6159,7 +6160,7 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString()), servicePrincipalCertificate, activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()), - servicePrincipalCertificateKey, servicePrincipalCertificatePassword, authenticationString); + servicePrincipalCertificateKey, servicePrincipalCertificatePassword, authenticationString, millisecondsRemaining); // Break out of the retry loop in successful case. break; @@ -6194,7 +6195,7 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw throw new SQLServerException(form.format(msgArgs), null); } - int millisecondsRemaining = timerRemaining(timerExpire); + millisecondsRemaining = timerRemaining(timerExpire); if (ActiveDirectoryAuthentication.GET_ACCESS_TOKEN_TRANSIENT_ERROR != errorCategory || timerHasExpired(timerExpire) || (fedauthSleepInterval >= millisecondsRemaining)) { @@ -6240,7 +6241,7 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw Object[] msgArgs = {SQLServerDriver.AUTH_DLL_NAME, authenticationString}; throw new SQLServerException(form.format(msgArgs), null, 0, null); } - fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenIntegrated(fedAuthInfo, authenticationString); + fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenIntegrated(fedAuthInfo, authenticationString, millisecondsRemaining); } // Break out of the retry loop in successful case. break; @@ -6248,7 +6249,7 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_INTERACTIVE.toString())) { // interactive flow fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenInteractive(fedAuthInfo, user, - authenticationString); + authenticationString, millisecondsRemaining); // Break out of the retry loop in successful case. break; @@ -6258,12 +6259,12 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) { fedAuthToken = SQLServerSecurityUtility.getDefaultAzureCredAuthToken(fedAuthInfo.spn, - managedIdentityClientId); + managedIdentityClientId, millisecondsRemaining); break; } fedAuthToken = SQLServerSecurityUtility.getDefaultAzureCredAuthToken(fedAuthInfo.spn, - activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString())); + activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()), millisecondsRemaining); break; } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java index 689347db3..8850e74fc 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java @@ -25,7 +25,9 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Level; @@ -64,7 +66,8 @@ class SQLServerMSAL4JUtils { static final String REDIRECTURI = "http://localhost"; static final String SLASH_DEFAULT = "/.default"; static final String ACCESS_TOKEN_EXPIRE = "access token expires: "; - + static final long TOKEN_WAIT_DURATION_MS = 20000; + static final long TOKEN_SEM_WAIT_DURATION_MS = 5000; private static final TokenCacheMap TOKEN_CACHE_MAP = new TokenCacheMap(); private final static String LOGCONTEXT = "MSAL version " @@ -77,19 +80,28 @@ private SQLServerMSAL4JUtils() { throw new UnsupportedOperationException(SQLServerException.getErrString("R_notSupported")); } - private static final Lock lock = new ReentrantLock(); + private static final Semaphore sem = new Semaphore(1); static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String user, String password, - String authenticationString) throws SQLServerException { + String authenticationString, int millisecondsRemaining) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); if (logger.isLoggable(Level.FINEST)) { logger.finest(LOGCONTEXT + authenticationString + ": get FedAuth token for user: " + user); } - lock.lock(); - + boolean isSemAcquired = false; try { + // + //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. + //The purpose is to optimize the token acquisition process, the first caller succeeding does caching + //which is then leveraged by subsequent threads. However, if the first thread takes considerable time, + //then we want the others to also go and try after waiting for a while. + //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints + //to get their tokens at the same time, stressing the auth endpoint. + // + isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + String hashedSecret = getHashedSecret(new String[] {fedAuthInfo.stsurl, user, password}); PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(user, hashedSecret); @@ -116,7 +128,7 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str .builder(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT), user, password.toCharArray()) .build()); - final IAuthenticationResult authenticationResult = future.get(); + final IAuthenticationResult authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); if (logger.isLoggable(Level.FINER)) { logger.finer( @@ -132,14 +144,18 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str throw new SQLServerException(e.getMessage(), e); } catch (MalformedURLException | ExecutionException e) { throw getCorrectedException(e, user, authenticationString); + } catch (TimeoutException e) { + throw getCorrectedException(new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e), user, authenticationString); } finally { - lock.unlock(); + if (isSemAcquired) { + sem.release(); + } executorService.shutdown(); } } static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuthInfo, String aadPrincipalID, - String aadPrincipalSecret, String authenticationString) throws SQLServerException { + String aadPrincipalSecret, String authenticationString, int millisecondsRemaining) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); if (logger.isLoggable(Level.FINEST)) { @@ -151,10 +167,19 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth : fedAuthInfo.spn + defaultScopeSuffix; Set scopes = new HashSet<>(); scopes.add(scope); - - lock.lock(); - + + boolean isSemAcquired = false; try { + // + //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. + //The purpose is to optimize the token acquisition process, the first caller succeeding does caching + //which is then leveraged by subsequent threads. However, if the first thread takes considerable time, + //then we want the others to also go and try after waiting for a while. + //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints + //to get their tokens at the same time, stressing the auth endpoint. + // + isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + String hashedSecret = getHashedSecret( new String[] {fedAuthInfo.stsurl, aadPrincipalID, aadPrincipalSecret}); PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(aadPrincipalID, @@ -181,7 +206,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth final CompletableFuture future = clientApplication .acquireToken(ClientCredentialParameters.builder(scopes).build()); - final IAuthenticationResult authenticationResult = future.get(); + final IAuthenticationResult authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); if (logger.isLoggable(Level.FINER)) { logger.finer( @@ -197,15 +222,19 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth throw new SQLServerException(e.getMessage(), e); } catch (MalformedURLException | ExecutionException e) { throw getCorrectedException(e, aadPrincipalID, authenticationString); + } catch (TimeoutException e) { + throw getCorrectedException(new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e), aadPrincipalID, authenticationString); } finally { - lock.unlock(); + if (isSemAcquired) { + sem.release(); + } executorService.shutdown(); } } static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthInfo fedAuthInfo, String aadPrincipalID, String certFile, String certPassword, String certKey, String certKeyPassword, - String authenticationString) throws SQLServerException { + String authenticationString, int millisecondsRemaining) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); if (logger.isLoggable(Level.FINEST)) { @@ -219,9 +248,18 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI Set scopes = new HashSet<>(); scopes.add(scope); - lock.lock(); - + boolean isSemAcquired = false; try { + // + //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. + //The purpose is to optimize the token acquisition process, the first caller succeeding does caching + //which is then leveraged by subsequent threads. However, if the first thread takes considerable time, + //then we want the others to also go and try after waiting for a while. + //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints + //to get their tokens at the same time, stressing the auth endpoint. + // + isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + String hashedSecret = getHashedSecret(new String[] {fedAuthInfo.stsurl, aadPrincipalID, certFile, certPassword, certKey, certKeyPassword}); PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(aadPrincipalID, @@ -297,7 +335,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI final CompletableFuture future = clientApplication .acquireToken(ClientCredentialParameters.builder(scopes).build()); - final IAuthenticationResult authenticationResult = future.get(); + final IAuthenticationResult authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); if (logger.isLoggable(Level.FINER)) { logger.finer( @@ -315,17 +353,21 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI // this includes all certificate exceptions throw new SQLServerException(SQLServerException.getErrString("R_readCertError") + e.getMessage(), null, 0, null); + } catch (TimeoutException e) { + throw getCorrectedException(new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e), aadPrincipalID, authenticationString); } catch (Exception e) { throw getCorrectedException(e, aadPrincipalID, authenticationString); } finally { - lock.unlock(); + if (isSemAcquired) { + sem.release(); + } executorService.shutdown(); } } static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo, - String authenticationString) throws SQLServerException { + String authenticationString, int millisecondsRemaining) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); /* @@ -340,9 +382,18 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut + "realm name:" + kerberosPrincipal.getRealm()); } - lock.lock(); - + boolean isSemAcquired = false; try { + // + //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. + //The purpose is to optimize the token acquisition process, the first caller succeeding does caching + //which is then leveraged by subsequent threads. However, if the first thread takes considerable time, + //then we want the others to also go and try after waiting for a while. + //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints + //to get their tokens at the same time, stressing the auth endpoint. + // + isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + final PublicClientApplication pca = PublicClientApplication .builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID).executorService(executorService) .setTokenCacheAccessAspect(PersistentTokenCacheAccessAspect.getInstance()) @@ -352,7 +403,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut .acquireToken(IntegratedWindowsAuthenticationParameters .builder(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT), user).build()); - final IAuthenticationResult authenticationResult = future.get(); + final IAuthenticationResult authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); if (logger.isLoggable(Level.FINER)) { logger.finer( @@ -368,23 +419,36 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut throw new SQLServerException(e.getMessage(), e); } catch (IOException | ExecutionException e) { throw getCorrectedException(e, user, authenticationString); + } catch (TimeoutException e) { + throw getCorrectedException(new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e), user, authenticationString); } finally { - lock.unlock(); + if (isSemAcquired) { + sem.release(); + } executorService.shutdown(); } } static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAuthInfo, String user, - String authenticationString) throws SQLServerException { + String authenticationString, int millisecondsRemaining) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); if (logger.isLoggable(Level.FINER)) { logger.finer(LOGCONTEXT + authenticationString + ": get FedAuth token interactive for user: " + user); } - lock.lock(); - + boolean isSemAcquired = false; try { + // + //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. + //The purpose is to optimize the token acquisition process, the first caller succeeding does caching + //which is then leveraged by subsequent threads. However, if the first thread takes considerable time, + //then we want the others to also go and try after waiting for a while. + //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints + //to get their tokens at the same time, stressing the auth endpoint. + // + isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + PublicClientApplication pca = PublicClientApplication .builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID).executorService(executorService) .setTokenCacheAccessAspect(PersistentTokenCacheAccessAspect.getInstance()) @@ -432,7 +496,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu } if (null != future) { - authenticationResult = future.get(); + authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); } else { // acquire token interactively with system browser if (logger.isLoggable(Level.FINEST)) { @@ -444,7 +508,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu .loginHint(user).scopes(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT)).build(); future = pca.acquireToken(parameters); - authenticationResult = future.get(); + authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); } if (logger.isLoggable(Level.FINER)) { @@ -461,8 +525,12 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu throw new SQLServerException(e.getMessage(), e); } catch (MalformedURLException | URISyntaxException | ExecutionException e) { throw getCorrectedException(e, user, authenticationString); + } catch (TimeoutException e) { + throw getCorrectedException(new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e), user, authenticationString); } finally { - lock.unlock(); + if (isSemAcquired) { + sem.release(); + } executorService.shutdown(); } } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java index 70c50ca28..d4e49ccde 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java @@ -8,6 +8,8 @@ import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; import java.text.MessageFormat; +import java.time.Duration; +import java.time.temporal.ChronoUnit; import java.util.Arrays; import java.util.HashMap; import java.util.Optional; @@ -56,6 +58,8 @@ class SQLServerSecurityUtility { private static final Lock CREDENTIAL_LOCK = new ReentrantLock(); + private static final int TOKEN_WAIT_DURATION_MS = 20000; + private SQLServerSecurityUtility() { throw new UnsupportedOperationException(SQLServerException.getErrString("R_notSupported")); } @@ -340,7 +344,7 @@ static void verifyColumnMasterKeyMetadata(SQLServerConnection connection, SQLSer * @throws SQLServerException */ static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, - String managedIdentityClientId) throws SQLServerException { + String managedIdentityClientId, long millisecondsRemaining) throws SQLServerException { if (logger.isLoggable(java.util.logging.Level.FINEST)) { logger.finest("Getting Managed Identity authentication token for: " + managedIdentityClientId); @@ -379,7 +383,7 @@ static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, SqlAuthenticationToken sqlFedAuthToken = null; - Optional accessTokenOptional = mic.getToken(tokenRequestContext).blockOptional(); + Optional accessTokenOptional = mic.getToken(tokenRequestContext).timeout(Duration.of(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), ChronoUnit.MILLIS)).blockOptional(); if (!accessTokenOptional.isPresent()) { throw new SQLServerException(SQLServerException.getErrString("R_ManagedIdentityTokenAcquisitionFail"), @@ -408,7 +412,7 @@ static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, * @throws SQLServerException */ static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource, - String managedIdentityClientId) throws SQLServerException { + String managedIdentityClientId, int millisecondsRemaining) throws SQLServerException { String intellijKeepassPath = System.getenv(INTELLIJ_KEEPASS_PASS); String[] additionallyAllowedTenants = getAdditonallyAllowedTenants(); @@ -463,7 +467,7 @@ static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource, SqlAuthenticationToken sqlFedAuthToken = null; - Optional accessTokenOptional = dac.getToken(tokenRequestContext).blockOptional(); + Optional accessTokenOptional = dac.getToken(tokenRequestContext).timeout(Duration.of(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), ChronoUnit.MILLIS)).blockOptional(); if (!accessTokenOptional.isPresent()) { throw new SQLServerException(SQLServerException.getErrString("R_ManagedIdentityTokenAcquisitionFail"), diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java index 916aa419f..787c8151e 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java @@ -42,6 +42,7 @@ import com.microsoft.aad.msal4j.TokenCache; import com.microsoft.aad.msal4j.TokenCacheAccessContext; +import com.microsoft.sqlserver.jdbc.SQLServerConnection.SqlFedAuthInfo; import com.microsoft.sqlserver.testframework.AbstractSQLGenerator; import com.microsoft.sqlserver.testframework.AbstractTest; import com.microsoft.sqlserver.testframework.Constants; @@ -50,6 +51,7 @@ @RunWith(JUnitPlatform.class) public class SQLServerConnectionTest extends AbstractTest { + // If no retry is done, the function should at least exit in 5 seconds static int threshHoldForNoRetryInMilliseconds = 5000; static int loginTimeOutInSeconds = 10; @@ -1321,4 +1323,51 @@ public void testServerNameField() throws SQLException { assertTrue(e.getMessage().matches(TestUtils.formatErrorMsg("R_errorServerName"))); } } + + + @Test + public void testGetSqlFedAuthTokenFailure() throws SQLException { + try (Connection conn = getConnection()){ + SqlFedAuthInfo fedAuthInfo = ((SQLServerConnection) conn).new SqlFedAuthInfo(); + fedAuthInfo.spn = "https://database.windows.net/"; + fedAuthInfo.stsurl = "https://login.windows.net/xxx"; + SqlAuthenticationToken fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, "xxx", + "xxx",SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString(), 10); + fail(TestResource.getResource("R_expectedExceptionNotThrown")); + } catch (SQLServerException e) { + //test pass + assertTrue(e.getMessage().contains(SQLServerException.getErrString("R_connectionTimedOut")), "Expected Timeout Exception was not thrown"); + } + } + + @Test + public void testGetSqlFedAuthTokenFailureNoWaiting() throws SQLException { + try (Connection conn = getConnection()){ + SqlFedAuthInfo fedAuthInfo = ((SQLServerConnection) conn).new SqlFedAuthInfo(); + fedAuthInfo.spn = "https://database.windows.net/"; + fedAuthInfo.stsurl = "https://login.windows.net/xxx"; + SqlAuthenticationToken fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, "xxx", + "xxx",SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString(), 0); + fail(TestResource.getResource("R_expectedExceptionNotThrown")); + } catch (SQLServerException e) { + //test pass + assertTrue(e.getMessage().contains(SQLServerException.getErrString("R_connectionTimedOut")), "Expected Timeout Exception was not thrown"); + } + } + + @Test + public void testGetSqlFedAuthTokenFailureNagativeWaiting() throws SQLException { + try (Connection conn = getConnection()){ + SqlFedAuthInfo fedAuthInfo = ((SQLServerConnection) conn).new SqlFedAuthInfo(); + fedAuthInfo.spn = "https://database.windows.net/"; + fedAuthInfo.stsurl = "https://login.windows.net/xxx"; + SqlAuthenticationToken fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, "xxx", + "xxx",SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString(), -1); + fail(TestResource.getResource("R_expectedExceptionNotThrown")); + } catch (SQLServerException e) { + //test pass + assertTrue(e.getMessage().contains(SQLServerException.getErrString("R_connectionTimedOut")), "Expected Timeout Exception was not thrown"); + } + } + } From 5b2cd71bf08bdca906faa8ef37519debeeb1c440 Mon Sep 17 00:00:00 2001 From: Mahendra Chavan Date: Fri, 24 Jan 2025 10:10:00 +0530 Subject: [PATCH 2/5] Incident#567732673 - Add provision to set SQLServerBulkCopy options in PreparedStatement (#2555) * Add provision to set SQLServerBulkCopy options in PreparedStatement * Reanmed test * Renamed bcCopyOptions to bcOptions * Fixed indentation * Fail the test that enables consraint checking when exception is not thrown * Added connection string options for all the available bulk copy options * Added test case for bulkCopyOptionDefaultsBatchSize, bulkCopyOptionDefaultsKeepIdentity and bulkCopyOptionDefaultsTableLock * Added test case for bulkCopty options in Prepared Statement * Updated failures in test case * Added test case testBulkCopyOptionDefaultsTimeoutLowerValue * Added test case for bulkCopyOptionDefaultsAllowEncryptedValueModifications * Fixed failures * Added test scenario for bulkCopyOptionDefaultsAllowEncryptedValueModifications * Improved messages. * Remove the setBulkCopyOptions method from PreparedStatement -Since the change has added connection string options for setting various bulk copy options, this non-standard API is not needed. * Fixed indent issue. * Update name for bulk copy batchSize option * Update name for bulk copy timeout option * Update name for bulk copy check constraints option * Update name for bulk copy fireTrigger, keepIdentity, keepNulls, tableLock, internalTrnsaction, encrptedValue options * Remove timeout option as configured already in bulk copy option * Updated value of bulkCopyForBatchInsertTimeout to 60s * Removed timeout and internal transaction option * removed local changes --------- Co-authored-by: muskan124947 Co-authored-by: Divang Sharma --- .../sqlserver/jdbc/ISQLServerConnection.java | 105 +++ .../sqlserver/jdbc/ISQLServerDataSource.java | 105 +++ .../jdbc/SQLServerBulkCopyOptions.java | 14 + .../sqlserver/jdbc/SQLServerConnection.java | 273 ++++++- .../jdbc/SQLServerConnectionPoolProxy.java | 139 ++++ .../sqlserver/jdbc/SQLServerDataSource.java | 90 +++ .../sqlserver/jdbc/SQLServerDriver.java | 26 +- .../jdbc/SQLServerPreparedStatement.java | 4 +- .../sqlserver/jdbc/SQLServerResource.java | 9 +- .../RequestBoundaryMethodsTest.java | 134 ++-- .../BatchExecutionWithBCOptionsTest.java | 683 ++++++++++++++++++ 11 files changed, 1533 insertions(+), 49 deletions(-) create mode 100644 src/test/java/com/microsoft/sqlserver/jdbc/preparedStatement/BatchExecutionWithBCOptionsTest.java diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerConnection.java index 52f1b67e6..f8959a9b1 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerConnection.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerConnection.java @@ -547,4 +547,109 @@ CallableStatement prepareCall(String sql, int nType, int nConcur, int nHold, * @return cacheBulkCopyMetadata boolean value */ boolean getcacheBulkCopyMetadata(); + + /** + * Specifies the default batch size for bulk copy operations created from batch insert operations. + * + * @param bulkCopyForBatchInsertBatchSize + * integer value for bulkCopyForBatchInsertBatchSize. + */ + void setBulkCopyForBatchInsertBatchSize(int bulkCopyForBatchInsertBatchSize); + + /** + * Returns the default batch size for bulk copy operations created from batch insert operations. + * + * @return integer value for bulkCopyForBatchInsertBatchSize. + */ + int getBulkCopyForBatchInsertBatchSize(); + + /** + * Specifies the default check constraints for bulk copy operations created from batch insert operations. + * + * @param bulkCopyForBatchInsertCheckConstraints + * boolean value for bulkCopyForBatchInsertCheckConstraints. + */ + void setBulkCopyForBatchInsertCheckConstraints(boolean bulkCopyForBatchInsertCheckConstraints); + + /** + * Returns the default check constraints for bulk copy operations created from batch insert operations. + * + * @return boolean value for bulkCopyForBatchInsertCheckConstraints. + */ + boolean getBulkCopyForBatchInsertCheckConstraints(); + + /** + * Specifies the default fire triggers for bulk copy operations created from batch insert operations. + * + * @param bulkCopyForBatchInsertFireTriggers + * boolean value for bulkCopyForBatchInsertFireTriggers. + */ + void setBulkCopyForBatchInsertFireTriggers(boolean bulkCopyForBatchInsertFireTriggers); + + /** + * Returns the default fire triggers for bulk copy operations created from batch insert operations. + * + * @return boolean value for bulkCopyForBatchInsertFireTriggers. + */ + boolean getBulkCopyForBatchInsertFireTriggers(); + + /** + * Specifies the default keep identity for bulk copy operations created from batch insert operations. + * + * @param bulkCopyForBatchInsertKeepIdentity + * boolean value for bulkCopyForBatchInsertKeepIdentity. + */ + void setBulkCopyForBatchInsertKeepIdentity(boolean bulkCopyForBatchInsertKeepIdentity); + + /** + * Returns the default keep identity for bulk copy operations created from batch insert operations. + * + * @return boolean value for bulkCopyForBatchInsertKeepIdentity. + */ + boolean getBulkCopyForBatchInsertKeepIdentity(); + + /** + * Specifies the default keep nulls for bulk copy operations created from batch insert operations. + * + * @param bulkCopyForBatchInsertKeepNulls + * boolean value for bulkCopyForBatchInsertKeepNulls. + */ + void setBulkCopyForBatchInsertKeepNulls(boolean bulkCopyForBatchInsertKeepNulls); + + /** + * Returns the default keep nulls for bulk copy operations created from batch insert operations. + * + * @return boolean value for bulkCopyForBatchInsertKeepNulls. + */ + boolean getBulkCopyForBatchInsertKeepNulls(); + + /** + * Specifies the default table lock for bulk copy operations created from batch insert operations. + * + * @param bulkCopyForBatchInsertTableLock + * boolean value for bulkCopyForBatchInsertTableLock. + */ + void setBulkCopyForBatchInsertTableLock(boolean bulkCopyForBatchInsertTableLock); + + /** + * Returns the default table lock for bulk copy operations created from batch insert operations. + * + * @return boolean value for bulkCopyForBatchInsertTableLock. + */ + boolean getBulkCopyForBatchInsertTableLock(); + + /** + * Specifies the default allow encrypted value modifications for bulk copy operations created from batch insert operations. + * + * @param bulkCopyForBatchInsertAllowEncryptedValueModifications + * boolean value for bulkCopyForBatchInsertAllowEncryptedValueModifications. + */ + void setBulkCopyForBatchInsertAllowEncryptedValueModifications(boolean bulkCopyForBatchInsertAllowEncryptedValueModifications); + + /** + * Returns the default allow encrypted value modifications for bulk copy operations created from batch insert operations. + * + * @return boolean value for bulkCopyForBatchInsertAllowEncryptedValueModifications. + */ + boolean getBulkCopyForBatchInsertAllowEncryptedValueModifications(); } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerDataSource.java b/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerDataSource.java index ec7067220..62fdeda8b 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerDataSource.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/ISQLServerDataSource.java @@ -980,6 +980,111 @@ public interface ISQLServerDataSource extends javax.sql.CommonDataSource { */ void setUseBulkCopyForBatchInsert(boolean useBulkCopyForBatchInsert); + /** + * Sets the default batch size for bulk copy operations created from batch insert operations. + * + * @param bulkCopyForBatchInsertBatchSize + * the default batch size for bulk copy operations created from batch insert operations. + */ + void setBulkCopyForBatchInsertBatchSize(int bulkCopyForBatchInsertBatchSize); + + /** + * Returns the default batch size for bulk copy operations created from batch insert operations. + * + * @return the default batch size for bulk copy operations created from batch insert operations. + */ + int getBulkCopyForBatchInsertBatchSize(); + + /** + * Sets whether to check constraints during bulk copy operations created from batch insert operations. + * + * @param bulkCopyForBatchInsertCheckConstraints + * indicates whether to check constraints during bulk copy operations created from batch insert operations. + */ + void setBulkCopyForBatchInsertCheckConstraints(boolean bulkCopyForBatchInsertCheckConstraints); + + /** + * Returns whether to check constraints during bulk copy operations created from batch insert operations. + * + * @return whether to check constraints during bulk copy operations created from batch insert operations. + */ + boolean getBulkCopyForBatchInsertCheckConstraints(); + + /** + * Sets whether to fire triggers during bulk copy operations created from batch insert operations. + * + * @param bulkCopyForBatchInsertFireTriggers + * indicates whether to fire triggers during bulk copy operations created from batch insert operations. + */ + void setBulkCopyForBatchInsertFireTriggers(boolean bulkCopyForBatchInsertFireTriggers); + + /** + * Returns whether to fire triggers during bulk copy operations created from batch insert operations. + * + * @return whether to fire triggers during bulk copy operations created from batch insert operations. + */ + boolean getBulkCopyForBatchInsertFireTriggers(); + + /** + * Sets whether to keep identity values during bulk copy operations created from batch insert operations. + * + * @param bulkCopyForBatchInsertKeepIdentity + * indicates whether to keep identity values during bulk copy operations created from batch insert operations. + */ + void setBulkCopyForBatchInsertKeepIdentity(boolean bulkCopyForBatchInsertKeepIdentity); + + /** + * Returns whether to keep identity values during bulk copy operations created from batch insert operations. + * + * @return whether to keep identity values during bulk copy operations created from batch insert operations. + */ + boolean getBulkCopyForBatchInsertKeepIdentity(); + + /** + * Sets whether to keep null values during bulk copy operations created from batch insert operations. + * + * @param bulkCopyForBatchInsertKeepNulls + * indicates whether to keep null values during bulk copy operations created from batch insert operations. + */ + void setBulkCopyForBatchInsertKeepNulls(boolean bulkCopyForBatchInsertKeepNulls); + + /** + * Returns whether to keep null values during bulk copy operations created from batch insert operations. + * + * @return whether to keep null values during bulk copy operations created from batch insert operations. + */ + boolean getBulkCopyForBatchInsertKeepNulls(); + + /** + * Sets whether to use table lock during bulk copy operations created from batch insert operations. + * + * @param bulkCopyForBatchInsertTableLock + * indicates whether to use table lock during bulk copy operations created from batch insert operations. + */ + void setBulkCopyForBatchInsertTableLock(boolean bulkCopyForBatchInsertTableLock); + + /** + * Returns whether to use table lock during bulk copy operations created from batch insert operations. + * + * @return whether to use table lock during bulk copy operations created from batch insert operations. + */ + boolean getBulkCopyForBatchInsertTableLock(); + + /** + * Sets whether to allow encrypted value modifications during bulk copy operations created from batch insert operations. + * + * @param bulkCopyForBatchInsertAllowEncryptedValueModifications + * indicates whether to allow encrypted value modifications during bulk copy operations created from batch insert operations. + */ + void setBulkCopyForBatchInsertAllowEncryptedValueModifications(boolean bulkCopyForBatchInsertAllowEncryptedValueModifications); + + /** + * Returns whether to allow encrypted value modifications during bulk copy operations created from batch insert operations. + * + * @return whether to allow encrypted value modifications during bulk copy operations created from batch insert operations. + */ + boolean getBulkCopyForBatchInsertAllowEncryptedValueModifications(); + /** * Sets the client id to be used to retrieve the access token for a user-assigned Managed Identity. * diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCopyOptions.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCopyOptions.java index 24e145a26..7756add0e 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCopyOptions.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCopyOptions.java @@ -103,6 +103,20 @@ public SQLServerBulkCopyOptions() { useInternalTransaction = false; allowEncryptedValueModifications = false; } + + /** + * Constructs a SQLServerBulkCopySettings class using defaults from given connection + */ + SQLServerBulkCopyOptions(SQLServerConnection conn) { + batchSize = conn.getBulkCopyForBatchInsertBatchSize(); + checkConstraints = conn.getBulkCopyForBatchInsertCheckConstraints(); + fireTriggers = conn.getBulkCopyForBatchInsertFireTriggers(); + keepIdentity = conn.getBulkCopyForBatchInsertKeepIdentity(); + keepNulls = conn.getBulkCopyForBatchInsertKeepNulls(); + tableLock = conn.getBulkCopyForBatchInsertTableLock(); + allowEncryptedValueModifications = conn.getBulkCopyForBatchInsertAllowEncryptedValueModifications(); + } + /** * Returns the number of rows in each batch. At the end of each batch, the rows in the batch are sent to the server. diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java index 5f076f9af..be4df3139 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java @@ -833,7 +833,7 @@ final int getSocketTimeoutMilliseconds() { * boolean value for deciding if the driver should use bulk copy API for batch inserts. */ private boolean useBulkCopyForBatchInsert; - + /** * Returns the useBulkCopyForBatchInsert value. * @@ -843,7 +843,7 @@ final int getSocketTimeoutMilliseconds() { public boolean getUseBulkCopyForBatchInsert() { return useBulkCopyForBatchInsert; } - + /** * Specifies the flag for using Bulk Copy API for batch insert operations. * @@ -855,6 +855,175 @@ public void setUseBulkCopyForBatchInsert(boolean useBulkCopyForBatchInsert) { this.useBulkCopyForBatchInsert = useBulkCopyForBatchInsert; } + /** + * The default batch size for bulk copy operations created from batch insert operations. + */ + private int bulkCopyForBatchInsertBatchSize = 0; + + /** + * Returns the bulkCopyForBatchInsertBatchSize value. + * + * @return the bulkCopyForBatchInsertBatchSize value. + */ + public int getBulkCopyForBatchInsertBatchSize() { + return bulkCopyForBatchInsertBatchSize; + } + + /** + * Sets the bulkCopyForBatchInsertBatchSize value. + * + * @param bulkCopyForBatchInsertBatchSize + * the bulkCopyForBatchInsertBatchSize value to set. + */ + public void setBulkCopyForBatchInsertBatchSize(int bulkCopyForBatchInsertBatchSize) { + this.bulkCopyForBatchInsertBatchSize = bulkCopyForBatchInsertBatchSize; + } + + /** + * Whether to check constraints during bulk copy operations. + */ + private boolean bulkCopyForBatchInsertCheckConstraints = false; + + /** + * Returns the bulkCopyForBatchInsertCheckConstraints value. + * + * @return the bulkCopyForBatchInsertCheckConstraints value. + */ + public boolean getBulkCopyForBatchInsertCheckConstraints() { + return bulkCopyForBatchInsertCheckConstraints; + } + + /** + * Sets the bulkCopyForBatchInsertCheckConstraints value. + * + * @param bulkCopyForBatchInsertCheckConstraints + * the bulkCopyForBatchInsertCheckConstraints value to set. + */ + public void setBulkCopyForBatchInsertCheckConstraints(boolean bulkCopyForBatchInsertCheckConstraints) { + this.bulkCopyForBatchInsertCheckConstraints = bulkCopyForBatchInsertCheckConstraints; + } + + /** + * Returns the bulkCopyForBatchInsertFireTriggers value. + * + * @return the bulkCopyForBatchInsertFireTriggers value. + */ + public boolean getBulkCopyForBatchInsertFireTriggers() { + return bulkCopyForBatchInsertFireTriggers; + } + + /** + * Whether to fire triggers during bulk copy operations. + */ + private boolean bulkCopyForBatchInsertFireTriggers = false; + + /** + * Sets the bulkCopyForBatchInsertFireTriggers value. + * + * @param bulkCopyForBatchInsertFireTriggers + * the bulkCopyForBatchInsertFireTriggers value to set. + */ + public void setBulkCopyForBatchInsertFireTriggers(boolean bulkCopyForBatchInsertFireTriggers) { + this.bulkCopyForBatchInsertFireTriggers = bulkCopyForBatchInsertFireTriggers; + } + + /** + * Whether to keep identity values during bulk copy operations. + */ + private boolean bulkCopyForBatchInsertKeepIdentity = false; + + /** + * Returns the bulkCopyForBatchInsertKeepIdentity value. + * + * @return the bulkCopyForBatchInsertKeepIdentity value. + */ + public boolean getBulkCopyForBatchInsertKeepIdentity() { + return bulkCopyForBatchInsertKeepIdentity; + } + + /** + * Sets the bulkCopyForBatchInsertKeepIdentity value. + * + * @param bulkCopyForBatchInsertKeepIdentity + * the bulkCopyForBatchInsertKeepIdentity value to set. + */ + public void setBulkCopyForBatchInsertKeepIdentity(boolean bulkCopyForBatchInsertKeepIdentity) { + this.bulkCopyForBatchInsertKeepIdentity = bulkCopyForBatchInsertKeepIdentity; + } + + /** + * Whether to keep null values during bulk copy operations. + */ + private boolean bulkCopyForBatchInsertKeepNulls = false; + + /** + * Returns the bulkCopyForBatchInsertKeepNulls value. + * + * @return the bulkCopyForBatchInsertKeepNulls value. + */ + public boolean getBulkCopyForBatchInsertKeepNulls() { + return bulkCopyForBatchInsertKeepNulls; + } + + /** + * Sets the bulkCopyForBatchInsertKeepNulls value. + * + * @param bulkCopyForBatchInsertKeepNulls + * the bulkCopyForBatchInsertKeepNulls value to set. + */ + public void setBulkCopyForBatchInsertKeepNulls(boolean bulkCopyForBatchInsertKeepNulls) { + this.bulkCopyForBatchInsertKeepNulls = bulkCopyForBatchInsertKeepNulls; + } + + /** + * Whether to use table lock during bulk copy operations. + */ + private boolean bulkCopyForBatchInsertTableLock = false; + + /** + * Returns the bulkCopyForBatchInsertTableLock value. + * + * @return the bulkCopyForBatchInsertTableLock value. + */ + public boolean getBulkCopyForBatchInsertTableLock() { + return bulkCopyForBatchInsertTableLock; + } + + /** + * Sets the bulkCopyForBatchInsertTableLock value. + * + * @param bulkCopyForBatchInsertTableLock + * the bulkCopyForBatchInsertTableLock value to set. + */ + public void setBulkCopyForBatchInsertTableLock(boolean bulkCopyForBatchInsertTableLock) { + this.bulkCopyForBatchInsertTableLock = bulkCopyForBatchInsertTableLock; + } + + /** + * Whether to allow encrypted value modifications during bulk copy operations. + */ + private boolean bulkCopyForBatchInsertAllowEncryptedValueModifications = false; + + + /** + * Returns the bulkCopyForBatchInsertAllowEncryptedValueModifications value. + * + * @return the bulkCopyForBatchInsertAllowEncryptedValueModifications value. + */ + public boolean getBulkCopyForBatchInsertAllowEncryptedValueModifications() { + return bulkCopyForBatchInsertAllowEncryptedValueModifications; + } + + /** + * Sets the bulkCopyForBatchInsertAllowEncryptedValueModifications value. + * + * @param bulkCopyForBatchInsertAllowEncryptedValueModifications + * the bulkCopyForBatchInsertAllowEncryptedValueModifications value to set. + */ + public void setBulkCopyForBatchInsertAllowEncryptedValueModifications(boolean bulkCopyForBatchInsertAllowEncryptedValueModifications) { + this.bulkCopyForBatchInsertAllowEncryptedValueModifications = bulkCopyForBatchInsertAllowEncryptedValueModifications; + } + /** user set TNIR flag */ boolean userSetTNIR = true; @@ -3118,6 +3287,48 @@ else if (0 == requestedPacketSize) useBulkCopyForBatchInsert = isBooleanPropertyOn(sPropKey, sPropValue); } + sPropKey = SQLServerDriverIntProperty.BULK_COPY_FOR_BATCH_INSERT_BATCH_SIZE.toString(); + sPropValue = activeConnectionProperties.getProperty(sPropKey); + if (null != sPropValue) { + bulkCopyForBatchInsertBatchSize = Integer.parseInt(sPropValue); + } + + sPropKey = SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_CHECK_CONSTRAINTS.toString(); + sPropValue = activeConnectionProperties.getProperty(sPropKey); + if (null != sPropValue) { + bulkCopyForBatchInsertCheckConstraints = isBooleanPropertyOn(sPropKey, sPropValue); + } + + sPropKey = SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_FIRE_TRIGGERS.toString(); + sPropValue = activeConnectionProperties.getProperty(sPropKey); + if (null != sPropValue) { + bulkCopyForBatchInsertFireTriggers = isBooleanPropertyOn(sPropKey, sPropValue); + } + + sPropKey = SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_KEEP_IDENTITY.toString(); + sPropValue = activeConnectionProperties.getProperty(sPropKey); + if (null != sPropValue) { + bulkCopyForBatchInsertKeepIdentity = isBooleanPropertyOn(sPropKey, sPropValue); + } + + sPropKey = SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_KEEP_NULLS.toString(); + sPropValue = activeConnectionProperties.getProperty(sPropKey); + if (null != sPropValue) { + bulkCopyForBatchInsertKeepNulls = isBooleanPropertyOn(sPropKey, sPropValue); + } + + sPropKey = SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_TABLE_LOCK.toString(); + sPropValue = activeConnectionProperties.getProperty(sPropKey); + if (null != sPropValue) { + bulkCopyForBatchInsertTableLock = isBooleanPropertyOn(sPropKey, sPropValue); + } + + sPropKey = SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_ALLOW_ENCRYPTED_VALUE_MODIFICATIONS.toString(); + sPropValue = activeConnectionProperties.getProperty(sPropKey); + if (null != sPropValue) { + bulkCopyForBatchInsertAllowEncryptedValueModifications = isBooleanPropertyOn(sPropKey, sPropValue); + } + sPropKey = SQLServerDriverBooleanProperty.ENABLE_BULK_COPY_CACHE.toString(); sPropValue = activeConnectionProperties.getProperty(sPropKey); if (null != sPropValue) { @@ -7630,6 +7841,27 @@ public T unwrap(Class iface) throws SQLException { /** original useBulkCopyForBatchInsert flag */ private boolean originalUseBulkCopyForBatchInsert; + /** original bulkCopyForBatchInsertBatchSize */ + private int originalBulkCopyForBatchInsertBatchSize; + + /** original bulkCopyForBatchInsertCheckConstraints flag */ + private boolean originalBulkCopyForBatchInsertCheckConstraints; + + /** original bulkCopyForBatchInsertFireTriggers flag */ + private boolean originalBulkCopyForBatchInsertFireTriggers; + + /** original bulkCopyForBatchInsertKeepIdentity flag */ + private boolean originalBulkCopyForBatchInsertKeepIdentity; + + /** original bulkCopyForBatchInsertKeepNulls flag */ + private boolean originalBulkCopyForBatchInsertKeepNulls; + + /** original bulkCopyForBatchInsertTableLock flag */ + private boolean originalBulkCopyForBatchInsertTableLock; + + /** original bulkCopyForBatchInsertAllowEncryptedValueModifications flag */ + private boolean originalBulkCopyForBatchInsertAllowEncryptedValueModifications; + /** original SqlWarnings */ private volatile SQLWarning originalSqlWarnings; @@ -7665,6 +7897,13 @@ void beginRequestInternal() throws SQLException { originalEnablePrepareOnFirstPreparedStatementCall = getEnablePrepareOnFirstPreparedStatementCall(); originalSCatalog = sCatalog; originalUseBulkCopyForBatchInsert = getUseBulkCopyForBatchInsert(); + originalBulkCopyForBatchInsertBatchSize = getBulkCopyForBatchInsertBatchSize(); + originalBulkCopyForBatchInsertCheckConstraints = getBulkCopyForBatchInsertCheckConstraints(); + originalBulkCopyForBatchInsertFireTriggers = getBulkCopyForBatchInsertFireTriggers(); + originalBulkCopyForBatchInsertKeepIdentity = getBulkCopyForBatchInsertKeepIdentity(); + originalBulkCopyForBatchInsertKeepNulls = getBulkCopyForBatchInsertKeepNulls(); + originalBulkCopyForBatchInsertTableLock = getBulkCopyForBatchInsertTableLock(); + originalBulkCopyForBatchInsertAllowEncryptedValueModifications = getBulkCopyForBatchInsertAllowEncryptedValueModifications(); originalSqlWarnings = sqlWarnings; openStatements = new LinkedList<>(); originalUseFmtOnly = useFmtOnly; @@ -7722,9 +7961,39 @@ void endRequestInternal() throws SQLException { if (!sCatalog.equals(originalSCatalog)) { setCatalog(originalSCatalog); } + if (getUseBulkCopyForBatchInsert() != originalUseBulkCopyForBatchInsert) { setUseBulkCopyForBatchInsert(originalUseBulkCopyForBatchInsert); } + + if (getBulkCopyForBatchInsertBatchSize() != originalBulkCopyForBatchInsertBatchSize) { + setBulkCopyForBatchInsertBatchSize(originalBulkCopyForBatchInsertBatchSize); + } + + if (getBulkCopyForBatchInsertCheckConstraints() != originalBulkCopyForBatchInsertCheckConstraints) { + setBulkCopyForBatchInsertCheckConstraints(originalBulkCopyForBatchInsertCheckConstraints); + } + + if (getBulkCopyForBatchInsertFireTriggers() != originalBulkCopyForBatchInsertFireTriggers) { + setBulkCopyForBatchInsertFireTriggers(originalBulkCopyForBatchInsertFireTriggers); + } + + if (getBulkCopyForBatchInsertKeepIdentity() != originalBulkCopyForBatchInsertKeepIdentity) { + setBulkCopyForBatchInsertKeepIdentity(originalBulkCopyForBatchInsertKeepIdentity); + } + + if (getBulkCopyForBatchInsertKeepNulls() != originalBulkCopyForBatchInsertKeepNulls) { + setBulkCopyForBatchInsertKeepNulls(originalBulkCopyForBatchInsertKeepNulls); + } + + if (getBulkCopyForBatchInsertTableLock() != originalBulkCopyForBatchInsertTableLock) { + setBulkCopyForBatchInsertTableLock(originalBulkCopyForBatchInsertTableLock); + } + + if (getBulkCopyForBatchInsertAllowEncryptedValueModifications() != originalBulkCopyForBatchInsertAllowEncryptedValueModifications) { + setBulkCopyForBatchInsertAllowEncryptedValueModifications(originalBulkCopyForBatchInsertAllowEncryptedValueModifications); + } + if (delayLoadingLobs != originalDelayLoadingLobs) { setDelayLoadingLobs(originalDelayLoadingLobs); } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionPoolProxy.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionPoolProxy.java index d9eb047ea..9f67a8321 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionPoolProxy.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionPoolProxy.java @@ -753,4 +753,143 @@ public boolean getUseBulkCopyForBatchInsert() { public void setUseBulkCopyForBatchInsert(boolean useBulkCopyForBatchInsert) { wrappedConnection.setUseBulkCopyForBatchInsert(useBulkCopyForBatchInsert); } + + /** + * The default batch size for bulk copy operations created from batch insert operations. + */ + private int bulkCopyForBatchInsertBatchSize = 0; + + /** + * Returns the bulkCopyForBatchInsertBatchSize value. + * + * @return the bulkCopyForBatchInsertBatchSize value. + */ + public int getBulkCopyForBatchInsertBatchSize() { + return wrappedConnection.getBulkCopyForBatchInsertBatchSize(); + } + + /** + * Sets the bulkCopyForBatchInsertBatchSize value. + * + * @param bulkCopyForBatchInsertBatchSize + * the bulkCopyForBatchInsertBatchSize value to set. + */ + public void setBulkCopyForBatchInsertBatchSize(int bulkCopyForBatchInsertBatchSize) { + wrappedConnection.setBulkCopyForBatchInsertBatchSize(bulkCopyForBatchInsertBatchSize); + } + + /** + * Returns the bulkCopyForBatchInsertCheckConstraints value. + * + * @return the bulkCopyForBatchInsertCheckConstraints value. + */ + public boolean getBulkCopyForBatchInsertCheckConstraints() { + return wrappedConnection.getBulkCopyForBatchInsertCheckConstraints(); + } + + /** + * Sets the bulkCopyForBatchInsertCheckConstraints value. + * + * @param bulkCopyForBatchInsertCheckConstraints + * the bulkCopyForBatchInsertCheckConstraints value to set. + */ + public void setBulkCopyForBatchInsertCheckConstraints(boolean bulkCopyForBatchInsertCheckConstraints) { + wrappedConnection.setBulkCopyForBatchInsertCheckConstraints(bulkCopyForBatchInsertCheckConstraints); + } + + /** + * Returns the bulkCopyForBatchInsertFireTriggers value. + * + * @return the bulkCopyForBatchInsertFireTriggers value. + */ + public boolean getBulkCopyForBatchInsertFireTriggers() { + return wrappedConnection.getBulkCopyForBatchInsertFireTriggers(); + } + + /** + * Sets the bulkCopyForBatchInsertFireTriggers value. + * + * @param bulkCopyForBatchInsertFireTriggers + * the bulkCopyForBatchInsertFireTriggers value to set. + */ + public void setBulkCopyForBatchInsertFireTriggers(boolean bulkCopyForBatchInsertFireTriggers) { + wrappedConnection.setBulkCopyForBatchInsertFireTriggers(bulkCopyForBatchInsertFireTriggers); + } + + /** + * Returns the bulkCopyForBatchInsertKeepIdentity value. + * + * @return the bulkCopyForBatchInsertKeepIdentity value. + */ + public boolean getBulkCopyForBatchInsertKeepIdentity() { + return wrappedConnection.getBulkCopyForBatchInsertKeepIdentity(); + } + + /** + * Sets the bulkCopyForBatchInsertKeepIdentity value. + * + * @param bulkCopyForBatchInsertKeepIdentity + * the bulkCopyForBatchInsertKeepIdentity value to set. + */ + public void setBulkCopyForBatchInsertKeepIdentity(boolean bulkCopyForBatchInsertKeepIdentity) { + wrappedConnection.setBulkCopyForBatchInsertKeepIdentity(bulkCopyForBatchInsertKeepIdentity); + } + + /** + * Returns the bulkCopyForBatchInsertKeepNulls value. + * + * @return the bulkCopyForBatchInsertKeepNulls value. + */ + public boolean getBulkCopyForBatchInsertKeepNulls() { + return wrappedConnection.getBulkCopyForBatchInsertKeepNulls(); + } + + /** + * Sets the bulkCopyForBatchInsertKeepNulls value. + * + * @param bulkCopyForBatchInsertKeepNulls + * the bulkCopyForBatchInsertKeepNulls value to set. + */ + public void setBulkCopyForBatchInsertKeepNulls(boolean bulkCopyForBatchInsertKeepNulls) { + wrappedConnection.setBulkCopyForBatchInsertKeepNulls(bulkCopyForBatchInsertKeepNulls); + } + + /** + * Returns the bulkCopyForBatchInsertTableLock value. + * + * @return the bulkCopyForBatchInsertTableLock value. + */ + public boolean getBulkCopyForBatchInsertTableLock() { + return wrappedConnection.getBulkCopyForBatchInsertTableLock(); + } + + /** + * Sets the bulkCopyForBatchInsertTableLock value. + * + * @param bulkCopyForBatchInsertTableLock + * the bulkCopyForBatchInsertTableLock value to set. + */ + public void setBulkCopyForBatchInsertTableLock(boolean bulkCopyForBatchInsertTableLock) { + wrappedConnection.setBulkCopyForBatchInsertTableLock(bulkCopyForBatchInsertTableLock); + } + + /** + * Returns the bulkCopyForBatchInsertAllowEncryptedValueModifications value. + * + * @return the bulkCopyForBatchInsertAllowEncryptedValueModifications value. + */ + public boolean getBulkCopyForBatchInsertAllowEncryptedValueModifications() { + return wrappedConnection.getBulkCopyForBatchInsertAllowEncryptedValueModifications(); + } + + /** + * Sets the bulkCopyForBatchInsertAllowEncryptedValueModifications value. + * + * @param bulkCopyForBatchInsertAllowEncryptedValueModifications + * the bulkCopyForBatchInsertAllowEncryptedValueModifications value to set. + */ + public void setBulkCopyForBatchInsertAllowEncryptedValueModifications(boolean bulkCopyForBatchInsertAllowEncryptedValueModifications) { + wrappedConnection.setBulkCopyForBatchInsertAllowEncryptedValueModifications(bulkCopyForBatchInsertAllowEncryptedValueModifications); + } + } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java index 480f36ba3..1f3c7993b 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java @@ -1025,6 +1025,96 @@ public boolean getUseBulkCopyForBatchInsert() { SQLServerDriverBooleanProperty.USE_BULK_COPY_FOR_BATCH_INSERT.getDefaultValue()); } + @Override + public void setBulkCopyForBatchInsertBatchSize(int bulkCopyForBatchInsertBatchSize) { + setIntProperty(connectionProps, SQLServerDriverIntProperty.BULK_COPY_FOR_BATCH_INSERT_BATCH_SIZE.toString(), + bulkCopyForBatchInsertBatchSize); + } + + @Override + public int getBulkCopyForBatchInsertBatchSize() { + return getIntProperty(connectionProps, + SQLServerDriverIntProperty.BULK_COPY_FOR_BATCH_INSERT_BATCH_SIZE.toString(), + SQLServerDriverIntProperty.BULK_COPY_FOR_BATCH_INSERT_BATCH_SIZE.getDefaultValue()); + } + + @Override + public void setBulkCopyForBatchInsertCheckConstraints(boolean bulkCopyForBatchInsertCheckConstraints) { + setBooleanProperty(connectionProps, SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_CHECK_CONSTRAINTS.toString(), + bulkCopyForBatchInsertCheckConstraints); + } + + @Override + public boolean getBulkCopyForBatchInsertCheckConstraints() { + return getBooleanProperty(connectionProps, + SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_CHECK_CONSTRAINTS.toString(), + SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_CHECK_CONSTRAINTS.getDefaultValue()); + } + + @Override + public void setBulkCopyForBatchInsertFireTriggers(boolean bulkCopyForBatchInsertFireTriggers) { + setBooleanProperty(connectionProps, SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_FIRE_TRIGGERS.toString(), + bulkCopyForBatchInsertFireTriggers); + } + + @Override + public boolean getBulkCopyForBatchInsertFireTriggers() { + return getBooleanProperty(connectionProps, + SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_FIRE_TRIGGERS.toString(), + SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_FIRE_TRIGGERS.getDefaultValue()); + } + + @Override + public void setBulkCopyForBatchInsertKeepIdentity(boolean bulkCopyForBatchInsertKeepIdentity) { + setBooleanProperty(connectionProps, SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_KEEP_IDENTITY.toString(), + bulkCopyForBatchInsertKeepIdentity); + } + + @Override + public boolean getBulkCopyForBatchInsertKeepIdentity() { + return getBooleanProperty(connectionProps, + SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_KEEP_IDENTITY.toString(), + SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_KEEP_IDENTITY.getDefaultValue()); + } + + @Override + public void setBulkCopyForBatchInsertKeepNulls(boolean bulkCopyForBatchInsertKeepNulls) { + setBooleanProperty(connectionProps, SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_KEEP_NULLS.toString(), + bulkCopyForBatchInsertKeepNulls); + } + + @Override + public boolean getBulkCopyForBatchInsertKeepNulls() { + return getBooleanProperty(connectionProps, + SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_KEEP_NULLS.toString(), + SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_KEEP_NULLS.getDefaultValue()); + } + + @Override + public void setBulkCopyForBatchInsertTableLock(boolean bulkCopyForBatchInsertTableLock) { + setBooleanProperty(connectionProps, SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_TABLE_LOCK.toString(), + bulkCopyForBatchInsertTableLock); + } + + @Override + public boolean getBulkCopyForBatchInsertTableLock() { + return getBooleanProperty(connectionProps, + SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_TABLE_LOCK.toString(), + SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_TABLE_LOCK.getDefaultValue()); + } + + @Override + public void setBulkCopyForBatchInsertAllowEncryptedValueModifications(boolean bulkCopyForBatchInsertAllowEncryptedValueModifications) { + setBooleanProperty(connectionProps, SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_ALLOW_ENCRYPTED_VALUE_MODIFICATIONS.toString(), + bulkCopyForBatchInsertAllowEncryptedValueModifications); + } + + @Override + public boolean getBulkCopyForBatchInsertAllowEncryptedValueModifications() { + return getBooleanProperty(connectionProps, + SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_ALLOW_ENCRYPTED_VALUE_MODIFICATIONS.toString(), + SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_ALLOW_ENCRYPTED_VALUE_MODIFICATIONS.getDefaultValue()); + } /** * @deprecated */ diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java index 1ffffa6f0..17264ce59 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java @@ -644,7 +644,8 @@ enum SQLServerDriverIntProperty { STATEMENT_POOLING_CACHE_SIZE("statementPoolingCacheSize", SQLServerConnection.DEFAULT_STATEMENT_POOLING_CACHE_SIZE), CANCEL_QUERY_TIMEOUT("cancelQueryTimeout", -1), CONNECT_RETRY_COUNT("connectRetryCount", 1, 0, 255), - CONNECT_RETRY_INTERVAL("connectRetryInterval", 10, 1, 60); + CONNECT_RETRY_INTERVAL("connectRetryInterval", 10, 1, 60), + BULK_COPY_FOR_BATCH_INSERT_BATCH_SIZE("bulkCopyForBatchInsertBatchSize", 0); private final String name; private final int defaultValue; @@ -694,6 +695,12 @@ enum SQLServerDriverBooleanProperty { ENABLE_PREPARE_ON_FIRST_PREPARED_STATEMENT("enablePrepareOnFirstPreparedStatementCall", SQLServerConnection.DEFAULT_ENABLE_PREPARE_ON_FIRST_PREPARED_STATEMENT_CALL), ENABLE_BULK_COPY_CACHE("cacheBulkCopyMetadata", false), USE_BULK_COPY_FOR_BATCH_INSERT("useBulkCopyForBatchInsert", false), + BULK_COPY_FOR_BATCH_INSERT_CHECK_CONSTRAINTS("bulkCopyForBatchInsertCheckConstraints", false), + BULK_COPY_FOR_BATCH_INSERT_FIRE_TRIGGERS("bulkCopyForBatchInsertFireTriggers", false), + BULK_COPY_FOR_BATCH_INSERT_KEEP_IDENTITY("bulkCopyForBatchInsertKeepIdentity", false), + BULK_COPY_FOR_BATCH_INSERT_KEEP_NULLS("bulkCopyForBatchInsertKeepNulls", false), + BULK_COPY_FOR_BATCH_INSERT_TABLE_LOCK("bulkCopyForBatchInsertTableLock", false), + BULK_COPY_FOR_BATCH_INSERT_ALLOW_ENCRYPTED_VALUE_MODIFICATIONS("bulkCopyForBatchInsertAllowEncryptedValueModifications", false), USE_FMT_ONLY("useFmtOnly", false), SEND_TEMPORAL_DATATYPES_AS_STRING_FOR_BULK_COPY("sendTemporalDataTypesAsStringForBulkCopy", true), DELAY_LOADING_LOBS("delayLoadingLobs", true), @@ -946,9 +953,22 @@ static String getAppName() { new SQLServerDriverPropertyInfo(SQLServerDriverBooleanProperty.USE_BULK_COPY_FOR_BATCH_INSERT.toString(), Boolean.toString(SQLServerDriverBooleanProperty.USE_BULK_COPY_FOR_BATCH_INSERT.getDefaultValue()), false, TRUE_FALSE), + new SQLServerDriverPropertyInfo(SQLServerDriverIntProperty.BULK_COPY_FOR_BATCH_INSERT_BATCH_SIZE.toString(), + Integer.toString(SQLServerDriverIntProperty.BULK_COPY_FOR_BATCH_INSERT_BATCH_SIZE.getDefaultValue()),false, null), + new SQLServerDriverPropertyInfo(SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_CHECK_CONSTRAINTS.toString(), + Boolean.toString(SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_CHECK_CONSTRAINTS.getDefaultValue()),false, TRUE_FALSE), + new SQLServerDriverPropertyInfo(SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_FIRE_TRIGGERS.toString(), + Boolean.toString(SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_FIRE_TRIGGERS.getDefaultValue()),false, TRUE_FALSE), + new SQLServerDriverPropertyInfo(SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_KEEP_IDENTITY.toString(), + Boolean.toString(SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_KEEP_IDENTITY.getDefaultValue()),false, TRUE_FALSE), + new SQLServerDriverPropertyInfo(SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_KEEP_NULLS.toString(), + Boolean.toString(SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_KEEP_NULLS.getDefaultValue()),false, TRUE_FALSE), + new SQLServerDriverPropertyInfo(SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_TABLE_LOCK.toString(), + Boolean.toString(SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_TABLE_LOCK.getDefaultValue()),false, TRUE_FALSE), + new SQLServerDriverPropertyInfo(SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_ALLOW_ENCRYPTED_VALUE_MODIFICATIONS.toString(), + Boolean.toString(SQLServerDriverBooleanProperty.BULK_COPY_FOR_BATCH_INSERT_ALLOW_ENCRYPTED_VALUE_MODIFICATIONS.getDefaultValue()),false, TRUE_FALSE), new SQLServerDriverPropertyInfo(SQLServerDriverBooleanProperty.ENABLE_BULK_COPY_CACHE.toString(), - Boolean.toString(SQLServerDriverBooleanProperty.ENABLE_BULK_COPY_CACHE.getDefaultValue()), - false, TRUE_FALSE), + Boolean.toString(SQLServerDriverBooleanProperty.ENABLE_BULK_COPY_CACHE.getDefaultValue()),false, TRUE_FALSE), new SQLServerDriverPropertyInfo(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString(), SQLServerDriverStringProperty.MSI_CLIENT_ID.getDefaultValue(), false, null), new SQLServerDriverPropertyInfo(SQLServerDriverStringProperty.KEY_VAULT_PROVIDER_CLIENT_ID.toString(), diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java index 244cfb696..96ce9858f 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java @@ -2224,7 +2224,7 @@ public int[] executeBatch() throws SQLServerException, BatchUpdateException, SQL if (null == bcOperation) { bcOperation = new SQLServerBulkCopy(connection); - SQLServerBulkCopyOptions option = new SQLServerBulkCopyOptions(); + SQLServerBulkCopyOptions option = new SQLServerBulkCopyOptions(connection); option.setBulkCopyTimeout(queryTimeout); bcOperation.setBulkCopyOptions(option); bcOperation.setDestinationTableName(bcOperationTableName); @@ -2405,7 +2405,7 @@ public long[] executeLargeBatch() throws SQLServerException, BatchUpdateExceptio if (null == bcOperation) { bcOperation = new SQLServerBulkCopy(connection); - SQLServerBulkCopyOptions option = new SQLServerBulkCopyOptions(); + SQLServerBulkCopyOptions option = new SQLServerBulkCopyOptions(connection); option.setBulkCopyTimeout(queryTimeout); bcOperation.setBulkCopyOptions(option); bcOperation.setDestinationTableName(bcOperationTableName); diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java index c9d875e58..9d09e1edd 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerResource.java @@ -465,7 +465,14 @@ protected Object[][] getContents() { {"R_invalidSSLProtocol", "SSL Protocol {0} label is not valid. Only TLS, TLSv1, TLSv1.1, and TLSv1.2 are supported."}, {"R_cancelQueryTimeoutPropertyDescription", "The number of seconds to wait to cancel sending a query timeout."}, {"R_invalidCancelQueryTimeout", "The cancel timeout value {0} is not valid."}, - {"R_useBulkCopyForBatchInsertPropertyDescription", "Whether the driver will use bulk copy API for batch insert operations"}, + {"R_useBulkCopyForBatchInsertPropertyDescription", "Determines whether the driver will use bulk copy API for batch insert operations."}, + {"R_bulkCopyForBatchInsertBatchSizePropertyDescription", "The default batch size for bulk copy operations created from batch insert operations."}, + {"R_bulkCopyForBatchInsertCheckConstraintsPropertyDescription", "Determines whether to check constraints during bulk copy operations created from batch insert operations."}, + {"R_bulkCopyForBatchInsertFireTriggersPropertyDescription", "Determines whether to fire triggers during bulk copy operations created from batch insert operations."}, + {"R_bulkCopyForBatchInsertKeepIdentityPropertyDescription", "Determines whether to keep identity values during bulk copy operations created from batch insert operations."}, + {"R_bulkCopyForBatchInsertKeepNullsPropertyDescription", "Determines whether to keep null values during bulk copy operations created from batch insert operations."}, + {"R_bulkCopyForBatchInsertTableLockPropertyDescription", "Determines whether to use table lock during bulk copy operations created from batch insert operations."}, + {"R_bulkCopyForBatchInsertAllowEncryptedValueModificationsPropertyDescription", "Determines whether to allow encrypted value modifications during bulk copy operations created from batch insert operations."}, {"R_UnknownDataClsTokenNumber", "Unknown token for Data Classification."}, // From Server {"R_InvalidDataClsVersionNumber", "Invalid version number {0} for Data Classification."}, // From Server {"R_unknownUTF8SupportValue", "Unknown value for UTF8 support."}, diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/connection/RequestBoundaryMethodsTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/connection/RequestBoundaryMethodsTest.java index 402cccc8c..89df6595a 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/connection/RequestBoundaryMethodsTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/connection/RequestBoundaryMethodsTest.java @@ -73,6 +73,13 @@ public void testModifiableConnectionProperties() throws SQLException { boolean enablePrepareOnFirstPreparedStatementCall1 = false; String sCatalog1 = "master"; boolean useBulkCopyForBatchInsert1 = true; + int bulkCopyForBatchInsertBatchSize1 = 1000; + boolean bulkCopyForBatchInsertCheckConstraints1 = true; + boolean bulkCopyForBatchInsertFireTriggers1 = true; + boolean bulkCopyForBatchInsertKeepIdentity1 = true; + boolean bulkCopyForBatchInsertKeepNulls1 = true; + boolean bulkCopyForBatchInsertTableLock1 = true; + boolean bulkCopyForBatchInsertAllowEncryptedValueModifications1 = true; boolean useFmtOnly1 = true; boolean delayLoadingLobs1 = false; boolean ignoreOffsetOnDateTimeOffsetConversion1 = true; @@ -88,6 +95,13 @@ public void testModifiableConnectionProperties() throws SQLException { boolean enablePrepareOnFirstPreparedStatementCall2 = true; String sCatalog2 = RandomUtil.getIdentifier("RequestBoundaryDatabase"); boolean useBulkCopyForBatchInsert2 = false; + int bulkCopyForBatchInsertBatchSize2 = 0; + boolean bulkCopyForBatchInsertCheckConstraints2 = false; + boolean bulkCopyForBatchInsertFireTriggers2 = false; + boolean bulkCopyForBatchInsertKeepIdentity2 = false; + boolean bulkCopyForBatchInsertKeepNulls2 = false; + boolean bulkCopyForBatchInsertTableLock2 = false; + boolean bulkCopyForBatchInsertAllowEncryptedValueModifications2 = false; boolean useFmtOnly2 = false; boolean delayLoadingLobs2 = true; boolean ignoreOffsetOnDateTimeOffsetConversion2 = false; @@ -101,62 +115,78 @@ public void testModifiableConnectionProperties() throws SQLException { setConnectionFields(con, autoCommitMode1, transactionIsolationLevel1, networkTimeout1, holdability1, sendTimeAsDatetime1, statementPoolingCacheSize1, disableStatementPooling1, serverPreparedStatementDiscardThreshold1, enablePrepareOnFirstPreparedStatementCall1, sCatalog1, - useBulkCopyForBatchInsert1, useFmtOnly1, delayLoadingLobs1, - ignoreOffsetOnDateTimeOffsetConversion1); + useBulkCopyForBatchInsert1, useFmtOnly1, delayLoadingLobs1, ignoreOffsetOnDateTimeOffsetConversion1, + bulkCopyForBatchInsertBatchSize1, bulkCopyForBatchInsertCheckConstraints1, + bulkCopyForBatchInsertFireTriggers1, bulkCopyForBatchInsertKeepIdentity1, bulkCopyForBatchInsertKeepNulls1, + bulkCopyForBatchInsertTableLock1,bulkCopyForBatchInsertAllowEncryptedValueModifications1); con.beginRequest(); // Call setters with the second set of values inside beginRequest()/endRequest() block. setConnectionFields(con, autoCommitMode2, transactionIsolationLevel2, networkTimeout2, holdability2, sendTimeAsDatetime2, statementPoolingCacheSize2, disableStatementPooling2, serverPreparedStatementDiscardThreshold2, enablePrepareOnFirstPreparedStatementCall2, sCatalog2, - useBulkCopyForBatchInsert2, useFmtOnly2, delayLoadingLobs2, - ignoreOffsetOnDateTimeOffsetConversion2); + useBulkCopyForBatchInsert2, useFmtOnly2, delayLoadingLobs2, ignoreOffsetOnDateTimeOffsetConversion2, + bulkCopyForBatchInsertBatchSize2, bulkCopyForBatchInsertCheckConstraints2, + bulkCopyForBatchInsertFireTriggers2, bulkCopyForBatchInsertKeepIdentity2, bulkCopyForBatchInsertKeepNulls2, + bulkCopyForBatchInsertTableLock2, bulkCopyForBatchInsertAllowEncryptedValueModifications2); con.endRequest(); // Test if endRequest() resets the SQLServerConnection properties back to the first set of values. compareValuesAgainstConnection(con, autoCommitMode1, transactionIsolationLevel1, networkTimeout1, holdability1, sendTimeAsDatetime1, statementPoolingCacheSize1, disableStatementPooling1, serverPreparedStatementDiscardThreshold1, enablePrepareOnFirstPreparedStatementCall1, sCatalog1, - useBulkCopyForBatchInsert1, useFmtOnly1, delayLoadingLobs1, - ignoreOffsetOnDateTimeOffsetConversion1); - + useBulkCopyForBatchInsert1, useFmtOnly1, delayLoadingLobs1, ignoreOffsetOnDateTimeOffsetConversion1, + bulkCopyForBatchInsertBatchSize1, bulkCopyForBatchInsertCheckConstraints1, + bulkCopyForBatchInsertFireTriggers1, bulkCopyForBatchInsertKeepIdentity1, bulkCopyForBatchInsertKeepNulls1, + bulkCopyForBatchInsertTableLock1, bulkCopyForBatchInsertAllowEncryptedValueModifications1); // Multiple calls to beginRequest() without an intervening call to endRequest() are no-op. setConnectionFields(con, autoCommitMode2, transactionIsolationLevel2, networkTimeout2, holdability2, sendTimeAsDatetime2, statementPoolingCacheSize2, disableStatementPooling2, serverPreparedStatementDiscardThreshold2, enablePrepareOnFirstPreparedStatementCall2, sCatalog2, - useBulkCopyForBatchInsert2, useFmtOnly2, delayLoadingLobs2, - ignoreOffsetOnDateTimeOffsetConversion2); + useBulkCopyForBatchInsert2, useFmtOnly2, delayLoadingLobs2, ignoreOffsetOnDateTimeOffsetConversion2, + bulkCopyForBatchInsertBatchSize2, bulkCopyForBatchInsertCheckConstraints2, + bulkCopyForBatchInsertFireTriggers2, bulkCopyForBatchInsertKeepIdentity2, bulkCopyForBatchInsertKeepNulls2, + bulkCopyForBatchInsertTableLock2, bulkCopyForBatchInsertAllowEncryptedValueModifications2); con.beginRequest(); setConnectionFields(con, autoCommitMode1, transactionIsolationLevel1, networkTimeout1, holdability1, sendTimeAsDatetime1, statementPoolingCacheSize1, disableStatementPooling1, serverPreparedStatementDiscardThreshold1, enablePrepareOnFirstPreparedStatementCall1, sCatalog1, - useBulkCopyForBatchInsert1, useFmtOnly1, delayLoadingLobs1, - ignoreOffsetOnDateTimeOffsetConversion1); + useBulkCopyForBatchInsert1, useFmtOnly1, delayLoadingLobs1, ignoreOffsetOnDateTimeOffsetConversion1, + bulkCopyForBatchInsertBatchSize1, bulkCopyForBatchInsertCheckConstraints1, + bulkCopyForBatchInsertFireTriggers1, bulkCopyForBatchInsertKeepIdentity1, bulkCopyForBatchInsertKeepNulls1, + bulkCopyForBatchInsertTableLock1, bulkCopyForBatchInsertAllowEncryptedValueModifications1); con.beginRequest(); con.endRequest(); // Same values as before the first beginRequest() compareValuesAgainstConnection(con, autoCommitMode2, transactionIsolationLevel2, networkTimeout2, holdability2, sendTimeAsDatetime2, statementPoolingCacheSize2, disableStatementPooling2, serverPreparedStatementDiscardThreshold2, enablePrepareOnFirstPreparedStatementCall2, sCatalog2, - useBulkCopyForBatchInsert2, useFmtOnly2, delayLoadingLobs2, - ignoreOffsetOnDateTimeOffsetConversion2); - + useBulkCopyForBatchInsert2, useFmtOnly2, delayLoadingLobs2, ignoreOffsetOnDateTimeOffsetConversion2, + bulkCopyForBatchInsertBatchSize2, bulkCopyForBatchInsertCheckConstraints2, + bulkCopyForBatchInsertFireTriggers2, bulkCopyForBatchInsertKeepIdentity2, bulkCopyForBatchInsertKeepNulls2, + bulkCopyForBatchInsertTableLock2, bulkCopyForBatchInsertAllowEncryptedValueModifications2); // A call to endRequest() without an intervening call to beginRequest() is no-op. setConnectionFields(con, autoCommitMode1, transactionIsolationLevel1, networkTimeout1, holdability1, sendTimeAsDatetime1, statementPoolingCacheSize1, disableStatementPooling1, serverPreparedStatementDiscardThreshold1, enablePrepareOnFirstPreparedStatementCall1, sCatalog1, - useBulkCopyForBatchInsert1, useFmtOnly1, delayLoadingLobs1, - ignoreOffsetOnDateTimeOffsetConversion1); + useBulkCopyForBatchInsert1, useFmtOnly1, delayLoadingLobs1, ignoreOffsetOnDateTimeOffsetConversion1, + bulkCopyForBatchInsertBatchSize1, bulkCopyForBatchInsertCheckConstraints1, + bulkCopyForBatchInsertFireTriggers1, bulkCopyForBatchInsertKeepIdentity1, bulkCopyForBatchInsertKeepNulls1, + bulkCopyForBatchInsertTableLock1, bulkCopyForBatchInsertAllowEncryptedValueModifications1); + setConnectionFields(con, autoCommitMode2, transactionIsolationLevel2, networkTimeout2, holdability2, sendTimeAsDatetime2, statementPoolingCacheSize2, disableStatementPooling2, serverPreparedStatementDiscardThreshold2, enablePrepareOnFirstPreparedStatementCall2, sCatalog2, - useBulkCopyForBatchInsert2, useFmtOnly2, delayLoadingLobs2, - ignoreOffsetOnDateTimeOffsetConversion2); - con.endRequest(); + useBulkCopyForBatchInsert2, useFmtOnly2, delayLoadingLobs2, ignoreOffsetOnDateTimeOffsetConversion2, + bulkCopyForBatchInsertBatchSize2, bulkCopyForBatchInsertCheckConstraints2, + bulkCopyForBatchInsertFireTriggers2, bulkCopyForBatchInsertKeepIdentity2, bulkCopyForBatchInsertKeepNulls2, + bulkCopyForBatchInsertTableLock2, bulkCopyForBatchInsertAllowEncryptedValueModifications2); con.endRequest(); // No change. compareValuesAgainstConnection(con, autoCommitMode2, transactionIsolationLevel2, networkTimeout2, holdability2, sendTimeAsDatetime2, statementPoolingCacheSize2, disableStatementPooling2, serverPreparedStatementDiscardThreshold2, enablePrepareOnFirstPreparedStatementCall2, sCatalog2, - useBulkCopyForBatchInsert2, useFmtOnly2, delayLoadingLobs2, - ignoreOffsetOnDateTimeOffsetConversion2); + useBulkCopyForBatchInsert2, useFmtOnly2, delayLoadingLobs2, ignoreOffsetOnDateTimeOffsetConversion2, + bulkCopyForBatchInsertBatchSize2, bulkCopyForBatchInsertCheckConstraints2, + bulkCopyForBatchInsertFireTriggers2, bulkCopyForBatchInsertKeepIdentity2, bulkCopyForBatchInsertKeepNulls2, + bulkCopyForBatchInsertTableLock2, bulkCopyForBatchInsertAllowEncryptedValueModifications2); } finally { TestUtils.dropDatabaseIfExists(sCatalog2, connectionString); } @@ -400,8 +430,12 @@ private void setConnectionFields(SQLServerConnection con, boolean autoCommitMode int networkTimeout, int holdability, boolean sendTimeAsDatetime, int statementPoolingCacheSize, boolean disableStatementPooling, int serverPreparedStatementDiscardThreshold, boolean enablePrepareOnFirstPreparedStatementCall, String sCatalog, boolean useBulkCopyForBatchInsert, - boolean useFmtOnly, boolean delayLoadingLobs, - boolean ignoreOffsetOnDateTimeOffsetConversion) throws SQLException { + boolean useFmtOnly, boolean delayLoadingLobs, boolean ignoreOffsetOnDateTimeOffsetConversion, + int bulkCopyForBatchInsertBatchSize, boolean bulkCopyForBatchInsertCheckConstraints, + boolean bulkCopyForBatchInsertFireTriggers, boolean bulkCopyForBatchInsertKeepIdentity, + boolean bulkCopyForBatchInsertKeepNulls, boolean bulkCopyForBatchInsertTableLock, + boolean bulkCopyForBatchInsertAllowEncryptedValueModifications) throws SQLException { + con.setAutoCommit(autoCommitMode); con.setTransactionIsolation(transactionIsolationLevel); con.setNetworkTimeout(null, networkTimeout); @@ -416,38 +450,49 @@ private void setConnectionFields(SQLServerConnection con, boolean autoCommitMode con.setUseFmtOnly(useFmtOnly); con.setDelayLoadingLobs(delayLoadingLobs); con.setIgnoreOffsetOnDateTimeOffsetConversion(ignoreOffsetOnDateTimeOffsetConversion); + con.setBulkCopyForBatchInsertBatchSize(bulkCopyForBatchInsertBatchSize); + con.setBulkCopyForBatchInsertCheckConstraints(bulkCopyForBatchInsertCheckConstraints); + con.setBulkCopyForBatchInsertFireTriggers(bulkCopyForBatchInsertFireTriggers); + con.setBulkCopyForBatchInsertKeepIdentity(bulkCopyForBatchInsertKeepIdentity); + con.setBulkCopyForBatchInsertKeepNulls(bulkCopyForBatchInsertKeepNulls); + con.setBulkCopyForBatchInsertTableLock(bulkCopyForBatchInsertTableLock); + con.setBulkCopyForBatchInsertAllowEncryptedValueModifications(bulkCopyForBatchInsertAllowEncryptedValueModifications); } - + private void compareValuesAgainstConnection(SQLServerConnection con, boolean autoCommitMode, int transactionIsolationLevel, int networkTimeout, int holdability, boolean sendTimeAsDatetime, int statementPoolingCacheSize, boolean disableStatementPooling, int serverPreparedStatementDiscardThreshold, boolean enablePrepareOnFirstPreparedStatementCall, String sCatalog, boolean useBulkCopyForBatchInsert, - boolean useFmtOnly, boolean delayLoadingLobs, - boolean ignoreOffsetOnDateTimeOffsetConversion) throws SQLException { + boolean useFmtOnly, boolean delayLoadingLobs, boolean ignoreOffsetOnDateTimeOffsetConversion, + int bulkCopyForBatchInsertBatchSize, boolean bulkCopyForBatchInsertCheckConstraints, + boolean bulkCopyForBatchInsertFireTriggers, boolean bulkCopyForBatchInsertKeepIdentity, + boolean bulkCopyForBatchInsertKeepNulls, boolean bulkCopyForBatchInsertTableLock, + boolean bulkCopyForBatchInsertAllowEncryptedValueModifications) throws SQLException { + final String description = " values do not match."; assertEquals(autoCommitMode, con.getAutoCommit(), "autoCommitmode" + description); - assertEquals(transactionIsolationLevel, con.getTransactionIsolation(), - "transactionIsolationLevel" + description); + assertEquals(transactionIsolationLevel, con.getTransactionIsolation(), "transactionIsolationLevel" + description); assertEquals(networkTimeout, con.getNetworkTimeout(), "networkTimeout" + description); assertEquals(holdability, con.getHoldability(), "holdability" + description); assertEquals(sendTimeAsDatetime, con.getSendTimeAsDatetime(), "sendTimeAsDatetime" + description); - assertEquals(statementPoolingCacheSize, con.getStatementPoolingCacheSize(), - "statementPoolingCacheSize" + description); - assertEquals(disableStatementPooling, con.getDisableStatementPooling(), - "disableStatementPooling" + description); - assertEquals(serverPreparedStatementDiscardThreshold, con.getServerPreparedStatementDiscardThreshold(), - "serverPreparedStatementDiscardThreshold" + description); - assertEquals(enablePrepareOnFirstPreparedStatementCall, con.getEnablePrepareOnFirstPreparedStatementCall(), - "enablePrepareOnFirstPreparedStatementCall" + description); + assertEquals(statementPoolingCacheSize, con.getStatementPoolingCacheSize(), "statementPoolingCacheSize" + description); + assertEquals(disableStatementPooling, con.getDisableStatementPooling(), "disableStatementPooling" + description); + assertEquals(serverPreparedStatementDiscardThreshold, con.getServerPreparedStatementDiscardThreshold(), "serverPreparedStatementDiscardThreshold" + description); + assertEquals(enablePrepareOnFirstPreparedStatementCall, con.getEnablePrepareOnFirstPreparedStatementCall(), "enablePrepareOnFirstPreparedStatementCall" + description); assertEquals(sCatalog, con.getCatalog(), "sCatalog" + description); - assertEquals(useBulkCopyForBatchInsert, con.getUseBulkCopyForBatchInsert(), - "useBulkCopyForBatchInsert" + description); + assertEquals(useBulkCopyForBatchInsert, con.getUseBulkCopyForBatchInsert(), "useBulkCopyForBatchInsert" + description); assertEquals(useFmtOnly, con.getUseFmtOnly(), "useFmtOnly" + description); assertEquals(delayLoadingLobs, con.getDelayLoadingLobs(), "delayLoadingLobs" + description); - assertEquals(ignoreOffsetOnDateTimeOffsetConversion, con.getIgnoreOffsetOnDateTimeOffsetConversion(), - "ignoreOffsetOnDateTimeOffsetConversion" + description); + assertEquals(ignoreOffsetOnDateTimeOffsetConversion, con.getIgnoreOffsetOnDateTimeOffsetConversion(), "ignoreOffsetOnDateTimeOffsetConversion" + description); + assertEquals(bulkCopyForBatchInsertBatchSize, con.getBulkCopyForBatchInsertBatchSize(), "bulkCopyForBatchInsertBatchSize" + description); + assertEquals(bulkCopyForBatchInsertCheckConstraints, con.getBulkCopyForBatchInsertCheckConstraints(), "bulkCopyForBatchInsertCheckConstraints" + description); + assertEquals(bulkCopyForBatchInsertFireTriggers, con.getBulkCopyForBatchInsertFireTriggers(), "bulkCopyForBatchInsertFireTriggers" + description); + assertEquals(bulkCopyForBatchInsertKeepIdentity, con.getBulkCopyForBatchInsertKeepIdentity(), "bulkCopyForBatchInsertKeepIdentity" + description); + assertEquals(bulkCopyForBatchInsertKeepNulls, con.getBulkCopyForBatchInsertKeepNulls(), "bulkCopyForBatchInsertKeepNulls" + description); + assertEquals(bulkCopyForBatchInsertTableLock, con.getBulkCopyForBatchInsertTableLock(), "bulkCopyForBatchInsertTableLock" + description); + assertEquals(bulkCopyForBatchInsertAllowEncryptedValueModifications, con.getBulkCopyForBatchInsertAllowEncryptedValueModifications(), "bulkCopyForBatchInsertAllowEncryptedValueModifications" + description); } - + private void generateWarning(Connection con) throws SQLException { con.setClientInfo("name", "value"); } @@ -476,6 +521,13 @@ private List getVerifiedMethodNames() { verifiedMethodNames.add("setDisableStatementPooling"); verifiedMethodNames.add("setTransactionIsolation"); verifiedMethodNames.add("setUseBulkCopyForBatchInsert"); + verifiedMethodNames.add("setBulkCopyForBatchInsertBatchSize"); + verifiedMethodNames.add("setBulkCopyForBatchInsertCheckConstraints"); + verifiedMethodNames.add("setBulkCopyForBatchInsertFireTriggers"); + verifiedMethodNames.add("setBulkCopyForBatchInsertKeepIdentity"); + verifiedMethodNames.add("setBulkCopyForBatchInsertKeepNulls"); + verifiedMethodNames.add("setBulkCopyForBatchInsertTableLock"); + verifiedMethodNames.add("setBulkCopyForBatchInsertAllowEncryptedValueModifications"); verifiedMethodNames.add("commit"); verifiedMethodNames.add("clearWarnings"); verifiedMethodNames.add("prepareStatement"); diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/preparedStatement/BatchExecutionWithBCOptionsTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/preparedStatement/BatchExecutionWithBCOptionsTest.java new file mode 100644 index 000000000..836d57a52 --- /dev/null +++ b/src/test/java/com/microsoft/sqlserver/jdbc/preparedStatement/BatchExecutionWithBCOptionsTest.java @@ -0,0 +1,683 @@ +/* + * Microsoft JDBC Driver for SQL Server Copyright(c) Microsoft Corporation All rights reserved. This program is made + * available under the terms of the MIT License. See the LICENSE file in the project root for more information. + */ +package com.microsoft.sqlserver.jdbc.preparedStatement; + +import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.platform.runner.JUnitPlatform; +import org.junit.runner.RunWith; + +import com.microsoft.sqlserver.jdbc.RandomUtil; +import com.microsoft.sqlserver.jdbc.SQLServerBulkCopyOptions; +import com.microsoft.sqlserver.jdbc.SQLServerPreparedStatement; +import com.microsoft.sqlserver.jdbc.TestResource; +import com.microsoft.sqlserver.jdbc.TestUtils; +import com.microsoft.sqlserver.testframework.AbstractSQLGenerator; +import com.microsoft.sqlserver.testframework.AbstractTest; +import com.microsoft.sqlserver.testframework.PrepUtil; + +@RunWith(JUnitPlatform.class) +public class BatchExecutionWithBCOptionsTest extends AbstractTest { + + private static final String tableName = AbstractSQLGenerator + .escapeIdentifier(RandomUtil.getIdentifier("BatchInsertWithBCOptions")); + + /** + * Test with useBulkCopyBatchInsert=true without passing + * bulkCopyForBatchInsertCheckConstraints + * + * @throws SQLException + */ + @Test + public void testBulkInsertNoConnStrOptions() throws Exception { + try (Connection connection = PrepUtil.getConnection(connectionString + ";useBulkCopyForBatchInsert=true;")) { + try (PreparedStatement pstmt = connection.prepareStatement("insert into " + tableName + " values(?, ?)")) { + pstmt.setInt(1, 1); + pstmt.setInt(2, 0); + pstmt.addBatch(); + + pstmt.setInt(1, 2); + pstmt.setInt(2, 2); + pstmt.addBatch(); + + pstmt.setInt(1, 3); + pstmt.setInt(2, 0); + pstmt.addBatch(); + + pstmt.setInt(1, 4); + pstmt.setInt(2, 4); + pstmt.addBatch(); + + pstmt.executeBatch(); + + try (Statement stmt = connection.createStatement()) { + try (ResultSet rs = stmt.executeQuery("select count(*) from " + tableName)) { + if (rs.next()) { + int cnt = rs.getInt(1); + assertEquals(cnt, 4, "row count should have been 4"); + } + } + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Test with useBulkCopyBatchInsert=true and + * bulkCopyForBatchInsertCheckConstraints=true + * + * @throws SQLException + */ + @Test + public void testBulkInsertWithConnStrConstraintCheckEnabled() throws Exception { + try (Connection connection = PrepUtil.getConnection( + connectionString + ";useBulkCopyForBatchInsert=true;bulkCopyForBatchInsertCheckConstraints=true")) { + try (PreparedStatement pstmt = connection.prepareStatement("insert into " + tableName + " values(?, ?)")) { + + pstmt.setInt(1, 1); + pstmt.setInt(2, 0); + pstmt.addBatch(); + + pstmt.setInt(1, 2); + pstmt.setInt(2, 2); + pstmt.addBatch(); + + pstmt.setInt(1, 3); + pstmt.setInt(2, 0); + pstmt.addBatch(); + + pstmt.setInt(1, 4); + pstmt.setInt(2, 4); + pstmt.addBatch(); + + pstmt.executeBatch(); + + fail(TestResource.getResource("R_expectedExceptionNotThrown")); + + } + } catch (SQLException e) { + if (!e.getMessage().contains("CHECK")) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + } + + /** + * Test with useBulkCopyBatchInsert=true and + * bulkCopyForBatchInsertCheckConstraints=false + * + * @throws SQLException + */ + @Test + public void testBulkInsertWithConnStrCheckConstraintsDisabled() throws Exception { + try (Connection connection = PrepUtil.getConnection( + connectionString + ";useBulkCopyForBatchInsert=true;bulkCopyForBatchInsertCheckConstraints=false")) { + try (PreparedStatement pstmt = connection.prepareStatement("insert into " + tableName + " values(?, ?)")) { + + pstmt.setInt(1, 1); + pstmt.setInt(2, 0); + pstmt.addBatch(); + + pstmt.setInt(1, 2); + pstmt.setInt(2, 2); + pstmt.addBatch(); + + pstmt.setInt(1, 3); + pstmt.setInt(2, 0); + pstmt.addBatch(); + + pstmt.setInt(1, 4); + pstmt.setInt(2, 4); + pstmt.addBatch(); + + pstmt.executeBatch(); + + try (Statement stmt = connection.createStatement()) { + try (ResultSet rs = stmt.executeQuery("select count(*) from " + tableName)) { + if (rs.next()) { + int cnt = rs.getInt(1); + assertEquals(cnt, 4, "row count should have been 4"); + } + } + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Test with useBulkCopyBatchInsert=true and bulkCopyForBatchInsertBatchSize set + * + * @throws SQLException + */ + @Test + public void testBulkInsertWithBatchSize() throws Exception { + try (Connection connection = PrepUtil.getConnection( + connectionString + ";useBulkCopyForBatchInsert=true;bulkCopyForBatchInsertBatchSize=2")) { + try (PreparedStatement pstmt = connection.prepareStatement("insert into " + tableName + " values(?, ?)")) { + pstmt.setInt(1, 1); + pstmt.setInt(2, 1); + pstmt.addBatch(); + + pstmt.setInt(1, 2); + pstmt.setInt(2, 2); + pstmt.addBatch(); + + pstmt.setInt(1, 3); + pstmt.setInt(2, 3); + pstmt.addBatch(); + + pstmt.setInt(1, 4); + pstmt.setInt(2, 4); + pstmt.addBatch(); + + pstmt.executeBatch(); + + try (Statement stmt = connection.createStatement()) { + try (ResultSet rs = stmt.executeQuery("select count(*) from " + tableName)) { + if (rs.next()) { + int cnt = rs.getInt(1); + assertEquals(cnt, 4, "row count should have been 4"); + } + } + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Test with useBulkCopyBatchInsert=true and + * bulkCopyForBatchInsertKeepIdentity=true + * + * @throws SQLException + */ + @Test + public void testBulkInsertWithKeepIdentity() throws Exception { + try (Connection connection = PrepUtil.getConnection( + connectionString + ";useBulkCopyForBatchInsert=true;bulkCopyForBatchInsertKeepIdentity=true")) { + try (PreparedStatement pstmt = connection.prepareStatement("insert into " + tableName + " values(?, ?)")) { + pstmt.setInt(1, 1); + pstmt.setInt(2, 1); + pstmt.addBatch(); + + pstmt.setInt(1, 2); + pstmt.setInt(2, 2); + pstmt.addBatch(); + + pstmt.setInt(1, 3); + pstmt.setInt(2, 3); + pstmt.addBatch(); + + pstmt.setInt(1, 4); + pstmt.setInt(2, 4); + pstmt.addBatch(); + + pstmt.executeBatch(); + + try (Statement stmt = connection.createStatement()) { + try (ResultSet rs = stmt.executeQuery("select count(*) from " + tableName)) { + if (rs.next()) { + int cnt = rs.getInt(1); + assertEquals(cnt, 4, "row count should have been 4"); + } + } + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Test with useBulkCopyBatchInsert=true and + * bulkCopyForBatchInsertKeepIdentity=true where identity insert fails + * + * @throws SQLException + */ + @Test + public void testBulkInsertWithKeepIdentityFailure() throws Exception { + try (Connection connection = PrepUtil.getConnection( + connectionString + ";useBulkCopyForBatchInsert=true;bulkCopyForBatchInsertKeepIdentity=true")) { + try (PreparedStatement pstmt = connection.prepareStatement("insert into " + tableName + " values(?, ?)")) { + pstmt.setInt(1, 1); + pstmt.setInt(2, 1); + pstmt.addBatch(); + + pstmt.setInt(1, 1); + pstmt.setInt(2, 2); + pstmt.addBatch(); + + pstmt.executeBatch(); + + fail(TestResource.getResource("R_expectedExceptionNotThrown")); + } + } catch (SQLException e) { + if (!e.getMessage().contains("Violation of PRIMARY KEY constraint")) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + } + + /** + * Test with useBulkCopyBatchInsert=true without passing + * SQLServerBulkCopyOptions + * + * @throws SQLException + */ + @Test + public void testBulkInsertNoOptions() throws Exception { + try (Connection connection = PrepUtil.getConnection(connectionString + ";useBulkCopyForBatchInsert=true;")) { + try (PreparedStatement pstmt = connection.prepareStatement("insert into " + tableName + " values(?, ?)")) { + pstmt.setInt(1, 1); + pstmt.setInt(2, 0); + pstmt.addBatch(); + + pstmt.setInt(1, 2); + pstmt.setInt(2, 2); + pstmt.addBatch(); + + pstmt.setInt(1, 3); + pstmt.setInt(2, 0); + pstmt.addBatch(); + + pstmt.setInt(1, 4); + pstmt.setInt(2, 4); + pstmt.addBatch(); + + pstmt.executeBatch(); + + try (Statement stmt = connection.createStatement()) { + try (ResultSet rs = stmt.executeQuery("select count(*) from " + tableName)) { + if (rs.next()) { + int cnt = rs.getInt(1); + assertEquals(cnt, 4, "row count should have been 4"); + } + } + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Test with useBulkCopyBatchInsert=true and + * bulkCopyForBatchInsertTableLock=true + * + * @throws SQLException + */ + @Test + public void testBulkInsertWithTableLock() throws Exception { + try (Connection connection = PrepUtil.getConnection( + connectionString + ";useBulkCopyForBatchInsert=true;bulkCopyForBatchInsertTableLock=true")) { + try (PreparedStatement pstmt = connection.prepareStatement("insert into " + tableName + " values(?, ?)")) { + pstmt.setInt(1, 1); + pstmt.setInt(2, 1); + pstmt.addBatch(); + + pstmt.setInt(1, 2); + pstmt.setInt(2, 2); + pstmt.addBatch(); + + pstmt.setInt(1, 3); + pstmt.setInt(2, 3); + pstmt.addBatch(); + + pstmt.setInt(1, 4); + pstmt.setInt(2, 4); + pstmt.addBatch(); + + pstmt.executeBatch(); + + try (Statement stmt = connection.createStatement()) { + try (ResultSet rs = stmt.executeQuery("select count(*) from " + tableName)) { + if (rs.next()) { + int cnt = rs.getInt(1); + assertEquals(cnt, 4, "row count should have been 4"); + } + } + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Test with useBulkCopyBatchInsert=true and + * bulkCopyForBatchInsertTableLock=true where insert fails + * + * @throws SQLException + */ + @Test + public void testBulkInsertWithTableLockFailure() throws Exception { + try (Connection connection = PrepUtil.getConnection( + connectionString + ";useBulkCopyForBatchInsert=true;bulkCopyForBatchInsertTableLock=true")) { + try (PreparedStatement pstmt = connection.prepareStatement("insert into " + tableName + " values(?, ?)")) { + pstmt.setInt(1, 1); + pstmt.setInt(2, 1); + pstmt.addBatch(); + + pstmt.setInt(1, 2); + pstmt.setInt(2, 2); + pstmt.addBatch(); + + pstmt.setInt(1, 3); + pstmt.setInt(2, 3); + pstmt.addBatch(); + + pstmt.setInt(1, 4); + pstmt.setInt(2, 4); + pstmt.addBatch(); + + // Start a transaction and acquire a table lock + connection.setAutoCommit(false); + try (Statement stmt = connection.createStatement()) { + String lockTableSQL = "SELECT * FROM " + tableName + " WITH (TABLOCKX)"; + stmt.execute(lockTableSQL); + + try (Connection connection2 = PrepUtil.getConnection( + connectionString + ";useBulkCopyForBatchInsert=true;bulkCopyForBatchInsertTableLock=true"); + PreparedStatement pstmt2 = connection2 + .prepareStatement("insert into " + tableName + " values(?, ?)")) { + + pstmt2.setInt(1, 5); + pstmt2.setInt(2, 5); + pstmt2.addBatch(); + + // Set a query timeout to prevent the test from running indefinitely + pstmt2.setQueryTimeout(5); + + pstmt2.executeBatch(); // This should fail due to the table lock + fail("Expected exception due to table lock was not thrown"); + } catch (SQLException e) { + System.out.println("Bulk insert failed as expected: " + e.getMessage()); + } + // Release the lock + connection.rollback(); + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Test with useBulkCopyBatchInsert=true and + * bulkCopyForBatchInsertFireTriggers=true + * + * @throws SQLException + */ + @Test + public void testBulkCopyOptionDefaultsFireTriggers() throws Exception { + try (Connection connection = PrepUtil.getConnection( + connectionString + ";useBulkCopyForBatchInsert=true;bulkCopyForBatchInsertFireTriggers=true")) { + try (PreparedStatement pstmt = connection.prepareStatement("insert into " + tableName + " values(?, ?)")) { + pstmt.setInt(1, 1); + pstmt.setInt(2, 1); + pstmt.addBatch(); + + pstmt.setInt(1, 2); + pstmt.setInt(2, 2); + pstmt.addBatch(); + + pstmt.setInt(1, 3); + pstmt.setInt(2, 3); + pstmt.addBatch(); + + pstmt.setInt(1, 4); + pstmt.setInt(2, 4); + pstmt.addBatch(); + + pstmt.executeBatch(); + try (Statement stmt = connection.createStatement()) { + try (ResultSet rs = stmt.executeQuery("select count(*) from " + tableName)) { + if (rs.next()) { + int cnt = rs.getInt(1); + assertEquals(cnt, 4, "Row count should have been 4"); + } + } + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Test with useBulkCopyBatchInsert=true and + * bulkCopyForBatchInsertFireTriggers=true where insert fails + * + * @throws SQLException + */ + @Test + public void testBulkCopyOptionDefaultsFireTriggersFailure() throws Exception { + try (Connection connection = PrepUtil.getConnection( + connectionString + ";useBulkCopyForBatchInsert=true;bulkCopyForBatchInsertFireTriggers=true")) { + try (PreparedStatement pstmt = connection.prepareStatement("insert into " + tableName + " values(?, ?)")) { + pstmt.setInt(1, 1); + pstmt.setInt(2, 1); + pstmt.addBatch(); + + pstmt.setInt(1, 2); + pstmt.setInt(2, 2); + pstmt.addBatch(); + + pstmt.setInt(1, 3); + pstmt.setInt(2, 3); + pstmt.addBatch(); + + pstmt.setInt(1, 4); + pstmt.setInt(2, 4); + pstmt.addBatch(); + + // Created a trigger that will cause the batch insert to fail + try (Statement stmt = connection.createStatement()) { + String createTriggerSQL = "CREATE TRIGGER trgFailInsert ON " + tableName + + " AFTER INSERT AS BEGIN " + + "RAISERROR('Trigger failure', 16, 1); " + + "ROLLBACK TRANSACTION; END"; + stmt.execute(createTriggerSQL); + } + + try { + pstmt.executeBatch(); + fail("Expected trigger failure exception was not thrown"); + } catch (SQLException e) { + System.out.println("Batch execution failed as expected: " + e.getMessage()); + } + + // Cleaning up by dropping the trigger + try (Statement stmt = connection.createStatement()) { + String dropTriggerSQL = "DROP TRIGGER trgFailInsert"; + stmt.execute(dropTriggerSQL); + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Test with useBulkCopyBatchInsert=true and + * bulkCopyForBatchInsertKeepNulls=true + * + * @throws SQLException + */ + @Test + public void testBulkCopyOptionDefaultsKeepNulls() throws Exception { + try (Connection connection = PrepUtil.getConnection( + connectionString + ";useBulkCopyForBatchInsert=true;bulkCopyForBatchInsertKeepNulls=true")) { + try (PreparedStatement pstmt = connection.prepareStatement("insert into " + tableName + " values(?, ?)")) { + pstmt.setInt(1, 1); + pstmt.setNull(2, java.sql.Types.INTEGER); + pstmt.addBatch(); + + pstmt.setInt(1, 2); + pstmt.setNull(2, java.sql.Types.INTEGER); + pstmt.addBatch(); + + pstmt.setInt(1, 3); + pstmt.setNull(2, java.sql.Types.INTEGER); + pstmt.addBatch(); + + pstmt.setInt(1, 4); + pstmt.setNull(2, java.sql.Types.INTEGER); + pstmt.addBatch(); + + pstmt.executeBatch(); + + try (Statement stmt = connection.createStatement()) { + try (ResultSet rs = stmt + .executeQuery("select count(*) from " + tableName + " where b is null")) { + if (rs.next()) { + int cnt = rs.getInt(1); + assertEquals(cnt, 4, "Row count with null values should have been 4"); + } + } + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Test with useBulkCopyBatchInsert=true and + * bulkCopyForBatchInsertKeepNulls=false + * + * @throws SQLException + */ + @Test + public void testBulkCopyOptionDefaultsKeepNullsFalse() throws Exception { + try (Connection connection = PrepUtil.getConnection( + connectionString + ";useBulkCopyForBatchInsert=true;bulkCopyForBatchInsertKeepNulls=false")) { + try (PreparedStatement pstmt = connection.prepareStatement("insert into " + tableName + " values(?, ?)")) { + pstmt.setInt(1, 1); + pstmt.setNull(2, java.sql.Types.INTEGER); + pstmt.addBatch(); + + pstmt.setInt(1, 2); + pstmt.setNull(2, java.sql.Types.INTEGER); + pstmt.addBatch(); + + pstmt.setInt(1, 3); + pstmt.setNull(2, java.sql.Types.INTEGER); + pstmt.addBatch(); + + pstmt.setInt(1, 4); + pstmt.setNull(2, java.sql.Types.INTEGER); + pstmt.addBatch(); + + pstmt.executeBatch(); + + try (Statement stmt = connection.createStatement()) { + try (ResultSet rs = stmt + .executeQuery("select count(*) from " + tableName + " where b is not null")) { + if (rs.next()) { + int cnt = rs.getInt(1); + assertEquals(cnt, 0, "Row count with non-null values should have been 0"); + } + } + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Test with useBulkCopyBatchInsert=true and + * bulkCopyForBatchInsertAllowEncryptedValueModifications=true + * + * @throws SQLException + */ + @Test + public void testBulkInsertWithEncryptedValueModifications() throws Exception { + try (Connection connection = PrepUtil.getConnection(connectionString + ";useBulkCopyForBatchInsert=true;bulkCopyForBatchInsertAllowEncryptedValueModifications=true")) { + try (PreparedStatement pstmt = connection.prepareStatement("insert into " + tableName + " values(?, ?)")) { + pstmt.setInt(1, 1); + pstmt.setInt(2, 0); + pstmt.addBatch(); + + pstmt.setInt(1, 2); + pstmt.setInt(2, 2); + pstmt.addBatch(); + + pstmt.setInt(1, 3); + pstmt.setInt(2, 0); + pstmt.addBatch(); + + pstmt.setInt(1, 4); + pstmt.setInt(2, 4); + pstmt.addBatch(); + + pstmt.executeBatch(); + + try (Statement stmt = connection.createStatement()) { + try (ResultSet rs = stmt.executeQuery("select count(*) from " + tableName)) { + if (rs.next()) { + int cnt = rs.getInt(1); + assertEquals(cnt, 4, "row count should have been 4"); + } + } + } + } + } catch (SQLException e) { + if (e.getMessage().contains("Invalid column type from bcp client for colid 1")) { + return; + } else { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + fail("Expected exception 'Invalid column type from bcp client for colid 1' was not thrown."); + } + + @BeforeEach + public void init() throws Exception { + try (Connection con = getConnection()) { + con.setAutoCommit(false); + try (Statement stmt = con.createStatement()) { + TestUtils.dropTableIfExists(tableName, stmt); + String sql1 = "create table " + tableName + "(a INT PRIMARY KEY, b INT CHECK (b > 0))"; + stmt.executeUpdate(sql1); + } + con.commit(); + } + } + + @AfterEach + public void terminate() throws Exception { + try (Connection con = getConnection()) { + try (Statement stmt = con.createStatement()) { + TestUtils.dropTableIfExists(tableName, stmt); + } + } + } + + @BeforeAll + public static void setupTests() throws Exception { + setConnection(); + } + +} From 08cd6fd9e91779fa1c5f37f218feafaae3251e0a Mon Sep 17 00:00:00 2001 From: muskan124947 Date: Fri, 24 Jan 2025 11:42:06 +0530 Subject: [PATCH 3/5] Apply new naming scheme to interfaceLibName (#2577) * Apply new naming scheme to interfaceLibName * Added changes to comment getAppName functionality --- .../sqlserver/jdbc/SQLServerConnection.java | 5 +- .../sqlserver/jdbc/SQLServerDataSource.java | 2 +- .../sqlserver/jdbc/SQLServerDriver.java | 46 +++++------ .../sqlserver/jdbc/SQLServerDriverTest.java | 76 +++++++++---------- 4 files changed, 65 insertions(+), 64 deletions(-) diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java index be4df3139..f5e6bc317 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java @@ -2615,7 +2615,7 @@ Connection connectInternal(Properties propsIn, if (null != sPropValue) validateMaxSQLLoginName(sPropKey, sPropValue); else - activeConnectionProperties.setProperty(sPropKey, SQLServerDriver.constructedAppName); + activeConnectionProperties.setProperty(sPropKey, SQLServerDriver.DEFAULT_APP_NAME); sPropKey = SQLServerDriverBooleanProperty.LAST_UPDATE_COUNT.toString(); sPropValue = activeConnectionProperties.getProperty(sPropKey); @@ -6892,7 +6892,8 @@ final boolean complete(LogonCommand logonCommand, TDSReader tdsReader) throws SQ String sPwd = activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()); String appName = activeConnectionProperties .getProperty(SQLServerDriverStringProperty.APPLICATION_NAME.toString()); - String interfaceLibName = "Microsoft JDBC Driver " + SQLJdbcVersion.MAJOR + "." + SQLJdbcVersion.MINOR; + String interfaceLibName = "Microsoft JDBC Driver " + SQLJdbcVersion.MAJOR + "." + SQLJdbcVersion.MINOR; + //String interfaceLibName = SQLServerDriver.constructedAppName; String databaseName = activeConnectionProperties .getProperty(SQLServerDriverStringProperty.DATABASE_NAME.toString()); diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java index 1f3c7993b..61b9f8f0c 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java @@ -170,7 +170,7 @@ public void setApplicationName(String applicationName) { @Override public String getApplicationName() { return getStringProperty(connectionProps, SQLServerDriverStringProperty.APPLICATION_NAME.toString(), - SQLServerDriver.constructedAppName); + SQLServerDriverStringProperty.APPLICATION_NAME.getDefaultValue()); } /** diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java index 17264ce59..a265d8099 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java @@ -738,11 +738,11 @@ public final class SQLServerDriver implements java.sql.Driver { static final String AUTH_DLL_NAME = "mssql-jdbc_auth-" + SQLJdbcVersion.MAJOR + "." + SQLJdbcVersion.MINOR + "." + SQLJdbcVersion.PATCH + "." + Util.getJVMArchOnWindows() + SQLJdbcVersion.RELEASE_EXT; static final String DEFAULT_APP_NAME = "Microsoft JDBC Driver for SQL Server"; - static final String APP_NAME_TEMPLATE = "Microsoft JDBC - %s, %s - %s"; - static final String constructedAppName; - static { - constructedAppName = getAppName(); - } + // static final String APP_NAME_TEMPLATE = "Microsoft JDBC - %s, %s - %s"; + // static final String constructedAppName; + // static { + // constructedAppName = getAppName(); + // } /** * Constructs the application name using system properties for OS, platform, and architecture. @@ -751,18 +751,18 @@ public final class SQLServerDriver implements java.sql.Driver { * * @return the constructed application name or the default application name if properties are not available */ - static String getAppName() { - String osName = System.getProperty("os.name", ""); - String osArch = System.getProperty("os.arch", ""); - String javaVmName = System.getProperty("java.vm.name", ""); - String javaVmVersion = System.getProperty("java.vm.version", ""); - String platform = javaVmName.isEmpty() || javaVmVersion.isEmpty() ? "" : javaVmName + " " + javaVmVersion; - - if (osName.isEmpty() && platform.isEmpty() && osArch.isEmpty()) { - return DEFAULT_APP_NAME; - } - return String.format(APP_NAME_TEMPLATE, osName, platform, osArch); - } + // static String getAppName() { + // String osName = System.getProperty("os.name", ""); + // String osArch = System.getProperty("os.arch", ""); + // String javaVmName = System.getProperty("java.vm.name", ""); + // String javaVmVersion = System.getProperty("java.vm.version", ""); + // String platform = javaVmName.isEmpty() || javaVmVersion.isEmpty() ? "" : javaVmName + " " + javaVmVersion; + + // if (osName.isEmpty() && platform.isEmpty() && osArch.isEmpty()) { + // return DEFAULT_APP_NAME; + // } + // return String.format(APP_NAME_TEMPLATE, osName, platform, osArch); + // } private static final String[] TRUE_FALSE = {"true", "false"}; @@ -1073,9 +1073,9 @@ String getClassNameLogging() { drLogger.finer("Error registering driver: " + e); } } - if (loggerExternal.isLoggable(Level.FINE)) { - loggerExternal.log(Level.FINE, "Application Name: " + SQLServerDriver.constructedAppName); - } + // if (loggerExternal.isLoggable(Level.FINE)) { + // loggerExternal.log(Level.FINE, "Application Name: " + SQLServerDriver.constructedAppName); + // } } // Check for jdk.net.ExtendedSocketOptions to set TCP keep-alive options for idle connection resiliency @@ -1314,9 +1314,9 @@ public java.sql.Connection connect(String url, Properties suppliedProperties) th Properties connectProperties = parseAndMergeProperties(url, suppliedProperties); if (connectProperties != null) { result = DriverJDBCVersion.getSQLServerConnection(toString()); - if (connectProperties.getProperty(SQLServerDriverStringProperty.APPLICATION_NAME.toString()) == null) { - connectProperties.setProperty(SQLServerDriverStringProperty.APPLICATION_NAME.toString(), SQLServerDriver.constructedAppName); - } + // if (connectProperties.getProperty(SQLServerDriverStringProperty.APPLICATION_NAME.toString()) == null) { + // connectProperties.setProperty(SQLServerDriverStringProperty.APPLICATION_NAME.toString(), SQLServerDriver.constructedAppName); + // } result.connect(connectProperties, null); } loggerExternal.exiting(getClassNameLogging(), "connect", result); diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerDriverTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerDriverTest.java index 646ad75e9..be9333e24 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerDriverTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerDriverTest.java @@ -204,7 +204,7 @@ public void testApplicationName() throws SQLException { Statement stmt = conn.createStatement(); ResultSet rs = stmt.executeQuery("SELECT program_name FROM sys.dm_exec_sessions WHERE session_id = @@SPID")) { if (rs.next()) { - assertEquals(SQLServerDriver.constructedAppName, rs.getString("program_name")); + assertEquals(SQLServerDriverStringProperty.APPLICATION_NAME.getDefaultValue(), rs.getString("program_name")); } } catch (SQLException e) { fail(e.getMessage()); @@ -216,54 +216,54 @@ public void testApplicationName() throws SQLException { * * @throws SQLException */ - @Test - public void testApplicationNameUsingApp_Name() throws SQLException { - try (Connection conn = DriverManager.getConnection(connectionString); - Statement stmt = conn.createStatement(); - ResultSet rs = stmt.executeQuery("SELECT app_name()")) { - if (rs.next()) { - assertEquals(SQLServerDriver.constructedAppName, rs.getString(1)); - } - } catch (SQLException e) { - fail(e.getMessage()); - } - } + // @Test + // public void testApplicationNameUsingApp_Name() throws SQLException { + // try (Connection conn = DriverManager.getConnection(connectionString); + // Statement stmt = conn.createStatement(); + // ResultSet rs = stmt.executeQuery("SELECT app_name()")) { + // if (rs.next()) { + // assertEquals(SQLServerDriver.constructedAppName, rs.getString(1)); + // } + // } catch (SQLException e) { + // fail(e.getMessage()); + // } + // } /** * test application name by executing select app_name() * * @throws SQLException */ - @Test - public void testAppNameWithSpecifiedApplicationName() throws SQLException { - String url = connectionString + ";applicationName={0123456789012345678901234567890123456789012345678901234567890123456789012345678901234589012345678901234567890123456789012345678}"; + // @Test + // public void testAppNameWithSpecifiedApplicationName() throws SQLException { + // String url = connectionString + ";applicationName={0123456789012345678901234567890123456789012345678901234567890123456789012345678901234589012345678901234567890123456789012345678}"; - try (Connection conn = DriverManager.getConnection(url); - Statement stmt = conn.createStatement(); - ResultSet rs = stmt.executeQuery("SELECT app_name()")) { - if (rs.next()) { - assertEquals("0123456789012345678901234567890123456789012345678901234567890123456789012345678901234589012345678901234567890123456789012345678", rs.getString(1)); - } - } catch (SQLException e) { - fail(e.getMessage()); - } - } + // try (Connection conn = DriverManager.getConnection(url); + // Statement stmt = conn.createStatement(); + // ResultSet rs = stmt.executeQuery("SELECT app_name()")) { + // if (rs.next()) { + // assertEquals("0123456789012345678901234567890123456789012345678901234567890123456789012345678901234589012345678901234567890123456789012345678", rs.getString(1)); + // } + // } catch (SQLException e) { + // fail(e.getMessage()); + // } + // } /** * test application name when system properties are empty * */ - @Test - public void testGetAppName() { - String appName = SQLServerDriver.getAppName(); - assertNotNull(appName, "Application name should not be null"); - assertFalse(appName.isEmpty(), "Application name should not be empty"); + // @Test + // public void testGetAppName() { + // String appName = SQLServerDriver.getAppName(); + // assertNotNull(appName, "Application name should not be null"); + // assertFalse(appName.isEmpty(), "Application name should not be empty"); - System.setProperty("os.name", ""); - System.setProperty("os.arch", ""); - System.setProperty("java.vm.name", ""); - System.setProperty("java.vm.version", ""); - String defaultAppName = SQLServerDriver.getAppName(); - assertEquals(SQLServerDriver.DEFAULT_APP_NAME, defaultAppName, "Application name should be the default one"); - } + // System.setProperty("os.name", ""); + // System.setProperty("os.arch", ""); + // System.setProperty("java.vm.name", ""); + // System.setProperty("java.vm.version", ""); + // String defaultAppName = SQLServerDriver.getAppName(); + // assertEquals(SQLServerDriver.DEFAULT_APP_NAME, defaultAppName, "Application name should be the default one"); + // } } From 03cfcfd3f298a56f690d0a4d298fb6d91eaae635 Mon Sep 17 00:00:00 2001 From: Mahendra Chavan Date: Mon, 27 Jan 2025 16:38:13 +0530 Subject: [PATCH 4/5] Fixed issue with SQLServerBulkCopy from CSV with setEscapeColumnDelimerts set to true (#2575) * Fixed issue with SQLServerBulkCopy from CSV for setEscapeColumnDelimitersCSV * Fixed scenario with last line having blank content * Simplied if condition * Changed sb.isEmpty to sb.toString().isEmpty() * Use length instead of toString * Added dropTable in finally block. --- .../jdbc/SQLServerBulkCSVFileRecord.java | 2 +- .../jdbc/bulkCopy/BulkCopyCSVTest.java | 59 +++++++++++++++++++ ...TestInputDelimiterEscapeNoNewLineAtEnd.csv | 15 +++++ 3 files changed, 75 insertions(+), 1 deletion(-) create mode 100644 src/test/resources/BulkCopyCSVTestInputDelimiterEscapeNoNewLineAtEnd.csv diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCSVFileRecord.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCSVFileRecord.java index 2db4339fa..1b21544f9 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCSVFileRecord.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerBulkCSVFileRecord.java @@ -246,7 +246,7 @@ private String readLineEscapeDelimiters() throws SQLServerException { if (c == -1 && quoteCount % 2 != 0) { // stream ended, but we are within quotes -- data problem throw new SQLServerException(SQLServerException.getErrString("R_InvalidCSVQuotes"), null, 0, null); } - if (c == -1) { // keep semantics of readLine() by returning a null when there is no more data + if ((c == -1) && (sb.length() == 0)) { // keep semantics of readLine() by returning a null when there is no more data return null; } } catch (IOException e) { diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/bulkCopy/BulkCopyCSVTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/bulkCopy/BulkCopyCSVTest.java index 2a8fdb0a0..057c5068c 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/bulkCopy/BulkCopyCSVTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/bulkCopy/BulkCopyCSVTest.java @@ -66,6 +66,7 @@ public class BulkCopyCSVTest extends AbstractTest { static String inputFile = "BulkCopyCSVTestInput.csv"; static String inputFileNoColumnName = "BulkCopyCSVTestInputNoColumnName.csv"; static String inputFileDelimiterEscape = "BulkCopyCSVTestInputDelimiterEscape.csv"; + static String inputFileDelimiterEscapeNoNewLineAtEnd = "BulkCopyCSVTestInputDelimiterEscapeNoNewLineAtEnd.csv"; static String inputFileMultipleDoubleQuotes = "BulkCopyCSVTestInputMultipleDoubleQuotes.csv"; static String encoding = "UTF-8"; static String delimiter = ","; @@ -197,12 +198,70 @@ public void testEscapeColumnDelimitersCSV() throws Exception { assertEquals(expectedEscaped[i][3], rs.getString("c4")); i++; } + assertEquals(i, 12, "Expected to load 12 records, but loaded " + i + " records"); } TestUtils.dropTableIfExists(tableName, stmt); } } + @Test + @DisplayName("Test setEscapeColumnDelimitersCSVNoNewLineAtEnd") + public void testEscapeColumnDelimitersCSVNoNewLineAtEnd() throws Exception { + String tableName = AbstractSQLGenerator.escapeIdentifier(RandomUtil.getIdentifier("BulkEscape")); + String fileName = filePath + inputFileDelimiterEscapeNoNewLineAtEnd; + /* + * The list below is the copy of inputFileDelimiterEsc ape with quotes removed. + */ + String[][] expectedEscaped = new String[12][4]; + expectedEscaped[0] = new String[] {"test", " test\"", "no@split", " testNoQuote", ""}; + expectedEscaped[1] = new String[] {null, null, null, null, ""}; + expectedEscaped[2] = new String[] {"\"", "test\"test", "test@\" test", null, ""}; + expectedEscaped[3] = new String[] {"testNoQuote ", " testSpaceAround ", " testSpaceInside ", + " testSpaceQuote\" ", ""}; + expectedEscaped[4] = new String[] {null, null, null, " testSpaceInside ", ""}; + expectedEscaped[5] = new String[] {"1997", "Ford", "E350", "E63", ""}; + expectedEscaped[6] = new String[] {"1997", "Ford", "E350", "E63", ""}; + expectedEscaped[7] = new String[] {"1997", "Ford", "E350", "Super@ luxurious truck", ""}; + expectedEscaped[8] = new String[] {"1997", "Ford", "E350", "Super@ \"luxurious\" truck", ""}; + expectedEscaped[9] = new String[] {"1997", "Ford", "E350", "E63", ""}; + expectedEscaped[10] = new String[] {"1997", "Ford", "E350", " Super luxurious truck ", ""}; + expectedEscaped[11] = new String[] {"1997", "F\r\no\r\nr\r\nd", "E350", "\"Super\" \"luxurious\" \"truck\"", + ""}; + + try (Connection con = getConnection(); Statement stmt = con.createStatement(); + SQLServerBulkCopy bulkCopy = new SQLServerBulkCopy(con); + SQLServerBulkCSVFileRecord fileRecord = new SQLServerBulkCSVFileRecord(fileName, encoding, "@", + false)) { + bulkCopy.setDestinationTableName(tableName); + fileRecord.setEscapeColumnDelimitersCSV(true); + fileRecord.addColumnMetadata(1, null, java.sql.Types.INTEGER, 0, 0); + fileRecord.addColumnMetadata(2, null, java.sql.Types.VARCHAR, 50, 0); + fileRecord.addColumnMetadata(3, null, java.sql.Types.VARCHAR, 50, 0); + fileRecord.addColumnMetadata(4, null, java.sql.Types.VARCHAR, 50, 0); + fileRecord.addColumnMetadata(5, null, java.sql.Types.VARCHAR, 50, 0); + fileRecord.addColumnMetadata(6, null, java.sql.Types.VARCHAR, 50, 0); + stmt.executeUpdate("CREATE TABLE " + tableName + + " (id INT IDENTITY(1,1), c1 VARCHAR(50), c2 VARCHAR(50), c3 VARCHAR(50), c4 VARCHAR(50), c5 VARCHAR(50))"); + bulkCopy.writeToServer(fileRecord); + + int i = 0; + try (ResultSet rs = stmt.executeQuery("SELECT * FROM " + tableName + " ORDER BY id"); + BufferedReader br = new BufferedReader(new FileReader(fileName));) { + while (rs.next()) { + assertEquals(expectedEscaped[i][0], rs.getString("c1")); + assertEquals(expectedEscaped[i][1], rs.getString("c2")); + assertEquals(expectedEscaped[i][2], rs.getString("c3")); + assertEquals(expectedEscaped[i][3], rs.getString("c4")); + i++; + } + assertEquals(i, 12, "Expected to load 12 records, but loaded " + i + " records"); + } finally { + TestUtils.dropTableIfExists(tableName, stmt); + } + } + } + /** * test simple csv file for bulkcopy, for GitHub issue 1391 Tests to ensure that the set returned by * getColumnOrdinals doesn't have to be ordered diff --git a/src/test/resources/BulkCopyCSVTestInputDelimiterEscapeNoNewLineAtEnd.csv b/src/test/resources/BulkCopyCSVTestInputDelimiterEscapeNoNewLineAtEnd.csv new file mode 100644 index 000000000..98ac9f114 --- /dev/null +++ b/src/test/resources/BulkCopyCSVTestInputDelimiterEscapeNoNewLineAtEnd.csv @@ -0,0 +1,15 @@ +1@"test"@ " test"""@ "no@split" @ testNoQuote@ +2@""@ ""@ ""@ ""@ +3@""""@ "test""test"@ "test@"" test"@ ""@ +4@testNoQuote @ testSpaceAround @ " testSpaceInside "@ " testSpaceQuote"" "@ +5@""@ ""@ ""@ " testSpaceInside "@ +6@1997@Ford@E350@E63@ +7@"1997"@"Ford"@"E350"@"E63"@ +8@1997@Ford@E350@"Super@ luxurious truck"@ +9@1997@Ford@E350@"Super@ ""luxurious"" truck"@ +10@1997@ "Ford" @E350@ "E63"@ +11@1997@Ford@E350@" Super luxurious truck "@ +12@1997@"F +o +r +d"@"E350"@"""Super"" ""luxurious"" ""truck"""@ \ No newline at end of file From 25ea8da0a0a56ee45ad79daa2b41cecf114312b2 Mon Sep 17 00:00:00 2001 From: Mahendra Chavan Date: Thu, 30 Jan 2025 13:38:48 +0530 Subject: [PATCH 5/5] Issue#2550 - Fixed getGeneratedKeys functionality for execute API (#2554) * Issue#2550 - Fixed getGeneratedKeys functionality for execute API * Adapted the test for working with jdk8 * Fixed indenetation * Enable thre new TCGenKeys tests for AzureDW * Incorporated review comments. * Incorporated review comments * Incorporated review comments * Add a test for PreparedStatement * Adding a fix and test case for issue # 2587 * Added a new test for execute API with set no count --- .../jdbc/SQLServerPreparedStatement.java | 2 +- .../sqlserver/jdbc/SQLServerStatement.java | 41 +- .../microsoft/sqlserver/jdbc/StreamDone.java | 2 +- .../jdbc/unit/statement/StatementTest.java | 386 ++++++++++++++++++ 4 files changed, 419 insertions(+), 12 deletions(-) diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java index 96ce9858f..2a0c79d97 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java @@ -705,7 +705,7 @@ final void doExecutePreparedStatement(PrepStmtExecCmd command) throws SQLServerE if (EXECUTE_QUERY == executeMethod && null == resultSet) { SQLServerException.makeFromDriverError(connection, this, SQLServerException.getErrString("R_noResultset"), null, true); - } else if (EXECUTE_UPDATE == executeMethod && null != resultSet) { + } else if ((EXECUTE_UPDATE == executeMethod) && (null != resultSet) && !bRequestedGeneratedKeys) { SQLServerException.makeFromDriverError(connection, this, SQLServerException.getErrString("R_resultsetGeneratedForUpdate"), null, false); } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java index ab9e9fbe2..3faa0ebb3 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java @@ -1601,9 +1601,15 @@ boolean onDone(TDSReader tdsReader) throws SQLServerException { if (null != procedureName) return false; + //For Insert, we must fetch additional TDS_DONE token that comes with the actual update count + if ((StreamDone.CMD_INSERT == doneToken.getCurCmd()) && (-1 != doneToken.getUpdateCount()) && EXECUTE == executeMethod) { + return true; + } + // Always return all update counts from statements executed through Statement.execute() - if (EXECUTE == executeMethod) - return false; + if (EXECUTE == executeMethod) { + return false; + } // Statement.executeUpdate() may or may not return this update count depending on the // setting of the lastUpdateCount connection property: @@ -2357,17 +2363,27 @@ public final ResultSet getGeneratedKeys() throws SQLServerException { if (null == autoGeneratedKeys) { long orgUpd = updateCount; + // + //A case of SET NOCOUNT ON and GENERATED KEYS requested + //where we may not have received update count but would have already read the resultset + //so directly consume it. + // + if ((executeMethod != EXECUTE_QUERY) && bRequestedGeneratedKeys && (resultSet != null)) { + autoGeneratedKeys = resultSet; + updateCount = orgUpd; + } else { - // Generated keys are returned in a ResultSet result right after the update count. - // Try to get that ResultSet. If there are no more results after the update count, - // or if the next result isn't a ResultSet, then something is wrong. - if (!getNextResult(true) || null == resultSet) { - SQLServerException.makeFromDriverError(connection, this, - SQLServerException.getErrString("R_statementMustBeExecuted"), null, false); + // Generated keys are returned in a ResultSet result right after the update count. + // Try to get that ResultSet. If there are no more results after the update count, + // or if the next result isn't a ResultSet, then something is wrong. + if (!getNextResult(true) || null == resultSet) { + SQLServerException.makeFromDriverError(connection, this, + SQLServerException.getErrString("R_statementMustBeExecuted"), null, false); + } + autoGeneratedKeys = resultSet; + updateCount = orgUpd; } - autoGeneratedKeys = resultSet; - updateCount = orgUpd; } loggerExternal.exiting(getClassNameLogging(), "getGeneratedKeys", autoGeneratedKeys); return autoGeneratedKeys; @@ -2616,6 +2632,11 @@ SQLServerColumnEncryptionKeyStoreProvider getColumnEncryptionKeyStoreProvider( lock.unlock(); } } + + protected void setAutoGeneratedKey(SQLServerResultSet rs) { + autoGeneratedKeys = rs; + } + } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/StreamDone.java b/src/main/java/com/microsoft/sqlserver/jdbc/StreamDone.java index 319c9ad6c..ede354260 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/StreamDone.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/StreamDone.java @@ -232,7 +232,7 @@ final long getUpdateCount() { } final boolean cmdIsDMLOrDDL() { - switch (curCmd) { + switch (curCmd) { case CMD_INSERT: case CMD_BULKINSERT: case CMD_DELETE: diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/StatementTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/StatementTest.java index 9c814916d..959837f0a 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/StatementTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/StatementTest.java @@ -24,6 +24,8 @@ import java.sql.Statement; import java.sql.Types; import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import java.util.Random; import java.util.UUID; import java.util.concurrent.Executors; @@ -2692,4 +2694,388 @@ public void terminate() throws Exception { } } } + + + @Nested + public class TCGenKeys { + private final String tableName = AbstractSQLGenerator + .escapeIdentifier(RandomUtil.getIdentifier("TCInsertWithGenKeys")); + private final String idTableName = AbstractSQLGenerator + .escapeIdentifier(RandomUtil.getIdentifier("TCInsertWithGenKeysIDs")); + + private final String triggerName = AbstractSQLGenerator.escapeIdentifier("Trigger"); + private final int NUM_ROWS = 3; + + @BeforeEach + public void setup() throws Exception { + try (Connection con = getConnection()) { + con.setAutoCommit(false); + try (Statement stmt = con.createStatement()) { + TestUtils.dropTriggerIfExists(triggerName, stmt); + stmt.executeUpdate("CREATE TABLE " + tableName + " (ID int NOT NULL IDENTITY(1,1) PRIMARY KEY, NAME varchar(32));"); + stmt.executeUpdate("CREATE TABLE " + idTableName + "(ID int NOT NULL IDENTITY(1,1) PRIMARY KEY);"); + stmt.executeUpdate("CREATE TRIGGER " + triggerName + " ON " + tableName + + " FOR INSERT AS INSERT INTO " + idTableName + " DEFAULT VALUES;"); + for (int i = 0; i < NUM_ROWS; i++) { + stmt.executeUpdate("INSERT INTO " + tableName + " (NAME) VALUES ('test')"); + } + } + con.commit(); + } + } + + /** + * Tests executeUpdate for Insert followed by getGenerateKeys + * + * @throws Exception + */ + @Test + public void testExecuteUpdateInsertAndGenKeys() { + try (Connection con = getConnection()) { + try(Statement stmt = con.createStatement()) { + String sql = "INSERT INTO " + tableName + " (NAME) VALUES('test')"; + List lst = Arrays.asList("ID"); + String[] arr = lst.toArray(new String[0]); + stmt.executeUpdate(sql, arr); + try (ResultSet generatedKeys = stmt.getGeneratedKeys()) { + if (generatedKeys.next()) { + int id = generatedKeys.getInt(1); + assertEquals(id, 4, "id should have been 4, but received : " + id); + } + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Tests executeUpdate using PreparedStatement for Insert followed by getGenerateKeys + * + * @throws Exception + */ + @Test + public void testPrepStmtExecuteUpdateInsertAndGenKeys() { + try (Connection con = getConnection()) { + String sql = "INSERT INTO " + tableName + " (NAME) VALUES('test')"; + try(PreparedStatement stmt = con.prepareStatement(sql,PreparedStatement.RETURN_GENERATED_KEYS)) { + stmt.executeUpdate(); + try (ResultSet generatedKeys = stmt.getGeneratedKeys()) { + if (generatedKeys.next()) { + int id = generatedKeys.getInt(1); + assertEquals(id, 4, "id should have been 4, but received : " + id); + } + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Tests executeUpdate using PreparedStatement for Insert followed by getGenerateKeys + * + * @throws Exception + */ + @Test + public void testPrepStmtNoCountExecuteUpdateInsertAndGenKeys() { + try (Connection con = getConnection()) { + String sql = "SET NOCOUNT ON; INSERT INTO " + tableName + " (NAME) VALUES('test')"; + try(PreparedStatement stmt = con.prepareStatement(sql,PreparedStatement.RETURN_GENERATED_KEYS)) { + stmt.executeUpdate(); + try (ResultSet generatedKeys = stmt.getGeneratedKeys()) { + if (generatedKeys.next()) { + int id = generatedKeys.getInt(1); + assertEquals(id, 4, "id should have been 4, but received : " + id); + } + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Tests executeUpdate using PreparedStatement for Insert followed by getGenerateKeys + * + * @throws Exception + */ + @Test + public void testPrepStmtNoCountExecuteInsertAndGenKeys() { + try (Connection con = getConnection()) { + String sql = "SET NOCOUNT ON; INSERT INTO " + tableName + " (NAME) VALUES('test')"; + try(PreparedStatement stmt = con.prepareStatement(sql,PreparedStatement.RETURN_GENERATED_KEYS)) { + stmt.execute(); + try (ResultSet generatedKeys = stmt.getGeneratedKeys()) { + if (generatedKeys.next()) { + int id = generatedKeys.getInt(1); + assertEquals(id, 4, "id should have been 4, but received : " + id); + } + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + + /** + * Tests execute for Insert followed by getGenerateKeys + * + * @throws Exception + */ + @Test + public void testExecuteInsertAndGenKeys() { + try (Connection con = getConnection()) { + try(Statement stmt = con.createStatement()) { + String sql = "INSERT INTO " + tableName + " (NAME) VALUES('test')"; + List lst = Arrays.asList("ID"); + String[] arr = lst.toArray(new String[0]); + stmt.execute(sql, arr); + try (ResultSet generatedKeys = stmt.getGeneratedKeys()) { + if (generatedKeys.next()) { + int id = generatedKeys.getInt(1); + assertEquals(id, 4, "generated key should have been 4"); + } + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Tests execute for Insert followed by select + * + * @throws Exception + */ + @Test + public void testExecuteInsertAndSelect() { + + try (Connection con = getConnection()) { + try(Statement stmt = con.createStatement()) { + boolean retval = stmt.execute("INSERT INTO " + tableName +" (NAME) VALUES('test') SELECT NAME FROM " + tableName + " WHERE ID = 1"); + do { + if (!retval) { + int count = stmt.getUpdateCount(); + if (count == -1) { + // no more results + break; + } else { + assertEquals(count, 1, "update count should have been 1"); + } + } else { + // process ResultSet + try (ResultSet rs = stmt.getResultSet()) { + if (rs.next()) { + String val = rs.getString(1); + assertEquals(val, "test", "read value should have been 'test'"); + } + } + } + retval = stmt.getMoreResults(); + } while (true); + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + + /** + * Tests execute for Merge followed by select + * + * @throws Exception + */ + @Test + public void testExecuteMergeAndSelect() { + try (Connection con = getConnection()) { + try(Statement stmt = con.createStatement()) { + boolean retval = stmt.execute("MERGE INTO " + tableName + " AS target USING (VALUES ('test1')) AS source (name) ON target.name = source.name WHEN NOT MATCHED THEN INSERT (name) VALUES ('test1'); SELECT NAME FROM " + tableName + " WHERE ID = 1"); + do { + if (!retval) { + int count = stmt.getUpdateCount(); + if (count == -1) { + // no more results + break; + } else { + assertEquals(count, 1, "update count should have been 1"); + } + } else { + // process ResultSet + try (ResultSet rs = stmt.getResultSet()) { + if (rs.next()) { + String val = rs.getString(1); + assertEquals(val, "test", "read value should have been 'test'"); + } + } + + } + retval = stmt.getMoreResults(); + } while (true); + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Tests execute for Insert multiple rows followed by select + * + * @throws Exception + */ + @Test + public void testExecuteInsertManyRowsAndSelect() { + try (Connection con = getConnection()) { + try (Statement stmt = con.createStatement()) { + boolean retval = stmt.execute("INSERT INTO " + tableName + " SELECT NAME FROM " + tableName + " SELECT NAME FROM " + tableName + " WHERE ID = 1"); + do { + if (!retval) { + int count = stmt.getUpdateCount(); + if (count == -1) { + // no more results + break; + } else { + assertEquals(count, 3, "update count should have been 6"); + } + } else { + // process ResultSet + try (ResultSet rs = stmt.getResultSet()) { + if (rs.next()) { + String val = rs.getString(1); + assertEquals(val, "test", "read value should have been 'test'"); + } + } + + } + retval = stmt.getMoreResults(); + } while (true); + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Tests execute two Inserts followed by select + * + * @throws Exception + */ + @Test + public void testExecuteTwoInsertsRowsAndSelect() { + try (Connection con = getConnection()) { + try(Statement stmt = con.createStatement()) { + boolean retval = stmt.execute("INSERT INTO " + tableName + " (NAME) VALUES('test') INSERT INTO " + tableName + " (NAME) VALUES('test') SELECT NAME from " + tableName + " WHERE ID = 1"); + do { + if (!retval) { + int count = stmt.getUpdateCount(); + if (count == -1) { + // no more results + break; + } else { + assertEquals(count, 1, "update count should have been 2"); + } + } else { + // process ResultSet + try (ResultSet rs = stmt.getResultSet()) { + if (rs.next()) { + String val = rs.getString(1); + assertEquals(val, "test", "read value should have been 'test'"); + } + } + + } + retval = stmt.getMoreResults(); + } while (true); + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + + /** + * Tests execute for Update followed by select + * + * @throws Exception + */ + @Test + public void testExecuteUpdAndSelect() { + try (Connection con = getConnection()) { + try(Statement stmt = con.createStatement()) { + boolean retval = stmt.execute("UPDATE " + tableName +" SET NAME = 'test' SELECT NAME FROM " + tableName + " WHERE ID = 1"); + do { + if (!retval) { + int count = stmt.getUpdateCount(); + if (count == -1) { + // no more results + break; + } else { + assertEquals(count, 3, "update count should have been 3"); + } + } else { + // process ResultSet + try (ResultSet rs = stmt.getResultSet()) { + if (rs.next()) { + String val = rs.getString(1); + assertEquals(val, "test", "read value should have been 'test'"); + } + } + } + retval = stmt.getMoreResults(); + } while (true); + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Tests execute for Update followed by select + * + * @throws Exception + */ + @Test + public void testExecuteDelAndSelect() { + try (Connection con = getConnection()) { + try(Statement stmt = con.createStatement()) { + boolean retval = stmt.execute("DELETE FROM " + tableName +" WHERE ID = 1 SELECT NAME FROM " + tableName + " WHERE ID = 2"); + do { + if (!retval) { + int count = stmt.getUpdateCount(); + if (count == -1) { + // no more results + break; + } else { + assertEquals(count, 1, "update count should have been 1"); + } + } else { + // process ResultSet + try (ResultSet rs = stmt.getResultSet()) { + if (rs.next()) { + String val = rs.getString(1); + assertEquals(val, "test", "read value should have been 'test'"); + } + } + } + retval = stmt.getMoreResults(); + } while (true); + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + @AfterEach + public void terminate() { + try (Connection con = getConnection(); Statement stmt = con.createStatement()) { + TestUtils.dropTriggerIfExists(triggerName, stmt); + TestUtils.dropTableIfExists(idTableName, stmt); + TestUtils.dropTableIfExists(tableName, stmt); + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + } + }