Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NOMRG] Example: adding custom PettingZoo env #84

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarl/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- experiment: base_experiment
- algorithm: ???
- task: ???
- algorithm: ippo
- task: myenv/my_task
- model: layers/mlp
- model@critic_model: layers/mlp
- _self_
Expand Down
6 changes: 3 additions & 3 deletions benchmarl/conf/experiment/base_experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ on_policy_collected_frames_per_batch: 6000
# Number of environments used for collection
# If the environment is vectorized, this will be the number of batched environments.
# Otherwise batching will be simulated and each env will be run sequentially.
on_policy_n_envs_per_worker: 10
on_policy_n_envs_per_worker: 1
# This is the number of times collected_frames_per_batch will be split into minibatches and trained
on_policy_n_minibatch_iters: 45
# In on-policy algorithms the train_batch_size will be equal to the on_policy_collected_frames_per_batch
Expand All @@ -60,7 +60,7 @@ off_policy_collected_frames_per_batch: 6000
# Number of environments used for collection
# If the environment is vectorized, this will be the number of batched environments.
# Otherwise batching will be simulated and each env will be run sequentially.
off_policy_n_envs_per_worker: 10
off_policy_n_envs_per_worker: 1
# This is the number of times off_policy_train_batch_size will be sampled from the buffer and trained over.
off_policy_n_optimizer_steps: 1000
# Number of frames used for each off_policy_n_optimizer_steps when training off-policy algorithms
Expand All @@ -77,7 +77,7 @@ render: True
# Frequency of evaluation in terms of collected frames (this should be a multiple of on/off_policy_collected_frames_per_batch)
evaluation_interval: 120_000
# Number of episodes that evaluation is run on
evaluation_episodes: 10
evaluation_episodes: 1
# If True, when stochastic policies are evaluated, their mode is taken, otherwise, if False, they are sampled
evaluation_deterministic_actions: True

Expand Down
Empty file.
3 changes: 2 additions & 1 deletion benchmarl/environments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@

from .common import Task
from .meltingpot.common import MeltingPotTask
from .myenv.common import MyenvTask
from .pettingzoo.common import PettingZooTask
from .smacv2.common import Smacv2Task
from .vmas.common import VmasTask

# This is a registry mapping "envname/task_name" to the EnvNameTask.TASK_NAME enum
# It is used by automatically load task enums from yaml files
task_config_registry = {}
for env in [VmasTask, Smacv2Task, PettingZooTask, MeltingPotTask]:
for env in [VmasTask, Smacv2Task, PettingZooTask, MeltingPotTask, MyenvTask]:
env_config_registry = {
f"{env.env_name()}/{task.name.lower()}": task for task in env
}
Expand Down
208 changes: 208 additions & 0 deletions benchmarl/environments/myenv/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#


from typing import Callable, Dict, List, Optional

import numpy as np

from benchmarl.environments.common import Task

from benchmarl.utils import DEVICE_TYPING

from gymnasium import spaces
from pettingzoo import ParallelEnv

from torchrl.data import CompositeSpec
from torchrl.envs import EnvBase, PettingZooWrapper


class MyCustomEnv2(ParallelEnv):
"""
Multi-agent version of my single agent class.
"""

metadata = {"render_modes": ["human"], "name": "myclass_v0"}

def __init__(self, num_envs=2):
super(MyCustomEnv2, self).__init__()
self.t = 1
num_agents = 3
self.possible_agents = ["player_" + str(r + 1) for r in range(num_agents)]

self.agent_name_mapping = dict(
zip(self.possible_agents, list(range(len(self.possible_agents))))
)
self.render_mode = None

def observation_space(self, agent):
state_low = np.concatenate(
(
np.zeros(1), # weights
np.full(2, -np.inf),
)
)
state_high = np.concatenate(
(
np.ones(1),
np.full(2, np.inf),
)
)

return spaces.Box(
low=state_low,
high=state_high,
shape=(3,),
dtype=np.float32, # this was the problem (originally we used float64 but
# changed to benchmarl error - also the function below)
)

def action_space(self, agent):
return spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)

def close(self):
"""
Close should release any graphical displays, subprocesses, network connections
or any other environment data which should not be kept around after the
user is no longer using the environment.
"""
pass

def reset(self, seed=None, options=None):
self.t = 1
self.agents = self.possible_agents[:]
state_dummy = np.concatenate(
(
np.ones(1),
np.full(2, np.inf),
)
)
observations = {agent: state_dummy for agent in self.agents}
infos = {agent: {} for agent in self.agents}
self.state = observations
print("RESET DONE")

return observations, infos

def step(self, actions):

self.t += 1

env_truncation = self.t >= 5

print(f"step, t: {self.t}")
print(f"env_truncation: {env_truncation}")
print()
self.done = env_truncation
if not actions:
self.agents = []

return {}, {}, {}, {}, {}

rewards = {}
observations = {}

for agent in self.agents:

state_dummy = np.concatenate(
(
np.ones(1),
np.full(2, np.inf),
)
)
reward = 10

# Store the reward in the dictionary
observations[agent] = state_dummy
rewards[agent] = reward

# self.state = observations
terminations = {agent: env_truncation for agent in self.agents}
truncations = {agent: env_truncation for agent in self.agents}

self.state = observations

infos = {agent: {} for agent in self.agents}

if env_truncation:
self.agents = []

return observations, rewards, terminations, truncations, infos


class MyenvTask(Task):

MY_TASK = None

def get_env_fun(
self,
num_envs: int,
continuous_actions: bool,
seed: Optional[int],
device: DEVICE_TYPING,
) -> Callable[[], EnvBase]:

return lambda: PettingZooWrapper(
MyCustomEnv2(),
categorical_actions=True,
device=device,
seed=seed,
return_state=False,
**self.config,
)

def supports_continuous_actions(self) -> bool:
return True

def supports_discrete_actions(self) -> bool:
return False

def has_state(self) -> bool:
return False

def has_render(self, env: EnvBase) -> bool:
return False

def max_steps(self, env: EnvBase) -> int:
return 100

def group_map(self, env: EnvBase) -> Dict[str, List[str]]:
return env.group_map

def state_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
return None

def action_mask_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
return None

def observation_spec(self, env: EnvBase) -> CompositeSpec:
observation_spec = env.observation_spec.clone()
for group in self.group_map(env):
group_obs_spec = observation_spec[group]
for key in list(group_obs_spec.keys()):
if key != "observation":
del group_obs_spec[key]
if "state" in observation_spec.keys():
del observation_spec["state"]
return observation_spec

def info_spec(self, env: EnvBase) -> Optional[CompositeSpec]:
observation_spec = env.observation_spec.clone()
for group in self.group_map(env):
group_obs_spec = observation_spec[group]
for key in list(group_obs_spec.keys()):
if key != "info":
del group_obs_spec[key]
if "state" in observation_spec.keys():
del observation_spec["state"]
return observation_spec

def action_spec(self, env: EnvBase) -> CompositeSpec:
return env.full_action_spec

@staticmethod
def env_name() -> str:
return "myenv"
Loading