Skip to content

Commit

Permalink
mdns tunneller: use to forward mdns packet
Browse files Browse the repository at this point in the history
  • Loading branch information
Spxg committed May 23, 2022
0 parents commit 0ea2ccd
Show file tree
Hide file tree
Showing 6 changed files with 352 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
/target
/Cargo.lock
21 changes: 21 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[package]
name = "mdns-tunneller"
version = "0.1.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
dns-parser = "0.8.0"
net2 = "0.2"
tokio = { version = "1.18.2", features = ["full"] }
anyhow = "1.0.57"
pnet = "0.30.0"
bytes = "1"
tokio-util = { version = "0.7", features = ["codec"] }
tokio-stream = "0.1.8"
futures = "0.3.21"
clap = { version = "3", features = ["derive"] }
async-channel = "1.6.1"
tracing = "0.1"
tracing-subscriber = "0.2"
7 changes: 7 additions & 0 deletions src/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
pub fn get_filter_domains() -> Vec<String> {
vec![
"_homekit._tcp.local".into(),
"_hap._tcp.local".into(),
"_googlecast._tcp.local".into(),
]
}
114 changes: 114 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
pub mod config;
pub mod mdns;
pub mod tunnel;

use anyhow::Result;
use async_channel::Receiver;
use clap::Parser;
use pnet::datalink::{self, NetworkInterface};
use std::sync::Arc;
use std::thread;
use tokio::{net::TcpListener, sync::Mutex};
use tokio::{
net::TcpStream,
sync::mpsc::{self, UnboundedReceiver},
};
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use tracing::{info, Level};
use tunnel::TunnelPeer;

use crate::config::get_filter_domains;

#[derive(Parser)]
enum Args {
Server {
#[clap(short, long)]
addr: String,
#[clap(short, long)]
interface: String,
},
Client {
#[clap(short, long)]
addr: String,
#[clap(short, long)]
interface: String,
},
}

#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt().with_max_level(Level::INFO).init();

let args = Args::parse();
let (is_client, addr, iface_name) = match args {
Args::Server { addr, interface } => (false, addr, interface),
Args::Client { addr, interface } => (true, addr, interface),
};
info!(?is_client, ?addr, ?iface_name);

let interface_names_match = |iface: &NetworkInterface| iface.name == iface_name;

// Find the network interface with the provided name
let interfaces = datalink::interfaces();
let interface = interfaces
.into_iter()
.filter(interface_names_match)
.next()
.unwrap_or_else(|| panic!("No such network interface: {}", iface_name));

let (channel_tx, channel_rx) = mpsc::unbounded_channel();
let (mdns_sender, mut mdns_listener) = mdns::pair(&interface, channel_tx, get_filter_domains());

let mdns_sender = Arc::new(Mutex::new(mdns_sender));
let channel_rx = forward(channel_rx);

if is_client {
let tcp = TcpStream::connect(&addr).await?;
info!("connected");

thread::spawn(move || mdns_listener.listen());
let tunnel = TunnelPeer {
mdns_sender,
channel_rx,
tcp: Framed::new(tcp, LengthDelimitedCodec::new()),
socket_addr: None,
};
tunnel.select_run().await;
} else {
let listener = TcpListener::bind(&addr).await?;
info!("start listening");

thread::spawn(move || mdns_listener.listen());

while let Ok((con, addr)) = listener.accept().await {
info!(?addr, "connected");

let mdns_sender = mdns_sender.clone();
let channel_rx = channel_rx.clone();

tokio::spawn(async move {
let tunnel = TunnelPeer {
mdns_sender,
channel_rx,
tcp: Framed::new(con, LengthDelimitedCodec::new()),
socket_addr: Some(addr),
};
tunnel.select_run().await;
});
}
}

Ok(())
}

fn forward(mut sc_rx: UnboundedReceiver<Vec<u8>>) -> Receiver<Vec<u8>> {
let (tx, rx) = async_channel::unbounded();
tokio::spawn(async move {
while let Some(packet) = sc_rx.recv().await {
if tx.send(packet).await.is_err() {
break;
}
}
});
rx
}
141 changes: 141 additions & 0 deletions src/mdns.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
use dns_parser::Packet as mDNSPacket;
use pnet::datalink::{DataLinkReceiver, DataLinkSender};
use pnet::{
datalink::{self, Channel::Ethernet, NetworkInterface},
packet::{
ethernet::{EtherTypes, EthernetPacket},
ip::IpNextHeaderProtocols,
ipv4::Ipv4Packet,
udp::UdpPacket,
Packet,
},
};
use std::io;
use tokio::sync::mpsc::UnboundedSender;
use tracing::info;

/// An mDNS listener on a specific interface.
#[allow(non_camel_case_types)]
pub struct mDNSListener {
pub eth_rx: Box<dyn DataLinkReceiver>,
// `EthernetPacket` with `mDNS`
pub channel_tx: UnboundedSender<Vec<u8>>,
pub filter_domains: Vec<String>,
}

