Skip to content

Commit

Permalink
Tweak Setup
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaixhin committed Jun 5, 2016
1 parent b84552a commit 77ffd85
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 65 deletions.
14 changes: 7 additions & 7 deletions ExperienceReplay.lua → Master.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@ local Evaluator = require 'Evaluator'
local classic = require 'classic'


local ExperienceReplay = classic.class('ExperienceReplay')
local Master = classic.class('Master')

function ExperienceReplay:_init(opt)
function Master:_init(opt)
self.opt = opt

-- Set up singleton global object for transferring step
Expand Down Expand Up @@ -76,7 +76,7 @@ function ExperienceReplay:_init(opt)
end


function ExperienceReplay:train()
function Master:train()
self:catchSigInt()

local reward, state, terminal = 0, self.env:start(), false
Expand Down Expand Up @@ -149,7 +149,7 @@ function ExperienceReplay:train()
end


function ExperienceReplay:validate()
function Master:validate()
log.info('Validating')
-- Set environment and agent to evaluation mode
if self.opt.ale then self.env:evaluate() end
Expand Down Expand Up @@ -230,7 +230,7 @@ function ExperienceReplay:validate()
end


function ExperienceReplay:evaluate()
function Master:evaluate()
log.info('Evaluation mode')
-- Set environment and agent to evaluation mode
if self.opt.ale then self.env:evaluate() end
Expand Down Expand Up @@ -261,7 +261,7 @@ end


