Skip to content

tokio-postgres: use tokio mpsc for CopyBoth streams #30

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions tokio-postgres/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ pub struct Responses {
}

pub struct CopyBothHandles {
pub(crate) stream_receiver: mpsc::Receiver<Result<Message, Error>>,
pub(crate) sink_sender: mpsc::Sender<FrontendMessage>,
pub(crate) stream_receiver: tokio::sync::mpsc::Receiver<Result<Message, Error>>,
pub(crate) sink_sender: tokio::sync::mpsc::Sender<FrontendMessage>,
}

impl Responses {
Expand Down Expand Up @@ -124,8 +124,8 @@ impl InnerClient {

pub fn start_copy_both(&self) -> Result<CopyBothHandles, Error> {
let (sender, receiver) = mpsc::channel(16);
let (stream_sender, stream_receiver) = mpsc::channel(16);
let (sink_sender, sink_receiver) = mpsc::channel(16);
let (stream_sender, stream_receiver) = tokio::sync::mpsc::channel(16);
let (sink_sender, sink_receiver) = tokio::sync::mpsc::channel(16);

let responses = Responses {
receiver,
Expand Down
45 changes: 21 additions & 24 deletions tokio-postgres/src/copy_both.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use crate::client::{InnerClient, Responses};
use crate::codec::FrontendMessage;
use crate::{simple_query, Error};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use futures_channel::mpsc;
use futures_util::{ready, Sink, SinkExt, Stream, StreamExt};
use log::debug;
use pin_project_lite::pin_project;
Expand All @@ -12,6 +11,8 @@ use postgres_protocol::message::frontend::CopyData;
use std::marker::{PhantomData, PhantomPinned};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::sync::mpsc;
use tokio_util::sync::PollSender;

/// The state machine of CopyBothReceiver
///
Expand Down Expand Up @@ -70,7 +71,7 @@ pub struct CopyBothReceiver {
/// Receiver of frontend messages sent by the user using <CopyBothDuplex as Sink>
sink_receiver: mpsc::Receiver<FrontendMessage>,
/// Sender of CopyData contents to be consumed by the user using <CopyBothDuplex as Stream>
stream_sender: mpsc::Sender<Result<Message, Error>>,
stream_sender: PollSender<Result<Message, Error>>,
/// The current state of the subprotocol
state: CopyBothState,
/// Holds a buffered message until we are ready to send it to the user's stream
Expand All @@ -86,7 +87,7 @@ impl CopyBothReceiver {
CopyBothReceiver {
responses,
sink_receiver,
stream_sender,
stream_sender: PollSender::new(stream_sender),
state: CopyBothState::Setup,
buffered_message: None,
}
Expand All @@ -108,10 +109,10 @@ impl CopyBothReceiver {
// Deliver the buffered message (if any) to the user to ensure we can potentially
// buffer a new one in response to a server message
if let Some(message) = self.buffered_message.take() {
match self.stream_sender.poll_ready(cx) {
match self.stream_sender.poll_ready_unpin(cx) {
Poll::Ready(_) => {
// If the receiver has hung up we'll just drop the message
let _ = self.stream_sender.start_send(message);
let _ = self.stream_sender.start_send_unpin(message);
}
Poll::Pending => {
// Stash the message and try again later
Expand Down Expand Up @@ -147,7 +148,7 @@ impl CopyBothReceiver {
match self.state {
CopyNone => self.state = CopyComplete,
CopyComplete => {
self.stream_sender.close_channel();
self.stream_sender.close();
self.sink_receiver.close();
self.state = CommandComplete;
}
Expand All @@ -168,7 +169,7 @@ impl CopyBothReceiver {
Some(Ok(Message::ReadyForQuery(_))) => match self.state {
CommandComplete => {
self.sink_receiver.close();
self.stream_sender.close_channel();
self.stream_sender.close();
}
_ => self.unexpected_message(),
},
Expand All @@ -190,7 +191,7 @@ impl Stream for CopyBothReceiver {
match self.poll_backend(cx) {
Poll::Ready(()) => Poll::Ready(None),
Poll::Pending => match self.state {
Setup | CopyBoth | CopyIn => match ready!(self.sink_receiver.poll_next_unpin(cx)) {
Setup | CopyBoth | CopyIn => match ready!(self.sink_receiver.poll_recv(cx)) {
Some(msg) => Poll::Ready(Some(msg)),
None => match self.state {
// The user has cancelled their interest to this CopyBoth query but we're
Expand Down Expand Up @@ -252,9 +253,7 @@ pin_project! {
/// }
/// ```
pub struct CopyBothDuplex<T> {
#[pin]
sink_sender: mpsc::Sender<FrontendMessage>,
#[pin]
sink_sender: PollSender<FrontendMessage>,
stream_receiver: mpsc::Receiver<Result<Message, Error>>,
buf: BytesMut,
#[pin]
Expand All @@ -267,7 +266,7 @@ impl<T> Stream for CopyBothDuplex<T> {
type Item = Result<Bytes, Error>;

fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Poll::Ready(match ready!(self.project().stream_receiver.poll_next(cx)) {
Poll::Ready(match ready!(self.project().stream_receiver.poll_recv(cx)) {
Some(Ok(Message::CopyData(body))) => Some(Ok(body.into_bytes())),
Some(Ok(_)) => Some(Err(Error::unexpected_message())),
Some(Err(err)) => Some(Err(err)),
Expand All @@ -285,7 +284,7 @@ where
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
self.project()
.sink_sender
.poll_ready(cx)
.poll_ready_unpin(cx)
.map_err(|_| Error::closed())
}

Expand All @@ -309,30 +308,28 @@ where

let data = CopyData::new(data).map_err(Error::encode)?;
this.sink_sender
.start_send(FrontendMessage::CopyData(data))
.start_send_unpin(FrontendMessage::CopyData(data))
.map_err(|_| Error::closed())
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
let mut this = self.project();
let this = self.project();

if !this.buf.is_empty() {
ready!(this.sink_sender.as_mut().poll_ready(cx)).map_err(|_| Error::closed())?;
ready!(this.sink_sender.poll_ready_unpin(cx)).map_err(|_| Error::closed())?;
let data: Box<dyn Buf + Send> = Box::new(this.buf.split().freeze());
let data = CopyData::new(data).map_err(Error::encode)?;
this.sink_sender
.as_mut()
.start_send(FrontendMessage::CopyData(data))
.start_send_unpin(FrontendMessage::CopyData(data))
.map_err(|_| Error::closed())?;
}

this.sink_sender.poll_flush(cx).map_err(|_| Error::closed())
Poll::Ready(Ok(()))
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Error>> {
ready!(self.as_mut().poll_flush(cx))?;
let mut this = self.as_mut().project();
this.sink_sender.disconnect();
let this = self.as_mut().project();
this.sink_sender.close();
Poll::Ready(Ok(()))
}
}
Expand All @@ -356,14 +353,14 @@ where
.await
.map_err(|_| Error::closed())?;

match handles.stream_receiver.next().await.transpose()? {
match handles.stream_receiver.recv().await.transpose()? {
Some(Message::CopyBothResponse(_)) => {}
_ => return Err(Error::unexpected_message()),
}

Ok(CopyBothDuplex {
stream_receiver: handles.stream_receiver,
sink_sender: handles.sink_sender,
sink_sender: PollSender::new(handles.sink_sender),
buf: BytesMut::new(),
_p: PhantomPinned,
_p2: PhantomData,
Expand Down
Loading