Skip to content

Commit

Permalink
🐛 Fix max_ip_reply_num when CNAME exists
Browse files Browse the repository at this point in the history
  • Loading branch information
mokeyish authored Jan 7, 2024
1 parent 01faf2c commit a7aa39c
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 15 deletions.
11 changes: 3 additions & 8 deletions src/dns_mw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,26 +241,21 @@ mod tests {
}

pub fn with_record(self, record: Record) -> Self {
self.with_multi_records(record.name().clone(), vec![record])
self.with_multi_records(record.name().clone(), record.record_type(), vec![record])
}

pub fn with_multi_records<Name: IntoName + Debug>(
mut self,
name: Name,
record_type: RecordType,
records: Vec<Record>,
) -> Self {
let name = match name.into_name() {
Ok(name) => name,
Err(err) => panic!("invalid Name {}", err),
};

let query = Query::query(
name,
records
.first()
.expect("must at least one record")
.record_type(),
);
let query = Query::query(name, record_type);

self.map.insert(
query.clone(),
Expand Down
84 changes: 77 additions & 7 deletions src/dns_mw_addr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,27 @@ impl Middleware<DnsContext, DnsRequest, DnsResponse, DnsError> for AddressMiddle
Ok(lookup) => Ok({
let mut records = Cow::Borrowed(lookup.records());

if let Some(max_reply_ip_num) = ctx.cfg().max_reply_ip_num() {
let max_reply_ip_num = max_reply_ip_num as usize;
if max_reply_ip_num > 0 && records.len() > max_reply_ip_num {
records.to_mut().truncate(max_reply_ip_num);
if query_type.is_ip_addr() {
if let Some(mut max_reply_ip_num) = ctx.cfg().max_reply_ip_num() {
if max_reply_ip_num > 0 {
let mut truncate = None;
for (i, r) in records.iter().enumerate() {
if matches!(r.data(), Some(RData::A(_) | RData::AAAA(_))) {
max_reply_ip_num -= 1;
if max_reply_ip_num == 0 {
truncate = Some(i + 1);
break;
}
}
}

match truncate {
Some(truncate) if records.len() > truncate => {
records.to_mut().truncate(truncate);
}
_ => (),
}
}
}
}

Expand Down Expand Up @@ -151,6 +168,7 @@ mod tests {
use crate::{
dns_conf::{DomainAddress, RuntimeConfig},
dns_mw::*,
libdns::proto::rr::rdata,
};

#[tokio::test(flavor = "multi_thread")]
Expand Down Expand Up @@ -229,6 +247,7 @@ mod tests {
let mock = DnsMockMiddleware::mock(AddressMiddleware)
.with_multi_records(
"dns.google",
RecordType::A,
vec![
Record::from_rdata(
"dns.google".parse().unwrap(),
Expand Down Expand Up @@ -259,6 +278,7 @@ mod tests {
let mock = DnsMockMiddleware::mock(AddressMiddleware)
.with_multi_records(
"dns.google",
RecordType::A,
vec![
Record::from_rdata(
"dns.google".parse().unwrap(),
Expand Down Expand Up @@ -292,6 +312,7 @@ mod tests {
let mock = DnsMockMiddleware::mock(AddressMiddleware)
.with_multi_records(
"dns.google",
RecordType::A,
vec![
Record::from_rdata(
"dns.google".parse().unwrap(),
Expand Down Expand Up @@ -326,6 +347,7 @@ mod tests {
let mock = DnsMockMiddleware::mock(AddressMiddleware)
.with_multi_records(
"dns.google",
RecordType::A,
vec![
Record::from_rdata(
"dns.google".parse().unwrap(),
Expand Down Expand Up @@ -360,6 +382,7 @@ mod tests {
let mock = DnsMockMiddleware::mock(AddressMiddleware)
.with_multi_records(
"dns.google",
RecordType::A,
vec![
Record::from_rdata(
"dns.google".parse().unwrap(),
Expand Down Expand Up @@ -393,12 +416,13 @@ mod tests {
.with("rr-ttl-max 66")
.with("rr-ttl-min 55")
.with("rr-ttl-reply-max 30")
.with("max-reply-ip-num 2")
.with("max-reply-ip-num 1")
.build();

let mock = DnsMockMiddleware::mock(AddressMiddleware)
.with_multi_records(
"dns.google",
RecordType::A,
vec![Record::from_rdata(
"dns.google".parse().unwrap(),
96,
Expand All @@ -420,12 +444,13 @@ mod tests {
.with("rr-ttl-max 66")
.with("rr-ttl-min 55")
.with("rr-ttl-reply-max 30")
.with("max-reply-ip-num 0")
.with("max-reply-ip-num 2")
.build();

let mock = DnsMockMiddleware::mock(AddressMiddleware)
.with_multi_records(
"dns.google",
RecordType::A,
vec![
Record::from_rdata(
"dns.google".parse().unwrap(),
Expand All @@ -448,7 +473,52 @@ mod tests {

let lookup = mock.lookup("dns.google", RecordType::A).await?;

assert_eq!(lookup.records().len(), 3);
assert_eq!(lookup.records().len(), 2);

Ok(())
}

#[tokio::test(flavor = "multi_thread")]
async fn test_ttl_clip_ttl_cname_max_reply_ip_num_2() -> Result<(), DnsError> {
let cfg = RuntimeConfig::builder()
.with("rr-ttl-max 66")
.with("rr-ttl-min 55")
.with("rr-ttl-reply-max 30")
.with("max-reply-ip-num 2")
.build();

let mock = DnsMockMiddleware::mock(AddressMiddleware)
.with_multi_records(
"dns.google",
RecordType::A,
vec![
Record::from_rdata(
"dns.google".parse().unwrap(),
96,
RData::CNAME(rdata::CNAME("dns.google".parse::<Name>().unwrap())),
),
Record::from_rdata(
"dns.google".parse().unwrap(),
96,
RData::A("8.8.8.8".parse().unwrap()),
),
Record::from_rdata(
"dns.google".parse().unwrap(),
48,
RData::A("8.8.4.4".parse().unwrap()),
),
],
)
.build(cfg);

let lookup = mock.lookup("dns.google", RecordType::A).await?;

let ip_count: u8 = lookup
.record_iter()
.map(|r| matches!(r.data(), Some(RData::A(_) | RData::AAAA(_))) as u8)
.sum();

assert_eq!(ip_count, 2);

Ok(())
}
Expand Down

0 comments on commit a7aa39c

Please sign in to comment.