-- Set up SIGINT (Ctrl+C) handler to save network before quitting
function ExperienceReplay:catchSigInt()
function Master:catchSigInt()
signal.signal(signal.SIGINT, function(signum)
log.warn('SIGINT received')
log.info('Save agent (y/n)?')
Expand All @@ -275,4 +275,4 @@ function ExperienceReplay:catchSigInt()
end


return ExperienceReplay
return Master
97 changes: 51 additions & 46 deletions Setup.lua
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
----- General Setup -----
require 'logroll'
local cjson = require 'cjson'
local classic = require 'classic'
local _ = require 'moses'
local classic = require 'classic'
local cjson = require 'cjson'

local Setup = classic.class('Setup')

-- Performs global setup
function Setup:_init(arg)
-- Create log10 for Lua 5.2
if not math.log10 then
Expand All @@ -14,49 +14,61 @@ function Setup:_init(arg)
end
end

local opt = self:options(arg)
-- Parse command-line options
self.opt = self:parseOptions(arg)

-- Create experiment directory
if not paths.dirp(self.opt.experiments) then
paths.mkdir(self.opt.experiments)
end
paths.mkdir(paths.concat(self.opt.experiments, self.opt._id))
-- Save options for reference
local file = torch.DiskFile(paths.concat(self.opt.experiments, self.opt._id, 'opts.json'), 'w')
file:writeString(cjson.encode(self.opt))
file:close()

-- Set up logs
local flog = logroll.file_logger(paths.concat(opt.experiments, opt._id, 'log.txt'))
-- Set up logging
local flog = logroll.file_logger(paths.concat(self.opt.experiments, self.opt._id, 'log.txt'))
local plog = logroll.print_logger()
log = logroll.combine(flog, plog)
log = logroll.combine(flog, plog) -- Global logger

self:validateOptions(opt)
-- Validate command-line options (logging errors)
self:validateOptions()

-- Torch setup
log.info('Setting up Torch7')
-- Use enhanced garbage collector
torch.setheaptracking(true)
-- Set number of BLAS threads
torch.setnumthreads(opt.threads)
torch.setnumthreads(self.opt.threads)
-- Set default Tensor type (float is more efficient than double)
torch.setdefaulttensortype(opt.tensorType)
torch.setdefaulttensortype(self.opt.tensorType)
-- Set manual seed
torch.manualSeed(opt.seed)
torch.manualSeed(self.opt.seed)

-- Tensor creation function for removing need to cast to CUDA if GPU is enabled
opt.Tensor = function(...)
-- TODO: Replace with local functions across codebase
self.opt.Tensor = function(...)
return torch.Tensor(...)
end

-- GPU setup
if opt.gpu > 0 then
if self.opt.gpu > 0 then
log.info('Setting up GPU')
cutorch.setDevice(opt.gpu)
cutorch.setDevice(self.opt.gpu)
-- Set manual seeds using random numbers to reduce correlations
cutorch.manualSeed(torch.random())
-- Replace tensor creation function
opt.Tensor = function(...)
self.opt.Tensor = function(...)
return torch.CudaTensor(...)
end
end

self.opt = opt
classic.strict(self)
end


function Setup:options(arg)
-- Parses command-line options
function Setup:parseOptions(arg)
-- Detect and use GPU 1 by default
local cuda = pcall(require, 'cutorch')

Expand Down Expand Up @@ -113,16 +125,16 @@ function Setup:options(arg)
cmd:option('-valFreq', 250000, 'Interval of steps between validating agent') -- valFreq steps is used as an epoch, hence #epochs = steps/valFreq
cmd:option('-valSteps', 125000, 'Number of steps to use for validation')
cmd:option('-valSize', 500, 'Number of transitions to use for calculating validation statistics')
-- Async options
cmd:option('-async', 'false', 'Async agent: false|Sarsa|OneStepQ|NStepQ|A3C') -- TODO: Change names
cmd:option('-rmsEpsilon', 0.1, 'Epsilon for sharedRmsProp')
cmd:option('-noValidation', 'false', 'Disable asynchronous agent validation thread') -- TODO: Make experiment option (not just for async)
-- ALEWrap options
cmd:option('-fullActions', 'false', 'Use full set of 18 actions')
cmd:option('-actRep', 4, 'Times to repeat action') -- Independent of history length
cmd:option('-randomStarts', 30, 'Max number of no-op actions played before presenting the start of each training episode')
cmd:option('-poolFrmsType', 'max', 'Type of pooling over previous emulator frames: max|mean')
cmd:option('-poolFrmsSize', 2, 'Number of emulator frames to pool over')
-- Async options
cmd:option('-async', 'false', 'async method') -- OneStepQ|NStepQ|Sarsa|A3C
cmd:option('-rmsEpsilon', 0.1, 'Epsilon for sharedRmsProp')
cmd:option('-novalidation', 'false', 'dont run validation thread in async') -- for debugging
-- Experiment options
cmd:option('-experiments', 'experiments', 'Base directory to store experiments')
cmd:option('-_id', '', 'ID of experiment (used to store saved results, defaults to game name)')
Expand All @@ -140,7 +152,9 @@ function Setup:options(arg)
opt.fullActions = opt.fullActions == 'true'
opt.verbose = opt.verbose == 'true'
opt.record = opt.record == 'true'
opt.novalidation = opt.novalidation == 'true'
opt.noValidation = opt.noValidation == 'true'

-- Process async agent options
if opt.async == 'false' then opt.async = false end
if opt.async then opt.gpu = 0 end

Expand All @@ -149,70 +163,61 @@ function Setup:options(arg)
opt._id = opt.game
end

-- Set ALE flag
-- TODO: Make environment more independent
opt.ale = opt.game ~= 'catch'

-- Create experiment directory
if not paths.dirp(opt.experiments) then
paths.mkdir(opt.experiments)
end
paths.mkdir(paths.concat(opt.experiments, opt._id))
-- Save options for reference
local file = torch.DiskFile(paths.concat(opt.experiments, opt._id, 'opts.json'), 'w')
file:writeString(cjson.encode(opt))
file:close()

return opt
end


function Setup:validateOptions(opt)
-- Validates command-line options
function Setup:validateOptions()
-- Calculate number of colour channels
if not _.contains({'rgb', 'y', 'lab', 'yuv', 'hsl', 'hsv', 'nrgb'}, opt.colorSpace) then
if not _.contains({'rgb', 'y', 'lab', 'yuv', 'hsl', 'hsv', 'nrgb'}, self.opt.colorSpace) then
self:abort('Unsupported colour space for conversion')
end
opt.nChannels = opt.colorSpace == 'y' and 1 or 3
self.opt.nChannels = self.opt.colorSpace == 'y' and 1 or 3

-- Check start of learning occurs after at least one minibatch of data has been collected
if opt.learnStart <= opt.batchSize then
if self.opt.learnStart <= self.opt.batchSize then
self:abort('learnStart must be greater than batchSize')
end

-- Check enough validation transitions will be collected before first validation
if opt.valFreq <= opt.valSize then
if self.opt.valFreq <= self.opt.valSize then
self:abort('valFreq must be greater than valSize')
end

-- Check prioritised experience replay options
if not _.contains({'none', 'rank', 'proportional'}, opt.memPriority) then
if not _.contains({'none', 'rank', 'proportional'}, self.opt.memPriority) then
self:abort('Type of prioritised experience replay unrecognised')
end

-- Check start of learning occurs after at least 1/100 of memory has been filled
if opt.learnStart <= opt.memSize/100 then
if self.opt.learnStart <= self.opt.memSize/100 then
self:abort('learnStart must be greater than memSize/100')
end

-- Check memory size is multiple of 100 (makes prioritised sampling partitioning simpler)
if opt.memSize % 100 ~= 0 then
if self.opt.memSize % 100 ~= 0 then
self:abort('memSize must be a multiple of 100')
end

-- Check learning occurs after first progress report
if opt.learnStart < opt.progFreq then
if self.opt.learnStart < self.opt.progFreq then
self:abort('learnStart must be greater than progFreq')
end

-- Check saliency map options
if not _.contains({'none', 'normal', 'guided', 'deconvnet'}, opt.saliency) then
if not _.contains({'none', 'normal', 'guided', 'deconvnet'}, self.opt.saliency) then
self:abort('Unrecognised method for visualising saliency maps')
end
end


-- Aborts setup (if options are invalid)
function Setup:abort(err)
log.error(err)
error(err)
end

return Setup

4 changes: 2 additions & 2 deletions async/AsyncMaster.lua
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ function AsyncMaster:_init(opt)
local signal = require 'posix.signal'
local ValidationAgent = require 'async/ValidationAgent'
validAgent = ValidationAgent(opt, theta, atomic)
if not opt.novalidation then
if not opt.noValidation then
signal.signal(signal.SIGINT, function(signum)
log.warn('SIGINT received')
log.info('Saving agent')
Expand Down Expand Up @@ -204,7 +204,7 @@ function AsyncMaster:start()
end
end

if not self.opt.novalidation then
if not self.opt.noValidation then
self.controlPool:addjob(validator)
end

Expand Down
18 changes: 8 additions & 10 deletions main.lua
Original file line number Diff line number Diff line change
@@ -1,24 +1,22 @@
local Setup = require 'Setup'
local ExperienceReplay = require 'ExperienceReplay'
local Master = require 'Master'
local AsyncMaster = require 'async/AsyncMaster'

-- Parse options and perform setup
local setup = Setup(arg)
local opt = setup.opt

-- Start master experiment runner
if opt.async then
log.info(opt)
local master = AsyncMaster(opt)
master:start()

master:start() -- TODO: Use same API as normal master
else
local experienceReplay = ExperienceReplay(opt)
local master = Master(opt)

if opt.mode == 'train' then
experienceReplay:train()

master:train()
elseif opt.mode == 'eval' then
experienceReplay:evaluate()

master:evaluate()
end

end
end

0 comments on commit 77ffd85

Please sign in to comment.