-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
mdns tunneller: use to forward mdns packet
- Loading branch information
0 parents
commit 0ea2ccd
Showing
6 changed files
with
352 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
/target | ||
/Cargo.lock |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
[package] | ||
name = "mdns-tunneller" | ||
version = "0.1.0" | ||
edition = "2021" | ||
|
||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html | ||
|
||
[dependencies] | ||
dns-parser = "0.8.0" | ||
net2 = "0.2" | ||
tokio = { version = "1.18.2", features = ["full"] } | ||
anyhow = "1.0.57" | ||
pnet = "0.30.0" | ||
bytes = "1" | ||
tokio-util = { version = "0.7", features = ["codec"] } | ||
tokio-stream = "0.1.8" | ||
futures = "0.3.21" | ||
clap = { version = "3", features = ["derive"] } | ||
async-channel = "1.6.1" | ||
tracing = "0.1" | ||
tracing-subscriber = "0.2" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
pub fn get_filter_domains() -> Vec<String> { | ||
vec![ | ||
"_homekit._tcp.local".into(), | ||
"_hap._tcp.local".into(), | ||
"_googlecast._tcp.local".into(), | ||
] | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
pub mod config; | ||
pub mod mdns; | ||
pub mod tunnel; | ||
|
||
use anyhow::Result; | ||
use async_channel::Receiver; | ||
use clap::Parser; | ||
use pnet::datalink::{self, NetworkInterface}; | ||
use std::sync::Arc; | ||
use std::thread; | ||
use tokio::{net::TcpListener, sync::Mutex}; | ||
use tokio::{ | ||
net::TcpStream, | ||
sync::mpsc::{self, UnboundedReceiver}, | ||
}; | ||
use tokio_util::codec::{Framed, LengthDelimitedCodec}; | ||
use tracing::{info, Level}; | ||
use tunnel::TunnelPeer; | ||
|
||
use crate::config::get_filter_domains; | ||
|
||
#[derive(Parser)] | ||
enum Args { | ||
Server { | ||
#[clap(short, long)] | ||
addr: String, | ||
#[clap(short, long)] | ||
interface: String, | ||
}, | ||
Client { | ||
#[clap(short, long)] | ||
addr: String, | ||
#[clap(short, long)] | ||
interface: String, | ||
}, | ||
} | ||
|
||
#[tokio::main] | ||
async fn main() -> Result<()> { | ||
tracing_subscriber::fmt().with_max_level(Level::INFO).init(); | ||
|
||
let args = Args::parse(); | ||
let (is_client, addr, iface_name) = match args { | ||
Args::Server { addr, interface } => (false, addr, interface), | ||
Args::Client { addr, interface } => (true, addr, interface), | ||
}; | ||
info!(?is_client, ?addr, ?iface_name); | ||
|
||
let interface_names_match = |iface: &NetworkInterface| iface.name == iface_name; | ||
|
||
// Find the network interface with the provided name | ||
let interfaces = datalink::interfaces(); | ||
let interface = interfaces | ||
.into_iter() | ||
.filter(interface_names_match) | ||
.next() | ||
.unwrap_or_else(|| panic!("No such network interface: {}", iface_name)); | ||
|
||
let (channel_tx, channel_rx) = mpsc::unbounded_channel(); | ||
let (mdns_sender, mut mdns_listener) = mdns::pair(&interface, channel_tx, get_filter_domains()); | ||
|
||
let mdns_sender = Arc::new(Mutex::new(mdns_sender)); | ||
let channel_rx = forward(channel_rx); | ||
|
||
if is_client { | ||
let tcp = TcpStream::connect(&addr).await?; | ||
info!("connected"); | ||
|
||
thread::spawn(move || mdns_listener.listen()); | ||
let tunnel = TunnelPeer { | ||
mdns_sender, | ||
channel_rx, | ||
tcp: Framed::new(tcp, LengthDelimitedCodec::new()), | ||
socket_addr: None, | ||
}; | ||
tunnel.select_run().await; | ||
} else { | ||
let listener = TcpListener::bind(&addr).await?; | ||
info!("start listening"); | ||
|
||
thread::spawn(move || mdns_listener.listen()); | ||
|
||
while let Ok((con, addr)) = listener.accept().await { | ||
info!(?addr, "connected"); | ||
|
||
let mdns_sender = mdns_sender.clone(); | ||
let channel_rx = channel_rx.clone(); | ||
|
||
tokio::spawn(async move { | ||
let tunnel = TunnelPeer { | ||
mdns_sender, | ||
channel_rx, | ||
tcp: Framed::new(con, LengthDelimitedCodec::new()), | ||
socket_addr: Some(addr), | ||
}; | ||
tunnel.select_run().await; | ||
}); | ||
} | ||
} | ||
|
||
Ok(()) | ||
} | ||
|
||
fn forward(mut sc_rx: UnboundedReceiver<Vec<u8>>) -> Receiver<Vec<u8>> { | ||
let (tx, rx) = async_channel::unbounded(); | ||
tokio::spawn(async move { | ||
while let Some(packet) = sc_rx.recv().await { | ||
if tx.send(packet).await.is_err() { | ||
break; | ||
} | ||
} | ||
}); | ||
rx | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
use dns_parser::Packet as mDNSPacket; | ||
use pnet::datalink::{DataLinkReceiver, DataLinkSender}; | ||
use pnet::{ | ||
datalink::{self, Channel::Ethernet, NetworkInterface}, | ||
packet::{ | ||
ethernet::{EtherTypes, EthernetPacket}, | ||
ip::IpNextHeaderProtocols, | ||
ipv4::Ipv4Packet, | ||
udp::UdpPacket, | ||
Packet, | ||
}, | ||
}; | ||
use std::io; | ||
use tokio::sync::mpsc::UnboundedSender; | ||
use tracing::info; | ||
|
||
/// An mDNS listener on a specific interface. | ||
#[allow(non_camel_case_types)] | ||
pub struct mDNSListener { | ||
pub eth_rx: Box<dyn DataLinkReceiver>, | ||
// `EthernetPacket` with `mDNS` | ||
pub channel_tx: UnboundedSender<Vec<u8>>, | ||
pub filter_domains: Vec<String>, | ||
} | ||
|
||
impl mDNSListener { | ||
/// Listen mDNS packet, than send `EthernetPacket` to channel | ||
pub fn listen(&mut self) { | ||
// mDNSPacket<'a> | ||
let mut mdns_buf = Vec::new(); | ||
|
||
while let Ok(packet) = self.eth_rx.next() { | ||
if let Some(eth) = EthernetPacket::new(packet) { | ||
if let Some(mdns) = mdns_packet(ð, &mut mdns_buf) { | ||
if filter_packet(&mdns, &self.filter_domains) { | ||
if self.channel_tx.send(packet.to_vec()).is_err() { | ||
break; | ||
} | ||
} | ||
} | ||
}; | ||
} | ||
} | ||
} | ||
|
||
/// An mDNS Sender on a specific interface. | ||
#[allow(non_camel_case_types)] | ||
pub struct mDNSSender { | ||
pub eth_tx: Box<dyn DataLinkSender>, | ||
} | ||
|
||
impl mDNSSender { | ||
/// packet is a `EthernetPacket` with `mDNS` | ||
pub fn send(&mut self, packet: &[u8]) -> Option<Result<(), io::Error>> { | ||
self.eth_tx.send_to(packet, None) | ||
} | ||
} | ||
|
||
pub fn pair( | ||
interface: &NetworkInterface, | ||
channel_tx: UnboundedSender<Vec<u8>>, | ||
filter_domains: Vec<String>, | ||
) -> (mDNSSender, mDNSListener) { | ||
// Create a channel to receive on | ||
let (tx, rx) = match datalink::channel(&interface, Default::default()) { | ||
Ok(Ethernet(tx, rx)) => (tx, rx), | ||
Ok(_) => panic!("unhandled channel type"), | ||
Err(e) => panic!("unable to create channel: {}", e), | ||
}; | ||
( | ||
mDNSSender { eth_tx: tx }, | ||
mDNSListener { | ||
eth_rx: rx, | ||
channel_tx, | ||
filter_domains, | ||
}, | ||
) | ||
} | ||
|
||
/// get multicast dns packet | ||
fn mdns_packet<'a>(ethernet: &EthernetPacket, buf: &'a mut Vec<u8>) -> Option<mDNSPacket<'a>> { | ||
fn ipv4_packet(payload: &[u8]) -> Option<Ipv4Packet> { | ||
let packet = Ipv4Packet::new(payload)?; | ||
if !packet.get_destination().is_multicast() | ||
|| !matches!(packet.get_next_level_protocol(), IpNextHeaderProtocols::Udp) | ||
{ | ||
return None; | ||
} | ||
Some(packet) | ||
} | ||
|
||
fn udp_packet(payload: &[u8]) -> Option<UdpPacket> { | ||
UdpPacket::new(payload) | ||
} | ||
|
||
match ethernet.get_ethertype() { | ||
EtherTypes::Ipv4 => { | ||
let ipv4_packet = ipv4_packet(ethernet.payload())?; | ||
let udp_packet = udp_packet(ipv4_packet.payload())?; | ||
*buf = udp_packet.payload().to_vec(); | ||
mDNSPacket::parse(buf).ok() | ||
} | ||
_ => None, | ||
} | ||
} | ||
|
||
fn filter_packet(packet: &mDNSPacket, domains: &Vec<String>) -> bool { | ||
let question_matched = packet | ||
.questions | ||
.iter() | ||
.filter(|record| { | ||
let record_name = record.qname.to_string(); | ||
let matched = domains.contains(&record_name); | ||
if matched { | ||
info!("found quary packet, domain: {}", record_name); | ||
} | ||
matched | ||
}) | ||
.count() | ||
> 0; | ||
|
||
let answer_matched = packet | ||
.answers | ||
.iter() | ||
.filter(|record| { | ||
let record_name = record.name.to_string(); | ||
let matched = domains.contains(&record_name); | ||
|
||
if matched { | ||
info!( | ||
"found answser packet, domain: {} at: {:?}", | ||
record_name, &record.data | ||
); | ||
} | ||
matched | ||
}) | ||
.count() | ||
> 0; | ||
|
||
question_matched || answer_matched | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
use crate::mdns::mDNSSender; | ||
use async_channel::Receiver; | ||
use bytes::Bytes; | ||
use futures::SinkExt; | ||
use std::net::SocketAddr; | ||
use std::sync::Arc; | ||
use tokio::net::TcpStream; | ||
use tokio::sync::Mutex; | ||
use tokio_stream::StreamExt; | ||
use tokio_util::codec::{Framed, LengthDelimitedCodec}; | ||
use tracing::{error, info}; | ||
|
||
pub struct TunnelPeer { | ||
pub mdns_sender: Arc<Mutex<mDNSSender>>, | ||
pub channel_rx: Receiver<Vec<u8>>, | ||
pub tcp: Framed<TcpStream, LengthDelimitedCodec>, | ||
pub socket_addr: Option<SocketAddr>, | ||
} | ||
|
||
impl TunnelPeer { | ||
pub async fn select_run(self) { | ||
let TunnelPeer { | ||
mdns_sender, | ||
channel_rx, | ||
mut tcp, | ||
socket_addr, | ||
} = self; | ||
|
||
loop { | ||
tokio::select! { | ||
matched = channel_rx.recv() => { | ||
match matched { | ||
Ok(packet) => { | ||
let bytes = Bytes::copy_from_slice(&packet); | ||
if let Err(e) = tcp.send(bytes).await { | ||
error!(?e, "tcp send err"); | ||
break; | ||
} | ||
}, | ||
Err(_) => break | ||
} | ||
}, | ||
matched = tcp.next() => { | ||
match matched { | ||
Some(Ok(packet)) => { | ||
let mut lock = mdns_sender.lock().await; | ||
if let Some(Err(e)) = lock.send(&packet.to_vec()) { | ||
error!(?e, "mdns sender send err"); | ||
break; | ||
} | ||
}, | ||
Some(Err(e)) => { | ||
error!(?e, "read buf error!"); | ||
break; | ||
}, | ||
None => break | ||
} | ||
}, | ||
} | ||
} | ||
if let Some(addr) = socket_addr { | ||
info!(?addr, "peer closed!"); | ||
} else { | ||
info!("peer closed!"); | ||
} | ||
} | ||
} |