Skip to content

Commit

Permalink
Support headers
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin Sparks committed Jul 5, 2024
1 parent dd32a61 commit 7192995
Show file tree
Hide file tree
Showing 6 changed files with 341 additions and 100 deletions.
2 changes: 1 addition & 1 deletion nats_client/src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ impl ConnectionHandle {
echo: None,
jwt: None,
no_responders: None,
headers: None,
headers: Some(true),
}))
.await
.expect("Failed to send `CONNECT`");
Expand Down
88 changes: 61 additions & 27 deletions nats_codec/src/decoding/header.rs
Original file line number Diff line number Diff line change
@@ -1,40 +1,74 @@
/*
use winnow::{combinator::seq, token::{literal, take_until}, Parser};
use std::collections::HashMap;

use super::{char_spliterator, slice_spliterator, CommandDecoderResult, ServerError};
use memchr::memchr_iter;

pub struct HeadersDecoder { pub headers_length: usize }
use crate::{HeaderName, HeaderValue};

impl super::CommandDecoder<crate::HeaderMap> for HeadersDecoder {
const PREFIX: &'static [u8] = b"NATS/1.0\r\n";
use super::slice_spliterator;

fn decode_body(&self, buffer: &[u8]) -> CommandDecoderResult<crate::HeaderMap> {
if self.headers_length > buffer.len() {
return CommandDecoderResult::FrameTooShort;
}
if self.headers_length < buffer.len() || !buffer.ends_with(&crate::CRLF) {
return CommandDecoderResult::FatalError(ServerError::BadHeaders)
}
// Guaranteed to consume entirety of `buffer` due to previous checks
let mut spliterator = slice_spliterator(buffer, &crate::CRLF);
pub enum HeaderDecodeError {
BadLength,
MissingNatsVersion,
BadHeaderName,
BadHeaderValue,
NoColon,
}

let Some((b"NATS/1.0", _)) = spliterator.next() else {
return CommandDecoderResult::FatalError(ServerError::BadHeaders)
pub fn parse_headers(
header_buffer: &[u8],
headers_length: usize,
) -> Result<crate::HeaderMap, HeaderDecodeError> {
if headers_length < 2 {
return Err(HeaderDecodeError::BadLength);
}

let mut spliterator = slice_spliterator(header_buffer, &crate::CRLF);
let Some((b"NATS/1.0", _)) = spliterator.next() else {
return Err(HeaderDecodeError::MissingNatsVersion);
};

let mut headers = HashMap::new();
loop {
let Some((slice, offset)) = spliterator.next() else {
return Err(HeaderDecodeError::BadLength);
};

let mut map: HashMap<_, Vec<_>> = HashMap::new();
for (mut slice, _) in spliterator {
let Ok(crate::Header { name, value }) = parse_header(&mut slice) else {
return CommandDecoderResult::FatalError(ServerError::BadHeaders)
};
map.entry(name).or_default().push(value);
if slice.is_empty() && offset == headers_length {
break;
} else if !slice.is_empty() && offset != headers_length {
let (name, value) = parse_header(slice)?;
let vs: &mut Vec<_> = headers.entry(name).or_default();
vs.push(value);
} else {
return Err(HeaderDecodeError::BadLength);
}
let headers = crate::HeaderMap(map);
CommandDecoderResult::Advance((headers, buffer.len()))
}

Ok(crate::HeaderMap(headers))
}

fn parse_header(slice: &[u8]) -> Result<(HeaderName, HeaderValue), HeaderDecodeError> {
let mut colon_iter = memchr_iter(b':', slice);
let (Some(colon_index), None) = (colon_iter.next(), colon_iter.next()) else {
return Err(HeaderDecodeError::NoColon);
};

let name = &slice[..colon_index];

let value = &slice[colon_index + 1..];
let value = value.strip_prefix(&b" "[..]).unwrap_or(value);
let value = value.strip_suffix(&b" "[..]).unwrap_or(value);

let Ok(name) = std::str::from_utf8(name) else {
return Err(HeaderDecodeError::BadHeaderName);
};

let Ok(value) = std::str::from_utf8(value) else {
return Err(HeaderDecodeError::BadHeaderValue);
};

Ok((HeaderName(name.into()), HeaderValue(value.into())))
}
*/

