Skip to content

Commit

Permalink
Chess Generation 002 (official-monty#48)
Browse files Browse the repository at this point in the history
Value Network
Score of dev vs main: 106 - 22 - 18  [0.788] 146
...      dev playing White: 60 - 6 - 7  [0.870] 73
...      dev playing Black: 46 - 16 - 11  [0.705] 73
...      White vs Black: 76 - 52 - 18  [0.582] 146
Elo difference: 227.7 +/- 63.5, LOS: 100.0 %, DrawRatio: 12.3 %
SPRT: llr 2.9 (100.4%), lbound -2.25, ubound 2.89 - H1 was accepted

Policy Network
Score of dev vs main: 167 - 81 - 51  [0.644] 299
...      dev playing White: 92 - 35 - 23  [0.690] 150
...      dev playing Black: 75 - 46 - 28  [0.597] 149
...      White vs Black: 138 - 110 - 51  [0.547] 299
Elo difference: 102.8 +/- 37.2, LOS: 100.0 %, DrawRatio: 17.1 %
SPRT: llr 2.89 (100.1%), lbound -2.25, ubound 2.89 - H1 was accepted

Horizontally Mirror Value Network
Score of dev vs main: 555 - 458 - 283  [0.537] 1296
...      dev playing White: 314 - 177 - 157  [0.606] 648
...      dev playing Black: 241 - 281 - 126  [0.469] 648
...      White vs Black: 595 - 418 - 283  [0.568] 1296
Elo difference: 26.1 +/- 16.7, LOS: 99.9 %, DrawRatio: 21.8 %
SPRT: llr 2.89 (100.0%), lbound -2.25, ubound 2.89 - H1 was accepted

Bench: 592803
  • Loading branch information
jw1912 authored Apr 7, 2024
1 parent 88f9610 commit fb1b588
Show file tree
Hide file tree
Showing 11 changed files with 24 additions and 14 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ and the executable will be located at `target/release/monty[.exe]`.
## Supported Games
- Ataxx
- Chess
- Shatranj

## How it works

Expand All @@ -34,7 +35,7 @@ To begin with, only the root node is in the tree.
Unfortunately, MCTS in its purest form (random selection and random simulation to the end of the game)
is really bad.

Instead **selection** is replaced with PUCT, a combination of a **policy network** which indicates the quality of the child nodes,
Instead, **selection** is done via PUCT, a combination of a **policy network** which indicates the quality of the child nodes,
and the PUCT formula to control exploration vs exploitation of these child nodes.

And **simulation** is replaced with quiescence search of the node, backed by a neural network evaluation, called the **value network**.
And **simulation** is replaced with a neural network evaluation, called the **value network**.
Binary file removed resources/chess-policy001.bin
Binary file not shown.
Binary file added resources/chess-policy002.bin
Binary file not shown.
Binary file removed resources/chess-value001.bin
Binary file not shown.
Binary file added resources/chess-value002.bin
Binary file not shown.
6 changes: 3 additions & 3 deletions src/chess.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ pub use self::{
const STARTPOS: &str = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1";

static VALUE: ValueNetwork<768, 16> =
unsafe { std::mem::transmute(*include_bytes!("../resources/chess-value001.bin")) };
unsafe { std::mem::transmute(*include_bytes!("../resources/chess-value002.bin")) };

impl ValueFeatureMap for Board {
fn value_feature_map<F: FnMut(usize)>(&self, f: F) {
self.map_features(f);
self.map_value_features(f);
}
}

Expand Down Expand Up @@ -131,7 +131,7 @@ impl GameRep for Chess {

fn get_policy_feats(&self) -> goober::SparseVector {
let mut feats = goober::SparseVector::with_capacity(32);
self.board.map_features(|feat| feats.push(feat));
self.board.map_policy_features(|feat| feats.push(feat));
feats
}

Expand Down
15 changes: 12 additions & 3 deletions src/chess/board.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,17 @@ impl Board {
}
}

pub fn map_features<F: FnMut(usize)>(&self, mut f: F) {
pub fn map_value_features<F: FnMut(usize)>(&self, f: F) {
self.map_features::<F, true>(f);
}

pub fn map_policy_features<F: FnMut(usize)>(&self, f: F) {
self.map_features::<F, false>(f);
}

fn map_features<F: FnMut(usize), const HM: bool>(&self, mut f: F) {
let flip = self.stm() == Side::BLACK;
let hm = if HM && self.king_index() % 8 > 3 {7} else {0};

for piece in Piece::PAWN..=Piece::KING {
let pc = 64 * (piece - 2);
Expand All @@ -157,12 +166,12 @@ impl Board {

while our_bb > 0 {
pop_lsb!(sq, our_bb);
f(pc + usize::from(sq));
f(pc + usize::from(sq ^ hm));
}

while opp_bb > 0 {
pop_lsb!(sq, opp_bb);
f(384 + pc + usize::from(sq));
f(384 + pc + usize::from(sq ^ hm));
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/chess/policy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use super::moves::Move;
use goober::{activation, layer, FeedForwardNetwork, Matrix, SparseVector, Vector};

pub static POLICY: PolicyNetwork =
unsafe { std::mem::transmute(*include_bytes!("../../resources/chess-policy001.bin")) };
unsafe { std::mem::transmute(*include_bytes!("../../resources/chess-policy002.bin")) };

#[repr(C)]
#[derive(Clone, Copy, FeedForwardNetwork)]
Expand Down
2 changes: 1 addition & 1 deletion train/policy/src/chess.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl TrainablePolicy for PolicyNetwork {
let board = pos.board;

let mut feats = SparseVector::with_capacity(32);
board.map_features(|feat| feats.push(feat));
board.map_policy_features(|feat| feats.push(feat));

let mut policies = Vec::with_capacity(pos.num);
let mut total = 0.0;
Expand Down
6 changes: 3 additions & 3 deletions train/value/src/bin/chess.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@ fn main() {
let mut trainer = TrainerBuilder::default()
.single_perspective()
.quantisations(&[255, 64])
.input(inputs::Chess768)
.input(inputs::ChessBucketsMirrored::new([0; 32]))
.output_buckets(outputs::Single)
.feature_transformer(HIDDEN_SIZE)
.activate(Activation::SCReLU)
.add_layer(1)
.build();

let schedule = TrainingSchedule {
net_id: "chess-value001".to_string(),
net_id: "chess-value002".to_string(),
eval_scale: 400.0,
ft_regularisation: 0.0,
batch_size: 16_384,
Expand All @@ -36,7 +36,7 @@ fn main() {

let settings = LocalSettings {
threads: 4,
data_file_paths: vec!["data/chess/value001.data"],
data_file_paths: vec!["data/chess/value002.data"],
output_directory: "checkpoints",
};

Expand Down
2 changes: 1 addition & 1 deletion train/value/src/bin/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ fn filter<T: BulletFormat>() {

let err = (score - result).abs();

if err < 0.8 && raw_score.abs() < 1500 {
if err < 0.7 && raw_score.abs() < 1500 {
new.push(*pos);
} else {
filtered += 1;
Expand Down

0 comments on commit fb1b588

Please sign in to comment.