Skip to content

Commit 69f5b17

Browse files
committedApr 18, 2024·
feat: reached mvp
1 parent 4ff8207 commit 69f5b17

File tree

10 files changed

+553
-33
lines changed

10 files changed

+553
-33
lines changed
 

‎.github/workflows/ci.yml

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
name: Rust
2+
3+
on:
4+
push:
5+
branches: [ "main" ]
6+
pull_request:
7+
branches: [ "main" ]
8+
9+
env:
10+
CARGO_TERM_COLOR: always
11+
12+
jobs:
13+
build:
14+
15+
runs-on: ubuntu-latest
16+
17+
steps:
18+
- uses: actions/checkout@v3
19+
- name: Build
20+
run: cargo build --verbose
21+
- name: Run tests
22+
run: cargo test --verbose

‎Cargo.lock

+74
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

‎Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ bytes = "1"
88
tokio = {version = "1", features = ["full"]}
99
tracing = "0.1"
1010
rustc-hash = "1"
11+
tracing-subscriber = "0.3.18"
1112

1213
[dev-dependencies]
1314
criterion = "0.5.1"

‎benches/bench_db.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::path::Path;
44
use std::sync::Arc;
55
use std::time::Duration;
66

7-
use criterion::{criterion_group, criterion_main, Criterion, black_box};
7+
use criterion::{black_box, criterion_group, criterion_main, Criterion};
88
use csv::ReaderBuilder;
99
use rayon::prelude::*;
1010

‎src/command.rs

+102
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
use std::time::Duration;
2+
3+
use crate::frame::{Frame, FrameError};
4+
5+
#[derive(Debug)]
6+
pub enum Command {
7+
Ping(Option<String>),
8+
Set(String, String, Duration),
9+
Get(String),
10+
Unknown(String),
11+
}
12+
13+
impl Command {
14+
/// Parse a command from a frame of array of bulk. This method should not be passed
15+
/// another type or frame contents. Because RESP commands are represented as an array of bulk frames.
16+
pub fn from_frame(frame: Frame) -> Result<Self, FrameError> {
17+
// users are aware they should only give an array of bulk frame to this method but we have to
18+
// check anyway. We opted to perform lazy validation as validating the whole array upfront
19+
// can be expensive.
20+
let frames = frame.get_array().ok_or(FrameError::Syntax(
21+
"commands can only be frame array".to_string(),
22+
))?;
23+
if let Some(name) = Self::get_name(frames) {
24+
return match name.to_ascii_uppercase().as_str() {
25+
"PING" => Command::parse_ping(&frames[1..]),
26+
"GET" => Command::parse_get(&frames[1..]),
27+
"SET" => Command::parse_set(&frames[1..]),
28+
_ => Ok(Command::Unknown(name.to_string())),
29+
};
30+
}
31+
Err(FrameError::Syntax(
32+
"RESP command name should be of type bulk frame".to_string(),
33+
))
34+
}
35+
36+
/// parse_get tries to retrieve get args from a slice of frames
37+
fn parse_get(frames: &[Frame]) -> Result<Command, FrameError> {
38+
if frames.len() != 1 {
39+
return Err(FrameError::Syntax(
40+
"PING command takes at most 1 argument".to_string(),
41+
));
42+
}
43+
Ok(Command::Get(Self::get_string(frames, 0)?))
44+
}
45+
46+
fn get_name(frames: &[Frame]) -> Option<&String> {
47+
if frames.is_empty() {
48+
return None;
49+
}
50+
if let Some((_, name)) = frames[0].get_bulk() {
51+
return Some(name);
52+
}
53+
None
54+
}
55+
56+
/// parse_ping tries to retrieve ping args from a slice of frames
57+
fn parse_ping(frames: &[Frame]) -> Result<Command, FrameError> {
58+
match frames.len() {
59+
0 => Ok(Command::Ping(None)),
60+
1 => Ok(Command::Ping(Some(Self::get_string(frames, 0)?))),
61+
_ => Err(FrameError::Syntax(
62+
"PING command takes at most 1 argument".to_string(),
63+
)),
64+
}
65+
}
66+
67+
fn get_string(frames: &[Frame], index: usize) -> Result<String, FrameError> {
68+
match frames[index].get_bulk() {
69+
Some(message) => Ok(message.1.to_owned()),
70+
None => Err(FrameError::Syntax(
71+
"RESP args should be of type bulk frame".to_string(),
72+
)),
73+
}
74+
}
75+
76+
fn parse_set(frames: &[Frame]) -> Result<Command, FrameError> {
77+
let len = frames.len();
78+
if len != 2 && len != 4 {
79+
return Err(FrameError::Syntax(
80+
"SET should take 2 or 4 arguments".to_string(),
81+
));
82+
}
83+
let key = Self::get_string(frames, 0)?;
84+
let value = Self::get_string(frames, 1)?;
85+
86+
let mut ttl = Duration::from_millis(0);
87+
88+
// check if we've got the right option to set the time in millis
89+
if len > 2 && Self::get_string(frames, 2)?.to_uppercase() == "PX" {
90+
let expiration = Self::get_u64(frames, 3)?;
91+
ttl = Duration::from_millis(expiration);
92+
}
93+
94+
Ok(Command::Set(key, value, ttl))
95+
}
96+
97+
fn get_u64(frames: &[Frame], index: usize) -> Result<u64, FrameError> {
98+
let content = Self::get_string(frames, index)?;
99+
let content = content.parse::<u64>().map_err(|_| FrameError::UTF8ToInt)?;
100+
Ok(content)
101+
}
102+
}

