@@ -42,6 +42,110 @@ final class PostgresClientTests: XCTestCase {
42
42
taskGroup. cancelAll ( )
43
43
}
44
44
}
45
+
46
+ func testTransaction( ) async throws {
47
+ var mlogger = Logger ( label: " test " )
48
+ mlogger. logLevel = . debug
49
+ let logger = mlogger
50
+ let eventLoopGroup = MultiThreadedEventLoopGroup ( numberOfThreads: 8 )
51
+ self . addTeardownBlock {
52
+ try await eventLoopGroup. shutdownGracefully ( )
53
+ }
54
+
55
+ let tableName = " test_client_transactions "
56
+
57
+ let clientConfig = PostgresClient . Configuration. makeTestConfiguration ( )
58
+ let client = PostgresClient ( configuration: clientConfig, eventLoopGroup: eventLoopGroup, backgroundLogger: logger)
59
+
60
+ do {
61
+ try await withThrowingTaskGroup ( of: Void . self) { taskGroup in
62
+ taskGroup. addTask {
63
+ await client. run ( )
64
+ }
65
+
66
+ try await client. query (
67
+ """
68
+ CREATE TABLE IF NOT EXISTS " \( unescaped: tableName) " (
69
+ id INT PRIMARY KEY GENERATED ALWAYS AS IDENTITY,
70
+ uuid UUID NOT NULL
71
+ );
72
+ """ ,
73
+ logger: logger
74
+ )
75
+
76
+ let iterations = 1000
77
+
78
+ for _ in 0 ..< iterations {
79
+ taskGroup. addTask {
80
+ let _ = try await client. withTransaction { transaction in
81
+ try await transaction. query (
82
+ """
83
+ INSERT INTO " \( unescaped: tableName) " (uuid) VALUES ( \( UUID ( ) ) );
84
+ """ ,
85
+ logger: logger
86
+ )
87
+ }
88
+ }
89
+ }
90
+
91
+ for _ in 0 ..< iterations {
92
+ _ = await taskGroup. nextResult ( ) !
93
+ }
94
+
95
+ let rows = try await client. query ( #"SELECT COUNT(1)::INT AS table_size FROM " \#( unescaped: tableName) ";"# , logger: logger) . decode ( Int . self)
96
+ for try await (count) in rows {
97
+ XCTAssertEqual ( count, iterations)
98
+ }
99
+
100
+ /// Test roll back
101
+ taskGroup. addTask {
102
+
103
+ do {
104
+ let _ = try await client. withTransaction { transaction in
105
+ /// insert valid data
106
+ try await transaction. query (
107
+ """
108
+ INSERT INTO " \( unescaped: tableName) " (uuid) VALUES ( \( UUID ( ) ) );
109
+ """ ,
110
+ logger: logger
111
+ )
112
+
113
+ /// insert invalid data
114
+ try await transaction. query (
115
+ """
116
+ INSERT INTO " \( unescaped: tableName) " (uuid) VALUES ( \( iterations) );
117
+ """ ,
118
+ logger: logger
119
+ )
120
+ }
121
+ } catch {
122
+ XCTAssertNotNil ( error)
123
+ guard let error = error as? PSQLError else { return XCTFail ( " Unexpected error type " ) }
124
+
125
+ XCTAssertEqual ( error. code, . server)
126
+ XCTAssertEqual ( error. serverInfo ? [ . severity] , " ERROR " )
127
+ }
128
+ }
129
+
130
+ let row = try await client. query ( #"SELECT COUNT(1)::INT AS table_size FROM " \#( unescaped: tableName) ";"# , logger: logger) . decode ( Int . self)
131
+
132
+ for try await (count) in row {
133
+ XCTAssertEqual ( count, iterations)
134
+ }
135
+
136
+ try await client. query (
137
+ """
138
+ DROP TABLE " \( unescaped: tableName) " ;
139
+ """ ,
140
+ logger: logger
141
+ )
142
+
143
+ taskGroup. cancelAll ( )
144
+ }
145
+ } catch {
146
+ XCTFail ( " Unexpected error: \( String ( reflecting: error) ) " )
147
+ }
148
+ }
45
149
46
150
func testApplicationNameIsForwardedCorrectly( ) async throws {
47
151
var mlogger = Logger ( label: " test " )
0 commit comments