Skip to content

Commit

Permalink
Pass serviceAccountToken to freezer service (knative#11917)
Browse files Browse the repository at this point in the history
* pass serviceaccounttoken to freezer service

* fixing temp file/dir creation in test

* removing ioutil

* moving to a struct

* fixing tests

* updates

* fixing up token handling

* cleanup

* using envvar for state-token passing

* always set tokenpath envvar

* token_path not token
  • Loading branch information
psschwei authored Sep 28, 2021
1 parent 1c7f409 commit 67429ed
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 51 deletions.
11 changes: 9 additions & 2 deletions cmd/queue/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ type config struct {
TracingConfigZipkinEndpoint string `split_words:"true"` // optional

// Concurrency State Endpoint configuration
ConcurrencyStateEndpoint string `split_words:"true"` // optional
ConcurrencyStateEndpoint string `split_words:"true"` // optional
ConcurrencyStateTokenPath string `split_words:"true"` // optional
}

func init() {
Expand Down Expand Up @@ -300,7 +301,13 @@ func buildServer(ctx context.Context, env config, healthState *health.State, rp
var composedHandler http.Handler = httpProxy
if concurrencyStateEnabled {
logger.Info("Concurrency state endpoint set, tracking request counts, using endpoint: ", env.ConcurrencyStateEndpoint)
composedHandler = queue.ConcurrencyStateHandler(logger, composedHandler, queue.Pause(env.ConcurrencyStateEndpoint), queue.Resume(env.ConcurrencyStateEndpoint))
ce := queue.NewConcurrencyEndpoint(env.ConcurrencyStateEndpoint, env.ConcurrencyStateTokenPath)
go func() {
for range time.NewTicker(1 * time.Minute).C {
ce.RefreshToken()
}
}()
composedHandler = queue.ConcurrencyStateHandler(logger, composedHandler, ce.Pause, ce.Resume)
}
if metricsSupported {
composedHandler = requestAppMetricsHandler(logger, composedHandler, breaker, env)
Expand Down
67 changes: 40 additions & 27 deletions pkg/queue/concurrency_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,6 @@ import (
"go.uber.org/zap"
)

//nolint:gosec // Filepath, not hardcoded credentials
const ConcurrencyStateTokenVolumeMountPath = "/var/run/secrets/tokens"

// ConcurrencyStateHandler tracks the in flight requests for the pod. When the requests
// drop to zero, it runs the `pause` function, and when requests scale up from zero, it
// runs the `resume` function. If either of `pause` or `resume` are not passed, it runs
Expand Down Expand Up @@ -93,33 +90,49 @@ func ConcurrencyStateHandler(logger *zap.SugaredLogger, h http.Handler, pause, r
}
}

// concurrencyStateRequest sends a request to the concurrency state endpoint.
func concurrencyStateRequest(endpoint string, action string) func() error {
return func() error {
bodyText := fmt.Sprintf(`{ "action": %q }`, action)
body := bytes.NewBufferString(bodyText)
req, err := http.NewRequest(http.MethodPost, endpoint, body)
if err != nil {
return fmt.Errorf("unable to create request: %w", err)
}
req.Header.Add("Token", "nil") // TODO: use serviceaccountToken from projected volume
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("unable to post request: %w", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("expected 200 response, got: %d: %s", resp.StatusCode, resp.Status)
}
return nil
type ConcurrencyEndpoint struct {
endpoint string
mountPath string
token atomic.Value
}

func NewConcurrencyEndpoint(e, m string) ConcurrencyEndpoint {
c := ConcurrencyEndpoint{
endpoint: e,
mountPath: m,
}
c.RefreshToken()
return c
}

// Pause sends a pause request to the concurrency state endpoint.
func Pause(endpoint string) func() error {
return concurrencyStateRequest(endpoint, "pause")
func (c ConcurrencyEndpoint) Pause() error { return c.Request("pause") }

func (c ConcurrencyEndpoint) Resume() error { return c.Request("resume") }

func (c ConcurrencyEndpoint) Request(action string) error {
bodyText := fmt.Sprintf(`{ "action": %q }`, action)
body := bytes.NewBufferString(bodyText)
req, err := http.NewRequest(http.MethodPost, c.endpoint, body)
if err != nil {
return fmt.Errorf("unable to create request: %w", err)
}
token := fmt.Sprint(c.token.Load())
req.Header.Add("Token", token)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return fmt.Errorf("unable to post request: %w", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("expected 200 response, got: %d: %s", resp.StatusCode, resp.Status)
}
return nil
}

// Resume sends a resume request to the concurrency state endpoint.
func Resume(endpoint string) func() error {
return concurrencyStateRequest(endpoint, "resume")
func (c *ConcurrencyEndpoint) RefreshToken() error {
token, err := os.ReadFile(c.mountPath)
if err != nil {
return fmt.Errorf("could not read token: %w", err)
}
c.token.Store(string(token))
return nil
}
116 changes: 97 additions & 19 deletions pkg/queue/concurrency_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"

Expand Down Expand Up @@ -140,20 +142,50 @@ func TestConcurrencyStateHandlerParallelOverlapping(t *testing.T) {
}
}

func TestConcurrencyStateRequestHeader(t *testing.T) {
func TestConcurrencyStateTokenRefresh(t *testing.T) {
var token string
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for k, v := range r.Header {
if k == "Token" {
// TODO update when using token (https://github.com/knative/serving/issues/11904)
if v[0] != "nil" {
t.Errorf("incorrect token header, expected 'nil', got %s", v)
}
}
tk := r.Header.Get("Token")
if tk != token {
t.Errorf("incorrect token header, expected %s, got %s", token, tk)
}
}))
tokenPath := filepath.Join(t.TempDir(), "secret")
token = "0123456789"
if err := os.WriteFile(tokenPath, []byte(token), 0700); err != nil {
t.Fatal(err)
}

c := NewConcurrencyEndpoint(ts.URL, tokenPath)
if err := c.Pause(); err != nil {
t.Errorf("initial token check returned an error: %s", err)
}

token = "abcdefghijklmnop"
if err := os.WriteFile(tokenPath, []byte(token), 0700); err != nil {
t.Fatal(err)
}
c.RefreshToken()
if err := c.Pause(); err != nil {
t.Errorf("updated token check returned an error: %s", err)
}
}

func TestConcurrencyStatePauseHeader(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Token")
if token != "0123456789" {
t.Errorf("incorrect token header, expected '0123456789', got %s", token)
}
}))
pause := Pause(ts.URL)
if err := pause(); err != nil {
t.Errorf("header check returned an error: %s", err)

tokenPath := filepath.Join(t.TempDir(), "secret")
if err := os.WriteFile(tokenPath, []byte("0123456789"), 0700); err != nil {
t.Fatal(err)
}
c := NewConcurrencyEndpoint(ts.URL, tokenPath)
if err := c.Pause(); err != nil {
t.Errorf("pause header check returned an error: %s", err)
}
}

Expand All @@ -170,12 +202,50 @@ func TestConcurrencyStatePauseRequest(t *testing.T) {
}
}))

