Skip to content

Commit

Permalink
[dns] add Name::ComapreMultipleLabels() and update Matches() (ope…
Browse files Browse the repository at this point in the history
…nthread#9744)

This commit introduces a new method, `Name::CompareMultipleLabels()`,
to efficiently parse and compare multiple DNS name labels directly
from a message. This is then used to optimize `Name::Matches()`
eliminating the need to read the entire name into a separate buffer.

This commit also updates `Name::Matches()` method to treat the
first label as a single label allowing it to include dot character
(which is useful for service instance label).

Additionally, `test_dns` unit test is updated to validate the
functionality of the new methods.
  • Loading branch information
abtink authored Dec 29, 2023
1 parent 3f99b11 commit 666a9bd
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 48 deletions.
62 changes: 42 additions & 20 deletions src/core/net/dns_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,40 +100,41 @@ Error Header::ResponseCodeToError(Response aResponse)
return error;
}

bool Name::Matches(const char *aFirstLabels, const char *aSecondLabels, const char *aDomain) const
bool Name::Matches(const char *aFirstLabel, const char *aLabels, const char *aDomain) const
{
bool matches = false;
const char *namePtr;
Buffer nameBuffer;
bool matches = false;

VerifyOrExit(!IsEmpty());

if (IsFromCString())
{
namePtr = mString;
const char *namePtr = mString;

if (aFirstLabel != nullptr)
{
matches = CompareAndSkipLabels(namePtr, aFirstLabel, kLabelSeparatorChar);
VerifyOrExit(matches);
}

matches = CompareAndSkipLabels(namePtr, aLabels, kLabelSeparatorChar);
VerifyOrExit(matches);

matches = CompareAndSkipLabels(namePtr, aDomain, kNullChar);
}
else
{
uint16_t offset = mOffset;

SuccessOrExit(ReadName(*mMessage, offset, nameBuffer));
namePtr = nameBuffer;
}

if (aFirstLabels != nullptr)
{
matches = CompareAndSkipLabels(namePtr, aFirstLabels, kLabelSeparatorChar);
VerifyOrExit(matches);
}
if (aFirstLabel != nullptr)
{
SuccessOrExit(CompareLabel(*mMessage, offset, aFirstLabel));
}

if (aSecondLabels != nullptr)
{
matches = CompareAndSkipLabels(namePtr, aSecondLabels, kLabelSeparatorChar);
VerifyOrExit(matches);
SuccessOrExit(CompareMultipleLabels(*mMessage, offset, aLabels));
SuccessOrExit(CompareName(*mMessage, offset, aDomain));
matches = true;
}

matches = CompareAndSkipLabels(namePtr, aDomain, kNullChar);

exit:
return matches;
}
Expand Down Expand Up @@ -431,6 +432,27 @@ Error Name::CompareLabel(const Message &aMessage, uint16_t &aOffset, const char
return error;
}

Error Name::CompareMultipleLabels(const Message &aMessage, uint16_t &aOffset, const char *aLabels)
{
Error error;
LabelIterator iterator(aMessage, aOffset);

while (true)
{
SuccessOrExit(error = iterator.GetNextLabel());
VerifyOrExit(iterator.CompareLabel(aLabels, !kIsSingleLabel), error = kErrorNotFound);

if (*aLabels == kNullChar)
{
aOffset = iterator.mNextLabelOffset;
ExitNow();
}
}

exit:
return error;
}

