Skip to content

Commit

Permalink
Add modelBody option
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaixhin committed Jun 16, 2016
1 parent 5099da4 commit c28443e
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 25 deletions.
72 changes: 47 additions & 25 deletions Model.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
local _ = require 'moses'
local paths = require 'paths'
local classic = require 'classic'
local nn = require 'nn'
local nninit = require 'nninit'
Expand All @@ -17,11 +18,13 @@ local Model = classic.class('Model')
-- Creates a Model (a helper for the network it creates)
function Model:_init(opt)
-- Extract relevant options
self.tensorType = opt.tensorType
self.gpu = opt.gpu
self.colorSpace = opt.colorSpace
self.width = opt.width
self.height = opt.height
self.nChannels = opt.nChannels
self.modelBody = opt.modelBody
self.hiddenSize = opt.hiddenSize
self.histLen = opt.histLen
self.duel = opt.duel
Expand Down Expand Up @@ -50,38 +53,57 @@ function Model:preprocess(observation)
end
end

-- Calculates network output size
local function getOutputSize(net, inputDims)
return net:forward(torch.Tensor(torch.LongStorage(inputDims))):size():totable()
end

-- Creates a dueling DQN based on a number of discrete actions
function Model:create(m)
-- Creates a DQN/AC model body
function Model:createBody()
-- Number of input frames for recurrent networks is always 1
local histLen = self.recurrent and 1 or self.histLen

-- Network starting with convolutional layers
local net = nn.Sequential()
if self.recurrent then
net:add(nn.Copy(nil, nil, true)) -- Needed when splitting batch x seq x input over seq for DRQN; better than nn.Contiguous
end
net:add(nn.View(histLen*self.nChannels, self.height, self.width)) -- Concatenate history in channel dimension
if self.ale then
local net

if paths.filep(self.modelBody) then
net = torch.load(self.modelBody) -- Model must take in TxCxHxW; can use VolumetricConvolution etc.
net:type(self.tensorType)
elseif self.ale then
net = nn.Sequential()
net:add(nn.View(histLen*self.nChannels, self.height, self.width)) -- Concatenate history in channel dimension
net:add(nn.SpatialConvolution(histLen*self.nChannels, 32, 8, 8, 4, 4, 1, 1))
net:add(nn.ReLU(true))
net:add(nn.SpatialConvolution(32, 64, 4, 4, 2, 2))
net:add(nn.ReLU(true))
net:add(nn.SpatialConvolution(64, 64, 3, 3, 1, 1))
net:add(nn.ReLU(true))
else
net = nn.Sequential()
net:add(nn.View(histLen*self.nChannels, self.height, self.width))
net:add(nn.SpatialConvolution(histLen*self.nChannels, 32, 5, 5, 2, 2, 1, 1))
net:add(nn.ReLU(true))
net:add(nn.SpatialConvolution(32, 32, 5, 5, 2, 2))
net:add(nn.ReLU(true))
end
-- Calculate convolutional network output size
local convOutputSize = torch.prod(torch.Tensor(getOutputSize(net, {histLen*self.nChannels, self.height, self.width})))
net:add(nn.View(convOutputSize))

return net
end

-- Calculates network output size
local function getOutputSize(net, inputDims)
return net:forward(torch.Tensor(torch.LongStorage(inputDims))):size():totable()
end

-- Creates a DQN/AC model based on a number of discrete actions
function Model:create(m)
-- Number of input frames for recurrent networks is always 1
local histLen = self.recurrent and 1 or self.histLen

-- Network starting with convolutional layers/model body
local net = nn.Sequential()
if self.recurrent then
net:add(nn.Copy(nil, nil, true)) -- Needed when splitting batch x seq x input over seq for DRQN; better than nn.Contiguous
end

-- Add network body
net:add(self:createBody())
-- Calculate body output size
local bodyOutputSize = torch.prod(torch.Tensor(getOutputSize(net, {histLen, self.nChannels, self.height, self.width})))
net:add(nn.View(bodyOutputSize))

-- Network head
local head = nn.Sequential()
Expand All @@ -90,23 +112,23 @@ function Model:create(m)
-- Value approximator V^(s)
local valStream = nn.Sequential()
if self.recurrent then
local lstm = nn.FastLSTM(convOutputSize, self.hiddenSize, self.histLen)
local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen)
lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1)
valStream:add(lstm)
else
valStream:add(nn.Linear(convOutputSize, self.hiddenSize))
valStream:add(nn.Linear(bodyOutputSize, self.hiddenSize))
valStream:add(nn.ReLU(true))
end
valStream:add(nn.Linear(self.hiddenSize, 1)) -- Predicts value for state

