Skip to content

Commit

Permalink
Tidy up
Browse files Browse the repository at this point in the history
  • Loading branch information
cloudhead committed Dec 19, 2021
1 parent 1a4af0c commit 011ae45
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 87 deletions.
3 changes: 3 additions & 0 deletions p2p/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ mod syncmgr;
#[cfg(test)]
mod tests;

// Futures executor.
mod executor;

use addrmgr::AddressManager;
use cbfmgr::FilterManager;
use invmgr::InventoryManager;
Expand Down
146 changes: 146 additions & 0 deletions p2p/src/protocol/executor.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
#![allow(dead_code)]
use std::cell::RefCell;
use std::fmt;
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc;
use std::sync::Arc;
use std::task::{Context, Poll};

struct Waker;

impl std::task::Wake for Waker {
fn wake(self: Arc<Self>) {}
fn wake_by_ref(self: &Arc<Self>) {}
}

type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;

struct Task {
future: Option<BoxFuture<'static, ()>>,
}

impl fmt::Debug for Task {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Task").finish()
}
}

#[derive(Debug, Clone)]
pub struct Request<T> {
result: Rc<RefCell<Option<Result<T, ()>>>>,
}

impl<T> Request<T> {
pub fn new() -> Self {
Self {
result: Rc::new(RefCell::new(None)),
}
}

pub fn complete(&mut self, result: Result<T, ()>) {
*self.result.borrow_mut() = Some(result);
}
}

impl<T: Clone + std::marker::Unpin + std::fmt::Debug> Future for Request<T> {
type Output = Result<T, ()>;

fn poll(
self: std::pin::Pin<&mut Self>,
_ctx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
// TODO: Use `take()` instead of cloning, once you figure it out.
// For now we have to clone, as multiple futures may share the same
// refcell.
if let Some(result) = self.get_mut().result.borrow().clone() {
Poll::Ready(result)
} else {
Poll::Pending
}
}
}

#[derive(Clone, Debug)]
pub struct Executor {
tasks: Rc<RefCell<Vec<Task>>>,
}

impl Executor {
pub fn new() -> Self {
Self {
tasks: Rc::new(RefCell::new(Vec::new())),
}
}

/// Spawn a future to be executed.
pub fn spawn(&mut self, future: impl Future<Output = ()> + 'static) {
self.tasks.borrow_mut().push(Task {
future: Some(Box::pin(future)),
});
}

/// Poll all tasks for completion.
pub fn poll(&mut self) -> Poll<()> {
let mut tasks = self.tasks.borrow_mut();
let waker = Arc::new(Waker).into();
let mut cx = Context::from_waker(&waker);

for task in tasks.iter_mut() {
if let Some(mut fut) = task.future.take() {
if fut.as_mut().poll(&mut cx).is_pending() {
task.future = Some(fut);
}
}
}
// Clear out all completed futures.
tasks.retain(|t| t.future.is_some());

if tasks.is_empty() {
return Poll::Ready(());
}
Poll::Pending
}
}

#[cfg(test)]
mod tests {
use super::*;

struct Random<T> {
val: T,
}

impl<T: Clone + std::fmt::Debug + Unpin> Future for Random<T> {
type Output = T;

fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
if fastrand::bool() {
Poll::Ready(self.val.clone())
} else {
Poll::Pending
}
}
}

#[test]
fn test_executor() {
let mut exe = Executor::new();

exe.spawn(async {
Random { val: 1 }.await;
});
exe.spawn(async {
Random { val: 2 }.await;
});
exe.spawn(async {
Random { val: 3 }.await;
});

loop {
if let Poll::Ready(()) = exe.poll() {
break;
}
}
}
}
100 changes: 13 additions & 87 deletions p2p/src/protocol/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,8 @@
use log::*;
use std::cell::RefCell;
use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::pin::Pin;
use std::rc::Rc;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use std::{fmt, io, net};

pub use crossbeam_channel as chan;
Expand All @@ -34,6 +30,8 @@ use nakamoto_common::block::{BlockHash, BlockHeader, BlockTime, Height};

use crate::protocol::{Event, PeerId};

use super::executor::Executor;
use super::executor::Request;
use super::invmgr::Inventories;
use super::network::Network;
use super::{addrmgr, cbfmgr, invmgr, peermgr, pingmgr, syncmgr, Locators};
Expand Down Expand Up @@ -124,48 +122,6 @@ impl fmt::Display for DisconnectReason {
}
}

struct Waker;

impl std::task::Wake for Waker {
fn wake(self: Arc<Self>) {}
fn wake_by_ref(self: &Arc<Self>) {}
}

type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;

struct Task {
future: Option<BoxFuture<'static, ()>>,
}

