Skip to content

Commit

Permalink
Clean up requires
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaixhin committed Aug 14, 2016
1 parent 0010e27 commit c648b5e
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 23 deletions.
5 changes: 4 additions & 1 deletion Master.lua
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,10 @@ function Master:_init(opt)
end

-- Start gaming
log.info('Starting game: ' .. opt.game)
log.info('Starting ' .. opt.env)
if opt.game then
log.info('Starting game: ' .. opt.game)
end
local state = self.env:start()

-- Set up display (if available)
Expand Down
11 changes: 1 addition & 10 deletions examples/GridWorldNet.lua
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
local nn = require 'nn'
local nninit = require 'nninit'
require 'classic.torch' -- Enables serialisation
--require 'rnn'
--require 'dpnn' -- Adds gradParamClip method

local Body = classic.class('Body')

-- Constructor
function Body:_init(opts)
opts = opts or {}

--self.recurrent = opts.recurrent
--self.histLen = opts.histLen
--self.stateSpec = opts.stateSpec
end

function Body:createBody()
local net

net = nn.Sequential()
local net = nn.Sequential()
net:add(nn.View(2))
net:add(nn.Linear(2, 32))
net:add(nn.ReLU(true))
Expand Down
7 changes: 1 addition & 6 deletions models/Atari.lua
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
local nn = require 'nn'
local nninit = require 'nninit'
require 'classic.torch' -- Enables serialisation
require 'rnn'
require 'dpnn' -- Adds gradParamClip method

local Body = classic.class('Body')

Expand All @@ -18,9 +15,7 @@ end
function Body:createBody()
-- Number of input frames for recurrent networks is always 1
local histLen = self.recurrent and 1 or self.histLen
local net

net = nn.Sequential()
local net = nn.Sequential()
net:add(nn.View(histLen*self.stateSpec[2][1], self.stateSpec[2][2], self.stateSpec[2][3])) -- Concatenate history in channel dimension
net:add(nn.SpatialConvolution(histLen*self.stateSpec[2][1], 32, 8, 8, 4, 4, 1, 1))
net:add(nn.ReLU(true))
Expand Down
7 changes: 1 addition & 6 deletions models/Catch.lua
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
local nn = require 'nn'
local nninit = require 'nninit'
require 'classic.torch' -- Enables serialisation
require 'rnn'
require 'dpnn' -- Adds gradParamClip method

local Body = classic.class('Body')

Expand All @@ -18,9 +15,7 @@ end
function Body:createBody()
-- Number of input frames for recurrent networks is always 1
local histLen = self.recurrent and 1 or self.histLen
local net

net = nn.Sequential()
local net = nn.Sequential()
net:add(nn.View(histLen*self.stateSpec[2][1], self.stateSpec[2][2], self.stateSpec[2][3]))
net:add(nn.SpatialConvolution(histLen*self.stateSpec[2][1], 32, 5, 5, 2, 2, 1, 1))
net:add(nn.ReLU(true))
Expand Down

0 comments on commit c648b5e

Please sign in to comment.