Skip to content

Commit

Permalink
lib/connections: Fix and optimize registry (syncthing#7996)
Browse files Browse the repository at this point in the history
Registry.Get used a full sort to get the minimum of a list, and the sort
was broken because util.AddressUnspecifiedLess assumed it could find out
whether an address is IPv4 or IPv6 from its Network method. However,
net.(TCP|UDP)Addr.Network always returns "tcp"/"udp".
  • Loading branch information
greatroar authored Oct 6, 2021
1 parent c94b797 commit 7c292cc
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 106 deletions.
3 changes: 2 additions & 1 deletion lib/connections/quic_dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ func (d *quicDialer) Dial(ctx context.Context, _ protocol.DeviceID, uri *url.URL
// Given we always pass the connection to quic, it assumes it's a remote connection it never closes it,
// So our wrapper around it needs to close it, but it only needs to close it if it's not the listening connection.
var createdConn net.PacketConn
if listenConn := registry.Get(uri.Scheme, packetConnLess); listenConn != nil {
listenConn := registry.Get(uri.Scheme, packetConnUnspecified)
if listenConn != nil {
conn = listenConn.(net.PacketConn)
} else {
if packetConn, err := net.ListenPacket("udp", ":0"); err != nil {
Expand Down
10 changes: 6 additions & 4 deletions lib/connections/quic_misc.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ import (
"net/url"

"github.com/lucas-clemente/quic-go"
"github.com/syncthing/syncthing/lib/util"
)

var (
Expand Down Expand Up @@ -63,7 +62,10 @@ func (q *quicTlsConn) ConnectionState() tls.ConnectionState {
return q.Session.ConnectionState().TLS.ConnectionState
}

// Sort available packet connections by ip address, preferring unspecified local address.
func packetConnLess(i interface{}, j interface{}) bool {
return util.AddressUnspecifiedLess(i.(net.PacketConn).LocalAddr(), j.(net.PacketConn).LocalAddr())
func packetConnUnspecified(conn interface{}) bool {
// Since QUIC connections are wrapped, we can't do a simple typecheck
// on *net.UDPAddr here.
addr := conn.(net.PacketConn).LocalAddr()
host, _, err := net.SplitHostPort(addr.String())
return err == nil && net.ParseIP(host).IsUnspecified()
}
44 changes: 27 additions & 17 deletions lib/connections/registry/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
package registry

import (
"sort"
"strings"

"github.com/syncthing/syncthing/lib/sync"
Expand Down Expand Up @@ -46,34 +45,45 @@ func (r *Registry) Unregister(scheme string, item interface{}) {
candidates := r.available[scheme]
for i, existingItem := range candidates {
if existingItem == item {
copy(candidates[i:], candidates[i+1:])
candidates[i] = candidates[len(candidates)-1]
candidates[len(candidates)-1] = nil
r.available[scheme] = candidates[:len(candidates)-1]
break
}
}
}

func (r *Registry) Get(scheme string, less func(i, j interface{}) bool) interface{} {
// Get returns an item for a schema compatible with the given scheme.
// If any item satisfies preferred, that has precedence over other items.
func (r *Registry) Get(scheme string, preferred func(interface{}) bool) interface{} {
r.mut.Lock()
defer r.mut.Unlock()

candidates := make([]interface{}, 0)
var (
best interface{}
bestPref bool
bestScheme string
)
for availableScheme, items := range r.available {
// quic:// should be considered ok for both quic4:// and quic6://
if strings.HasPrefix(scheme, availableScheme) {
candidates = append(candidates, items...)
if !strings.HasPrefix(scheme, availableScheme) {
continue
}
for _, item := range items {
better := best == nil
pref := preferred(item)
if !better {
// In case of a tie, prefer "quic" to "quic[46]" etc.
better = pref &&
(!bestPref || len(availableScheme) < len(bestScheme))
}
if !better {
continue
}
best, bestPref, bestScheme = item, pref, availableScheme
}
}

if len(candidates) == 0 {
return nil
}

sort.Slice(candidates, func(i, j int) bool {
return less(candidates[i], candidates[j])
})
return candidates[0]
return best
}

func Register(scheme string, item interface{}) {
Expand All @@ -84,6 +94,6 @@ func Unregister(scheme string, item interface{}) {
Default.Unregister(scheme, item)
}

func Get(scheme string, less func(i, j interface{}) bool) interface{} {
return Default.Get(scheme, less)
func Get(scheme string, preferred func(interface{}) bool) interface{} {
return Default.Get(scheme, preferred)
}
51 changes: 38 additions & 13 deletions lib/connections/registry/registry_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,18 @@
package registry

import (
"net"
"testing"
)

func TestRegistry(t *testing.T) {
r := New()

if res := r.Get("int", intLess); res != nil {
want := func(i int) func(interface{}) bool {
return func(x interface{}) bool { return x.(int) == i }
}

if res := r.Get("int", want(1)); res != nil {
t.Error("unexpected")
}

Expand All @@ -24,30 +29,28 @@ func TestRegistry(t *testing.T) {
r.Register("int6", 6)
r.Register("int6", 66)

if res := r.Get("int", intLess).(int); res != 1 {
if res := r.Get("int", want(1)).(int); res != 1 {
t.Error("unexpected", res)
}

// int is prefix of int4, so returns 1
if res := r.Get("int4", intLess).(int); res != 1 {
if res := r.Get("int4", want(1)).(int); res != 1 {
t.Error("unexpected", res)
}

r.Unregister("int", 1)

// Check that falls through to 11
if res := r.Get("int", intLess).(int); res != 11 {
if res := r.Get("int", want(1)).(int); res == 1 {
t.Error("unexpected", res)
}

// 6 is smaller than 11 available in int.
if res := r.Get("int6", intLess).(int); res != 6 {
if res := r.Get("int6", want(6)).(int); res != 6 {
t.Error("unexpected", res)
}

// Unregister 11, int should be impossible to find
r.Unregister("int", 11)
if res := r.Get("int", intLess); res != nil {
if res := r.Get("int", want(11)); res != nil {
t.Error("unexpected")
}

Expand All @@ -59,13 +62,35 @@ func TestRegistry(t *testing.T) {
r.Register("int", 1)
r.Unregister("int", 1)

if res := r.Get("int4", intLess).(int); res != 1 {
if res := r.Get("int4", want(1)).(int); res != 1 {
t.Error("unexpected", res)
}
}

func intLess(i, j interface{}) bool {
iInt := i.(int)
jInt := j.(int)
return iInt < jInt
func TestShortSchemeFirst(t *testing.T) {
r := New()
r.Register("foo", 0)
r.Register("foobar", 1)

// If we don't care about the value, we should get the one with "foo".
res := r.Get("foo", func(interface{}) bool { return false })
if res != 0 {
t.Error("unexpected", res)
}
}

func BenchmarkGet(b *testing.B) {
r := New()
for _, addr := range []string{"192.168.1.1", "172.1.1.1", "10.1.1.1"} {
r.Register("tcp", &net.TCPAddr{IP: net.ParseIP(addr)})
}

b.ReportAllocs()
b.ResetTimer()

for i := 0; i < b.N; i++ {
r.Get("tcp", func(x interface{}) bool {
return x.(*net.TCPAddr).IP.IsUnspecified()
})
}
}
6 changes: 0 additions & 6 deletions lib/dialer/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"os"
"time"

"github.com/syncthing/syncthing/lib/util"
"golang.org/x/net/proxy"
)

Expand Down Expand Up @@ -61,11 +60,6 @@ func socksDialerFunction(u *url.URL, forward proxy.Dialer) (proxy.Dialer, error)
return proxy.SOCKS5("tcp", u.Host, auth, forward)
}

// Sort available addresses, preferring unspecified address.
func tcpAddrLess(i interface{}, j interface{}) bool {
return util.AddressUnspecifiedLess(i.(*net.TCPAddr), j.(*net.TCPAddr))
}

// dialerConn is needed because proxy dialed connections have RemoteAddr() pointing at the proxy,
// which then screws up various things such as IsLAN checks, and "let's populate the relay invitation address from
// existing connection" shenanigans.
Expand Down
4 changes: 3 additions & 1 deletion lib/dialer/public.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,9 @@ func DialContextReusePort(ctx context.Context, network, addr string) (net.Conn,
return DialContext(ctx, network, addr)
}

localAddrInterface := registry.Get(network, tcpAddrLess)
localAddrInterface := registry.Get(network, func(addr interface{}) bool {
return addr.(*net.TCPAddr).IP.IsUnspecified()
})
if localAddrInterface == nil {
// Nothing listening, nothing to reuse.
return DialContext(ctx, network, addr)
Expand Down
20 changes: 0 additions & 20 deletions lib/util/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ package util
import (
"context"
"fmt"
"net"
"net/url"
"reflect"
"strconv"
Expand Down Expand Up @@ -231,25 +230,6 @@ func Address(network, host string) string {
return u.String()
}

// AddressUnspecifiedLess is a comparator function preferring least specific network address (most widely listening,
// namely preferring 0.0.0.0 over some IP), if both IPs are equal, it prefers the less restrictive network (prefers tcp
// over tcp4)
func AddressUnspecifiedLess(a, b net.Addr) bool {
aIsUnspecified := false
bIsUnspecified := false
if host, _, err := net.SplitHostPort(a.String()); err == nil {
aIsUnspecified = host == "" || net.ParseIP(host).IsUnspecified()
}
if host, _, err := net.SplitHostPort(b.String()); err == nil {
bIsUnspecified = host == "" || net.ParseIP(host).IsUnspecified()
}

if aIsUnspecified == bIsUnspecified {
return len(a.Network()) < len(b.Network())
}
return aIsUnspecified
}

func CallWithContext(ctx context.Context, fn func() error) error {
var err error
done := make(chan struct{})
Expand Down
44 changes: 0 additions & 44 deletions lib/util/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -225,50 +225,6 @@ func TestCopyMatching(t *testing.T) {
}
}

type mockedAddr struct {
network string
addr string
}

func (a mockedAddr) Network() string {
return a.network
}

func (a mockedAddr) String() string {
return a.addr
}

func TestInspecifiedAddressLess(t *testing.T) {
cases := []struct {
netA string
addrA string
netB string
addrB string
}{
// B is assumed the winner.
{"tcp", "127.0.0.1:1234", "tcp", ":1235"},
{"tcp", "127.0.0.1:1234", "tcp", "0.0.0.0:1235"},
{"tcp4", "0.0.0.0:1234", "tcp", "0.0.0.0:1235"}, // tcp4 on the first one
}

for i, testCase := range cases {
addrs := []mockedAddr{
{testCase.netA, testCase.addrA},
{testCase.netB, testCase.addrB},
}

if AddressUnspecifiedLess(addrs[0], addrs[1]) {
t.Error(i, "unexpected")
}
if !AddressUnspecifiedLess(addrs[1], addrs[0]) {
t.Error(i, "unexpected")
}
if AddressUnspecifiedLess(addrs[0], addrs[0]) || AddressUnspecifiedLess(addrs[1], addrs[1]) {
t.Error(i, "unexpected")
}
}
}

func TestFillNil(t *testing.T) {
type A struct {
Slice []int
Expand Down

0 comments on commit 7c292cc

Please sign in to comment.