Skip to content

Commit

Permalink
Simplify Env setup
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaixhin committed Jun 5, 2016
1 parent 77ffd85 commit 0de62cb
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 49 deletions.
32 changes: 8 additions & 24 deletions Master.lua
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
local Display = require 'Display'
local signal = require 'posix.signal'
local _ = require 'moses'
local classic = require 'classic'
local signal = require 'posix.signal'
local gnuplot = require 'gnuplot'
local Singleton = require 'structures/Singleton'
local Agent = require 'Agent'
local Evaluator = require 'Evaluator'
local classic = require 'classic'

local Display = require 'Display'

local Master = classic.class('Master')

Expand All @@ -16,28 +15,13 @@ function Master:_init(opt)
-- Set up singleton global object for transferring step
self.globals = Singleton({step = 1}) -- Initial step

----- Environment + Agent Setup -----

-- Initialise Catch or Arcade Learning Environment
log.info('Setting up ' .. (opt.ale and 'Arcade Learning Environment' or 'Catch'))
if opt.ale then
local Atari = require 'rlenvs.Atari'
self.env = Atari(opt)
local stateSpec = self.env:getStateSpec()

-- Provide original channels, height and width for resizing from
opt.origChannels, opt.origHeight, opt.origWidth = table.unpack(stateSpec[2])
else
local Catch = require 'rlenvs.Catch'
self.env = Catch()
local stateSpec = self.env:getStateSpec()

-- Provide original channels, height and width for resizing from
opt.origChannels, opt.origHeight, opt.origWidth = table.unpack(stateSpec[2])

-- Adjust height and width
opt.height, opt.width = stateSpec[2][2], stateSpec[2][3]
end
local Env = opt.ale and require 'rlenvs.Atari' or require 'rlenvs.Catch'
self.env = Env(opt)
local stateSpec = self.env:getStateSpec()
-- Provide original channels, height and width for resizing from
opt.origChannels, opt.origHeight, opt.origWidth = table.unpack(stateSpec[2])

-- Create DQN agent
log.info('Creating DQN')
Expand Down
4 changes: 2 additions & 2 deletions Setup.lua
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ function Setup:parseOptions(arg)
cmd:option('-memSize', 1e6, 'Experience replay memory size (number of tuples)')
cmd:option('-memSampleFreq', 4, 'Interval of steps between sampling from memory to learn')
cmd:option('-memNSamples', 1, 'Number of times to sample per learning step')
cmd:option('-memPriority', 'rank', 'Type of prioritised experience replay: none|rank|proportional')
cmd:option('-memPriority', 'rank', 'Type of prioritised experience replay: none|rank|proportional') -- TODO: Implement proportional prioritised experience replay
cmd:option('-alpha', 0.65, 'Prioritised experience replay exponent α') -- Best vals are rank = 0.7, proportional = 0.6
cmd:option('-betaZero', 0.45, 'Initial value of importance-sampling exponent β') -- Best vals are rank = 0.5, proportional = 0.4
-- Reinforcement learning parameters
Expand Down Expand Up @@ -156,7 +156,7 @@ function Setup:parseOptions(arg)

-- Process async agent options
if opt.async == 'false' then opt.async = false end
if opt.async then opt.gpu = 0 end
if opt.async then opt.gpu = 0 end -- Asynchronous agents are CPU-only

-- Set ID as game name if not set
if opt._id == '' then
Expand Down
27 changes: 7 additions & 20 deletions async/AsyncModel.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,13 @@ local Model = require 'Model'
local AsyncModel = classic.class('AsyncModel')

function AsyncModel:_init(opt)
-- Initialise Catch or Arcade Learning Environment
log.info('Setting up ' .. (opt.ale and 'Arcade Learning Environment' or 'Catch'))

if opt.ale then
local Atari = require 'rlenvs.Atari'
self.env = Atari(opt)
local stateSpec = self.env:getStateSpec()

-- Provide original channels, height and width for resizing from
opt.origChannels, opt.origHeight, opt.origWidth = table.unpack(stateSpec[2])
else
local Catch = require 'rlenvs.Catch'
self.env = Catch()
local stateSpec = self.env:getStateSpec()

