From 011ae45ecfa01a9c0969ecc38b91ee00d40bb03b Mon Sep 17 00:00:00 2001 From: Alexis Sellier Date: Sun, 19 Dec 2021 17:34:11 +0100 Subject: [PATCH] Tidy up --- p2p/src/protocol.rs | 3 + p2p/src/protocol/executor.rs | 146 +++++++++++++++++++++++++++++++++++ p2p/src/protocol/output.rs | 100 ++++-------------------- 3 files changed, 162 insertions(+), 87 deletions(-) create mode 100644 p2p/src/protocol/executor.rs diff --git a/p2p/src/protocol.rs b/p2p/src/protocol.rs index 0b1bf2a1..81faa67f 100644 --- a/p2p/src/protocol.rs +++ b/p2p/src/protocol.rs @@ -19,6 +19,9 @@ mod syncmgr; #[cfg(test)] mod tests; +// Futures executor. +mod executor; + use addrmgr::AddressManager; use cbfmgr::FilterManager; use invmgr::InventoryManager; diff --git a/p2p/src/protocol/executor.rs b/p2p/src/protocol/executor.rs new file mode 100644 index 00000000..14d7dad8 --- /dev/null +++ b/p2p/src/protocol/executor.rs @@ -0,0 +1,146 @@ +#![allow(dead_code)] +use std::cell::RefCell; +use std::fmt; +use std::future::Future; +use std::pin::Pin; +use std::rc::Rc; +use std::sync::Arc; +use std::task::{Context, Poll}; + +struct Waker; + +impl std::task::Wake for Waker { + fn wake(self: Arc) {} + fn wake_by_ref(self: &Arc) {} +} + +type BoxFuture<'a, T> = Pin + 'a>>; + +struct Task { + future: Option>, +} + +impl fmt::Debug for Task { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Task").finish() + } +} + +#[derive(Debug, Clone)] +pub struct Request { + result: Rc>>>, +} + +impl Request { + pub fn new() -> Self { + Self { + result: Rc::new(RefCell::new(None)), + } + } + + pub fn complete(&mut self, result: Result) { + *self.result.borrow_mut() = Some(result); + } +} + +impl Future for Request { + type Output = Result; + + fn poll( + self: std::pin::Pin<&mut Self>, + _ctx: &mut std::task::Context<'_>, + ) -> std::task::Poll { + // TODO: Use `take()` instead of cloning, once you figure it out. + // For now we have to clone, as multiple futures may share the same + // refcell. + if let Some(result) = self.get_mut().result.borrow().clone() { + Poll::Ready(result) + } else { + Poll::Pending + } + } +} + +#[derive(Clone, Debug)] +pub struct Executor { + tasks: Rc>>, +} + +impl Executor { + pub fn new() -> Self { + Self { + tasks: Rc::new(RefCell::new(Vec::new())), + } + } + + /// Spawn a future to be executed. + pub fn spawn(&mut self, future: impl Future + 'static) { + self.tasks.borrow_mut().push(Task { + future: Some(Box::pin(future)), + }); + } + + /// Poll all tasks for completion. + pub fn poll(&mut self) -> Poll<()> { + let mut tasks = self.tasks.borrow_mut(); + let waker = Arc::new(Waker).into(); + let mut cx = Context::from_waker(&waker); + + for task in tasks.iter_mut() { + if let Some(mut fut) = task.future.take() { + if fut.as_mut().poll(&mut cx).is_pending() { + task.future = Some(fut); + } + } + } + // Clear out all completed futures. + tasks.retain(|t| t.future.is_some()); + + if tasks.is_empty() { + return Poll::Ready(()); + } + Poll::Pending + } +} + +#[cfg(test)] +mod tests { + use super::*; + + struct Random { + val: T, + } + + impl Future for Random { + type Output = T; + + fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + if fastrand::bool() { + Poll::Ready(self.val.clone()) + } else { + Poll::Pending + } + } + } + + #[test] + fn test_executor() { + let mut exe = Executor::new(); + + exe.spawn(async { + Random { val: 1 }.await; + }); + exe.spawn(async { + Random { val: 2 }.await; + }); + exe.spawn(async { + Random { val: 3 }.await; + }); + + loop { + if let Poll::Ready(()) = exe.poll() { + break; + } + } + } +} diff --git a/p2p/src/protocol/output.rs b/p2p/src/protocol/output.rs index b6a9161a..107e60ed 100644 --- a/p2p/src/protocol/output.rs +++ b/p2p/src/protocol/output.rs @@ -8,12 +8,8 @@ use log::*; use std::cell::RefCell; use std::collections::{HashMap, VecDeque}; -use std::future::Future; -use std::pin::Pin; use std::rc::Rc; use std::sync::Arc; -use std::task::Context; -use std::task::Poll; use std::{fmt, io, net}; pub use crossbeam_channel as chan; @@ -34,6 +30,8 @@ use nakamoto_common::block::{BlockHash, BlockHeader, BlockTime, Height}; use crate::protocol::{Event, PeerId}; +use super::executor::Executor; +use super::executor::Request; use super::invmgr::Inventories; use super::network::Network; use super::{addrmgr, cbfmgr, invmgr, peermgr, pingmgr, syncmgr, Locators}; @@ -124,48 +122,6 @@ impl fmt::Display for DisconnectReason { } } -struct Waker; - -impl std::task::Wake for Waker { - fn wake(self: Arc) {} - fn wake_by_ref(self: &Arc) {} -} - -type BoxFuture<'a, T> = Pin + 'a>>; - -struct Task { - future: Option>, -} - -impl fmt::Debug for Task { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("Task").finish() - } -} - -#[derive(Debug, Clone)] -pub struct Request { - result: Rc>>>, -} - -impl Future for Request { - type Output = Result; - - fn poll( - self: std::pin::Pin<&mut Self>, - _ctx: &mut std::task::Context<'_>, - ) -> std::task::Poll { - // TODO: Use `take()` instead of cloning, once you figure it out. - // For now we have to clone, as multiple futures may share the same - // refcell. - if let Some(result) = self.get_mut().result.borrow().clone() { - Poll::Ready(result) - } else { - Poll::Pending - } - } -} - pub(crate) mod message { use nakamoto_common::bitcoin::consensus::Encodable; use nakamoto_common::bitcoin::network::message::RawNetworkMessage; @@ -204,6 +160,8 @@ pub(crate) mod message { pub struct Outbox { /// Protocol version. version: u32, + /// Futures executor. + executor: Executor, /// Output queue. outbound: Rc>>, /// Message outbox. @@ -214,8 +172,6 @@ pub struct Outbox { builder: message::Builder, /// Log target. target: &'static str, - - tasks: Rc>>, } impl Outbox { @@ -223,12 +179,12 @@ impl Outbox { pub fn new(network: Network, version: u32, target: &'static str) -> Self { Self { version, + executor: Executor::new(), outbound: Rc::new(RefCell::new(VecDeque::new())), outbox: Rc::new(RefCell::new(HashMap::new())), block_requests: Rc::new(RefCell::new(HashMap::new())), builder: message::Builder::new(network), target, - tasks: Rc::new(RefCell::new(Vec::new())), } } @@ -294,38 +250,9 @@ impl Outbox { pub fn block_received(&mut self, blk: Block) { let block_hash = blk.block_hash(); - if let Some(req) = self.block_requests.borrow_mut().remove(&block_hash) { - *req.result.borrow_mut() = Some(Ok(blk.clone())); - } - } - - /// Spawn a future to be executed. - pub fn spawn(&mut self, future: impl Future + 'static) { - self.tasks.borrow_mut().push(Task { - future: Some(Box::pin(future)), - }); - } - - /// Poll all tasks for completion. - pub fn poll(&mut self) -> Poll<()> { - let mut tasks = self.tasks.borrow_mut(); - let waker = Arc::new(Waker).into(); - let mut cx = Context::from_waker(&waker); - - for task in tasks.iter_mut() { - if let Some(mut fut) = task.future.take() { - if fut.as_mut().poll(&mut cx).is_pending() { - task.future = Some(fut); - } - } + if let Some(mut req) = self.block_requests.borrow_mut().remove(&block_hash) { + req.complete(Ok(blk.clone())); } - // Clear out all completed futures. - tasks.retain(|t| t.future.is_some()); - - if tasks.is_empty() { - return Poll::Ready(()); - } - Poll::Pending } } @@ -368,9 +295,8 @@ impl Blocks for Outbox { fn get_block(&mut self, hash: BlockHash, addr: &PeerId) -> Request { use std::collections::hash_map::Entry; - let request = Request { - result: Rc::new(RefCell::new(None)), - }; + let request = Request::new(); + match self.block_requests.borrow_mut().entry(hash) { Entry::Vacant(e) => { e.insert(request.clone()); @@ -638,7 +564,7 @@ pub mod test { let remote: net::SocketAddr = ([88, 88, 88, 88], 8333).into(); let mut outbox = Outbox::new(network, crate::protocol::PROTOCOL_VERSION, "test"); - outbox.spawn({ + outbox.executor.spawn({ let mut outbox = outbox.clone(); async move { @@ -648,7 +574,7 @@ pub mod test { .unwrap(); } }); - outbox.spawn({ + outbox.executor.spawn({ let mut outbox = outbox.clone(); async move { @@ -658,7 +584,7 @@ pub mod test { .unwrap(); } }); - assert!(outbox.poll().is_pending()); + assert!(outbox.executor.poll().is_pending()); assert_matches!( messages(&mut outbox, &remote).next(), @@ -666,6 +592,6 @@ pub mod test { ); outbox.block_received(network.genesis_block()); - assert!(outbox.poll().is_ready()); + assert!(outbox.executor.poll().is_ready()); } }