// Dummy impl; definitely not safe!
/*
Expand Down
81 changes: 24 additions & 57 deletions nats_codec/src/decoding/hmsg.rs
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
use std::collections::HashMap;

use tokio_util::bytes::Bytes;

use super::{char_spliterator, slice_spliterator, CommandDecoderResult, ServerDecodeError};
use super::{
char_spliterator, header::parse_headers, slice_spliterator, CommandDecoderResult,
ServerDecodeError,
};

pub struct Decoder;

struct Metadata<'a> {
subject: &'a [u8],
sid: &'a [u8],
reply_to: Option<&'a [u8]>,
header_bytes: usize,
total_bytes: usize,
}

impl super::CommandDecoder<crate::ServerCommand, ServerDecodeError> for Decoder {
fn decode_body(
&self,
Expand All @@ -22,7 +15,7 @@ impl super::CommandDecoder<crate::ServerCommand, ServerDecodeError> for Decoder
let mut crlf_iter = slice_spliterator(buffer, &crate::CRLF);

// Load fixed sized blocks, namely metadata and NATS version field; buffer if too short
let Some((metadata, _)) = crlf_iter.next() else {
let Some((metadata, metadata_len)) = crlf_iter.next() else {
return CommandDecoderResult::FrameTooShort(None);
};

Expand Down Expand Up @@ -68,62 +61,34 @@ impl super::CommandDecoder<crate::ServerCommand, ServerDecodeError> for Decoder
return CommandDecoderResult::FatalError(ServerDecodeError::BadHMsg);
};

if total_bytes <= header_bytes {
if total_bytes < header_bytes {
return CommandDecoderResult::FatalError(ServerDecodeError::BadHMsg);
}
if total_bytes > buffer.len() {
return CommandDecoderResult::FrameTooShort(None);
}

// Metadata parsing complete!
let metadata = Metadata {
let headers = &buffer[metadata_len..metadata_len + header_bytes];
let payload = &buffer[metadata_len + header_bytes..metadata_len + total_bytes];

let parts = HMsgParts {
subject,
sid,
reply_to,
header_bytes,
total_bytes,
};

// TODO: Add length checks
/*let headers = if header_bytes > 0 {
let parsing = crlf_iter
.take_while(|(_, length)| *length <= header_bytes)
.map(|(slice, ending)| {
if ending != header_bytes {
return Err(());
}
// TODO parse headers here
return Ok((slice.into(), ending));
});
let Ok(headers) = parsing.collect::<Result<Vec<_>, _>>() else {
return CommandDecoderResult::FatalError(Error::BadHMsg);
};
Some(headers)
} else {
None
};*/
// crlf_iter.next();

let Some((payload, _)) = crlf_iter.next() else {
return CommandDecoderResult::FatalError(ServerDecodeError::BadHMsg);
};

let parts = HMsgParts {
subject: metadata.subject,
sid: metadata.sid,
reply_to: metadata.reply_to,
header_bytes: metadata.header_bytes,
total_bytes: metadata.total_bytes,
headers: &b""[..],
headers,
payload,
};
let hmsg = match parts.try_into() {
Ok(hmsg) => hmsg,
Err(e) => return CommandDecoderResult::FatalError(e),
};

CommandDecoderResult::Advance((crate::ServerCommand::HMsg(hmsg), 0))
CommandDecoderResult::Advance((
crate::ServerCommand::HMsg(hmsg),
metadata_len + total_bytes + 2,
))
}
}

Expand All @@ -142,32 +107,34 @@ impl std::convert::TryFrom<HMsgParts<'_>> for crate::HMsg {

fn try_from(value: HMsgParts<'_>) -> Result<Self, Self::Error> {
let Ok(subject) = std::str::from_utf8(value.subject) else {
return Err(Self::Error::BadMsg);
return Err(Self::Error::BadHMsg);
};

let Ok(sid) = std::str::from_utf8(value.sid) else {
return Err(Self::Error::BadMsg);
return Err(Self::Error::BadHMsg);
};

let Ok(reply_to) = value.reply_to.map(std::str::from_utf8).transpose() else {
return Err(Self::Error::BadMsg);
return Err(Self::Error::BadHMsg);
};

if value.total_bytes - value.header_bytes == 0
if value.total_bytes < value.header_bytes
|| value.payload.len() != value.total_bytes - value.header_bytes
{
return Err(Self::Error::BadMsg);
return Err(Self::Error::BadHMsg);
}

let _ = value.headers;
let Ok(headers) = parse_headers(value.headers, value.header_bytes) else {
return Err(Self::Error::BadHMsg);
};

Ok(crate::HMsg {
subject: subject.into(),
sid: sid.into(),
reply_to: reply_to.map(Into::into),
header_bytes: value.header_bytes,
total_bytes: value.total_bytes,
headers: crate::HeaderMap(HashMap::new()),
headers,
payload: Bytes::copy_from_slice(value.payload),
})
}
Expand Down
122 changes: 119 additions & 3 deletions nats_codec/src/decoding/hpub.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,128 @@
use super::{ClientDecodeError, CommandDecoderResult};
use super::{
char_spliterator, header::parse_headers, slice_spliterator, ClientDecodeError,
CommandDecoderResult,
};
use tokio_util::bytes::Bytes;

