Skip to content

Commit

Permalink
Add MCTS classes & index file
Browse files Browse the repository at this point in the history
  • Loading branch information
quasimik committed Jul 9, 2018
1 parent 125c813 commit e18284f
Show file tree
Hide file tree
Showing 3 changed files with 330 additions and 0 deletions.
25 changes: 25 additions & 0 deletions index.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
'use strict'

const Game = require('./game.js')
const MonteCarlo = require('./monte-carlo.js')

let game = new Game()
let state = game.start()
console.log(state.board)

let monteCarlo = new MonteCarlo(game)
// monteCarlo.update(state)

let winner = game.winner(state)
while (winner === null) {

monteCarlo.runSims(state, 0.1)
let play = monteCarlo.getPlay(state, "best") // Timeout = 5 seconds

state = game.nextState(state, play)
let printBoard = state.board.map((row) => row.map((cell) => cell == -1 ? 2 : cell))
console.log(printBoard)

winner = game.winner(state)
}
console.log(winner)
108 changes: 108 additions & 0 deletions monte-carlo-node.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
'use strict'

/**
* Class representing a node in the search tree.
* Stores UCB1 wins/simulations stats.
*/
class MonteCarloNode {
/**
* Create a new MonteCarloNode in the search tree.
* @param {MonteCarloNode} parent - The parent node.
* @param {number[][]} unexpandedPlays - An array of unexpanded play hashes.
*/
constructor(parent, play, state, unexpandedPlays) {
this.play = play // Last play played to get to this state
this.state = state // Corresponding state

// Monte Carlo stuff
this.n_plays = 0
this.n_wins = 0

// Tree stuff
this.parent = parent // Parent MonteCarloNode
this.children = new Map() // Map: hash(play) => { expanded?, child MonteCarloNode }
for (let play of unexpandedPlays) {
this.children.set(play.hash(), { play: play, node: null })
}
}

/**
* Get the MonteCarloNode corresponding to the given play.
* @param {number} play - The play leading to the child node.
* @return {MonteCarloNode} The child node corresponding to the play given.
*/
childNode(play) {
let child = this.children.get(play.hash())
if (child === undefined) {
throw new Error('No such play!')
}
else if (child.node === null) {
throw new Error("Child is not expanded!")
}
return child.node
}

/**
* Expand the child play at the specified index and return it.
* Add the node to the array of children nodes.
* Remove the play from the array of unexpanded plays.
* @param {Play} play - The play to expand.
* @param {State} childState - The child state corresponding to the given play.
* @param {Play[]} childPlays - Legal plays of given child.
* @return {MonteCarloNode} The new child node.
*/
expand(play, childState, childUnexpandedPlays) {
if (!this.children.has(play.hash())) throw new Error("No such play!")
let childNode = new MonteCarloNode(this, play, childState, childUnexpandedPlays)
this.children.set(play.hash(), { play: play, node: childNode })
return childNode
}

allPlays() {
let ret = []
for (let child of this.children.values()) {
ret.push(child.play)
}
return ret
}

unexpandedPlays() {
let ret = []
for (let child of this.children.values()) {
if (child.node === null) ret.push(child.play)
}
return ret
}

/**
* @return {boolean} Whether all the children plays have expanded nodes
*/
isFullyExpanded() {
for (let child of this.children.values()) {
if (child.node === null) return false
}
return true
}

/**
* @return {boolean} Whether this node is terminal in the game tree
*/
isLeaf() {
if (this.children.size === 0) return true
else return false
}

/**
* Get the UCB1 value for this node.
* @param {number} biasParam - The square of the bias parameter in the UCB1 algorithm, defaults to 2.
* @return {number} The UCB1 value of this node.
*/
getUCB1(biasParam) {
// console.log(this.n_wins / this.n_plays)
// console.log(Math.sqrt(biasParam * Math.log(this.parent.plays) / this.n_plays))
return (this.n_wins / this.n_plays) + Math.sqrt(biasParam * Math.log(this.parent.n_plays) / this.n_plays);
}

}

module.exports = MonteCarloNode
197 changes: 197 additions & 0 deletions monte-carlo.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
'use strict'

const MonteCarloNode = require('./monte-carlo-node.js')

