Skip to content

Commit

Permalink
refactor: revert separate bi/uni channels
Browse files Browse the repository at this point in the history
  • Loading branch information
b-zee authored and bochaco committed Dec 12, 2022
1 parent 38fb7dd commit f88b5a4
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 86 deletions.
9 changes: 4 additions & 5 deletions examples/p2p_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
use bytes::Bytes;
use color_eyre::eyre::Result;
use futures::StreamExt;
use qp2p::{Config, Endpoint};
use std::{
env,
Expand Down Expand Up @@ -54,13 +53,13 @@ async fn main() -> Result<()> {
.expect("Invalid SocketAddr. Use the form 127.0.0.1:1234");
let msg = Bytes::from(MSG_MARCO);
println!("Sending to {:?} --> {:?}\n", peer, msg);
let conn = node.connect_to(&peer).await?;
let (conn, mut incoming) = node.connect_to(&peer).await?;
conn.send((Bytes::new(), Bytes::new(), msg.clone())).await?;
// `Endpoint` no longer having `connection_pool` to hold established connection.
// Which means the connection get closed immediately when it reaches end of life span.
// And causes the receiver side a sending error when reply via the in-coming connection.
// Hence here have to listen for the reply to avoid such error
let reply = conn.accept_uni().next().await.unwrap();
let reply = incoming.next().await.unwrap();
println!("Received from {:?} --> {:?}", peer, reply);
}

Expand All @@ -72,11 +71,11 @@ async fn main() -> Result<()> {
println!("---\n");

// loop over incoming connections
while let Some(connection) = incoming_conns.next().await {
while let Some((connection, mut incoming)) = incoming_conns.next().await {
let src = connection.remote_address();

// loop over incoming messages
while let Some(Ok((_, _, bytes))) = connection.accept_uni().next().await {
while let Ok(Some((_, _, bytes))) = incoming.next().await {
println!("Received from {:?} --> {:?}", src, bytes);
if bytes == *MSG_MARCO {
let reply = Bytes::from(MSG_POLO);
Expand Down
119 changes: 56 additions & 63 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,6 @@ use crate::{
error::{ConnectionError, RecvError, RpcError, SendError, StreamError},
wire_msg::{UsrMsgBytes, WireMsg},
};
use bytes::Bytes;
use futures::lock::Mutex;
use quinn::VarInt;
use std::{fmt, net::SocketAddr, time::Duration};
use tokio::{
sync::mpsc::{Receiver, Sender},
Expand All @@ -24,48 +21,24 @@ const ENDPOINT_VERIFICATION_TIMEOUT: Duration = Duration::from_secs(30);
// Error reason for closing a connection when triggered manually by qp2p apis
const QP2P_CLOSED_CONNECTION: &str = "The connection was closed intentionally by qp2p.";

type ResultUni = Result<(Bytes, Bytes, Bytes), RecvError>;
type ResultBi = Result<((Bytes, Bytes, Bytes), SendStream), RecvError>;
type IncomingMsg = Result<(UsrMsgBytes, Option<SendStream>), RecvError>;

/// The sending API for a connection.
#[derive(Clone)]
pub struct Connection {
inner: quinn::Connection,

// Wrapped in mutex to allow receiving messages on separate threads with shared references to this connection.
rx_uni: Mutex<Receiver<ResultUni>>,
rx_bi: Mutex<Receiver<ResultBi>>,
}

impl Drop for Connection {
fn drop(&mut self) {
self.inner.close(VarInt::from_u32(0), b"lost interest");
}
}

impl Connection {
pub(crate) fn new(connection: quinn::Connection, endpoint: quinn::Endpoint) -> Connection {
let (tx_uni, rx_uni) = tokio::sync::mpsc::channel(INCOMING_MESSAGE_BUFFER_LEN);
let (tx_bi, rx_bi) = tokio::sync::mpsc::channel(INCOMING_MESSAGE_BUFFER_LEN);
listen_on_uni_streams(connection.clone(), tx_uni);
listen_on_bi_streams(connection.clone(), endpoint, tx_bi);

Self {
inner: connection,
rx_uni: Mutex::new(rx_uni),
rx_bi: Mutex::new(rx_bi),
}
}

///
pub async fn recv_bi(&self) -> Option<ResultBi> {
let mut rx = self.rx_bi.lock().await;
rx.recv().await
}
pub(crate) fn new(
connection: quinn::Connection,
endpoint: quinn::Endpoint,
) -> (Connection, ConnectionIncoming) {
let (tx, rx) = tokio::sync::mpsc::channel(INCOMING_MESSAGE_BUFFER_LEN);
listen_on_uni_streams(connection.clone(), tx.clone());
listen_on_bi_streams(connection.clone(), endpoint, tx);

///
pub async fn recv_uni(&self) -> Option<ResultUni> {
let mut rx = self.rx_uni.lock().await;
rx.recv().await
(Self { inner: connection }, ConnectionIncoming(rx))
}

/// Returns `Some(...)` if the connection is closed.
Expand Down Expand Up @@ -163,10 +136,7 @@ impl Connection {
}
}

fn listen_on_uni_streams(
connection: quinn::Connection,
tx: Sender<Result<(Bytes, Bytes, Bytes), RecvError>>,
) {
fn listen_on_uni_streams(connection: quinn::Connection, tx: Sender<IncomingMsg>) {
let id = connection.stable_id();

let _ = tokio::spawn(async move {
Expand Down Expand Up @@ -203,7 +173,7 @@ fn listen_on_uni_streams(
};

// Send away the msg or error
let _ = tx.send(msg).await;
let _ = tx.send(msg.map(|r| (r, None))).await;
});
}

Expand All @@ -215,7 +185,7 @@ fn listen_on_uni_streams(
fn listen_on_bi_streams(
connection: quinn::Connection,
endpoint: quinn::Endpoint,
tx: Sender<Result<((Bytes, Bytes, Bytes), SendStream), RecvError>>,
tx: Sender<IncomingMsg>,
) {
let id = connection.stable_id();
let conn_id = format!("{}{}", connection.remote_address(), connection.stable_id());
Expand Down Expand Up @@ -272,7 +242,7 @@ fn listen_on_bi_streams(
};

// Pass the stream, so it can be used to respond to the user message.
let msg = msg.map(|msg| (msg, SendStream::new(send, conn_id)));
let msg = msg.map(|msg| (msg, Some(SendStream::new(send, conn_id))));
// Send away the msg or error
let _ = tx.send(msg).await;
});
Expand Down Expand Up @@ -316,6 +286,27 @@ impl fmt::Display for StreamId {
}
}

///
#[derive(Debug)]
pub struct ConnectionIncoming(Receiver<IncomingMsg>);
impl ConnectionIncoming {
/// Get the next message sent by the peer, over any stream.
pub async fn next(&mut self) -> Result<Option<UsrMsgBytes>, RecvError> {
if let Some((bytes, _opt)) = self.next_with_stream().await? {
Ok(Some(bytes))
} else {
Ok(None)
}
}

/// Get the next message sent by the peer, over any stream along with the stream to respond with.
pub async fn next_with_stream(
&mut self,
) -> Result<Option<(UsrMsgBytes, Option<SendStream>)>, RecvError> {
self.0.recv().await.transpose()
}
}

/// The sending API for a QUIC stream.
pub struct SendStream {
conn_id: String,
Expand Down Expand Up @@ -501,39 +492,40 @@ mod tests {
let peer2 = quinn::Endpoint::server(config.server.clone(), local_addr())?;

{
let p1_tx = Connection::new(
let (p1_conn, mut p1_incoming) = Connection::new(
peer1.connect(peer2.local_addr()?, SERVER_NAME)?.await?,
peer1.clone(),
);

let p2_tx = if let Some(connection) = timeout(OptionFuture::from(peer2.accept().await))
.await?
.and_then(|c| c.ok())
let (p2_conn, mut p2_incoming) = if let Some(connection) =
timeout(OptionFuture::from(peer2.accept().await))
.await?
.and_then(|c| c.ok())
{
Connection::new(connection, peer2.clone())
} else {
bail!("did not receive incoming connection when one was expected");
};

p1_tx
p1_conn
.open_uni()
.await?
.send_user_msg((Bytes::new(), Bytes::new(), Bytes::from_static(b"hello")))
.await?;

if let Some(Ok((_, _, msg))) = timeout(p2_tx.recv_uni()).await? {
if let Ok(Some((_, _, msg))) = timeout(p2_incoming.next()).await? {
assert_eq!(&msg[..], b"hello");
} else {
bail!("did not receive message when one was expected");
}

p2_tx
p2_conn
.open_uni()
.await?
.send_user_msg((Bytes::new(), Bytes::new(), Bytes::from_static(b"world")))
.await?;

if let Some(Ok((_, _, msg))) = timeout(p1_tx.recv_uni()).await? {
if let Ok(Some((_, _, msg))) = timeout(p1_incoming.next()).await? {
assert_eq!(&msg[..], b"world");
} else {
bail!("did not receive message when one was expected");
Expand Down Expand Up @@ -561,14 +553,15 @@ mod tests {
let peer2 = quinn::Endpoint::server(config.server.clone(), local_addr())?;

// open a connection between the two peers
let p1_conn = Connection::new(
let (p1_conn, _) = Connection::new(
peer1.connect(peer2.local_addr()?, SERVER_NAME)?.await?,
peer1.clone(),
);

let p2_conn = if let Some(connection) = timeout(OptionFuture::from(peer2.accept().await))
.await?
.and_then(|c| c.ok())
let (_p2_conn, mut p2_incoming) = if let Some(connection) =
timeout(OptionFuture::from(peer2.accept().await))
.await?
.and_then(|c| c.ok())
{
Connection::new(connection, peer2.clone())
} else {
Expand All @@ -588,8 +581,8 @@ mod tests {
}

// trying to receive should NOT return an error
match p2_conn.recv_uni().await {
None => {}
match p2_incoming.next().await {
Err(_) => {}
res => bail!("unexpected recv result: {:?}", res),
}

Expand All @@ -607,7 +600,7 @@ mod tests {
let peer2 = quinn::Endpoint::server(config.server.clone(), local_addr())?;

{
let p1_tx = Connection::new(
let (p1_conn, _) = Connection::new(
peer1.connect(peer2.local_addr()?, SERVER_NAME)?.await?,
peer1.clone(),
);
Expand All @@ -623,7 +616,7 @@ mod tests {
bail!("did not receive incoming connection when one was expected");
};

let (mut send_stream, mut recv_stream) = p1_tx.open_bi().await?;
let (mut send_stream, mut recv_stream) = p1_conn.open_bi().await?;
send_stream.send_wire_msg(WireMsg::EndpointEchoReq).await?;

if let Some(msg) = timeout(recv_stream.next_wire_msg()).await?? {
Expand Down Expand Up @@ -659,7 +652,7 @@ mod tests {
peer2.set_default_client_config(config.client);

{
let p1_tx = Connection::new(
let (p1_conn, _) = Connection::new(
peer1.connect(peer2.local_addr()?, SERVER_NAME)?.await?,
peer1.clone(),
);
Expand All @@ -675,7 +668,7 @@ mod tests {
bail!("did not receive incoming connection when one was expected");
};

let (mut send_stream, mut recv_stream) = p1_tx.open_bi().await?;
let (mut send_stream, mut recv_stream) = p1_conn.open_bi().await?;
send_stream
.send_wire_msg(WireMsg::EndpointVerificationReq(peer1.local_addr()?))
.await?;
Expand Down Expand Up @@ -704,7 +697,7 @@ mod tests {
}

// only one msg per bi-stream is supported, let's create a new bi-stream for this test
let (mut send_stream, mut recv_stream) = p1_tx.open_bi().await?;
let (mut send_stream, mut recv_stream) = p1_conn.open_bi().await?;
send_stream
.send_wire_msg(WireMsg::EndpointVerificationReq(local_addr()))
.await?;
Expand Down
Loading

0 comments on commit f88b5a4

Please sign in to comment.