From 25ea8da0a0a56ee45ad79daa2b41cecf114312b2 Mon Sep 17 00:00:00 2001 From: Mahendra Chavan Date: Thu, 30 Jan 2025 13:38:48 +0530 Subject: [PATCH] Issue#2550 - Fixed getGeneratedKeys functionality for execute API (#2554) * Issue#2550 - Fixed getGeneratedKeys functionality for execute API * Adapted the test for working with jdk8 * Fixed indenetation * Enable thre new TCGenKeys tests for AzureDW * Incorporated review comments. * Incorporated review comments * Incorporated review comments * Add a test for PreparedStatement * Adding a fix and test case for issue # 2587 * Added a new test for execute API with set no count --- .../jdbc/SQLServerPreparedStatement.java | 2 +- .../sqlserver/jdbc/SQLServerStatement.java | 41 +- .../microsoft/sqlserver/jdbc/StreamDone.java | 2 +- .../jdbc/unit/statement/StatementTest.java | 386 ++++++++++++++++++ 4 files changed, 419 insertions(+), 12 deletions(-) diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java index 96ce9858f..2a0c79d97 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerPreparedStatement.java @@ -705,7 +705,7 @@ final void doExecutePreparedStatement(PrepStmtExecCmd command) throws SQLServerE if (EXECUTE_QUERY == executeMethod && null == resultSet) { SQLServerException.makeFromDriverError(connection, this, SQLServerException.getErrString("R_noResultset"), null, true); - } else if (EXECUTE_UPDATE == executeMethod && null != resultSet) { + } else if ((EXECUTE_UPDATE == executeMethod) && (null != resultSet) && !bRequestedGeneratedKeys) { SQLServerException.makeFromDriverError(connection, this, SQLServerException.getErrString("R_resultsetGeneratedForUpdate"), null, false); } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java index ab9e9fbe2..3faa0ebb3 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/SQLServerStatement.java @@ -1601,9 +1601,15 @@ boolean onDone(TDSReader tdsReader) throws SQLServerException { if (null != procedureName) return false; + //For Insert, we must fetch additional TDS_DONE token that comes with the actual update count + if ((StreamDone.CMD_INSERT == doneToken.getCurCmd()) && (-1 != doneToken.getUpdateCount()) && EXECUTE == executeMethod) { + return true; + } + // Always return all update counts from statements executed through Statement.execute() - if (EXECUTE == executeMethod) - return false; + if (EXECUTE == executeMethod) { + return false; + } // Statement.executeUpdate() may or may not return this update count depending on the // setting of the lastUpdateCount connection property: @@ -2357,17 +2363,27 @@ public final ResultSet getGeneratedKeys() throws SQLServerException { if (null == autoGeneratedKeys) { long orgUpd = updateCount; + // + //A case of SET NOCOUNT ON and GENERATED KEYS requested + //where we may not have received update count but would have already read the resultset + //so directly consume it. + // + if ((executeMethod != EXECUTE_QUERY) && bRequestedGeneratedKeys && (resultSet != null)) { + autoGeneratedKeys = resultSet; + updateCount = orgUpd; + } else { - // Generated keys are returned in a ResultSet result right after the update count. - // Try to get that ResultSet. If there are no more results after the update count, - // or if the next result isn't a ResultSet, then something is wrong. - if (!getNextResult(true) || null == resultSet) { - SQLServerException.makeFromDriverError(connection, this, - SQLServerException.getErrString("R_statementMustBeExecuted"), null, false); + // Generated keys are returned in a ResultSet result right after the update count. + // Try to get that ResultSet. If there are no more results after the update count, + // or if the next result isn't a ResultSet, then something is wrong. + if (!getNextResult(true) || null == resultSet) { + SQLServerException.makeFromDriverError(connection, this, + SQLServerException.getErrString("R_statementMustBeExecuted"), null, false); + } + autoGeneratedKeys = resultSet; + updateCount = orgUpd; } - autoGeneratedKeys = resultSet; - updateCount = orgUpd; } loggerExternal.exiting(getClassNameLogging(), "getGeneratedKeys", autoGeneratedKeys); return autoGeneratedKeys; @@ -2616,6 +2632,11 @@ SQLServerColumnEncryptionKeyStoreProvider getColumnEncryptionKeyStoreProvider( lock.unlock(); } } + + protected void setAutoGeneratedKey(SQLServerResultSet rs) { + autoGeneratedKeys = rs; + } + } diff --git a/src/main/java/com/microsoft/sqlserver/jdbc/StreamDone.java b/src/main/java/com/microsoft/sqlserver/jdbc/StreamDone.java index 319c9ad6c..ede354260 100644 --- a/src/main/java/com/microsoft/sqlserver/jdbc/StreamDone.java +++ b/src/main/java/com/microsoft/sqlserver/jdbc/StreamDone.java @@ -232,7 +232,7 @@ final long getUpdateCount() { } final boolean cmdIsDMLOrDDL() { - switch (curCmd) { + switch (curCmd) { case CMD_INSERT: case CMD_BULKINSERT: case CMD_DELETE: diff --git a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/StatementTest.java b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/StatementTest.java index 9c814916d..959837f0a 100644 --- a/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/StatementTest.java +++ b/src/test/java/com/microsoft/sqlserver/jdbc/unit/statement/StatementTest.java @@ -24,6 +24,8 @@ import java.sql.Statement; import java.sql.Types; import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; import java.util.Random; import java.util.UUID; import java.util.concurrent.Executors; @@ -2692,4 +2694,388 @@ public void terminate() throws Exception { } } } + + + @Nested + public class TCGenKeys { + private final String tableName = AbstractSQLGenerator + .escapeIdentifier(RandomUtil.getIdentifier("TCInsertWithGenKeys")); + private final String idTableName = AbstractSQLGenerator + .escapeIdentifier(RandomUtil.getIdentifier("TCInsertWithGenKeysIDs")); + + private final String triggerName = AbstractSQLGenerator.escapeIdentifier("Trigger"); + private final int NUM_ROWS = 3; + + @BeforeEach + public void setup() throws Exception { + try (Connection con = getConnection()) { + con.setAutoCommit(false); + try (Statement stmt = con.createStatement()) { + TestUtils.dropTriggerIfExists(triggerName, stmt); + stmt.executeUpdate("CREATE TABLE " + tableName + " (ID int NOT NULL IDENTITY(1,1) PRIMARY KEY, NAME varchar(32));"); + stmt.executeUpdate("CREATE TABLE " + idTableName + "(ID int NOT NULL IDENTITY(1,1) PRIMARY KEY);"); + stmt.executeUpdate("CREATE TRIGGER " + triggerName + " ON " + tableName + + " FOR INSERT AS INSERT INTO " + idTableName + " DEFAULT VALUES;"); + for (int i = 0; i < NUM_ROWS; i++) { + stmt.executeUpdate("INSERT INTO " + tableName + " (NAME) VALUES ('test')"); + } + } + con.commit(); + } + } + + /** + * Tests executeUpdate for Insert followed by getGenerateKeys + * + * @throws Exception + */ + @Test + public void testExecuteUpdateInsertAndGenKeys() { + try (Connection con = getConnection()) { + try(Statement stmt = con.createStatement()) { + String sql = "INSERT INTO " + tableName + " (NAME) VALUES('test')"; + List lst = Arrays.asList("ID"); + String[] arr = lst.toArray(new String[0]); + stmt.executeUpdate(sql, arr); + try (ResultSet generatedKeys = stmt.getGeneratedKeys()) { + if (generatedKeys.next()) { + int id = generatedKeys.getInt(1); + assertEquals(id, 4, "id should have been 4, but received : " + id); + } + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Tests executeUpdate using PreparedStatement for Insert followed by getGenerateKeys + * + * @throws Exception + */ + @Test + public void testPrepStmtExecuteUpdateInsertAndGenKeys() { + try (Connection con = getConnection()) { + String sql = "INSERT INTO " + tableName + " (NAME) VALUES('test')"; + try(PreparedStatement stmt = con.prepareStatement(sql,PreparedStatement.RETURN_GENERATED_KEYS)) { + stmt.executeUpdate(); + try (ResultSet generatedKeys = stmt.getGeneratedKeys()) { + if (generatedKeys.next()) { + int id = generatedKeys.getInt(1); + assertEquals(id, 4, "id should have been 4, but received : " + id); + } + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Tests executeUpdate using PreparedStatement for Insert followed by getGenerateKeys + * + * @throws Exception + */ + @Test + public void testPrepStmtNoCountExecuteUpdateInsertAndGenKeys() { + try (Connection con = getConnection()) { + String sql = "SET NOCOUNT ON; INSERT INTO " + tableName + " (NAME) VALUES('test')"; + try(PreparedStatement stmt = con.prepareStatement(sql,PreparedStatement.RETURN_GENERATED_KEYS)) { + stmt.executeUpdate(); + try (ResultSet generatedKeys = stmt.getGeneratedKeys()) { + if (generatedKeys.next()) { + int id = generatedKeys.getInt(1); + assertEquals(id, 4, "id should have been 4, but received : " + id); + } + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Tests executeUpdate using PreparedStatement for Insert followed by getGenerateKeys + * + * @throws Exception + */ + @Test + public void testPrepStmtNoCountExecuteInsertAndGenKeys() { + try (Connection con = getConnection()) { + String sql = "SET NOCOUNT ON; INSERT INTO " + tableName + " (NAME) VALUES('test')"; + try(PreparedStatement stmt = con.prepareStatement(sql,PreparedStatement.RETURN_GENERATED_KEYS)) { + stmt.execute(); + try (ResultSet generatedKeys = stmt.getGeneratedKeys()) { + if (generatedKeys.next()) { + int id = generatedKeys.getInt(1); + assertEquals(id, 4, "id should have been 4, but received : " + id); + } + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + + /** + * Tests execute for Insert followed by getGenerateKeys + * + * @throws Exception + */ + @Test + public void testExecuteInsertAndGenKeys() { + try (Connection con = getConnection()) { + try(Statement stmt = con.createStatement()) { + String sql = "INSERT INTO " + tableName + " (NAME) VALUES('test')"; + List lst = Arrays.asList("ID"); + String[] arr = lst.toArray(new String[0]); + stmt.execute(sql, arr); + try (ResultSet generatedKeys = stmt.getGeneratedKeys()) { + if (generatedKeys.next()) { + int id = generatedKeys.getInt(1); + assertEquals(id, 4, "generated key should have been 4"); + } + } + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Tests execute for Insert followed by select + * + * @throws Exception + */ + @Test + public void testExecuteInsertAndSelect() { + + try (Connection con = getConnection()) { + try(Statement stmt = con.createStatement()) { + boolean retval = stmt.execute("INSERT INTO " + tableName +" (NAME) VALUES('test') SELECT NAME FROM " + tableName + " WHERE ID = 1"); + do { + if (!retval) { + int count = stmt.getUpdateCount(); + if (count == -1) { + // no more results + break; + } else { + assertEquals(count, 1, "update count should have been 1"); + } + } else { + // process ResultSet + try (ResultSet rs = stmt.getResultSet()) { + if (rs.next()) { + String val = rs.getString(1); + assertEquals(val, "test", "read value should have been 'test'"); + } + } + } + retval = stmt.getMoreResults(); + } while (true); + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + + /** + * Tests execute for Merge followed by select + * + * @throws Exception + */ + @Test + public void testExecuteMergeAndSelect() { + try (Connection con = getConnection()) { + try(Statement stmt = con.createStatement()) { + boolean retval = stmt.execute("MERGE INTO " + tableName + " AS target USING (VALUES ('test1')) AS source (name) ON target.name = source.name WHEN NOT MATCHED THEN INSERT (name) VALUES ('test1'); SELECT NAME FROM " + tableName + " WHERE ID = 1"); + do { + if (!retval) { + int count = stmt.getUpdateCount(); + if (count == -1) { + // no more results + break; + } else { + assertEquals(count, 1, "update count should have been 1"); + } + } else { + // process ResultSet + try (ResultSet rs = stmt.getResultSet()) { + if (rs.next()) { + String val = rs.getString(1); + assertEquals(val, "test", "read value should have been 'test'"); + } + } + + } + retval = stmt.getMoreResults(); + } while (true); + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Tests execute for Insert multiple rows followed by select + * + * @throws Exception + */ + @Test + public void testExecuteInsertManyRowsAndSelect() { + try (Connection con = getConnection()) { + try (Statement stmt = con.createStatement()) { + boolean retval = stmt.execute("INSERT INTO " + tableName + " SELECT NAME FROM " + tableName + " SELECT NAME FROM " + tableName + " WHERE ID = 1"); + do { + if (!retval) { + int count = stmt.getUpdateCount(); + if (count == -1) { + // no more results + break; + } else { + assertEquals(count, 3, "update count should have been 6"); + } + } else { + // process ResultSet + try (ResultSet rs = stmt.getResultSet()) { + if (rs.next()) { + String val = rs.getString(1); + assertEquals(val, "test", "read value should have been 'test'"); + } + } + + } + retval = stmt.getMoreResults(); + } while (true); + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Tests execute two Inserts followed by select + * + * @throws Exception + */ + @Test + public void testExecuteTwoInsertsRowsAndSelect() { + try (Connection con = getConnection()) { + try(Statement stmt = con.createStatement()) { + boolean retval = stmt.execute("INSERT INTO " + tableName + " (NAME) VALUES('test') INSERT INTO " + tableName + " (NAME) VALUES('test') SELECT NAME from " + tableName + " WHERE ID = 1"); + do { + if (!retval) { + int count = stmt.getUpdateCount(); + if (count == -1) { + // no more results + break; + } else { + assertEquals(count, 1, "update count should have been 2"); + } + } else { + // process ResultSet + try (ResultSet rs = stmt.getResultSet()) { + if (rs.next()) { + String val = rs.getString(1); + assertEquals(val, "test", "read value should have been 'test'"); + } + } + + } + retval = stmt.getMoreResults(); + } while (true); + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + + /** + * Tests execute for Update followed by select + * + * @throws Exception + */ + @Test + public void testExecuteUpdAndSelect() { + try (Connection con = getConnection()) { + try(Statement stmt = con.createStatement()) { + boolean retval = stmt.execute("UPDATE " + tableName +" SET NAME = 'test' SELECT NAME FROM " + tableName + " WHERE ID = 1"); + do { + if (!retval) { + int count = stmt.getUpdateCount(); + if (count == -1) { + // no more results + break; + } else { + assertEquals(count, 3, "update count should have been 3"); + } + } else { + // process ResultSet + try (ResultSet rs = stmt.getResultSet()) { + if (rs.next()) { + String val = rs.getString(1); + assertEquals(val, "test", "read value should have been 'test'"); + } + } + } + retval = stmt.getMoreResults(); + } while (true); + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + /** + * Tests execute for Update followed by select + * + * @throws Exception + */ + @Test + public void testExecuteDelAndSelect() { + try (Connection con = getConnection()) { + try(Statement stmt = con.createStatement()) { + boolean retval = stmt.execute("DELETE FROM " + tableName +" WHERE ID = 1 SELECT NAME FROM " + tableName + " WHERE ID = 2"); + do { + if (!retval) { + int count = stmt.getUpdateCount(); + if (count == -1) { + // no more results + break; + } else { + assertEquals(count, 1, "update count should have been 1"); + } + } else { + // process ResultSet + try (ResultSet rs = stmt.getResultSet()) { + if (rs.next()) { + String val = rs.getString(1); + assertEquals(val, "test", "read value should have been 'test'"); + } + } + } + retval = stmt.getMoreResults(); + } while (true); + } + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + + @AfterEach + public void terminate() { + try (Connection con = getConnection(); Statement stmt = con.createStatement()) { + TestUtils.dropTriggerIfExists(triggerName, stmt); + TestUtils.dropTableIfExists(idTableName, stmt); + TestUtils.dropTableIfExists(tableName, stmt); + } catch (SQLException e) { + fail(TestResource.getResource("R_unexpectedException") + e.getMessage()); + } + } + } + }