impl mDNSListener {
/// Listen mDNS packet, than send `EthernetPacket` to channel
pub fn listen(&mut self) {
// mDNSPacket<'a>
let mut mdns_buf = Vec::new();

while let Ok(packet) = self.eth_rx.next() {
if let Some(eth) = EthernetPacket::new(packet) {
if let Some(mdns) = mdns_packet(&eth, &mut mdns_buf) {
if filter_packet(&mdns, &self.filter_domains) {
if self.channel_tx.send(packet.to_vec()).is_err() {
break;
}
}
}
};
}
}
}

/// An mDNS Sender on a specific interface.
#[allow(non_camel_case_types)]
pub struct mDNSSender {
pub eth_tx: Box<dyn DataLinkSender>,
}

impl mDNSSender {
/// packet is a `EthernetPacket` with `mDNS`
pub fn send(&mut self, packet: &[u8]) -> Option<Result<(), io::Error>> {
self.eth_tx.send_to(packet, None)
}
}

pub fn pair(
interface: &NetworkInterface,
channel_tx: UnboundedSender<Vec<u8>>,
filter_domains: Vec<String>,
) -> (mDNSSender, mDNSListener) {
// Create a channel to receive on
let (tx, rx) = match datalink::channel(&interface, Default::default()) {
Ok(Ethernet(tx, rx)) => (tx, rx),
Ok(_) => panic!("unhandled channel type"),
Err(e) => panic!("unable to create channel: {}", e),
};
(
mDNSSender { eth_tx: tx },
mDNSListener {
eth_rx: rx,
channel_tx,
filter_domains,
},
)
}

/// get multicast dns packet
fn mdns_packet<'a>(ethernet: &EthernetPacket, buf: &'a mut Vec<u8>) -> Option<mDNSPacket<'a>> {
fn ipv4_packet(payload: &[u8]) -> Option<Ipv4Packet> {
let packet = Ipv4Packet::new(payload)?;
if !packet.get_destination().is_multicast()
|| !matches!(packet.get_next_level_protocol(), IpNextHeaderProtocols::Udp)
{
return None;
}
Some(packet)
}

fn udp_packet(payload: &[u8]) -> Option<UdpPacket> {
UdpPacket::new(payload)
}

match ethernet.get_ethertype() {
EtherTypes::Ipv4 => {
let ipv4_packet = ipv4_packet(ethernet.payload())?;
let udp_packet = udp_packet(ipv4_packet.payload())?;
*buf = udp_packet.payload().to_vec();
mDNSPacket::parse(buf).ok()
}
_ => None,
}
}

fn filter_packet(packet: &mDNSPacket, domains: &Vec<String>) -> bool {
let question_matched = packet
.questions
.iter()
.filter(|record| {
let record_name = record.qname.to_string();
let matched = domains.contains(&record_name);
if matched {
info!("found quary packet, domain: {}", record_name);
}
matched
})
.count()
> 0;

let answer_matched = packet
.answers
.iter()
.filter(|record| {
let record_name = record.name.to_string();
let matched = domains.contains(&record_name);

if matched {
info!(
"found answser packet, domain: {} at: {:?}",
record_name, &record.data
);
}
matched
})
.count()
> 0;

question_matched || answer_matched
}
67 changes: 67 additions & 0 deletions src/tunnel.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use crate::mdns::mDNSSender;
use async_channel::Receiver;
use bytes::Bytes;
use futures::SinkExt;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpStream;
use tokio::sync::Mutex;
use tokio_stream::StreamExt;
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use tracing::{error, info};

pub struct TunnelPeer {
pub mdns_sender: Arc<Mutex<mDNSSender>>,
pub channel_rx: Receiver<Vec<u8>>,
pub tcp: Framed<TcpStream, LengthDelimitedCodec>,
pub socket_addr: Option<SocketAddr>,
}

impl TunnelPeer {
pub async fn select_run(self) {
let TunnelPeer {
mdns_sender,
channel_rx,
mut tcp,
socket_addr,
} = self;

loop {
tokio::select! {
matched = channel_rx.recv() => {
match matched {
Ok(packet) => {
let bytes = Bytes::copy_from_slice(&packet);
if let Err(e) = tcp.send(bytes).await {
error!(?e, "tcp send err");
break;
}
},
Err(_) => break
}
},
matched = tcp.next() => {
match matched {
Some(Ok(packet)) => {
let mut lock = mdns_sender.lock().await;
if let Some(Err(e)) = lock.send(&packet.to_vec()) {
error!(?e, "mdns sender send err");
break;
}
},
Some(Err(e)) => {
error!(?e, "read buf error!");
break;
},
None => break
}
},
}
}
if let Some(addr) = socket_addr {
info!(?addr, "peer closed!");
} else {
info!("peer closed!");
}
}
}

0 comments on commit 0ea2ccd

Please sign in to comment.