Skip to content

Commit 712740b

Browse files
thoven87fabianfett
andauthored
Add withTransaction API (#519)
Co-authored-by: Fabian Fett <[email protected]>
1 parent 8d07f20 commit 712740b

File tree

2 files changed

+126
-0
lines changed

2 files changed

+126
-0
lines changed

Diff for: Sources/PostgresNIO/Pool/PostgresClient.swift

+22
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,28 @@ public final class PostgresClient: Sendable, ServiceLifecycle.Service {
307307

308308
return try await closure(connection)
309309
}
310+
311+
/// Lease a connection for the provided `closure`'s lifetime.
312+
/// A transation starts with call to withConnection
313+
/// A transaction should end with a call to COMMIT or ROLLBACK
314+
/// COMMIT is called upon successful completion and ROLLBACK is called should any steps fail
315+
///
316+
/// - Parameter closure: A closure that uses the passed `PostgresConnection`. The closure **must not** capture
317+
/// the provided `PostgresConnection`.
318+
/// - Returns: The closure's return value.
319+
public func withTransaction<Result>(_ process: (PostgresConnection) async throws -> Result) async throws -> Result {
320+
try await withConnection { connection in
321+
try await connection.query("BEGIN;", logger: self.backgroundLogger)
322+
do {
323+
let value = try await process(connection)
324+
try await connection.query("COMMIT;", logger: self.backgroundLogger)
325+
return value
326+
} catch {
327+
try await connection.query("ROLLBACK;", logger: self.backgroundLogger)
328+
throw error
329+
}
330+
}
331+
}
310332

311333
/// Run a query on the Postgres server the client is connected to.
312334
///

Diff for: Tests/IntegrationTests/PostgresClientTests.swift

+104
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,110 @@ final class PostgresClientTests: XCTestCase {
4242
taskGroup.cancelAll()
4343
}
4444
}
45+
46+
func testTransaction() async throws {
47+
var mlogger = Logger(label: "test")
48+
mlogger.logLevel = .debug
49+
let logger = mlogger
50+
let eventLoopGroup = MultiThreadedEventLoopGroup(numberOfThreads: 8)
51+
self.addTeardownBlock {
52+
try await eventLoopGroup.shutdownGracefully()
53+
}
54+
55+
let tableName = "test_client_transactions"
56+
57+
let clientConfig = PostgresClient.Configuration.makeTestConfiguration()
58+
let client = PostgresClient(configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger)
59+
60+
do {
61+
try await withThrowingTaskGroup(of: Void.self) { taskGroup in
62+
taskGroup.addTask {
63+
await client.run()
64+
}
65+
66+
try await client.query(
67+
"""
68+
CREATE TABLE IF NOT EXISTS "\(unescaped: tableName)" (
69+
id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY,
70+
uuid UUID NOT NULL
71+
);
72+
""",
73+
logger: logger
74+
)
75+
76+
let iterations = 1000
77+
78+
for _ in 0..<iterations {
79+
taskGroup.addTask {
80+
let _ = try await client.withTransaction { transaction in
81+
try await transaction.query(
82+
"""
83+
INSERT INTO "\(unescaped: tableName)" (uuid) VALUES (\(UUID()));
84+
""",
85+
logger: logger
86+
)
87+
}
88+
}
89+
}
90+
91+
for _ in 0..<iterations {
92+
_ = await taskGroup.nextResult()!
93+
}
94+
95+
let rows = try await client.query(#"SELECT COUNT(1)::INT AS table_size FROM "\#(unescaped: tableName)";"#, logger: logger).decode(Int.self)
96+
for try await (count) in rows {
97+
XCTAssertEqual(count, iterations)
98+
}
99+
100+
/// Test roll back
101+
taskGroup.addTask {
102+
103+
do {
104+
let _ = try await client.withTransaction { transaction in
105+
/// insert valid data
106+
try await transaction.query(
107+
"""
108+
INSERT INTO "\(unescaped: tableName)" (uuid) VALUES (\(UUID()));
109+
""",
110+
logger: logger
111+
)
112+
113+
/// insert invalid data
114+
try await transaction.query(
115+
"""
116+
INSERT INTO "\(unescaped: tableName)" (uuid) VALUES (\(iterations));
117+
""",
118+
logger: logger
119+
)
120+
}
121+
} catch {
122+
XCTAssertNotNil(error)
123+
guard let error = error as? PSQLError else { return XCTFail("Unexpected error type") }
124+
125+
XCTAssertEqual(error.code, .server)
126+
XCTAssertEqual(error.serverInfo?[.severity], "ERROR")
127+
}
128+
}
129+
130+
let row = try await client.query(#"SELECT COUNT(1)::INT AS table_size FROM "\#(unescaped: tableName)";"#, logger: logger).decode(Int.self)
131+
132+
for try await (count) in row {
133+
XCTAssertEqual(count, iterations)
134+
}
135+
136+
try await client.query(
137+
"""
138+
DROP TABLE "\(unescaped: tableName)";
139+
""",
140+
logger: logger
141+
)
142+
143+
taskGroup.cancelAll()
144+
}
145+
} catch {
146+
XCTFail("Unexpected error: \(String(reflecting: error))")
147+
}
148+
}
45149

46150
func testApplicationNameIsForwardedCorrectly() async throws {
47151
var mlogger = Logger(label: "test")

0 commit comments

Comments
 (0)