Skip to content

Commit

Permalink
Add a semaphore to limit the concurrency of forward executions
Browse files Browse the repository at this point in the history
  • Loading branch information
matteo-grella committed May 2, 2023
1 parent 6a67106 commit 6c657b5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 2 deletions.
16 changes: 14 additions & 2 deletions ag/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,15 @@ type AutoGradFunction[T DualValue] interface {
Operands() []T
}

var _ Node = &Operator{}
// forwardGuard is a buffered channel that acts as a semaphore to limit the concurrency
// of async forward operations in the Run function. Its buffer size determines the maximum
// number of forward operations that can run concurrently. Acquiring and releasing slots
// in the semaphore ensures that the concurrency level stays within the desired limit.
var forwardGuard chan struct{}

func init() {
forwardGuard = make(chan struct{}, runtime.NumCPU()*2)
}

// Operator is a type of node.
// It's used to represent a function with automatic differentiation features.
Expand Down Expand Up @@ -119,7 +127,11 @@ func (o *Operator) Run(async ...bool) *Operator {
if isAsync {
//lint:ignore S1019 explicitly set the buffer size to 0 as the channel is used as a signal
o.broadcast = make(chan struct{}, 0)
go o.executeForward()
forwardGuard <- struct{}{}
go func() {
o.executeForward()
<-forwardGuard
}()
return o
}

Expand Down
2 changes: 2 additions & 0 deletions ag/operator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import (
"github.com/stretchr/testify/require"
)

var _ Node = &Operator{}

func TestNewOperator(t *testing.T) {
t.Run("float32", testNewOperator[float32])
t.Run("float64", testNewOperator[float64])
Expand Down

0 comments on commit 6c657b5

Please sign in to comment.