Skip to content

Commit

Permalink
residual policy network. train run 292 gets 2000 game score on average
Browse files Browse the repository at this point in the history
  • Loading branch information
mikhail-vlasenko committed Dec 29, 2024
1 parent 0b4caaf commit b3468f2
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 9 deletions.
8 changes: 6 additions & 2 deletions reinforcement_learning/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class EnvConfig:

@dataclass
class TrainConfig:
env_steps: int = 32_000_000
env_steps: int = 64_000_000
time_total_s: Optional[int] = None # if None, then env_steps is used
iter_env_steps: int = 256
load_from: str = None
Expand Down Expand Up @@ -51,13 +51,17 @@ class EvaluationConfig:
class ModelConfig:
nonlinear: str = 'tanh'
dimensions: List[int] = field(default_factory=lambda: [1024, 512, 512, 512])
residual: bool = True
residual_main_dim: int = 1024
residual_hidden_dim: int = 1536
residual_num_blocks: int = 3


@dataclass
class PPOConfig:
lr: float = 3e-4
gamma: float = 0.995
update_epochs: int = 5 # todo: 1
update_epochs: int = 2 # todo: 1
ent_coef: float = 0.01
batch_size: int = 512
rollout_fragment_length: Union[int, str] = 'auto'
Expand Down
File renamed without changes.
98 changes: 98 additions & 0 deletions reinforcement_learning/model/policy_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from typing import Callable, Dict, List, Optional, Tuple, Type, Union

from gymnasium import spaces
import torch
from torch import nn

from stable_baselines3 import PPO
from stable_baselines3.common.policies import MultiInputActorCriticPolicy

from reinforcement_learning.config import CONFIG


class ResidualNetwork(nn.Module):
"""
Custom network for policy and value function.
It receives as input the features extracted by the features extractor.
:param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN)
:param last_layer_dim_pi: (int) number of units for the last layer of the policy network
:param last_layer_dim_vf: (int) number of units for the last layer of the value network
"""

def __init__(
self,
feature_dim: int,
last_layer_dim_pi: int = CONFIG.model.residual_main_dim,
last_layer_dim_vf: int = CONFIG.model.residual_main_dim,
):
super().__init__()

# IMPORTANT:
# Save output dimensions, used to create the distributions
self.latent_dim_pi = last_layer_dim_pi
self.latent_dim_vf = last_layer_dim_vf

residual_kwargs = dict(
main_dim=CONFIG.model.residual_main_dim,
hidden_dim=CONFIG.model.residual_hidden_dim,
num_residual_blocks=CONFIG.model.residual_num_blocks,
activation=nn.Tanh
)

self.policy_net = nn.Sequential(
nn.Linear(feature_dim, CONFIG.model.residual_main_dim),
nn.Tanh(),
ResidualMLP(**residual_kwargs),
)
self.value_net = nn.Sequential(
nn.Linear(feature_dim, CONFIG.model.residual_main_dim),
nn.Tanh(),
ResidualMLP(**residual_kwargs),
)

def forward(self, features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
:return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
If all layers are shared, then ``latent_policy == latent_value``
"""
return self.forward_actor(features), self.forward_critic(features)

def forward_actor(self, features: torch.Tensor) -> torch.Tensor:
return self.policy_net(features)

def forward_critic(self, features: torch.Tensor) -> torch.Tensor:
return self.value_net(features)


class CustomActorCriticPolicy(MultiInputActorCriticPolicy):
def _build_mlp_extractor(self) -> None:
self.mlp_extractor = ResidualNetwork(self.features_dim)


class ResidualMLP(nn.Module):
def __init__(
self,
main_dim: int,
hidden_dim: int,
num_residual_blocks: int,
activation: Type[nn.Module] = nn.Tanh
):
super().__init__()

self.activation = activation()
layers = []
for i in range(num_residual_blocks):
layers.append(nn.Sequential(
nn.Linear(main_dim, hidden_dim),
self.activation,
nn.Linear(hidden_dim, main_dim),
))
self.hidden_layers = nn.ModuleList(
layers
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
for layer in self.hidden_layers:
x = self.activation(x + layer(x))
return x
22 changes: 15 additions & 7 deletions reinforcement_learning/sb3_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from python_wrapper.minecraft_2d_env import Minecraft2dEnv
from python_wrapper.simplified_actions import ActionSimplificationWrapper
from reinforcement_learning.config import CONFIG, ENV_KWARGS, WANDB_KWARGS
from reinforcement_learning.model.model import FeatureExtractor
from reinforcement_learning.model.feature_extractor import FeatureExtractor
from reinforcement_learning.model.policy_network import CustomActorCriticPolicy


class LoggingCallback(BaseCallback):
Expand Down Expand Up @@ -84,17 +85,24 @@ def main():

env = make_vec_env(Minecraft2dEnv, n_envs=CONFIG.env.num_envs, env_kwargs=ENV_KWARGS, wrapper_class=wrapper_class)

policy_kwargs = dict(
net_arch=dict(pi=CONFIG.model.dimensions, vf=CONFIG.model.dimensions),
features_extractor_class=FeatureExtractor,
features_extractor_kwargs=dict(),
)
if CONFIG.model.residual:
policy = CustomActorCriticPolicy
policy_kwargs = dict(
features_extractor_class=FeatureExtractor,
)
else:
policy = "MultiInputPolicy"
policy_kwargs = dict(
net_arch=dict(pi=CONFIG.model.dimensions, vf=CONFIG.model.dimensions),
features_extractor_class=FeatureExtractor,
features_extractor_kwargs=dict(),
)

if CONFIG.train.load_checkpoint is not None:
print(f"Loading checkpoint from {CONFIG.train.load_checkpoint}")
model = PPO.load(CONFIG.train.load_checkpoint, env, tensorboard_log=f"runs/{run.id}")
else:
model = PPO("MultiInputPolicy", env, policy_kwargs=policy_kwargs,
model = PPO(policy, env, policy_kwargs=policy_kwargs,
verbose=0, tensorboard_log=f"runs/{run.id}",
learning_rate=CONFIG.ppo.lr,
n_steps=CONFIG.train.iter_env_steps,
Expand Down

0 comments on commit b3468f2

Please sign in to comment.