Skip to content

Commit

Permalink
Fix memleak when dns clusters removed
Browse files Browse the repository at this point in the history
Signed-off-by: Zhonghu Xu <[email protected]>
  • Loading branch information
hzxuzhonghu committed Sep 19, 2024
1 parent d9f38b5 commit 0bf5b40
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 17 deletions.
7 changes: 2 additions & 5 deletions pkg/controller/ads/ads_processor.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,8 @@ func (p *processor) handleCdsResponse(resp *service_discovery_v3.DiscoveryRespon
log.Debugf("unchanged cluster %s", cluster.GetName())
}
}

if len(dnsClusters) > 0 {
// send dns clusters to dns resolver
p.DnsResolverChan <- dnsClusters
}
// send dns clusters to dns resolver, even dnsClusters is empty, we need to send empty list to dns resolver to clear the cache
p.DnsResolverChan <- dnsClusters
removed := p.Cache.ClusterCache.GetResourceNames().Difference(current)
for key := range removed {
p.Cache.UpdateApiClusterStatus(key, core_v2.ApiStatus_DELETE)
Expand Down
27 changes: 20 additions & 7 deletions pkg/dns/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@ var (
)

const (
MaxConcurrency uint32 = 5
RetryAfter = 5 * time.Millisecond
MaxConcurrency uint32 = 5
RetryAfter = 5 * time.Millisecond
DeRefreshInterval = 15 * time.Second
)

type DNSResolver struct {
Expand Down Expand Up @@ -232,6 +233,9 @@ func (r *DNSResolver) resolve(v *pendingResolveDomain) {
if ttl > v.refreshRate {
ttl = v.refreshRate
}
if ttl == 0 {
ttl = DeRefreshInterval
}
if !slices.Equal(entry.addresses, addrs) {
for _, c := range v.clusters {
ready := overwriteDnsCluster(c, v.domainName, addrs)
Expand Down Expand Up @@ -280,14 +284,23 @@ func (r *DNSResolver) refreshDNS() bool {
return true
}

func (r *DNSResolver) GetCacheResult(name string) []string {
var res []string
func (r *DNSResolver) GetDNSAddresses(domain string) []string {
r.RLock()
defer r.RUnlock()
if entry, ok := r.cache[name]; ok {
res = entry.addresses
if entry, ok := r.cache[domain]; ok {
return entry.addresses
}
return nil
}

func (r *DNSResolver) GetAllCachedDomains() []string {
r.RLock()
defer r.RUnlock()
out := make([]string, 0, len(r.cache))
for domain := range r.cache {
out = append(out, domain)
}
return res
return out
}

// doResolve is copied and adapted from github.com/istio/istio/pilot/pkg/model/network.go.
Expand Down
110 changes: 105 additions & 5 deletions pkg/dns/dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"math/rand"
"net"
"reflect"
"slices"
"sync"
"testing"
"time"
Expand All @@ -31,7 +30,9 @@ import (
v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
endpointv3 "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/wrapperspb"
"istio.io/istio/pkg/slices"

core_v2 "kmesh.net/kmesh/api/v2/core"
"kmesh.net/kmesh/pkg/controller/ads"
Expand All @@ -48,13 +49,14 @@ type fakeDNSServer struct {
}

func TestDNS(t *testing.T) {
fakeDNSServer := newFakeDNSServer()
fakeDNSServer := NewFakeDNSServer()

testDNSResolver, err := NewDNSResolver(ads.NewAdsCache())
if err != nil {
t.Fatal(err)
}
stopCh := make(chan struct{})
defer close(stopCh)
testDNSResolver.StartDNSResolver(stopCh)
testDNSResolver.resolvConfServers = []string{fakeDNSServer.Server.PacketConn.LocalAddr().String()}

Expand Down Expand Up @@ -144,7 +146,7 @@ func TestDNS(t *testing.T) {

time.Sleep(2 * time.Second)

res := testDNSResolver.GetCacheResult(testcase.domain)
res := testDNSResolver.GetDNSAddresses(testcase.domain)
if len(res) != 0 || len(testcase.expected) != 0 {
if !reflect.DeepEqual(res, testcase.expected) {
t.Errorf("dns resolve for %s do not match. \n got %v\nwant %v", testcase.domain, res, testcase.expected)
Expand All @@ -153,7 +155,7 @@ func TestDNS(t *testing.T) {
if testcase.expectedAfterTTL != nil {
ttl := time.Duration(math.Min(float64(testcase.ttl), float64(testcase.refreshRate)))
time.Sleep(ttl + 1)
res = testDNSResolver.GetCacheResult(testcase.domain)
res = testDNSResolver.GetDNSAddresses(testcase.domain)
if !reflect.DeepEqual(res, testcase.expectedAfterTTL) {
t.Errorf("dns refresh after ttl failed, for %s do not match. \n got %v\nwant %v", testcase.domain, res, testcase.expectedAfterTTL)
}
Expand Down Expand Up @@ -254,7 +256,7 @@ func TestOverwriteDNSCluster(t *testing.T) {
}
}

func newFakeDNSServer() *fakeDNSServer {
func NewFakeDNSServer() *fakeDNSServer {
var wg sync.WaitGroup
wg.Add(1)
s := &fakeDNSServer{
Expand Down Expand Up @@ -421,3 +423,101 @@ func TestGetPendingResolveDomain(t *testing.T) {
})
}
}

func TestHandleCdsResponseWithDns(t *testing.T) {
cluster1 := &clusterv3.Cluster{
Name: "ut-cluster1",
ClusterDiscoveryType: &clusterv3.Cluster_Type{
Type: clusterv3.Cluster_LOGICAL_DNS,
},
LoadAssignment: &endpointv3.ClusterLoadAssignment{
Endpoints: []*endpointv3.LocalityLbEndpoints{
{
LbEndpoints: []*endpointv3.LbEndpoint{
{
HostIdentifier: &endpointv3.LbEndpoint_Endpoint{
Endpoint: &endpointv3.Endpoint{
Address: &v3.Address{
Address: &v3.Address_SocketAddress{
SocketAddress: &v3.SocketAddress{
Address: "foo.bar",
PortSpecifier: &v3.SocketAddress_PortValue{
PortValue: uint32(9898),
},
},
},
},
},
},
},
},
},
},
},
}
cluster2 := &clusterv3.Cluster{
Name: "ut-cluster2",
ClusterDiscoveryType: &clusterv3.Cluster_Type{
Type: clusterv3.Cluster_STRICT_DNS,
},
LoadAssignment: &endpointv3.ClusterLoadAssignment{
Endpoints: []*endpointv3.LocalityLbEndpoints{
{
LbEndpoints: []*endpointv3.LbEndpoint{
{
HostIdentifier: &endpointv3.LbEndpoint_Endpoint{
Endpoint: &endpointv3.Endpoint{
Address: &v3.Address{
Address: &v3.Address_SocketAddress{
SocketAddress: &v3.SocketAddress{
Address: "foo.baz",
PortSpecifier: &v3.SocketAddress_PortValue{
PortValue: uint32(9898),
},
},
},
},
},
},
},
},
},
},
},
}

testcases := []struct {
name string
clusters []*clusterv3.Cluster
expected []string
}{
{
name: "add clusters with DNS type",
clusters: []*clusterv3.Cluster{cluster1, cluster2},
expected: []string{"foo.bar", "foo.baz"},
},
{
name: "remove all DNS type clusters",
clusters: []*clusterv3.Cluster{},
expected: []string{},
},
}

p := ads.NewController().Processor
stopCh := make(chan struct{})
defer close(stopCh)
dnsResolver, err := NewDNSResolver(ads.NewAdsCache())
assert.NoError(t, err)
dnsResolver.StartDNSResolver(stopCh)
p.DnsResolverChan = dnsResolver.DnsResolverChan
for _, tc := range testcases {
t.Run(tc.name, func(t *testing.T) {
// notify dns resolver
dnsResolver.DnsResolverChan <- tc.clusters
time.Sleep(1 * time.Second)
if !slices.EqualUnordered(tc.expected, dnsResolver.GetAllCachedDomains()) {
t.Errorf("expected domain %v, but found %v", tc.expected, dnsResolver.GetAllCachedDomains())
}
})
}
}

0 comments on commit 0bf5b40

Please sign in to comment.