-- Advantage approximator A^(s, a)
local advStream = nn.Sequential()
if self.recurrent then
local lstm = nn.FastLSTM(convOutputSize, self.hiddenSize, self.histLen)
local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen)
lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1)
advStream:add(lstm)
else
advStream:add(nn.Linear(convOutputSize, self.hiddenSize))
advStream:add(nn.Linear(bodyOutputSize, self.hiddenSize))
advStream:add(nn.ReLU(true))
end
advStream:add(nn.Linear(self.hiddenSize, m)) -- Predicts action-conditional advantage
Expand All @@ -124,15 +146,15 @@ function Model:create(m)
head:add(DuelAggregator(m))
else
if self.recurrent then
local lstm = nn.FastLSTM(convOutputSize, self.hiddenSize, self.histLen)
local lstm = nn.FastLSTM(bodyOutputSize, self.hiddenSize, self.histLen)
lstm.i2g:init({'bias', {{3*self.hiddenSize+1, 4*self.hiddenSize}}}, nninit.constant, 1) -- Extra: high forget gate bias (Gers et al., 2000)
head:add(lstm)
if self.async then
lstm:remember('both')
head:add(nn.ReLU(true)) -- DRQN paper reports worse performance with ReLU after LSTM, but lets do it anyway...
end
else
head:add(nn.Linear(convOutputSize, self.hiddenSize))
head:add(nn.Linear(bodyOutputSize, self.hiddenSize))
head:add(nn.ReLU(true)) -- DRQN paper reports worse performance with ReLU after LSTM
end
head:add(nn.Linear(self.hiddenSize, m)) -- Note: Tuned DDQN uses shared bias at last layer
Expand All @@ -154,7 +176,7 @@ function Model:create(m)
net:add(nn.GradientRescale(1/self.bootstraps)) -- Normalise gradients by number of heads
net:add(headConcat)
elseif self.a3c then
net:add(nn.Linear(convOutputSize, self.hiddenSize))
net:add(nn.Linear(bodyOutputSize, self.hiddenSize))
net:add(nn.ReLU(true))

local valueAndPolicy = nn.ConcatTable()
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ Run `th main.lua` to run headless, or `qlua main.lua` to display the game. The m

To run experiments based on hyperparameters specified in the individual papers, use `./run.sh <paper> <game> <args>`. `<args>` can be used to overwrite arguments specified earlier (in the script); for more details see the script itself. By default the code trains on a demo environment called Catch - use `./run.sh demo` to run the demo with good default parameters. Note that this code uses CUDA by default if available, but the Catch network is small enough that it runs faster on CPU.

You can use a custom (visual) environment using `-env`, as long as the class provided respects the `rlenvs` [API](https://github.com/Kaixhin/rlenvs#api). If it has separate behaviour during training and testing it should also implement `training` and `evaluate` methods - otherwise these will be added as empty methods during runtime.

You can also use a custom model (body) with `-modelBody`, which replaces the usual DQN convolutional layers with a saved Torch model (which may include pretrained weights). The model will receive the previous frames in a separate dimension to the colour channels, and must reshape them manually if needed; this allows the use of spatiotemporal convolutions. The DQN "heads" will then be constructed as normal, with `-hiddenSize` used to change the size of the fully connected layer if needed.

In training mode if you want to quit using `Ctrl+C` then this will be caught and you will be asked if you would like to save the agent. Note that for non-asynchronous agents the experience replay memory will be included, totalling ~7GB. The main script also automatically saves the weights of the best performing DQN (according to the average validation score).

In evaluation mode you can create recordings with `-record true` (requires FFmpeg); this does not require using `qlua`. Recordings will be stored in the videos directory.
Expand Down
1 change: 1 addition & 0 deletions Setup.lua
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ function Setup:parseOptions(arg)
cmd:option('-width', 84, 'Resize screen width')
cmd:option('-colorSpace', 'y', 'Colour space conversion (screen is RGB): rgb|y|lab|yuv|hsl|hsv|nrgb')
-- Model options
cmd:option('-modelBody', '', 'Path to Torch nn model to be used as DQN "body"')
cmd:option('-hiddenSize', 512, 'Number of units in the hidden fully connected layer')
cmd:option('-histLen', 4, 'Number of consecutive states processed/used for backpropagation-through-time') -- DQN standard is 4, DRQN is 10
cmd:option('-duel', 'true', 'Use dueling network architecture (learns advantage function)')
Expand Down

0 comments on commit c28443e

Please sign in to comment.