pause := Pause(ts.URL)
if err := pause(); err != nil {
tokenPath := filepath.Join(t.TempDir(), "secret")
if err := os.WriteFile(tokenPath, []byte("0123456789"), 0700); err != nil {
t.Fatal(err)
}
c := NewConcurrencyEndpoint(ts.URL, tokenPath)
if err := c.Pause(); err != nil {
t.Errorf("request test returned an error: %s", err)
}
}

func TestConcurrencyStatePauseResponse(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}))
defer ts.Close()

tokenPath := filepath.Join(t.TempDir(), "secret")
if err := os.WriteFile(tokenPath, []byte("0123456789"), 0700); err != nil {
t.Fatal(err)
}
c := NewConcurrencyEndpoint(ts.URL, tokenPath)
if err := c.Pause(); err == nil {
t.Errorf("pausefunction did not return an error")
}
}

func TestConcurrencyStateResumeHeader(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
token := r.Header.Get("Token")
if token != "0123456789" {
t.Errorf("incorrect token header, expected '0123456789', got %s", token)
}
}))

tokenPath := filepath.Join(t.TempDir(), "secret")
if err := os.WriteFile(tokenPath, []byte("0123456789"), 0700); err != nil {
t.Fatal(err)
}
c := NewConcurrencyEndpoint(ts.URL, tokenPath)
if err := c.Resume(); err != nil {
t.Errorf("resume header check returned an error: %s", err)
}
}

func TestConcurrencyStateResumeRequest(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
Expand All @@ -189,21 +259,29 @@ func TestConcurrencyStateResumeRequest(t *testing.T) {
}
}))

resume := Resume(ts.URL)
if err := resume(); err != nil {
tokenPath := filepath.Join(t.TempDir(), "secret")
if err := os.WriteFile(tokenPath, []byte("0123456789"), 0700); err != nil {
t.Fatal(err)
}
c := NewConcurrencyEndpoint(ts.URL, tokenPath)
if err := c.Resume(); err != nil {
t.Errorf("request test returned an error: %s", err)
}
}

