Skip to content
This repository has been archived by the owner on Nov 10, 2023. It is now read-only.

Commit

Permalink
Separate TdError method from Agent
Browse files Browse the repository at this point in the history
Separated the TdError method from the Agent interface. Generally, this
method is only needed for average reward calculation, and since only
very few algorithms will ever use the average reward, there is no point
in requiring all agents to implement this method.
  • Loading branch information
samuelfneumann committed Oct 21, 2021
1 parent c9c1f88 commit 925ed0d
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
15 changes: 10 additions & 5 deletions agent/Agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,18 @@ type Learner interface {
// ObserveFirst records the first timestep in an episode
ObserveFirst(timestep.TimeStep) error

// TdError returns the TD error on a transition
TdError(t timestep.Transition) float64

// EndEpisode performs cleanup at the end of an episode
EndEpisode()
}

// TdErrorer is a Learner that can return the TdError of some transition
type TdErrorer interface {
Learner

// TdError returns the TD error on a transition
TdError(t timestep.Transition) float64
}

// Policy represents a policy that an agent can have.
//
// Policies determine how agents select actions. Agents usually have a
Expand All @@ -51,8 +56,8 @@ type Learner interface {
// makes to the weights are reflected in the actions the Policy chooses
type Policy interface {
SelectAction(t timestep.TimeStep) *mat.VecDense
Eval() // Set agent to evaluation mode
Train() // Set agent to training mode
Eval() // Set policy to evaluation mode
Train() // Set policy to training mode
IsEval() bool // Indicates if in evaluation mode
}

Expand Down
10 changes: 6 additions & 4 deletions environment/wrappers/AverageReward.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,15 @@ import (
// generate the differential TD error to use in updates. Usually, the
// agent.Learner of the agent.Agent that is acting in the environment
// should be used. If not, the differential return may diverge,
// resulting in algorithms that do no learn.
// resulting in algorithms that do not learn.
type AverageReward struct {
environment.Environment
avgReward float64
learningRate float64
useTDError bool

// learner calculates the TD error for the average reward update target
learner agent.Learner
learner agent.TdErrorer

// The last timestep is needed to calculate the TD error, since it stores
// the reward R_{t} for the last action A_{t} taken in the last state S_{t}
Expand All @@ -66,6 +66,8 @@ type AverageReward struct {
// the average reward estimate is updated using the TD error of a
// registered learner as the update target or not. If false, then the
// environmental reward is used as the average reward update target.
// Each method of learning the TdError is acceptable, but the first
// method (where useTDError == True) is lower variance. See the RL Book.
func NewAverageReward(env environment.Environment, init, learningRate float64,
useTDError bool) (*AverageReward, ts.TimeStep, error) {
// Get the first step from the embedded environment
Expand Down Expand Up @@ -106,7 +108,7 @@ func (a *AverageReward) Reset() (ts.TimeStep, error) {

}

// Register registers an agent.Learner with the environment so that
// Register registers an agent.TdErrorer with the environment so that
// the TD error can be calculated and used to update the average reward
// estimate.
//
Expand All @@ -115,7 +117,7 @@ func (a *AverageReward) Reset() (ts.TimeStep, error) {
// before calling the Step() method. Failure to do so will result in a
// panic. If the useTDError parameter was set to false, then this
// method will have no effect.
func (a *AverageReward) Register(l agent.Learner) {
func (a *AverageReward) Register(l agent.TdErrorer) {
a.learner = l
}

Expand Down

0 comments on commit 925ed0d

Please sign in to comment.