Skip to content

Commit

Permalink
Merge pull request #10 from mikhail-vlasenko/dict-observation
Browse files Browse the repository at this point in the history
Environment returns a Dict observation. Task-aware feature encoder leads to better learning
  • Loading branch information
mikhail-vlasenko authored Dec 27, 2024
2 parents 552ac4f + 544b250 commit 49c86d1
Show file tree
Hide file tree
Showing 27 changed files with 643 additions and 292 deletions.
24 changes: 14 additions & 10 deletions ffi/src/game_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use game_logic::map_generation::field::{absolute_to_relative, Field, RelativePos
use game_logic::map_generation::field_observation::get_tile_observation;
use game_logic::map_generation::mobs::mob_kind::MobKind;
use game_logic::map_generation::save_load::load_game;
use crate::observation::{LOOT_INFO_SIZE, MAX_MOBS, MOB_INFO_SIZE, Observation};
use crate::observation::{LOOT_INFO_SIZE, make_action_mask, MAX_MOBS, MOB_INFO_SIZE, NUM_MOB_KINDS, Observation};


#[derive(Debug)]
Expand Down Expand Up @@ -61,21 +61,25 @@ impl GameState {
pub fn get_observation(&self) -> Observation {
let (top_materials, tile_heights) = get_tile_observation(&self.field, &self.player);
let top_materials = top_materials.iter().map(|row| row.iter().map(|mat| (*mat).into()).collect()).collect();
Observation::new(top_materials, tile_heights, &self.field, &self.player, self.get_closest_mobs(), self.make_loot_array())
Observation::new(top_materials, tile_heights, &self.field, &self.player, self.get_closest_mobs(), self.make_loot_array(), make_action_mask(&self))
}

/// Produces a vector of 4-arrays of mob information: x, y, type, health
/// The vector is sorted by manhattan distance from the player
/// x and y are player-relative
/// mob kind is an index in the MobKind enum
/// health is integer - rounded percentage of max health, so from 0 to 100
pub fn get_closest_mobs(&self) -> Vec<[i32; MOB_INFO_SIZE]> {
pub fn get_closest_mobs(&self) -> Vec<[i32; MOB_INFO_SIZE as usize]> {
let mob_kinds = MobKind::iter().collect::<Vec<MobKind>>();
let mut mobs = self.field.close_mob_info(|mob| {
let pos = absolute_to_relative((mob.pos.x, mob.pos.y), &self.player);
[pos.0, pos.1,
mob_kinds.iter().position(| kind | { kind == mob.get_kind() }).unwrap() as i32,
(mob.get_hp_share() * 100.0) as i32]
let mut arr = [0; MOB_INFO_SIZE as usize];
let idx = mob_kinds.iter().position(| kind | { kind == mob.get_kind() }).unwrap();
arr[0] = pos.0;
arr[1] = pos.1;
arr[2] = (mob.get_hp_share() * 100.0) as i32;
arr[3 + idx] = 1;
arr
}, &self.player);

// Sorting the mobs by manhattan distance from the player
Expand All @@ -92,8 +96,8 @@ impl GameState {
/// The array is sorted by manhattan distance from the player
/// x and y are player-relative
/// content (1: arrow, 2: other loot, 3: arrow and other loot). content -1 for no loot
pub fn make_loot_array(&self) -> [[i32; LOOT_INFO_SIZE]; MAX_MOBS] {
let mut loot = [[0, 0, -1]; MAX_MOBS];
pub fn make_loot_array(&self) -> [[i32; LOOT_INFO_SIZE as usize]; MAX_MOBS as usize] {
let mut loot = [[0, 0, -1]; MAX_MOBS as usize];
let player_dist_cmp = |a: &RelativePos, b: &RelativePos| {
let dist_a = a.0.abs() + a.1.abs();
let dist_b = b.0.abs() + b.1.abs();
Expand All @@ -105,7 +109,7 @@ impl GameState {
arrow_indices.sort_by(player_dist_cmp);
let mut min_empty_loot_position = 0;
// record loot and loot+arrow positions
for i in 0..min(loot_indices.len(), MAX_MOBS) {
for i in 0..min(loot_indices.len(), MAX_MOBS as usize) {
let idx = loot_indices[i];
if arrow_indices.contains(&idx) {
loot[i] = [idx.0, idx.1, 3];
Expand All @@ -117,7 +121,7 @@ impl GameState {
min_empty_loot_position = i + 1;
}
// in the remaining slots, record arrow-only positions
for i in 0..min(arrow_indices.len(), MAX_MOBS - min_empty_loot_position) {
for i in 0..min(arrow_indices.len(), MAX_MOBS as usize - min_empty_loot_position) {
let idx = arrow_indices[i];
loot[min_empty_loot_position + i] = [idx.0, idx.1, 1];
}
Expand Down
47 changes: 13 additions & 34 deletions ffi/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::os::raw::c_char;
use std::sync::Mutex;
use interoptopus::{ffi_function, Inventory, InventoryBuilder, function};
use interoptopus::{ffi_function, Inventory, InventoryBuilder, function, constant};
use interoptopus::{Error, Interop};
use lazy_static::lazy_static;
use game_logic::auxiliary::actions::Action;
use game_logic::SETTINGS;

use crate::game_state::GameState;
use crate::observation::{ActionMask, NUM_ACTIONS, Observation};
use crate::observation::Observation;

pub mod game_state;
pub mod observation;
Expand Down Expand Up @@ -144,28 +144,6 @@ pub extern "C" fn close_one(index: i32) {
}
}

/// Gets the actions mask for the game state at the specified index.
/// The mask is an array of integers where 1 means the action will lead to something happening with the games state,
/// and 0 means taking the action will yield the same observation.
///
/// # Arguments
///
/// * `index` - The index of the game state to get the actions mask for.
///
/// # Returns
///
/// * `ActionMask` - The actions mask for the game state.
#[ffi_function]
#[no_mangle]
pub extern "C" fn valid_actions_mask(index: i32) -> ActionMask {
let state = STATE.lock().unwrap();
if let Some(game_state) = state.get(index as usize) {
ActionMask::new(game_state)
} else {
panic!("Index {} out of bounds for batch size {}", index, state.len());
}
}

/// Sets the record_replays setting to the given value.
/// Training is better done with record_replays set to false, as it saves memory and time.
/// For evaluation and assessment one can consider setting it to true.
Expand Down Expand Up @@ -216,12 +194,6 @@ pub extern "C" fn get_batch_size() -> i32 {
*BATCH_SIZE.lock().unwrap() as i32
}

#[ffi_function]
#[no_mangle]
pub extern "C" fn num_actions() -> i32 {
NUM_ACTIONS as i32
}

#[ffi_function]
#[no_mangle]
pub extern "C" fn action_name(action: i32) -> *mut c_char {
Expand All @@ -241,13 +213,20 @@ pub fn ffi_inventory() -> Inventory {
.register(function!(step_one))
.register(function!(get_one_observation))
.register(function!(close_one))
.register(function!(valid_actions_mask))
.register(function!(set_record_replays))
.register(function!(set_start_loadout))
.register(function!(set_save_on_milestone))
.register(function!(get_batch_size))
.register(function!(num_actions))
.register(function!(action_name))

.register(constant!(observation::OBSERVATION_GRID_SIZE))
.register(constant!(observation::INVENTORY_SIZE))
.register(constant!(observation::NUM_ACTIONS))
.register(constant!(observation::MOB_INFO_SIZE))
.register(constant!(observation::MAX_MOBS))
.register(constant!(observation::LOOT_INFO_SIZE))
.register(constant!(observation::NUM_MATERIALS))

.inventory()
}

Expand All @@ -268,7 +247,7 @@ fn verify_num_inventory_items() {
use crate::observation::INVENTORY_SIZE;
use game_logic::crafting::storable::ALL_STORABLES;
println!("Actual length of the inventory: {}", ALL_STORABLES.len());
assert_eq!(INVENTORY_SIZE, ALL_STORABLES.len(), "change INVENTORY_SIZE manually");
assert_eq!(INVENTORY_SIZE as usize, ALL_STORABLES.len(), "change INVENTORY_SIZE manually");
}

#[test]
Expand All @@ -278,5 +257,5 @@ fn verify_num_actions() {
max += 1;
}
println!("Actual number of actions: {}", max);
assert_eq!(max, NUM_ACTIONS as i32, "change NUM_ACTIONS manually");
assert_eq!(max, observation::NUM_ACTIONS as i32, "change NUM_ACTIONS manually");
}
91 changes: 49 additions & 42 deletions ffi/src/observation.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::cmp::min;
use std::ffi::CString;
use std::os::raw::c_char;
use interoptopus::ffi_type;
use interoptopus::{ffi_constant, ffi_type};
use game_logic::character::player::Player;
use game_logic::crafting::storable::{ALL_STORABLES, Storable};
use game_logic::is_game_over;
Expand All @@ -11,28 +11,39 @@ use crate::game_state::GameState;


// use constants for array sizes to avoid dynamically sized arrays that may leak memory during FFI
pub const OBSERVATION_GRID_SIZE: usize = ((DEFAULT_SETTINGS.window.render_distance * 2) + 1) as usize;
pub const INVENTORY_SIZE: usize = 26;
pub const NUM_ACTIONS: usize = 39;
pub const MOB_INFO_SIZE: usize = 4;
pub const MAX_MOBS: usize = 16; // also max number of loot items
pub const LOOT_INFO_SIZE: usize = 3;
#[ffi_constant]
pub const OBSERVATION_GRID_SIZE: u32 = ((DEFAULT_SETTINGS.window.render_distance * 2) + 1) as u32;
#[ffi_constant]
pub const INVENTORY_SIZE: u32 = 27;
#[ffi_constant]
pub const NUM_ACTIONS: u32 = 39;
#[ffi_constant]
pub const NUM_MOB_KINDS: u32 = 5;
#[ffi_constant]
pub const MOB_INFO_SIZE: u32 = 3 + NUM_MOB_KINDS; // x, y, health share (0 to 100), [one-hot-encoded type]
#[ffi_constant]
pub const MAX_MOBS: u32 = 16; // also max number of loot items
#[ffi_constant]
pub const LOOT_INFO_SIZE: u32 = 3;
#[ffi_constant]
pub const NUM_MATERIALS: u32 = 13;


#[ffi_type]
#[repr(C)]
pub struct Observation {
pub top_materials: [[i32; OBSERVATION_GRID_SIZE]; OBSERVATION_GRID_SIZE],
pub tile_heights: [[i32; OBSERVATION_GRID_SIZE]; OBSERVATION_GRID_SIZE],
pub top_materials: [[i32; OBSERVATION_GRID_SIZE as usize]; OBSERVATION_GRID_SIZE as usize],
pub tile_heights: [[i32; OBSERVATION_GRID_SIZE as usize]; OBSERVATION_GRID_SIZE as usize],
pub player_pos: [i32; 3], // x, y, z
pub player_rot: i32, // one of 4 values: 0, 1, 2, 3
pub hp: i32,
pub time: f32,
pub inventory_state: [i32; INVENTORY_SIZE], // amount of storables of i-th type
// player-relative x, player-relative y, type, health. for the 16 closest mobs that are visible. mob type -1 is no mob
pub mobs: [[i32; MOB_INFO_SIZE]; MAX_MOBS],
pub inventory_state: [i32; INVENTORY_SIZE as usize], // amount of storables of i-th type
// player-relative x, player-relative y, health share (0 to 100), [one-hot-encoded type]. for the 16 closest mobs that are visible. mob type -1 is no mob
pub mobs: [[i32; MOB_INFO_SIZE as usize]; MAX_MOBS as usize],
// player-relative x, player-relative y, content (1: arrow, 2: other loot, 3: arrow and other loot). content -1 for no loot
pub loot: [[i32; LOOT_INFO_SIZE]; MAX_MOBS],
pub loot: [[i32; LOOT_INFO_SIZE as usize]; MAX_MOBS as usize],
pub action_mask: [i32; NUM_ACTIONS as usize],
pub score: i32,
pub message: *mut c_char,
pub done: bool,
Expand All @@ -43,16 +54,17 @@ impl Observation {
vec_top_materials: Vec<Vec<i32>>,
vec_tile_heights: Vec<Vec<i32>>,
field: &Field, player: &Player,
close_mobs: Vec<[i32; MOB_INFO_SIZE]>,
close_loot: [[i32; LOOT_INFO_SIZE]; MAX_MOBS]
close_mobs: Vec<[i32; MOB_INFO_SIZE as usize]>,
close_loot: [[i32; LOOT_INFO_SIZE as usize]; MAX_MOBS as usize],
action_mask: [i32; NUM_ACTIONS as usize],
) -> Self {
if vec_top_materials.len() != OBSERVATION_GRID_SIZE || vec_tile_heights.len() != OBSERVATION_GRID_SIZE {
if vec_top_materials.len() != OBSERVATION_GRID_SIZE as usize || vec_tile_heights.len() != OBSERVATION_GRID_SIZE as usize {
panic!("Invalid observation size");
}
let mut top_materials = [[0; OBSERVATION_GRID_SIZE]; OBSERVATION_GRID_SIZE];
let mut tile_heights = [[0; OBSERVATION_GRID_SIZE]; OBSERVATION_GRID_SIZE];
for i in 0..OBSERVATION_GRID_SIZE {
for j in 0..OBSERVATION_GRID_SIZE {
let mut top_materials = [[0; OBSERVATION_GRID_SIZE as usize]; OBSERVATION_GRID_SIZE as usize];
let mut tile_heights = [[0; OBSERVATION_GRID_SIZE as usize]; OBSERVATION_GRID_SIZE as usize];
for i in 0..OBSERVATION_GRID_SIZE as usize {
for j in 0..OBSERVATION_GRID_SIZE as usize {
top_materials[i][j] = vec_top_materials[i][j];
tile_heights[i][j] = vec_tile_heights[i][j];
}
Expand All @@ -62,14 +74,14 @@ impl Observation {
let hp = player.get_hp();
let time = field.get_time();

let mut inventory_state = [0; INVENTORY_SIZE];
let mut inventory_state = [0; INVENTORY_SIZE as usize];
for (storable, n) in player.get_inventory() {
let idx = storable_to_inv_index(storable);
inventory_state[idx] = *n as i32;
}

let mut mobs = [[0, 0, -1, 0]; MAX_MOBS];
for i in 0..min(close_mobs.len(), MAX_MOBS) {
let mut mobs = [[0; MOB_INFO_SIZE as usize]; MAX_MOBS as usize];
for i in 0..min(close_mobs.len(), MAX_MOBS as usize) {
for j in 0..4 {
mobs[i][j] = close_mobs[i][j];
}
Expand All @@ -88,6 +100,7 @@ impl Observation {
inventory_state,
mobs,
loot,
action_mask,
score,
message,
done,
Expand All @@ -98,15 +111,16 @@ impl Observation {
impl Default for Observation {
fn default() -> Self {
Self {
top_materials: [[0; OBSERVATION_GRID_SIZE]; OBSERVATION_GRID_SIZE],
tile_heights: [[0; OBSERVATION_GRID_SIZE]; OBSERVATION_GRID_SIZE],
top_materials: [[0; OBSERVATION_GRID_SIZE as usize]; OBSERVATION_GRID_SIZE as usize],
tile_heights: [[0; OBSERVATION_GRID_SIZE as usize]; OBSERVATION_GRID_SIZE as usize],
player_pos: [0; 3],
player_rot: 0,
hp: 0,
time: 0.0,
inventory_state: [0; INVENTORY_SIZE],
mobs: [[0, 0, -1, 0]; MAX_MOBS],
loot: [[0, 0, -1]; MAX_MOBS],
inventory_state: [0; INVENTORY_SIZE as usize],
mobs: [[0; MOB_INFO_SIZE as usize]; MAX_MOBS as usize],
loot: [[0, 0, -1]; MAX_MOBS as usize],
action_mask: [0; NUM_ACTIONS as usize],
score: 0,
message: CString::new(String::from("")).unwrap().into_raw(),
done: false,
Expand All @@ -119,22 +133,15 @@ fn storable_to_inv_index(storable: &Storable) -> usize {
ALL_STORABLES.iter().position(|s| s == storable).unwrap()
}

#[ffi_type]
#[repr(C)]
pub struct ActionMask {
pub mask: [i32; NUM_ACTIONS],
}

impl ActionMask {
pub fn new(game_state: &GameState) -> Self {
let mut mask = [0; NUM_ACTIONS];
if !game_state.is_done() {
for i in 0..NUM_ACTIONS {
if game_state.can_take_action(i as i32) {
mask[i] = 1;
}
pub fn make_action_mask(game_state: &GameState) -> [i32; NUM_ACTIONS as usize] {
let mut mask = [0; NUM_ACTIONS as usize];
if !game_state.is_done() {
for i in 0..NUM_ACTIONS as usize {
if game_state.can_take_action(i as i32) {
mask[i] = 1;
}
}
Self { mask }
}
mask
}
2 changes: 1 addition & 1 deletion game_logic/src/auxiliary/i32_enum_conversions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl Into<i32> for Material {
Bedrock => 4,
IronOre => 5,
CraftTable => 6,
Diamond => 7,
DiamondOre => 7,
Texture(t) => t as i32 + 8,
}
}
Expand Down
4 changes: 2 additions & 2 deletions game_logic/src/character/cheats.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::crafting::consumable::Consumable::{RawMeat, SpeedPotion};
use crate::crafting::items::Item::{Arrow, DiamondSword, IronIngot, IronPickaxe, Stick};
use crate::crafting::material::Material::{CraftTable, Diamond, Plank};
use crate::crafting::items::Item::{Arrow, Diamond, DiamondSword, IronIngot, IronPickaxe, Stick};
use crate::crafting::material::Material::{CraftTable, Plank};
use crate::crafting::ranged_weapon::RangedWeapon::Bow;
use crate::character::player::Player;
use crate::crafting::interactable::InteractableKind::CrossbowTurret;
Expand Down
2 changes: 1 addition & 1 deletion game_logic/src/character/game_score.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ impl Player {

pub fn score_mined(&mut self, material: &Material) {
match material {
Material::Diamond => self.add_to_score(self.score_values().blocks.mined.diamond),
Material::DiamondOre => self.add_to_score(self.score_values().blocks.mined.diamond),
_ => {}
}
}
Expand Down
6 changes: 3 additions & 3 deletions game_logic/src/character/milestones.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use serde::{Deserialize, Serialize};
use crate::character::player::Player;
use crate::crafting::interactable::InteractableKind::CrossbowTurret;
use crate::crafting::items::Item::{DiamondSword, IronIngot, IronPickaxe, WoodenPickaxe};
use crate::crafting::material::Material::{Diamond, IronOre};
use crate::crafting::items::Item::{Diamond, DiamondSword, IronIngot, IronPickaxe, WoodenPickaxe};
use crate::crafting::material::Material::IronOre;
use crate::crafting::storable::Storable;


Expand Down Expand Up @@ -76,7 +76,7 @@ impl MilestoneTracker {
fn mined_diamond() -> Milestone {
Milestone::new(
"Diamonds!".to_string(),
vec![(Storable::M(Diamond), 1)],
vec![(Storable::I(Diamond), 1)],
0.
)
}
Expand Down
3 changes: 1 addition & 2 deletions game_logic/src/crafting/interactable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use crate::character::acting_with_speed::ActingWithSpeed;
use crate::character::player::Player;
use crate::crafting::items::Item::*;
use crate::crafting::material::Material;
use crate::crafting::material::Material::Diamond;
use crate::crafting::storable::{Craftable, CraftMenuSection, Storable};
use crate::crafting::storable::CraftMenuSection::*;
use crate::crafting::interactable::InteractableKind::*;
Expand Down Expand Up @@ -114,7 +113,7 @@ impl Craftable for InteractableKind {
}
fn craft_requirements(&self) -> &[(&Storable, u32)] {
match self {
CrossbowTurret => &[(&I(Stick), 6), (&M(Diamond), 1), (&I(TargetingModule), 1)]
CrossbowTurret => &[(&I(Stick), 6), (&I(Diamond), 1), (&I(TargetingModule), 1)]
}
}
fn craft_yield(&self) -> u32 {
Expand Down
Loading

0 comments on commit 49c86d1

Please sign in to comment.