Skip to content

Commit

Permalink
Fix racy tests that test streaming RPCs (spiffe#4810)
Browse files Browse the repository at this point in the history
* Fix racy tests that test streaming RPCs

Signed-off-by: Andrew Harding <[email protected]>
  • Loading branch information
azdagron authored Jan 19, 2024
1 parent cb46cb6 commit 1426095
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 34 deletions.
36 changes: 3 additions & 33 deletions pkg/agent/endpoints/workload/handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/stats"
"google.golang.org/protobuf/types/known/structpb"
)

Expand Down Expand Up @@ -1509,18 +1508,18 @@ func runTest(t *testing.T, params testParams, fn func(ctx context.Context, clien
AllowedForeignJWTClaims: params.AllowedForeignJWTClaims,
})

drainHandler := spiretest.NewDrainHandlerMiddleware()
unaryInterceptor, streamInterceptor := middleware.Interceptors(middleware.Chain(
drainHandler,
middleware.WithLogger(log),
middleware.Preprocess(func(ctx context.Context, fullMethod string, req any) (context.Context, error) {
return rpccontext.WithCallerPID(ctx, params.AsPID), nil
}),
))

sh := newStatsHandler()
server := grpc.NewServer(
grpc.UnaryInterceptor(unaryInterceptor),
grpc.StreamInterceptor(streamInterceptor),
grpc.StatsHandler(sh),
)
workloadPB.RegisterSpiffeWorkloadAPIServer(server, handler)
addr := spiretest.ServeGRPCServerOnTempUDSSocket(t, server)
Expand All @@ -1547,9 +1546,7 @@ func runTest(t *testing.T, params testParams, fn func(ctx context.Context, clien
// reports that all RPCs are complete before checking that Finish was
// called.
server.GracefulStop()
assert.Eventually(t, func() bool {
return sh.Outstanding() == 0
}, time.Second*10, time.Millisecond*50)
drainHandler.Wait()

assert.Equal(t, 0, manager.Subscribers(), "there should be no more subscribers")

Expand Down Expand Up @@ -1667,30 +1664,3 @@ func pkcs8FromSigner(t *testing.T, key crypto.Signer) []byte {
require.NoError(t, err)
return keyBytes
}

type statsHandler struct {
outstanding int32
}

func newStatsHandler() *statsHandler {
return &statsHandler{}
}

func (c *statsHandler) TagRPC(ctx context.Context, _ *stats.RPCTagInfo) context.Context { return ctx }

func (c *statsHandler) HandleRPC(_ context.Context, s stats.RPCStats) {
switch s.(type) {
case *stats.Begin:
atomic.AddInt32(&c.outstanding, 1)
case *stats.End:
atomic.AddInt32(&c.outstanding, -1)
}
}

func (c *statsHandler) TagConn(ctx context.Context, _ *stats.ConnTagInfo) context.Context { return ctx }

func (c *statsHandler) HandleConn(_ context.Context, _ stats.ConnStats) {}

func (c *statsHandler) Outstanding() int {
return int(atomic.LoadInt32(&c.outstanding))
}
8 changes: 7 additions & 1 deletion pkg/server/api/entry/v1/service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4769,7 +4769,10 @@ func setupServiceTest(t *testing.T, ds datastore.DataStore, options ...serviceTe
return ctx, nil
})

drainHandler := spiretest.NewDrainHandlerMiddleware()

unaryInterceptor, streamInterceptor := middleware.Interceptors(middleware.Chain(
drainHandler,
ppMiddleware,
// Add audit log with local tracking disabled
middleware.WithAuditLog(false),
Expand All @@ -4780,7 +4783,10 @@ func setupServiceTest(t *testing.T, ds datastore.DataStore, options ...serviceTe
)

conn, done := spiretest.NewAPIServerWithMiddleware(t, registerFn, server)
test.done = done
test.done = func() {
done()
drainHandler.Wait()
}
test.client = entryv1.NewEntryClient(conn)

return test
Expand Down
34 changes: 34 additions & 0 deletions test/spiretest/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"net"
"sync"
"testing"

"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -54,6 +55,39 @@ func newAPIServer(tb testing.TB, registerFn func(s *grpc.Server), server *grpc.S
return conn, done
}

type DrainHandlerMiddleware struct {
wg sync.WaitGroup
}

func NewDrainHandlerMiddleware() *DrainHandlerMiddleware {
return &DrainHandlerMiddleware{}
}

func (m *DrainHandlerMiddleware) Wait() {
m.wg.Wait()
}

func (m *DrainHandlerMiddleware) Preprocess(ctx context.Context, _ string, _ any) (context.Context, error) {
m.wg.Add(1)
return ctx, nil
}

func (m *DrainHandlerMiddleware) Postprocess(context.Context, string, bool, error) {
m.wg.Done()
}

func (m *DrainHandlerMiddleware) UnaryServerInterceptor(ctx context.Context, req any, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
m.wg.Add(1)
defer m.wg.Done()
return handler(ctx, req)
}

func (m *DrainHandlerMiddleware) StreamServerInterceptor(srv any, ss grpc.ServerStream, _ *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
m.wg.Add(1)
defer m.wg.Done()
return handler(srv, ss)
}

func unaryInterceptor(fn func(ctx context.Context) context.Context) func(context.Context, any, *grpc.UnaryServerInfo, grpc.UnaryHandler) (any, error) {
return func(ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
return handler(fn(ctx), req)
Expand Down

0 comments on commit 1426095

Please sign in to comment.