impl fmt::Debug for Task {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Task").finish()
}
}

#[derive(Debug, Clone)]
pub struct Request<T> {
result: Rc<RefCell<Option<Result<T, ()>>>>,
}

impl<T: Clone + std::marker::Unpin + std::fmt::Debug> Future for Request<T> {
type Output = Result<T, ()>;

fn poll(
self: std::pin::Pin<&mut Self>,
_ctx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
// TODO: Use `take()` instead of cloning, once you figure it out.
// For now we have to clone, as multiple futures may share the same
// refcell.
if let Some(result) = self.get_mut().result.borrow().clone() {
Poll::Ready(result)
} else {
Poll::Pending
}
}
}

pub(crate) mod message {
use nakamoto_common::bitcoin::consensus::Encodable;
use nakamoto_common::bitcoin::network::message::RawNetworkMessage;
Expand Down Expand Up @@ -204,6 +160,8 @@ pub(crate) mod message {
pub struct Outbox {
/// Protocol version.
version: u32,
/// Futures executor.
executor: Executor,
/// Output queue.
outbound: Rc<RefCell<VecDeque<Io>>>,
/// Message outbox.
Expand All @@ -214,21 +172,19 @@ pub struct Outbox {
builder: message::Builder,
/// Log target.
target: &'static str,

tasks: Rc<RefCell<Vec<Task>>>,
}

impl Outbox {
/// Create a new channel.
pub fn new(network: Network, version: u32, target: &'static str) -> Self {
Self {
version,
executor: Executor::new(),
outbound: Rc::new(RefCell::new(VecDeque::new())),
outbox: Rc::new(RefCell::new(HashMap::new())),
block_requests: Rc::new(RefCell::new(HashMap::new())),
builder: message::Builder::new(network),
target,
tasks: Rc::new(RefCell::new(Vec::new())),
}
}

Expand Down Expand Up @@ -294,38 +250,9 @@ impl Outbox {
pub fn block_received(&mut self, blk: Block) {
let block_hash = blk.block_hash();

if let Some(req) = self.block_requests.borrow_mut().remove(&block_hash) {
*req.result.borrow_mut() = Some(Ok(blk.clone()));
}
}

/// Spawn a future to be executed.
pub fn spawn(&mut self, future: impl Future<Output = ()> + 'static) {
self.tasks.borrow_mut().push(Task {
future: Some(Box::pin(future)),
});
}

/// Poll all tasks for completion.
pub fn poll(&mut self) -> Poll<()> {
let mut tasks = self.tasks.borrow_mut();
let waker = Arc::new(Waker).into();
let mut cx = Context::from_waker(&waker);

for task in tasks.iter_mut() {
if let Some(mut fut) = task.future.take() {
if fut.as_mut().poll(&mut cx).is_pending() {
task.future = Some(fut);
}
}
if let Some(mut req) = self.block_requests.borrow_mut().remove(&block_hash) {
req.complete(Ok(blk.clone()));
}
// Clear out all completed futures.
tasks.retain(|t| t.future.is_some());

if tasks.is_empty() {
return Poll::Ready(());
}
Poll::Pending
}
}

Expand Down Expand Up @@ -368,9 +295,8 @@ impl Blocks for Outbox {
fn get_block(&mut self, hash: BlockHash, addr: &PeerId) -> Request<Block> {
use std::collections::hash_map::Entry;

let request = Request {
result: Rc::new(RefCell::new(None)),
};
let request = Request::new();

match self.block_requests.borrow_mut().entry(hash) {
Entry::Vacant(e) => {
e.insert(request.clone());
Expand Down Expand Up @@ -638,7 +564,7 @@ pub mod test {
let remote: net::SocketAddr = ([88, 88, 88, 88], 8333).into();
let mut outbox = Outbox::new(network, crate::protocol::PROTOCOL_VERSION, "test");

outbox.spawn({
outbox.executor.spawn({
let mut outbox = outbox.clone();

async move {
Expand All @@ -648,7 +574,7 @@ pub mod test {
.unwrap();
}
});
outbox.spawn({
outbox.executor.spawn({
let mut outbox = outbox.clone();

async move {
Expand All @@ -658,14 +584,14 @@ pub mod test {
.unwrap();
}
});
assert!(outbox.poll().is_pending());
assert!(outbox.executor.poll().is_pending());

assert_matches!(
messages(&mut outbox, &remote).next(),
Some(NetworkMessage::GetData(_))
);

outbox.block_received(network.genesis_block());
assert!(outbox.poll().is_ready());
assert!(outbox.executor.poll().is_ready());
}
}

0 comments on commit 011ae45

Please sign in to comment.