@@ -737,17 +737,22 @@ pub struct Streams(RwLock<HashMap<String, StreamRef>>);
737
737
// 4. When first event is sent to stream (update the schema)
738
738
// 5. When set alert API is called (update the alert)
739
739
impl Streams {
740
- pub fn create (
740
+ /// Checks after getting an exclusive lock whether the stream already exists, else creates it.
741
+ /// NOTE: This is done to ensure we don't have contention among threads.
742
+ pub fn get_or_create (
741
743
& self ,
742
744
options : Arc < Options > ,
743
745
stream_name : String ,
744
746
metadata : LogStreamMetadata ,
745
747
ingestor_id : Option < String > ,
746
748
) -> StreamRef {
749
+ let mut guard = self . write ( ) . expect ( LOCK_EXPECT ) ;
750
+ if let Some ( stream) = guard. get ( & stream_name) {
751
+ return stream. clone ( ) ;
752
+ }
753
+
747
754
let stream = Stream :: new ( options, & stream_name, metadata, ingestor_id) ;
748
- self . write ( )
749
- . expect ( LOCK_EXPECT )
750
- . insert ( stream_name, stream. clone ( ) ) ;
755
+ guard. insert ( stream_name, stream. clone ( ) ) ;
751
756
752
757
stream
753
758
}
@@ -812,7 +817,7 @@ impl Streams {
812
817
813
818
#[ cfg( test) ]
814
819
mod tests {
815
- use std:: time:: Duration ;
820
+ use std:: { sync :: Barrier , thread :: spawn , time:: Duration } ;
816
821
817
822
use arrow_array:: { Int32Array , StringArray , TimestampMillisecondArray } ;
818
823
use arrow_schema:: { DataType , Field , TimeUnit } ;
@@ -1187,4 +1192,113 @@ mod tests {
1187
1192
assert_eq ! ( staging. parquet_files( ) . len( ) , 2 ) ;
1188
1193
assert_eq ! ( staging. arrow_files( ) . len( ) , 1 ) ;
1189
1194
}
1195
+
1196
+ #[ test]
1197
+ fn get_or_create_returns_existing_stream ( ) {
1198
+ let streams = Streams :: default ( ) ;
1199
+ let options = Arc :: new ( Options :: default ( ) ) ;
1200
+ let stream_name = "test_stream" ;
1201
+ let metadata = LogStreamMetadata :: default ( ) ;
1202
+ let ingestor_id = Some ( "test_ingestor" . to_owned ( ) ) ;
1203
+
1204
+ // Create the stream first
1205
+ let stream1 = streams. get_or_create (
1206
+ options. clone ( ) ,
1207
+ stream_name. to_owned ( ) ,
1208
+ metadata. clone ( ) ,
1209
+ ingestor_id. clone ( ) ,
1210
+ ) ;
1211
+
1212
+ // Call get_or_create again with the same stream_name
1213
+ let stream2 = streams. get_or_create (
1214
+ options. clone ( ) ,
1215
+ stream_name. to_owned ( ) ,
1216
+ metadata. clone ( ) ,
1217
+ ingestor_id. clone ( ) ,
1218
+ ) ;
1219
+
1220
+ // Assert that both references point to the same stream
1221
+ assert ! ( Arc :: ptr_eq( & stream1, & stream2) ) ;
1222
+
1223
+ // Verify the map contains only one entry
1224
+ let guard = streams. read ( ) . expect ( "Failed to acquire read lock" ) ;
1225
+ assert_eq ! ( guard. len( ) , 1 ) ;
1226
+ }
1227
+
1228
+ #[ test]
1229
+ fn create_and_return_new_stream_when_name_does_not_exist ( ) {
1230
+ let streams = Streams :: default ( ) ;
1231
+ let options = Arc :: new ( Options :: default ( ) ) ;
1232
+ let stream_name = "new_stream" ;
1233
+ let metadata = LogStreamMetadata :: default ( ) ;
1234
+ let ingestor_id = Some ( "new_ingestor" . to_owned ( ) ) ;
1235
+
1236
+ // Assert the stream doesn't exist already
1237
+ let guard = streams. read ( ) . expect ( "Failed to acquire read lock" ) ;
1238
+ assert_eq ! ( guard. len( ) , 0 ) ;
1239
+ assert ! ( !guard. contains_key( stream_name) ) ;
1240
+ drop ( guard) ;
1241
+
1242
+ // Call get_or_create with a new stream_name
1243
+ let stream = streams. get_or_create (
1244
+ options. clone ( ) ,
1245
+ stream_name. to_owned ( ) ,
1246
+ metadata. clone ( ) ,
1247
+ ingestor_id. clone ( ) ,
1248
+ ) ;
1249
+
1250
+ // verify created stream has the same ingestor_id
1251
+ assert_eq ! ( stream. ingestor_id, ingestor_id) ;
1252
+
1253
+ // Assert that the stream is created
1254
+ let guard = streams. read ( ) . expect ( "Failed to acquire read lock" ) ;
1255
+ assert_eq ! ( guard. len( ) , 1 ) ;
1256
+ assert ! ( guard. contains_key( stream_name) ) ;
1257
+ }
1258
+
1259
+ #[ test]
1260
+ fn get_or_create_stream_concurrently ( ) {
1261
+ let streams = Arc :: new ( Streams :: default ( ) ) ;
1262
+ let options = Arc :: new ( Options :: default ( ) ) ;
1263
+ let stream_name = String :: from ( "concurrent_stream" ) ;
1264
+ let metadata = LogStreamMetadata :: default ( ) ;
1265
+ let ingestor_id = Some ( String :: from ( "concurrent_ingestor" ) ) ;
1266
+
1267
+ // Barrier to synchronize threads
1268
+ let barrier = Arc :: new ( Barrier :: new ( 2 ) ) ;
1269
+
1270
+ // Clones for the first thread
1271
+ let streams1 = Arc :: clone ( & streams) ;
1272
+ let options1 = Arc :: clone ( & options) ;
1273
+ let barrier1 = Arc :: clone ( & barrier) ;
1274
+ let stream_name1 = stream_name. clone ( ) ;
1275
+ let metadata1 = metadata. clone ( ) ;
1276
+ let ingestor_id1 = ingestor_id. clone ( ) ;
1277
+
1278
+ // First thread
1279
+ let handle1 = spawn ( move || {
1280
+ barrier1. wait ( ) ;
1281
+ streams1. get_or_create ( options1, stream_name1, metadata1, ingestor_id1)
1282
+ } ) ;
1283
+
1284
+ // Cloned for the second thread
1285
+ let streams2 = Arc :: clone ( & streams) ;
1286
+
1287
+ // Second thread
1288
+ let handle2 = spawn ( move || {
1289
+ barrier. wait ( ) ;
1290
+ streams2. get_or_create ( options, stream_name, metadata, ingestor_id)
1291
+ } ) ;
1292
+
1293
+ // Wait for both threads to complete and get their results
1294
+ let stream1 = handle1. join ( ) . expect ( "Thread 1 panicked" ) ;
1295
+ let stream2 = handle2. join ( ) . expect ( "Thread 2 panicked" ) ;
1296
+
1297
+ // Assert that both references point to the same stream
1298
+ assert ! ( Arc :: ptr_eq( & stream1, & stream2) ) ;
1299
+
1300
+ // Verify the map contains only one entry
1301
+ let guard = streams. read ( ) . expect ( "Failed to acquire read lock" ) ;
1302
+ assert_eq ! ( guard. len( ) , 1 ) ;
1303
+ }
1190
1304
}
0 commit comments