Skip to content

Commit

Permalink
return err on subcriber callback (dgraph-io#1167)
Browse files Browse the repository at this point in the history
  • Loading branch information
poonai authored Dec 20, 2019
1 parent 779d9a0 commit ab4352b
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 19 deletions.
4 changes: 2 additions & 2 deletions badger/cmd/bank.go
Original file line number Diff line number Diff line change
Expand Up @@ -544,13 +544,13 @@ func runTest(cmd *cobra.Command, args []string) error {
for i := 0; i < numAccounts; i++ {
accountIDS = append(accountIDS, key(i))
}
updater := func(kvs *pb.KVList) {
updater := func(kvs *pb.KVList) error {
batch := subscribeDB.NewWriteBatch()
for _, kv := range kvs.GetKv() {
y.Check(batch.Set(kv.Key, kv.Value))
}

y.Check(batch.Flush())
return batch.Flush()
}
_ = db.Subscribe(ctx, updater, accountIDS...)
}()
Expand Down
30 changes: 17 additions & 13 deletions db.go
Original file line number Diff line number Diff line change
Expand Up @@ -1563,7 +1563,7 @@ type KVList = pb.KVList
// This function blocks until the given context is done or an error occurs.
// The given function will be called with a new KVList containing the modified keys and the
// corresponding values.
func (db *DB) Subscribe(ctx context.Context, cb func(kv *KVList), prefixes ...[]byte) error {
func (db *DB) Subscribe(ctx context.Context, cb func(kv *KVList) error, prefixes ...[]byte) error {
if cb == nil {
return ErrNilCallback
}
Expand All @@ -1572,37 +1572,41 @@ func (db *DB) Subscribe(ctx context.Context, cb func(kv *KVList), prefixes ...[]
}
c := y.NewCloser(1)
recvCh, id := db.pub.newSubscriber(c, prefixes...)
slurp := func(batch *pb.KVList) {
defer func() {
if len(batch.GetKv()) > 0 {
cb(batch)
}
}()
slurp := func(batch *pb.KVList) error {
for {
select {
case kvs := <-recvCh:
batch.Kv = append(batch.Kv, kvs.Kv...)
default:
return
if len(batch.GetKv()) > 0 {
return cb(batch)
}
return nil
}
}
}
for {
select {
case <-c.HasBeenClosed():
slurp(new(pb.KVList))
// Drain if any pending updates.
c.Done()
// No need to delete here. Closer will be called only while
// closing DB. Subscriber will be deleted by cleanSubscribers.
return nil
err := slurp(new(pb.KVList))
// Drain if any pending updates.
c.Done()
return err
case <-ctx.Done():
c.Done()
db.pub.deleteSubscriber(id)
// Delete the subscriber to avoid further updates.
return ctx.Err()
case batch := <-recvCh:
slurp(batch)
err := slurp(batch)
if err != nil {
c.Done()
// Delete the subsriber if there is an error by the callback.
db.pub.deleteSubscriber(id)
return err
}
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1634,10 +1634,11 @@ func TestGoroutineLeak(t *testing.T) {
var wg sync.WaitGroup
wg.Add(1)
go func() {
err := db.Subscribe(ctx, func(kvs *pb.KVList) {
err := db.Subscribe(ctx, func(kvs *pb.KVList) error {
require.Equal(t, []byte("value"), kvs.Kv[0].GetValue())
updated = true
wg.Done()
return nil
}, []byte("key"))
if err != nil {
require.Equal(t, err.Error(), context.Canceled.Error())
Expand Down Expand Up @@ -1994,10 +1995,11 @@ func ExampleDB_Subscribe() {
wg.Add(1)
go func() {
defer wg.Done()
cb := func(kvs *KVList) {
cb := func(kvs *KVList) error {
for _, kv := range kvs.Kv {
fmt.Printf("%s is now set to %s\n", kv.Key, kv.Value)
}
return nil
}
if err := db.Subscribe(ctx, cb, prefix); err != nil && err != context.Canceled {
log.Fatal(err)
Expand Down
6 changes: 4 additions & 2 deletions publisher_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,15 @@ func TestPublisherOrdering(t *testing.T) {
go func() {
subWg.Done()
updates := 0
err := db.Subscribe(context.Background(), func(kvs *pb.KVList) {
err := db.Subscribe(context.Background(), func(kvs *pb.KVList) error {
updates += len(kvs.GetKv())
for _, kv := range kvs.GetKv() {
order = append(order, string(kv.Value))
}
if updates == 5 {
wg.Done()
}
return nil
}, []byte("ke"))
if err != nil {
require.Equal(t, err.Error(), context.Canceled.Error())
Expand Down Expand Up @@ -72,7 +73,7 @@ func TestMultiplePrefix(t *testing.T) {
go func() {
subWg.Done()
updates := 0
err := db.Subscribe(context.Background(), func(kvs *pb.KVList) {
err := db.Subscribe(context.Background(), func(kvs *pb.KVList) error {
updates += len(kvs.GetKv())
for _, kv := range kvs.GetKv() {
if string(kv.Key) == "key" {
Expand All @@ -84,6 +85,7 @@ func TestMultiplePrefix(t *testing.T) {
if updates == 2 {
wg.Done()
}
return nil
}, []byte("ke"), []byte("hel"))
if err != nil {
require.Equal(t, err.Error(), context.Canceled.Error())
Expand Down

0 comments on commit ab4352b

Please sign in to comment.