Skip to content

Commit

Permalink
Change output gradient assignment: set to 1 only for scalar nodes, pr…
Browse files Browse the repository at this point in the history
…eviously assigned to non-scalars too
  • Loading branch information
matteo-grella committed May 10, 2023
1 parent 55418c9 commit 5bb695b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 16 deletions.
24 changes: 15 additions & 9 deletions ag/backward.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,41 @@ import "sync"
//
// The function operates according to the following mutually exclusive rules:
// - If the node already has gradients (likely assigned externally via node.AccGrads()), those gradients are used.
// - If the node does not have gradients assigned, the output gradients are automatically assigned by finding the derivative of the node with respect to itself (dy/dy = 1).
// - If the node does not have gradients assigned and is a scalar, the output gradients are automatically assigned
// by finding the derivative of the node with respect to itself (dy/dy = 1).
// - If the node does not have gradients assigned and is not a scalar, it returns an error.
//
// During the back-propagation process, the gradients of all nodes, except for the given node, are summed to the existing gradients.
// Unless you intend to do so, ensure that all nodes have zero gradients.
func Backward(xs ...DualValue) {
func Backward(xs ...DualValue) error {
ops := filterOperators(xs)
if len(ops) == 0 {
return
return nil
}

// The three for loops below are intentionally executed in sequence. They perform the following steps:
// The three for loops below are intentionally executed in sequence.

// 1. Prepare the backward pass for each operator.
// 2. Set the output gradients for each operator.
// 3. Process the backward pass for each operator in parallel using wait groups.
//
// These steps must occur in this order, so the loops cannot be combined due to their sequential dependencies.
for _, op := range ops {
op.prepareBackwardPass()
}

// 2. Assign the output gradients for each operator.
for _, op := range ops {
op.setOutputGrad()
if err := op.assignOutputGradient(); err != nil {
return err
}
}

// 3. Process the backward pass for each operator in parallel using wait groups.
// These steps must occur in this order, so the loops cannot be combined due to their sequential dependencies.
wg := &sync.WaitGroup{}
for _, op := range ops {
op.processBackwardPass(wg)
}
wg.Wait()

return nil
}

// filterOperators returns a list of operators from a list of nodes.
Expand Down
21 changes: 14 additions & 7 deletions ag/operator.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package ag

import (
"fmt"
"log"
"reflect"
"runtime"
Expand Down Expand Up @@ -215,14 +216,20 @@ func (o *Operator) AccGrad(grad mat.Matrix) {
}
}

func (o *Operator) setOutputGrad() {
if isNil(o.Value().Grad()) {
gx := o.Value().OnesLike()
o.AccGrad(gx)
return
func (o *Operator) assignOutputGradient() error {
grad := o.Value().Grad()

if !isNil(grad) {
o.pendingGrads--
return nil
}
// If the node already has gradients, we can use them directly.
o.pendingGrads--

if o.Value().Size() == 1 {
o.AccGrad(o.Value().NewScalar(1.))
return nil
}

return fmt.Errorf("ag: missing gradient for %v", o)
}

func (o *Operator) prepareBackwardPass() {
Expand Down

0 comments on commit 5bb695b

Please sign in to comment.