Skip to content

Commit

Permalink
impala but its not good
Browse files Browse the repository at this point in the history
  • Loading branch information
mikhail-vlasenko committed Jul 29, 2024
1 parent ba58ec2 commit 941aa59
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 28 deletions.
17 changes: 12 additions & 5 deletions reinforcement_learning/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@

@dataclass
class EnvConfig:
num_envs: int = 64
num_envs: int = 8
lib_path: str = 'C:/Users/Mikhail/RustProjects/2d-minecraft/target/release/ffi.dll'
discovered_actions_reward: float = 100.
include_actions_in_obs: bool = True


@dataclass
class PPOTrainConfig:
class TrainConfig:
env_steps: int = 16000000
iter_env_steps: int = 512
iter_env_steps: int = 1024
load_from: str = None
save_to: str = f'reinforcement_learning/saved_models/sb3_ppo.pt'
checkpoints_per_training: int = 16
num_runners: int = 0
num_runners: int = 8


@dataclass
Expand All @@ -41,14 +41,21 @@ class PPOConfig:
dimensions: List[int] = field(default_factory=lambda: [512, 256, 128, 64])


@dataclass
class IMPALAConfig:
gamma: float = 0.995
rollout_fragment_length: int = 256


@dataclass
class Config:
storage_path: str = f"{os.getcwd()}/reinforcement_learning/ray_results"
wandb_resume_id: str = ""
env: EnvConfig = field(default_factory=EnvConfig)
ppo_train: PPOTrainConfig = field(default_factory=PPOTrainConfig)
train: TrainConfig = field(default_factory=TrainConfig)
evaluation: EvaluationConfig = field(default_factory=EvaluationConfig)
ppo: PPOConfig = field(default_factory=PPOConfig)
impala: IMPALAConfig = field(default_factory=IMPALAConfig)
device: torch.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

def as_dict(self):
Expand Down
52 changes: 29 additions & 23 deletions reinforcement_learning/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import ray
from ray import tune
from ray.air import CheckpointConfig
from ray.rllib.algorithms import ImpalaConfig
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.registry import register_env
from ray.air.integrations.wandb import WandbLoggerCallback
Expand Down Expand Up @@ -35,48 +36,53 @@ def main():
wandb_kwargs['resume'] = "must"
wandb_kwargs['id'] = CONFIG.wandb_resume_id

train_batch_size = CONFIG.ppo_train.iter_env_steps * CONFIG.env.num_envs
train_batch_size = CONFIG.train.iter_env_steps * CONFIG.env.num_envs * max(1, CONFIG.train.num_runners) // 2

ppo_config = (
PPOConfig()
impala_config = (
ImpalaConfig()
.environment("Minecraft2D", env_config={
"discovered_actions_reward": CONFIG.env.discovered_actions_reward,
"include_actions_in_obs": CONFIG.env.include_actions_in_obs,
})
.framework("torch")
.training(
model={
"fcnet_hiddens": CONFIG.ppo.dimensions,
"fcnet_activation": CONFIG.ppo.nonlinear,
},
lr=CONFIG.ppo.lr,
# model={
# "fcnet_hiddens": CONFIG.ppo.dimensions,
# "fcnet_activation": CONFIG.ppo.nonlinear,
# },
gamma=CONFIG.ppo.gamma,
lambda_=0.95,
entropy_coeff=CONFIG.ppo.ent_coef,
num_sgd_iter=CONFIG.ppo.update_epochs,
sgd_minibatch_size=CONFIG.ppo.batch_size,
vf_loss_coeff=0.5,
# entropy_coeff=CONFIG.ppo.ent_coef,
train_batch_size=train_batch_size,
vtrace=True,
vtrace_clip_rho_threshold=1.0,
vtrace_clip_pg_rho_threshold=1.0,
)
.resources(num_gpus=1, num_cpus_for_main_process=2)
.env_runners(
rollout_fragment_length=CONFIG.impala.rollout_fragment_length,
num_env_runners=CONFIG.train.num_runners,
num_envs_per_env_runner=CONFIG.env.num_envs,
num_cpus_per_env_runner=1,
)
.resources(num_gpus=1, num_cpus_for_main_process=4)
.env_runners(num_env_runners=CONFIG.ppo_train.num_runners, num_envs_per_env_runner=CONFIG.env.num_envs) # num_cpus_per_env_runner=2
.callbacks(MinecraftMetricsCallback)
)

stop_conditions = {
"timesteps_total": CONFIG.ppo_train.env_steps,
"timesteps_total": CONFIG.train.env_steps,
}

checkpoint_config = CheckpointConfig(
num_to_keep=2,
checkpoint_score_attribute="env_runners/episode_return_mean",
checkpoint_at_end=True,
checkpoint_frequency=CONFIG.ppo_train.env_steps // train_batch_size // CONFIG.ppo_train.checkpoints_per_training
checkpoint_frequency=CONFIG.train.env_steps // train_batch_size // CONFIG.train.checkpoints_per_training
)

analysis = tune.run(
"PPO",
"IMPALA",
storage_path=CONFIG.storage_path,
config=ppo_config.to_dict(),
config=impala_config.to_dict(),
stop=stop_conditions,
checkpoint_config=checkpoint_config,
callbacks=[
Expand All @@ -87,14 +93,14 @@ def main():
],
)

best_trial = analysis.get_best_trial("episode_reward_mean", "max", "last")
best_trial = analysis.get_best_trial("env_runners/episode_return_mean", "max", "last")
print(f"Best trial config: {best_trial.config}")
print(f"Best trial final episode reward mean: {best_trial.last_result['episode_reward_mean']}")
print(f"Best trial final episode reward mean: {best_trial.last_result['env_runners/episode_return_mean']}")

best_checkpoint = analysis.best_checkpoint
if CONFIG.ppo_train.save_to:
best_checkpoint.to_directory(CONFIG.ppo_train.save_to)
print(f"Best checkpoint saved to: {CONFIG.ppo_train.save_to}")
if CONFIG.train.save_to:
best_checkpoint.to_directory(CONFIG.train.save_to)
print(f"Best checkpoint saved to: {CONFIG.train.save_to}")

ray.shutdown()

Expand Down

0 comments on commit 941aa59

Please sign in to comment.