‎src/db.rs

+4-5
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
//! We do not store the state for eviction. Time-based eviction is used and we perform lazy eviction.
44
//! If an expired key is read, this key is deleted and no value is returned to the user.
55
//! ...
6-
use std::collections::{BinaryHeap};
6+
use rustc_hash::FxHashMap;
7+
use std::collections::BinaryHeap;
78
use std::fmt::Debug;
89
use std::hash::{Hash, Hasher};
910
use std::sync::atomic::{AtomicUsize, Ordering};
1011
use std::sync::{Arc, Mutex};
1112
use std::time::{Duration, Instant};
12-
use rustc_hash::FxHashMap;
1313

1414
struct Shard {
1515
storage: FxHashMap<String, String>,
@@ -42,7 +42,7 @@ impl Shard {
4242
0
4343
}
4444
}
45-
45+
4646
fn latest_is_expired(&self) -> bool {
4747
if let Some((instant, _)) = self.eviction_state.peek() {
4848
if Instant::now() > *instant {
@@ -51,7 +51,7 @@ impl Shard {
5151
}
5252
false
5353
}
54-
54+
5555
fn del_latest(&mut self) {
5656
if let Some((_, key)) = self.eviction_state.pop() {
5757
self.storage.remove(&key);
@@ -150,7 +150,6 @@ impl Storage {
150150
}
151151
}
152152

153-
154153
#[cfg(test)]
155154
mod tests {
156155
use super::*;

‎src/frame.rs

+182-24
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
//! simple frames are valid at their creation. We do that because we want to pay the cost of checking
99
//! this property only if needed as it is expensive.
1010
11-
use crate::frame::FrameData::Nested;
12-
use tokio::io::ErrorKind;
13-
use tokio::io::{AsyncBufReadExt, AsyncReadExt, BufReader};
11+
use std::fmt::{self, Debug, Display, Formatter};
12+
use tokio::io::{
13+
self, AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader, BufWriter, ErrorKind,
14+
};
15+
use tracing::error;
1416

1517
const CR: u8 = b'\r';
1618
const LF: u8 = b'\n';
@@ -20,7 +22,7 @@ const LF: u8 = b'\n';
2022
/// up front, but we have changed our mind.
2123
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
2224
#[repr(u8)]
23-
pub(crate) enum FrameID {
25+
pub enum FrameID {
2426
Integer = 58, // ':'
2527
// @TODO: remove for now
2628
// Double = 44, // ','
@@ -89,21 +91,156 @@ enum FrameData {
8991
Nested(Vec<Frame>),
9092
}
9193

94+
impl FrameData {
95+
fn get_integer(&self) -> Option<i64> {
96+
match self {
97+
FrameData::Integer(value) => Some(*value),
98+
_ => None,
99+
}
100+
}
101+
fn get_string(&self) -> Option<&String> {
102+
match self {
103+
FrameData::Simple(value) => Some(value),
104+
_ => None,
105+
}
106+
}
107+
fn get_bulk(&self) -> Option<(usize, &String)> {
108+
match self {
109+
FrameData::Bulk(size, data) => Some((*size, data)),
110+
_ => None,
111+
}
112+
}
113+
fn get_boolean(&self) -> Option<bool> {
114+
match self {
115+
FrameData::Boolean(value) => Some(*value),
116+
_ => None,
117+
}
118+
}
119+
pub fn get_nested(&self) -> Option<&Vec<Frame>> {
120+
match self {
121+
FrameData::Nested(value) => Some(value),
122+
_ => None,
123+
}
124+
}
125+
}
126+
92127
#[derive(Debug, Eq, PartialEq)]
93-
struct Frame {
128+
pub struct Frame {
94129
frame_type: FrameID,
95130
frame_data: FrameData,
96131
}
97132

98-
fn validate_bool(data: String) -> Result<bool, FrameError> {
99-
match data.as_str() {
133+
impl Frame {
134+
pub fn get_id(&self) -> FrameID {
135+
self.frame_type
136+
}
137+
138+
pub fn get_array(&self) -> Option<&Vec<Frame>> {
139+
if self.frame_type != FrameID::Array {
140+
return None;
141+
}
142+
self.frame_data.get_nested()
143+
}
144+
pub fn get_bulk(&self) -> Option<(usize, &String)> {
145+
match &self.frame_data {
146+
FrameData::Bulk(size, data) => Some((*size, data)),
147+
_ => None,
148+
}
149+
}
150+
151+
pub async fn write_flush_all<T>(&self, stream: &mut BufWriter<T>) -> io::Result<()>
152+
where
153+
T: AsyncWriteExt + Unpin,
154+
{
155+
stream.write_all(self.to_string().as_bytes()).await?;
156+
stream.flush().await
157+
}
158+
159+
pub fn new_bulk_error(inner: String) -> Frame {
160+
Frame {
161+
frame_type: FrameID::BulkError,
162+
frame_data: FrameData::Bulk(inner.len(), inner),
163+
}
164+
}
165+
166+
pub fn new_simple_string(inner: String) -> Frame {
167+
Frame {
168+
frame_type: FrameID::SimpleString,
169+
frame_data: FrameData::Simple(inner),
170+
}
171+
}
172+
173+
pub fn new_bulk_string(inner: String) -> Frame {
174+
Frame {
175+
frame_type: FrameID::BulkString,
176+
frame_data: FrameData::Bulk(inner.len(), inner),
177+
}
178+
}
179+
180+
pub fn new_null() -> Frame {
181+
Frame {
182+
frame_type: FrameID::Null,
183+
frame_data: FrameData::Null,
184+
}
185+
}
186+
}
187+
188+
impl fmt::Display for Frame {
189+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
190+
match self.frame_type {
191+
FrameID::Integer => {
192+
let value = self.frame_data.get_integer().ok_or(fmt::Error)?;
193+
write!(f, ":{}\r\n", value)
194+
}
195+
FrameID::SimpleString => {
196+
let value = self.frame_data.get_string().ok_or(fmt::Error)?;
197+
write!(f, "+{}\r\n", value)
198+
}
199+
FrameID::SimpleError => {
200+
let value = self.frame_data.get_string().ok_or(fmt::Error)?;
201+
write!(f, "-{}\r\n", value)
202+
}
203+
FrameID::BulkString => {
204+
let (len, value) = self.get_bulk().ok_or(fmt::Error)?;
205+
write!(f, "${}\r\n{}\r\n", len, value)
206+
}
207+
FrameID::BulkError => {
208+
let (len, value) = self.frame_data.get_bulk().ok_or(fmt::Error)?;
209+
write!(f, "!{}\r\n{}\r\n", len, value)
210+
}
211+
FrameID::Boolean => {
212+
let value = self.frame_data.get_boolean().ok_or(fmt::Error)?;
213+
let value = if value { "t" } else { "f" };
214+
write!(f, "#{}\r\n", value)
215+
}
216+
FrameID::Null => {
217+
write!(f, "_\r\n")
218+
}
219+
FrameID::BigNumber => {
220+
let value = self.frame_data.get_string().ok_or(fmt::Error)?;
221+
write!(f, "({}\r\n", value)
222+
}
223+
FrameID::Array => {
224+
let frames = self.frame_data.get_nested().ok_or(fmt::Error)?;
225+
write!(f, "*{}\r\n", frames.len())?;
226+
for v in frames {
227+
write!(f, "{}", v)?;
228+
}
229+
Ok(())
230+
}
231+
}
232+
}
233+
}
234+
235+
fn validate_bool(data: &str) -> Result<bool, FrameError> {
236+
match data {
100237
"t" => Ok(true),
101238
"f" => Ok(false),
102239
_ => Err(FrameError::Invalid),
103240
}
104241
}
105242

106-
async fn decode<T>(stream: &mut BufReader<T>) -> Result<Frame, FrameError>
243+
pub async fn decode<T>(stream: &mut BufReader<T>) -> Result<Frame, FrameError>
107244
where
108245
T: AsyncReadExt + Unpin,
109246
{
@@ -122,7 +259,7 @@ where
122259
let frame_vec = process_aggregate_frames(id, stream).await?;
123260
Ok(Frame {
124261
frame_type: FrameID::Array,
125-
frame_data: Nested(frame_vec),
262+
frame_data: FrameData::Nested(frame_vec),
126263
})
127264
}
128265
}
@@ -138,7 +275,7 @@ where
138275
let data = read_simple_string(stream).await?;
139276
match id {
140277
FrameID::Boolean => {
141-
let bool = validate_bool(data)?;
278+
let bool = validate_bool(&data)?;
142279
Ok(Frame {
143280
frame_type: id,
144281
frame_data: FrameData::Boolean(bool),
@@ -189,7 +326,9 @@ where
189326
T: AsyncReadExt + Unpin,
190327
{
191328
match id {
192-
FrameID::Array => Err(FrameError::Syntax),
329+
FrameID::Array => Err(FrameError::Syntax(
330+
"received aggregate frame in non aggregate decoding".to_string(),
331+
)),
193332
FrameID::BulkString | FrameID::BulkError => process_bulk_frames(id, stream).await,
194333
_ => process_simple_frames(id, stream).await,
195334
}
@@ -249,7 +388,7 @@ where
249388
// to build the right aggregate.
250389
frames.push(Frame {
251390
frame_type: *id,
252-
frame_data: Nested(last_vec_of_frames),
391+
frame_data: FrameData::Nested(last_vec_of_frames),
253392
});
254393
*count -= 1;
255394
if *count != 0 {
@@ -269,9 +408,9 @@ pub enum FrameError {
269408
// Frame is not correctly formatted
270409
Invalid,
271410
// Empty buffer should not be passed to get frame from, so this is an error.
272-
EmptyBuffer,
411+
// EmptyBuffer,
273412
// reached expected EOF
274-
EOF,
413+
Eof,
275414
// Connection unexpectedly reset
276415
ConnectionReset,
277416
// Unidentified IO error
@@ -281,7 +420,23 @@ pub enum FrameError {
281420
// Unknown frame type
282421
Unknown,
283422
// This is a programming error. It should not happen.
284-
Syntax,
423+
Syntax(String),
424+
}
425+
426+
impl Display for FrameError {
427+
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
428+
match self {
429+
FrameError::Incomplete => write!(f, "not enough data to decode a full frame"),
430+
FrameError::Invalid => write!(f, "frame is not correctly formatted"),
431+
FrameError::Eof => write!(f, "seen EOF, this is generally a graceful disconnection"),
432+
FrameError::ConnectionReset => write!(f, "unexpected connection reset"),
433+
// this error should not happen in practice
434+
FrameError::IOError => write!(f, "unexpected IO error"),
435+
FrameError::UTF8ToInt => write!(f, "utf8 to int decoding error"),
436+
FrameError::Unknown => write!(f, "unable to identify the frame type"),
437+
FrameError::Syntax(message) => write!(f, "{}", message),
438+
}
439+
}
285440
}
286441

287442
/// `read_simple_string` gets a simple string from the network. As a reminder, such string does
@@ -295,7 +450,7 @@ where
295450
{
296451
let mut buf = Vec::new();
297452
match stream.read_until(LF, &mut buf).await {
298-
Ok(0) => Err(FrameError::EOF),
453+
Ok(0) => Err(FrameError::Eof),
299454
Ok(size) => {
300455
if size < 2 {
301456
return Err(FrameError::Incomplete);
@@ -308,8 +463,10 @@ where
308463
Ok(String::from_utf8_lossy(&buf[0..size - 2]).to_string())
309464
}
310465
Err(e) if e.kind() == ErrorKind::ConnectionReset => Err(FrameError::ConnectionReset),
311-
// @TODO log e later
312-
Err(e) => Err(FrameError::IOError),
466+
Err(e) => {
467+
error!("unexpected io error: {}", e);
468+
Err(FrameError::IOError)
469+
}
313470
}
314471
}
315472

@@ -322,7 +479,7 @@ where
322479
Some(id) => Ok(id),
323480
None => Err(FrameError::Unknown),
324481
},
325-
Err(e) if e.kind() == ErrorKind::UnexpectedEof => Err(FrameError::EOF),
482+
Err(e) if e.kind() == ErrorKind::UnexpectedEof => Err(FrameError::Eof),
326483
// @TODO log e later
327484
Err(e) => Err(FrameError::IOError),
328485
}
@@ -360,7 +517,7 @@ where
360517
))
361518
}
362519
// The caller will treat EOF differently, so it needs to be returned explicitly
363-
Err(e) if e.kind() == ErrorKind::UnexpectedEof => Err(FrameError::EOF),
520+
Err(e) if e.kind() == ErrorKind::UnexpectedEof => Err(FrameError::Eof),
364521
// @TODO log e later
365522
Err(e) => Err(FrameError::IOError),
366523
}
@@ -369,6 +526,7 @@ where
369526
#[cfg(test)]
370527
mod tests {
371528
use super::*;
529+
372530
#[tokio::test]
373531
async fn decode_test() {
374532
//
@@ -394,7 +552,7 @@ mod tests {
394552
let frame = decode(&mut stream).await;
395553
assert_eq!(
396554
frame,
397-
Err(FrameError::EOF),
555+
Err(FrameError::Eof),
398556
"should return EOF error variant"
399557
);
400558

@@ -441,7 +599,7 @@ mod tests {
441599
let mut stream = BufReader::new("*3\r\n:1\r\n+Two\r\n$5\r\nThree\r\n".as_bytes());
442600
let frame = decode(&mut stream).await.unwrap();
443601
assert_eq!(frame.frame_type, FrameID::Array);
444-
let frame_data = Nested(vec![
602+
let frame_data = FrameData::Nested(vec![
445603
Frame {
446604
frame_type: FrameID::Integer,
447605
frame_data: FrameData::Integer(1),
@@ -461,7 +619,7 @@ mod tests {
461619
let mut stream = BufReader::new("*2\r\n:1\r\n*1\r\n+Three\r\n".as_bytes());
462620
let frame = decode(&mut stream).await.unwrap();
463621
assert_eq!(frame.frame_type, FrameID::Array);
464-
let frame_data = Nested(vec![
622+
let frame_data = FrameData::Nested(vec![
465623
Frame {
466624
frame_type: FrameID::Integer,
467625
frame_data: FrameData::Integer(1),
@@ -480,7 +638,7 @@ mod tests {
480638
let mut stream = BufReader::new("*3\r\n:1\r\n*1\r\n+Three\r\n-Err\r\n".as_bytes());
481639
let frame = decode(&mut stream).await.unwrap();
482640
assert_eq!(frame.frame_type, FrameID::Array);
483-
let frame_data = Nested(vec![
641+
let frame_data = FrameData::Nested(vec![
484642
Frame {
485643
frame_type: FrameID::Integer,
486644
frame_data: FrameData::Integer(1),

‎src/lib.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
1+
mod command;
12
pub mod db;
2-
mod frame;
3+
pub mod frame;
4+
pub mod server;

‎src/main.rs

+15-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,18 @@
1+
use mredis::server::{Config, Server};
2+
13
mod frame;
24

3-
fn main() {
4-
println!("Hello, world!");
5+
#[tokio::main]
6+
pub async fn main() -> std::io::Result<()> {
7+
// console_subscriber::init();
8+
let cfg = Config {
9+
ip_addr: "127.0.0.1".to_string(),
10+
port: 6379,
11+
capacity: 1_000_000,
12+
shard_count: 32,
13+
};
14+
tracing_subscriber::fmt::try_init().expect("unable to initialize logging");
15+
let server = Server::new(&cfg).await;
16+
server.listen().await;
17+
Ok(())
518
}

‎src/server.rs

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
use crate::command::Command;
2+
use crate::db::Storage;
3+
use crate::frame::{decode, Frame, FrameError};
4+
use std::sync::Arc;
5+
use tokio::io::{AsyncWriteExt, BufReader, BufWriter};
6+
use tokio::net::{TcpListener, TcpStream};
7+
use tracing::{debug, error};
8+
9+
pub struct Config {
10+
pub ip_addr: String,
11+
pub port: u16,
12+
pub capacity: usize,
13+
pub shard_count: usize,
14+
}
15+
16+
pub struct Server {
17+
storage: Arc<Storage>,
18+
tcp_listener: TcpListener,
19+
}
20+
21+
impl Server {
22+
pub async fn new(cfg: &Config) -> Self {
23+
let tcp_listener = TcpListener::bind((cfg.ip_addr.to_owned(), cfg.port))
24+
.await
25+
.expect("failed to start TCP server");
26+
let storage = Arc::new(Storage::new(cfg.capacity, cfg.shard_count));
27+
28+
Server {
29+
storage,
30+
tcp_listener,
31+
}
32+
}
33+
34+
pub async fn listen(&self) {
35+
loop {
36+
let conn_string = self.tcp_listener.accept().await;
37+
match conn_string {
38+
Ok((mut stream, addr)) => {
39+
debug!("new connection established: {}", addr);
40+
41+
let state = self.storage.clone();
42+
43+
tokio::spawn(async move { Self::process_stream(&mut stream, state).await });
44+
}
45+
Err(err) => {
46+
debug!("error accepting client connection: {:?}", err);
47+
}
48+
}
49+
}
50+
}
51+
52+
async fn process_stream(stream: &mut TcpStream, state: Arc<Storage>) {
53+
let (reader_half, writer_half) = stream.split();
54+
55+
let mut reader = BufReader::with_capacity(8 * 1024, reader_half);
56+
let mut writer = BufWriter::with_capacity(8 * 1024, writer_half);
57+
58+
loop {
59+
let frame = decode(&mut reader).await;
60+
match frame {
61+
Ok(frame) => {
62+
debug!("command frame received!");
63+
process_frame(frame, &state, &mut writer).await;
64+
}
65+
Err(err) => {
66+
if seen_eof(&err, &mut writer).await {
67+
debug!("client gracefully closed connection");
68+
return;
69+
}
70+
}
71+
}
72+
}
73+
}
74+
}
75+
76+
async fn process_frame<T>(frame: Frame, state: &Arc<Storage>, stream_writer: &mut BufWriter<T>)
77+
where
78+
T: AsyncWriteExt + Unpin,
79+
{
80+
match Command::from_frame(frame) {
81+
Ok(cmd) => {
82+
apply_command(&cmd, state, stream_writer).await;
83+
}
84+
Err(err) => send_error(&err, stream_writer).await,
85+
}
86+
}
87+
88+
async fn apply_command<T: AsyncWriteExt + Unpin>(command: &Command, state: &Arc<Storage>, stream_writer: &mut BufWriter<T>)
89+
{
90+
let response_frame = match command {
91+
Command::Ping(message) => {
92+
if let Some(message) = message {
93+
Frame::new_bulk_string(message.to_string())
94+
} else {
95+
Frame::new_simple_string("PONG".to_string())
96+
}
97+
}
98+
Command::Get(key) => {
99+
if let Some(ans) = state.get_v(key) {
100+
Frame::new_simple_string(ans)
101+
} else {
102+
Frame::new_null()
103+
}
104+
}
105+
Command::Set(key, value, ttl) => {
106+
// if let Some(ans) = state.set_kv(key, value, *ttl) {
107+
// Frame::new_simple_string(ans)
108+
// } else {
109+
// Frame::new_null()
110+
// }
111+
state.set_kv(key, value, *ttl);
112+
Frame::new_simple_string("OK".to_string())
113+
}
114+
Command::Unknown(name) => Frame::new_bulk_error(format!("unknown command: {}", name)),
115+
};
116+
response_frame.write_flush_all(stream_writer).await.unwrap()
117+
}
118+
119+
/// seen_eof filters FrameError because some errors need to be sent back to the client via
120+
/// the network, for instance, syntax errors. In this case, send the error to the client. If EOF
121+
/// return true to the caller. And, only log over error variants.
122+
async fn seen_eof<T: AsyncWriteExt + Unpin>(err: &FrameError, stream_writer: &mut BufWriter<T>) -> bool
123+
{
124+
match err {
125+
FrameError::Eof => true,
126+
FrameError::Incomplete
127+
| FrameError::Invalid
128+
| FrameError::Unknown
129+
| FrameError::UTF8ToInt
130+
| FrameError::Syntax(_) => {
131+
send_error(err, stream_writer).await;
132+
false
133+
}
134+
_ => {
135+
error!("error while decoding frame: {}", err);
136+
false
137+
}
138+
}
139+
}
140+
141+
/// send_error is a wrapper to send errors to the client over the network.
142+
/// These are mostly syntax errors.
143+
async fn send_error<T: AsyncWriteExt + Unpin>(err: &FrameError, stream: &mut BufWriter<T>)
144+
{
145+
let err_frame = Frame::new_bulk_error(err.to_string());
146+
if let Err(err) = err_frame.write_flush_all(stream).await {
147+
error!("failed to write to network: {}", err);
148+
}
149+
}

0 commit comments

Comments
 (0)
Please sign in to comment.