Skip to content

Commit

Permalink
Move BPF event filtering to the kernel (gravitational#23017)
Browse files Browse the repository at this point in the history
* Move BPF event filtering to the kernel

gravitational#19354 moved filtering of disk events to the kernel space. This PR continues work in this area and moves all events to be filtered in the kernel space.

* Improve comments.
Minor code fixes.
  • Loading branch information
jakule authored Mar 20, 2023
1 parent 4099bf7 commit 35a5688
Show file tree
Hide file tree
Showing 11 changed files with 251 additions and 116 deletions.
26 changes: 24 additions & 2 deletions bpf/enhancedrecording/command.bpf.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <bpf/bpf_core_read.h> /* for BPF CO-RE helpers */
#include <bpf/bpf_tracing.h> /* for getting kprobe arguments */

#include "./common.h"
#include "../helpers.h"

#define ARGSIZE 128
Expand All @@ -13,6 +14,9 @@
// the userspace can adjust this value based on config.
#define EVENTS_BUF_SIZE (4096*8)

// hashmap keeps all cgroups id that should be monitored by Teleport.
BPF_HASH(monitored_cgroups, u64, int64_t, MAX_MONITORED_SESSIONS);

char LICENSE[] SEC("license") = "Dual BSD/GPL";

enum event_type {
Expand Down Expand Up @@ -61,9 +65,18 @@ static int enter_execve(const char *filename,
// create data here and pass to submit_arg to save stack space (#555)
struct data_t data = {};
struct task_struct *task;
u64 cgroup = bpf_get_current_cgroup_id();
u64 *is_monitored;

// Check if the cgroup should be monitored.
is_monitored = bpf_map_lookup_elem(&monitored_cgroups, &cgroup);
if (is_monitored == NULL) {
// Missed entry.
return 0;
}

data.pid = bpf_get_current_pid_tgid() >> 32;
data.cgroup = bpf_get_current_cgroup_id();
data.cgroup = cgroup;

task = (struct task_struct *)bpf_get_current_task();
data.ppid = BPF_CORE_READ(task, real_parent, tgid);
Expand All @@ -90,9 +103,18 @@ static int exit_execve(int ret)
{
struct data_t data = {};
struct task_struct *task;
u64 cgroup = bpf_get_current_cgroup_id();
u64 *is_monitored;

// Check if the cgroup should be monitored.
is_monitored = bpf_map_lookup_elem(&monitored_cgroups, &cgroup);
if (is_monitored == NULL) {
// cgroup has not been marked for monitoring, ignore.
return 0;
}

data.pid = bpf_get_current_pid_tgid() >> 32;
data.cgroup = bpf_get_current_cgroup_id();
data.cgroup = cgroup;

task = (struct task_struct *)bpf_get_current_task();
data.ppid = BPF_CORE_READ(task, real_parent, tgid);
Expand Down
7 changes: 7 additions & 0 deletions bpf/enhancedrecording/common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#ifndef BPF_COMMON_H
#define BPF_COMMON_H

// Maximum monitored sessions.
#define MAX_MONITORED_SESSIONS 1024

#endif // BPF_COMMON_H
5 changes: 2 additions & 3 deletions bpf/enhancedrecording/disk.bpf.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <bpf/bpf_core_read.h> /* for BPF CO-RE helpers */
#include <bpf/bpf_tracing.h> /* for getting kprobe arguments */

#include "./common.h"
#include "../helpers.h"

// Maximum number of in-flight open syscalls supported
Expand All @@ -14,8 +15,6 @@
// the userspace can adjust this value based on config.
#define EVENTS_BUF_SIZE (4096*128)

// Maximum monitored sessions.
#define MAX_MONITORED_SESSIONS 1024

char LICENSE[] SEC("license") = "Dual BSD/GPL";

Expand Down Expand Up @@ -73,7 +72,7 @@ static int exit_open(int ret) {
// Check if the cgroup should be monitored.
is_monitored = bpf_map_lookup_elem(&monitored_cgroups, &cgroup);
if (is_monitored == NULL) {
// Missed entry.
// cgroup has not been marked for monitoring, ignore.
return 0;
}

Expand Down
15 changes: 14 additions & 1 deletion bpf/enhancedrecording/network.bpf.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <bpf/bpf_core_read.h> /* for BPF CO-RE helpers */
#include <bpf/bpf_tracing.h> /* for getting kprobe arguments */

#include "./common.h"
#include "../helpers.h"

char LICENSE[] SEC("license") = "Dual BSD/GPL";
Expand All @@ -17,6 +18,9 @@ char LICENSE[] SEC("license") = "Dual BSD/GPL";

BPF_HASH(currsock, u32, struct sock *, INFLIGHT_MAX);

// hashmap keeps all cgroups id that should be monitored by Teleport.
BPF_HASH(monitored_cgroups, u64, int64_t, MAX_MONITORED_SESSIONS);

// separate data structs for ipv4 and ipv6
struct ipv4_data_t {
u64 cgroup;
Expand Down Expand Up @@ -55,7 +59,16 @@ static int trace_connect_entry(struct sock *sk)
static int trace_connect_return(int ret, short ipver)
{
u64 pid_tgid = bpf_get_current_pid_tgid();
u32 id = (u32)pid_tgid;
u64 cgroup = bpf_get_current_cgroup_id();
u32 id = (u32)pid_tgid;
u64 *is_monitored;

// Check if the cgroup should be monitored.
is_monitored = bpf_map_lookup_elem(&monitored_cgroups, &cgroup);
if (is_monitored == NULL) {
// cgroup has not been marked for monitoring, ignore.
return 0;
}

struct sock **skpp;
skpp = bpf_map_lookup_elem(&currsock, &id);
Expand Down
33 changes: 27 additions & 6 deletions lib/bpf/bpf.go
Original file line number Diff line number Diff line change
Expand Up @@ -233,9 +233,24 @@ func (s *Service) OpenSession(ctx *SessionContext) (uint64, error) {
return 0, trace.Wrap(err)
}

// Register cgroup in the BPF module.
if err := s.open.startSession(cgroupID); err != nil {
return 0, trace.Wrap(err)
// initializedModClosures holds all already opened modules closures.
initializedModClosures := make([]interface{ endSession(uint64) error }, 0)
for _, module := range []cgroupRegister{
s.open,
s.exec,
s.conn,
} {
// Register cgroup in the BPF module.
if err := module.startSession(cgroupID); err != nil {
// Clean up all already opened modules.
for _, closer := range initializedModClosures {
if closeErr := closer.endSession(cgroupID); closeErr != nil {
log.Debugf("failed to close session: %v", closeErr)
}
}
return 0, trace.Wrap(err)
}
initializedModClosures = append(initializedModClosures, module)
}

// Start watching for any events that come from this cgroup.
Expand Down Expand Up @@ -268,9 +283,15 @@ func (s *Service) CloseSession(ctx *SessionContext) error {
errs = append(errs, trace.Wrap(err))
}

// Remove the cgroup from BPF module.
if err := s.open.endSession(cgroupID); err != nil {
errs = append(errs, trace.Wrap(err))
for _, module := range []interface{ endSession(cgroupID uint64) error }{
s.open,
s.exec,
s.conn,
} {
// Remove the cgroup from BPF module.
if err := module.endSession(cgroupID); err != nil {
errs = append(errs, trace.Wrap(err))
}
}

return trace.NewAggregate(errs...)
Expand Down
67 changes: 31 additions & 36 deletions lib/bpf/bpf_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,7 @@ func TestRootObfuscate(t *testing.T) {
for {
select {
case <-ticker.C:
if err := osexec.Command(fileName).Run(); err != nil {
t.Logf("Failed to run script: %v.", err)
}
runCmd(t, reexecInCGroupCmd, fileName, execsnoop)
case <-done:
return
}
Expand Down Expand Up @@ -241,10 +239,8 @@ func TestRootScript(t *testing.T) {
case <-done:
return
case <-ticker.C:
// Run script.
if err := osexec.Command(fileName).Run(); err != nil {
t.Logf("Failed to run script: %v.", err)
}
// Run script in a cgroup.
runCmd(t, reexecInCGroupCmd, fileName, execsnoop)
}
}
}()
Expand Down Expand Up @@ -304,20 +300,18 @@ func TestRootPrograms(t *testing.T) {
// Loop over all three programs and make sure events are received off the
// perf buffer.
var tests = []struct {
inName string
inCommand string
inCommandArgs []string
inEventCh <-chan []byte
inHTTP bool
verifyFn func(event []byte) bool
inName string
inEventCh <-chan []byte
genEvents func(t *testing.T, ctx context.Context)
verifyFn func(event []byte) bool
}{
// Run execsnoop with "ls".
{
inName: "execsnoop",
inCommand: "ls",
inCommandArgs: []string{},
inEventCh: execsnoop.events(),
inHTTP: false,
inName: "execsnoop",
inEventCh: execsnoop.events(),
genEvents: func(t *testing.T, ctx context.Context) {
executeCommand(t, ctx, "ls", execsnoop)
},
verifyFn: func(event []byte) bool {
var e rawExecEvent
err := unmarshalEvent(event, &e)
Expand All @@ -327,11 +321,11 @@ func TestRootPrograms(t *testing.T) {
// Run opensnoop with "ls". This is fine because "ls" will open some
// shared library.
{
inName: "opensnoop",
inCommand: "ls",
inCommandArgs: []string{},
inEventCh: opensnoop.events(),
inHTTP: false,
inName: "opensnoop",
inEventCh: opensnoop.events(),
genEvents: func(t *testing.T, ctx context.Context) {
executeCommand(t, ctx, "ls", opensnoop)
},
verifyFn: func(event []byte) bool {
var e rawOpenEvent
err := unmarshalEvent(event, &e)
Expand All @@ -342,7 +336,9 @@ func TestRootPrograms(t *testing.T) {
{
inName: "tcpconnect",
inEventCh: tcpconnect.v4Events(),
inHTTP: true,
genEvents: func(t *testing.T, ctx context.Context) {
executeHTTP(t, ctx, ts.URL, tcpconnect)
},
verifyFn: func(event []byte) bool {
var e rawConn4Event
err := unmarshalEvent(event, &e)
Expand All @@ -359,11 +355,8 @@ func TestRootPrograms(t *testing.T) {
// second will continue to execute or an HTTP GET in a processAccessEvents attempting to
// trigger an event.
go waitForEvent(doneContext, doneFunc, tt.inEventCh, tt.verifyFn)
if tt.inHTTP {
go executeHTTP(t, doneContext, ts.URL)
} else {
go executeCommand(t, doneContext, tt.inCommand, opensnoop)
}

go tt.genEvents(t, doneContext)

// Wait for an event to arrive from execsnoop. If an event does not arrive
// within 10 seconds, timeout.
Expand Down Expand Up @@ -526,14 +519,17 @@ func executeCommand(t *testing.T, doneContext context.Context, file string,
t.Logf("Failed to find executable %q: %v.", file, err)
}

runCmd(t, path, traceCgroup)
fullPath, err := osexec.LookPath(path)
require.NoError(t, err)

runCmd(t, reexecInCGroupCmd, fullPath, traceCgroup)
case <-doneContext.Done():
return
}
}
}

func runCmd(t *testing.T, cmdName string, traceCgroup cgroupRegister) {
func runCmd(t *testing.T, reexecCmd string, arg string, traceCgroup cgroupRegister) {
t.Helper()

// Create a pipe to communicate with the child process after re-exec.
Expand All @@ -545,11 +541,8 @@ func runCmd(t *testing.T, cmdName string, traceCgroup cgroupRegister) {
writeP.Close()
})

path, err := osexec.LookPath(cmdName)
require.NoError(t, err)

// Re-exec the test binary. We can then move the binary to a new cgroup.
cmd := osexec.Command(os.Args[0], reexecInCGroupCmd, path)
cmd := osexec.Command(os.Args[0], reexecCmd, arg)

cmd.ExtraFiles = append(cmd.ExtraFiles, readP)

Expand Down Expand Up @@ -578,7 +571,7 @@ func runCmd(t *testing.T, cmdName string, traceCgroup cgroupRegister) {
}

// executeHTTP will perform a HTTP GET to some endpoint in a loop.
func executeHTTP(t *testing.T, doneContext context.Context, endpoint string) {
func executeHTTP(t *testing.T, doneContext context.Context, endpoint string, traceCgroup cgroupRegister) {
t.Helper()

ticker := time.NewTicker(250 * time.Millisecond)
Expand All @@ -592,6 +585,8 @@ func executeHTTP(t *testing.T, doneContext context.Context, endpoint string) {
t.Logf("HTTP request failed: %v.", err)
}

runCmd(t, networkInCgroupCmd, endpoint, traceCgroup)

case <-doneContext.Done():
return
}
Expand Down
16 changes: 8 additions & 8 deletions lib/bpf/command.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ type rawExecEvent struct {
}

type exec struct {
module *libbpfgo.Module
session

eventBuf *RingBuffer
lost *Counter
Expand All @@ -90,36 +90,36 @@ func startExec(bufferSize int) (*exec, error) {
return nil, trace.Wrap(err)
}

e.module, err = libbpfgo.NewModuleFromBuffer(commandBPF, "command")
e.session.module, err = libbpfgo.NewModuleFromBuffer(commandBPF, "command")
if err != nil {
return nil, trace.Wrap(err)
}

// Resizing the ring buffer must be done here, after the module
// was created but before it's loaded into the kernel.
if err = ResizeMap(e.module, commandEventsBuffer, uint32(bufferSize*pageSize)); err != nil {
if err = ResizeMap(e.session.module, commandEventsBuffer, uint32(bufferSize*pageSize)); err != nil {
return nil, trace.Wrap(err)
}

// Load into the kernel
if err = e.module.BPFLoadObject(); err != nil {
if err = e.session.module.BPFLoadObject(); err != nil {
return nil, trace.Wrap(err)
}

syscalls := []string{"execve", "execveat"}

for _, syscall := range syscalls {
if err = AttachSyscallTracepoint(e.module, syscall); err != nil {
if err = AttachSyscallTracepoint(e.session.module, syscall); err != nil {
return nil, trace.Wrap(err)
}
}

e.eventBuf, err = NewRingBuffer(e.module, commandEventsBuffer)
e.eventBuf, err = NewRingBuffer(e.session.module, commandEventsBuffer)
if err != nil {
return nil, trace.Wrap(err)
}

e.lost, err = NewCounter(e.module, "lost", lostCommandEvents)
e.lost, err = NewCounter(e.session.module, "lost", lostCommandEvents)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -132,7 +132,7 @@ func startExec(bufferSize int) (*exec, error) {
func (e *exec) close() {
e.lost.Close()
e.eventBuf.Close()
e.module.Close()
e.session.module.Close()
}

// events contains raw events off the perf buffer.
Expand Down
Loading

0 comments on commit 35a5688

Please sign in to comment.