Skip to content

Commit

Permalink
Support SO_BINDTODEVICE in ICMP sockets
Browse files Browse the repository at this point in the history
Adds support for the SO_BINDTODEVICE socket option in ICMP sockets with an
accompanying packetimpact test to exercise use of this socket option.

Adds a unit test to exercise the NIC selection logic introduced by this change.
The remaining unit tests for ICMP sockets need to be added in a subsequent CL.
See https://gvisor.dev/issues/5623 for the list of remaining unit tests.

Adds a "timeout" field to PacketimpactTestInfo, necessary due to the long
runtime of the newly added packetimpact test.

Fixes google#5678
Fixes google#4896
Updates google#5623
Updates google#5681
Updates google#5763
Updates google#5956
Updates google#5966
Updates google#5967

PiperOrigin-RevId: 376271581
  • Loading branch information
puradox authored and gvisor-bot committed May 27, 2021
1 parent 17df2df commit 121af37
Show file tree
Hide file tree
Showing 12 changed files with 1,167 additions and 389 deletions.
5 changes: 3 additions & 2 deletions pkg/tcpip/socketops.go
Original file line number Diff line number Diff line change
Expand Up @@ -601,9 +601,10 @@ func (so *SocketOptions) GetBindToDevice() int32 {
return atomic.LoadInt32(&so.bindToDevice)
}

// SetBindToDevice sets value for SO_BINDTODEVICE option.
// SetBindToDevice sets value for SO_BINDTODEVICE option. If bindToDevice is
// zero, the socket device binding is removed.
func (so *SocketOptions) SetBindToDevice(bindToDevice int32) Error {
if !so.handler.HasNIC(bindToDevice) {
if bindToDevice != 0 && !so.handler.HasNIC(bindToDevice) {
return &ErrUnknownDevice{}
}

Expand Down
21 changes: 20 additions & 1 deletion pkg/tcpip/transport/icmp/BUILD
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
load("//tools:defs.bzl", "go_library")
load("//tools:defs.bzl", "go_library", "go_test")
load("//tools/go_generics:defs.bzl", "go_template_instance")

package(licenses = ["notice"])
Expand Down Expand Up @@ -38,3 +38,22 @@ go_library(
"//pkg/waiter",
],
)

go_test(
name = "icmp_x_test",
size = "small",
srcs = ["icmp_test.go"],
deps = [
":icmp",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/checker",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/stack",
"//pkg/tcpip/testutil",
"//pkg/waiter",
],
)
20 changes: 15 additions & 5 deletions pkg/tcpip/transport/icmp/endpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)

// TODO(https://gvisor.dev/issues/5623): Unit test this package.

