Skip to content

Commit 7536132

Browse files
authored
sync: use AtomicBool in broadcast channel future (#6298)
1 parent b6d0c90 commit 7536132

File tree

3 files changed

+145
-31
lines changed

3 files changed

+145
-31
lines changed

benches/Cargo.toml

+5
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ name = "spawn"
2626
path = "spawn.rs"
2727
harness = false
2828

29+
[[bench]]
30+
name = "sync_broadcast"
31+
path = "sync_broadcast.rs"
32+
harness = false
33+
2934
[[bench]]
3035
name = "sync_mpsc"
3136
path = "sync_mpsc.rs"

benches/sync_broadcast.rs

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
use rand::{Rng, RngCore, SeedableRng};
2+
use std::sync::atomic::{AtomicUsize, Ordering};
3+
use std::sync::Arc;
4+
use tokio::sync::{broadcast, Notify};
5+
6+
use criterion::measurement::WallTime;
7+
use criterion::{black_box, criterion_group, criterion_main, BenchmarkGroup, Criterion};
8+
9+
fn rt() -> tokio::runtime::Runtime {
10+
tokio::runtime::Builder::new_multi_thread()
11+
.worker_threads(6)
12+
.build()
13+
.unwrap()
14+
}
15+
16+
fn do_work(rng: &mut impl RngCore) -> u32 {
17+
use std::fmt::Write;
18+
let mut message = String::new();
19+
for i in 1..=10 {
20+
let _ = write!(&mut message, " {i}={}", rng.gen::<f64>());
21+
}
22+
message
23+
.as_bytes()
24+
.iter()
25+
.map(|&c| c as u32)
26+
.fold(0, u32::wrapping_add)
27+
}
28+
29+
fn contention_impl<const N_TASKS: usize>(g: &mut BenchmarkGroup<WallTime>) {
30+
let rt = rt();
31+
32+
let (tx, _rx) = broadcast::channel::<usize>(1000);
33+
let wg = Arc::new((AtomicUsize::new(0), Notify::new()));
34+
35+
for n in 0..N_TASKS {
36+
let wg = wg.clone();
37+
let mut rx = tx.subscribe();
38+
let mut rng = rand::rngs::StdRng::seed_from_u64(n as u64);
39+
rt.spawn(async move {
40+
while let Ok(_) = rx.recv().await {
41+
let r = do_work(&mut rng);
42+
let _ = black_box(r);
43+
if wg.0.fetch_sub(1, Ordering::Relaxed) == 1 {
44+
wg.1.notify_one();
45+
}
46+
}
47+
});
48+
}
49+
50+
const N_ITERS: usize = 100;
51+
52+
g.bench_function(N_TASKS.to_string(), |b| {
53+
b.iter(|| {
54+
rt.block_on({
55+
let wg = wg.clone();
56+
let tx = tx.clone();
57+
async move {
58+
for i in 0..N_ITERS {
59+
assert_eq!(wg.0.fetch_add(N_TASKS, Ordering::Relaxed), 0);
60+
tx.send(i).unwrap();
61+
while wg.0.load(Ordering::Relaxed) > 0 {
62+
wg.1.notified().await;
63+
}
64+
}
65+
}
66+
})
67+
})
68+
});
69+
}
70+
71+
fn bench_contention(c: &mut Criterion) {
72+
let mut group = c.benchmark_group("contention");
73+
contention_impl::<10>(&mut group);
74+
contention_impl::<100>(&mut group);
75+
contention_impl::<500>(&mut group);
76+
contention_impl::<1000>(&mut group);
77+
group.finish();
78+
}
79+
80+
criterion_group!(contention, bench_contention);
81+
82+
criterion_main!(contention);

tokio/src/sync/broadcast.rs

+58-31
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@
117117
//! ```
118118
119119
use crate::loom::cell::UnsafeCell;
120-
use crate::loom::sync::atomic::AtomicUsize;
120+
use crate::loom::sync::atomic::{AtomicBool, AtomicUsize};
121121
use crate::loom::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard};
122122
use crate::util::linked_list::{self, GuardedLinkedList, LinkedList};
123123
use crate::util::WakeList;
@@ -127,7 +127,7 @@ use std::future::Future;
127127
use std::marker::PhantomPinned;
128128
use std::pin::Pin;
129129
use std::ptr::NonNull;
130-
use std::sync::atomic::Ordering::SeqCst;
130+
use std::sync::atomic::Ordering::{Acquire, Relaxed, Release, SeqCst};
131131
use std::task::{Context, Poll, Waker};
132132
use std::usize;
133133

