From 823374b5786468d3a9f4a4d62b5a8c56b6154545 Mon Sep 17 00:00:00 2001 From: Teppei Fukuda Date: Tue, 24 Dec 2019 16:49:56 +0200 Subject: [PATCH] feat(client/server): add --token-headers option (#326) * feat(option): add token-header * feat(client): add token header * feat(server): add token header * test(token): fix tests * test(token): add integration tests * feat(client): add --custom-headers --- README.md | 10 +++ integration/client_server_test.go | 60 +++++++++++------ internal/app.go | 23 +++++-- internal/client/config/config.go | 33 ++++++++-- internal/client/config/config_test.go | 94 ++++++++++++++++++++------- internal/client/inject.go | 3 +- internal/client/run.go | 3 +- internal/client/wire_gen.go | 6 +- internal/server/config/config.go | 6 +- pkg/rpc/client/headers.go | 20 ++++++ pkg/rpc/client/headers_test.go | 59 +++++++++++++++++ pkg/rpc/client/library/client.go | 12 ++-- pkg/rpc/client/library/client_test.go | 9 ++- pkg/rpc/client/ospkg/client.go | 12 ++-- pkg/rpc/client/ospkg/client_test.go | 8 ++- pkg/rpc/client/token.go | 35 ---------- pkg/rpc/client/token_test.go | 72 -------------------- pkg/rpc/server/server.go | 8 +-- 18 files changed, 281 insertions(+), 192 deletions(-) create mode 100644 pkg/rpc/client/headers.go create mode 100644 pkg/rpc/client/headers_test.go delete mode 100644 pkg/rpc/client/token.go delete mode 100644 pkg/rpc/client/token_test.go diff --git a/README.md b/README.md index 3e12e7948017..833b441a164f 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ A Simple and Comprehensive Vulnerability Scanner for Containers, Suitable for CI - [Client/Server](#client--server) - [Server](#server) - [Client](#client) + - [Authentication](#authentication) - [Continuous Integration (CI)](#continuous-integration-ci) - [Travis CI](#travis-ci) - [CircleCI](#circleci) @@ -1126,6 +1127,15 @@ Total: 3 (UNKNOWN: 0, LOW: 1, MEDIUM: 2, HIGH: 0, CRITICAL: 0) ``` +### Authentication + +``` +$ trivy server --listen localhost:8080 --token dummy +``` + +``` +$ trivy client --remote http://localhost:8080 --token dummy alpine:3.10 +``` ### Deprecated options diff --git a/integration/client_server_test.go b/integration/client_server_test.go index 8d68b0d8fd45..b0b893a87386 100644 --- a/integration/client_server_test.go +++ b/integration/client_server_test.go @@ -20,13 +20,15 @@ import ( func TestClientServer(t *testing.T) { type args struct { - Version string - IgnoreUnfixed bool - Severity []string - IgnoreIDs []string - Input string - ClientToken string - ServerToken string + Version string + IgnoreUnfixed bool + Severity []string + IgnoreIDs []string + Input string + ClientToken string + ClientTokenHeader string + ServerToken string + ServerTokenHeader string } cases := []struct { name string @@ -45,10 +47,12 @@ func TestClientServer(t *testing.T) { { name: "alpine 3.10 integration with token", testArgs: args{ - Version: "dev", - Input: "testdata/fixtures/alpine-310.tar.gz", - ClientToken: "token", - ServerToken: "token", + Version: "dev", + Input: "testdata/fixtures/alpine-310.tar.gz", + ClientToken: "token", + ClientTokenHeader: "Trivy-Token", + ServerToken: "token", + ServerTokenHeader: "Trivy-Token", }, golden: "testdata/alpine-310.json.golden", }, @@ -276,10 +280,24 @@ func TestClientServer(t *testing.T) { { name: "invalid token", testArgs: args{ - Version: "dev", - Input: "testdata/fixtures/distroless-base.tar.gz", - ClientToken: "invalidtoken", - ServerToken: "token", + Version: "dev", + Input: "testdata/fixtures/distroless-base.tar.gz", + ClientToken: "invalidtoken", + ClientTokenHeader: "Trivy-Token", + ServerToken: "token", + ServerTokenHeader: "Trivy-Token", + }, + wantErr: "twirp error unauthenticated: invalid token", + }, + { + name: "invalid token header", + testArgs: args{ + Version: "dev", + Input: "testdata/fixtures/distroless-base.tar.gz", + ClientToken: "valid-token", + ClientTokenHeader: "Trivy-Token", + ServerToken: "valid-token", + ServerTokenHeader: "Invalid", }, wantErr: "twirp error unauthenticated: invalid token", }, @@ -299,7 +317,7 @@ func TestClientServer(t *testing.T) { // Setup CLI App app := internal.NewApp(c.testArgs.Version) app.Writer = ioutil.Discard - osArgs := setupServer(addr, c.testArgs.ServerToken, cacheDir) + osArgs := setupServer(addr, c.testArgs.ServerToken, c.testArgs.ServerTokenHeader, cacheDir) // Run Trivy server require.NoError(t, app.Run(osArgs), c.name) @@ -313,7 +331,7 @@ func TestClientServer(t *testing.T) { app.Writer = ioutil.Discard osArgs, outputFile, cleanup := setupClient(t, c.testArgs.IgnoreUnfixed, c.testArgs.Severity, - c.testArgs.IgnoreIDs, addr, c.testArgs.ClientToken, c.testArgs.Input, cacheDir, c.golden) + c.testArgs.IgnoreIDs, addr, c.testArgs.ClientToken, c.testArgs.ClientTokenHeader, c.testArgs.Input, cacheDir, c.golden) defer cleanup() // Run Trivy client @@ -338,16 +356,16 @@ func TestClientServer(t *testing.T) { } } -func setupServer(addr, token, cacheDir string) []string { +func setupServer(addr, token, tokenHeader, cacheDir string) []string { osArgs := []string{"trivy", "server", "--skip-update", "--cache-dir", cacheDir, "--listen", addr} if token != "" { - osArgs = append(osArgs, []string{"--token", token}...) + osArgs = append(osArgs, []string{"--token", token, "--token-header", tokenHeader}...) } return osArgs } func setupClient(t *testing.T, ignoreUnfixed bool, severity, ignoreIDs []string, - addr, token, input, cacheDir, golden string) ([]string, string, func()) { + addr, token, tokenHeader, input, cacheDir, golden string) ([]string, string, func()) { t.Helper() osArgs := []string{"trivy", "client", "--cache-dir", cacheDir, "--format", "json", "--remote", "http://" + addr} @@ -371,7 +389,7 @@ func setupClient(t *testing.T, ignoreUnfixed bool, severity, ignoreIDs []string, osArgs = append(osArgs, []string{"--ignorefile", trivyIgnore}...) } if token != "" { - osArgs = append(osArgs, []string{"--token", token}...) + osArgs = append(osArgs, []string{"--token", token, "--token-header", tokenHeader}...) } if input != "" { osArgs = append(osArgs, []string{"--input", input}...) diff --git a/internal/app.go b/internal/app.go index d076ff28267d..10e08e334194 100644 --- a/internal/app.go +++ b/internal/app.go @@ -4,15 +4,14 @@ import ( "strings" "time" + "github.com/urfave/cli" + + "github.com/aquasecurity/trivy-db/pkg/types" "github.com/aquasecurity/trivy/internal/client" "github.com/aquasecurity/trivy/internal/server" - "github.com/aquasecurity/trivy/internal/standalone" - "github.com/aquasecurity/trivy/pkg/vulnerability" - - "github.com/aquasecurity/trivy-db/pkg/types" "github.com/aquasecurity/trivy/pkg/utils" - "github.com/urfave/cli" + "github.com/aquasecurity/trivy/pkg/vulnerability" ) var ( @@ -144,6 +143,13 @@ var ( Usage: "for authentication", EnvVar: "TRIVY_TOKEN", } + + tokenHeader = cli.StringFlag{ + Name: "token-header", + Value: "Trivy-Token", + Usage: "specify a header name for token", + EnvVar: "TRIVY_TOKEN_HEADER", + } ) func NewApp(version string) *cli.App { @@ -243,12 +249,18 @@ func NewClientCommand() cli.Command { // original flags token, + tokenHeader, cli.StringFlag{ Name: "remote", Value: "http://localhost:4954", Usage: "server address", EnvVar: "TRIVY_REMOTE", }, + cli.StringSliceFlag{ + Name: "custom-headers", + Usage: "custom headers", + EnvVar: "TRIVY_CUSTOM_HEADERS", + }, }, } } @@ -269,6 +281,7 @@ func NewServerCommand() cli.Command { // original flags token, + tokenHeader, cli.StringFlag{ Name: "listen", Value: "localhost:4954", diff --git a/internal/client/config/config.go b/internal/client/config/config.go index 2adf757c0ed2..6d99eb6a9819 100644 --- a/internal/client/config/config.go +++ b/internal/client/config/config.go @@ -1,6 +1,7 @@ package config import ( + "net/http" "os" "strings" "time" @@ -36,8 +37,11 @@ type Config struct { IgnoreUnfixed bool ExitCode int - RemoteAddr string - Token string + RemoteAddr string + token string + tokenHeader string + customHeaders []string + CustomHeaders http.Header // these variables are generated by Init() ImageName string @@ -76,8 +80,10 @@ func New(c *cli.Context) (Config, error) { IgnoreUnfixed: c.Bool("ignore-unfixed"), ExitCode: c.Int("exit-code"), - RemoteAddr: c.String("remote"), - Token: c.String("token"), + RemoteAddr: c.String("remote"), + token: c.String("token"), + tokenHeader: c.String("token-header"), + customHeaders: c.StringSlice("custom-headers"), }, nil } @@ -85,6 +91,12 @@ func (c *Config) Init() (err error) { c.Severities = c.splitSeverity(c.severities) c.VulnType = strings.Split(c.vulnType, ",") c.AppVersion = c.context.App.Version + c.CustomHeaders = splitCustomHeaders(c.customHeaders) + + // add token to custom headers + if c.token != "" { + c.CustomHeaders.Set(c.tokenHeader, c.token) + } // --clear-cache doesn't conduct the scan if c.ClearCache { @@ -135,3 +147,16 @@ func (c *Config) splitSeverity(severity string) []dbTypes.Severity { } return severities } + +func splitCustomHeaders(headers []string) http.Header { + result := make(http.Header) + for _, header := range headers { + // e.g. x-api-token:XXX + s := strings.SplitN(header, ":", 2) + if len(s) != 2 { + continue + } + result.Set(s[0], s[1]) + } + return result +} diff --git a/internal/client/config/config_test.go b/internal/client/config/config_test.go index bf2d3fc9ca91..9e493b0f1d01 100644 --- a/internal/client/config/config_test.go +++ b/internal/client/config/config_test.go @@ -2,7 +2,9 @@ package config import ( "flag" + "net/http" "os" + "reflect" "testing" "time" @@ -83,6 +85,8 @@ func TestConfig_Init(t *testing.T) { onlyUpdate string refresh bool autoRefresh bool + token string + tokenHeader string } tests := []struct { name string @@ -95,20 +99,27 @@ func TestConfig_Init(t *testing.T) { { name: "happy path", fields: fields{ - severities: "CRITICAL", - vulnType: "os", - Quiet: true, + severities: "CRITICAL", + vulnType: "os", + Quiet: true, + token: "foobar", + tokenHeader: "Trivy-Token", }, args: []string{"alpine:3.10"}, want: Config{ - AppVersion: "0.0.0", - Severities: []dbTypes.Severity{dbTypes.SeverityCritical}, - severities: "CRITICAL", - ImageName: "alpine:3.10", - VulnType: []string{"os"}, - vulnType: "os", - Output: os.Stdout, - Quiet: true, + AppVersion: "0.0.0", + Severities: []dbTypes.Severity{dbTypes.SeverityCritical}, + severities: "CRITICAL", + ImageName: "alpine:3.10", + VulnType: []string{"os"}, + vulnType: "os", + Output: os.Stdout, + Quiet: true, + token: "foobar", + tokenHeader: "Trivy-Token", + CustomHeaders: http.Header{ + "Trivy-Token": []string{"foobar"}, + }, }, }, { @@ -122,13 +133,14 @@ func TestConfig_Init(t *testing.T) { "unknown severity option: unknown severity: INVALID", }, want: Config{ - AppVersion: "0.0.0", - Severities: []dbTypes.Severity{dbTypes.SeverityCritical, dbTypes.SeverityUnknown}, - severities: "CRITICAL,INVALID", - ImageName: "centos:7", - VulnType: []string{"os", "library"}, - vulnType: "os,library", - Output: os.Stdout, + AppVersion: "0.0.0", + Severities: []dbTypes.Severity{dbTypes.SeverityCritical, dbTypes.SeverityUnknown}, + severities: "CRITICAL,INVALID", + ImageName: "centos:7", + VulnType: []string{"os", "library"}, + vulnType: "os,library", + Output: os.Stdout, + CustomHeaders: make(http.Header), }, }, { @@ -143,13 +155,14 @@ func TestConfig_Init(t *testing.T) { "You should avoid using the :latest tag as it is cached. You need to specify '--clear-cache' option when :latest image is changed", }, want: Config{ - AppVersion: "0.0.0", - Severities: []dbTypes.Severity{dbTypes.SeverityLow}, - severities: "LOW", - ImageName: "gcr.io/distroless/base", - VulnType: []string{"os", "library"}, - vulnType: "os,library", - Output: os.Stdout, + AppVersion: "0.0.0", + Severities: []dbTypes.Severity{dbTypes.SeverityLow}, + severities: "LOW", + ImageName: "gcr.io/distroless/base", + VulnType: []string{"os", "library"}, + vulnType: "os,library", + Output: os.Stdout, + CustomHeaders: make(http.Header), }, }, { @@ -200,6 +213,8 @@ func TestConfig_Init(t *testing.T) { ExitCode: tt.fields.ExitCode, ImageName: tt.fields.ImageName, Output: tt.fields.Output, + token: tt.fields.token, + tokenHeader: tt.fields.tokenHeader, } err := c.Init() @@ -227,3 +242,32 @@ func TestConfig_Init(t *testing.T) { }) } } + +func Test_splitCustomHeaders(t *testing.T) { + type args struct { + headers []string + } + tests := []struct { + name string + args args + want http.Header + }{ + { + name: "happy path", + args: args{ + headers: []string{"x-api-token:foo bar", "Authorization:user:password"}, + }, + want: http.Header{ + "X-Api-Token": []string{"foo bar"}, + "Authorization": []string{"user:password"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := splitCustomHeaders(tt.args.headers); !reflect.DeepEqual(got, tt.want) { + t.Errorf("splitCustomHeaders() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/internal/client/inject.go b/internal/client/inject.go index 0e8cd813eece..3cba7ddc0566 100644 --- a/internal/client/inject.go +++ b/internal/client/inject.go @@ -10,7 +10,8 @@ import ( "github.com/google/wire" ) -func initializeScanner(ospkgToken ospkg.Token, libToken library.Token, ospkgURL ospkg.RemoteURL, libURL library.RemoteURL) scanner.Scanner { +func initializeScanner(ospkgCustomHeaders ospkg.CustomHeaders, libraryCustomHeaders library.CustomHeaders, + ospkgURL ospkg.RemoteURL, libURL library.RemoteURL) scanner.Scanner { wire.Build(scanner.ClientSet) return scanner.Scanner{} } diff --git a/internal/client/run.go b/internal/client/run.go index a5ef3db856fd..0b032237fe95 100644 --- a/internal/client/run.go +++ b/internal/client/run.go @@ -41,11 +41,10 @@ func run(c config.Config) (err error) { VulnType: c.VulnType, Timeout: c.Timeout, RemoteURL: c.RemoteAddr, - Token: c.Token, } log.Logger.Debugf("Vulnerability type: %s", scanOptions.VulnType) - scanner := initializeScanner(ospkg.Token(c.Token), library.Token(c.Token), + scanner := initializeScanner(ospkg.CustomHeaders(c.CustomHeaders), library.CustomHeaders(c.CustomHeaders), ospkg.RemoteURL(c.RemoteAddr), library.RemoteURL(c.RemoteAddr)) results, err := scanner.ScanImage(c.ImageName, c.Input, scanOptions) if err != nil { diff --git a/internal/client/wire_gen.go b/internal/client/wire_gen.go index c79e8f0831ef..6eb753104cda 100644 --- a/internal/client/wire_gen.go +++ b/internal/client/wire_gen.go @@ -17,12 +17,12 @@ import ( // Injectors from inject.go: -func initializeScanner(ospkgToken ospkg.Token, libToken library.Token, ospkgURL ospkg.RemoteURL, libURL library.RemoteURL) scanner.Scanner { +func initializeScanner(ospkgCustomHeaders ospkg.CustomHeaders, libraryCustomHeaders library.CustomHeaders, ospkgURL ospkg.RemoteURL, libURL library.RemoteURL) scanner.Scanner { osDetector := ospkg.NewProtobufClient(ospkgURL) - detector := ospkg.NewDetector(ospkgToken, osDetector) + detector := ospkg.NewDetector(ospkgCustomHeaders, osDetector) ospkgScanner := ospkg2.NewScanner(detector) libDetector := library.NewProtobufClient(libURL) - libraryDetector := library.NewDetector(libToken, libDetector) + libraryDetector := library.NewDetector(libraryCustomHeaders, libDetector) libraryScanner := library2.NewScanner(libraryDetector) scannerScanner := scanner.NewScanner(ospkgScanner, libraryScanner) return scannerScanner diff --git a/internal/server/config/config.go b/internal/server/config/config.go index 7e6e591cea71..23816066bfbc 100644 --- a/internal/server/config/config.go +++ b/internal/server/config/config.go @@ -15,8 +15,9 @@ type Config struct { DownloadDBOnly bool SkipUpdate bool - Listen string - Token string + Listen string + Token string + TokenHeader string // these variables are generated by Init() AppVersion string @@ -36,6 +37,7 @@ func New(c *cli.Context) Config { SkipUpdate: c.Bool("skip-update"), Listen: c.String("listen"), Token: c.String("token"), + TokenHeader: c.String("token-header"), } } diff --git a/pkg/rpc/client/headers.go b/pkg/rpc/client/headers.go new file mode 100644 index 000000000000..5f7244baa09e --- /dev/null +++ b/pkg/rpc/client/headers.go @@ -0,0 +1,20 @@ +package client + +import ( + "context" + "net/http" + + "github.com/twitchtv/twirp" + + "github.com/aquasecurity/trivy/pkg/log" +) + +func WithCustomHeaders(ctx context.Context, customHeaders http.Header) context.Context { + // Attach the headers to a context + ctxWithToken, err := twirp.WithHTTPRequestHeaders(ctx, customHeaders) + if err != nil { + log.Logger.Warnf("twirp error setting headers: %s", err) + return ctx + } + return ctxWithToken +} diff --git a/pkg/rpc/client/headers_test.go b/pkg/rpc/client/headers_test.go new file mode 100644 index 000000000000..09900c3e1e14 --- /dev/null +++ b/pkg/rpc/client/headers_test.go @@ -0,0 +1,59 @@ +package client + +import ( + "context" + "net/http" + "os" + "testing" + + "github.com/twitchtv/twirp" + + "github.com/aquasecurity/trivy/pkg/log" + + "github.com/stretchr/testify/assert" +) + +func TestMain(m *testing.M) { + _ = log.InitLogger(false, true) + os.Exit(m.Run()) +} + +func TestWithCustomHeaders(t *testing.T) { + type args struct { + ctx context.Context + customHeaders http.Header + } + tests := []struct { + name string + args args + want http.Header + }{ + { + name: "happy path", + args: args{ + ctx: context.Background(), + customHeaders: http.Header{ + "Trivy-Token": []string{"token"}, + }, + }, + want: http.Header{ + "Trivy-Token": []string{"token"}, + }, + }, + { + name: "sad path, invalid headers passed in", + args: args{ + ctx: context.Background(), + customHeaders: http.Header{ + "Content-Type": []string{"token"}, + }, + }, + want: http.Header(nil), + }, + } + for _, tt := range tests { + gotCtx := WithCustomHeaders(tt.args.ctx, tt.args.customHeaders) + header, _ := twirp.HTTPRequestHeaders(gotCtx) + assert.Equal(t, tt.want, header, tt.name) + } +} diff --git a/pkg/rpc/client/library/client.go b/pkg/rpc/client/library/client.go index fb2f76e99503..cb6ff64ee8e2 100644 --- a/pkg/rpc/client/library/client.go +++ b/pkg/rpc/client/library/client.go @@ -28,19 +28,19 @@ func NewProtobufClient(remoteURL RemoteURL) rpc.LibDetector { return rpc.NewLibDetectorProtobufClient(string(remoteURL), &http.Client{}) } -type Token string +type CustomHeaders http.Header type Detector struct { - token Token - client rpc.LibDetector + customHeaders CustomHeaders + client rpc.LibDetector } -func NewDetector(token Token, detector rpc.LibDetector) Detector { - return Detector{token: token, client: detector} +func NewDetector(customHeaders CustomHeaders, detector rpc.LibDetector) Detector { + return Detector{customHeaders: customHeaders, client: detector} } func (d Detector) Detect(filePath string, libs []ptypes.Library) ([]types.DetectedVulnerability, error) { - ctx := client.WithToken(context.Background(), string(d.token)) + ctx := client.WithCustomHeaders(context.Background(), http.Header(d.customHeaders)) res, err := d.client.Detect(ctx, &rpc.LibDetectRequest{ FilePath: filePath, Libraries: r.ConvertToRpcLibraries(libs), diff --git a/pkg/rpc/client/library/client_test.go b/pkg/rpc/client/library/client_test.go index 64b1c2ea2889..2486935b36ea 100644 --- a/pkg/rpc/client/library/client_test.go +++ b/pkg/rpc/client/library/client_test.go @@ -48,8 +48,9 @@ func TestDetectClient_Detect(t *testing.T) { } type fields struct { - token Token + customHeaders CustomHeaders } + type args struct { filePath string libs []ptypes.Library @@ -65,7 +66,9 @@ func TestDetectClient_Detect(t *testing.T) { { name: "happy path", fields: fields{ - token: "token", + customHeaders: CustomHeaders{ + "Trivy-Token": []string{"token"}, + }, }, args: args{ filePath: "app/Pipfile.lock", @@ -141,7 +144,7 @@ func TestDetectClient_Detect(t *testing.T) { mockDetector.On("Detect", mock.Anything, tt.detect.input.req).Return( tt.detect.output.res, tt.detect.output.err) - d := NewDetector(tt.fields.token, mockDetector) + d := NewDetector(tt.fields.customHeaders, mockDetector) got, err := d.Detect(tt.args.filePath, tt.args.libs) if tt.wantErr != "" { require.NotNil(t, err, tt.name) diff --git a/pkg/rpc/client/ospkg/client.go b/pkg/rpc/client/ospkg/client.go index 8c1706bf762e..2a6e8fd6d772 100644 --- a/pkg/rpc/client/ospkg/client.go +++ b/pkg/rpc/client/ospkg/client.go @@ -28,19 +28,19 @@ func NewProtobufClient(remoteURL RemoteURL) rpc.OSDetector { return rpc.NewOSDetectorProtobufClient(string(remoteURL), &http.Client{}) } -type Token string +type CustomHeaders http.Header type Detector struct { - token Token - client rpc.OSDetector + customHeaders CustomHeaders + client rpc.OSDetector } -func NewDetector(token Token, detector rpc.OSDetector) Detector { - return Detector{token: token, client: detector} +func NewDetector(customHeaders CustomHeaders, detector rpc.OSDetector) Detector { + return Detector{customHeaders: customHeaders, client: detector} } func (d Detector) Detect(osFamily, osName string, pkgs []analyzer.Package) ([]types.DetectedVulnerability, bool, error) { - ctx := client.WithToken(context.Background(), string(d.token)) + ctx := client.WithCustomHeaders(context.Background(), http.Header(d.customHeaders)) res, err := d.client.Detect(ctx, &rpc.OSDetectRequest{ OsFamily: osFamily, OsName: osName, diff --git a/pkg/rpc/client/ospkg/client_test.go b/pkg/rpc/client/ospkg/client_test.go index 6f8d34f67d34..035a6ef4226e 100644 --- a/pkg/rpc/client/ospkg/client_test.go +++ b/pkg/rpc/client/ospkg/client_test.go @@ -48,7 +48,7 @@ func TestDetectClient_Detect(t *testing.T) { } type fields struct { - token Token + customHeaders CustomHeaders } type args struct { osFamily string @@ -66,7 +66,9 @@ func TestDetectClient_Detect(t *testing.T) { { name: "happy path", fields: fields{ - token: "token", + customHeaders: CustomHeaders{ + "Trivy-Token": []string{"token"}, + }, }, args: args{ osFamily: "alpine", @@ -168,7 +170,7 @@ func TestDetectClient_Detect(t *testing.T) { mockDetector.On("Detect", mock.Anything, tt.detect.input.req).Return( tt.detect.output.res, tt.detect.output.err) - d := NewDetector(tt.fields.token, mockDetector) + d := NewDetector(tt.fields.customHeaders, mockDetector) got, _, err := d.Detect(tt.args.osFamily, tt.args.osName, tt.args.pkgs) if tt.wantErr != "" { require.NotNil(t, err, tt.name) diff --git a/pkg/rpc/client/token.go b/pkg/rpc/client/token.go deleted file mode 100644 index fa71f83f7961..000000000000 --- a/pkg/rpc/client/token.go +++ /dev/null @@ -1,35 +0,0 @@ -package client - -import ( - "context" - "net/http" - - "github.com/twitchtv/twirp" - - "github.com/aquasecurity/trivy/pkg/log" -) - -var ( - buildRequestHeaderFunc = buildRequestHeader -) - -func buildRequestHeader(inputHeaders map[string]string) http.Header { - header := make(http.Header) - for k, v := range inputHeaders { - header.Set(k, v) - } - return header -} - -func WithToken(ctx context.Context, token string) context.Context { - // Prepare custom header - header := buildRequestHeaderFunc(map[string]string{"Trivy-Token": token}) - - // Attach the headers to a context - ctxWithToken, err := twirp.WithHTTPRequestHeaders(ctx, header) - if err != nil { - log.Logger.Warnf("twirp error setting headers: %s", err) - return ctx - } - return ctxWithToken -} diff --git a/pkg/rpc/client/token_test.go b/pkg/rpc/client/token_test.go deleted file mode 100644 index 2ce79da3d469..000000000000 --- a/pkg/rpc/client/token_test.go +++ /dev/null @@ -1,72 +0,0 @@ -package client - -import ( - "context" - "net/http" - "os" - "testing" - - "github.com/twitchtv/twirp" - - "github.com/aquasecurity/trivy/pkg/log" - - "github.com/stretchr/testify/assert" -) - -func TestMain(m *testing.M) { - _ = log.InitLogger(false, true) - os.Exit(m.Run()) -} - -func TestWithToken(t *testing.T) { - type args struct { - ctx context.Context - token string - } - tests := []struct { - name string - args args - buildRequestHeaderFunc func(map[string]string) http.Header - want http.Header - }{ - { - name: "happy path", - args: args{ - ctx: context.Background(), - token: "token", - }, - want: http.Header{ - "Trivy-Token": []string{"token"}, - }, - buildRequestHeaderFunc: buildRequestHeader, - }, - { - name: "sad path, invalid headers passed in", - args: args{ - ctx: context.Background(), - token: "token", - }, - want: http.Header(nil), - buildRequestHeaderFunc: func(m map[string]string) http.Header { - header := make(http.Header) - for k, v := range m { - header.Set(k, v) - } - - // add an extra header that is reserved for twirp - header.Set("Content-Type", "foobar") - return header - }, - }, - } - for _, tt := range tests { - oldbuildRequestHeaderFunc := buildRequestHeaderFunc - defer func() { - buildRequestHeaderFunc = oldbuildRequestHeaderFunc - }() - buildRequestHeaderFunc = tt.buildRequestHeaderFunc - gotCtx := WithToken(tt.args.ctx, tt.args.token) - header, _ := twirp.HTTPRequestHeaders(gotCtx) - assert.Equal(t, tt.want, header, tt.name) - } -} diff --git a/pkg/rpc/server/server.go b/pkg/rpc/server/server.go index bd45634f5462..5ab89dbd5e07 100644 --- a/pkg/rpc/server/server.go +++ b/pkg/rpc/server/server.go @@ -58,19 +58,19 @@ func ListenAndServe(addr string, c config.Config) error { mux := http.NewServeMux() osHandler := rpc.NewOSDetectorServer(initializeOspkgServer(), nil) - mux.Handle(rpc.OSDetectorPathPrefix, withToken(withWaitGroup(osHandler), c.Token)) + mux.Handle(rpc.OSDetectorPathPrefix, withToken(withWaitGroup(osHandler), c.Token, c.TokenHeader)) libHandler := rpc.NewLibDetectorServer(initializeLibServer(), nil) - mux.Handle(rpc.LibDetectorPathPrefix, withToken(withWaitGroup(libHandler), c.Token)) + mux.Handle(rpc.LibDetectorPathPrefix, withToken(withWaitGroup(libHandler), c.Token, c.TokenHeader)) log.Logger.Infof("Listening %s...", addr) return http.ListenAndServe(addr, mux) } -func withToken(base http.Handler, token string) http.Handler { +func withToken(base http.Handler, token, tokenHeader string) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if token != "" && token != r.Header.Get("Trivy-Token") { + if token != "" && token != r.Header.Get(tokenHeader) { rpc.WriteError(w, twirp.NewError(twirp.Unauthenticated, "invalid token")) return }