pub struct Decoder;

impl super::CommandDecoder<crate::ClientCommand, ClientDecodeError> for Decoder {
fn decode_body(
&self,
_buffer: &[u8],
buffer: &[u8],
) -> CommandDecoderResult<crate::ClientCommand, ClientDecodeError> {
CommandDecoderResult::WrongDecoder
let mut crlf_iter = slice_spliterator(buffer, &crate::CRLF);

// Load fixed sized blocks, namely metadata and NATS version field; buffer if too short
let Some((metadata, metadata_len)) = crlf_iter.next() else {
return CommandDecoderResult::FrameTooShort(None);
};

// Split metadata by space and gather 4 to 5 fields
let mut meta_iter = char_spliterator(metadata, b' ');
let (subject, reply_to, header_bytes, total_bytes) = match (
meta_iter.next(),
meta_iter.next(),
meta_iter.next(),
meta_iter.next(),
) {
(Some((subject, _)), Some((reply_to, _)), Some((header_bytes, last)), None) => {
(subject, Some(reply_to), header_bytes, &metadata[last..])
}
(Some((subject, _)), Some((header_bytes, last)), None, None) => {
(subject, None, header_bytes, &metadata[last..])
}
_ => {
return CommandDecoderResult::FrameTooShort(None);
}
};

let (Ok(headers), Ok(totals)) = (
std::str::from_utf8(header_bytes),
std::str::from_utf8(total_bytes),
) else {
return CommandDecoderResult::FatalError(ClientDecodeError::BadHPub);
};
let (Ok(header_bytes), Ok(total_bytes)) =
(headers.parse::<usize>(), totals.parse::<usize>())
else {
return CommandDecoderResult::FatalError(ClientDecodeError::BadHPub);
};

if total_bytes < header_bytes {
return CommandDecoderResult::FatalError(ClientDecodeError::BadHPub);
}
if total_bytes > buffer.len() {
return CommandDecoderResult::FrameTooShort(None);
}

dbg!(buffer[metadata_len..].len());
dbg!(total_bytes);

let headers = &buffer[metadata_len..metadata_len + header_bytes];
dbg!(std::str::from_utf8(headers).unwrap());

let payload = &buffer[metadata_len + header_bytes..metadata_len + total_bytes];
dbg!(std::str::from_utf8(payload).unwrap());

let parts = HPubParts {
subject,
reply_to,
header_bytes,
total_bytes,
headers,
payload,
};

let Ok(hpub) = parts.try_into() else {
return CommandDecoderResult::FatalError(ClientDecodeError::BadHPub);
};

CommandDecoderResult::Advance((
crate::ClientCommand::HPub(hpub),
metadata_len + total_bytes + 2,
))
}
}

struct HPubParts<'a> {
subject: &'a [u8],
reply_to: Option<&'a [u8]>,
header_bytes: usize,
total_bytes: usize,
headers: &'a [u8],
payload: &'a [u8],
}

impl std::convert::TryFrom<HPubParts<'_>> for crate::HPub {
type Error = ClientDecodeError;

fn try_from(value: HPubParts<'_>) -> Result<Self, Self::Error> {
let Ok(subject) = std::str::from_utf8(value.subject) else {
return Err(Self::Error::BadHPub);
};

let Ok(reply_to) = value.reply_to.map(std::str::from_utf8).transpose() else {
return Err(Self::Error::BadHPub);
};

if value.total_bytes < value.header_bytes
|| value.payload.len() != value.total_bytes - value.header_bytes
{
return Err(Self::Error::BadHPub);
}

let Ok(headers) = parse_headers(value.headers, value.header_bytes) else {
return Err(Self::Error::BadHPub);
};

Ok(crate::HPub {
subject: subject.into(),
reply_to: reply_to.map(Into::into),
header_bytes: value.header_bytes,
total_bytes: value.total_bytes,
headers,
payload: Bytes::copy_from_slice(value.payload),
})
}
}
2 changes: 1 addition & 1 deletion nats_codec/src/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ impl tokio_util::codec::Encoder<crate::ClientCommand> for crate::ClientCodec {
match item {
crate::ClientCommand::Connect(c) => connect(c, dst)?,
crate::ClientCommand::Pub(p) => publish(p, dst)?,
crate::ClientCommand::HPub => todo!(),
crate::ClientCommand::HPub(hpub) => todo!(),
crate::ClientCommand::Sub(s) => subscribe(s, dst)?,
crate::ClientCommand::Unsub(us) => unsubscribe(us, dst)?,
crate::ClientCommand::Ping => ping(dst)?,
Expand Down
Loading

0 comments on commit 7192995

Please sign in to comment.