Skip to content
This repository has been archived by the owner on Nov 10, 2023. It is now read-only.

Commit

Permalink
Set value function input and target once
Browse files Browse the repository at this point in the history
Before, the value function/critic target and input in VPG and VAC was
re-set at each training iteration in a single call to `Step`. This is
unneeded and increases the compute time. Now, we only set these values
once.
  • Loading branch information
samuelfneumann committed Sep 24, 2021
1 parent f025b86 commit ea10135
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 21 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -714,3 +714,5 @@ sequential runs of hyperparameter setting `m` of the `Agent` in the
* [ ] Move GAEBuffer and ExpReplay to a new `buffer` package - in which case GAE buffer needs a public API
* [ ] Rename `FifoRemove1ExpReplay` to `Default` and document what default means
* [ ] Agents should have a Close() method, or create an agent.Closer interface, and check if agent is a Closer before closing at the end of main
9 changes: 4 additions & 5 deletions agent/nonlinear/continuous/vanillaac/VanillaAC.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,12 +453,11 @@ func (v *VAC) Step() error {
v.trainPolicyVM.Reset()

// === === Value Function Train === ===
err = v.vTrainValueFn.SetInput(S)
if err != nil {
return fmt.Errorf("step: could not set critic input state: %v", err)
}
for i := 0; i < v.valueGradSteps; i++ {
err = v.vTrainValueFn.SetInput(S)
if err != nil {
return fmt.Errorf("step: could not set critic input state on "+
"training iteration %d: %v", i, err)
}
err = v.vTrainValueFnVM.RunAll()
if err != nil {
return fmt.Errorf("step: could not run critic vm on training "+
Expand Down
33 changes: 17 additions & 16 deletions agent/nonlinear/continuous/vanillapg/VanillaPG.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,23 +323,24 @@ func (v *VPG) Step() error {
}
v.trainPolicyVM.Reset()

// Value function update
for i := 0; i < v.valueGradSteps; i++ {
if err := v.vTrainValueFn.SetInput(obs); err != nil {
return fmt.Errorf("step: could not set value function input "+
"at training iteration %d: %v", i, err)
}
// Set value function input
if err := v.vTrainValueFn.SetInput(obs); err != nil {
return fmt.Errorf("step: could not set value function input: %v", err)
}

trainValueFnTargetsTensor := tensor.NewDense(
tensor.Float64,
v.vTrainValueFnTargets.Shape(),
tensor.WithBacking(ret),
)
err = G.Let(v.vTrainValueFnTargets, trainValueFnTargetsTensor)
if err != nil {
return fmt.Errorf("step: could not set value function target "+
"at training iteration %d: %v", i, err)
}
// Set value function target
trainValueFnTargetsTensor := tensor.NewDense(
tensor.Float64,
v.vTrainValueFnTargets.Shape(),
tensor.WithBacking(ret),
)
err = G.Let(v.vTrainValueFnTargets, trainValueFnTargetsTensor)
if err != nil {
return fmt.Errorf("step: could not set value function target: %v", err)
}

// Update value function
for i := 0; i < v.valueGradSteps; i++ {
if err := v.vTrainValueFnVM.RunAll(); err != nil {
return fmt.Errorf("step: could not run value function vm "+
"at training iteration %d: %v", i, err)
Expand Down

0 comments on commit ea10135

Please sign in to comment.