-- Provide original channels, height and width for resizing from
opt.origChannels, opt.origHeight, opt.origWidth = table.unpack(stateSpec[2])

-- Adjust height and width
opt.height, opt.width = stateSpec[2][2], stateSpec[2][3]
end
local Env = opt.ale and require 'rlenvs.Atari' or require 'rlenvs.Catch'
self.env = Env(opt)
local stateSpec = self.env:getStateSpec()
-- Provide original channels, height and width for resizing from
opt.origChannels, opt.origHeight, opt.origWidth = table.unpack(stateSpec[2])

self.model = Model(opt)
self.a3c = opt.async == 'A3C'
Expand All @@ -42,4 +29,4 @@ function AsyncModel:createNet()
end


return AsyncModel
return AsyncModel
6 changes: 3 additions & 3 deletions run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ fi

if [ "$PAPER" == "demo" ]; then
# Catch demo
th main.lua -gpu 0 -hiddenSize 32 -optimiser adam -steps 500000 -learnStart 50000 -tau 4 -memSize 50000 -epsilonSteps 10000 -valFreq 10000 -valSteps 6000 -bootstraps 0 -PALpha 0 "$@"
th main.lua -gpu 0 -height 24 -width 24 -hiddenSize 32 -optimiser adam -steps 500000 -learnStart 50000 -tau 4 -memSize 50000 -epsilonSteps 10000 -valFreq 10000 -valSteps 6000 -bootstraps 0 -PALpha 0 "$@"
elif [ "$PAPER" == "nature" ]; then
# Nature
th main.lua -game $GAME -duel false -bootstraps 0 -memPriority none -epsilonEnd 0.1 -tau 10000 -doubleQ false -PALpha 0 -eta 0.00025 -gradClip 0 "$@"
Expand Down Expand Up @@ -62,10 +62,10 @@ elif [ "$PAPER" == "recurrent" ]; then
# Async modes
elif [ "$PAPER" == "demo-async" ]; then
# N-Step Q-learning Catch demo
th main.lua -async NStepQ -eta 0.00025 -momentum 0.99 -bootstraps 0 -batchSize 5 -hiddenSize 32 -doubleQ false -duel false -optimiser adam -steps 15000000 -tau 4 -memSize 20000 -epsilonSteps 10000 -valFreq 10000 -valSteps 6000 -bootstraps 0 -PALpha 0 "$@"
th main.lua -height 24 -width 24 -async NStepQ -eta 0.00025 -momentum 0.99 -bootstraps 0 -batchSize 5 -hiddenSize 32 -doubleQ false -duel false -optimiser adam -steps 15000000 -tau 4 -memSize 20000 -epsilonSteps 10000 -valFreq 10000 -valSteps 6000 -bootstraps 0 -PALpha 0 "$@"
elif [ "$PAPER" == "demo-async-a3c" ]; then
# A3C Catch demo
th main.lua -async A3C -eta 0.0007 -momentum 0.99 -bootstraps 0 -batchSize 5 -hiddenSize 32 -doubleQ false -duel false -optimiser adam -steps 15000000 -tau 4 -memSize 20000 -epsilonSteps 10000 -valFreq 10000 -valSteps 6000 -bootstraps 0 -PALpha 0 "$@"
th main.lua -height 24 -width 24 -async A3C -eta 0.0007 -momentum 0.99 -bootstraps 0 -batchSize 5 -hiddenSize 32 -doubleQ false -duel false -optimiser adam -steps 15000000 -tau 4 -memSize 20000 -epsilonSteps 10000 -valFreq 10000 -valSteps 6000 -bootstraps 0 -PALpha 0 "$@"
elif [ "$PAPER" == "async-nstep" ]; then
# Steps for "1 day" = 80 * 1e6; for "4 days" = 1e9
th main.lua -async NStepQ -bootstraps 0 -batchSize 5 -momentum 0.99 -rmsEpsilon 0.1 -steps 80000000 -game $GAME -duel false -tau 40000 -optimiser sharedRmsProp -epsilonSteps 4000000 -doubleQ false -PALpha 0 -eta 0.0007 -gradClip 0 "$@"
Expand Down

0 comments on commit 0de62cb

Please sign in to comment.