Skip to content

Commit

Permalink
Refactor remotes and handshaking to give every address a fair shot (s…
Browse files Browse the repository at this point in the history
  • Loading branch information
nbrownus authored Apr 14, 2021
1 parent 20bef97 commit 710df6a
Show file tree
Hide file tree
Showing 25 changed files with 1,546 additions and 1,370 deletions.
50 changes: 27 additions & 23 deletions control.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,23 +67,11 @@ func (c *Control) RebindUDPServer() {

// ListHostmap returns details about the actual or pending (handshaking) hostmap
func (c *Control) ListHostmap(pendingMap bool) []ControlHostInfo {
var hm *HostMap
if pendingMap {
hm = c.f.handshakeManager.pendingHostMap
return listHostMap(c.f.handshakeManager.pendingHostMap)
} else {
hm = c.f.hostMap
return listHostMap(c.f.hostMap)
}

hm.RLock()
hosts := make([]ControlHostInfo, len(hm.Hosts))
i := 0
for _, v := range hm.Hosts {
hosts[i] = copyHostInfo(v)
i++
}
hm.RUnlock()

return hosts
}

// GetHostInfoByVpnIP returns a single tunnels hostInfo, or nil if not found
Expand All @@ -100,7 +88,7 @@ func (c *Control) GetHostInfoByVpnIP(vpnIP uint32, pending bool) *ControlHostInf
return nil
}

ch := copyHostInfo(h)
ch := copyHostInfo(h, c.f.hostMap.preferredRanges)
return &ch
}

Expand All @@ -112,7 +100,7 @@ func (c *Control) SetRemoteForTunnel(vpnIP uint32, addr udpAddr) *ControlHostInf
}

hostInfo.SetRemote(addr.Copy())
ch := copyHostInfo(hostInfo)
ch := copyHostInfo(hostInfo, c.f.hostMap.preferredRanges)
return &ch
}

Expand Down Expand Up @@ -163,14 +151,17 @@ func (c *Control) CloseAllTunnels(excludeLighthouses bool) (closed int) {
return
}

func copyHostInfo(h *HostInfo) ControlHostInfo {
func copyHostInfo(h *HostInfo, preferredRanges []*net.IPNet) ControlHostInfo {
chi := ControlHostInfo{
VpnIP: int2ip(h.hostId),
LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId,
RemoteAddrs: h.CopyRemotes(),
CachedPackets: len(h.packetStore),
MessageCounter: atomic.LoadUint64(&h.ConnectionState.atomicMessageCounter),
VpnIP: int2ip(h.hostId),
LocalIndex: h.localIndexId,
RemoteIndex: h.remoteIndexId,
RemoteAddrs: h.remotes.CopyAddrs(preferredRanges),
CachedPackets: len(h.packetStore),
}

if h.ConnectionState != nil {
chi.MessageCounter = atomic.LoadUint64(&h.ConnectionState.atomicMessageCounter)
}

if c := h.GetCert(); c != nil {
Expand All @@ -183,3 +174,16 @@ func copyHostInfo(h *HostInfo) ControlHostInfo {

return chi
}

func listHostMap(hm *HostMap) []ControlHostInfo {
hm.RLock()
hosts := make([]ControlHostInfo, len(hm.Hosts))
i := 0
for _, v := range hm.Hosts {
hosts[i] = copyHostInfo(v, hm.preferredRanges)
i++
}
hm.RUnlock()

return hosts
}
10 changes: 6 additions & 4 deletions control_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,12 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
Signature: []byte{1, 2, 1, 2, 1, 3},
}

remotes := []*udpAddr{remote1, remote2}
remotes := NewRemoteList()
remotes.unlockedPrependV4(0, NewIp4AndPort(remote1.IP, uint32(remote1.Port)))
remotes.unlockedPrependV6(0, NewIp6AndPort(remote2.IP, uint32(remote2.Port)))
hm.Add(ip2int(ipNet.IP), &HostInfo{
remote: remote1,
Remotes: remotes,
remotes: remotes,
ConnectionState: &ConnectionState{
peerCert: crt,
},
Expand All @@ -59,7 +61,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {

hm.Add(ip2int(ipNet2.IP), &HostInfo{
remote: remote1,
Remotes: remotes,
remotes: remotes,
ConnectionState: &ConnectionState{
peerCert: nil,
},
Expand All @@ -81,7 +83,7 @@ func TestControl_GetHostInfoByVpnIP(t *testing.T) {
VpnIP: net.IPv4(1, 2, 3, 4).To4(),
LocalIndex: 201,
RemoteIndex: 200,
RemoteAddrs: []*udpAddr{remote1, remote2},
RemoteAddrs: []*udpAddr{remote2, remote1},
CachedPackets: 0,
Cert: crt.Copy(),
MessageCounter: 0,
Expand Down
30 changes: 27 additions & 3 deletions control_tester.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,18 @@ func (c *Control) WaitForTypeByIndex(toIndex uint32, msgType NebulaMessageType,
// InjectLightHouseAddr will push toAddr into the local lighthouse cache for the vpnIp
// This is necessary if you did not configure static hosts or are not running a lighthouse
func (c *Control) InjectLightHouseAddr(vpnIp net.IP, toAddr *net.UDPAddr) {
c.f.lightHouse.AddRemote(ip2int(vpnIp), &udpAddr{IP: toAddr.IP, Port: uint16(toAddr.Port)}, false)
c.f.lightHouse.Lock()
remoteList := c.f.lightHouse.unlockedGetRemoteList(ip2int(vpnIp))
remoteList.Lock()
defer remoteList.Unlock()
c.f.lightHouse.Unlock()

iVpnIp := ip2int(vpnIp)
if v4 := toAddr.IP.To4(); v4 != nil {
remoteList.unlockedPrependV4(iVpnIp, NewIp4AndPort(v4, uint32(toAddr.Port)))
} else {
remoteList.unlockedPrependV6(iVpnIp, NewIp6AndPort(toAddr.IP, uint32(toAddr.Port)))
}
}

// GetFromTun will pull a packet off the tun side of nebula
Expand Down Expand Up @@ -84,14 +95,17 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
SrcPort: layers.UDPPort(fromPort),
DstPort: layers.UDPPort(toPort),
}
udp.SetNetworkLayerForChecksum(&ip)
err := udp.SetNetworkLayerForChecksum(&ip)
if err != nil {
panic(err)
}

buffer := gopacket.NewSerializeBuffer()
opt := gopacket.SerializeOptions{
ComputeChecksums: true,
FixLengths: true,
}
err := gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload(data))
if err != nil {
panic(err)
}
Expand All @@ -102,3 +116,13 @@ func (c *Control) InjectTunUDPPacket(toIp net.IP, toPort uint16, fromPort uint16
func (c *Control) GetUDPAddr() string {
return c.f.outside.addr.String()
}

func (c *Control) KillPendingTunnel(vpnIp net.IP) bool {
hostinfo, ok := c.f.handshakeManager.pendingHostMap.Hosts[ip2int(vpnIp)]
if !ok {
return false
}

c.f.handshakeManager.pendingHostMap.DeleteHostInfo(hostinfo)
return true
}
Loading

0 comments on commit 710df6a

Please sign in to comment.