/** Class representing the Monte Carlo search tree. */
class MonteCarlo {

/**
* Create a Monte Carlo search tree.
* @param {Game} game - The game to query regarding legal moves and state advancement.
* @param {number} UCB1ExploreParam - The square of the bias parameter in the UCB1 algorithm, defaults to 2.
*/
constructor(game, UCB1ExploreParam = 2) {
this.game = game
this.UCB1ExploreParam = UCB1ExploreParam
// this.state = null // current state node
this.nodes = new Map() // map: hash(State) => MonteCarloNode
}

/**
* If state does not exist, create dangling node
* @param {State} state - The state to make a node for; its parent is set to null.
*/
makeNode(state) {
// this.state = state
if (!this.nodes.has(state.hash())) {
let unexpandedPlays = this.game.legalPlays(state).slice()
let node = new MonteCarloNode(null, null, state, unexpandedPlays)
// console.log(node.children) // DEBUG
this.nodes.set(state.hash(), node)
}
}

/**
* From given state, run as many simulations as possible until the time limit, building statistics.
* @param {number} timeout - The time to run the simulations for, in seconds.
*/
runSims(state, timeout) {

this.makeNode(state)

let draws = 0
let totalSims = 0

let start = Date.now()
let end = start + timeout * 1000

// console.log("a") // DEBUG
while (Date.now() < end) {

// console.log("sel") // DEBUG
let node = this.select(state)
// console.log("exp") // DEBUG
if (!node.isLeaf()) {
node = this.expand(node)
}
// console.log("sim") // DEBUG
let winner = this.simulate(node)
// console.log("bpg") // DEBUG
this.backpropagate(node, winner) // ??

if (winner === 0) draws++
totalSims++
// console.log("") // DEBUG
}

console.log('time(s) ' + timeout + '/' + timeout + ' (FINISHED)')
console.log('total sims : ' + totalSims)
console.log('total rate(sims/s) : ' + (totalSims/timeout).toFixed(1))
console.log('draws : ' + draws) // no winner
}

/**
* From the available statistics, calculate the best move from the given state.
* @return {Play} The best play from the current state.
*/
getPlay(state, policy = "robust") {

this.makeNode(state)

// If not all children are expanded, not enough information
if (this.nodes.get(state.hash()).isFullyExpanded() === false)
return null

let node = this.nodes.get(state.hash())
let allPlays = node.allPlays()
let bestPlay

// Most visits (Chaslot's robust child)
if (policy === "robust") {
let max = 0
for (let play of allPlays) {
let childNode = node.childNode(play)
if (childNode.n_plays > max) {
bestPlay = play
max = childNode.n_plays
}
}
}
// Highest winrate (Best child)
else if (policy === "best") {
let max = 0
for (let play of allPlays) {
let childNode = node.childNode(play)
let ratio = childNode.n_wins / childNode.n_plays
if (ratio > max) {
bestPlay = play
max = ratio
}
}
// console.log(max)
}

return bestPlay
}

/**
* Phase 1: Selection
* Select until EITHER not fully expanded OR leaf node
*/
select(state) {
let node = this.nodes.get(state.hash())
while(node.isFullyExpanded() && !node.isLeaf()) {
// console.log("x") // DEBUG
let plays = node.allPlays()
// console.log(plays) // DEBUG
let bestPlay
let bestUCB1 = 0
for (let play of plays) {
// console.log(node.childNode(play)) // DEBUG
let childUCB1 = node.childNode(play).getUCB1(this.UCB1ExploreParam)
// console.log("---") // DEBUG
if (childUCB1 > bestUCB1) {
bestPlay = play
bestUCB1 = childUCB1
}
}
node = node.childNode(bestPlay)
}
return node
}

/**
* Phase 2: Expansion
* Of the given node, expand a random unexpanded child node
* Assume given node is not a leaf
*/
expand(node) {
let plays = node.unexpandedPlays()
let index = Math.floor(Math.random() * plays.length)
let play = plays[index]

let childState = this.game.nextState(node.state, play)
let childUnexpandedPlays = this.game.legalPlays(childState)
// let node = new MonteCarloNode(node, play, childState, childUnexpandedPlays)
let childNode = node.expand(play, childState, childUnexpandedPlays)
this.nodes.set(childState.hash(), childNode)

return childNode
}

/**
* Phase 3: Simulation
* From given node, play the game until a terminal state, then return winner
*/
simulate(node) {
let state = node.state
let winner = this.game.winner(state)
while (winner === null) {
let plays = this.game.legalPlays(state)
let play = plays[Math.floor(Math.random() * plays.length)]
state = this.game.nextState(state, play)
winner = this.game.winner(state)
}
return winner
}

/**
* Phase 4: Backpropagation
* From given node, propagate winner to ancestors' statistics
*/
backpropagate(node, winner) {
while (node !== null) {
node.n_plays += 1
// Flip for parent's choice
if (node.state.player === -winner) {
node.n_wins += 1
}
node = node.parent
}
}



}

module.exports = MonteCarlo

0 comments on commit e18284f

Please sign in to comment.