Skip to content

Commit

Permalink
WIP: Experiment with potential fixes for race conditions during backw…
Browse files Browse the repository at this point in the history
…ard pass

This work-in-progress commit explores the use of atomic operations to address race conditions that have been observed during the backward pass.
  • Loading branch information
matteo-grella committed May 2, 2023
1 parent d09b07d commit 40f923c
Showing 1 changed file with 31 additions and 20 deletions.
51 changes: 31 additions & 20 deletions ag/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ package ag
import (
"log"
"reflect"
"runtime"
"sync"
"sync/atomic"

Expand All @@ -28,7 +29,7 @@ func SetForceSyncExecution(enable bool) {

// backwardState is an enumeration type associated to an Operator, to keep
// track of its visited status among different backpropagation phases.
type backwardState byte
type backwardState = uint32

const (
// idle reports that gradient propagation is not pending for an
Expand Down Expand Up @@ -162,7 +163,7 @@ func (o *Operator) Value() mat.Matrix {

// Grad returns the gradients accumulated during the backward pass.
func (o *Operator) Grad() mat.Matrix {
if o.backwardState == idle || atomic.LoadInt64(&o.pendingGrads) == 0 {
if o.isBackwardIdle() || atomic.LoadInt64(&o.pendingGrads) == 0 {
return o.Value().Grad()
}

Expand Down Expand Up @@ -217,7 +218,7 @@ func (o *Operator) AccGrad(grad mat.Matrix) {
o.Value().AccGrad(grad)

// Don't decrement the counter if the backward pass is not running.
if o.backwardState != idle && atomic.AddInt64(&o.pendingGrads, -1) == 0 {
if !o.isBackwardIdle() && atomic.AddInt64(&o.pendingGrads, -1) == 0 {
close(o.broadcastGrad) // notify all goroutines that have been waiting for the gradients
}
}
Expand All @@ -236,16 +237,15 @@ func (o *Operator) prepareBackwardPass() {
if !o.RequiresGrad() {
return
}

o.pendingGrads++
if o.backwardState == idle {
o.backwardState = pending
//lint:ignore S1019 explicitly set the buffer size to 0 as the channel is used as a signal
o.broadcastGrad = make(chan struct{}, 0)
o.traverseOperandsForPreparation()
if !o.trySetBackwardPending() {
return
}
}

func (o *Operator) traverseOperandsForPreparation() {
//lint:ignore S1019 explicitly set the buffer size to 0 as the channel is used as a signal
o.broadcastGrad = make(chan struct{}, 0)

for _, operand := range o.Operands() {
if oo, ok := operand.(*Operator); ok {
oo.prepareBackwardPass()
Expand All @@ -254,15 +254,18 @@ func (o *Operator) traverseOperandsForPreparation() {
}

func (o *Operator) processBackwardPass(wg *sync.WaitGroup) {
if !o.RequiresGrad() || o.backwardState != pending {
if !o.RequiresGrad() || !o.trySetBackwardOngoing() {
return
}
o.backwardState = ongoing

wg.Add(1) // decrement when the backward pass is done
go o.executeBackward(wg)

o.traverseOperandsForBackward(wg)
for _, operand := range o.Operands() {
if oo, ok := operand.(*Operator); ok {
oo.processBackwardPass(wg)
}
}
}

func (o *Operator) executeBackward(wg *sync.WaitGroup) {
Expand All @@ -271,16 +274,24 @@ func (o *Operator) executeBackward(wg *sync.WaitGroup) {
log.Fatalf("ag: error during backward pass: %v", err) // TODO: handle error
}
}
o.backwardState = idle
o.setBackwardIdle()
wg.Done()
}

func (o *Operator) traverseOperandsForBackward(wg *sync.WaitGroup) {
for _, operand := range o.Operands() {
if oo, ok := operand.(*Operator); ok {
oo.processBackwardPass(wg)
}
}
func (o *Operator) isBackwardIdle() bool {
return atomic.LoadUint32(&o.backwardState) == idle
}

func (o *Operator) setBackwardIdle() {
atomic.StoreUint32(&o.backwardState, idle)
}

func (o *Operator) trySetBackwardPending() bool {
return atomic.CompareAndSwapUint32(&o.backwardState, idle, pending)
}

func (o *Operator) trySetBackwardOngoing() bool {
return atomic.CompareAndSwapUint32(&o.backwardState, pending, ongoing)
}

// isNil returns true if the gradients are nil.
Expand Down

0 comments on commit 40f923c

Please sign in to comment.