Error Name::CompareName(const Message &aMessage, uint16_t &aOffset, const char *aName)
{
Error error;
Expand Down
46 changes: 38 additions & 8 deletions src/core/net/dns_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -666,20 +666,27 @@ class Name : public Clearable<Name>
* Matches the `Name` with a given set of labels and domain name.
*
* This method allows the caller to specify name components separately, enabling scenarios like comparing "service
* instance name" with separate instance label, service type, and domain strings.
* instance name" with separate instance label (which can include dot character), service type, and domain strings.
*
* @p aFirstLabels or @p aSecondLabels can be `nullptr` if not needed. But if non-null, these strings MUST NOT
* end with dot. @p aDomain MUST NOT be `nullptr` and MUST always end with a dot `.` character.
* @p aFirstLabel can be `nullptr` if not needed. But if non-null, it is treated as a single label and can itself
* include dot `.` character.
*
* @param[in] aFirstLabels A string of dot separated labels, MUST NOT end with dot. Can be `nullptr`.
* @param[in] aSecondLabels A string of dot separated labels, MUST NOT end with dot. Can be `nullptr`.
* The @p aLabels MUST NOT be `nullptr` and MUST follow "<label1>.<label2>.<label3>", i.e., a sequence of one or
* more labels separated by dot '.' char, and it MUST NOT end with dot `.`.
*
* @p aDomain MUST NOT be `nullptr` and MUST have at least one label and MUST always end with a dot `.` character.
*
* If the above conditions are not satisfied, the behavior of this method is undefined.
*
* @param[in] aFirstLabel A first label to check. Can be `nullptr`.
* @param[in] aLabels A string of dot separated labels, MUST NOT end with dot.
* @param[in] aDomain Domain name. MUST end with dot.
*
* @retval TRUE The name matches the given labels.
* @retval FALSE The name does not match the given labels.
* @retval TRUE The name matches the given components.
* @retval FALSE The name does not match the given components.
*
*/
bool Matches(const char *aFirstLabels, const char *aSecondLabels, const char *aDomain) const;
bool Matches(const char *aFirstLabel, const char *aLabels, const char *aDomain) const;

/**
* Encodes and appends the name to a message.
Expand Down Expand Up @@ -909,6 +916,29 @@ class Name : public Clearable<Name>
*/
static Error CompareLabel(const Message &aMessage, uint16_t &aOffset, const char *aLabel);

/**
* Parses and compares multiple name labels from a message.
*
* Can be used to read and compare a group of labels from an encoded DNS name in a message with possibly more
* labels remaining to read.
*
* The @p aLabels must follow "<label1>.<label2>.<label3>", i.e., a sequence of labels separated by dot '.' char.
*
* @param[in] aMessage The message to read the labels from to compare. `aMessage.GetOffset()` MUST point
* to the start of DNS header (this is used to handle compressed names).
* @param[in,out] aOffset On input, the offset in @p aMessage pointing to the start of the labels to read.
* On exit and only when all labels are successfully read and match @p aLabels,
* @p aOffset is updated to point to the start of the next label.
* @param[in] aLabels A pointer to a null terminated string containing the labels to compare with.
*
* @retval kErrorNone The labels from @p aMessage matches @p aLabels. @p aOffset is updated.
* @retval kErrorNotFound The labels from @p aMessage does not match @p aLabel (note that @p aOffset is not
* updated in this case).
* @retval kErrorParse Name could not be parsed (invalid format).
*
*/
static Error CompareMultipleLabels(const Message &aMessage, uint16_t &aOffset, const char *aLabels);

/**
* Parses and compares a full name from a message with a given name.
*
Expand Down
144 changes: 124 additions & 20 deletions tests/unit/test_dns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ void TestDnsName(void)
struct TestMatches
{
const char *mFullName;
const char *mFirstLabels;
const char *mSecondLabels;
const char *mFirstLabel;
const char *mLabels;
const char *mDomain;
bool mShouldMatch;
};
Expand All @@ -79,6 +79,7 @@ void TestDnsName(void)
const char *domain2;
const char *fullName;
const char *suffixName;
Dns::Name dnsName;

static const uint8_t kEncodedName1[] = {7, 'e', 'x', 'a', 'm', 'p', 'l', 'e', 3, 'c', 'o', 'm', 0};
static const uint8_t kEncodedName2[] = {3, 'f', 'o', 'o', 1, 'a', 2, 'b', 'b', 3, 'e', 'd', 'u', 0};
Expand Down Expand Up @@ -150,21 +151,16 @@ void TestDnsName(void)

static const TestMatches kTestMatches[] = {
{"foo.bar.local.", "foo", "bar", "local.", true},
{"foo.bar.local.", "foo.bar", nullptr, "local.", true},
{"foo.bar.local.", nullptr, "foo.bar", "local.", true},
{"foo.bar.local.", nullptr, nullptr, "foo.bar.local.", true},
{"foo.bar.local.", "foo", "ba", "local.", false},
{"foo.bar.local.", "fooooo", "bar", "local.", false},
{"foo.bar.local.", "foo", "bar", "locall.", false},
{"foo.bar.local.", "f", "bar", "local.", false},
{"foo.bar.local.", "foo", "barr", "local.", false},
{"foo.bar.local.", "foo", "bar", ".local.", false},
{"My Lovely Instance._mt._udp.local.", "mY lovely instancE", "_mt._udp", "local.", true},
{"My Lovely Instance._mt._udp.local.", "mY lovely instancE._mt", "_udp", "local.", true},
{"_s1._sub._srv._udp.default.service.arpa.", "_s1._sub", "_srv._udp", "default.service.arpa.", true},
{"_s1._sub._srv._udp.default.service.arpa.", "_s1._sub", "_srv._udp", "default.service.arpa", false},
{"_s1._sub._srv._udp.default.service.arpa.", "_s1._sub", "_srv._udp.", "default.service.arpa.", false},
{"_s1._sub._srv._udp.default.service.arpa.", "_s1._sub.", "_srv._udp", "default.service.arpa.", false},
{"My Lovely Instance._mt._udp.local.", nullptr, "mY lovely instancE._mt._udp", "local.", true},
{"_s1._sub._srv._udp.default.service.arpa.", "_s1", "_sub._srv._udp", "default.service.arpa.", true},
};

printf("================================================================\n");
Expand Down Expand Up @@ -555,20 +551,48 @@ void TestDnsName(void)

for (const TestMatches &test : kTestMatches)
{
Dns::Name name;

printf(" \"%s\"\n", test.mFullName);

name.Set(test.mFullName);
VerifyOrQuit(name.Matches(test.mFirstLabels, test.mSecondLabels, test.mDomain) == test.mShouldMatch);
dnsName.Set(test.mFullName);
VerifyOrQuit(dnsName.Matches(test.mFirstLabel, test.mLabels, test.mDomain) == test.mShouldMatch);

IgnoreError(message->SetLength(0));
SuccessOrQuit(name.AppendTo(*message));
SuccessOrQuit(dnsName.AppendTo(*message));

name.SetFromMessage(*message, 0);
VerifyOrQuit(name.Matches(test.mFirstLabels, test.mSecondLabels, test.mDomain) == test.mShouldMatch);
dnsName.SetFromMessage(*message, 0);
VerifyOrQuit(dnsName.Matches(test.mFirstLabel, test.mLabels, test.mDomain) == test.mShouldMatch);
}

IgnoreError(message->SetLength(0));
dnsName.SetFromMessage(*message, 0);
SuccessOrQuit(Dns::Name::AppendLabel("Name.With.Dot", *message));
SuccessOrQuit(Dns::Name::AppendName("_srv._udp.local.", *message));

VerifyOrQuit(dnsName.Matches("Name.With.Dot", "_srv._udp", "local."));
VerifyOrQuit(dnsName.Matches("nAme.with.dOT", "_srv._udp", "local."));
VerifyOrQuit(dnsName.Matches("Name.With.Dot", "_srv", "_udp.local."));

VerifyOrQuit(!dnsName.Matches("Name", "With.Dot._srv._udp", "local."));
VerifyOrQuit(!dnsName.Matches("Name.", "With.Dot._srv._udp", "local."));
VerifyOrQuit(!dnsName.Matches("Name.With", "Dot._srv._udp", "local."));

VerifyOrQuit(!dnsName.Matches("Name.With.Dott", "_srv._udp", "local."));
VerifyOrQuit(!dnsName.Matches("Name.With.Dot.", "_srv._udp", "local."));
VerifyOrQuit(!dnsName.Matches("Name.With.Dot", "_srv._tcp", "local."));
VerifyOrQuit(!dnsName.Matches("Name.With.Dot", "_srv._udp", "arpa."));

offset = 0;
SuccessOrQuit(Dns::Name::ReadName(*message, offset, name));
dnsName.Set(name);

VerifyOrQuit(dnsName.Matches("Name.With.Dot", "_srv._udp", "local."));
VerifyOrQuit(dnsName.Matches("nAme.with.dOT", "_srv._udp", "local."));
VerifyOrQuit(dnsName.Matches("Name.With.Dot", "_srv", "_udp.local."));
VerifyOrQuit(!dnsName.Matches("Name.With.Dott", "_srv._udp", "local."));
VerifyOrQuit(!dnsName.Matches("Name.With.Dot.", "_srv._udp", "local."));
VerifyOrQuit(!dnsName.Matches("Name.With.Dot", "_srv._tcp", "local."));
VerifyOrQuit(!dnsName.Matches("Name.With.Dot", "_srv._udp", "arpa."));

message->Free();
testFreeInstance(instance);
}
Expand Down Expand Up @@ -601,6 +625,13 @@ void TestDnsCompressedName(void)
static const char *kName3Labels[] = {"ISI", "ARPA"};
static const char *kName4Labels[] = {"Human.Readable", "F", "ISI", "ARPA"};

static const char *kName1MultiLabels[] = {"F.ISI", "ARPA"};
static const char *kName2MultiLabels1[] = {"FOO", "F.ISI.ARPA."};
static const char *kName2MultiLabels2[] = {"FOO.F.", "ISI.ARPA."};

static const char kName1BadMultiLabels[] = "F.ISI.ARPA.MORE";
static const char kName2BadMultiLabels[] = "FOO.F.IS";

static const char kExpectedReadName1[] = "F.ISI.ARPA.";
static const char kExpectedReadName2[] = "FOO.F.ISI.ARPA.";
static const char kExpectedReadName3[] = "ISI.ARPA.";
Expand Down Expand Up @@ -712,11 +743,35 @@ void TestDnsCompressedName(void)
VerifyOrQuit(offset == name1Offset + sizeof(kEncodedName), "Name::ReadName() returned incorrect offset");

offset = name1Offset;

for (const char *nameLabel : kName1Labels)
{
SuccessOrQuit(Dns::Name::CompareLabel(*message, offset, nameLabel));
}
SuccessOrQuit(Dns::Name::CompareName(*message, offset, "."));

offset = name1Offset;
for (const char *nameLabel : kName1Labels)
{
SuccessOrQuit(Dns::Name::CompareMultipleLabels(*message, offset, nameLabel));
}
SuccessOrQuit(Dns::Name::CompareName(*message, offset, "."));

offset = name1Offset;
SuccessOrQuit(Dns::Name::CompareMultipleLabels(*message, offset, kExpectedReadName1));
SuccessOrQuit(Dns::Name::CompareName(*message, offset, "."));

offset = name1Offset;
VerifyOrQuit(Dns::Name::CompareMultipleLabels(*message, offset, kBadName) == kErrorNotFound);

offset = name1Offset;
VerifyOrQuit(Dns::Name::CompareMultipleLabels(*message, offset, kName1BadMultiLabels) == kErrorNotFound);

offset = name1Offset;
for (const char *nameLabels : kName1MultiLabels)
{
SuccessOrQuit(Dns::Name::CompareMultipleLabels(*message, offset, nameLabels));
}
SuccessOrQuit(Dns::Name::CompareName(*message, offset, "."));

offset = name1Offset;
SuccessOrQuit(Dns::Name::CompareName(*message, offset, kExpectedReadName1));
Expand Down Expand Up @@ -769,11 +824,42 @@ void TestDnsCompressedName(void)
VerifyOrQuit(offset == name2Offset + kName2EncodedSize, "Name::ReadName() returned incorrect offset");

offset = name2Offset;

for (const char *nameLabel : kName2Labels)
{
SuccessOrQuit(Dns::Name::CompareLabel(*message, offset, nameLabel));
}
SuccessOrQuit(Dns::Name::CompareName(*message, offset, "."));

offset = name2Offset;
for (const char *nameLabel : kName2Labels)
{
SuccessOrQuit(Dns::Name::CompareMultipleLabels(*message, offset, nameLabel));
}
SuccessOrQuit(Dns::Name::CompareName(*message, offset, "."));

offset = name2Offset;
SuccessOrQuit(Dns::Name::CompareMultipleLabels(*message, offset, kExpectedReadName2));
SuccessOrQuit(Dns::Name::CompareName(*message, offset, "."));

offset = name2Offset;
VerifyOrQuit(Dns::Name::CompareMultipleLabels(*message, offset, kBadName) == kErrorNotFound);

offset = name2Offset;
VerifyOrQuit(Dns::Name::CompareMultipleLabels(*message, offset, kName2BadMultiLabels) == kErrorNotFound);

offset = name2Offset;
for (const char *nameLabels : kName2MultiLabels1)
{
SuccessOrQuit(Dns::Name::CompareMultipleLabels(*message, offset, nameLabels));
}
SuccessOrQuit(Dns::Name::CompareName(*message, offset, "."));

offset = name2Offset;
for (const char *nameLabels : kName2MultiLabels2)
{
SuccessOrQuit(Dns::Name::CompareMultipleLabels(*message, offset, nameLabels));
}
SuccessOrQuit(Dns::Name::CompareName(*message, offset, "."));

offset = name2Offset;
SuccessOrQuit(Dns::Name::CompareName(*message, offset, kExpectedReadName2));
Expand Down Expand Up @@ -826,11 +912,22 @@ void TestDnsCompressedName(void)
VerifyOrQuit(offset == name3Offset + kName3EncodedSize, "Name::ReadName() returned incorrect offset");

offset = name3Offset;

for (const char *nameLabel : kName3Labels)
{
SuccessOrQuit(Dns::Name::CompareLabel(*message, offset, nameLabel));
}
SuccessOrQuit(Dns::Name::CompareName(*message, offset, "."));

offset = name3Offset;
for (const char *nameLabel : kName3Labels)
{
SuccessOrQuit(Dns::Name::CompareMultipleLabels(*message, offset, nameLabel));
}
SuccessOrQuit(Dns::Name::CompareName(*message, offset, "."));

offset = name3Offset;
SuccessOrQuit(Dns::Name::CompareMultipleLabels(*message, offset, kExpectedReadName3));
SuccessOrQuit(Dns::Name::CompareName(*message, offset, "."));

offset = name3Offset;
SuccessOrQuit(Dns::Name::CompareName(*message, offset, kExpectedReadName3));
Expand Down Expand Up @@ -880,11 +977,18 @@ void TestDnsCompressedName(void)
VerifyOrQuit(offset == name4Offset + kName4EncodedSize, "Name::ParseName() returned incorrect offset");

offset = name4Offset;

for (const char *nameLabel : kName4Labels)
{
SuccessOrQuit(Dns::Name::CompareLabel(*message, offset, nameLabel));
}
SuccessOrQuit(Dns::Name::CompareName(*message, offset, "."));

offset = name4Offset;
for (const char *nameLabel : kName4Labels)
{
SuccessOrQuit(Dns::Name::CompareMultipleLabels(*message, offset, nameLabel));
}
SuccessOrQuit(Dns::Name::CompareName(*message, offset, "."));

offset = name4Offset;
SuccessOrQuit(Dns::Name::CompareName(*message, offset, *message, offset), "Name::CompareName() with itself failed");
Expand Down

0 comments on commit 666a9bd

Please sign in to comment.