Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue#2550 - Fixed getGeneratedKeys functionality for execute API #2554

Merged
merged 12 commits into from
Jan 30, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -1601,9 +1601,15 @@ boolean onDone(TDSReader tdsReader) throws SQLServerException {
if (null != procedureName)
return false;

//For Insert, we must fetch additional TDS_DONE token that comes with the actual update count
if (doneToken.cmdIsInsert() && (-1 != doneToken.getUpdateCount()) && EXECUTE == executeMethod) {
machavan marked this conversation as resolved.
Show resolved Hide resolved
return true;
}

// Always return all update counts from statements executed through Statement.execute()
if (EXECUTE == executeMethod)
return false;
if (EXECUTE == executeMethod) {
return false;
}

// Statement.executeUpdate() may or may not return this update count depending on the
// setting of the lastUpdateCount connection property:
Expand Down
6 changes: 5 additions & 1 deletion src/main/java/com/microsoft/sqlserver/jdbc/StreamDone.java
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,12 @@ final long getUpdateCount() {
}
}

final boolean cmdIsInsert() {
return (CMD_INSERT == curCmd);
}

final boolean cmdIsDMLOrDDL() {
switch (curCmd) {
switch (curCmd) {
case CMD_INSERT:
case CMD_BULKINSERT:
case CMD_DELETE:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import java.sql.Statement;
import java.sql.Types;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;
import java.util.UUID;
import java.util.concurrent.Executors;
Expand Down Expand Up @@ -2692,4 +2694,323 @@ public void terminate() throws Exception {
}
}
}


@Nested
public class TCGenKeys {
private final String tableName = AbstractSQLGenerator
.escapeIdentifier(RandomUtil.getIdentifier("TCInsertWithGenKeys"));
private final String idTableName = AbstractSQLGenerator
.escapeIdentifier(RandomUtil.getIdentifier("TCInsertWithGenKeysIDs"));

private final String triggerName = AbstractSQLGenerator.escapeIdentifier("Trigger");
private final int NUM_ROWS = 3;

@BeforeEach
public void setup() throws Exception {
try (Connection con = getConnection()) {
con.setAutoCommit(false);
try (Statement stmt = con.createStatement()) {
TestUtils.dropTriggerIfExists(triggerName, stmt);
stmt.executeUpdate("CREATE TABLE " + tableName + " (ID int NOT NULL IDENTITY(1,1) PRIMARY KEY, NAME varchar(32));");

stmt.executeUpdate("CREATE TABLE " + idTableName + "(ID int NOT NULL IDENTITY(1,1) PRIMARY KEY);");

stmt.executeUpdate("CREATE TRIGGER " + triggerName + " ON " + tableName + " FOR INSERT AS INSERT INTO " + idTableName + " DEFAULT VALUES;");
machavan marked this conversation as resolved.
Show resolved Hide resolved

for (int i = 0; i < NUM_ROWS; i++) {
stmt.executeUpdate("INSERT INTO " + tableName + " (NAME) VALUES ('test')");
}

}
con.commit();
}
}

/**
* Tests executeUpdate for Insert followed by getGenerateKeys
*
* @throws Exception
*/
@Test
public void testExecuteUpdateInsertAndGenKeys() throws Exception {
machavan marked this conversation as resolved.
Show resolved Hide resolved
try (Connection con = getConnection()) {
try(Statement stmt = con.createStatement()) {
String sql = "INSERT INTO " + tableName + " (NAME) VALUES('test')";
List<String> lst = Arrays.asList("ID");
String[] arr = lst.toArray(new String[0]);
stmt.executeUpdate(sql, arr);
try (ResultSet generatedKeys = stmt.getGeneratedKeys()) {
if (generatedKeys.next()) {
int id = generatedKeys.getInt(1);
assertEquals(id, 4, "id should have been 4, but received : " + id);
}
}
}
} catch (SQLException e) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
}

/**
* Tests execute for Insert followed by getGenerateKeys
*
* @throws Exception
*/
@Test
public void testExecuteInsertAndGenKeys() throws Exception {
try (Connection con = getConnection()) {
try(Statement stmt = con.createStatement()) {
String sql = "INSERT INTO " + tableName + " (NAME) VALUES('test')";
List<String> lst = Arrays.asList("ID");
String[] arr = lst.toArray(new String[0]);
stmt.execute(sql, arr);
try (ResultSet generatedKeys = stmt.getGeneratedKeys()) {
if (generatedKeys.next()) {
int id = generatedKeys.getInt(1);
assertEquals(id, 4, "generated key should have been 4");
}
}
}
} catch (SQLException e) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
}

/**
* Tests execute for Insert followed by select
*
* @throws Exception
*/
@Test
public void testExecuteInsertAndSelect() throws Exception {

try (Connection con = getConnection()) {
try(Statement stmt = con.createStatement()) {
boolean retval = stmt.execute("INSERT INTO " + tableName +" (NAME) VALUES('test') SELECT NAME FROM " + tableName + " WHERE ID = 1");
do {
if (retval == false) {
machavan marked this conversation as resolved.
Show resolved Hide resolved
int count = stmt.getUpdateCount();
if (count == -1) {
// no more results
break;
} else {
assertEquals(count, 1, "update count should have been 1");
}
} else {
// process ResultSet
try (ResultSet rs = stmt.getResultSet()) {
if (rs.next()) {
String val = rs.getString(1);
assertEquals(val, "test", "read value should have been 'test'");
}
}
}
retval = stmt.getMoreResults();
} while (true);
}
} catch (SQLException e) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
}


/**
* Tests execute for Merge followed by select
*
* @throws Exception
*/
@Test
public void testExecuteMergeAndSelect() throws Exception {
try (Connection con = getConnection()) {
try(Statement stmt = con.createStatement()) {
boolean retval = stmt.execute("MERGE INTO " + tableName + " AS target USING (VALUES ('test1')) AS source (name) ON target.name = source.name WHEN NOT MATCHED THEN INSERT (name) VALUES ('test1'); SELECT NAME FROM " + tableName + " WHERE ID = 1");
do {
if (retval == false) {
int count = stmt.getUpdateCount();
if (count == -1) {
// no more results
break;
} else {
assertEquals(count, 1, "update count should have been 1");
}
} else {
// process ResultSet
try (ResultSet rs = stmt.getResultSet()) {
if (rs.next()) {
String val = rs.getString(1);
assertEquals(val, "test", "read value should have been 'test'");
}
}

}
retval = stmt.getMoreResults();
} while (true);
}
} catch (SQLException e) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
}

/**
* Tests execute for Insert multiple rows followed by select
*
* @throws Exception
*/
@Test
public void testExecuteInsertManyRowsAndSelect() throws Exception {
try (Connection con = getConnection()) {
try (Statement stmt = con.createStatement()) {
boolean retval = stmt.execute("INSERT INTO " + tableName + " SELECT NAME FROM " + tableName + " SELECT NAME FROM " + tableName + " WHERE ID = 1");
do {
if (retval == false) {
int count = stmt.getUpdateCount();
if (count == -1) {
// no more results
break;
} else {
assertEquals(count, 3, "update count should have been 6");
}
} else {
// process ResultSet
try (ResultSet rs = stmt.getResultSet()) {
if (rs.next()) {
String val = rs.getString(1);
assertEquals(val, "test", "read value should have been 'test'");
}
}

}
retval = stmt.getMoreResults();
} while (true);
}
} catch (SQLException e) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
}

/**
* Tests execute two Inserts followed by select
*
* @throws Exception
*/
@Test
public void testExecuteTwoInsertsRowsAndSelect() throws Exception {
try (Connection con = getConnection()) {
try(Statement stmt = con.createStatement()) {
boolean retval = stmt.execute("INSERT INTO " + tableName + " (NAME) VALUES('test') INSERT INTO " + tableName + " (NAME) VALUES('test') SELECT NAME from " + tableName + " WHERE ID = 1");
do {
if (retval == false) {
int count = stmt.getUpdateCount();
if (count == -1) {
// no more results
break;
} else {
assertEquals(count, 1, "update count should have been 2");
}
} else {
// process ResultSet
try (ResultSet rs = stmt.getResultSet()) {
if (rs.next()) {
String val = rs.getString(1);
assertEquals(val, "test", "read value should have been 'test'");
}
}

}
retval = stmt.getMoreResults();
} while (true);
}
} catch (SQLException e) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
}


/**
* Tests execute for Update followed by select
*
* @throws Exception
*/
@Test
public void testExecuteUpdAndSelect() throws Exception {
try (Connection con = getConnection()) {
try(Statement stmt = con.createStatement()) {
boolean retval = stmt.execute("UPDATE " + tableName +" SET NAME = 'test' SELECT NAME FROM " + tableName + " WHERE ID = 1");
do {
if (retval == false) {
int count = stmt.getUpdateCount();
if (count == -1) {
// no more results
break;
} else {
assertEquals(count, 3, "update count should have been 3");
}
} else {
// process ResultSet
try (ResultSet rs = stmt.getResultSet()) {
if (rs.next()) {
String val = rs.getString(1);
assertEquals(val, "test", "read value should have been 'test'");
}
}
}
retval = stmt.getMoreResults();
} while (true);
}
} catch (SQLException e) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
}

/**
* Tests execute for Update followed by select
*
* @throws Exception
*/
@Test
public void testExecuteDelAndSelect() throws Exception {
try (Connection con = getConnection()) {
try(Statement stmt = con.createStatement()) {
boolean retval = stmt.execute("DELETE FROM " + tableName +" WHERE ID = 1 SELECT NAME FROM " + tableName + " WHERE ID = 2");
do {
if (retval == false) {
int count = stmt.getUpdateCount();
if (count == -1) {
// no more results
break;
} else {
assertEquals(count, 1, "update count should have been 1");
}
} else {
// process ResultSet
try (ResultSet rs = stmt.getResultSet()) {
if (rs.next()) {
String val = rs.getString(1);
assertEquals(val, "test", "read value should have been 'test'");
}
}
}
retval = stmt.getMoreResults();
} while (true);
}
} catch (SQLException e) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
}

@AfterEach
public void terminate() throws Exception {
try (Connection con = getConnection(); Statement stmt = con.createStatement()) {
try {
TestUtils.dropTriggerIfExists(triggerName, stmt);
TestUtils.dropTableIfExists(idTableName, stmt);
TestUtils.dropTableIfExists(tableName, stmt);
} catch (SQLException e) {
fail(TestResource.getResource("R_unexpectedException") + e.getMessage());
}
}
}
}

}
Loading