From 1426095e108fdfd65d560a13e924f6fcc32f7dfd Mon Sep 17 00:00:00 2001 From: Andrew Harding Date: Fri, 19 Jan 2024 13:48:00 -0700 Subject: [PATCH] Fix racy tests that test streaming RPCs (#4810) * Fix racy tests that test streaming RPCs Signed-off-by: Andrew Harding --- pkg/agent/endpoints/workload/handler_test.go | 36 ++------------------ pkg/server/api/entry/v1/service_test.go | 8 ++++- test/spiretest/apiserver.go | 34 ++++++++++++++++++ 3 files changed, 44 insertions(+), 34 deletions(-) diff --git a/pkg/agent/endpoints/workload/handler_test.go b/pkg/agent/endpoints/workload/handler_test.go index 946a80bd99..9862e93517 100644 --- a/pkg/agent/endpoints/workload/handler_test.go +++ b/pkg/agent/endpoints/workload/handler_test.go @@ -34,7 +34,6 @@ import ( "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" - "google.golang.org/grpc/stats" "google.golang.org/protobuf/types/known/structpb" ) @@ -1509,18 +1508,18 @@ func runTest(t *testing.T, params testParams, fn func(ctx context.Context, clien AllowedForeignJWTClaims: params.AllowedForeignJWTClaims, }) + drainHandler := spiretest.NewDrainHandlerMiddleware() unaryInterceptor, streamInterceptor := middleware.Interceptors(middleware.Chain( + drainHandler, middleware.WithLogger(log), middleware.Preprocess(func(ctx context.Context, fullMethod string, req any) (context.Context, error) { return rpccontext.WithCallerPID(ctx, params.AsPID), nil }), )) - sh := newStatsHandler() server := grpc.NewServer( grpc.UnaryInterceptor(unaryInterceptor), grpc.StreamInterceptor(streamInterceptor), - grpc.StatsHandler(sh), ) workloadPB.RegisterSpiffeWorkloadAPIServer(server, handler) addr := spiretest.ServeGRPCServerOnTempUDSSocket(t, server) @@ -1547,9 +1546,7 @@ func runTest(t *testing.T, params testParams, fn func(ctx context.Context, clien // reports that all RPCs are complete before checking that Finish was // called. server.GracefulStop() - assert.Eventually(t, func() bool { - return sh.Outstanding() == 0 - }, time.Second*10, time.Millisecond*50) + drainHandler.Wait() assert.Equal(t, 0, manager.Subscribers(), "there should be no more subscribers") @@ -1667,30 +1664,3 @@ func pkcs8FromSigner(t *testing.T, key crypto.Signer) []byte { require.NoError(t, err) return keyBytes } - -type statsHandler struct { - outstanding int32 -} - -func newStatsHandler() *statsHandler { - return &statsHandler{} -} - -func (c *statsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { return ctx } - -func (c *statsHandler) HandleRPC(_ context.Context, s stats.RPCStats) { - switch s.(type) { - case *stats.Begin: - atomic.AddInt32(&c.outstanding, 1) - case *stats.End: - atomic.AddInt32(&c.outstanding, -1) - } -} - -func (c *statsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { return ctx } - -func (c *statsHandler) HandleConn(_ context.Context, _ stats.ConnStats) {} - -func (c *statsHandler) Outstanding() int { - return int(atomic.LoadInt32(&c.outstanding)) -} diff --git a/pkg/server/api/entry/v1/service_test.go b/pkg/server/api/entry/v1/service_test.go index 148370e126..d31a778792 100644 --- a/pkg/server/api/entry/v1/service_test.go +++ b/pkg/server/api/entry/v1/service_test.go @@ -4769,7 +4769,10 @@ func setupServiceTest(t *testing.T, ds datastore.DataStore, options ...serviceTe return ctx, nil }) + drainHandler := spiretest.NewDrainHandlerMiddleware() + unaryInterceptor, streamInterceptor := middleware.Interceptors(middleware.Chain( + drainHandler, ppMiddleware, // Add audit log with local tracking disabled middleware.WithAuditLog(false), @@ -4780,7 +4783,10 @@ func setupServiceTest(t *testing.T, ds datastore.DataStore, options ...serviceTe ) conn, done := spiretest.NewAPIServerWithMiddleware(t, registerFn, server) - test.done = done + test.done = func() { + done() + drainHandler.Wait() + } test.client = entryv1.NewEntryClient(conn) return test diff --git a/test/spiretest/apiserver.go b/test/spiretest/apiserver.go index 37dea68534..c4a1947c6a 100644 --- a/test/spiretest/apiserver.go +++ b/test/spiretest/apiserver.go @@ -4,6 +4,7 @@ import ( "context" "errors" "net" + "sync" "testing" "github.com/stretchr/testify/assert" @@ -54,6 +55,39 @@ func newAPIServer(tb testing.TB, registerFn func(s *grpc.Server), server *grpc.S return conn, done } +type DrainHandlerMiddleware struct { + wg sync.WaitGroup +} + +func NewDrainHandlerMiddleware() *DrainHandlerMiddleware { + return &DrainHandlerMiddleware{} +} + +func (m *DrainHandlerMiddleware) Wait() { + m.wg.Wait() +} + +func (m *DrainHandlerMiddleware) Preprocess(ctx context.Context, _ string, _ any) (context.Context, error) { + m.wg.Add(1) + return ctx, nil +} + +func (m *DrainHandlerMiddleware) Postprocess(context.Context, string, bool, error) { + m.wg.Done() +} + +func (m *DrainHandlerMiddleware) UnaryServerInterceptor(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + m.wg.Add(1) + defer m.wg.Done() + return handler(ctx, req) +} + +func (m *DrainHandlerMiddleware) StreamServerInterceptor(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + m.wg.Add(1) + defer m.wg.Done() + return handler(srv, ss) +} + func unaryInterceptor(fn func(ctx context.Context) context.Context) func(context.Context, any, *grpc.UnaryServerInfo, grpc.UnaryHandler) (any, error) { return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { return handler(fn(ctx), req)