Skip to content

Commit

Permalink
better error handle
Browse files Browse the repository at this point in the history
  • Loading branch information
zephyrchien committed Feb 20, 2022
1 parent 516c183 commit 8ae8544
Show file tree
Hide file tree
Showing 7 changed files with 135 additions and 37 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ edition = "2018"
[dependencies]
cfg-if = "1"
futures = "0.3"
log = "0.4"

clap = "2"
serde = { version = "1", features = ["derive"] }
Expand Down
12 changes: 10 additions & 2 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,16 @@ fn start_from_conf(conf: String) {
setup_dns(dns);
}

let eps: Vec<Endpoint> =
conf.endpoints.into_iter().map(|epc| epc.build()).collect();
let eps: Vec<Endpoint> = conf
.endpoints
.into_iter()
.map(|epc| {
let ep = epc.build();
println!("inited: {}", &ep);
ep
})
.collect();

run_relay(eps);
}

Expand Down
23 changes: 17 additions & 6 deletions src/relay/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::io::Result;
use log::error;
use futures::future::join_all;

mod tcp;
Expand All @@ -17,34 +17,45 @@ pub async fn run(eps: Vec<Endpoint>) {
join_all(workers).await;
}

async fn proxy_tcp(ep: Endpoint) -> Result<()> {
async fn proxy_tcp(ep: Endpoint) {
let Endpoint {
local,
remote,
opts,
..
} = ep;

let lis = TcpListener::bind(local)
.await
.unwrap_or_else(|_| panic!("unable to bind {}", &local));
while let Ok((stream, _)) = lis.accept().await {

loop {
let (stream, _) = match lis.accept().await {
Ok(x) => x,
Err(ref e) => {
error!("failed to accept tcp connection: {}", e);
continue;
}
};
tokio::spawn(tcp::proxy(stream, remote.clone(), opts));
}
Ok(())
}

#[cfg(feature = "udp")]
mod udp;

#[cfg(feature = "udp")]
async fn proxy_udp(ep: Endpoint) -> Result<()> {
async fn proxy_udp(ep: Endpoint) {
let Endpoint {
local,
remote,
opts,
..
} = ep;
udp::proxy(local, remote, opts).await

if let Err(ref e) = udp::proxy(local, remote, opts).await {
panic!("udp forward exit: {}", e);
}
}

fn compute_workers(workers: &[Endpoint]) -> usize {
Expand Down
47 changes: 28 additions & 19 deletions src/relay/tcp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,21 +78,22 @@ pub async fn proxy(
let (ro, wo) = outbound.split();

#[cfg(all(target_os = "linux", feature = "zero-copy"))]
if zero_copy {
let res = if zero_copy {
use zero_copy::copy;
let _ = try_join!(copy(ri, wo, timeout), copy(ro, wi, timeout));
try_join!(copy(ri, wo, timeout), copy(ro, wi, timeout))
} else {
use normal_copy::copy;
let _ = try_join!(copy(ri, wo, timeout), copy(ro, wi, timeout));
}
try_join!(copy(ri, wo, timeout), copy(ro, wi, timeout))
};

#[cfg(not(all(target_os = "linux", feature = "zero-copy")))]
{
let res = {
use normal_copy::copy;
let _ = try_join!(copy(ri, wo, timeout), copy(ro, wi, timeout));
}
try_join!(copy(ri, wo, timeout), copy(ro, wi, timeout))
};

Ok(())
// ignore read/write n bytes
res.map(|_| ())
}

mod normal_copy {
Expand Down Expand Up @@ -148,7 +149,7 @@ mod zero_copy {
if libc::pipe2(pipes.as_mut_ptr() as *mut c_int, O_NONBLOCK) < 0
{
return Err(Error::new(
ErrorKind::Unsupported,
ErrorKind::Other,
"failed to create a pipe",
));
}
Expand Down Expand Up @@ -209,7 +210,7 @@ mod zero_copy {
let mut n: usize = 0;
let mut done = false;

'LOOP: loop {
let res = 'LOOP: loop {
// read until the socket buffer is empty
// or the pipe is filled
timeoutfut(timeout, rx.readable()).await??;
Expand All @@ -224,7 +225,12 @@ mod zero_copy {
clear_readiness(rx, Interest::READABLE);
break;
}
_ => break 'LOOP,
_ => {
break 'LOOP Err(Error::new(
ErrorKind::Other,
"failed to splice from tcp connection",
))
}
}
}
// write until the pipe is empty
Expand All @@ -233,23 +239,26 @@ mod zero_copy {
match splice_n(rpipe, wfd, n) {
x if x > 0 => n -= x as usize,
x if x < 0 && is_wouldblock() => {
clear_readiness(wx, Interest::WRITABLE)
clear_readiness(wx, Interest::WRITABLE);
}
_ => {
break 'LOOP Err(Error::new(
ErrorKind::Other,
"failed to splice to tcp connection",
))
}
_ => break 'LOOP,
}
}
// complete
if done {
break;
break Ok(());
}
}
};

