Skip to content

Commit

Permalink
actual batched action sampling. all in the main process for 100% gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
mikhail-vlasenko committed Jul 29, 2024
1 parent ad7db63 commit e3c7d8b
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 13 deletions.
7 changes: 7 additions & 0 deletions ffi/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,12 @@ pub extern "C" fn set_record_replays(value: bool) {
SETTINGS.write().unwrap().record_replays = value;
}

#[ffi_function]
#[no_mangle]
pub extern "C" fn get_batch_size() -> i32 {
*BATCH_SIZE.lock().unwrap() as i32
}

#[ffi_function]
#[no_mangle]
pub extern "C" fn num_actions() -> i32 {
Expand All @@ -193,6 +199,7 @@ pub fn ffi_inventory() -> Inventory {
.register(function!(close_one))
.register(function!(valid_actions_mask))
.register(function!(set_record_replays))
.register(function!(get_batch_size))
.register(function!(num_actions))
.register(function!(action_name))
.inventory()
Expand Down
5 changes: 5 additions & 0 deletions python_wrapper/ffi_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ def init_lib(path):
c_lib.close_one.argtypes = [ctypes.c_int32]
c_lib.valid_actions_mask.argtypes = [ctypes.c_int32]
c_lib.set_record_replays.argtypes = [ctypes.c_bool]
c_lib.get_batch_size.argtypes = []
c_lib.num_actions.argtypes = []
c_lib.action_name.argtypes = [ctypes.c_int32]

c_lib.connect_env.restype = ctypes.c_int32
c_lib.get_one_observation.restype = Observation
c_lib.valid_actions_mask.restype = ActionMask
c_lib.get_batch_size.restype = ctypes.c_int32
c_lib.num_actions.restype = ctypes.c_int32
c_lib.action_name.restype = ctypes.POINTER(ctypes.c_int8)

Expand Down Expand Up @@ -100,6 +102,9 @@ def set_record_replays(value: bool):
* `value` - The value to set record_replays to."""
return c_lib.set_record_replays(value)

def get_batch_size() -> int:
return c_lib.get_batch_size()

def num_actions() -> int:
return c_lib.num_actions()

Expand Down
16 changes: 12 additions & 4 deletions python_wrapper/minecraft_2d_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from python_wrapper.ffi_elements import (
init_lib, reset_one, step_one, num_actions, set_batch_size,
set_record_replays, connect_env, close_one, c_lib
set_record_replays, connect_env, close_one, c_lib, get_batch_size
)
from python_wrapper.observation import get_processed_observation, NUM_MATERIALS, get_actions_mask

Expand Down Expand Up @@ -35,16 +35,24 @@ def __init__(
include_actions_in_obs (bool): Whether to include available actions in the observation.
"""
if c_lib is None:
print("Initializing Minecraft connection in Minecraft2dEnv init. "
"This should not happen twice in the same process")
initialize_minecraft_connection(num_envs=num_total_envs, record_replays=record_replays, lib_path=lib_path)
init_lib(lib_path)
if get_batch_size() == 1:
# so in Ray env creation seems somehow isolated, so the init_lib has to be called always,
# but after the first time it connects to an existing FFI environment.
# so if the batch size is not its default value (1), then the by calling set_batch_size again,
# we would invalidate previous connections.
print("Initializing Minecraft connection in Minecraft2dEnv init. "
"This should not happen twice in the same process")
set_batch_size(num_total_envs)
set_record_replays(record_replays)
super().__init__()

if render_mode is not None:
raise ValueError("Rendering is not supported. Use the 2d-minecraft binary for replays.")

self.render_mode = render_mode
self.c_lib_index = connect_env()
print(f"Connected to environment with index {self.c_lib_index}.")

self.num_actions = num_actions()
self.current_score = 0
Expand Down
4 changes: 2 additions & 2 deletions reinforcement_learning/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

@dataclass
class EnvConfig:
num_envs: int = 8
num_envs: int = 64
lib_path: str = 'C:/Users/Mikhail/RustProjects/2d-minecraft/target/release/ffi.dll'
discovered_actions_reward: float = 50.
include_actions_in_obs: bool = True
Expand Down Expand Up @@ -45,7 +45,7 @@ class PPOConfig:
@dataclass
class Config:
wandb_resume_id: str = ""
num_runners: int = 8
num_runners: int = 0
env: EnvConfig = field(default_factory=EnvConfig)
ppo_train: PPOTrainConfig = field(default_factory=PPOTrainConfig)
evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
Expand Down
12 changes: 5 additions & 7 deletions reinforcement_learning/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from ray.tune.registry import register_env
from ray.air.integrations.wandb import WandbLoggerCallback

from python_wrapper.minecraft_2d_env import Minecraft2dEnv
from python_wrapper.minecraft_2d_env import Minecraft2dEnv, initialize_minecraft_connection
from reinforcement_learning.config import CONFIG
from reinforcement_learning.metrics_callback import MinecraftMetricsCallback

Expand Down Expand Up @@ -34,6 +34,8 @@ def main():
wandb_kwargs['resume'] = "must"
wandb_kwargs['id'] = CONFIG.wandb_resume_id

initialize_minecraft_connection(num_envs=CONFIG.env.num_envs, record_replays=False, lib_path=CONFIG.env.lib_path)

ppo_config = (
PPOConfig()
.environment("Minecraft2D", env_config={
Expand All @@ -54,15 +56,11 @@ def main():
sgd_minibatch_size=CONFIG.ppo.batch_size,
train_batch_size=CONFIG.ppo_train.iter_env_steps * CONFIG.env.num_envs,
)
.rollouts(num_rollout_workers=CONFIG.num_runners)
.resources(num_gpus=1)
.env_runners(num_cpus_per_env_runner=1)
.resources(num_gpus=1, num_cpus_for_main_process=4)
.env_runners(num_env_runners=CONFIG.num_runners, num_envs_per_env_runner=CONFIG.env.num_envs) # num_cpus_per_env_runner=2
.callbacks(MinecraftMetricsCallback)
)

if CONFIG.ppo_train.load_from:
ppo_config = ppo_config.restore(CONFIG.ppo_train.load_from)

stop_conditions = {
"timesteps_total": CONFIG.ppo_train.env_steps,
}
Expand Down

0 comments on commit e3c7d8b

Please sign in to comment.