diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java index 6874deab4..5f076f9af 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerConnection.java @@ -2446,7 +2446,7 @@ Connection connectInternal(Properties propsIn, if (null != sPropValue) validateMaxSQLLoginName(sPropKey, sPropValue); else - activeConnectionProperties.setProperty(sPropKey, SQLServerDriver.DEFAULT_APP_NAME); + activeConnectionProperties.setProperty(sPropKey, SQLServerDriver.constructedAppName); sPropKey = SQLServerDriverBooleanProperty.LAST_UPDATE_COUNT.toString(); sPropValue = activeConnectionProperties.getProperty(sPropKey); @@ -6110,10 +6110,11 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw } while (true) { + int millisecondsRemaining = timerRemaining(timerExpire); if (authenticationString.equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString())) { fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, user, activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()), - authenticationString); + authenticationString, millisecondsRemaining); // Break out of the retry loop in successful case. break; @@ -6125,12 +6126,12 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) { fedAuthToken = SQLServerSecurityUtility.getManagedIdentityCredAuthToken(fedAuthInfo.spn, - managedIdentityClientId); + managedIdentityClientId, millisecondsRemaining); break; } fedAuthToken = SQLServerSecurityUtility.getManagedIdentityCredAuthToken(fedAuthInfo.spn, - activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString())); + activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()), millisecondsRemaining); // Break out of the retry loop in successful case. break; @@ -6141,12 +6142,12 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw if (aadPrincipalID != null && !aadPrincipalID.isEmpty() && aadPrincipalSecret != null && !aadPrincipalSecret.isEmpty()) { fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo, aadPrincipalID, - aadPrincipalSecret, authenticationString); + aadPrincipalSecret, authenticationString, millisecondsRemaining); } else { fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenPrincipal(fedAuthInfo, activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString()), activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()), - authenticationString); + authenticationString, millisecondsRemaining); } // Break out of the retry loop in successful case. @@ -6159,7 +6160,7 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw activeConnectionProperties.getProperty(SQLServerDriverStringProperty.USER.toString()), servicePrincipalCertificate, activeConnectionProperties.getProperty(SQLServerDriverStringProperty.PASSWORD.toString()), - servicePrincipalCertificateKey, servicePrincipalCertificatePassword, authenticationString); + servicePrincipalCertificateKey, servicePrincipalCertificatePassword, authenticationString, millisecondsRemaining); // Break out of the retry loop in successful case. break; @@ -6194,7 +6195,7 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw throw new SQLServerException(form.format(msgArgs), null); } - int millisecondsRemaining = timerRemaining(timerExpire); + millisecondsRemaining = timerRemaining(timerExpire); if (ActiveDirectoryAuthentication.GET_ACCESS_TOKEN_TRANSIENT_ERROR != errorCategory || timerHasExpired(timerExpire) || (fedauthSleepInterval >= millisecondsRemaining)) { @@ -6240,7 +6241,7 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw Object[] msgArgs = {SQLServerDriver.AUTH_DLL_NAME, authenticationString}; throw new SQLServerException(form.format(msgArgs), null, 0, null); } - fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenIntegrated(fedAuthInfo, authenticationString); + fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenIntegrated(fedAuthInfo, authenticationString, millisecondsRemaining); } // Break out of the retry loop in successful case. break; @@ -6248,7 +6249,7 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw .equalsIgnoreCase(SqlAuthentication.ACTIVE_DIRECTORY_INTERACTIVE.toString())) { // interactive flow fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthTokenInteractive(fedAuthInfo, user, - authenticationString); + authenticationString, millisecondsRemaining); // Break out of the retry loop in successful case. break; @@ -6258,12 +6259,12 @@ private SqlAuthenticationToken getFedAuthToken(SqlFedAuthInfo fedAuthInfo) throw if (null != managedIdentityClientId && !managedIdentityClientId.isEmpty()) { fedAuthToken = SQLServerSecurityUtility.getDefaultAzureCredAuthToken(fedAuthInfo.spn, - managedIdentityClientId); + managedIdentityClientId, millisecondsRemaining); break; } fedAuthToken = SQLServerSecurityUtility.getDefaultAzureCredAuthToken(fedAuthInfo.spn, - activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString())); + activeConnectionProperties.getProperty(SQLServerDriverStringProperty.MSI_CLIENT_ID.toString()), millisecondsRemaining); break; } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java index da7688e60..480f36ba3 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDataSource.java @@ -170,7 +170,7 @@ public void setApplicationName(String applicationName) { @Override public String getApplicationName() { return getStringProperty(connectionProps, SQLServerDriverStringProperty.APPLICATION_NAME.toString(), - SQLServerDriverStringProperty.APPLICATION_NAME.getDefaultValue()); + SQLServerDriver.constructedAppName); } /** diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java index e4b1d59ee..1ffffa6f0 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerDriver.java @@ -731,7 +731,32 @@ public final class SQLServerDriver implements java.sql.Driver { static final String AUTH_DLL_NAME = "mssql-jdbc_auth-" + SQLJdbcVersion.MAJOR + "." + SQLJdbcVersion.MINOR + "." + SQLJdbcVersion.PATCH + "." + Util.getJVMArchOnWindows() + SQLJdbcVersion.RELEASE_EXT; static final String DEFAULT_APP_NAME = "Microsoft JDBC Driver for SQL Server"; + static final String APP_NAME_TEMPLATE = "Microsoft JDBC - %s, %s - %s"; + static final String constructedAppName; + static { + constructedAppName = getAppName(); + } + /** + * Constructs the application name using system properties for OS, platform, and architecture. + * If any of the properties cannot be fetched, it falls back to the default application name. + * Format -> Microsoft JDBC - {OS}, {Platform} - {architecture} + * + * @return the constructed application name or the default application name if properties are not available + */ + static String getAppName() { + String osName = System.getProperty("os.name", ""); + String osArch = System.getProperty("os.arch", ""); + String javaVmName = System.getProperty("java.vm.name", ""); + String javaVmVersion = System.getProperty("java.vm.version", ""); + String platform = javaVmName.isEmpty() || javaVmVersion.isEmpty() ? "" : javaVmName + " " + javaVmVersion; + + if (osName.isEmpty() && platform.isEmpty() && osArch.isEmpty()) { + return DEFAULT_APP_NAME; + } + return String.format(APP_NAME_TEMPLATE, osName, platform, osArch); + } + private static final String[] TRUE_FALSE = {"true", "false"}; private static final SQLServerDriverPropertyInfo[] DRIVER_PROPERTIES = { @@ -741,7 +766,7 @@ public final class SQLServerDriver implements java.sql.Driver { SQLServerDriverStringProperty.APPLICATION_INTENT.getDefaultValue(), false, new String[] {ApplicationIntent.READ_ONLY.toString(), ApplicationIntent.READ_WRITE.toString()}), new SQLServerDriverPropertyInfo(SQLServerDriverStringProperty.APPLICATION_NAME.toString(), - SQLServerDriverStringProperty.APPLICATION_NAME.getDefaultValue(), false, null), + SQLServerDriverStringProperty.APPLICATION_NAME.getDefaultValue(), false, null), new SQLServerDriverPropertyInfo(SQLServerDriverStringProperty.COLUMN_ENCRYPTION.toString(), SQLServerDriverStringProperty.COLUMN_ENCRYPTION.getDefaultValue(), false, new String[] {ColumnEncryptionSetting.DISABLED.toString(), @@ -1028,6 +1053,9 @@ String getClassNameLogging() { drLogger.finer("Error registering driver: " + e); } } + if (loggerExternal.isLoggable(Level.FINE)) { + loggerExternal.log(Level.FINE, "Application Name: " + SQLServerDriver.constructedAppName); + } } // Check for jdk.net.ExtendedSocketOptions to set TCP keep-alive options for idle connection resiliency @@ -1266,6 +1294,9 @@ public java.sql.Connection connect(String url, Properties suppliedProperties) th Properties connectProperties = parseAndMergeProperties(url, suppliedProperties); if (connectProperties != null) { result = DriverJDBCVersion.getSQLServerConnection(toString()); + if (connectProperties.getProperty(SQLServerDriverStringProperty.APPLICATION_NAME.toString()) == null) { + connectProperties.setProperty(SQLServerDriverStringProperty.APPLICATION_NAME.toString(), SQLServerDriver.constructedAppName); + } result.connect(connectProperties, null); } loggerExternal.exiting(getClassNameLogging(), "connect", result); diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java index 689347db3..8850e74fc 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerMSAL4JUtils.java @@ -25,7 +25,9 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Level; @@ -64,7 +66,8 @@ class SQLServerMSAL4JUtils { static final String REDIRECTURI = "http://localhost"; static final String SLASH_DEFAULT = "/.default"; static final String ACCESS_TOKEN_EXPIRE = "access token expires: "; - + static final long TOKEN_WAIT_DURATION_MS = 20000; + static final long TOKEN_SEM_WAIT_DURATION_MS = 5000; private static final TokenCacheMap TOKEN_CACHE_MAP = new TokenCacheMap(); private final static String LOGCONTEXT = "MSAL version " @@ -77,19 +80,28 @@ private SQLServerMSAL4JUtils() { throw new UnsupportedOperationException(SQLServerException.getErrString("R_notSupported")); } - private static final Lock lock = new ReentrantLock(); + private static final Semaphore sem = new Semaphore(1); static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, String user, String password, - String authenticationString) throws SQLServerException { + String authenticationString, int millisecondsRemaining) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); if (logger.isLoggable(Level.FINEST)) { logger.finest(LOGCONTEXT + authenticationString + ": get FedAuth token for user: " + user); } - lock.lock(); - + boolean isSemAcquired = false; try { + // + //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. + //The purpose is to optimize the token acquisition process, the first caller succeeding does caching + //which is then leveraged by subsequent threads. However, if the first thread takes considerable time, + //then we want the others to also go and try after waiting for a while. + //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints + //to get their tokens at the same time, stressing the auth endpoint. + // + isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + String hashedSecret = getHashedSecret(new String[] {fedAuthInfo.stsurl, user, password}); PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(user, hashedSecret); @@ -116,7 +128,7 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str .builder(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT), user, password.toCharArray()) .build()); - final IAuthenticationResult authenticationResult = future.get(); + final IAuthenticationResult authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); if (logger.isLoggable(Level.FINER)) { logger.finer( @@ -132,14 +144,18 @@ static SqlAuthenticationToken getSqlFedAuthToken(SqlFedAuthInfo fedAuthInfo, Str throw new SQLServerException(e.getMessage(), e); } catch (MalformedURLException | ExecutionException e) { throw getCorrectedException(e, user, authenticationString); + } catch (TimeoutException e) { + throw getCorrectedException(new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e), user, authenticationString); } finally { - lock.unlock(); + if (isSemAcquired) { + sem.release(); + } executorService.shutdown(); } } static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuthInfo, String aadPrincipalID, - String aadPrincipalSecret, String authenticationString) throws SQLServerException { + String aadPrincipalSecret, String authenticationString, int millisecondsRemaining) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); if (logger.isLoggable(Level.FINEST)) { @@ -151,10 +167,19 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth : fedAuthInfo.spn + defaultScopeSuffix; Set scopes = new HashSet<>(); scopes.add(scope); - - lock.lock(); - + + boolean isSemAcquired = false; try { + // + //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. + //The purpose is to optimize the token acquisition process, the first caller succeeding does caching + //which is then leveraged by subsequent threads. However, if the first thread takes considerable time, + //then we want the others to also go and try after waiting for a while. + //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints + //to get their tokens at the same time, stressing the auth endpoint. + // + isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + String hashedSecret = getHashedSecret( new String[] {fedAuthInfo.stsurl, aadPrincipalID, aadPrincipalSecret}); PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(aadPrincipalID, @@ -181,7 +206,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth final CompletableFuture future = clientApplication .acquireToken(ClientCredentialParameters.builder(scopes).build()); - final IAuthenticationResult authenticationResult = future.get(); + final IAuthenticationResult authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); if (logger.isLoggable(Level.FINER)) { logger.finer( @@ -197,15 +222,19 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipal(SqlFedAuthInfo fedAuth throw new SQLServerException(e.getMessage(), e); } catch (MalformedURLException | ExecutionException e) { throw getCorrectedException(e, aadPrincipalID, authenticationString); + } catch (TimeoutException e) { + throw getCorrectedException(new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e), aadPrincipalID, authenticationString); } finally { - lock.unlock(); + if (isSemAcquired) { + sem.release(); + } executorService.shutdown(); } } static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthInfo fedAuthInfo, String aadPrincipalID, String certFile, String certPassword, String certKey, String certKeyPassword, - String authenticationString) throws SQLServerException { + String authenticationString, int millisecondsRemaining) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); if (logger.isLoggable(Level.FINEST)) { @@ -219,9 +248,18 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI Set scopes = new HashSet<>(); scopes.add(scope); - lock.lock(); - + boolean isSemAcquired = false; try { + // + //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. + //The purpose is to optimize the token acquisition process, the first caller succeeding does caching + //which is then leveraged by subsequent threads. However, if the first thread takes considerable time, + //then we want the others to also go and try after waiting for a while. + //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints + //to get their tokens at the same time, stressing the auth endpoint. + // + isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + String hashedSecret = getHashedSecret(new String[] {fedAuthInfo.stsurl, aadPrincipalID, certFile, certPassword, certKey, certKeyPassword}); PersistentTokenCacheAccessAspect persistentTokenCacheAccessAspect = TOKEN_CACHE_MAP.getEntry(aadPrincipalID, @@ -297,7 +335,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI final CompletableFuture future = clientApplication .acquireToken(ClientCredentialParameters.builder(scopes).build()); - final IAuthenticationResult authenticationResult = future.get(); + final IAuthenticationResult authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); if (logger.isLoggable(Level.FINER)) { logger.finer( @@ -315,17 +353,21 @@ static SqlAuthenticationToken getSqlFedAuthTokenPrincipalCertificate(SqlFedAuthI // this includes all certificate exceptions throw new SQLServerException(SQLServerException.getErrString("R_readCertError") + e.getMessage(), null, 0, null); + } catch (TimeoutException e) { + throw getCorrectedException(new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e), aadPrincipalID, authenticationString); } catch (Exception e) { throw getCorrectedException(e, aadPrincipalID, authenticationString); } finally { - lock.unlock(); + if (isSemAcquired) { + sem.release(); + } executorService.shutdown(); } } static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAuthInfo, - String authenticationString) throws SQLServerException { + String authenticationString, int millisecondsRemaining) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); /* @@ -340,9 +382,18 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut + "realm name:" + kerberosPrincipal.getRealm()); } - lock.lock(); - + boolean isSemAcquired = false; try { + // + //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. + //The purpose is to optimize the token acquisition process, the first caller succeeding does caching + //which is then leveraged by subsequent threads. However, if the first thread takes considerable time, + //then we want the others to also go and try after waiting for a while. + //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints + //to get their tokens at the same time, stressing the auth endpoint. + // + isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + final PublicClientApplication pca = PublicClientApplication .builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID).executorService(executorService) .setTokenCacheAccessAspect(PersistentTokenCacheAccessAspect.getInstance()) @@ -352,7 +403,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut .acquireToken(IntegratedWindowsAuthenticationParameters .builder(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT), user).build()); - final IAuthenticationResult authenticationResult = future.get(); + final IAuthenticationResult authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); if (logger.isLoggable(Level.FINER)) { logger.finer( @@ -368,23 +419,36 @@ static SqlAuthenticationToken getSqlFedAuthTokenIntegrated(SqlFedAuthInfo fedAut throw new SQLServerException(e.getMessage(), e); } catch (IOException | ExecutionException e) { throw getCorrectedException(e, user, authenticationString); + } catch (TimeoutException e) { + throw getCorrectedException(new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e), user, authenticationString); } finally { - lock.unlock(); + if (isSemAcquired) { + sem.release(); + } executorService.shutdown(); } } static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAuthInfo, String user, - String authenticationString) throws SQLServerException { + String authenticationString, int millisecondsRemaining) throws SQLServerException { ExecutorService executorService = Executors.newSingleThreadExecutor(); if (logger.isLoggable(Level.FINER)) { logger.finer(LOGCONTEXT + authenticationString + ": get FedAuth token interactive for user: " + user); } - lock.lock(); - + boolean isSemAcquired = false; try { + // + //Just try to acquire the semaphore and if can't then proceed to attempt to get the token. + //The purpose is to optimize the token acquisition process, the first caller succeeding does caching + //which is then leveraged by subsequent threads. However, if the first thread takes considerable time, + //then we want the others to also go and try after waiting for a while. + //If we were to let say 30 threads try in parallel, they would all miss the cache and hit the AAD auth endpoints + //to get their tokens at the same time, stressing the auth endpoint. + // + isSemAcquired = sem.tryAcquire(Math.min(millisecondsRemaining, TOKEN_SEM_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); + PublicClientApplication pca = PublicClientApplication .builder(ActiveDirectoryAuthentication.JDBC_FEDAUTH_CLIENT_ID).executorService(executorService) .setTokenCacheAccessAspect(PersistentTokenCacheAccessAspect.getInstance()) @@ -432,7 +496,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu } if (null != future) { - authenticationResult = future.get(); + authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); } else { // acquire token interactively with system browser if (logger.isLoggable(Level.FINEST)) { @@ -444,7 +508,7 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu .loginHint(user).scopes(Collections.singleton(fedAuthInfo.spn + SLASH_DEFAULT)).build(); future = pca.acquireToken(parameters); - authenticationResult = future.get(); + authenticationResult = future.get(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), TimeUnit.MILLISECONDS); } if (logger.isLoggable(Level.FINER)) { @@ -461,8 +525,12 @@ static SqlAuthenticationToken getSqlFedAuthTokenInteractive(SqlFedAuthInfo fedAu throw new SQLServerException(e.getMessage(), e); } catch (MalformedURLException | URISyntaxException | ExecutionException e) { throw getCorrectedException(e, user, authenticationString); + } catch (TimeoutException e) { + throw getCorrectedException(new SQLServerException(SQLServerException.getErrString("R_connectionTimedOut"), e), user, authenticationString); } finally { - lock.unlock(); + if (isSemAcquired) { + sem.release(); + } executorService.shutdown(); } } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java index 2128f9cef..244cfb696 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java @@ -1246,16 +1246,25 @@ public final java.sql.ResultSetMetaData getMetaData() throws SQLServerException, */ private SQLServerResultSet buildExecuteMetaData() throws SQLServerException, SQLTimeoutException { String fmtSQL = userSQL; - + SQLServerResultSet emptyResultSet = null; try { - fmtSQL = replaceMarkerWithNull(fmtSQL); internalStmt = (SQLServerStatement) connection.createStatement(); emptyResultSet = internalStmt.executeQueryInternal("set fmtonly on " + fmtSQL + "\nset fmtonly off"); } catch (SQLServerException sqle) { // Ignore empty result set errors, otherwise propagate the server error. if (!sqle.getMessage().equals(SQLServerException.getErrString("R_noResultset"))) { - throw sqle; + //try by replacing ? characters in case that was an issue + try { + fmtSQL = replaceMarkerWithNull(fmtSQL); + internalStmt = (SQLServerStatement) connection.createStatement(); + emptyResultSet = internalStmt.executeQueryInternal("set fmtonly on " + fmtSQL + "\nset fmtonly off"); + } catch (SQLServerException ex) { + // Ignore empty result set errors, otherwise propagate the server error. + if (!ex.getMessage().equals(SQLServerException.getErrString("R_noResultset"))) { + throw ex; + } + } } } return emptyResultSet; diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java index 70c50ca28..d4e49ccde 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerSecurityUtility.java @@ -8,6 +8,8 @@ import java.security.InvalidKeyException; import java.security.NoSuchAlgorithmException; import java.text.MessageFormat; +import java.time.Duration; +import java.time.temporal.ChronoUnit; import java.util.Arrays; import java.util.HashMap; import java.util.Optional; @@ -56,6 +58,8 @@ class SQLServerSecurityUtility { private static final Lock CREDENTIAL_LOCK = new ReentrantLock(); + private static final int TOKEN_WAIT_DURATION_MS = 20000; + private SQLServerSecurityUtility() { throw new UnsupportedOperationException(SQLServerException.getErrString("R_notSupported")); } @@ -340,7 +344,7 @@ static void verifyColumnMasterKeyMetadata(SQLServerConnection connection, SQLSer * @throws SQLServerException */ static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, - String managedIdentityClientId) throws SQLServerException { + String managedIdentityClientId, long millisecondsRemaining) throws SQLServerException { if (logger.isLoggable(java.util.logging.Level.FINEST)) { logger.finest("Getting Managed Identity authentication token for: " + managedIdentityClientId); @@ -379,7 +383,7 @@ static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, SqlAuthenticationToken sqlFedAuthToken = null; - Optional accessTokenOptional = mic.getToken(tokenRequestContext).blockOptional(); + Optional accessTokenOptional = mic.getToken(tokenRequestContext).timeout(Duration.of(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), ChronoUnit.MILLIS)).blockOptional(); if (!accessTokenOptional.isPresent()) { throw new SQLServerException(SQLServerException.getErrString("R_ManagedIdentityTokenAcquisitionFail"), @@ -408,7 +412,7 @@ static SqlAuthenticationToken getManagedIdentityCredAuthToken(String resource, * @throws SQLServerException */ static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource, - String managedIdentityClientId) throws SQLServerException { + String managedIdentityClientId, int millisecondsRemaining) throws SQLServerException { String intellijKeepassPath = System.getenv(INTELLIJ_KEEPASS_PASS); String[] additionallyAllowedTenants = getAdditonallyAllowedTenants(); @@ -463,7 +467,7 @@ static SqlAuthenticationToken getDefaultAzureCredAuthToken(String resource, SqlAuthenticationToken sqlFedAuthToken = null; - Optional accessTokenOptional = dac.getToken(tokenRequestContext).blockOptional(); + Optional accessTokenOptional = dac.getToken(tokenRequestContext).timeout(Duration.of(Math.min(millisecondsRemaining, TOKEN_WAIT_DURATION_MS), ChronoUnit.MILLIS)).blockOptional(); if (!accessTokenOptional.isPresent()) { throw new SQLServerException(SQLServerException.getErrString("R_ManagedIdentityTokenAcquisitionFail"), diff --git a/src/main/java/microsoft/sql/DateTimeOffset.java b/src/main/java/microsoft/sql/DateTimeOffset.java index bf9e95c7b..dd1de85b2 100644 --- a/src/main/java/microsoft/sql/DateTimeOffset.java +++ b/src/main/java/microsoft/sql/DateTimeOffset.java @@ -5,6 +5,8 @@ package microsoft.sql; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; import java.util.Calendar; import java.util.Locale; import java.util.TimeZone; @@ -190,7 +192,6 @@ public String toString() { .substring(2), // -> "123456" formattedOffset); } - return result; } @@ -257,12 +258,32 @@ public java.sql.Timestamp getTimestamp() { * @return OffsetDateTime equivalent to this DateTimeOffset object. */ public java.time.OffsetDateTime getOffsetDateTime() { - java.time.ZoneOffset zoneOffset = java.time.ZoneOffset.ofTotalSeconds(60 * minutesOffset); - java.time.LocalDateTime localDateTime = java.time.LocalDateTime.ofEpochSecond(utcMillis / 1000, nanos, - zoneOffset); - return java.time.OffsetDateTime.of(localDateTime, zoneOffset); + // Format the offset as +hh:mm or -hh:mm. Zero offset is formatted as +00:00. + String formattedOffset = (minutesOffset < 0) ? + String.format(Locale.US, "-%1$02d:%2$02d", -minutesOffset / 60, -minutesOffset % 60) : + String.format(Locale.US, "+%1$02d:%2$02d", minutesOffset / 60, minutesOffset % 60); + + // Create a Calendar instance with the time zone set to GMT plus the formatted offset + Calendar calendar = Calendar.getInstance(TimeZone.getTimeZone("GMT" + formattedOffset), Locale.US); + // Initialize the calendar with the UTC milliseconds value + calendar.setTimeInMillis(utcMillis); + + // Extract the date and time components from the calendar + int year = calendar.get(Calendar.YEAR); + int month = calendar.get(Calendar.MONTH) + 1; // Calendar.MONTH is zero-based + int day = calendar.get(Calendar.DAY_OF_MONTH); + int hour = calendar.get(Calendar.HOUR_OF_DAY); + int minute = calendar.get(Calendar.MINUTE); + int second = calendar.get(Calendar.SECOND); + + // Create the ZoneOffset from the minutesOffset + ZoneOffset offset = ZoneOffset.ofTotalSeconds(minutesOffset * 60); + + // Create and return the OffsetDateTime + return OffsetDateTime.of(year, month, day, hour, minute, second, nanos, offset); } - + + /** * Returns this DateTimeOffset object's offset value. * diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java index 916aa419f..787c8151e 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerConnectionTest.java @@ -42,6 +42,7 @@ import com.microsoft.aad.msal4j.TokenCache; import com.microsoft.aad.msal4j.TokenCacheAccessContext; +import com.microsoft.sqlserver.jdbc.SQLServerConnection.SqlFedAuthInfo; import com.microsoft.sqlserver.testframework.AbstractSQLGenerator; import com.microsoft.sqlserver.testframework.AbstractTest; import com.microsoft.sqlserver.testframework.Constants; @@ -50,6 +51,7 @@ @RunWith(JUnitPlatform.class) public class SQLServerConnectionTest extends AbstractTest { + // If no retry is done, the function should at least exit in 5 seconds static int threshHoldForNoRetryInMilliseconds = 5000; static int loginTimeOutInSeconds = 10; @@ -1321,4 +1323,51 @@ public void testServerNameField() throws SQLException { assertTrue(e.getMessage().matches(TestUtils.formatErrorMsg("R_errorServerName"))); } } + + + @Test + public void testGetSqlFedAuthTokenFailure() throws SQLException { + try (Connection conn = getConnection()){ + SqlFedAuthInfo fedAuthInfo = ((SQLServerConnection) conn).new SqlFedAuthInfo(); + fedAuthInfo.spn = "https://database.windows.net/"; + fedAuthInfo.stsurl = "https://login.windows.net/xxx"; + SqlAuthenticationToken fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, "xxx", + "xxx",SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString(), 10); + fail(TestResource.getResource("R_expectedExceptionNotThrown")); + } catch (SQLServerException e) { + //test pass + assertTrue(e.getMessage().contains(SQLServerException.getErrString("R_connectionTimedOut")), "Expected Timeout Exception was not thrown"); + } + } + + @Test + public void testGetSqlFedAuthTokenFailureNoWaiting() throws SQLException { + try (Connection conn = getConnection()){ + SqlFedAuthInfo fedAuthInfo = ((SQLServerConnection) conn).new SqlFedAuthInfo(); + fedAuthInfo.spn = "https://database.windows.net/"; + fedAuthInfo.stsurl = "https://login.windows.net/xxx"; + SqlAuthenticationToken fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, "xxx", + "xxx",SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString(), 0); + fail(TestResource.getResource("R_expectedExceptionNotThrown")); + } catch (SQLServerException e) { + //test pass + assertTrue(e.getMessage().contains(SQLServerException.getErrString("R_connectionTimedOut")), "Expected Timeout Exception was not thrown"); + } + } + + @Test + public void testGetSqlFedAuthTokenFailureNagativeWaiting() throws SQLException { + try (Connection conn = getConnection()){ + SqlFedAuthInfo fedAuthInfo = ((SQLServerConnection) conn).new SqlFedAuthInfo(); + fedAuthInfo.spn = "https://database.windows.net/"; + fedAuthInfo.stsurl = "https://login.windows.net/xxx"; + SqlAuthenticationToken fedAuthToken = SQLServerMSAL4JUtils.getSqlFedAuthToken(fedAuthInfo, "xxx", + "xxx",SqlAuthentication.ACTIVE_DIRECTORY_PASSWORD.toString(), -1); + fail(TestResource.getResource("R_expectedExceptionNotThrown")); + } catch (SQLServerException e) { + //test pass + assertTrue(e.getMessage().contains(SQLServerException.getErrString("R_connectionTimedOut")), "Expected Timeout Exception was not thrown"); + } + } + } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerDriverTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerDriverTest.java index 5309780ca..646ad75e9 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerDriverTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/SQLServerDriverTest.java @@ -2,6 +2,8 @@ import static org.junit.Assert.fail; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertTrue; import java.sql.Connection; @@ -190,4 +192,78 @@ public void testConnectionDriver() throws SQLException { } } } + + /** + * test application name + * + * @throws SQLException + */ + @Test + public void testApplicationName() throws SQLException { + try (Connection conn = DriverManager.getConnection(connectionString); + Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT program_name FROM sys.dm_exec_sessions WHERE session_id = @@SPID")) { + if (rs.next()) { + assertEquals(SQLServerDriver.constructedAppName, rs.getString("program_name")); + } + } catch (SQLException e) { + fail(e.getMessage()); + } + } + + /** + * test application name by executing select app_name() + * + * @throws SQLException + */ + @Test + public void testApplicationNameUsingApp_Name() throws SQLException { + try (Connection conn = DriverManager.getConnection(connectionString); + Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT app_name()")) { + if (rs.next()) { + assertEquals(SQLServerDriver.constructedAppName, rs.getString(1)); + } + } catch (SQLException e) { + fail(e.getMessage()); + } + } + + /** + * test application name by executing select app_name() + * + * @throws SQLException + */ + @Test + public void testAppNameWithSpecifiedApplicationName() throws SQLException { + String url = connectionString + ";applicationName={0123456789012345678901234567890123456789012345678901234567890123456789012345678901234589012345678901234567890123456789012345678}"; + + try (Connection conn = DriverManager.getConnection(url); + Statement stmt = conn.createStatement(); + ResultSet rs = stmt.executeQuery("SELECT app_name()")) { + if (rs.next()) { + assertEquals("0123456789012345678901234567890123456789012345678901234567890123456789012345678901234589012345678901234567890123456789012345678", rs.getString(1)); + } + } catch (SQLException e) { + fail(e.getMessage()); + } + } + + /** + * test application name when system properties are empty + * + */ + @Test + public void testGetAppName() { + String appName = SQLServerDriver.getAppName(); + assertNotNull(appName, "Application name should not be null"); + assertFalse(appName.isEmpty(), "Application name should not be empty"); + + System.setProperty("os.name", ""); + System.setProperty("os.arch", ""); + System.setProperty("java.vm.name", ""); + System.setProperty("java.vm.version", ""); + String defaultAppName = SQLServerDriver.getAppName(); + assertEquals(SQLServerDriver.DEFAULT_APP_NAME, defaultAppName, "Application name should be the default one"); + } } diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/datatypes/DataTypesTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/datatypes/DataTypesTest.java index 445d23dd0..d1bb30a2c 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/datatypes/DataTypesTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/datatypes/DataTypesTest.java @@ -1945,6 +1945,34 @@ public void testDateTimeOffsetValueOfOffsetDateTime() throws Exception { assertEquals(expected, DateTimeOffset.valueOf(roundUp).getOffsetDateTime()); assertEquals(expected, DateTimeOffset.valueOf(roundDown).getOffsetDateTime()); } + + @Test + public void testPreGregorianDateTime() throws Exception { + try (Connection conn = getConnection(); + Statement stmt = conn.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_UPDATABLE);) { + + conn.setAutoCommit(false); + TestUtils.dropTableIfExists(escapedTableName, stmt); + + stmt.executeUpdate("CREATE TABLE " + escapedTableName + " (dob datetimeoffset(7) null)"); + stmt.executeUpdate("INSERT INTO " + escapedTableName + " VALUES ('1500-12-16 00:00:00.0000000+08:00')"); + stmt.executeUpdate("INSERT INTO " + escapedTableName + " VALUES ('1400-09-27 09:30:00.0000000+08:00')"); + stmt.executeUpdate("INSERT INTO " + escapedTableName + " VALUES ('2024-12-16 23:40:00.0000000+08:00')"); + + try (ResultSet rs = stmt.executeQuery("select dob from " + escapedTableName + " order by dob")) { + while (rs.next()) { + String strDateTimeOffset = rs.getString(1).substring(0, 10); + DateTimeOffset objDateTimeOffset = (DateTimeOffset) rs.getObject(1); + OffsetDateTime objOffsetDateTime = objDateTimeOffset.getOffsetDateTime(); + + String strOffsetDateTime = objOffsetDateTime.toString().substring(0, 10); + assertEquals(strDateTimeOffset, strOffsetDateTime, "Mismatch found in DateTimeOffset : " + + objDateTimeOffset + " and OffsetDateTime : " + objOffsetDateTime); + } + } + TestUtils.dropTableIfExists(escapedTableName, stmt); + } + } static LocalDateTime getUnstorableValue() throws Exception { ZoneId systemTimezone = ZoneId.systemDefault(); diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/PreparedStatementTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/PreparedStatementTest.java index 23c89071d..87e0994aa 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/PreparedStatementTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/PreparedStatementTest.java @@ -13,8 +13,10 @@ import java.lang.reflect.Field; import java.sql.BatchUpdateException; +import java.sql.Connection; import java.sql.PreparedStatement; import java.sql.ResultSet; +import java.sql.ResultSetMetaData; import java.sql.SQLException; import java.sql.Statement; import java.sql.Types; @@ -124,6 +126,25 @@ public void testPreparedStatementWithSpPrepare() throws SQLException { } } } + + @Test + void testDatabaseQueryMetaData() throws SQLException { + try (Connection connection = getConnection()) { + try (SQLServerPreparedStatement stmt = (SQLServerPreparedStatement) connection.prepareStatement( + "select 1 as \"any questions ???\"")) { + ResultSetMetaData metaData = stmt.getMetaData(); + String actualLabel = metaData.getColumnLabel(1); + String actualName = metaData.getColumnName(1); + + String expected = "any questions ???"; + assertEquals(expected, actualLabel, "Column label should match the expected value"); + assertEquals(expected, actualName, "Column name should match the expected value"); + } + } catch (SQLException e) { + e.printStackTrace(); + fail("SQLException occurred during test: " + e.getMessage()); + } + } @Test public void testPreparedStatementParamNameSpacingWithMultipleParams() throws SQLException { @@ -927,5 +948,5 @@ private static void dropTables() throws Exception { TestUtils.dropTableIfExists(AbstractSQLGenerator.escapeIdentifier(tableName5), stmt); } } - + }