Skip to content

Commit

Permalink
add WatchChan method to WatchSet
Browse files Browse the repository at this point in the history
  • Loading branch information
abennett committed Apr 22, 2020
1 parent ac04d7d commit 9ed1fd1
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 0 deletions.
22 changes: 22 additions & 0 deletions watch.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,25 @@ func (w WatchSet) watchMany(ctx context.Context) error {
return ctx.Err()
}
}

// WatchChan returns a channel that is used to wait for either the watch set
// to trigger or for the context to be cancelled.
func (w WatchSet) WatchChan(ctx context.Context) <-chan error {
// Create the outgoing channel
triggerCh := make(chan error)

// Create a goroutine to collect the errors from WatchCtx
go func() {
for {
err := w.WatchCtx(ctx)
triggerCh <- err

// Exit function if context is cancelled
if err != nil {
return
}
}
}()

return triggerCh
}
71 changes: 71 additions & 0 deletions watch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,77 @@ func TestWatch(t *testing.T) {
t.Run("Context", testFactory(true))
}

func testWatchChan(size, fire int) error {
shouldTimeout := true
ws := NewWatchSet()
for i := 0; i < size; i++ {
watchCh := make(chan struct{})
ws.Add(watchCh)
if fire == i {
close(watchCh)
shouldTimeout = false
}
}

ctx, cancelFn := context.WithCancel(context.Background())
defer cancelFn()

doneCh := make(chan bool, 1)
go func() {
err := <-ws.WatchChan(ctx)
doneCh <- err != nil
}()

if shouldTimeout {
select {
case <-doneCh:
return fmt.Errorf("should not trigger")
default:
}

cancelFn()
select {
case didTimeout := <-doneCh:
if !didTimeout {
return fmt.Errorf("should have timed out")
}
case <-time.After(10 * time.Second):
return fmt.Errorf("should have timed out")
}
} else {
select {
case didTimeout := <-doneCh:
if didTimeout {
return fmt.Errorf("should not have timed out")
}
case <-time.After(10 * time.Second):
return fmt.Errorf("should have triggered")
}
cancelFn()
}
return nil
}

func TestWatchChan(t *testing.T) {

// Sweep through a bunch of chunks to hit the various cases of dividing
// the work into watchFew calls.
for size := 0; size < 3*aFew; size++ {
// Fire each possible channel slot.
for fire := 0; fire < size; fire++ {
if err := testWatchChan(size, fire); err != nil {
t.Fatalf("err %d %d: %v", size, fire, err)
}
}

// Run a timeout case as well.
fire := -1
if err := testWatchChan(size, fire); err != nil {
t.Fatalf("err %d %d: %v", size, fire, err)
}
}
}

func TestWatch_AddWithLimit(t *testing.T) {
// Make sure nil doesn't crash.
{
Expand Down

0 comments on commit 9ed1fd1

Please sign in to comment.