Skip to content

Commit

Permalink
refactor: single WireMsg per stream
Browse files Browse the repository at this point in the history
  • Loading branch information
b-zee authored and joshuef committed Jan 10, 2023
1 parent e33ded1 commit fc6f565
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 69 deletions.
66 changes: 26 additions & 40 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,9 +172,7 @@ fn listen_on_uni_streams(connection: quinn::Connection, tx: Sender<IncomingMsg>)
let _ = tokio::spawn(async move {
let msg = WireMsg::read_from_stream(&mut recv).await;
let msg = match msg {
// Stream was finished or connection closed.
Ok(None) => return,
Ok(Some(msg)) => match msg {
Ok(msg) => match msg {
WireMsg::UserMsg(msg) => Ok(msg),
_ => Err(RecvError::UnexpectedMsgReceived(msg.to_string())),
},
Expand Down Expand Up @@ -227,17 +225,15 @@ fn listen_on_bi_streams(
let _ = tokio::spawn(async move {
let msg = WireMsg::read_from_stream(&mut recv).await;
let msg = match msg {
// Stream was finished or connection closed.
Ok(None) => return,
Ok(Some(WireMsg::UserMsg(msg))) => Ok(msg),
Ok(Some(WireMsg::EndpointEchoReq)) => {
Ok(WireMsg::UserMsg(msg)) => Ok(msg),
Ok(WireMsg::EndpointEchoReq) => {
if let Err(error) = handle_endpoint_echo(send, addr).await {
// TODO: consider more carefully how to handle this
warn!("Error handling endpoint echo request on conn_id {conn_id}: {error}");
}
return;
}
Ok(Some(WireMsg::EndpointVerificationReq(addr))) => {
Ok(WireMsg::EndpointVerificationReq(addr)) => {
if let Err(error) =
handle_endpoint_verification(&endpoint, send, addr).await
{
Expand All @@ -247,7 +243,7 @@ fn listen_on_bi_streams(
return;
}
// We do not expect other types.
Ok(Some(msg)) => Err(RecvError::UnexpectedMsgReceived(msg.to_string())),
Ok(msg) => Err(RecvError::UnexpectedMsgReceived(msg.to_string())),
Err(err) => Err(err),
};

Expand Down Expand Up @@ -399,15 +395,14 @@ impl RecvStream {
}

/// Get the next message sent by the peer over this stream.
pub async fn next(&mut self) -> Result<Option<UsrMsgBytes>, RecvError> {
pub async fn next(&mut self) -> Result<UsrMsgBytes, RecvError> {
match self.next_wire_msg().await? {
Some(WireMsg::UserMsg(msg)) => Ok(Some(msg)),
Some(msg) => Err(RecvError::UnexpectedMsgReceived(msg.to_string())),
None => Ok(None),
WireMsg::UserMsg(msg) => Ok(msg),
msg => Err(RecvError::UnexpectedMsgReceived(msg.to_string())),
}
}

pub(crate) async fn next_wire_msg(&mut self) -> Result<Option<WireMsg>, RecvError> {
pub(crate) async fn next_wire_msg(&mut self) -> Result<WireMsg, RecvError> {
WireMsg::read_from_stream(&mut self.inner).await
}
}
Expand Down Expand Up @@ -452,13 +447,13 @@ async fn handle_endpoint_verification(
.await?;

match WireMsg::read_from_stream(&mut recv_stream).await? {
Some(WireMsg::EndpointEchoResp(_)) => {
WireMsg::EndpointEchoResp(_) => {
trace!("EndpointVerificationReq: Received EndpointEchoResp from {addr}");
Ok(())
}
msg => Err(RpcError::EchoResponseMissing {
peer: addr,
response: msg.map(|m| m.to_string()),
response: Some(msg.to_string()),
}),
}
};
Expand Down Expand Up @@ -630,17 +625,14 @@ mod tests {
let (mut send_stream, mut recv_stream) = p1_conn.open_bi().await?;
send_stream.send_wire_msg(WireMsg::EndpointEchoReq).await?;

if let Some(msg) = timeout(recv_stream.next_wire_msg()).await?? {
if let WireMsg::EndpointEchoResp(addr) = msg {
assert_eq!(addr, peer1.local_addr()?);
} else {
bail!(
"received unexpected message when EndpointEchoResp was expected: {:?}",
msg
);
}
let msg = timeout(recv_stream.next_wire_msg()).await??;
if let WireMsg::EndpointEchoResp(addr) = msg {
assert_eq!(addr, peer1.local_addr()?);
} else {
bail!("did not receive incoming message when one was expected");
bail!(
"received unexpected message when EndpointEchoResp was expected: {:?}",
msg
);
}
}

Expand Down Expand Up @@ -695,16 +687,13 @@ mod tests {
bail!("did not receive incoming connection when one was expected");
};

if let Some(msg) = timeout(recv_stream.next_wire_msg()).await?? {
if let WireMsg::EndpointVerificationResp(true) = msg {
} else {
bail!(
let msg = timeout(recv_stream.next_wire_msg()).await??;
if let WireMsg::EndpointVerificationResp(true) = msg {
} else {
bail!(
"received unexpected message when EndpointVerificationResp(true) was expected: {:?}",
msg
);
}
} else {
bail!("did not receive incoming message when one was expected");
}

// only one msg per bi-stream is supported, let's create a new bi-stream for this test
Expand All @@ -713,16 +702,13 @@ mod tests {
.send_wire_msg(WireMsg::EndpointVerificationReq(local_addr()))
.await?;

if let Some(msg) = timeout(recv_stream.next_wire_msg()).await?? {
if let WireMsg::EndpointVerificationResp(false) = msg {
} else {
bail!(
let msg = timeout(recv_stream.next_wire_msg()).await??;
if let WireMsg::EndpointVerificationResp(false) = msg {
} else {
bail!(
"received unexpected message when EndpointVerificationResp(false) was expected: {:?}",
msg
);
}
} else {
bail!("did not receive incoming message when one was expected");
}
}

Expand Down
19 changes: 6 additions & 13 deletions src/endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,21 +230,14 @@ impl Endpoint {
send_stream.send_wire_msg(WireMsg::EndpointEchoReq).await?;

match timeout(ECHO_SERVICE_QUERY_TIMEOUT, recv_stream.next_wire_msg()).await?? {
Some(WireMsg::EndpointEchoResp(_)) => Ok(()),
Some(other) => {
WireMsg::EndpointEchoResp(_) => Ok(()),
other => {
info!(
"Unexpected message type when verifying reachability: {}",
&other
);
Ok(())
}
None => {
info!(
"Peer {} did not reply when verifying reachability",
peer_addr
);
Ok(())
}
}
}

Expand Down Expand Up @@ -372,10 +365,10 @@ impl Endpoint {
send.send_wire_msg(WireMsg::EndpointEchoReq).await?;

match timeout(ECHO_SERVICE_QUERY_TIMEOUT, recv.next_wire_msg()).await?? {
Some(WireMsg::EndpointEchoResp(addr)) => Ok(addr),
WireMsg::EndpointEchoResp(addr) => Ok(addr),
msg => Err(RpcError::EchoResponseMissing {
peer: contact.remote_address(),
response: msg.map(|m| m.to_string()),
response: Some(msg.to_string()),
}),
}
}
Expand All @@ -392,10 +385,10 @@ impl Endpoint {
.await?;

match timeout(ECHO_SERVICE_QUERY_TIMEOUT, recv.next_wire_msg()).await?? {
Some(WireMsg::EndpointVerificationResp(valid)) => Ok(valid),
WireMsg::EndpointVerificationResp(valid) => Ok(valid),
msg => Err(RpcError::EndpointVerificationRespMissing {
peer: contact.remote_address(),
response: msg.map(|m| m.to_string()),
response: Some(msg.to_string()),
}),
}
}
Expand Down
27 changes: 11 additions & 16 deletions src/wire_msg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use crate::{
utils,
};
use bytes::{Bytes, BytesMut};
use futures::TryFutureExt;
use quinn::VarInt;
use serde::{Deserialize, Serialize};
use std::convert::TryFrom;
use std::{fmt, net::SocketAddr};
Expand Down Expand Up @@ -41,19 +41,11 @@ const ECHO_SRVC_MSG_FLAG: u8 = 0x01;

impl WireMsg {
// Read a message's bytes from the provided stream
pub(crate) async fn read_from_stream(
recv: &mut quinn::RecvStream,
) -> Result<Option<Self>, RecvError> {
/// # Cancellation safety
/// Warning: This method is not cancellation safe!
pub(crate) async fn read_from_stream(recv: &mut quinn::RecvStream) -> Result<Self, RecvError> {
let mut header_bytes = [0; MSG_HEADER_LEN];
match recv.read(&mut header_bytes).err_into().await {
Err(error) => return Err(error),
Ok(None) => return Ok(None),
Ok(Some(len)) => {
if len < MSG_HEADER_LEN {
recv.read_exact(&mut header_bytes[len..]).await?;
}
}
}
recv.read_exact(&mut header_bytes).await?;

let msg_header = MsgHeader::from_bytes(header_bytes);
// https://github.com/rust-lang/rust/issues/70460 for work on a cleaner alternative:
Expand Down Expand Up @@ -81,16 +73,19 @@ impl WireMsg {
recv.read_exact(&mut dst_data).await?;
recv.read_exact(&mut payload_data).await?;

// let sender know we won't receive any more.
let _ = recv.stop(VarInt::from_u32(0));

if payload_data.is_empty() {
Err(RecvError::EmptyMsgPayload)
} else if msg_flag == USER_MSG_FLAG {
Ok(Some(WireMsg::UserMsg((
Ok(WireMsg::UserMsg((
header_data.freeze(),
dst_data.freeze(),
payload_data.freeze(),
))))
)))
} else if msg_flag == ECHO_SRVC_MSG_FLAG {
Ok(Some(bincode::deserialize(&payload_data)?))
Ok(bincode::deserialize(&payload_data)?)
} else {
Err(RecvError::InvalidMsgTypeFlag(msg_flag))
}
Expand Down

0 comments on commit fc6f565

Please sign in to comment.