Skip to content

Commit

Permalink
vector env updates (openai#1706)
Browse files Browse the repository at this point in the history
* make daemon=True an option of async_vector_env

* custom worker in async_vector_env

* add compatibility methods to SyncVectorEnv

* fix name in sync_vector_env

* vectorenv api cleanup

* add docstrings for daemon and worker options in AsyncVectorEnv
  • Loading branch information
pzhokhov authored Oct 9, 2019
1 parent 1d31c12 commit 51136b1
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 50 deletions.
29 changes: 17 additions & 12 deletions gym/vector/async_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,23 @@ class AsyncVectorEnv(VectorEnv):
context : str, optional
Context for multiprocessing. If `None`, then the default context is used.
Only available in Python 3.
daemon : bool (default: `True`)
If `True`, then subprocesses have `daemon` flag turned on; that is, they
will quit if the head process quits. However, `daemon=True` prevents
subprocesses to spawn children, so for some environments you may want
to have it set to `False`
worker : function, optional
WARNING - advanced mode option! If set, then use that worker in a subprocess
instead of a default one. Can be useful to override some inner vector env
logic, for instance, how resets on done are handled. Provides high
degree of flexibility and a high chance to shoot yourself in the foot; thus,
if you are writing your own worker, it is recommended to start from the code
for `_worker` (or `_worker_shared_memory`) method below, and add changes
"""
def __init__(self, env_fns, observation_space=None, action_space=None,
shared_memory=True, copy=True, context=None):
shared_memory=True, copy=True, context=None, daemon=True, worker=None):
try:
ctx = mp.get_context(context)
except AttributeError:
Expand Down Expand Up @@ -86,6 +100,7 @@ def __init__(self, env_fns, observation_space=None, action_space=None,
self.parent_pipes, self.processes = [], []
self.error_queue = ctx.Queue()
target = _worker_shared_memory if self.shared_memory else _worker
target = worker or target
with clear_mpi_env_vars():
for idx, env_fn in enumerate(self.env_fns):
parent_pipe, child_pipe = ctx.Pipe()
Expand All @@ -97,24 +112,14 @@ def __init__(self, env_fns, observation_space=None, action_space=None,
self.parent_pipes.append(parent_pipe)
self.processes.append(process)

process.daemon = True
process.daemon = daemon
process.start()
child_pipe.close()

self._state = AsyncState.DEFAULT
self._check_observation_spaces()

def seed(self, seeds=None):
"""
Parameters
----------
seeds : list of int, or int, optional
Random seed for each individual environment. If `seeds` is a list of
length `num_envs`, then the items of the list are chosen as random
seeds. If `seeds` is an int, then each environment uses the random
seed `seeds + n`, where `n` is the index of the environment (between
`0` and `num_envs - 1`).
"""
self._assert_is_running()
if seeds is None:
seeds = [None for _ in range(self.num_envs)]
Expand Down
44 changes: 6 additions & 38 deletions gym/vector/sync_vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,18 +45,9 @@ def __init__(self, env_fns, observation_space=None, action_space=None,
n=self.num_envs, fn=np.zeros)
self._rewards = np.zeros((self.num_envs,), dtype=np.float64)
self._dones = np.zeros((self.num_envs,), dtype=np.bool_)
self._actions = None

def seed(self, seeds=None):
"""
Parameters
----------
seeds : list of int, or int, optional
Random seed for each individual environment. If `seeds` is a list of
length `num_envs`, then the items of the list are chosen as random
seeds. If `seeds` is an int, then each environment uses the random
seed `seeds + n`, where `n` is the index of the environment (between
`0` and `num_envs - 1`).
"""
if seeds is None:
seeds = [None for _ in range(self.num_envs)]
if isinstance(seeds, int):
Expand All @@ -66,13 +57,7 @@ def seed(self, seeds=None):
for env, seed in zip(self.envs, seeds):
env.seed(seed)

def reset(self):
"""
Returns
-------
observations : sample from `observation_space`
A batch of observations from the vectorized environment.
"""
def reset_wait(self):
self._dones[:] = False
observations = []
for env in self.envs:
Expand All @@ -82,29 +67,12 @@ def reset(self):

return np.copy(self.observations) if self.copy else self.observations

def step(self, actions):
"""
Parameters
----------
actions : iterable of samples from `action_space`
List of actions.
def step_async(self, actions):
self._actions = actions

Returns
-------
observations : sample from `observation_space`
A batch of observations from the vectorized environment.
rewards : `np.ndarray` instance (dtype `np.float_`)
A vector of rewards from the vectorized environment.
dones : `np.ndarray` instance (dtype `np.bool_`)
A vector whose entries indicate whether the episode has ended.
infos : list of dict
A list of auxiliary diagnostic informations.
"""
def step_wait(self):
observations, infos = [], []
for i, (env, action) in enumerate(zip(self.envs, actions)):
for i, (env, action) in enumerate(zip(self.envs, self._actions)):
observation, self._rewards[i], self._dones[i], info = env.step(action)
if self._dones[i]:
observation = env.reset()
Expand Down
40 changes: 40 additions & 0 deletions gym/vector/vector_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ def reset_wait(self, **kwargs):
raise NotImplementedError()

def reset(self):
"""
Returns
-------
observations : sample from `observation_space`
A batch of observations from the vectorized environment.
"""
self.reset_async()
return self.reset_wait()

Expand All @@ -50,9 +56,43 @@ def step_wait(self, **kwargs):
raise NotImplementedError()

def step(self, actions):
"""
Parameters
----------
actions : iterable of samples from `action_space`
List of actions.
Returns
-------
observations : sample from `observation_space`
A batch of observations from the vectorized environment.
rewards : `np.ndarray` instance (dtype `np.float_`)
A vector of rewards from the vectorized environment.
dones : `np.ndarray` instance (dtype `np.bool_`)
A vector whose entries indicate whether the episode has ended.
infos : list of dict
A list of auxiliary diagnostic informations.
"""
self.step_async(actions)
return self.step_wait()

def seed(self, seeds=None):
"""
Parameters
----------
seeds : list of int, or int, optional
Random seed for each individual environment. If `seeds` is a list of
length `num_envs`, then the items of the list are chosen as random
seeds. If `seeds` is an int, then each environment uses the random
seed `seeds + n`, where `n` is the index of the environment (between
`0` and `num_envs - 1`).
"""
pass


def __del__(self):
if hasattr(self, 'closed'):
if not self.closed:
Expand Down

0 comments on commit 51136b1

Please sign in to comment.