From 0ea2ccd335793fdf7492b22065309d1d34052580 Mon Sep 17 00:00:00 2001 From: Spxg Date: Sun, 22 May 2022 11:38:30 +0800 Subject: [PATCH] mdns tunneller: use to forward mdns packet --- .gitignore | 2 + Cargo.toml | 21 ++++++++ src/config.rs | 7 +++ src/main.rs | 114 ++++++++++++++++++++++++++++++++++++++++ src/mdns.rs | 141 ++++++++++++++++++++++++++++++++++++++++++++++++++ src/tunnel.rs | 67 ++++++++++++++++++++++++ 6 files changed, 352 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 src/config.rs create mode 100644 src/main.rs create mode 100644 src/mdns.rs create mode 100644 src/tunnel.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4fffb2f --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +/Cargo.lock diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..ee11684 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "mdns-tunneller" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +dns-parser = "0.8.0" +net2 = "0.2" +tokio = { version = "1.18.2", features = ["full"] } +anyhow = "1.0.57" +pnet = "0.30.0" +bytes = "1" +tokio-util = { version = "0.7", features = ["codec"] } +tokio-stream = "0.1.8" +futures = "0.3.21" +clap = { version = "3", features = ["derive"] } +async-channel = "1.6.1" +tracing = "0.1" +tracing-subscriber = "0.2" diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..b8c3d06 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,7 @@ +pub fn get_filter_domains() -> Vec { + vec![ + "_homekit._tcp.local".into(), + "_hap._tcp.local".into(), + "_googlecast._tcp.local".into(), + ] +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..b8679be --- /dev/null +++ b/src/main.rs @@ -0,0 +1,114 @@ +pub mod config; +pub mod mdns; +pub mod tunnel; + +use anyhow::Result; +use async_channel::Receiver; +use clap::Parser; +use pnet::datalink::{self, NetworkInterface}; +use std::sync::Arc; +use std::thread; +use tokio::{net::TcpListener, sync::Mutex}; +use tokio::{ + net::TcpStream, + sync::mpsc::{self, UnboundedReceiver}, +}; +use tokio_util::codec::{Framed, LengthDelimitedCodec}; +use tracing::{info, Level}; +use tunnel::TunnelPeer; + +use crate::config::get_filter_domains; + +#[derive(Parser)] +enum Args { + Server { + #[clap(short, long)] + addr: String, + #[clap(short, long)] + interface: String, + }, + Client { + #[clap(short, long)] + addr: String, + #[clap(short, long)] + interface: String, + }, +} + +#[tokio::main] +async fn main() -> Result<()> { + tracing_subscriber::fmt().with_max_level(Level::INFO).init(); + + let args = Args::parse(); + let (is_client, addr, iface_name) = match args { + Args::Server { addr, interface } => (false, addr, interface), + Args::Client { addr, interface } => (true, addr, interface), + }; + info!(?is_client, ?addr, ?iface_name); + + let interface_names_match = |iface: &NetworkInterface| iface.name == iface_name; + + // Find the network interface with the provided name + let interfaces = datalink::interfaces(); + let interface = interfaces + .into_iter() + .filter(interface_names_match) + .next() + .unwrap_or_else(|| panic!("No such network interface: {}", iface_name)); + + let (channel_tx, channel_rx) = mpsc::unbounded_channel(); + let (mdns_sender, mut mdns_listener) = mdns::pair(&interface, channel_tx, get_filter_domains()); + + let mdns_sender = Arc::new(Mutex::new(mdns_sender)); + let channel_rx = forward(channel_rx); + + if is_client { + let tcp = TcpStream::connect(&addr).await?; + info!("connected"); + + thread::spawn(move || mdns_listener.listen()); + let tunnel = TunnelPeer { + mdns_sender, + channel_rx, + tcp: Framed::new(tcp, LengthDelimitedCodec::new()), + socket_addr: None, + }; + tunnel.select_run().await; + } else { + let listener = TcpListener::bind(&addr).await?; + info!("start listening"); + + thread::spawn(move || mdns_listener.listen()); + + while let Ok((con, addr)) = listener.accept().await { + info!(?addr, "connected"); + + let mdns_sender = mdns_sender.clone(); + let channel_rx = channel_rx.clone(); + + tokio::spawn(async move { + let tunnel = TunnelPeer { + mdns_sender, + channel_rx, + tcp: Framed::new(con, LengthDelimitedCodec::new()), + socket_addr: Some(addr), + }; + tunnel.select_run().await; + }); + } + } + + Ok(()) +} + +fn forward(mut sc_rx: UnboundedReceiver>) -> Receiver> { + let (tx, rx) = async_channel::unbounded(); + tokio::spawn(async move { + while let Some(packet) = sc_rx.recv().await { + if tx.send(packet).await.is_err() { + break; + } + } + }); + rx +} diff --git a/src/mdns.rs b/src/mdns.rs new file mode 100644 index 0000000..887370d --- /dev/null +++ b/src/mdns.rs @@ -0,0 +1,141 @@ +use dns_parser::Packet as mDNSPacket; +use pnet::datalink::{DataLinkReceiver, DataLinkSender}; +use pnet::{ + datalink::{self, Channel::Ethernet, NetworkInterface}, + packet::{ + ethernet::{EtherTypes, EthernetPacket}, + ip::IpNextHeaderProtocols, + ipv4::Ipv4Packet, + udp::UdpPacket, + Packet, + }, +}; +use std::io; +use tokio::sync::mpsc::UnboundedSender; +use tracing::info; + +/// An mDNS listener on a specific interface. +#[allow(non_camel_case_types)] +pub struct mDNSListener { + pub eth_rx: Box, + // `EthernetPacket` with `mDNS` + pub channel_tx: UnboundedSender>, + pub filter_domains: Vec, +} + +impl mDNSListener { + /// Listen mDNS packet, than send `EthernetPacket` to channel + pub fn listen(&mut self) { + // mDNSPacket<'a> + let mut mdns_buf = Vec::new(); + + while let Ok(packet) = self.eth_rx.next() { + if let Some(eth) = EthernetPacket::new(packet) { + if let Some(mdns) = mdns_packet(ð, &mut mdns_buf) { + if filter_packet(&mdns, &self.filter_domains) { + if self.channel_tx.send(packet.to_vec()).is_err() { + break; + } + } + } + }; + } + } +} + +/// An mDNS Sender on a specific interface. +#[allow(non_camel_case_types)] +pub struct mDNSSender { + pub eth_tx: Box, +} + +impl mDNSSender { + /// packet is a `EthernetPacket` with `mDNS` + pub fn send(&mut self, packet: &[u8]) -> Option> { + self.eth_tx.send_to(packet, None) + } +} + +pub fn pair( + interface: &NetworkInterface, + channel_tx: UnboundedSender>, + filter_domains: Vec, +) -> (mDNSSender, mDNSListener) { + // Create a channel to receive on + let (tx, rx) = match datalink::channel(&interface, Default::default()) { + Ok(Ethernet(tx, rx)) => (tx, rx), + Ok(_) => panic!("unhandled channel type"), + Err(e) => panic!("unable to create channel: {}", e), + }; + ( + mDNSSender { eth_tx: tx }, + mDNSListener { + eth_rx: rx, + channel_tx, + filter_domains, + }, + ) +} + +/// get multicast dns packet +fn mdns_packet<'a>(ethernet: &EthernetPacket, buf: &'a mut Vec) -> Option> { + fn ipv4_packet(payload: &[u8]) -> Option { + let packet = Ipv4Packet::new(payload)?; + if !packet.get_destination().is_multicast() + || !matches!(packet.get_next_level_protocol(), IpNextHeaderProtocols::Udp) + { + return None; + } + Some(packet) + } + + fn udp_packet(payload: &[u8]) -> Option { + UdpPacket::new(payload) + } + + match ethernet.get_ethertype() { + EtherTypes::Ipv4 => { + let ipv4_packet = ipv4_packet(ethernet.payload())?; + let udp_packet = udp_packet(ipv4_packet.payload())?; + *buf = udp_packet.payload().to_vec(); + mDNSPacket::parse(buf).ok() + } + _ => None, + } +} + +fn filter_packet(packet: &mDNSPacket, domains: &Vec) -> bool { + let question_matched = packet + .questions + .iter() + .filter(|record| { + let record_name = record.qname.to_string(); + let matched = domains.contains(&record_name); + if matched { + info!("found quary packet, domain: {}", record_name); + } + matched + }) + .count() + > 0; + + let answer_matched = packet + .answers + .iter() + .filter(|record| { + let record_name = record.name.to_string(); + let matched = domains.contains(&record_name); + + if matched { + info!( + "found answser packet, domain: {} at: {:?}", + record_name, &record.data + ); + } + matched + }) + .count() + > 0; + + question_matched || answer_matched +} diff --git a/src/tunnel.rs b/src/tunnel.rs new file mode 100644 index 0000000..7c5b636 --- /dev/null +++ b/src/tunnel.rs @@ -0,0 +1,67 @@ +use crate::mdns::mDNSSender; +use async_channel::Receiver; +use bytes::Bytes; +use futures::SinkExt; +use std::net::SocketAddr; +use std::sync::Arc; +use tokio::net::TcpStream; +use tokio::sync::Mutex; +use tokio_stream::StreamExt; +use tokio_util::codec::{Framed, LengthDelimitedCodec}; +use tracing::{error, info}; + +pub struct TunnelPeer { + pub mdns_sender: Arc>, + pub channel_rx: Receiver>, + pub tcp: Framed, + pub socket_addr: Option, +} + +impl TunnelPeer { + pub async fn select_run(self) { + let TunnelPeer { + mdns_sender, + channel_rx, + mut tcp, + socket_addr, + } = self; + + loop { + tokio::select! { + matched = channel_rx.recv() => { + match matched { + Ok(packet) => { + let bytes = Bytes::copy_from_slice(&packet); + if let Err(e) = tcp.send(bytes).await { + error!(?e, "tcp send err"); + break; + } + }, + Err(_) => break + } + }, + matched = tcp.next() => { + match matched { + Some(Ok(packet)) => { + let mut lock = mdns_sender.lock().await; + if let Some(Err(e)) = lock.send(&packet.to_vec()) { + error!(?e, "mdns sender send err"); + break; + } + }, + Some(Err(e)) => { + error!(?e, "read buf error!"); + break; + }, + None => break + } + }, + } + } + if let Some(addr) = socket_addr { + info!(?addr, "peer closed!"); + } else { + info!("peer closed!"); + } + } +}