// +stateify savable
type icmpPacket struct {
icmpPacketEntry
Expand Down Expand Up @@ -134,7 +132,8 @@ func (e *endpoint) Close() {
e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
switch e.state {
case stateBound, stateConnected:
e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */)
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
e.stack.UnregisterTransportEndpoint([]tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, bindToDevice)
}

// Close the receive list and drain it.
Expand Down Expand Up @@ -305,6 +304,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
// Reject destination address if it goes through a different
// NIC than the endpoint was bound to.
nicID := to.NIC
if nicID == 0 {
nicID = tcpip.NICID(e.ops.GetBindToDevice())
}
if e.BindNICID != 0 {
if nicID != 0 && nicID != e.BindNICID {
return 0, &tcpip.ErrNoRoute{}
Expand Down Expand Up @@ -349,6 +351,13 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
return int64(len(v)), nil
}

var _ tcpip.SocketOptionsHandler = (*endpoint)(nil)

// HasNIC implements tcpip.SocketOptionsHandler.
func (e *endpoint) HasNIC(id int32) bool {
return e.stack.HasNIC(tcpip.NICID(id))
}

// SetSockOpt sets a socket option.
func (*endpoint) SetSockOpt(tcpip.SettableSocketOption) tcpip.Error {
return nil
Expand Down Expand Up @@ -608,17 +617,18 @@ func (*endpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, tcpi
}

func (e *endpoint) registerWithStack(_ tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.Error) {
bindToDevice := tcpip.NICID(e.ops.GetBindToDevice())
if id.LocalPort != 0 {
// The endpoint already has a local port, just attempt to
// register it.
err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */)
err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice)
return id, err
}

// We need to find a port for the endpoint.
_, err := e.stack.PickEphemeralPort(e.stack.Rand(), func(p uint16) (bool, tcpip.Error) {
id.LocalPort = p
err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */)
err := e.stack.RegisterTransportEndpoint(netProtos, e.TransProto, id, e, ports.Flags{}, bindToDevice)
switch err.(type) {
case nil:
return true, nil
Expand Down
235 changes: 235 additions & 0 deletions pkg/tcpip/transport/icmp/icmp_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
// Copyright 2021 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package icmp_test

import (
"testing"

"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/testutil"
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
"gvisor.dev/gvisor/pkg/waiter"
)

// TODO(https://gvisor.dev/issues/5623): Finish unit testing the icmp package.
// See the issue for remaining areas of work.

var (
localV4Addr1 = testutil.MustParse4("10.0.0.1")
localV4Addr2 = testutil.MustParse4("10.0.0.2")
remoteV4Addr = testutil.MustParse4("10.0.0.3")
)

func addNICWithDefaultRoute(t *testing.T, s *stack.Stack, id tcpip.NICID, name string, addrV4 tcpip.Address) *channel.Endpoint {
t.Helper()

ep := channel.New(1 /* size */, header.IPv4MinimumMTU, "" /* linkAddr */)
t.Cleanup(ep.Close)

wep := stack.LinkEndpoint(ep)
if testing.Verbose() {
wep = sniffer.New(ep)
}

opts := stack.NICOptions{Name: name}
if err := s.CreateNICWithOptions(id, wep, opts); err != nil {
t.Fatalf("s.CreateNIC(%d, _) = %s", id, err)
}

if err := s.AddAddress(id, ipv4.ProtocolNumber, addrV4); err != nil {
t.Fatalf("s.AddAddress(%d, %d, %s) = %s", id, ipv4.ProtocolNumber, addrV4, err)
}

s.AddRoute(tcpip.Route{
Destination: header.IPv4EmptySubnet,
NIC: id,
})

return ep
}

func writePayload(buf []byte) {
for i := range buf {
buf[i] = byte(i)
}
}

func newICMPv4EchoRequest(payloadSize uint32) buffer.View {
buf := buffer.NewView(header.ICMPv4MinimumSize + int(payloadSize))
writePayload(buf[header.ICMPv4MinimumSize:])

icmp := header.ICMPv4(buf)
icmp.SetType(header.ICMPv4Echo)
// No need to set the checksum; it is reset by the socket before the packet
// is sent.

return buf
}

// TestWriteUnboundWithBindToDevice exercises writing to an unbound ICMP socket
// when SO_BINDTODEVICE is set to the non-default NIC for that subnet.
//
// Only IPv4 is tested. The logic to determine which NIC to use is agnostic to
// the version of IP.
func TestWriteUnboundWithBindToDevice(t *testing.T) {
s := stack.New(stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4},
HandleLocal: true,
})

// Add two NICs, both with default routes on the same subnet. The first NIC
// added will be the default NIC for that subnet.
defaultEP := addNICWithDefaultRoute(t, s, 1, "default", localV4Addr1)
alternateEP := addNICWithDefaultRoute(t, s, 2, "alternate", localV4Addr2)

socket, err := s.NewEndpoint(icmp.ProtocolNumber4, ipv4.ProtocolNumber, &waiter.Queue{})
if err != nil {
t.Fatalf("s.NewEndpoint(%d, %d, _) = %s", icmp.ProtocolNumber4, ipv4.ProtocolNumber, err)
}
defer socket.Close()

echoPayloadSize := defaultEP.MTU() - header.IPv4MinimumSize - header.ICMPv4MinimumSize

// Send a packet without SO_BINDTODEVICE. This verifies that the first NIC
// to be added is the default NIC to send packets when not explicitly bound.
{
buf := newICMPv4EchoRequest(echoPayloadSize)
r := buf.Reader()
n, err := socket.Write(&r, tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: remoteV4Addr},
})
if err != nil {
t.Fatalf("socket.Write(_, {To:%s}) = %s", remoteV4Addr, err)
}
if n != int64(len(buf)) {
t.Fatalf("got n = %d, want n = %d", n, len(buf))
}

// Verify the packet was sent out the default NIC.
p, ok := defaultEP.Read()
if !ok {
t.Fatalf("got defaultEP.Read(_) = _, false; want = _, true (packet wasn't written out)")
}

vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
b := vv.ToView()

checker.IPv4(t, b, []checker.NetworkChecker{
checker.SrcAddr(localV4Addr1),
checker.DstAddr(remoteV4Addr),
checker.ICMPv4(
checker.ICMPv4Type(header.ICMPv4Echo),
checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]),
),
}...)

