Skip to content

Commit

Permalink
Merge branch 'dict-observation'
Browse files Browse the repository at this point in the history
  • Loading branch information
mikhail-vlasenko committed Dec 29, 2024
2 parents 49c86d1 + edb83a5 commit 0b4caaf
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 49 deletions.
40 changes: 15 additions & 25 deletions game_logic/src/map_generation/save_load.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::{fs, io};
use std::error::Error;
use std::fs::{create_dir_all, File};
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
Expand Down Expand Up @@ -87,42 +88,31 @@ pub fn load_game(path: &Path) -> Result<(Field, Player, Replay, MilestoneTracker
Ok((field, player, Replay::new(), milestone_tracker))
}

pub fn get_directories(path: &Path) -> io::Result<Vec<String>> {
let mut directories = Vec::new();
/// Returns a list of directories or files in the given path.
/// Name and seconds since epoch of last modification are returned.
///
/// # Arguments
///
/// * `path` - The path to list directories or files in.
/// * `directories` - If true, directories are listed. If false, files are listed.
pub fn list_directory(path: &Path, directories: bool) -> Result<Vec<(String, i32)>, Box<dyn Error>> {
let mut result = Vec::new();

for entry in fs::read_dir(path)? {
let entry = entry?;
let path = entry.path();

// If the entry is a directory, add it to the list
if path.is_dir() {
if (directories && path.is_dir()) || (!directories && path.is_file()) {
if let Some(filename) = path.file_name() {
directories.push(filename.to_string_lossy().into_owned());
let metadata = fs::metadata(&path)?;
let modified = metadata.modified()?.duration_since(std::time::SystemTime::UNIX_EPOCH)?.as_secs() as i32;
result.push((filename.to_string_lossy().into_owned(), modified));
}
}
}
directories.sort();

Ok(directories)
}

pub fn get_files(path: &Path) -> io::Result<Vec<String>> {
let mut files = Vec::new();

for entry in fs::read_dir(path)? {
let entry = entry?;
let path = entry.path();

// If the entry is a file, add it to the list
if path.is_file() {
if let Some(filename) = path.file_name() {
files.push(filename.to_string_lossy().into_owned());
}
}
}
files.sort();

Ok(files)
Ok(result)
}

pub fn get_full_path(path: &Path) -> PathBuf {
Expand Down
85 changes: 72 additions & 13 deletions minecraft/src/graphics/ui/main_menu.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::path::Path;
use std::process::exit;
use egui::{ScrollArea, Slider, TextEdit};
use egui::{ComboBox, ScrollArea, Slider, TextEdit};
use egui::{Align2, Color32, FontId, RichText, Context};
use game_logic::character::player::Player;
use game_logic::map_generation::save_load::{get_directories, get_files, get_full_path};
use game_logic::map_generation::save_load::{list_directory, get_full_path};
use game_logic::SETTINGS;
use game_logic::settings::Settings;

Expand All @@ -15,6 +15,8 @@ pub struct MainMenu {
pub save_name: String,
pub replay_name: String,
pub world_seed_buffer: String,
substring_search: String,
sorting_regime: SortingRegime,
}

impl MainMenu {
Expand All @@ -29,6 +31,8 @@ impl MainMenu {
save_name: String::from("default_save"),
replay_name: String::from("default_replay"),
world_seed_buffer,
substring_search: String::new(),
sorting_regime: SortingRegime::DateDescending,
}
}

Expand Down Expand Up @@ -99,10 +103,14 @@ impl MainMenu {

let save_path_string = settings.save_folder.clone().into_owned();
let save_path = Path::new(&save_path_string);
let save_names = get_directories(&save_path).unwrap_or(vec![]);

let mut save_directories = list_directory(&save_path, true).unwrap_or(vec![]);
save_directories.retain(|(name, _)| name.contains(&self.substring_search));
self.sorting_regime.sort_directories(&mut save_directories);

match self.second_panel {
SecondPanelState::SaveGame => {
self.render_search_bar(&mut columns[1]);
columns[1].horizontal(|ui| {
ui.label("Save name:");
ui.text_edit_singleline(&mut self.save_name);
Expand All @@ -111,15 +119,18 @@ impl MainMenu {
self.selected_option = SelectedOption::SaveGame;
}
columns[1].label("Existing saves:");
for name in save_names.iter() {
columns[1].label(RichText::new(name).font(FontId::proportional(20.0)));
}
ScrollArea::vertical().show(&mut columns[1], |scroll| {
for (name, _epoch_time) in save_directories.iter() {
scroll.label(RichText::new(name).font(FontId::proportional(20.0)));
}
});

self.back_button(&mut columns[1]);
}
SecondPanelState::LoadGame => {
self.render_search_bar(&mut columns[1]);
ScrollArea::vertical().show(&mut columns[1], |scroll| {
for name in save_names.iter() {
for (name, _epoch_time) in save_directories.iter() {
if scroll.button(RichText::new(name)
.font(FontId::proportional(20.0))).clicked() {
self.save_name = name.clone();
Expand All @@ -128,22 +139,23 @@ impl MainMenu {
}
});

if save_names.is_empty() {
if save_directories.is_empty() {
let full_path = get_full_path(&save_path);
columns[1].label(format!("No saves found in \n{}", full_path.to_string_lossy()));
}

self.back_button(&mut columns[1]);
}
SecondPanelState::Replays => {
self.render_search_bar(&mut columns[1]);
let replay_path_string = settings.replay_folder.clone().into_owned();
let replay_path = Path::new(&replay_path_string);
let mut replay_names = get_files(&replay_path).unwrap_or(vec![]);
// files are already sorted alphabetically but recent replays should come first
replay_names.reverse();
let mut replay_names = list_directory(&replay_path, false).unwrap_or(vec![]);
replay_names.retain(|(name, _)| name.contains(&self.substring_search));
self.sorting_regime.sort_directories(&mut replay_names);
ScrollArea::vertical().show(&mut columns[1], |scroll| {
for name in replay_names.iter() {
if scroll.button(RichText::new(name)
for (name, _epoch_time) in replay_names.iter() {
if scroll.button(RichText::new(name.clone().replace(".postcard", ""))
.font(FontId::proportional(20.0))).clicked() {
self.replay_name = name.clone();
self.selected_option = SelectedOption::WatchReplay;
Expand Down Expand Up @@ -320,6 +332,25 @@ impl MainMenu {
);
}

fn render_search_bar(&mut self, ui: &mut egui::Ui) {
use SortingRegime::*;
ui.horizontal(|ui| {
ui.label("Search:");
ui.add(TextEdit::singleline(&mut self.substring_search).desired_width(ui.available_width() / 1.75));
ComboBox::from_label("")
.selected_text(match self.sorting_regime {
ref regime => regime.name(),
}).width(ui.available_width())
.show_ui(ui, |ui| {
for regime in &[AlphaAscending, AlphaDescending, DateAscending, DateDescending] {
if ui.selectable_label(*regime == self.sorting_regime, regime.name()).clicked() {
self.sorting_regime = regime.clone();
}
}
});
});
}

fn back_button(&mut self, ui: &mut egui::Ui) {
ui.horizontal(|ui| {
if ui.button("Back").clicked() {
Expand Down Expand Up @@ -347,3 +378,31 @@ pub enum SecondPanelState {
Settings,
Controls,
}

#[derive(PartialEq, Clone)]
enum SortingRegime {
AlphaAscending,
AlphaDescending,
DateAscending,
DateDescending,
}

impl SortingRegime {
fn sort_directories(&self, directories: &mut Vec<(String, i32)>) {
match self {
SortingRegime::AlphaAscending => directories.sort_by(|a, b| a.0.cmp(&b.0)),
SortingRegime::AlphaDescending => directories.sort_by(|a, b| b.0.cmp(&a.0)),
SortingRegime::DateAscending => directories.sort_by(|a, b| a.1.cmp(&b.1)),
SortingRegime::DateDescending => directories.sort_by(|a, b| b.1.cmp(&a.1)),
}
}

fn name(&self) -> &str {
match self {
SortingRegime::AlphaAscending => "Name Asc.",
SortingRegime::AlphaDescending => "Name Desc.",
SortingRegime::DateAscending => "Date Asc.",
SortingRegime::DateDescending => "Date Desc.",
}
}
}
2 changes: 1 addition & 1 deletion reinforcement_learning/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class EvaluationConfig:
@dataclass
class ModelConfig:
nonlinear: str = 'tanh'
dimensions: List[int] = field(default_factory=lambda: [512, 512, 256, 256])
dimensions: List[int] = field(default_factory=lambda: [1024, 512, 512, 512])


@dataclass
Expand Down
29 changes: 19 additions & 10 deletions reinforcement_learning/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,21 @@ def __init__(self, observation_space):

self.material_channels = 8
self.height_channels = 8
self.mob_heads = 12
self.mob_dim = 8 * self.mob_heads
self.loot_heads = 12
self.loot_dim = 4 * self.loot_heads
self.mob_heads = 16
self.mob_dim = 16 * self.mob_heads
self.loot_heads = 8
self.loot_dim = 8 * self.loot_heads

self.block_encoder = nn.Sequential(
nn.Linear(NUM_MATERIALS, self.material_channels),
nn.Tanh(), # relu will likely lead to dead neurons, as input is one-hot
)
self.height_conv = nn.Sequential(
nn.Conv2d(1, self.height_channels, kernel_size=3, stride=(1, 1), padding=(0, 0)),
self.height_conv_hor = nn.Sequential(
nn.Conv2d(1, self.height_channels, kernel_size=(1, 3), stride=(1, 1), padding=(0, 0)),
nn.GELU(),
)
self.height_conv_vert = nn.Sequential(
nn.Conv2d(1, self.height_channels, kernel_size=(3, 1), stride=(1, 1), padding=(0, 0)),
nn.GELU(),
)
self.mob_encoder = nn.Sequential(
Expand Down Expand Up @@ -85,8 +89,10 @@ def forward(self, observation: dict) -> torch.Tensor:
tile_heights = observation["tile_heights"] / 4 # max height is 4
tile_heights = tile_heights[:, self.height_grid_start:self.height_grid_end,
self.height_grid_start:self.height_grid_end].unsqueeze(1)
height_features = self.height_conv(tile_heights)
height_features = height_features.view(height_features.size(0), -1)
height_features_hor = self.height_conv_hor(tile_heights)
height_features_hor = height_features_hor.view(height_features_hor.size(0), -1)
height_features_vert = self.height_conv_vert(tile_heights)
height_features_vert = height_features_vert.view(height_features_vert.size(0), -1)

mobs = observation["mobs"] # (batch_size, NUM_MOBS, MOB_INFO_SIZE)
# first two mob features are x and y positions, which should be log-scaled
Expand All @@ -106,8 +112,11 @@ def forward(self, observation: dict) -> torch.Tensor:
inventory = observation["inventory_state"] / 4 # arbitrary downscaling
inventory = self.inventory_encoder(inventory)

return torch.cat([near_materials, pooled_materials, height_features, mob_pool, loot_pool, inventory,
self.extract_flat_features(observation)], dim=1)
return torch.cat([
near_materials, pooled_materials, height_features_hor, height_features_vert,
mob_pool, loot_pool, inventory,
self.extract_flat_features(observation)
], dim=1)

def extract_flat_features(self, observation: dict) -> torch.Tensor:
player_pos = observation["player_pos"]
Expand Down

0 comments on commit 0b4caaf

Please sign in to comment.