-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
330 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
'use strict' | ||
|
||
const Game = require('./game.js') | ||
const MonteCarlo = require('./monte-carlo.js') | ||
|
||
let game = new Game() | ||
let state = game.start() | ||
console.log(state.board) | ||
|
||
let monteCarlo = new MonteCarlo(game) | ||
// monteCarlo.update(state) | ||
|
||
let winner = game.winner(state) | ||
while (winner === null) { | ||
|
||
monteCarlo.runSims(state, 0.1) | ||
let play = monteCarlo.getPlay(state, "best") // Timeout = 5 seconds | ||
|
||
state = game.nextState(state, play) | ||
let printBoard = state.board.map((row) => row.map((cell) => cell == -1 ? 2 : cell)) | ||
console.log(printBoard) | ||
|
||
winner = game.winner(state) | ||
} | ||
console.log(winner) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
'use strict' | ||
|
||
/** | ||
* Class representing a node in the search tree. | ||
* Stores UCB1 wins/simulations stats. | ||
*/ | ||
class MonteCarloNode { | ||
/** | ||
* Create a new MonteCarloNode in the search tree. | ||
* @param {MonteCarloNode} parent - The parent node. | ||
* @param {number[][]} unexpandedPlays - An array of unexpanded play hashes. | ||
*/ | ||
constructor(parent, play, state, unexpandedPlays) { | ||
this.play = play // Last play played to get to this state | ||
this.state = state // Corresponding state | ||
|
||
// Monte Carlo stuff | ||
this.n_plays = 0 | ||
this.n_wins = 0 | ||
|
||
// Tree stuff | ||
this.parent = parent // Parent MonteCarloNode | ||
this.children = new Map() // Map: hash(play) => { expanded?, child MonteCarloNode } | ||
for (let play of unexpandedPlays) { | ||
this.children.set(play.hash(), { play: play, node: null }) | ||
} | ||
} | ||
|
||
/** | ||
* Get the MonteCarloNode corresponding to the given play. | ||
* @param {number} play - The play leading to the child node. | ||
* @return {MonteCarloNode} The child node corresponding to the play given. | ||
*/ | ||
childNode(play) { | ||
let child = this.children.get(play.hash()) | ||
if (child === undefined) { | ||
throw new Error('No such play!') | ||
} | ||
else if (child.node === null) { | ||
throw new Error("Child is not expanded!") | ||
} | ||
return child.node | ||
} | ||
|
||
/** | ||
* Expand the child play at the specified index and return it. | ||
* Add the node to the array of children nodes. | ||
* Remove the play from the array of unexpanded plays. | ||
* @param {Play} play - The play to expand. | ||
* @param {State} childState - The child state corresponding to the given play. | ||
* @param {Play[]} childPlays - Legal plays of given child. | ||
* @return {MonteCarloNode} The new child node. | ||
*/ | ||
expand(play, childState, childUnexpandedPlays) { | ||
if (!this.children.has(play.hash())) throw new Error("No such play!") | ||
let childNode = new MonteCarloNode(this, play, childState, childUnexpandedPlays) | ||
this.children.set(play.hash(), { play: play, node: childNode }) | ||
return childNode | ||
} | ||
|
||
allPlays() { | ||
let ret = [] | ||
for (let child of this.children.values()) { | ||
ret.push(child.play) | ||
} | ||
return ret | ||
} | ||
|
||
unexpandedPlays() { | ||
let ret = [] | ||
for (let child of this.children.values()) { | ||
if (child.node === null) ret.push(child.play) | ||
} | ||
return ret | ||
} | ||
|
||
/** | ||
* @return {boolean} Whether all the children plays have expanded nodes | ||
*/ | ||
isFullyExpanded() { | ||
for (let child of this.children.values()) { | ||
if (child.node === null) return false | ||
} | ||
return true | ||
} | ||
|
||
/** | ||
* @return {boolean} Whether this node is terminal in the game tree | ||
*/ | ||
isLeaf() { | ||
if (this.children.size === 0) return true | ||
else return false | ||
} | ||
|
||
/** | ||
* Get the UCB1 value for this node. | ||
* @param {number} biasParam - The square of the bias parameter in the UCB1 algorithm, defaults to 2. | ||
* @return {number} The UCB1 value of this node. | ||
*/ | ||
getUCB1(biasParam) { | ||
// console.log(this.n_wins / this.n_plays) | ||
// console.log(Math.sqrt(biasParam * Math.log(this.parent.plays) / this.n_plays)) | ||
return (this.n_wins / this.n_plays) + Math.sqrt(biasParam * Math.log(this.parent.n_plays) / this.n_plays); | ||
} | ||
|
||
} | ||
|
||
module.exports = MonteCarloNode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
'use strict' | ||
|
||
const MonteCarloNode = require('./monte-carlo-node.js') | ||
|
||
/** Class representing the Monte Carlo search tree. */ | ||
class MonteCarlo { | ||
|
||
/** | ||
* Create a Monte Carlo search tree. | ||
* @param {Game} game - The game to query regarding legal moves and state advancement. | ||
* @param {number} UCB1ExploreParam - The square of the bias parameter in the UCB1 algorithm, defaults to 2. | ||
*/ | ||
constructor(game, UCB1ExploreParam = 2) { | ||
this.game = game | ||
this.UCB1ExploreParam = UCB1ExploreParam | ||
// this.state = null // current state node | ||
this.nodes = new Map() // map: hash(State) => MonteCarloNode | ||
} | ||
|
||
/** | ||
* If state does not exist, create dangling node | ||
* @param {State} state - The state to make a node for; its parent is set to null. | ||
*/ | ||
makeNode(state) { | ||
// this.state = state | ||
if (!this.nodes.has(state.hash())) { | ||
let unexpandedPlays = this.game.legalPlays(state).slice() | ||
let node = new MonteCarloNode(null, null, state, unexpandedPlays) | ||
// console.log(node.children) // DEBUG | ||
this.nodes.set(state.hash(), node) | ||
} | ||
} | ||
|
||
/** | ||
* From given state, run as many simulations as possible until the time limit, building statistics. | ||
* @param {number} timeout - The time to run the simulations for, in seconds. | ||
*/ | ||
runSims(state, timeout) { | ||
|
||
this.makeNode(state) | ||
|
||
let draws = 0 | ||
let totalSims = 0 | ||
|
||
let start = Date.now() | ||
let end = start + timeout * 1000 | ||
|
||
// console.log("a") // DEBUG | ||
while (Date.now() < end) { | ||
|
||
// console.log("sel") // DEBUG | ||
let node = this.select(state) | ||
// console.log("exp") // DEBUG | ||
if (!node.isLeaf()) { | ||
node = this.expand(node) | ||
} | ||
// console.log("sim") // DEBUG | ||
let winner = this.simulate(node) | ||
// console.log("bpg") // DEBUG | ||
this.backpropagate(node, winner) // ?? | ||
|
||
if (winner === 0) draws++ | ||
totalSims++ | ||
// console.log("") // DEBUG | ||
} | ||
|
||
console.log('time(s) ' + timeout + '/' + timeout + ' (FINISHED)') | ||
console.log('total sims : ' + totalSims) | ||
console.log('total rate(sims/s) : ' + (totalSims/timeout).toFixed(1)) | ||
console.log('draws : ' + draws) // no winner | ||
} | ||
|
||
/** | ||
* From the available statistics, calculate the best move from the given state. | ||
* @return {Play} The best play from the current state. | ||
*/ | ||
getPlay(state, policy = "robust") { | ||
|
||
this.makeNode(state) | ||
|
||
// If not all children are expanded, not enough information | ||
if (this.nodes.get(state.hash()).isFullyExpanded() === false) | ||
return null | ||
|
||
let node = this.nodes.get(state.hash()) | ||
let allPlays = node.allPlays() | ||
let bestPlay | ||
|
||
// Most visits (Chaslot's robust child) | ||
if (policy === "robust") { | ||
let max = 0 | ||
for (let play of allPlays) { | ||
let childNode = node.childNode(play) | ||
if (childNode.n_plays > max) { | ||
bestPlay = play | ||
max = childNode.n_plays | ||
} | ||
} | ||
} | ||
// Highest winrate (Best child) | ||
else if (policy === "best") { | ||
let max = 0 | ||
for (let play of allPlays) { | ||
let childNode = node.childNode(play) | ||
let ratio = childNode.n_wins / childNode.n_plays | ||
if (ratio > max) { | ||
bestPlay = play | ||
max = ratio | ||
} | ||
} | ||
// console.log(max) | ||
} | ||
|
||
return bestPlay | ||
} | ||
|
||
/** | ||
* Phase 1: Selection | ||
* Select until EITHER not fully expanded OR leaf node | ||
*/ | ||
select(state) { | ||
let node = this.nodes.get(state.hash()) | ||
while(node.isFullyExpanded() && !node.isLeaf()) { | ||
// console.log("x") // DEBUG | ||
let plays = node.allPlays() | ||
// console.log(plays) // DEBUG | ||
let bestPlay | ||
let bestUCB1 = 0 | ||
for (let play of plays) { | ||
// console.log(node.childNode(play)) // DEBUG | ||
let childUCB1 = node.childNode(play).getUCB1(this.UCB1ExploreParam) | ||
// console.log("---") // DEBUG | ||
if (childUCB1 > bestUCB1) { | ||
bestPlay = play | ||
bestUCB1 = childUCB1 | ||
} | ||
} | ||
node = node.childNode(bestPlay) | ||
} | ||
return node | ||
} | ||
|
||
/** | ||
* Phase 2: Expansion | ||
* Of the given node, expand a random unexpanded child node | ||
* Assume given node is not a leaf | ||
*/ | ||
expand(node) { | ||
let plays = node.unexpandedPlays() | ||
let index = Math.floor(Math.random() * plays.length) | ||
let play = plays[index] | ||
|
||
let childState = this.game.nextState(node.state, play) | ||
let childUnexpandedPlays = this.game.legalPlays(childState) | ||
// let node = new MonteCarloNode(node, play, childState, childUnexpandedPlays) | ||
let childNode = node.expand(play, childState, childUnexpandedPlays) | ||
this.nodes.set(childState.hash(), childNode) | ||
|
||
return childNode | ||
} | ||
|
||
/** | ||
* Phase 3: Simulation | ||
* From given node, play the game until a terminal state, then return winner | ||
*/ | ||
simulate(node) { | ||
let state = node.state | ||
let winner = this.game.winner(state) | ||
while (winner === null) { | ||
let plays = this.game.legalPlays(state) | ||
let play = plays[Math.floor(Math.random() * plays.length)] | ||
state = this.game.nextState(state, play) | ||
winner = this.game.winner(state) | ||
} | ||
return winner | ||
} | ||
|
||
/** | ||
* Phase 4: Backpropagation | ||
* From given node, propagate winner to ancestors' statistics | ||
*/ | ||
backpropagate(node, winner) { | ||
while (node !== null) { | ||
node.n_plays += 1 | ||
// Flip for parent's choice | ||
if (node.state.player === -winner) { | ||
node.n_wins += 1 | ||
} | ||
node = node.parent | ||
} | ||
} | ||
|
||
|
||
|
||
} | ||
|
||
module.exports = MonteCarlo |