if done {
w.shutdown().await?;
Ok(())
} else {
Err(Error::new(ErrorKind::ConnectionReset, "connection reset"))
}
};
res
}
}

Expand Down
28 changes: 23 additions & 5 deletions src/relay/udp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use std::net::SocketAddr;
use std::sync::{Arc, RwLock};
use std::collections::HashMap;

use log::error;

use tokio::net::UdpSocket;
use tokio::time::timeout as timeoutfut;

Expand All @@ -25,16 +27,28 @@ pub async fn proxy(
..
} = conn_opts;
let sock_map: SockMap = Arc::new(RwLock::new(HashMap::new()));
let local_sock = Arc::new(UdpSocket::bind(&local).await.unwrap());
let local_sock = Arc::new(UdpSocket::bind(&local).await?);
let timeout = Duration::from_secs(timeout as u64);
let mut buf = vec![0u8; BUFFERSIZE];

loop {
let (n, client_addr) = local_sock.recv_from(&mut buf).await?;
let (n, client_addr) = match local_sock.recv_from(&mut buf).await {
Ok(x) => x,
Err(ref e) => {
error!("failed to recv udp packet: {}", e);
continue;
}
};

let remote_addr = remote.to_sockaddr().await?;
let remote_addr = match remote.to_sockaddr().await {
Ok(x) => x,
Err(ref e) => {
error!("failed to resolve remote addr: {}", e);
continue;
}
};

// the socket associated with a unique client
// the old/new socket associated with a unique client
let alloc_sock = match get_socket(&sock_map, &client_addr) {
Some(x) => x,
None => {
Expand All @@ -50,8 +64,12 @@ pub async fn proxy(
}
};

alloc_sock.send_to(&buf[..n], &remote_addr).await?;
if let Err(ref e) = alloc_sock.send_to(&buf[..n], &remote_addr).await {
error!("failed to send udp packet: {}", e);
}
}

// Err(Error::new(ErrorKind::Other, "unknown error"))
}

async fn send_back(
Expand Down
60 changes: 55 additions & 5 deletions src/utils/types.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::io::Result;
use std::fmt::{Formatter, Display};
use std::net::{IpAddr, SocketAddr, ToSocketAddrs};

use crate::dns;
Expand All @@ -14,6 +15,7 @@ pub enum RemoteAddr {

#[derive(Clone, Copy)]
pub struct ConnectOpts {
pub use_udp: bool,
pub fast_open: bool,
pub zero_copy: bool,
pub tcp_timeout: usize,
Expand All @@ -23,10 +25,9 @@ pub struct ConnectOpts {

#[derive(Clone)]
pub struct Endpoint {
pub udp: bool,
pub local: SocketAddr,
pub remote: RemoteAddr,
pub conn_opts: ConnectOpts,
pub opts: ConnectOpts,
}

impl RemoteAddr {
Expand Down Expand Up @@ -71,7 +72,7 @@ impl Endpoint {
local: &str,
remote: &str,
through: &str,
udp: bool,
use_udp: bool,
fast_open: bool,
zero_copy: bool,
tcp_timeout: usize,
Expand Down Expand Up @@ -109,10 +110,10 @@ impl Endpoint {
};

Endpoint {
udp,
local,
remote,
conn_opts: ConnectOpts {
opts: ConnectOpts {
use_udp,
fast_open,
zero_copy,
tcp_timeout,
Expand All @@ -122,3 +123,52 @@ impl Endpoint {
}
}
}

impl Display for RemoteAddr {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
use RemoteAddr::*;
match self {
SocketAddr(x) => write!(f, "{}", x),
DomainName(addr, port) => write!(f, "{}:{}", addr, port),
}
}
}

impl Display for ConnectOpts {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
macro_rules! on_off {
($x: expr) => {
if $x {
"on"
} else {
"off"
}
};
}
if let Some(send_through) = &self.send_through {
write!(f, "send-through={}, ", send_through)?;
}
write!(
f,
"udp-forward={}, tcp-fast-open={}, tcp-zero-copy={}, ",
on_off!(self.use_udp),
on_off!(self.fast_open),
on_off!(self.zero_copy)
)?;
write!(
f,
"tcp-timeout={}s, udp-timeout={}s",
self.tcp_timeout, self.udp_timeout
)
}
}

impl Display for Endpoint {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(
f,
"{} -> {}, options: {}",
&self.local, &self.remote, &self.opts
)
}
}

0 comments on commit 8ae8544

Please sign in to comment.