Skip to content
This repository has been archived by the owner on Nov 10, 2023. It is now read-only.

Commit

Permalink
Add ability to close agents
Browse files Browse the repository at this point in the history
Added functionality to close agents when they are no longer needed. This
is done through the new `agent.Closer` interface. Any `agent.Closer`
needs to be closed to clean up used resources. For now, this is
restricted to neural network agents using Gorgonia.
  • Loading branch information
samuelfneumann committed Sep 24, 2021
1 parent 159b63a commit 79bbc7d
Show file tree
Hide file tree
Showing 8 changed files with 189 additions and 1 deletion.
7 changes: 7 additions & 0 deletions agent/Agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ type Agent interface {
Policy
}

// A Closer is an agent that must be closed after it is done learning
type Closer interface {
Agent
Close() error
}

// Learner implements a learning algorithm that defines how weights are
// updated.
type Learner interface {
Expand Down Expand Up @@ -62,6 +68,7 @@ type NNPolicy interface {
Clone() (NNPolicy, error)
CloneWithBatch(int) (NNPolicy, error)
Network() network.NeuralNet
Close() error
}

// EGreedyNNPolicy implements an epsilon greedy policy using neural
Expand Down
8 changes: 8 additions & 0 deletions agent/nonlinear/continuous/policy/CategoricalMLP.go
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,11 @@ func (c *CategoricalMLP) Eval() {
func (c *CategoricalMLP) IsEval() bool {
return c.eval
}

// Close cleans up resources after the policy is no longer needed
func (c *CategoricalMLP) Close() error {
if c.vm != nil {
return c.vm.Close()
}
return nil
}
9 changes: 9 additions & 0 deletions agent/nonlinear/continuous/policy/GaussianTreeMLP.go
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ func NewGaussianTreeMLP(env environment.Environment, batchForLogProb int,
source := rand.NewSource(seed)
normal, ok := distmv.NewNormal(means, stds, source)
if !ok {
// This should never happen
panic("newGaussianTreeMLP: could not create standard normal for " +
"action selection")
}
Expand Down Expand Up @@ -320,3 +321,11 @@ func (g *GaussianTreeMLP) Mean() G.Value {
func (g *GaussianTreeMLP) StdDev() G.Value {
return g.stddevVal
}

// Close cleans up resources after the policy is no longer needed
func (g *GaussianTreeMLP) Close() error {
if g.vm != nil {
return g.vm.Close()
}
return nil
}
60 changes: 60 additions & 0 deletions agent/nonlinear/continuous/vanillaac/VanillaAC.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package vanillaac

import (
"fmt"
"strings"

"github.com/samuelfneumann/golearn/agent"
env "github.com/samuelfneumann/golearn/environment"
Expand Down Expand Up @@ -525,3 +526,62 @@ func (v *VAC) TdError(t ts.Transition) float64 {

return r + *nextStateValue[0] - stateValue[0]
}

// Close cleans up any used resources
func (v *VAC) Close() error {
behaviourVMErr := v.behaviour.Close()
trainPolicyVMErr := v.trainPolicy.Close()
valueFnVMErr := v.vTrainValueFnVM.Close()
trainValueFnVMErr := v.vTrainValueFnVM.Close()
targetValueFnVMErr := v.vTargetValueFnVM.Close()

flag := false
var errBuilder strings.Builder
errBuilder.WriteString("close: could not close")

if behaviourVMErr != nil {
flag = true
errBuilder.WriteString(" behaviour policy")
}

if trainPolicyVMErr != nil {
if flag {
errBuilder.WriteString(", train policy")
} else {
flag = true
errBuilder.WriteString(" train policy")
}
}

if valueFnVMErr != nil {
if flag {
errBuilder.WriteString(", value function")
} else {
flag = true
errBuilder.WriteString(" value function")
}
}

if trainValueFnVMErr != nil {
if flag {
errBuilder.WriteString(", train value function")
} else {
flag = true
errBuilder.WriteString(" train value function")
}
}

if targetValueFnVMErr != nil {
if flag {
errBuilder.WriteString(", target value function")
} else {
flag = true
errBuilder.WriteString(" target value function")
}
}

if flag {
return fmt.Errorf(errBuilder.String())
}
return nil
}
50 changes: 50 additions & 0 deletions agent/nonlinear/continuous/vanillapg/VanillaPG.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package vanillapg

import (
"fmt"
"strings"

"github.com/samuelfneumann/golearn/agent"
"github.com/samuelfneumann/golearn/environment"
Expand Down Expand Up @@ -394,3 +395,52 @@ func (v *VPG) TdError(t ts.Transition) float64 {

return r + *nextStateValue[0] - stateValue[0]
}

// Close cleans up any used resources
func (v *VPG) Close() error {
behaviourVMErr := v.behaviour.Close()
trainPolicyVMErr := v.trainPolicy.Close()
valueFnVMErr := v.vTrainValueFnVM.Close()
trainValueFnVMErr := v.vTrainValueFnVM.Close()

flag := false
var errBuilder strings.Builder
errBuilder.WriteString("close: could not close")

if behaviourVMErr != nil {
flag = true
errBuilder.WriteString(" behaviour policy")
}

if trainPolicyVMErr != nil {
if flag {
errBuilder.WriteString(", train policy")
} else {
flag = true
errBuilder.WriteString(" train policy")
}
}

if valueFnVMErr != nil {
if flag {
errBuilder.WriteString(", value function")
} else {
flag = true
errBuilder.WriteString(" value function")
}
}

if trainValueFnVMErr != nil {
if flag {
errBuilder.WriteString(", train value function")
} else {
flag = true
errBuilder.WriteString(" train value function")
}
}

if flag {
return fmt.Errorf(errBuilder.String())
}
return nil
}
42 changes: 41 additions & 1 deletion agent/nonlinear/discrete/deepq/DeepQ.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package deepq

import (
"fmt"
"strings"

"github.com/samuelfneumann/golearn/agent"
"github.com/samuelfneumann/golearn/agent/linear/discrete/qlearning"
Expand Down Expand Up @@ -405,5 +406,44 @@ func (d *DeepQ) IsEval() bool {
return d.policy.IsEval()
}

// Cleanup at the end of an episode
// EndEpisode performs cleanup at the end of an episode
func (d *DeepQ) EndEpisode() {}

// Close cleans up any used resources
func (d *DeepQ) Close() error {
policyErr := d.policy.Close()
trainVMErr := d.trainNetVM.Close()
targetVMErr := d.targetNetVM.Close()

flag := false
var errBuilder strings.Builder
errBuilder.WriteString("close: could not close")

if policyErr != nil {
flag = true
errBuilder.WriteString(" policy")
}

if trainVMErr != nil {
if flag {
errBuilder.WriteString(", train network")
} else {
flag = true
errBuilder.WriteString(" train network")
}
}

if targetVMErr != nil {
if flag {
errBuilder.WriteString(", target network")
} else {
flag = true
errBuilder.WriteString(" target network")
}
}

if flag {
return fmt.Errorf(errBuilder.String())
}
return nil
}
8 changes: 8 additions & 0 deletions agent/nonlinear/discrete/policy/EGreedyMLP.go
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,14 @@ func (e *MultiHeadEGreedyMLP) SelectAction(t timestep.TimeStep) *mat.VecDense {
return mat.NewVecDense(1, []float64{float64(action)})
}

// Close cleans up resources after the policy is no longer needed
func (e *MultiHeadEGreedyMLP) Close() error {
if e.vm != nil {
return e.vm.Close()
}
return nil
}

// numActions returns the number of actions that the policy chooses
// between.
func (e *MultiHeadEGreedyMLP) numActions() int {
Expand Down
6 changes: 6 additions & 0 deletions experiment/Online.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,12 @@ func (o *Online) Run() error {
if env, ok := o.environment.(env.Closer); ok {
env.Close()
}

// Close the agent if needed
if agent, ok := o.agent.(ag.Closer); ok {
agent.Close()
}

o.progBar.Close()
return nil
}
Expand Down

0 comments on commit 79bbc7d

Please sign in to comment.