Skip to content
This repository has been archived by the owner on Nov 10, 2023. It is now read-only.

Commit

Permalink
Rename FifoRemove1Cache to default
Browse files Browse the repository at this point in the history
Renamed the FifoRemove1Cache experience replayer to reflect the fact
that it is actually the default usage of experience replay in the
literature.
  • Loading branch information
samuelfneumann committed Sep 24, 2021
1 parent ea10135 commit 63a89e8
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 103 deletions.
204 changes: 102 additions & 102 deletions expreplay/FifoRemove1ExpReplay.go → expreplay/Default.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@ import (
"github.com/samuelfneumann/golearn/timestep"
)

// fifoRemove1Cache implements a concrete ExperienceReplayer where
// elements are removed from the buffer in a FiFo manner, and only a
// defaultCache implements a concrete ExperienceReplayer where
// elements are removed from the buffer in a FiFo manner and only a
// single element is removed from the cache at a time. This is the most
// common use of experience replay.
//
// The fifoRemove1Cache is implemented to increase the efficiency of the
// The defaultCache is implemented to increase the efficiency of the
// cache struct when a FiFo Remover is used that removes only a single
// element from the cache at a time. In such cases, we can reduce the
// used RAM and increase the computational speed since we can take
// advantage of knowing the concrete type of the Remover.
type fifoRemove1Cache struct {
type defaultCache struct {
// includeNextAction denotes whether the next action in the SARSA
// tuple should be stored and returned
includeNextAction bool
Expand All @@ -43,7 +43,7 @@ type fifoRemove1Cache struct {
actionSize int
}

// newFifoRemove1Cache returns a new fifoRemove1Cache. The sampler
// newDefaultCache returns a new defaultCache. The sampler
// parameter is a Selectors which determines how data is sampled
// from the replay buffer. The featureSize and actionSize
// parameters define the size of the feature and action vectors.
Expand All @@ -53,8 +53,8 @@ type fifoRemove1Cache struct {
// allowed in the buffer at any given time.
// The includeNextAction parameter determiines whether or not
// the next action in the SARSA tuple should also be stored.
func newFifoRemove1Cache(sampler Selector, minCapacity, maxCapacity,
featureSize, actionSize int, includeNextAction bool) *fifoRemove1Cache {
func newDefaultCache(sampler Selector, minCapacity, maxCapacity,
featureSize, actionSize int, includeNextAction bool) *defaultCache {
stateCache := make([]float64, maxCapacity*featureSize)
nextStateCache := make([]float64, maxCapacity*featureSize)

Expand All @@ -72,7 +72,7 @@ func newFifoRemove1Cache(sampler Selector, minCapacity, maxCapacity,
indices[i] = i
}

return &fifoRemove1Cache{
return &defaultCache{
includeNextAction: includeNextAction,

stateCache: stateCache,
Expand All @@ -95,224 +95,224 @@ func newFifoRemove1Cache(sampler Selector, minCapacity, maxCapacity,
}
}

// String returns the string representation of the fifoRemove1Cache
func (c *fifoRemove1Cache) String() string {
// String returns the string representation of the defaultCache
func (d *defaultCache) String() string {
var emptyIndices []int
var usedIndices []int
if !c.isFull {
emptyIndices = c.indices[c.currentInUsePos:]
usedIndices = c.indices[:c.currentInUsePos]
if !d.isFull {
emptyIndices = d.indices[d.currentInUsePos:]
usedIndices = d.indices[:d.currentInUsePos]
} else {
emptyIndices = []int{}
usedIndices = c.indices
usedIndices = d.indices
}

baseStr := "Indices Available: %v \nIndices Used: %v \nStates: %v" +
" \nActions: %v \nRewards: %v \nDiscounts: %v \nNext States: %v \n" +
"Next Actions: %v"
return fmt.Sprintf(baseStr, emptyIndices, usedIndices, c.stateCache,
c.actionCache, c.rewardCache, c.discountCache, c.nextStateCache, c.nextActionCache)
return fmt.Sprintf(baseStr, emptyIndices, usedIndices, d.stateCache,
d.actionCache, d.rewardCache, d.discountCache, d.nextStateCache, d.nextActionCache)
}

// BatchSize returns the number of samples sampled using Sample() -
// a.k.a the batch size
func (c *fifoRemove1Cache) BatchSize() int {
return c.sampler.BatchSize()
func (d *defaultCache) BatchSize() int {
return d.sampler.BatchSize()
}

// insertOrder returns the insertion order of samples into the buffer
func (c *fifoRemove1Cache) insertOrder(n int) []int {
c.wait.Wait()
func (d *defaultCache) insertOrder(n int) []int {
d.wait.Wait()

if !c.isFull {
return c.indices[:c.currentInUsePos]
if !d.isFull {
return d.indices[:d.currentInUsePos]
}

currentIndices := make([]int, c.MaxCapacity())
copy(currentIndices[c.currentInUsePos:], c.indices[c.currentInUsePos:])
copy(currentIndices[:c.currentInUsePos], c.indices[:c.currentInUsePos])
currentIndices := make([]int, d.MaxCapacity())
copy(currentIndices[d.currentInUsePos:], d.indices[d.currentInUsePos:])
copy(currentIndices[:d.currentInUsePos], d.indices[:d.currentInUsePos])

return currentIndices[:n]
}

// sampleFrom returns the slice of indices to sample from
func (c *fifoRemove1Cache) sampleFrom() []int {
c.wait.Wait()
func (d *defaultCache) sampleFrom() []int {
d.wait.Wait()

if !c.isFull {
return c.indices[:c.currentInUsePos]
if !d.isFull {
return d.indices[:d.currentInUsePos]
}
return c.indices
return d.indices
}

// Sample samples and returns a batch of transitions from the replay
// buffer. The returned values are the state, action, reward, discount,
// next state, and next action.
func (c *fifoRemove1Cache) Sample() ([]float64, []float64, []float64,
func (d *defaultCache) Sample() ([]float64, []float64, []float64,
[]float64, []float64, []float64, error) {
c.wait.Wait()
d.wait.Wait()

if c.Capacity() == 0 {
if d.Capacity() == 0 {
err := &ExpReplayError{
Op: "sample",
Err: errEmptyCache,
}
return nil, nil, nil, nil, nil, nil, err
}
if c.Capacity() < c.MinCapacity() {
if d.Capacity() < d.MinCapacity() {
err := &ExpReplayError{
Op: "sample",
Err: errInsufficientSamples,
}
return nil, nil, nil, nil, nil, nil, err
}

indices := c.sampler.choose(c)
indices := d.sampler.choose(d)

// Create the state batches
stateBatch := make([]float64, c.BatchSize()*c.featureSize)
nextStateBatch := make([]float64, c.BatchSize()*c.featureSize)
stateBatch := make([]float64, d.BatchSize()*d.featureSize)
nextStateBatch := make([]float64, d.BatchSize()*d.featureSize)

// Fill the state batches
c.wait.Add(2 * len(indices))
d.wait.Add(2 * len(indices))
for i, index := range indices {
batchStartInd := i * c.featureSize
expStartInd := index * c.featureSize
batchStartInd := i * d.featureSize
expStartInd := index * d.featureSize

go func() {
copyInto(stateBatch, batchStartInd, batchStartInd+c.featureSize,
c.stateCache[expStartInd:expStartInd+c.featureSize])
c.wait.Done()
copyInto(stateBatch, batchStartInd, batchStartInd+d.featureSize,
d.stateCache[expStartInd:expStartInd+d.featureSize])
d.wait.Done()
}()

go func() {
copyInto(nextStateBatch, batchStartInd,
batchStartInd+c.featureSize,
c.nextStateCache[expStartInd:expStartInd+c.featureSize],
batchStartInd+d.featureSize,
d.nextStateCache[expStartInd:expStartInd+d.featureSize],
)
c.wait.Done()
d.wait.Done()
}()
}

// Create the action batches
actionBatch := make([]float64, c.BatchSize()*c.actionSize)
actionBatch := make([]float64, d.BatchSize()*d.actionSize)
var nextActionBatch []float64
if c.includeNextAction {
nextActionBatch = make([]float64, c.BatchSize()*c.actionSize)
if d.includeNextAction {
nextActionBatch = make([]float64, d.BatchSize()*d.actionSize)
}

// Fill the action batches
c.wait.Add(2 * len(indices))
d.wait.Add(2 * len(indices))
for i, index := range indices {
batchStartInd := i * c.actionSize
expStartInd := index * c.actionSize
batchStartInd := i * d.actionSize
expStartInd := index * d.actionSize

go func() {
copyInto(actionBatch, batchStartInd, batchStartInd+c.actionSize,
c.actionCache[expStartInd:expStartInd+c.actionSize],
copyInto(actionBatch, batchStartInd, batchStartInd+d.actionSize,
d.actionCache[expStartInd:expStartInd+d.actionSize],
)
c.wait.Done()
d.wait.Done()
}()

go func() {
if c.includeNextAction {
if d.includeNextAction {
copyInto(nextActionBatch, batchStartInd,
batchStartInd+c.actionSize,
c.nextActionCache[expStartInd:expStartInd+c.actionSize],
batchStartInd+d.actionSize,
d.nextActionCache[expStartInd:expStartInd+d.actionSize],
)
}
c.wait.Done()
d.wait.Done()
}()
}

rewardBatch := make([]float64, c.BatchSize())
discountBatch := make([]float64, c.BatchSize())
rewardBatch := make([]float64, d.BatchSize())
discountBatch := make([]float64, d.BatchSize())
for i, index := range indices {
discountBatch[i] = c.discountCache[index]
rewardBatch[i] = c.rewardCache[index]
discountBatch[i] = d.discountCache[index]
rewardBatch[i] = d.rewardCache[index]
}

c.wait.Wait()
d.wait.Wait()
return stateBatch, actionBatch, rewardBatch, discountBatch, nextStateBatch,
nextActionBatch, nil
}

// Capacity returns the current number of elements in the fifoRemove1Cache that
// Capacity returns the current number of elements in the defaultCache that
// are available for sampling
func (c *fifoRemove1Cache) Capacity() int {
c.wait.Wait()
func (d *defaultCache) Capacity() int {
d.wait.Wait()

if c.isFull {
return c.MaxCapacity()
if d.isFull {
return d.MaxCapacity()
}
return c.currentInUsePos
return d.currentInUsePos
}

// MaxCapacity returns the maximum number of elements that are allowed
// in the fifoRemove1Cache
func (c *fifoRemove1Cache) MaxCapacity() int {
return c.maxCapacity
// in the defaultCache
func (d *defaultCache) MaxCapacity() int {
return d.maxCapacity
}

// MinCapacity returns the minimum number of elements required in the
// fifoRemove1Cache before sampling is allowed
func (c *fifoRemove1Cache) MinCapacity() int {
return c.minCapacity
// defaultCache before sampling is allowed
func (d *defaultCache) MinCapacity() int {
return d.minCapacity
}

// Add adds a transition to the fifoRemove1Cache
func (c *fifoRemove1Cache) Add(t timestep.Transition) error {
// Add adds a transition to the defaultCache
func (d *defaultCache) Add(t timestep.Transition) error {
// Finish the last Add operation, then start
c.wait.Wait()
c.wait.Add(4)
d.wait.Wait()
d.wait.Add(4)

index := c.currentInUsePos
if !c.isFull && index+1 == c.MaxCapacity() {
c.isFull = true
index := d.currentInUsePos
if !d.isFull && index+1 == d.MaxCapacity() {
d.isFull = true
}

if t.State.Len() != c.featureSize || t.NextState.Len() != c.featureSize {
if t.State.Len() != d.featureSize || t.NextState.Len() != d.featureSize {
return fmt.Errorf("add: invalid feature size \n\twant(%v)\n\thave(%v)",
t.State.Len(), c.featureSize)
t.State.Len(), d.featureSize)
}
if t.Action.Len() != c.actionSize || t.NextAction.Len() != c.actionSize {
if t.Action.Len() != d.actionSize || t.NextAction.Len() != d.actionSize {
return fmt.Errorf("add: invalid action size \n\twant(%v)\n\thave(%v)",
t.Action.Len(), c.actionSize)
t.Action.Len(), d.actionSize)
}

// Copy states
stateInd := index * c.featureSize
stateInd := index * d.featureSize
go func() {
copyInto(c.stateCache, stateInd, stateInd+c.featureSize,
copyInto(d.stateCache, stateInd, stateInd+d.featureSize,
t.State.RawVector().Data)
c.wait.Done()
d.wait.Done()
}()
go func() {
copyInto(c.nextStateCache, stateInd, stateInd+c.featureSize,
copyInto(d.nextStateCache, stateInd, stateInd+d.featureSize,
t.NextState.RawVector().Data)
c.wait.Done()
d.wait.Done()
}()

// Copy actions
actionInd := index * c.actionSize
actionInd := index * d.actionSize
go func() {
copyInto(c.actionCache, actionInd, actionInd+c.actionSize,
copyInto(d.actionCache, actionInd, actionInd+d.actionSize,
t.Action.RawVector().Data)
c.wait.Done()
d.wait.Done()
}()
go func() {
if c.includeNextAction {
copyInto(c.nextActionCache, actionInd, actionInd+c.actionSize,
if d.includeNextAction {
copyInto(d.nextActionCache, actionInd, actionInd+d.actionSize,
t.NextAction.RawVector().Data)
}
c.wait.Done()
d.wait.Done()
}()

// Copy reward R
c.rewardCache[index] = t.Reward
c.discountCache[index] = t.Discount
d.rewardCache[index] = t.Reward
d.discountCache[index] = t.Discount

c.wait.Wait()
c.currentInUsePos = (c.currentInUsePos + 1) % c.MaxCapacity()
d.wait.Wait()
d.currentInUsePos = (d.currentInUsePos + 1) % d.MaxCapacity()
return nil
}
2 changes: 1 addition & 1 deletion expreplay/ExpReplay.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ func New(remover, sampler Selector, minCapacity, maxCapacity, featureSize,
}

if _, ok := remover.(*fifoSelector); ok && remover.BatchSize() == 1 {
return newFifoRemove1Cache(sampler, minCapacity, maxCapacity, featureSize,
return newDefaultCache(sampler, minCapacity, maxCapacity, featureSize,
actionSize, includeNextAction), nil
}

Expand Down

0 comments on commit 63a89e8

Please sign in to comment.