func TestConcurrencyStateRequestResponse(t *testing.T) {
func TestConcurrencyStateResumeResponse(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}))
defer ts.Close()

pause := Pause(ts.URL)
if err := pause(); err == nil {
t.Errorf("failed function did not return an error")
tokenPath := filepath.Join(t.TempDir(), "secret")
if err := os.WriteFile(tokenPath, []byte("0123456789"), 0700); err != nil {
t.Fatal(err)
}
c := NewConcurrencyEndpoint(ts.URL, tokenPath)
if err := c.Resume(); err == nil {
t.Errorf("resume function did not return an error")
}
}

Expand Down
9 changes: 7 additions & 2 deletions pkg/reconciler/revision/resources/deploy.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ import (
"k8s.io/apimachinery/pkg/util/intstr"
)

//nolint:gosec // Filepath, not hardcoded credentials
const concurrencyStateTokenVolumeMountPath = "/var/run/secrets/tokens"
const concurrencyStateTokenName = "state-token"
const concurrencyStateToken = concurrencyStateTokenVolumeMountPath + "/" + concurrencyStateTokenName

var (
varLogVolume = corev1.Volume{
Name: "knative-var-log",
Expand All @@ -57,7 +62,7 @@ var (
Sources: []corev1.VolumeProjection{{
ServiceAccountToken: &corev1.ServiceAccountTokenProjection{
ExpirationSeconds: ptr.Int64(600),
Path: "state-token",
Path: concurrencyStateTokenName,
Audience: "concurrency-state-hook"},
}},
},
Expand All @@ -66,7 +71,7 @@ var (

varTokenVolumeMount = corev1.VolumeMount{
Name: varTokenVolume.Name,
MountPath: queue.ConcurrencyStateTokenVolumeMountPath,
MountPath: concurrencyStateTokenVolumeMountPath,
}

// This PreStop hook is actually calling an endpoint on the queue-proxy
Expand Down
4 changes: 4 additions & 0 deletions pkg/reconciler/revision/resources/deploy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,9 @@ var (
}, {
Name: "CONCURRENCY_STATE_ENDPOINT",
Value: "",
}, {
Name: "CONCURRENCY_STATE_TOKEN_PATH",
Value: "/var/run/secrets/tokens/state-token",
}, {
Name: "ENABLE_HTTP2_AUTO_DETECTION",
Value: "false",
Expand Down Expand Up @@ -1133,6 +1136,7 @@ func TestMakePodSpec(t *testing.T) {
}}
},
withEnvVar("CONCURRENCY_STATE_ENDPOINT", `freeze-proxy`),
withEnvVar("CONCURRENCY_STATE_TOKEN_PATH", `/var/run/secrets/tokens/state-token`),
),
},
withAppendedVolumes(varTokenVolume),
Expand Down
3 changes: 3 additions & 0 deletions pkg/reconciler/revision/resources/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,9 @@ func makeQueueContainer(rev *v1.Revision, cfg *config.Config) (*corev1.Container
}, {
Name: "CONCURRENCY_STATE_ENDPOINT",
Value: cfg.Deployment.ConcurrencyStateEndpoint,
}, {
Name: "CONCURRENCY_STATE_TOKEN_PATH",
Value: concurrencyStateToken,
}, {
Name: "ENABLE_HTTP2_AUTO_DETECTION",
Value: strconv.FormatBool(cfg.Features.AutoDetectHTTP2 == apicfg.Enabled),
Expand Down
4 changes: 3 additions & 1 deletion pkg/reconciler/revision/resources/queue_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,8 @@ func TestMakeQueueContainer(t *testing.T) {
},
want: queueContainer(func(c *corev1.Container) {
c.Env = env(map[string]string{
"CONCURRENCY_STATE_ENDPOINT": "freeze-proxy",
"CONCURRENCY_STATE_ENDPOINT": "freeze-proxy",
"CONCURRENCY_STATE_TOKEN_PATH": "/var/run/secrets/tokens/state-token",
})
}),
}, {
Expand Down Expand Up @@ -846,6 +847,7 @@ func TestTCPProbeGeneration(t *testing.T) {

var defaultEnv = map[string]string{
"CONCURRENCY_STATE_ENDPOINT": "",
"CONCURRENCY_STATE_TOKEN_PATH": "/var/run/secrets/tokens/state-token",
"CONTAINER_CONCURRENCY": "0",
"ENABLE_HTTP2_AUTO_DETECTION": "false",
"ENABLE_PROFILING": "false",
Expand Down

0 comments on commit 67429ed

Please sign in to comment.