From dce9ea5015b7389fb37503ede183c4ce2e8cb2d6 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 19 May 2023 17:05:24 +0200 Subject: [PATCH 01/15] Add failing unit test for missing flush --- test-harness/Cargo.toml | 1 - yamux/Cargo.toml | 1 + yamux/src/connection.rs | 243 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 244 insertions(+), 1 deletion(-) diff --git a/test-harness/Cargo.toml b/test-harness/Cargo.toml index 977a7c2d..92fe8a2f 100644 --- a/test-harness/Cargo.toml +++ b/test-harness/Cargo.toml @@ -16,4 +16,3 @@ log = "0.4.17" [dev-dependencies] env_logger = "0.10" constrained-connection = "0.1" - diff --git a/yamux/Cargo.toml b/yamux/Cargo.toml index 76475b00..24041444 100644 --- a/yamux/Cargo.toml +++ b/yamux/Cargo.toml @@ -26,6 +26,7 @@ quickcheck = "1.0" tokio = { version = "1.0", features = ["net", "rt-multi-thread", "macros", "time"] } tokio-util = { version = "0.7", features = ["compat"] } constrained-connection = "0.1" +futures_ringbuf = "0.3.1" [[bench]] name = "concurrent" diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index c2c19ea3..a6cb4ed4 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -934,3 +934,246 @@ impl Active { } } } + +#[cfg(test)] +mod tests { + use std::mem; + use std::pin::Pin; + use futures::AsyncReadExt; + use futures::future::BoxFuture; + use futures::stream::FuturesUnordered; + use futures_ringbuf::Endpoint; + use super::*; + + #[tokio::test] + async fn poll_flush_on_stream_only_returns_ok_if_frame_is_queued_for_sending() { + let (client, server) = Endpoint::pair(1000, 1000); + + let client = Client::new(Connection::new(client, Config::default(), Mode::Client)); + let server = EchoServer::new(Connection::new(server, Config::default(), Mode::Server)); + + let ((), processed) = futures::future::try_join(client, server).await.unwrap(); + + assert_eq!(processed, 1); + } + + /// Our testing client. + /// + /// This struct will open a single outbound stream, send a message, attempt to flush it and assert the internal state of [`Connection`] after it. + enum Client { + Initial { + connection: Connection, + }, + Testing { + connection: Connection, + worker_stream: StreamState, + }, + Closing { + connection: Connection, + }, + Poisoned, + } + + enum StreamState { + Sending(Stream), + Flushing(Stream), + Receiving(Stream), + Closing(Stream), + } + + impl Client { + fn new(connection: Connection) -> Self { + Self::Initial { + connection + } + } + } + + impl Future for Client { + type Output = Result<()>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + loop { + match mem::replace(this, Client::Poisoned) { + // This state matching is out of order to have the interesting one at the top. + Client::Testing { worker_stream: StreamState::Flushing(mut stream), mut connection } => { + match Pin::new(&mut stream).poll_flush(cx)? { + Poll::Ready(()) => { + // Here is the actual test: + // If the stream reports that it successfully flushed, we expect the connection to have queued the frames for sending. + // Because we only have a single stream, this means we can simply assert that there are no pending frames in the channel. + + let ConnectionState::Active(active) = &mut connection.inner else { + panic!("Connection is not active") + }; + + active.stream_receiver.try_next().expect_err("expected no pending frames in the channel after flushing"); + + *this = Client::Testing { worker_stream: StreamState::Receiving(stream), connection }; + continue; + } + Poll::Pending => {} + } + + drive_connection(this, connection, StreamState::Flushing(stream), cx); + return Poll::Pending; + } + Client::Testing { worker_stream: StreamState::Receiving(mut stream), connection } => { + let mut buffer = [0u8; 5]; + + match Pin::new(&mut stream).poll_read(cx, &mut buffer)? { + Poll::Ready(num_bytes) => { + assert_eq!(num_bytes, 5); + assert_eq!(&buffer, b"hello"); + + *this = Client::Testing { worker_stream: StreamState::Closing(stream), connection }; + continue; + } + Poll::Pending => {} + } + + drive_connection(this, connection, StreamState::Closing(stream), cx); + return Poll::Pending; + } + Client::Testing { worker_stream: StreamState::Closing(mut stream), connection } => { + match Pin::new(&mut stream).poll_close(cx)? { + Poll::Ready(()) => { + *this = Client::Closing { connection }; + continue; + } + Poll::Pending => {} + } + + drive_connection(this, connection, StreamState::Closing(stream), cx); + return Poll::Pending; + } + Client::Initial { mut connection } => { + match connection.poll_new_outbound(cx)? { + Poll::Ready(stream) => { + *this = Client::Testing { connection, worker_stream: StreamState::Sending(stream) }; + continue; + } + Poll::Pending => { + *this = Client::Initial { connection }; + return Poll::Pending; + } + } + } + Client::Testing { worker_stream: StreamState::Sending(mut stream), connection } => { + match Pin::new(&mut stream).poll_write(cx, b"hello")? { + Poll::Ready(written) => { + assert_eq!(written, 5); + *this = Client::Testing { worker_stream: StreamState::Flushing(stream), connection }; + continue; + } + Poll::Pending => {} + } + + drive_connection(this, connection, StreamState::Flushing(stream), cx); + return Poll::Pending; + } + Client::Closing { mut connection } => { + match connection.poll_close(cx)? { + Poll::Ready(()) => { + return Poll::Ready(Ok(())); + } + Poll::Pending => { + *this = Client::Closing { connection }; + return Poll::Pending; + } + } + } + Client::Poisoned => { + unreachable!() + } + } + } + } + } + + fn drive_connection(this: &mut Client, mut connection: Connection, state: StreamState, cx: &mut Context) { + match connection.poll_next_inbound(cx) { + Poll::Ready(Some(_)) => { + panic!("Unexpected inbound stream") + } + Poll::Ready(None) => { + panic!("Unexpected connection close") + } + Poll::Pending => { + *this = Client::Testing { worker_stream: state, connection }; + } + } + } + + struct EchoServer { + connection: Connection, + worker_streams: FuturesUnordered>>, + streams_processed: usize, + connection_closed: bool, + } + + impl EchoServer { + fn new(connection: Connection) -> Self { + Self { + connection, + worker_streams: FuturesUnordered::default(), + streams_processed: 0, + connection_closed: false, + } + } + } + + impl Future for EchoServer + { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = self.get_mut(); + + loop { + match this.worker_streams.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(()))) => { + this.streams_processed += 1; + continue; + } + Poll::Ready(Some(Err(e))) => { + eprintln!("A stream failed: {}", e); + continue; + } + Poll::Ready(None) => { + if this.connection_closed { + return Poll::Ready(Ok(this.streams_processed)); + } + } + Poll::Pending => {} + } + + match this.connection.poll_next_inbound(cx) { + Poll::Ready(Some(Ok(mut stream))) => { + this.worker_streams.push( + async move { + { + let (mut r, mut w) = AsyncReadExt::split(&mut stream); + futures::io::copy(&mut r, &mut w).await?; + } + stream.close().await?; + Ok(()) + } + .boxed(), + ); + continue; + } + Poll::Ready(None) | Poll::Ready(Some(Err(_))) => { + this.connection_closed = true; + continue; + } + Poll::Pending => {} + } + + return Poll::Pending; + } + } + } +} From 76a7e6f7e55025e4c78844635d6976eace1ead83 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 19 May 2023 18:00:22 +0200 Subject: [PATCH 02/15] Actually wait for flushing on `yamux::Stream` --- yamux/src/connection.rs | 147 +++++++++++++++++++++++--------- yamux/src/connection/closing.rs | 25 ++++-- yamux/src/connection/stream.rs | 56 +++++++++++- 3 files changed, 183 insertions(+), 45 deletions(-) diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index a6cb4ed4..28748a44 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -102,8 +102,9 @@ use cleanup::Cleanup; use closing::Closing; use futures::{channel::mpsc, future::Either, prelude::*, sink::SinkExt, stream::Fuse}; use nohash_hasher::IntMap; +use std::collections::hash_map::Entry; use std::collections::VecDeque; -use std::task::Context; +use std::task::{Context, Waker}; use std::{fmt, sync::Arc, task::Poll}; pub use stream::{Packet, State, Stream}; @@ -348,6 +349,8 @@ struct Active { socket: Fuse>, next_id: u32, streams: IntMap, + /// Stores the "marks" at which we need to notify a waiting flush task of a [`Stream`]. + flush_marks: IntMap, stream_sender: mpsc::Sender, stream_receiver: mpsc::Receiver, dropped_streams: Vec, @@ -359,6 +362,13 @@ struct Active { pub(crate) enum StreamCommand { /// A new frame should be sent to the remote. SendFrame(Frame>), + Flush { + id: StreamId, + /// How many frames we've queued for sending at the time the flush was requested. + num_frames: u64, + /// The waker to wake once the flush is complete. + waker: Waker, + }, /// Close a stream. CloseStream { id: StreamId, ack: bool }, } @@ -416,6 +426,7 @@ impl Active { config: Arc::new(cfg), socket, streams: IntMap::default(), + flush_marks: Default::default(), stream_sender, stream_receiver, next_id: match mode { @@ -466,6 +477,14 @@ impl Active { self.on_close_stream(id, ack); continue; } + Poll::Ready(Some(StreamCommand::Flush { + id, + num_frames, + waker, + })) => { + self.on_flush_stream(id, num_frames, waker); + continue; + } Poll::Ready(None) => { debug_assert!(false, "Only closed during shutdown") } @@ -526,13 +545,36 @@ impl Active { } fn on_send_frame(&mut self, frame: Frame>) { - log::trace!( - "{}/{}: sending: {}", - self.id, - frame.header().stream_id(), - frame.header() - ); + let stream_id = frame.header().stream_id(); + + log::trace!("{}/{}: sending: {}", self.id, stream_id, frame.header()); self.pending_frames.push_back(frame.into()); + + if let Some(stream) = self.streams.get(&stream_id) { + let mut shared = stream.shared(); + + shared.inc_sent(); + + if let Entry::Occupied(entry) = self.flush_marks.entry(stream_id) { + if shared.num_sent() >= entry.get().0 { + entry.remove().1.wake(); + } + } + } + } + + fn on_flush_stream(&mut self, id: StreamId, new_flush_mark: u64, waker: Waker) { + if let Some(stream) = self.streams.get(&id) { + let shared = stream.shared(); + + // Check if we have already reached the requested flush mark: + if shared.num_sent() >= new_flush_mark { + waker.wake(); + return; + } + + self.flush_marks.insert(id, (new_flush_mark, waker)); + } } fn on_close_stream(&mut self, id: StreamId, ack: bool) { @@ -937,13 +979,13 @@ impl Active { #[cfg(test)] mod tests { - use std::mem; - use std::pin::Pin; - use futures::AsyncReadExt; + use super::*; use futures::future::BoxFuture; use futures::stream::FuturesUnordered; + use futures::AsyncReadExt; use futures_ringbuf::Endpoint; - use super::*; + use std::mem; + use std::pin::Pin; #[tokio::test] async fn poll_flush_on_stream_only_returns_ok_if_frame_is_queued_for_sending() { @@ -983,9 +1025,7 @@ mod tests { impl Client { fn new(connection: Connection) -> Self { - Self::Initial { - connection - } + Self::Initial { connection } } } @@ -998,7 +1038,10 @@ mod tests { loop { match mem::replace(this, Client::Poisoned) { // This state matching is out of order to have the interesting one at the top. - Client::Testing { worker_stream: StreamState::Flushing(mut stream), mut connection } => { + Client::Testing { + worker_stream: StreamState::Flushing(mut stream), + mut connection, + } => { match Pin::new(&mut stream).poll_flush(cx)? { Poll::Ready(()) => { // Here is the actual test: @@ -1009,9 +1052,14 @@ mod tests { panic!("Connection is not active") }; - active.stream_receiver.try_next().expect_err("expected no pending frames in the channel after flushing"); + active.stream_receiver.try_next().expect_err( + "expected no pending frames in the channel after flushing", + ); - *this = Client::Testing { worker_stream: StreamState::Receiving(stream), connection }; + *this = Client::Testing { + worker_stream: StreamState::Receiving(stream), + connection, + }; continue; } Poll::Pending => {} @@ -1020,7 +1068,10 @@ mod tests { drive_connection(this, connection, StreamState::Flushing(stream), cx); return Poll::Pending; } - Client::Testing { worker_stream: StreamState::Receiving(mut stream), connection } => { + Client::Testing { + worker_stream: StreamState::Receiving(mut stream), + connection, + } => { let mut buffer = [0u8; 5]; match Pin::new(&mut stream).poll_read(cx, &mut buffer)? { @@ -1028,7 +1079,10 @@ mod tests { assert_eq!(num_bytes, 5); assert_eq!(&buffer, b"hello"); - *this = Client::Testing { worker_stream: StreamState::Closing(stream), connection }; + *this = Client::Testing { + worker_stream: StreamState::Closing(stream), + connection, + }; continue; } Poll::Pending => {} @@ -1037,7 +1091,10 @@ mod tests { drive_connection(this, connection, StreamState::Closing(stream), cx); return Poll::Pending; } - Client::Testing { worker_stream: StreamState::Closing(mut stream), connection } => { + Client::Testing { + worker_stream: StreamState::Closing(mut stream), + connection, + } => { match Pin::new(&mut stream).poll_close(cx)? { Poll::Ready(()) => { *this = Client::Closing { connection }; @@ -1052,7 +1109,10 @@ mod tests { Client::Initial { mut connection } => { match connection.poll_new_outbound(cx)? { Poll::Ready(stream) => { - *this = Client::Testing { connection, worker_stream: StreamState::Sending(stream) }; + *this = Client::Testing { + connection, + worker_stream: StreamState::Sending(stream), + }; continue; } Poll::Pending => { @@ -1061,11 +1121,17 @@ mod tests { } } } - Client::Testing { worker_stream: StreamState::Sending(mut stream), connection } => { + Client::Testing { + worker_stream: StreamState::Sending(mut stream), + connection, + } => { match Pin::new(&mut stream).poll_write(cx, b"hello")? { Poll::Ready(written) => { assert_eq!(written, 5); - *this = Client::Testing { worker_stream: StreamState::Flushing(stream), connection }; + *this = Client::Testing { + worker_stream: StreamState::Flushing(stream), + connection, + }; continue; } Poll::Pending => {} @@ -1074,17 +1140,15 @@ mod tests { drive_connection(this, connection, StreamState::Flushing(stream), cx); return Poll::Pending; } - Client::Closing { mut connection } => { - match connection.poll_close(cx)? { - Poll::Ready(()) => { - return Poll::Ready(Ok(())); - } - Poll::Pending => { - *this = Client::Closing { connection }; - return Poll::Pending; - } + Client::Closing { mut connection } => match connection.poll_close(cx)? { + Poll::Ready(()) => { + return Poll::Ready(Ok(())); } - } + Poll::Pending => { + *this = Client::Closing { connection }; + return Poll::Pending; + } + }, Client::Poisoned => { unreachable!() } @@ -1093,7 +1157,12 @@ mod tests { } } - fn drive_connection(this: &mut Client, mut connection: Connection, state: StreamState, cx: &mut Context) { + fn drive_connection( + this: &mut Client, + mut connection: Connection, + state: StreamState, + cx: &mut Context, + ) { match connection.poll_next_inbound(cx) { Poll::Ready(Some(_)) => { panic!("Unexpected inbound stream") @@ -1102,7 +1171,10 @@ mod tests { panic!("Unexpected connection close") } Poll::Pending => { - *this = Client::Testing { worker_stream: state, connection }; + *this = Client::Testing { + worker_stream: state, + connection, + }; } } } @@ -1125,8 +1197,7 @@ mod tests { } } - impl Future for EchoServer - { + impl Future for EchoServer { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -1161,7 +1232,7 @@ mod tests { stream.close().await?; Ok(()) } - .boxed(), + .boxed(), ); continue; } diff --git a/yamux/src/connection/closing.rs b/yamux/src/connection/closing.rs index faedb25d..c9c9a7e9 100644 --- a/yamux/src/connection/closing.rs +++ b/yamux/src/connection/closing.rs @@ -7,8 +7,9 @@ use futures::stream::Fuse; use futures::{ready, AsyncRead, AsyncWrite, SinkExt, StreamExt}; use std::collections::VecDeque; use std::future::Future; +use std::mem; use std::pin::Pin; -use std::task::{Context, Poll}; +use std::task::{Context, Poll, Waker}; /// A [`Future`] that gracefully closes the yamux connection. #[must_use] @@ -16,6 +17,7 @@ pub struct Closing { state: State, stream_receiver: mpsc::Receiver, pending_frames: VecDeque>, + pending_flush_wakers: Vec, socket: Fuse>, } @@ -32,6 +34,7 @@ where state: State::ClosingStreamReceiver, stream_receiver, pending_frames, + pending_flush_wakers: vec![], socket, } } @@ -60,10 +63,22 @@ where Some(StreamCommand::SendFrame(frame)) => { this.pending_frames.push_back(frame.into()) } - Some(StreamCommand::CloseStream { id, ack }) => this - .pending_frames - .push_back(Frame::close_stream(id, ack).into()), - None => this.state = State::SendingTermFrame, + Some(StreamCommand::CloseStream { id, ack }) => { + this.pending_frames + .push_back(Frame::close_stream(id, ack).into()); + } + Some(StreamCommand::Flush { waker, .. }) => { + this.pending_flush_wakers.push(waker); + } + None => { + // Receiver is closed, meaning we have queued all frames for sending. + // Notify all pending flush tasks. + for waker in mem::take(&mut this.pending_flush_wakers) { + waker.wake(); + } + + this.state = State::SendingTermFrame; + } } } State::SendingTermFrame => { diff --git a/yamux/src/connection/stream.rs b/yamux/src/connection/stream.rs index d7405cd3..f7434a51 100644 --- a/yamux/src/connection/stream.rs +++ b/yamux/src/connection/stream.rs @@ -331,6 +331,7 @@ impl AsyncWrite for Stream { .sender .poll_ready(cx) .map_err(|_| self.write_zero_err())?); + let body = { let mut shared = self.shared(); if !shared.state().can_write() { @@ -345,6 +346,8 @@ impl AsyncWrite for Stream { let k = std::cmp::min(shared.credit as usize, buf.len()); let k = std::cmp::min(k, self.config.split_send_size); shared.credit = shared.credit.saturating_sub(k as u32); + shared.inc_queued(); + Vec::from(&buf[..k]) }; let n = body.len(); @@ -355,11 +358,36 @@ impl AsyncWrite for Stream { self.sender .start_send(cmd) .map_err(|_| self.write_zero_err())?; + Poll::Ready(Ok(n)) } - fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll> { - Poll::Ready(Ok(())) + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + let num_frames = { + let shared = self.shared(); + + if shared.is_flushed() { + return Poll::Ready(Ok(())); + } + + shared.queued_frames + }; + + ready!(self + .sender + .poll_ready(cx) + .map_err(|_| self.write_zero_err())?); + + let cmd = StreamCommand::Flush { + id: self.id, + num_frames, + waker: cx.waker().clone(), + }; + self.sender + .start_send(cmd) + .map_err(|_| self.write_zero_err())?; + + Poll::Pending } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { @@ -396,6 +424,12 @@ pub(crate) struct Shared { pub(crate) reader: Option, pub(crate) writer: Option, config: Arc, + + /// The number of frames queued for sending via [`StreamCommand::SendFrame`] + queued_frames: u64, + + /// The number of frames sent to the socket by the [`Connection`](crate::Connection). + sent_frames: u64, } impl Shared { @@ -408,6 +442,8 @@ impl Shared { reader: None, writer: None, config, + queued_frames: 0, + sent_frames: 0, } } @@ -415,6 +451,22 @@ impl Shared { self.state } + pub(crate) fn is_flushed(&self) -> bool { + self.queued_frames == self.sent_frames + } + + pub(crate) fn inc_queued(&mut self) { + self.queued_frames += 1; + } + + pub(crate) fn inc_sent(&mut self) { + self.sent_frames += 1; + } + + pub(crate) fn num_sent(&self) -> u64 { + self.sent_frames + } + /// Update the stream state and return the state before it was updated. pub(crate) fn update_state( &mut self, From 6c03282f86121b5f3e7f0fd846d14610b036aa83 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Fri, 19 May 2023 18:42:02 +0200 Subject: [PATCH 03/15] Flushing a closed stream is okay --- yamux/src/connection/stream.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/yamux/src/connection/stream.rs b/yamux/src/connection/stream.rs index f7434a51..030a278f 100644 --- a/yamux/src/connection/stream.rs +++ b/yamux/src/connection/stream.rs @@ -363,6 +363,10 @@ impl AsyncWrite for Stream { } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + if self.is_closed() { + return Poll::Ready(Ok(())); + } + let num_frames = { let shared = self.shared(); From 496613e71b3e801b9e066a43cb76df21a3fca032 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Mon, 22 May 2023 11:54:55 +0200 Subject: [PATCH 04/15] Use 1 channel per stream --- yamux/src/connection.rs | 130 +++++++++++--------------------- yamux/src/connection/cleanup.rs | 15 ++-- yamux/src/connection/closing.rs | 19 ++--- yamux/src/connection/stream.rs | 57 +------------- 4 files changed, 63 insertions(+), 158 deletions(-) diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index 28748a44..87e4023d 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -96,15 +96,15 @@ use crate::{ error::ConnectionError, frame::header::{self, Data, GoAway, Header, Ping, StreamId, Tag, WindowUpdate, CONNECTION_ID}, frame::{self, Frame}, - Config, WindowUpdateMode, DEFAULT_CREDIT, MAX_COMMAND_BACKLOG, + Config, WindowUpdateMode, DEFAULT_CREDIT, }; use cleanup::Cleanup; use closing::Closing; +use futures::stream::SelectAll; use futures::{channel::mpsc, future::Either, prelude::*, sink::SinkExt, stream::Fuse}; use nohash_hasher::IntMap; -use std::collections::hash_map::Entry; use std::collections::VecDeque; -use std::task::{Context, Waker}; +use std::task::Context; use std::{fmt, sync::Arc, task::Poll}; pub use stream::{Packet, State, Stream}; @@ -349,10 +349,7 @@ struct Active { socket: Fuse>, next_id: u32, streams: IntMap, - /// Stores the "marks" at which we need to notify a waiting flush task of a [`Stream`]. - flush_marks: IntMap, - stream_sender: mpsc::Sender, - stream_receiver: mpsc::Receiver, + stream_receivers: SelectAll>, dropped_streams: Vec, pending_frames: VecDeque>, } @@ -362,13 +359,6 @@ struct Active { pub(crate) enum StreamCommand { /// A new frame should be sent to the remote. SendFrame(Frame>), - Flush { - id: StreamId, - /// How many frames we've queued for sending at the time the flush was requested. - num_frames: u64, - /// The waker to wake once the flush is complete. - waker: Waker, - }, /// Close a stream. CloseStream { id: StreamId, ack: bool }, } @@ -418,7 +408,6 @@ impl Active { fn new(socket: T, cfg: Config, mode: Mode) -> Self { let id = Id::random(); log::debug!("new connection: {} ({:?})", id, mode); - let (stream_sender, stream_receiver) = mpsc::channel(MAX_COMMAND_BACKLOG); let socket = frame::Io::new(id, socket, cfg.max_buffer_size).fuse(); Active { id, @@ -426,9 +415,7 @@ impl Active { config: Arc::new(cfg), socket, streams: IntMap::default(), - flush_marks: Default::default(), - stream_sender, - stream_receiver, + stream_receivers: SelectAll::default(), next_id: match mode { Mode::Client => 1, Mode::Server => 2, @@ -440,7 +427,7 @@ impl Active { /// Gracefully close the connection to the remote. fn close(self) -> Closing { - Closing::new(self.stream_receiver, self.pending_frames, self.socket) + Closing::new(self.stream_receivers, self.pending_frames, self.socket) } /// Cleanup all our resources. @@ -449,7 +436,7 @@ impl Active { fn cleanup(mut self, error: ConnectionError) -> Cleanup { self.drop_all_streams(); - Cleanup::new(self.stream_receiver, error) + Cleanup::new(self.stream_receivers, error) } fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { @@ -468,7 +455,7 @@ impl Active { Poll::Pending => {} } - match self.stream_receiver.poll_next_unpin(cx) { + match self.stream_receivers.poll_next_unpin(cx) { Poll::Ready(Some(StreamCommand::SendFrame(frame))) => { self.on_send_frame(frame); continue; @@ -477,17 +464,7 @@ impl Active { self.on_close_stream(id, ack); continue; } - Poll::Ready(Some(StreamCommand::Flush { - id, - num_frames, - waker, - })) => { - self.on_flush_stream(id, num_frames, waker); - continue; - } - Poll::Ready(None) => { - debug_assert!(false, "Only closed during shutdown") - } + Poll::Ready(None) => {} Poll::Pending => {} } @@ -527,16 +504,11 @@ impl Active { self.pending_frames.push_back(frame.into()); } - let stream = { - let config = self.config.clone(); - let sender = self.stream_sender.clone(); - let window = self.config.receive_window; - let mut stream = Stream::new(id, self.id, config, window, DEFAULT_CREDIT, sender); - if extra_credit == 0 { - stream.set_flag(stream::Flag::Syn) - } - stream - }; + let mut stream = self.make_new_stream(id, self.config.receive_window, DEFAULT_CREDIT); + + if extra_credit == 0 { + stream.set_flag(stream::Flag::Syn) + } log::debug!("{}: new outbound {} of {}", self.id, stream, self); self.streams.insert(id, stream.clone()); @@ -549,32 +521,6 @@ impl Active { log::trace!("{}/{}: sending: {}", self.id, stream_id, frame.header()); self.pending_frames.push_back(frame.into()); - - if let Some(stream) = self.streams.get(&stream_id) { - let mut shared = stream.shared(); - - shared.inc_sent(); - - if let Entry::Occupied(entry) = self.flush_marks.entry(stream_id) { - if shared.num_sent() >= entry.get().0 { - entry.remove().1.wake(); - } - } - } - } - - fn on_flush_stream(&mut self, id: StreamId, new_flush_mark: u64, waker: Waker) { - if let Some(stream) = self.streams.get(&id) { - let shared = stream.shared(); - - // Check if we have already reached the requested flush mark: - if shared.num_sent() >= new_flush_mark { - waker.wake(); - return; - } - - self.flush_marks.insert(id, (new_flush_mark, waker)); - } } fn on_close_stream(&mut self, id: StreamId, ack: bool) { @@ -670,12 +616,7 @@ impl Active { log::error!("{}: maximum number of streams reached", self.id); return Action::Terminate(Frame::internal_error()); } - let mut stream = { - let config = self.config.clone(); - let credit = DEFAULT_CREDIT; - let sender = self.stream_sender.clone(); - Stream::new(stream_id, self.id, config, credit, credit, sender) - }; + let mut stream = self.make_new_stream(stream_id, DEFAULT_CREDIT, DEFAULT_CREDIT); let mut window_update = None; { let mut shared = stream.shared(); @@ -790,15 +731,11 @@ impl Active { log::error!("{}: maximum number of streams reached", self.id); return Action::Terminate(Frame::protocol_error()); } - let stream = { - let credit = frame.header().credit() + DEFAULT_CREDIT; - let config = self.config.clone(); - let sender = self.stream_sender.clone(); - let mut stream = - Stream::new(stream_id, self.id, config, DEFAULT_CREDIT, credit, sender); - stream.set_flag(stream::Flag::Ack); - stream - }; + + let credit = frame.header().credit() + DEFAULT_CREDIT; + let mut stream = self.make_new_stream(stream_id, DEFAULT_CREDIT, credit); + stream.set_flag(stream::Flag::Ack); + if is_finish { stream .shared() @@ -863,6 +800,20 @@ impl Active { Action::None } + fn make_new_stream(&mut self, id: StreamId, window: u32, credit: u32) -> Stream { + let config = self.config.clone(); + + // Create a channel with 0 _additional_ capacity for items. + // `poll_flush` for `Sender` will check whether we can send an item into the stream and + // NOT whether all items have been taken out of the receiver. + // To ensure that `poll_flush` on our `Stream` means that we have sent all frames, + // this channel must be configured with 0 capacity. + let (sender, receiver) = mpsc::channel(0); + self.stream_receivers.push(receiver); + + Stream::new(id, self.id, config, window, credit, sender) + } + fn next_stream_id(&mut self) -> Result { let proposed = StreamId::new(self.next_id); self.next_id = self @@ -1052,9 +1003,16 @@ mod tests { panic!("Connection is not active") }; - active.stream_receiver.try_next().expect_err( - "expected no pending frames in the channel after flushing", - ); + assert_eq!(active.stream_receivers.len(), 1); + active + .stream_receivers + .iter_mut() + .next() + .unwrap() + .try_next() + .expect_err( + "expected no pending frames in the channel after flushing", + ); *this = Client::Testing { worker_stream: StreamState::Receiving(stream), diff --git a/yamux/src/connection/cleanup.rs b/yamux/src/connection/cleanup.rs index c0017700..b611d67e 100644 --- a/yamux/src/connection/cleanup.rs +++ b/yamux/src/connection/cleanup.rs @@ -1,6 +1,7 @@ use crate::connection::StreamCommand; use crate::ConnectionError; use futures::channel::mpsc; +use futures::stream::SelectAll; use futures::{ready, StreamExt}; use std::future::Future; use std::pin::Pin; @@ -10,18 +11,18 @@ use std::task::{Context, Poll}; #[must_use] pub struct Cleanup { state: State, - stream_receiver: mpsc::Receiver, + stream_receivers: SelectAll>, error: Option, } impl Cleanup { pub(crate) fn new( - stream_receiver: mpsc::Receiver, + stream_receivers: SelectAll>, error: ConnectionError, ) -> Self { Self { state: State::ClosingStreamReceiver, - stream_receiver, + stream_receivers, error: Some(error), } } @@ -36,14 +37,14 @@ impl Future for Cleanup { loop { match this.state { State::ClosingStreamReceiver => { - this.stream_receiver.close(); + for stream in this.stream_receivers.iter_mut() { + stream.close(); + } this.state = State::DrainingStreamReceiver; } State::DrainingStreamReceiver => { - this.stream_receiver.close(); - - match ready!(this.stream_receiver.poll_next_unpin(cx)) { + match ready!(this.stream_receivers.poll_next_unpin(cx)) { Some(cmd) => { drop(cmd); } diff --git a/yamux/src/connection/closing.rs b/yamux/src/connection/closing.rs index c9c9a7e9..b3d27198 100644 --- a/yamux/src/connection/closing.rs +++ b/yamux/src/connection/closing.rs @@ -3,7 +3,7 @@ use crate::frame; use crate::frame::Frame; use crate::Result; use futures::channel::mpsc; -use futures::stream::Fuse; +use futures::stream::{Fuse, SelectAll}; use futures::{ready, AsyncRead, AsyncWrite, SinkExt, StreamExt}; use std::collections::VecDeque; use std::future::Future; @@ -15,7 +15,7 @@ use std::task::{Context, Poll, Waker}; #[must_use] pub struct Closing { state: State, - stream_receiver: mpsc::Receiver, + stream_receivers: SelectAll>, pending_frames: VecDeque>, pending_flush_wakers: Vec, socket: Fuse>, @@ -26,13 +26,13 @@ where T: AsyncRead + AsyncWrite + Unpin, { pub(crate) fn new( - stream_receiver: mpsc::Receiver, + stream_receiver: SelectAll>, pending_frames: VecDeque>, socket: Fuse>, ) -> Self { Self { state: State::ClosingStreamReceiver, - stream_receiver, + stream_receivers: stream_receiver, pending_frames, pending_flush_wakers: vec![], socket, @@ -52,14 +52,14 @@ where loop { match this.state { State::ClosingStreamReceiver => { - this.stream_receiver.close(); + for stream in this.stream_receivers.iter_mut() { + stream.close(); + } this.state = State::DrainingStreamReceiver; } State::DrainingStreamReceiver => { - this.stream_receiver.close(); - - match ready!(this.stream_receiver.poll_next_unpin(cx)) { + match ready!(this.stream_receivers.poll_next_unpin(cx)) { Some(StreamCommand::SendFrame(frame)) => { this.pending_frames.push_back(frame.into()) } @@ -67,9 +67,6 @@ where this.pending_frames .push_back(Frame::close_stream(id, ack).into()); } - Some(StreamCommand::Flush { waker, .. }) => { - this.pending_flush_wakers.push(waker); - } None => { // Receiver is closed, meaning we have queued all frames for sending. // Notify all pending flush tasks. diff --git a/yamux/src/connection/stream.rs b/yamux/src/connection/stream.rs index 030a278f..7547e39d 100644 --- a/yamux/src/connection/stream.rs +++ b/yamux/src/connection/stream.rs @@ -21,7 +21,7 @@ use futures::{ channel::mpsc, future::Either, io::{AsyncRead, AsyncWrite}, - ready, + ready, SinkExt, }; use parking_lot::{Mutex, MutexGuard}; use std::convert::TryInto; @@ -346,7 +346,6 @@ impl AsyncWrite for Stream { let k = std::cmp::min(shared.credit as usize, buf.len()); let k = std::cmp::min(k, self.config.split_send_size); shared.credit = shared.credit.saturating_sub(k as u32); - shared.inc_queued(); Vec::from(&buf[..k]) }; @@ -363,35 +362,9 @@ impl AsyncWrite for Stream { } fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - if self.is_closed() { - return Poll::Ready(Ok(())); - } - - let num_frames = { - let shared = self.shared(); - - if shared.is_flushed() { - return Poll::Ready(Ok(())); - } - - shared.queued_frames - }; - - ready!(self - .sender - .poll_ready(cx) - .map_err(|_| self.write_zero_err())?); - - let cmd = StreamCommand::Flush { - id: self.id, - num_frames, - waker: cx.waker().clone(), - }; self.sender - .start_send(cmd) - .map_err(|_| self.write_zero_err())?; - - Poll::Pending + .poll_flush_unpin(cx) + .map_err(|_| self.write_zero_err()) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { @@ -428,12 +401,6 @@ pub(crate) struct Shared { pub(crate) reader: Option, pub(crate) writer: Option, config: Arc, - - /// The number of frames queued for sending via [`StreamCommand::SendFrame`] - queued_frames: u64, - - /// The number of frames sent to the socket by the [`Connection`](crate::Connection). - sent_frames: u64, } impl Shared { @@ -446,8 +413,6 @@ impl Shared { reader: None, writer: None, config, - queued_frames: 0, - sent_frames: 0, } } @@ -455,22 +420,6 @@ impl Shared { self.state } - pub(crate) fn is_flushed(&self) -> bool { - self.queued_frames == self.sent_frames - } - - pub(crate) fn inc_queued(&mut self) { - self.queued_frames += 1; - } - - pub(crate) fn inc_sent(&mut self) { - self.sent_frames += 1; - } - - pub(crate) fn num_sent(&self) -> u64 { - self.sent_frames - } - /// Update the stream state and return the state before it was updated. pub(crate) fn update_state( &mut self, From ee3898ad65516b4f12cb61a6b0a33119ea5832cd Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Mon, 22 May 2023 12:47:09 +0200 Subject: [PATCH 05/15] Minimize diff --- yamux/src/connection.rs | 9 ++++++--- yamux/src/connection/stream.rs | 3 --- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index 87e4023d..403fcbd1 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -517,9 +517,12 @@ impl Active { } fn on_send_frame(&mut self, frame: Frame>) { - let stream_id = frame.header().stream_id(); - - log::trace!("{}/{}: sending: {}", self.id, stream_id, frame.header()); + log::trace!( + "{}/{}: sending: {}", + self.id, + frame.header().stream_id(), + frame.header() + ); self.pending_frames.push_back(frame.into()); } diff --git a/yamux/src/connection/stream.rs b/yamux/src/connection/stream.rs index 7547e39d..466f817b 100644 --- a/yamux/src/connection/stream.rs +++ b/yamux/src/connection/stream.rs @@ -331,7 +331,6 @@ impl AsyncWrite for Stream { .sender .poll_ready(cx) .map_err(|_| self.write_zero_err())?); - let body = { let mut shared = self.shared(); if !shared.state().can_write() { @@ -346,7 +345,6 @@ impl AsyncWrite for Stream { let k = std::cmp::min(shared.credit as usize, buf.len()); let k = std::cmp::min(k, self.config.split_send_size); shared.credit = shared.credit.saturating_sub(k as u32); - Vec::from(&buf[..k]) }; let n = body.len(); @@ -357,7 +355,6 @@ impl AsyncWrite for Stream { self.sender .start_send(cmd) .map_err(|_| self.write_zero_err())?; - Poll::Ready(Ok(n)) } From 53d90e4ccd080e6e0f78d63aa8ec4812889a54f8 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Mon, 22 May 2023 13:13:42 +0200 Subject: [PATCH 06/15] Introduce dedicated `CommandReceivers` --- yamux/src/connection.rs | 14 +++--- yamux/src/connection/cleanup.rs | 32 ++++++-------- yamux/src/connection/closing.rs | 52 ++++++++++------------- yamux/src/connection/command_receivers.rs | 47 ++++++++++++++++++++ 4 files changed, 89 insertions(+), 56 deletions(-) create mode 100644 yamux/src/connection/command_receivers.rs diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index 403fcbd1..6b24c48f 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -89,6 +89,7 @@ mod cleanup; mod closing; +mod command_receivers; mod stream; use crate::Result; @@ -100,13 +101,13 @@ use crate::{ }; use cleanup::Cleanup; use closing::Closing; -use futures::stream::SelectAll; use futures::{channel::mpsc, future::Either, prelude::*, sink::SinkExt, stream::Fuse}; use nohash_hasher::IntMap; use std::collections::VecDeque; use std::task::Context; use std::{fmt, sync::Arc, task::Poll}; +use crate::connection::command_receivers::CommandReceivers; pub use stream::{Packet, State, Stream}; /// How the connection is used. @@ -349,7 +350,7 @@ struct Active { socket: Fuse>, next_id: u32, streams: IntMap, - stream_receivers: SelectAll>, + stream_receivers: CommandReceivers, dropped_streams: Vec, pending_frames: VecDeque>, } @@ -415,7 +416,7 @@ impl Active { config: Arc::new(cfg), socket, streams: IntMap::default(), - stream_receivers: SelectAll::default(), + stream_receivers: CommandReceivers::default(), next_id: match mode { Mode::Client => 1, Mode::Server => 2, @@ -455,16 +456,15 @@ impl Active { Poll::Pending => {} } - match self.stream_receivers.poll_next_unpin(cx) { - Poll::Ready(Some(StreamCommand::SendFrame(frame))) => { + match self.stream_receivers.poll_next(cx) { + Poll::Ready(StreamCommand::SendFrame(frame)) => { self.on_send_frame(frame); continue; } - Poll::Ready(Some(StreamCommand::CloseStream { id, ack })) => { + Poll::Ready(StreamCommand::CloseStream { id, ack }) => { self.on_close_stream(id, ack); continue; } - Poll::Ready(None) => {} Poll::Pending => {} } diff --git a/yamux/src/connection/cleanup.rs b/yamux/src/connection/cleanup.rs index b611d67e..b7a7de63 100644 --- a/yamux/src/connection/cleanup.rs +++ b/yamux/src/connection/cleanup.rs @@ -1,8 +1,5 @@ -use crate::connection::StreamCommand; +use crate::connection::command_receivers::CommandReceivers; use crate::ConnectionError; -use futures::channel::mpsc; -use futures::stream::SelectAll; -use futures::{ready, StreamExt}; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; @@ -11,15 +8,12 @@ use std::task::{Context, Poll}; #[must_use] pub struct Cleanup { state: State, - stream_receivers: SelectAll>, + stream_receivers: CommandReceivers, error: Option, } impl Cleanup { - pub(crate) fn new( - stream_receivers: SelectAll>, - error: ConnectionError, - ) -> Self { + pub(crate) fn new(stream_receivers: CommandReceivers, error: ConnectionError) -> Self { Self { state: State::ClosingStreamReceiver, stream_receivers, @@ -37,26 +31,26 @@ impl Future for Cleanup { loop { match this.state { State::ClosingStreamReceiver => { - for stream in this.stream_receivers.iter_mut() { - stream.close(); - } + this.stream_receivers.close(); this.state = State::DrainingStreamReceiver; } - State::DrainingStreamReceiver => { - match ready!(this.stream_receivers.poll_next_unpin(cx)) { - Some(cmd) => { - drop(cmd); - } - None => { + State::DrainingStreamReceiver => match this.stream_receivers.poll_next(cx) { + Poll::Ready(cmd) => { + drop(cmd); + } + Poll::Pending => { + if this.stream_receivers.is_empty() { return Poll::Ready( this.error .take() .expect("to not be called after completion"), ); } + + return Poll::Pending; } - } + }, } } } diff --git a/yamux/src/connection/closing.rs b/yamux/src/connection/closing.rs index b3d27198..b6cddb88 100644 --- a/yamux/src/connection/closing.rs +++ b/yamux/src/connection/closing.rs @@ -1,23 +1,21 @@ +use crate::connection::command_receivers::CommandReceivers; use crate::connection::StreamCommand; use crate::frame; use crate::frame::Frame; use crate::Result; -use futures::channel::mpsc; -use futures::stream::{Fuse, SelectAll}; -use futures::{ready, AsyncRead, AsyncWrite, SinkExt, StreamExt}; +use futures::stream::Fuse; +use futures::{ready, AsyncRead, AsyncWrite, SinkExt}; use std::collections::VecDeque; use std::future::Future; -use std::mem; use std::pin::Pin; -use std::task::{Context, Poll, Waker}; +use std::task::{Context, Poll}; /// A [`Future`] that gracefully closes the yamux connection. #[must_use] pub struct Closing { state: State, - stream_receivers: SelectAll>, + stream_receivers: CommandReceivers, pending_frames: VecDeque>, - pending_flush_wakers: Vec, socket: Fuse>, } @@ -26,15 +24,14 @@ where T: AsyncRead + AsyncWrite + Unpin, { pub(crate) fn new( - stream_receiver: SelectAll>, + stream_receivers: CommandReceivers, pending_frames: VecDeque>, socket: Fuse>, ) -> Self { Self { state: State::ClosingStreamReceiver, - stream_receivers: stream_receiver, + stream_receivers, pending_frames, - pending_flush_wakers: vec![], socket, } } @@ -52,32 +49,27 @@ where loop { match this.state { State::ClosingStreamReceiver => { - for stream in this.stream_receivers.iter_mut() { - stream.close(); - } + this.stream_receivers.close(); this.state = State::DrainingStreamReceiver; } - State::DrainingStreamReceiver => { - match ready!(this.stream_receivers.poll_next_unpin(cx)) { - Some(StreamCommand::SendFrame(frame)) => { - this.pending_frames.push_back(frame.into()) - } - Some(StreamCommand::CloseStream { id, ack }) => { - this.pending_frames - .push_back(Frame::close_stream(id, ack).into()); - } - None => { - // Receiver is closed, meaning we have queued all frames for sending. - // Notify all pending flush tasks. - for waker in mem::take(&mut this.pending_flush_wakers) { - waker.wake(); - } - + State::DrainingStreamReceiver => match this.stream_receivers.poll_next(cx) { + Poll::Ready(StreamCommand::SendFrame(frame)) => { + this.pending_frames.push_back(frame.into()) + } + Poll::Ready(StreamCommand::CloseStream { id, ack }) => { + this.pending_frames + .push_back(Frame::close_stream(id, ack).into()); + } + Poll::Pending => { + if this.stream_receivers.is_empty() { this.state = State::SendingTermFrame; + continue; } + + return Poll::Pending; } - } + }, State::SendingTermFrame => { this.pending_frames.push_back(Frame::term().into()); this.state = State::FlushingPendingFrames; diff --git a/yamux/src/connection/command_receivers.rs b/yamux/src/connection/command_receivers.rs new file mode 100644 index 00000000..0454548d --- /dev/null +++ b/yamux/src/connection/command_receivers.rs @@ -0,0 +1,47 @@ +use crate::connection::StreamCommand; +use futures::channel::mpsc; +use futures::stream::SelectAll; +use futures::{ready, StreamExt}; +use std::task::{Context, Poll, Waker}; + +/// A set of [`mpsc::Receiver`]s for [`StreamCommand`]s. +#[derive(Default)] +pub struct CommandReceivers { + inner: SelectAll>, + waker: Option, +} + +impl CommandReceivers { + /// Push a new [`mpsc::Receiver`]. + pub(crate) fn push(&mut self, receiver: mpsc::Receiver) { + self.inner.push(receiver); + if let Some(waker) = self.waker.take() { + waker.wake(); + } + } + + /// Poll for the next [`StreamCommand`] from any of the internal receivers. + /// + /// The only difference to a plain [`SelectAll`] is that this will never return [`None`] but park the current task instead. + pub(crate) fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll { + match ready!(self.inner.poll_next_unpin(cx)) { + Some(cmd) => Poll::Ready(cmd), + None => { + self.waker = Some(cx.waker().clone()); + Poll::Pending + } + } + } + + /// Close all remaining [`mpsc::Receiver`]s. + pub(crate) fn close(&mut self) { + for stream in self.inner.iter_mut() { + stream.close(); + } + } + + /// Returns `true` if there are no [`mpsc::Receiver`]s. + pub(crate) fn is_empty(&self) -> bool { + self.inner.is_empty() + } +} From 09a660e0ebd275881befdd5ba5eda11f5510a654 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Mon, 22 May 2023 13:19:41 +0200 Subject: [PATCH 07/15] Fix compile error in unit test --- yamux/src/connection.rs | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index 6b24c48f..965d773c 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -998,24 +998,14 @@ mod tests { } => { match Pin::new(&mut stream).poll_flush(cx)? { Poll::Ready(()) => { - // Here is the actual test: - // If the stream reports that it successfully flushed, we expect the connection to have queued the frames for sending. - // Because we only have a single stream, this means we can simply assert that there are no pending frames in the channel. - let ConnectionState::Active(active) = &mut connection.inner else { panic!("Connection is not active") }; - assert_eq!(active.stream_receivers.len(), 1); - active - .stream_receivers - .iter_mut() - .next() - .unwrap() - .try_next() - .expect_err( - "expected no pending frames in the channel after flushing", - ); + // Here is the actual test: + // If the stream reports that it successfully flushed, we expect the connection to have queued the frames for sending + // and thus not have any more `StreamCommand`s. + assert!(active.stream_receivers.poll_next(cx).is_pending()); *this = Client::Testing { worker_stream: StreamState::Receiving(stream), From 4b5293b7fde4db0797d9e852c9ce751875deae2c Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Mon, 22 May 2023 13:50:07 +0200 Subject: [PATCH 08/15] Fix a bug where closing would hang forever --- yamux/src/connection/cleanup.rs | 16 ++++++---------- yamux/src/connection/closing.rs | 15 ++++----------- yamux/src/connection/command_receivers.rs | 5 ----- 3 files changed, 10 insertions(+), 26 deletions(-) diff --git a/yamux/src/connection/cleanup.rs b/yamux/src/connection/cleanup.rs index b7a7de63..c0e4f6fa 100644 --- a/yamux/src/connection/cleanup.rs +++ b/yamux/src/connection/cleanup.rs @@ -34,21 +34,17 @@ impl Future for Cleanup { this.stream_receivers.close(); this.state = State::DrainingStreamReceiver; } - State::DrainingStreamReceiver => match this.stream_receivers.poll_next(cx) { Poll::Ready(cmd) => { drop(cmd); } + // Poll::Pending means that there are no more commands. Poll::Pending => { - if this.stream_receivers.is_empty() { - return Poll::Ready( - this.error - .take() - .expect("to not be called after completion"), - ); - } - - return Poll::Pending; + return Poll::Ready( + this.error + .take() + .expect("to not be called after completion"), + ) } }, } diff --git a/yamux/src/connection/closing.rs b/yamux/src/connection/closing.rs index b6cddb88..86ed480a 100644 --- a/yamux/src/connection/closing.rs +++ b/yamux/src/connection/closing.rs @@ -62,18 +62,12 @@ where .push_back(Frame::close_stream(id, ack).into()); } Poll::Pending => { - if this.stream_receivers.is_empty() { - this.state = State::SendingTermFrame; - continue; - } - - return Poll::Pending; + // No more frames from streams, append `Term` frame and flush them all. + this.pending_frames.push_back(Frame::term().into()); + this.state = State::FlushingPendingFrames; + continue; } }, - State::SendingTermFrame => { - this.pending_frames.push_back(Frame::term().into()); - this.state = State::FlushingPendingFrames; - } State::FlushingPendingFrames => { ready!(this.socket.poll_ready_unpin(cx))?; @@ -95,7 +89,6 @@ where enum State { ClosingStreamReceiver, DrainingStreamReceiver, - SendingTermFrame, FlushingPendingFrames, ClosingSocket, } diff --git a/yamux/src/connection/command_receivers.rs b/yamux/src/connection/command_receivers.rs index 0454548d..e70a76fa 100644 --- a/yamux/src/connection/command_receivers.rs +++ b/yamux/src/connection/command_receivers.rs @@ -39,9 +39,4 @@ impl CommandReceivers { stream.close(); } } - - /// Returns `true` if there are no [`mpsc::Receiver`]s. - pub(crate) fn is_empty(&self) -> bool { - self.inner.is_empty() - } } From c6acd8ef26a78c4fd8ed1dfd95e99d5231463265 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Mon, 22 May 2023 14:46:28 +0200 Subject: [PATCH 09/15] Replace `garbage_collect` with detecting closed receiver --- yamux/src/connection.rs | 487 ++++++---------------- yamux/src/connection/cleanup.rs | 23 +- yamux/src/connection/closing.rs | 44 +- yamux/src/connection/command_receivers.rs | 42 -- yamux/src/connection/stream.rs | 4 - 5 files changed, 156 insertions(+), 444 deletions(-) delete mode 100644 yamux/src/connection/command_receivers.rs diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index 965d773c..ca1f6788 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -89,7 +89,6 @@ mod cleanup; mod closing; -mod command_receivers; mod stream; use crate::Result; @@ -101,13 +100,14 @@ use crate::{ }; use cleanup::Cleanup; use closing::Closing; +use futures::stream::SelectAll; use futures::{channel::mpsc, future::Either, prelude::*, sink::SinkExt, stream::Fuse}; use nohash_hasher::IntMap; use std::collections::VecDeque; -use std::task::Context; -use std::{fmt, sync::Arc, task::Poll}; +use std::iter::FromIterator; +use std::task::{Context, Waker}; +use std::{fmt, mem, sync::Arc, task::Poll}; -use crate::connection::command_receivers::CommandReceivers; pub use stream::{Packet, State, Stream}; /// How the connection is used. @@ -349,9 +349,11 @@ struct Active { config: Arc, socket: Fuse>, next_id: u32, + streams: IntMap, - stream_receivers: CommandReceivers, - dropped_streams: Vec, + stream_receivers: Vec<(StreamId, mpsc::Receiver)>, + no_streams_waker: Option, + pending_frames: VecDeque>, } @@ -416,19 +418,27 @@ impl Active { config: Arc::new(cfg), socket, streams: IntMap::default(), - stream_receivers: CommandReceivers::default(), + stream_receivers: Vec::default(), + no_streams_waker: None, next_id: match mode { Mode::Client => 1, Mode::Server => 2, }, - dropped_streams: Vec::new(), pending_frames: VecDeque::default(), } } /// Gracefully close the connection to the remote. fn close(self) -> Closing { - Closing::new(self.stream_receivers, self.pending_frames, self.socket) + Closing::new( + SelectAll::from_iter( + self.stream_receivers + .into_iter() + .map(|(_, receiver)| receiver), + ), + self.pending_frames, + self.socket, + ) } /// Cleanup all our resources. @@ -437,13 +447,18 @@ impl Active { fn cleanup(mut self, error: ConnectionError) -> Cleanup { self.drop_all_streams(); - Cleanup::new(self.stream_receivers, error) + Cleanup::new( + SelectAll::from_iter( + self.stream_receivers + .into_iter() + .map(|(_, receiver)| receiver), + ), + error, + ) } fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { loop { - self.garbage_collect(); - if self.socket.poll_ready_unpin(cx).is_ready() { if let Some(frame) = self.pending_frames.pop_front() { self.socket.start_send_unpin(frame)?; @@ -456,16 +471,23 @@ impl Active { Poll::Pending => {} } - match self.stream_receivers.poll_next(cx) { - Poll::Ready(StreamCommand::SendFrame(frame)) => { - self.on_send_frame(frame); - continue; - } - Poll::Ready(StreamCommand::CloseStream { id, ack }) => { - self.on_close_stream(id, ack); - continue; + for (id, mut stream) in mem::take(&mut self.stream_receivers) { + match stream.poll_next_unpin(cx) { + Poll::Ready(Some(StreamCommand::SendFrame(frame))) => { + self.on_send_frame(frame); + self.stream_receivers.push((id, stream)); + } + Poll::Ready(Some(StreamCommand::CloseStream { id, ack })) => { + self.on_close_stream(id, ack); + self.stream_receivers.push((id, stream)); + } + Poll::Ready(None) => { + self.on_drop_stream(id); + } + Poll::Pending => { + self.stream_receivers.push((id, stream)); + } } - Poll::Pending => {} } match self.socket.poll_next_unpin(cx) { @@ -481,6 +503,10 @@ impl Active { Poll::Pending => {} } + if self.stream_receivers.is_empty() { + self.no_streams_waker = Some(cx.waker().clone()); + } + // If we make it this far, at least one of the above must have registered a waker. return Poll::Pending; } @@ -532,6 +558,71 @@ impl Active { .push_back(Frame::close_stream(id, ack).into()); } + fn on_drop_stream(&mut self, id: StreamId) { + let stream = self.streams.remove(&id).expect("stream not found"); + + log::trace!("{}: removing dropped {}", self.id, stream); + let stream_id = stream.id(); + let frame = { + let mut shared = stream.shared(); + let frame = match shared.update_state(self.id, stream_id, State::Closed) { + // The stream was dropped without calling `poll_close`. + // We reset the stream to inform the remote of the closure. + State::Open => { + let mut header = Header::data(stream_id, 0); + header.rst(); + Some(Frame::new(header)) + } + // The stream was dropped without calling `poll_close`. + // We have already received a FIN from remote and send one + // back which closes the stream for good. + State::RecvClosed => { + let mut header = Header::data(stream_id, 0); + header.fin(); + Some(Frame::new(header)) + } + // The stream was properly closed. We either already have + // or will at some later point send our FIN frame. + // The remote may be out of credit though and blocked on + // writing more data. We may need to reset the stream. + State::SendClosed => { + if self.config.window_update_mode == WindowUpdateMode::OnRead + && shared.window == 0 + { + // The remote may be waiting for a window update + // which we will never send, so reset the stream now. + let mut header = Header::data(stream_id, 0); + header.rst(); + Some(Frame::new(header)) + } else { + // The remote has either still credit or will be given more + // (due to an enqueued window update or because the update + // mode is `OnReceive`) or we already have inbound frames in + // the socket buffer which will be processed later. In any + // case we will reply with an RST in `Connection::on_data` + // because the stream will no longer be known. + None + } + } + // The stream was properly closed. We either already have + // or will at some later point send our FIN frame. The + // remote end has already done so in the past. + State::Closed => None, + }; + if let Some(w) = shared.reader.take() { + w.wake() + } + if let Some(w) = shared.writer.take() { + w.wake() + } + frame + }; + if let Some(f) = frame { + log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header()); + self.pending_frames.push_back(f.into()); + } + } + /// Process the result of reading from the socket. /// /// Unless `frame` is `Ok(Some(_))` we will assume the connection got closed @@ -806,13 +897,11 @@ impl Active { fn make_new_stream(&mut self, id: StreamId, window: u32, credit: u32) -> Stream { let config = self.config.clone(); - // Create a channel with 0 _additional_ capacity for items. - // `poll_flush` for `Sender` will check whether we can send an item into the stream and - // NOT whether all items have been taken out of the receiver. - // To ensure that `poll_flush` on our `Stream` means that we have sent all frames, - // this channel must be configured with 0 capacity. - let (sender, receiver) = mpsc::channel(0); - self.stream_receivers.push(receiver); + let (sender, receiver) = mpsc::channel(10); + self.stream_receivers.push((id, receiver)); + if let Some(waker) = self.no_streams_waker.take() { + waker.wake(); + } Stream::new(id, self.id, config, window, credit, sender) } @@ -840,79 +929,6 @@ impl Active { Mode::Server => id.is_client(), } } - - /// Remove stale streams and create necessary messages to be sent to the remote. - fn garbage_collect(&mut self) { - let conn_id = self.id; - let win_update_mode = self.config.window_update_mode; - for stream in self.streams.values_mut() { - if stream.strong_count() > 1 { - continue; - } - log::trace!("{}: removing dropped {}", conn_id, stream); - let stream_id = stream.id(); - let frame = { - let mut shared = stream.shared(); - let frame = match shared.update_state(conn_id, stream_id, State::Closed) { - // The stream was dropped without calling `poll_close`. - // We reset the stream to inform the remote of the closure. - State::Open => { - let mut header = Header::data(stream_id, 0); - header.rst(); - Some(Frame::new(header)) - } - // The stream was dropped without calling `poll_close`. - // We have already received a FIN from remote and send one - // back which closes the stream for good. - State::RecvClosed => { - let mut header = Header::data(stream_id, 0); - header.fin(); - Some(Frame::new(header)) - } - // The stream was properly closed. We either already have - // or will at some later point send our FIN frame. - // The remote may be out of credit though and blocked on - // writing more data. We may need to reset the stream. - State::SendClosed => { - if win_update_mode == WindowUpdateMode::OnRead && shared.window == 0 { - // The remote may be waiting for a window update - // which we will never send, so reset the stream now. - let mut header = Header::data(stream_id, 0); - header.rst(); - Some(Frame::new(header)) - } else { - // The remote has either still credit or will be given more - // (due to an enqueued window update or because the update - // mode is `OnReceive`) or we already have inbound frames in - // the socket buffer which will be processed later. In any - // case we will reply with an RST in `Connection::on_data` - // because the stream will no longer be known. - None - } - } - // The stream was properly closed. We either already have - // or will at some later point send our FIN frame. The - // remote end has already done so in the past. - State::Closed => None, - }; - if let Some(w) = shared.reader.take() { - w.wake() - } - if let Some(w) = shared.writer.take() { - w.wake() - } - frame - }; - if let Some(f) = frame { - log::trace!("{}/{}: sending: {}", self.id, stream_id, f.header()); - self.pending_frames.push_back(f.into()); - } - self.dropped_streams.push(stream_id) - } - for id in self.dropped_streams.drain(..) { - self.streams.remove(&id); - } - } } impl Active { @@ -930,272 +946,3 @@ impl Active { } } } - -#[cfg(test)] -mod tests { - use super::*; - use futures::future::BoxFuture; - use futures::stream::FuturesUnordered; - use futures::AsyncReadExt; - use futures_ringbuf::Endpoint; - use std::mem; - use std::pin::Pin; - - #[tokio::test] - async fn poll_flush_on_stream_only_returns_ok_if_frame_is_queued_for_sending() { - let (client, server) = Endpoint::pair(1000, 1000); - - let client = Client::new(Connection::new(client, Config::default(), Mode::Client)); - let server = EchoServer::new(Connection::new(server, Config::default(), Mode::Server)); - - let ((), processed) = futures::future::try_join(client, server).await.unwrap(); - - assert_eq!(processed, 1); - } - - /// Our testing client. - /// - /// This struct will open a single outbound stream, send a message, attempt to flush it and assert the internal state of [`Connection`] after it. - enum Client { - Initial { - connection: Connection, - }, - Testing { - connection: Connection, - worker_stream: StreamState, - }, - Closing { - connection: Connection, - }, - Poisoned, - } - - enum StreamState { - Sending(Stream), - Flushing(Stream), - Receiving(Stream), - Closing(Stream), - } - - impl Client { - fn new(connection: Connection) -> Self { - Self::Initial { connection } - } - } - - impl Future for Client { - type Output = Result<()>; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); - - loop { - match mem::replace(this, Client::Poisoned) { - // This state matching is out of order to have the interesting one at the top. - Client::Testing { - worker_stream: StreamState::Flushing(mut stream), - mut connection, - } => { - match Pin::new(&mut stream).poll_flush(cx)? { - Poll::Ready(()) => { - let ConnectionState::Active(active) = &mut connection.inner else { - panic!("Connection is not active") - }; - - // Here is the actual test: - // If the stream reports that it successfully flushed, we expect the connection to have queued the frames for sending - // and thus not have any more `StreamCommand`s. - assert!(active.stream_receivers.poll_next(cx).is_pending()); - - *this = Client::Testing { - worker_stream: StreamState::Receiving(stream), - connection, - }; - continue; - } - Poll::Pending => {} - } - - drive_connection(this, connection, StreamState::Flushing(stream), cx); - return Poll::Pending; - } - Client::Testing { - worker_stream: StreamState::Receiving(mut stream), - connection, - } => { - let mut buffer = [0u8; 5]; - - match Pin::new(&mut stream).poll_read(cx, &mut buffer)? { - Poll::Ready(num_bytes) => { - assert_eq!(num_bytes, 5); - assert_eq!(&buffer, b"hello"); - - *this = Client::Testing { - worker_stream: StreamState::Closing(stream), - connection, - }; - continue; - } - Poll::Pending => {} - } - - drive_connection(this, connection, StreamState::Closing(stream), cx); - return Poll::Pending; - } - Client::Testing { - worker_stream: StreamState::Closing(mut stream), - connection, - } => { - match Pin::new(&mut stream).poll_close(cx)? { - Poll::Ready(()) => { - *this = Client::Closing { connection }; - continue; - } - Poll::Pending => {} - } - - drive_connection(this, connection, StreamState::Closing(stream), cx); - return Poll::Pending; - } - Client::Initial { mut connection } => { - match connection.poll_new_outbound(cx)? { - Poll::Ready(stream) => { - *this = Client::Testing { - connection, - worker_stream: StreamState::Sending(stream), - }; - continue; - } - Poll::Pending => { - *this = Client::Initial { connection }; - return Poll::Pending; - } - } - } - Client::Testing { - worker_stream: StreamState::Sending(mut stream), - connection, - } => { - match Pin::new(&mut stream).poll_write(cx, b"hello")? { - Poll::Ready(written) => { - assert_eq!(written, 5); - *this = Client::Testing { - worker_stream: StreamState::Flushing(stream), - connection, - }; - continue; - } - Poll::Pending => {} - } - - drive_connection(this, connection, StreamState::Flushing(stream), cx); - return Poll::Pending; - } - Client::Closing { mut connection } => match connection.poll_close(cx)? { - Poll::Ready(()) => { - return Poll::Ready(Ok(())); - } - Poll::Pending => { - *this = Client::Closing { connection }; - return Poll::Pending; - } - }, - Client::Poisoned => { - unreachable!() - } - } - } - } - } - - fn drive_connection( - this: &mut Client, - mut connection: Connection, - state: StreamState, - cx: &mut Context, - ) { - match connection.poll_next_inbound(cx) { - Poll::Ready(Some(_)) => { - panic!("Unexpected inbound stream") - } - Poll::Ready(None) => { - panic!("Unexpected connection close") - } - Poll::Pending => { - *this = Client::Testing { - worker_stream: state, - connection, - }; - } - } - } - - struct EchoServer { - connection: Connection, - worker_streams: FuturesUnordered>>, - streams_processed: usize, - connection_closed: bool, - } - - impl EchoServer { - fn new(connection: Connection) -> Self { - Self { - connection, - worker_streams: FuturesUnordered::default(), - streams_processed: 0, - connection_closed: false, - } - } - } - - impl Future for EchoServer { - type Output = Result; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.get_mut(); - - loop { - match this.worker_streams.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(()))) => { - this.streams_processed += 1; - continue; - } - Poll::Ready(Some(Err(e))) => { - eprintln!("A stream failed: {}", e); - continue; - } - Poll::Ready(None) => { - if this.connection_closed { - return Poll::Ready(Ok(this.streams_processed)); - } - } - Poll::Pending => {} - } - - match this.connection.poll_next_inbound(cx) { - Poll::Ready(Some(Ok(mut stream))) => { - this.worker_streams.push( - async move { - { - let (mut r, mut w) = AsyncReadExt::split(&mut stream); - futures::io::copy(&mut r, &mut w).await?; - } - stream.close().await?; - Ok(()) - } - .boxed(), - ); - continue; - } - Poll::Ready(None) | Poll::Ready(Some(Err(_))) => { - this.connection_closed = true; - continue; - } - Poll::Pending => {} - } - - return Poll::Pending; - } - } - } -} diff --git a/yamux/src/connection/cleanup.rs b/yamux/src/connection/cleanup.rs index c0e4f6fa..5afeb0fa 100644 --- a/yamux/src/connection/cleanup.rs +++ b/yamux/src/connection/cleanup.rs @@ -1,5 +1,8 @@ -use crate::connection::command_receivers::CommandReceivers; +use crate::connection::StreamCommand; use crate::ConnectionError; +use futures::channel::mpsc; +use futures::stream::SelectAll; +use futures::StreamExt; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; @@ -8,12 +11,15 @@ use std::task::{Context, Poll}; #[must_use] pub struct Cleanup { state: State, - stream_receivers: CommandReceivers, + stream_receivers: SelectAll>, error: Option, } impl Cleanup { - pub(crate) fn new(stream_receivers: CommandReceivers, error: ConnectionError) -> Self { + pub(crate) fn new( + stream_receivers: SelectAll>, + error: ConnectionError, + ) -> Self { Self { state: State::ClosingStreamReceiver, stream_receivers, @@ -31,15 +37,16 @@ impl Future for Cleanup { loop { match this.state { State::ClosingStreamReceiver => { - this.stream_receivers.close(); + for stream in this.stream_receivers.iter_mut() { + stream.close(); + } this.state = State::DrainingStreamReceiver; } - State::DrainingStreamReceiver => match this.stream_receivers.poll_next(cx) { - Poll::Ready(cmd) => { + State::DrainingStreamReceiver => match this.stream_receivers.poll_next_unpin(cx) { + Poll::Ready(Some(cmd)) => { drop(cmd); } - // Poll::Pending means that there are no more commands. - Poll::Pending => { + Poll::Ready(None) | Poll::Pending => { return Poll::Ready( this.error .take() diff --git a/yamux/src/connection/closing.rs b/yamux/src/connection/closing.rs index 86ed480a..d7d12ec3 100644 --- a/yamux/src/connection/closing.rs +++ b/yamux/src/connection/closing.rs @@ -1,10 +1,10 @@ -use crate::connection::command_receivers::CommandReceivers; use crate::connection::StreamCommand; use crate::frame; use crate::frame::Frame; use crate::Result; -use futures::stream::Fuse; -use futures::{ready, AsyncRead, AsyncWrite, SinkExt}; +use futures::channel::mpsc; +use futures::stream::{Fuse, SelectAll}; +use futures::{ready, AsyncRead, AsyncWrite, SinkExt, StreamExt}; use std::collections::VecDeque; use std::future::Future; use std::pin::Pin; @@ -14,7 +14,7 @@ use std::task::{Context, Poll}; #[must_use] pub struct Closing { state: State, - stream_receivers: CommandReceivers, + stream_receivers: SelectAll>, pending_frames: VecDeque>, socket: Fuse>, } @@ -24,7 +24,7 @@ where T: AsyncRead + AsyncWrite + Unpin, { pub(crate) fn new( - stream_receivers: CommandReceivers, + stream_receivers: SelectAll>, pending_frames: VecDeque>, socket: Fuse>, ) -> Self { @@ -49,25 +49,29 @@ where loop { match this.state { State::ClosingStreamReceiver => { - this.stream_receivers.close(); + for stream in this.stream_receivers.iter_mut() { + stream.close(); + } this.state = State::DrainingStreamReceiver; } - State::DrainingStreamReceiver => match this.stream_receivers.poll_next(cx) { - Poll::Ready(StreamCommand::SendFrame(frame)) => { - this.pending_frames.push_back(frame.into()) - } - Poll::Ready(StreamCommand::CloseStream { id, ack }) => { - this.pending_frames - .push_back(Frame::close_stream(id, ack).into()); + State::DrainingStreamReceiver => { + match this.stream_receivers.poll_next_unpin(cx) { + Poll::Ready(Some(StreamCommand::SendFrame(frame))) => { + this.pending_frames.push_back(frame.into()) + } + Poll::Ready(Some(StreamCommand::CloseStream { id, ack })) => { + this.pending_frames + .push_back(Frame::close_stream(id, ack).into()); + } + Poll::Pending | Poll::Ready(None) => { + // No more frames from streams, append `Term` frame and flush them all. + this.pending_frames.push_back(Frame::term().into()); + this.state = State::FlushingPendingFrames; + continue; + } } - Poll::Pending => { - // No more frames from streams, append `Term` frame and flush them all. - this.pending_frames.push_back(Frame::term().into()); - this.state = State::FlushingPendingFrames; - continue; - } - }, + } State::FlushingPendingFrames => { ready!(this.socket.poll_ready_unpin(cx))?; diff --git a/yamux/src/connection/command_receivers.rs b/yamux/src/connection/command_receivers.rs deleted file mode 100644 index e70a76fa..00000000 --- a/yamux/src/connection/command_receivers.rs +++ /dev/null @@ -1,42 +0,0 @@ -use crate::connection::StreamCommand; -use futures::channel::mpsc; -use futures::stream::SelectAll; -use futures::{ready, StreamExt}; -use std::task::{Context, Poll, Waker}; - -/// A set of [`mpsc::Receiver`]s for [`StreamCommand`]s. -#[derive(Default)] -pub struct CommandReceivers { - inner: SelectAll>, - waker: Option, -} - -impl CommandReceivers { - /// Push a new [`mpsc::Receiver`]. - pub(crate) fn push(&mut self, receiver: mpsc::Receiver) { - self.inner.push(receiver); - if let Some(waker) = self.waker.take() { - waker.wake(); - } - } - - /// Poll for the next [`StreamCommand`] from any of the internal receivers. - /// - /// The only difference to a plain [`SelectAll`] is that this will never return [`None`] but park the current task instead. - pub(crate) fn poll_next(&mut self, cx: &mut Context<'_>) -> Poll { - match ready!(self.inner.poll_next_unpin(cx)) { - Some(cmd) => Poll::Ready(cmd), - None => { - self.waker = Some(cx.waker().clone()); - Poll::Pending - } - } - } - - /// Close all remaining [`mpsc::Receiver`]s. - pub(crate) fn close(&mut self) { - for stream in self.inner.iter_mut() { - stream.close(); - } - } -} diff --git a/yamux/src/connection/stream.rs b/yamux/src/connection/stream.rs index 466f817b..ed918711 100644 --- a/yamux/src/connection/stream.rs +++ b/yamux/src/connection/stream.rs @@ -136,10 +136,6 @@ impl Stream { self.flag = flag } - pub(crate) fn strong_count(&self) -> usize { - Arc::strong_count(&self.shared) - } - pub(crate) fn shared(&self) -> MutexGuard<'_, Shared> { self.shared.lock() } From e0ec0aa975cfd790a7faa6f36f04552f10ca197a Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Mon, 22 May 2023 15:08:24 +0200 Subject: [PATCH 10/15] Correctly poll all receivers --- yamux/src/connection.rs | 55 ++++++++++++++++++++++++++--------------- 1 file changed, 35 insertions(+), 20 deletions(-) diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index ca1f6788..b3189441 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -106,7 +106,7 @@ use nohash_hasher::IntMap; use std::collections::VecDeque; use std::iter::FromIterator; use std::task::{Context, Waker}; -use std::{fmt, mem, sync::Arc, task::Poll}; +use std::{fmt, sync::Arc, task::Poll}; pub use stream::{Packet, State, Stream}; @@ -471,23 +471,16 @@ impl Active { Poll::Pending => {} } - for (id, mut stream) in mem::take(&mut self.stream_receivers) { - match stream.poll_next_unpin(cx) { - Poll::Ready(Some(StreamCommand::SendFrame(frame))) => { - self.on_send_frame(frame); - self.stream_receivers.push((id, stream)); - } - Poll::Ready(Some(StreamCommand::CloseStream { id, ack })) => { - self.on_close_stream(id, ack); - self.stream_receivers.push((id, stream)); - } - Poll::Ready(None) => { - self.on_drop_stream(id); - } - Poll::Pending => { - self.stream_receivers.push((id, stream)); - } + match self.poll_stream_receivers(cx) { + Poll::Ready(StreamCommand::SendFrame(frame)) => { + self.on_send_frame(frame.into()); + continue; } + Poll::Ready(StreamCommand::CloseStream { id, ack }) => { + self.on_close_stream(id, ack); + continue; + } + Poll::Pending => {} } match self.socket.poll_next_unpin(cx) { @@ -503,13 +496,35 @@ impl Active { Poll::Pending => {} } - if self.stream_receivers.is_empty() { - self.no_streams_waker = Some(cx.waker().clone()); + // If we make it this far, at least one of the above must have registered a waker. + return Poll::Pending; + } + } + + fn poll_stream_receivers(&mut self, cx: &mut Context) -> Poll { + for i in (0..self.stream_receivers.len()).rev() { + let (id, mut receiver) = self.stream_receivers.swap_remove(i); + + match receiver.poll_next_unpin(cx) { + Poll::Ready(Some(command)) => { + self.stream_receivers.push((id, receiver)); + return Poll::Ready(command); + } + Poll::Ready(None) => { + self.on_drop_stream(id); + } + Poll::Pending => { + self.stream_receivers.push((id, receiver)); + } } + } - // If we make it this far, at least one of the above must have registered a waker. + if self.stream_receivers.is_empty() { + self.no_streams_waker = Some(cx.waker().clone()); return Poll::Pending; } + + Poll::Pending } fn new_outbound(&mut self) -> Result { From e755c2c5224803c690155f04fc3b7f2423c69aff Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 23 May 2023 22:57:25 +0200 Subject: [PATCH 11/15] Introduce `TaggedStream` so we can use `SelectAll` --- yamux/Cargo.toml | 1 + yamux/src/connection.rs | 52 +++++++++------------------------ yamux/src/connection/cleanup.rs | 9 +++--- yamux/src/connection/closing.rs | 14 +++++---- yamux/src/lib.rs | 1 + yamux/src/tagged_stream.rs | 52 +++++++++++++++++++++++++++++++++ 6 files changed, 81 insertions(+), 48 deletions(-) create mode 100644 yamux/src/tagged_stream.rs diff --git a/yamux/Cargo.toml b/yamux/Cargo.toml index 24041444..edb3774f 100644 --- a/yamux/Cargo.toml +++ b/yamux/Cargo.toml @@ -16,6 +16,7 @@ nohash-hasher = "0.2" parking_lot = "0.12" rand = "0.8.3" static_assertions = "1" +pin-project = "1.1.0" [dev-dependencies] anyhow = "1" diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index b3189441..b4cad570 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -104,10 +104,10 @@ use futures::stream::SelectAll; use futures::{channel::mpsc, future::Either, prelude::*, sink::SinkExt, stream::Fuse}; use nohash_hasher::IntMap; use std::collections::VecDeque; -use std::iter::FromIterator; use std::task::{Context, Waker}; use std::{fmt, sync::Arc, task::Poll}; +use crate::tagged_stream::TaggedStream; pub use stream::{Packet, State, Stream}; /// How the connection is used. @@ -351,7 +351,7 @@ struct Active { next_id: u32, streams: IntMap, - stream_receivers: Vec<(StreamId, mpsc::Receiver)>, + stream_receivers: SelectAll>>, no_streams_waker: Option, pending_frames: VecDeque>, @@ -418,7 +418,7 @@ impl Active { config: Arc::new(cfg), socket, streams: IntMap::default(), - stream_receivers: Vec::default(), + stream_receivers: SelectAll::default(), no_streams_waker: None, next_id: match mode { Mode::Client => 1, @@ -430,15 +430,7 @@ impl Active { /// Gracefully close the connection to the remote. fn close(self) -> Closing { - Closing::new( - SelectAll::from_iter( - self.stream_receivers - .into_iter() - .map(|(_, receiver)| receiver), - ), - self.pending_frames, - self.socket, - ) + Closing::new(self.stream_receivers, self.pending_frames, self.socket) } /// Cleanup all our resources. @@ -447,14 +439,7 @@ impl Active { fn cleanup(mut self, error: ConnectionError) -> Cleanup { self.drop_all_streams(); - Cleanup::new( - SelectAll::from_iter( - self.stream_receivers - .into_iter() - .map(|(_, receiver)| receiver), - ), - error, - ) + Cleanup::new(self.stream_receivers, error) } fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { @@ -502,29 +487,20 @@ impl Active { } fn poll_stream_receivers(&mut self, cx: &mut Context) -> Poll { - for i in (0..self.stream_receivers.len()).rev() { - let (id, mut receiver) = self.stream_receivers.swap_remove(i); - - match receiver.poll_next_unpin(cx) { - Poll::Ready(Some(command)) => { - self.stream_receivers.push((id, receiver)); + loop { + match futures::ready!(self.stream_receivers.poll_next_unpin(cx)) { + None => { + self.no_streams_waker = Some(cx.waker().clone()); + return Poll::Pending; + } + Some((_, Some(command))) => { return Poll::Ready(command); } - Poll::Ready(None) => { + Some((id, None)) => { self.on_drop_stream(id); } - Poll::Pending => { - self.stream_receivers.push((id, receiver)); - } } } - - if self.stream_receivers.is_empty() { - self.no_streams_waker = Some(cx.waker().clone()); - return Poll::Pending; - } - - Poll::Pending } fn new_outbound(&mut self) -> Result { @@ -913,7 +889,7 @@ impl Active { let config = self.config.clone(); let (sender, receiver) = mpsc::channel(10); - self.stream_receivers.push((id, receiver)); + self.stream_receivers.push(TaggedStream::new(id, receiver)); if let Some(waker) = self.no_streams_waker.take() { waker.wake(); } diff --git a/yamux/src/connection/cleanup.rs b/yamux/src/connection/cleanup.rs index 5afeb0fa..b4dfa816 100644 --- a/yamux/src/connection/cleanup.rs +++ b/yamux/src/connection/cleanup.rs @@ -1,5 +1,6 @@ use crate::connection::StreamCommand; -use crate::ConnectionError; +use crate::tagged_stream::TaggedStream; +use crate::{ConnectionError, StreamId}; use futures::channel::mpsc; use futures::stream::SelectAll; use futures::StreamExt; @@ -11,13 +12,13 @@ use std::task::{Context, Poll}; #[must_use] pub struct Cleanup { state: State, - stream_receivers: SelectAll>, + stream_receivers: SelectAll>>, error: Option, } impl Cleanup { pub(crate) fn new( - stream_receivers: SelectAll>, + stream_receivers: SelectAll>>, error: ConnectionError, ) -> Self { Self { @@ -38,7 +39,7 @@ impl Future for Cleanup { match this.state { State::ClosingStreamReceiver => { for stream in this.stream_receivers.iter_mut() { - stream.close(); + stream.inner_mut().close(); } this.state = State::DrainingStreamReceiver; } diff --git a/yamux/src/connection/closing.rs b/yamux/src/connection/closing.rs index d7d12ec3..37fc365b 100644 --- a/yamux/src/connection/closing.rs +++ b/yamux/src/connection/closing.rs @@ -1,7 +1,8 @@ use crate::connection::StreamCommand; -use crate::frame; use crate::frame::Frame; +use crate::tagged_stream::TaggedStream; use crate::Result; +use crate::{frame, StreamId}; use futures::channel::mpsc; use futures::stream::{Fuse, SelectAll}; use futures::{ready, AsyncRead, AsyncWrite, SinkExt, StreamExt}; @@ -14,7 +15,7 @@ use std::task::{Context, Poll}; #[must_use] pub struct Closing { state: State, - stream_receivers: SelectAll>, + stream_receivers: SelectAll>>, pending_frames: VecDeque>, socket: Fuse>, } @@ -24,7 +25,7 @@ where T: AsyncRead + AsyncWrite + Unpin, { pub(crate) fn new( - stream_receivers: SelectAll>, + stream_receivers: SelectAll>>, pending_frames: VecDeque>, socket: Fuse>, ) -> Self { @@ -50,20 +51,21 @@ where match this.state { State::ClosingStreamReceiver => { for stream in this.stream_receivers.iter_mut() { - stream.close(); + stream.inner_mut().close(); } this.state = State::DrainingStreamReceiver; } State::DrainingStreamReceiver => { match this.stream_receivers.poll_next_unpin(cx) { - Poll::Ready(Some(StreamCommand::SendFrame(frame))) => { + Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => { this.pending_frames.push_back(frame.into()) } - Poll::Ready(Some(StreamCommand::CloseStream { id, ack })) => { + Poll::Ready(Some((_, Some(StreamCommand::CloseStream { id, ack })))) => { this.pending_frames .push_back(Frame::close_stream(id, ack).into()); } + Poll::Ready(Some((_, None))) => {} Poll::Pending | Poll::Ready(None) => { // No more frames from streams, append `Term` frame and flush them all. this.pending_frames.push_back(Frame::term().into()); diff --git a/yamux/src/lib.rs b/yamux/src/lib.rs index cafbf77b..040bd8d7 100644 --- a/yamux/src/lib.rs +++ b/yamux/src/lib.rs @@ -30,6 +30,7 @@ mod error; mod frame; pub(crate) mod connection; +mod tagged_stream; pub use crate::connection::{Connection, Mode, Packet, Stream}; pub use crate::control::{Control, ControlledConnection}; diff --git a/yamux/src/tagged_stream.rs b/yamux/src/tagged_stream.rs new file mode 100644 index 00000000..5c6035aa --- /dev/null +++ b/yamux/src/tagged_stream.rs @@ -0,0 +1,52 @@ +use futures::Stream; +use std::pin::Pin; +use std::task::{Context, Poll}; + +/// A stream that yields its tag with every item. +#[pin_project::pin_project] +pub struct TaggedStream { + key: K, + #[pin] + inner: S, + + reported_none: bool, +} + +impl TaggedStream { + pub fn new(key: K, inner: S) -> Self { + Self { + key, + inner, + reported_none: false, + } + } + + pub fn inner_mut(&mut self) -> &mut S { + &mut self.inner + } +} + +impl Stream for TaggedStream +where + K: Copy, + S: Stream, +{ + type Item = (K, Option); + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.project(); + + if *this.reported_none { + return Poll::Ready(None); + } + + match futures::ready!(this.inner.poll_next(cx)) { + Some(item) => Poll::Ready(Some((*this.key, Some(item)))), + None => { + *this.reported_none = true; + + Poll::Ready(Some((*this.key, None))) + } + } + } +} From fd810c4e55ba483881ddabc2bae74178bf82882a Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 23 May 2023 23:00:22 +0200 Subject: [PATCH 12/15] Inline `poll_stream_receivers` and remove ID from stream command --- yamux/src/connection.rs | 32 +++++++++++--------------------- yamux/src/connection/closing.rs | 2 +- yamux/src/connection/stream.rs | 2 +- 3 files changed, 13 insertions(+), 23 deletions(-) diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index b4cad570..16bbe028 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -363,7 +363,7 @@ pub(crate) enum StreamCommand { /// A new frame should be sent to the remote. SendFrame(Frame>), /// Close a stream. - CloseStream { id: StreamId, ack: bool }, + CloseStream { ack: bool }, } /// Possible actions as a result of incoming frame handling. @@ -456,15 +456,22 @@ impl Active { Poll::Pending => {} } - match self.poll_stream_receivers(cx) { - Poll::Ready(StreamCommand::SendFrame(frame)) => { + match self.stream_receivers.poll_next_unpin(cx) { + Poll::Ready(None) => { + self.no_streams_waker = Some(cx.waker().clone()); + } + Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => { self.on_send_frame(frame.into()); continue; } - Poll::Ready(StreamCommand::CloseStream { id, ack }) => { + Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => { self.on_close_stream(id, ack); continue; } + Poll::Ready(Some((id, None))) => { + self.on_drop_stream(id); + continue; + } Poll::Pending => {} } @@ -486,23 +493,6 @@ impl Active { } } - fn poll_stream_receivers(&mut self, cx: &mut Context) -> Poll { - loop { - match futures::ready!(self.stream_receivers.poll_next_unpin(cx)) { - None => { - self.no_streams_waker = Some(cx.waker().clone()); - return Poll::Pending; - } - Some((_, Some(command))) => { - return Poll::Ready(command); - } - Some((id, None)) => { - self.on_drop_stream(id); - } - } - } - } - fn new_outbound(&mut self) -> Result { if self.streams.len() >= self.config.max_num_streams { log::error!("{}: maximum number of streams reached", self.id); diff --git a/yamux/src/connection/closing.rs b/yamux/src/connection/closing.rs index 37fc365b..d503941f 100644 --- a/yamux/src/connection/closing.rs +++ b/yamux/src/connection/closing.rs @@ -61,7 +61,7 @@ where Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => { this.pending_frames.push_back(frame.into()) } - Poll::Ready(Some((_, Some(StreamCommand::CloseStream { id, ack })))) => { + Poll::Ready(Some((id, Some(StreamCommand::CloseStream { ack })))) => { this.pending_frames .push_back(Frame::close_stream(id, ack).into()); } diff --git a/yamux/src/connection/stream.rs b/yamux/src/connection/stream.rs index ed918711..469ca3e2 100644 --- a/yamux/src/connection/stream.rs +++ b/yamux/src/connection/stream.rs @@ -375,7 +375,7 @@ impl AsyncWrite for Stream { false }; log::trace!("{}/{}: close", self.conn, self.id); - let cmd = StreamCommand::CloseStream { id: self.id, ack }; + let cmd = StreamCommand::CloseStream { ack }; self.sender .start_send(cmd) .map_err(|_| self.write_zero_err())?; From 4060d079732ba2da2a87721ebf716d9488f0957a Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 23 May 2023 23:02:34 +0200 Subject: [PATCH 13/15] Add comment about channel size --- yamux/src/connection.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index 16bbe028..08849062 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -878,7 +878,7 @@ impl Active { fn make_new_stream(&mut self, id: StreamId, window: u32, credit: u32) -> Stream { let config = self.config.clone(); - let (sender, receiver) = mpsc::channel(10); + let (sender, receiver) = mpsc::channel(10); // 10 is an arbitrary number. self.stream_receivers.push(TaggedStream::new(id, receiver)); if let Some(waker) = self.no_streams_waker.take() { waker.wake(); From e8091036987b30a7b6b3aab624ae21795645b6f6 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 23 May 2023 23:04:28 +0200 Subject: [PATCH 14/15] Update docs --- yamux/src/connection.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index 08849062..6a90e9cc 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -562,8 +562,7 @@ impl Active { header.fin(); Some(Frame::new(header)) } - // The stream was properly closed. We either already have - // or will at some later point send our FIN frame. + // The stream was properly closed. We already sent our FIN frame. // The remote may be out of credit though and blocked on // writing more data. We may need to reset the stream. State::SendClosed => { @@ -585,8 +584,7 @@ impl Active { None } } - // The stream was properly closed. We either already have - // or will at some later point send our FIN frame. The + // The stream was properly closed. We already have sent our FIN frame. The // remote end has already done so in the past. State::Closed => None, }; From 53d5ece8f7f4169bab2512d3a68192f7c0027d16 Mon Sep 17 00:00:00 2001 From: Thomas Eizinger Date: Tue, 23 May 2023 23:05:55 +0200 Subject: [PATCH 15/15] Reduce diff --- yamux/src/connection.rs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/yamux/src/connection.rs b/yamux/src/connection.rs index 6a90e9cc..43f76226 100644 --- a/yamux/src/connection.rs +++ b/yamux/src/connection.rs @@ -457,9 +457,6 @@ impl Active { } match self.stream_receivers.poll_next_unpin(cx) { - Poll::Ready(None) => { - self.no_streams_waker = Some(cx.waker().clone()); - } Poll::Ready(Some((_, Some(StreamCommand::SendFrame(frame))))) => { self.on_send_frame(frame.into()); continue; @@ -472,6 +469,9 @@ impl Active { self.on_drop_stream(id); continue; } + Poll::Ready(None) => { + self.no_streams_waker = Some(cx.waker().clone()); + } Poll::Pending => {} }