@@ -354,7 +354,7 @@ struct Slot<T> {
354354
/// An entry in the wait queue.
355355
struct Waiter {
356356
/// True if queued.
357-
queued: bool,
357+
queued: AtomicBool,
358358

359359
/// Task waiting on the broadcast channel.
360360
waker: Option<Waker>,
@@ -369,7 +369,7 @@ struct Waiter {
369369
impl Waiter {
370370
fn new() -> Self {
371371
Self {
372-
queued: false,
372+
queued: AtomicBool::new(false),
373373
waker: None,
374374
pointers: linked_list::Pointers::new(),
375375
_p: PhantomPinned,
@@ -897,15 +897,22 @@ impl<T> Shared<T> {
897897
'outer: loop {
898898
while wakers.can_push() {
899899
match list.pop_back_locked(&mut tail) {
900-
Some(mut waiter) => {
901-
// Safety: `tail` lock is still held.
902-
let waiter = unsafe { waiter.as_mut() };
903-
904-
assert!(waiter.queued);
905-
waiter.queued = false;
906-
907-
if let Some(waker) = waiter.waker.take() {
908-
wakers.push(waker);
900+
Some(waiter) => {
901+
unsafe {
902+
// Safety: accessing `waker` is safe because
903+
// the tail lock is held.
904+
if let Some(waker) = (*waiter.as_ptr()).waker.take() {
905+
wakers.push(waker);
906+
}
907+
908+
// Safety: `queued` is atomic.
909+
let queued = &(*waiter.as_ptr()).queued;
910+
// `Relaxed` suffices because the tail lock is held.
911+
assert!(queued.load(Relaxed));
912+
// `Release` is needed to synchronize with `Recv::drop`.
913+
// It is critical to set this variable **after** waker
914+
// is extracted, otherwise we may data race with `Recv::drop`.
915+
queued.store(false, Release);
909916
}
910917
}
911918
None => {
@@ -1104,8 +1111,13 @@ impl<T> Receiver<T> {
11041111
}
11051112
}
11061113

1107-
if !(*ptr).queued {
1108-
(*ptr).queued = true;
1114+
// If the waiter is not already queued, enqueue it.
1115+
// `Relaxed` order suffices: we have synchronized with
1116+
// all writers through the tail lock that we hold.
1117+
if !(*ptr).queued.load(Relaxed) {
1118+
// `Relaxed` order suffices: all the readers will
1119+
// synchronize with this write through the tail lock.
1120+
(*ptr).queued.store(true, Relaxed);
11091121
tail.waiters.push_front(NonNull::new_unchecked(&mut *ptr));
11101122
}
11111123
});
@@ -1357,7 +1369,7 @@ impl<'a, T> Recv<'a, T> {
13571369
Recv {
13581370
receiver,
13591371
waiter: UnsafeCell::new(Waiter {
1360-
queued: false,
1372+
queued: AtomicBool::new(false),
13611373
waker: None,
13621374
pointers: linked_list::Pointers::new(),
13631375
_p: PhantomPinned,
@@ -1402,22 +1414,37 @@ where
14021414

14031415
impl<'a, T> Drop for Recv<'a, T> {
14041416
fn drop(&mut self) {
1405-
// Acquire the tail lock. This is required for safety before accessing
1406-
// the waiter node.
1407-
let mut tail = self.receiver.shared.tail.lock();
1408-
1409-
// safety: tail lock is held
1410-
let queued = self.waiter.with(|ptr| unsafe { (*ptr).queued });
1411-
1417+
// Safety: `waiter.queued` is atomic.
1418+
// Acquire ordering is required to synchronize with
1419+
// `Shared::notify_rx` before we drop the object.
1420+
let queued = self
1421+
.waiter
1422+
.with(|ptr| unsafe { (*ptr).queued.load(Acquire) });
1423+
1424+
// If the waiter is queued, we need to unlink it from the waiters list.
1425+
// If not, no further synchronization is required, since the waiter
1426+
// is not in the list and, as such, is not shared with any other threads.
14121427
if queued {
1413-
// Remove the node
1414-
//
1415-
// safety: tail lock is held and the wait node is verified to be in
1416-
// the list.
1417-
unsafe {
1418-
self.waiter.with_mut(|ptr| {
1419-
tail.waiters.remove((&mut *ptr).into());
1420-
});
1428+
// Acquire the tail lock. This is required for safety before accessing
1429+
// the waiter node.
1430+
let mut tail = self.receiver.shared.tail.lock();
1431+
1432+
// Safety: tail lock is held.
1433+
// `Relaxed` order suffices because we hold the tail lock.
1434+
let queued = self
1435+
.waiter
1436+
.with_mut(|ptr| unsafe { (*ptr).queued.load(Relaxed) });
1437+
1438+
if queued {
1439+
// Remove the node
1440+
//
1441+
// safety: tail lock is held and the wait node is verified to be in
1442+
// the list.
1443+
unsafe {
1444+
self.waiter.with_mut(|ptr| {
1445+
tail.waiters.remove((&mut *ptr).into());
1446+
});
1447+
}
14211448
}
14221449
}
14231450
}

0 commit comments

Comments
 (0)