// Verify the packet was not sent out the alternate NIC.
if p, ok := alternateEP.Read(); ok {
t.Fatalf("got alternateEP.Read(_) = %+v, true; want = _, false", p)
}
}

// Send a packet with SO_BINDTODEVICE. This exercises reliance on
// SO_BINDTODEVICE to route the packet to the alternate NIC.
{
// Use SO_BINDTODEVICE to send over the alternate NIC by default.
socket.SocketOptions().SetBindToDevice(2)

buf := newICMPv4EchoRequest(echoPayloadSize)
r := buf.Reader()
n, err := socket.Write(&r, tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: remoteV4Addr},
})
if err != nil {
t.Fatalf("socket.Write(_, {To:%s}) = %s", tcpip.Address(remoteV4Addr), err)
}
if n != int64(len(buf)) {
t.Fatalf("got n = %d, want n = %d", n, len(buf))
}

// Verify the packet was not sent out the default NIC.
if p, ok := defaultEP.Read(); ok {
t.Fatalf("got defaultEP.Read(_) = %+v, true; want = _, false", p)
}

// Verify the packet was sent out the alternate NIC.
p, ok := alternateEP.Read()
if !ok {
t.Fatalf("got alternateEP.Read(_) = _, false; want = _, true (packet wasn't written out)")
}

vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
b := vv.ToView()

checker.IPv4(t, b, []checker.NetworkChecker{
checker.SrcAddr(localV4Addr2),
checker.DstAddr(remoteV4Addr),
checker.ICMPv4(
checker.ICMPv4Type(header.ICMPv4Echo),
checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]),
),
}...)
}

// Send a packet with SO_BINDTODEVICE cleared. This verifies that clearing
// the device binding will fallback to using the default NIC to send
// packets.
{
socket.SocketOptions().SetBindToDevice(0)

buf := newICMPv4EchoRequest(echoPayloadSize)
r := buf.Reader()
n, err := socket.Write(&r, tcpip.WriteOptions{
To: &tcpip.FullAddress{Addr: remoteV4Addr},
})
if err != nil {
t.Fatalf("socket.Write(_, {To:%s}) = %s", tcpip.Address(remoteV4Addr), err)
}
if n != int64(len(buf)) {
t.Fatalf("got n = %d, want n = %d", n, len(buf))
}

// Verify the packet was sent out the default NIC.
p, ok := defaultEP.Read()
if !ok {
t.Fatalf("got defaultEP.Read(_) = _, false; want = _, true (packet wasn't written out)")
}

vv := buffer.NewVectorisedView(p.Pkt.Size(), p.Pkt.Views())
b := vv.ToView()

checker.IPv4(t, b, []checker.NetworkChecker{
checker.SrcAddr(localV4Addr1),
checker.DstAddr(remoteV4Addr),
checker.ICMPv4(
checker.ICMPv4Type(header.ICMPv4Echo),
checker.ICMPv4Payload(buf[header.ICMPv4MinimumSize:]),
),
}...)

// Verify the packet was not sent out the alternate NIC.
if p, ok := alternateEP.Read(); ok {
t.Fatalf("got alternateEP.Read(_) = %+v, true; want = _, false", p)
}
}
}
Loading

0 comments on commit 121af37

Please sign in to comment.