Skip to content

Commit

Permalink
[BugFix] Fix TorchRL demo tutorial (pytorch#721)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Dec 1, 2022
1 parent d29bbec commit 08e3b71
Show file tree
Hide file tree
Showing 9 changed files with 148 additions and 115 deletions.
5 changes: 3 additions & 2 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
DiscreteActionVecPolicy,
MockSerialEnv,
)
from tensordict.nn import TensorDictModule
from tensordict.tensordict import assert_allclose_td, TensorDict
from torch import nn
from torchrl._utils import seed_generator
Expand Down Expand Up @@ -980,12 +981,12 @@ def test_auto_wrap_modules(self, collector_class, multiple_outputs, env_maker):

if collector_class is not SyncDataCollector:
assert all(
isinstance(p, SafeModule) for p in collector._policy_dict.values()
isinstance(p, TensorDictModule) for p in collector._policy_dict.values()
)
assert all(p.out_keys == out_keys for p in collector._policy_dict.values())
assert all(p.module is policy for p in collector._policy_dict.values())
else:
assert isinstance(collector.policy, SafeModule)
assert isinstance(collector.policy, TensorDictModule)
assert collector.policy.out_keys == out_keys
assert collector.policy.module is policy

Expand Down
40 changes: 20 additions & 20 deletions torchrl/collectors/collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
import torch
import torch.nn as nn
from tensordict.nn import TensorDictModule
from tensordict.tensordict import TensorDict, TensorDictBase
from torch import multiprocessing as mp
from torch.utils.data import IterableDataset
Expand All @@ -29,7 +30,6 @@
from ..data.utils import CloudpickleWrapper, DEVICE_TYPING
from ..envs.common import EnvBase
from ..envs.vec_env import _BatchedEnv
from ..modules.tensordict_module import SafeModule, SafeProbabilisticModule
from .utils import split_trajectories

_TIMEOUT = 1.0
Expand Down Expand Up @@ -84,21 +84,21 @@ def recursive_map_to_cpu(dictionary: OrderedDict) -> OrderedDict:
def _policy_is_tensordict_compatible(policy: nn.Module):
sig = inspect.signature(policy.forward)

if isinstance(policy, SafeModule) or (
if isinstance(policy, TensorDictModule) or (
len(sig.parameters) == 1
and hasattr(policy, "in_keys")
and hasattr(policy, "out_keys")
):
# if the policy is a SafeModule or takes a single argument and defines
# if the policy is a TensorDictModule or takes a single argument and defines
# in_keys and out_keys then we assume it can already deal with TensorDict input
# to forward and we return True
return True
elif not hasattr(policy, "in_keys") and not hasattr(policy, "out_keys"):
# if it's not a SafeModule, and in_keys and out_keys are not defined then
# if it's not a TensorDictModule, and in_keys and out_keys are not defined then
# we assume no TensorDict compatibility and will try to wrap it.
return False

# if in_keys or out_keys were defined but policy is not a SafeModule or
# if in_keys or out_keys were defined but policy is not a TensorDictModule or
# accepts multiple arguments then it's likely the user is trying to do something
# that will have undetermined behaviour, we raise an error
raise TypeError(
Expand All @@ -107,7 +107,7 @@ def _policy_is_tensordict_compatible(policy: nn.Module):
"should take a single argument of type TensorDict to policy.forward and define "
"both in_keys and out_keys. Alternatively, policy.forward can accept "
"arbitrarily many tensor inputs and leave in_keys and out_keys undefined and "
"TorchRL will attempt to automatically wrap the policy with a SafeModule."
"TorchRL will attempt to automatically wrap the policy with a TensorDictModule."
)


Expand All @@ -116,13 +116,13 @@ def _get_policy_and_device(
self,
policy: Optional[
Union[
SafeProbabilisticModule,
TensorDictModule,
Callable[[TensorDictBase], TensorDictBase],
]
] = None,
device: Optional[DEVICE_TYPING] = None,
observation_spec: TensorSpec = None,
) -> Tuple[SafeProbabilisticModule, torch.device, Union[None, Callable[[], dict]]]:
) -> Tuple[TensorDictModule, torch.device, Union[None, Callable[[], dict]]]:
"""Util method to get a policy and its device given the collector __init__ inputs.
From a policy and a device, assigns the self.device attribute to
Expand All @@ -133,7 +133,7 @@ def _get_policy_and_device(
create_env_fn (Callable or list of callables): an env creator
function (or a list of creators)
create_env_kwargs (dictionary): kwargs for the env creator
policy (SafeProbabilisticModule, optional): a policy to be used
policy (TensorDictModule, optional): a policy to be used
device (int, str or torch.device, optional): device where to place
the policy
observation_spec (TensorSpec, optional): spec of the observations
Expand Down Expand Up @@ -161,13 +161,13 @@ def _get_policy_and_device(
# callables should be supported as policies.
if not _policy_is_tensordict_compatible(policy):
# policy is a nn.Module that doesn't operate on tensordicts directly
# so we attempt to auto-wrap policy with SafeModule
# so we attempt to auto-wrap policy with TensorDictModule
if observation_spec is None:
raise ValueError(
"Unable to read observation_spec from the environment. This is "
"required to check compatibility of the environment and policy "
"since the policy is a nn.Module that operates on tensors "
"rather than a SafeModule or a nn.Module that accepts a "
"rather than a TensorDictModule or a nn.Module that accepts a "
"TensorDict as input and defines in_keys and out_keys."
)
sig = inspect.signature(policy.forward)
Expand All @@ -181,18 +181,18 @@ def _get_policy_and_device(
if isinstance(output, tuple):
out_keys.extend(f"output{i+1}" for i in range(len(output) - 1))

policy = SafeModule(
policy = TensorDictModule(
policy, in_keys=list(sig.parameters), out_keys=out_keys
)
else:
raise TypeError(
"Arguments to policy.forward are incompatible with entries in "
"env.observation_spec. If you want TorchRL to automatically "
"wrap your policy with a SafeModule then the arguments "
"wrap your policy with a TensorDictModule then the arguments "
"to policy.forward must correspond one-to-one with entries in "
"env.observation_spec that are prefixed with 'next_'. For more "
"complex behaviour and more control you can consider writing "
"your own SafeModule."
"your own TensorDictModule."
)

try:
Expand Down Expand Up @@ -305,7 +305,7 @@ def __init__(
], # noqa: F821
policy: Optional[
Union[
SafeProbabilisticModule,
TensorDictModule,
Callable[[TensorDictBase], TensorDictBase],
]
] = None,
Expand Down Expand Up @@ -517,7 +517,7 @@ def iterator(self) -> Iterator[TensorDictBase]:
def _cast_to_policy(self, td: TensorDictBase) -> TensorDictBase:
policy_device = self.device
if hasattr(self.policy, "in_keys"):
# some keys may be absent -- SafeModule is resilient to missing keys
# some keys may be absent -- TensorDictModule is resilient to missing keys
td = td.select(*self.policy.in_keys, strict=False)
if self._td_policy is None:
self._td_policy = td.to(policy_device)
Expand Down Expand Up @@ -717,7 +717,7 @@ class _MultiDataCollector(_DataCollector):
Args:
create_env_fn (list of Callabled): list of Callables, each returning an instance of EnvBase
policy (Callable, optional): Instance of SafeProbabilisticModule class.
policy (Callable, optional): Instance of TensorDictModule class.
Must accept TensorDictBase object as input.
total_frames (int): lower bound of the total number of frames returned by the collector. In parallel settings,
the actual number of frames may well be greater than this as the closing signals are sent to the
Expand Down Expand Up @@ -776,7 +776,7 @@ def __init__(
create_env_fn: Sequence[Callable[[], EnvBase]],
policy: Optional[
Union[
SafeProbabilisticModule,
TensorDictModule,
Callable[[TensorDictBase], TensorDictBase],
]
] = None,
Expand Down Expand Up @@ -1303,7 +1303,7 @@ class aSyncDataCollector(MultiaSyncDataCollector):
Args:
create_env_fn (Callabled): Callable returning an instance of EnvBase
policy (Callable, optional): Instance of SafeProbabilisticModule class.
policy (Callable, optional): Instance of TensorDictModule class.
Must accept TensorDictBase object as input.
total_frames (int): lower bound of the total number of frames returned
by the collector. In parallel settings, the actual number of
Expand Down Expand Up @@ -1358,7 +1358,7 @@ def __init__(
create_env_fn: Callable[[], EnvBase],
policy: Optional[
Union[
SafeProbabilisticModule,
TensorDictModule,
Callable[[TensorDictBase], TensorDictBase],
]
] = None,
Expand Down
8 changes: 7 additions & 1 deletion tutorials/sphinx-tutorials/coding_ddpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@

# Make all the necessary imports for training

# sphinx_gallery_start_ignore
import warnings

warnings.filterwarnings("ignore")
# sphinx_gallery_end_ignore

from copy import deepcopy
from typing import Optional

Expand All @@ -40,6 +46,7 @@
import torch.cuda
import tqdm
from matplotlib import pyplot as plt
from tensordict.nn import TensorDictModule
from torch import nn, optim
from torchrl.collectors import MultiaSyncDataCollector
from torchrl.data import (
Expand All @@ -64,7 +71,6 @@
MLP,
OrnsteinUhlenbeckProcessWrapper,
ProbabilisticActor,
TensorDictModule,
ValueOperator,
)
from torchrl.modules.distributions.continuous import TanhDelta
Expand Down
63 changes: 25 additions & 38 deletions tutorials/sphinx-tutorials/coding_dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,19 @@
# to provide a high-level illustration of TorchRL features in the context
# of this algorithm.

# sphinx_gallery_start_ignore
import warnings

warnings.filterwarnings("ignore")
# sphinx_gallery_end_ignore

import torch
import tqdm
from functorch import vmap
from IPython import display
from matplotlib import pyplot as plt
from tensordict import TensorDict
from tensordict.nn import get_functional
from torch import nn
from torchrl.collectors import MultiaSyncDataCollector
from torchrl.data import LazyMemmapStorage, TensorDictReplayBuffer
Expand Down Expand Up @@ -251,18 +259,16 @@ def make_model():
print("Q-value network results:", tensordict)

# make functional
factor, (_, buffers) = actor.make_functional_with_buffers(clone=True, native=True)
# making functional creates a copy of the params, which we don't want (i.e. we want the parameters from `actor` to match those in the params object),
# hence we create the params object in a second step
params = TensorDict({k: v for k, v in net.named_parameters()}, []).unflatten_keys(
"."
)
# here's an explicit way of creating the parameters and buffer tensordict.
# Alternatively, we could have used `params = make_functional(actor)` from
# tensordict.nn
params = TensorDict({k: v for k, v in actor.named_parameters()}, [])
buffers = TensorDict({k: v for k, v in actor.named_buffers()}, [])
params = params.update(buffers).unflatten_keys(".") # creates a nested TensorDict
factor = get_functional(actor)

# creating the target parameters is fairly easy with tensordict:
params_target, buffers_target = (
params.to_tensordict().detach(),
buffers.to_tensordict().detach(),
)
(params_target,) = (params.to_tensordict().detach(),)

# we wrap our actor in an EGreedyWrapper for data collection
actor_explore = EGreedyWrapper(
Expand All @@ -272,7 +278,7 @@ def make_model():
eps_end=eps_greedy_val_env,
)

return factor, actor, actor_explore, params, buffers, params_target, buffers_target
return factor, actor, actor_explore, params, params_target


###############################################################################
Expand All @@ -286,14 +292,10 @@ def make_model():
actor,
actor_explore,
params,
buffers,
params_target,
buffers_target,
) = make_model()
params_flat = params.flatten_keys(".")
buffers_flat = buffers.flatten_keys(".")
params_target_flat = params_target.flatten_keys(".")
buffers_target_flat = buffers_target.flatten_keys(".")

###############################################################################
# Regular DQN
Expand Down Expand Up @@ -393,7 +395,7 @@ def make_model():

# Compute action value (of the action actually taken) at time t
sampled_data_out = sampled_data.select(*actor.in_keys)
sampled_data_out = factor(sampled_data_out, params=params, buffers=buffers)
sampled_data_out = factor(sampled_data_out, params=params)
action_value = sampled_data_out["action_value"]
action_value = (action_value * action.to(action_value.dtype)).sum(-1)
with torch.no_grad():
Expand All @@ -402,7 +404,6 @@ def make_model():
next_value = factor(
tdstep.select(*actor.in_keys),
params=params_target,
buffers=buffers_target,
)["chosen_action_value"].squeeze(-1)
exp_value = reward + gamma * next_value * (1 - done)
assert exp_value.shape == action_value.shape
Expand All @@ -420,9 +421,6 @@ def make_model():
for (key, p1) in params_flat.items():
p2 = params_target_flat[key]
params_target_flat.set_(key, tau * p1.data + (1 - tau) * p2.data)
for (key, p1) in buffers_flat.items():
p2 = buffers_target_flat[key]
buffers_target_flat.set_(key, tau * p1.data + (1 - tau) * p2.data)

pbar.set_description(
f"error: {error: 4.4f}, value: {action_value.mean(): 4.4f}"
Expand Down Expand Up @@ -513,7 +511,7 @@ def make_model():
"grad_vals": grad_vals,
"traj_lengths_training": traj_lengths,
"traj_count": traj_count,
"weights": (params, buffers),
"weights": (params,),
},
"saved_results_td0.pt",
)
Expand Down Expand Up @@ -548,14 +546,10 @@ def make_model():
actor,
actor_explore,
params,
buffers,
params_target,
buffers_target,
) = make_model()
params_flat = params.flatten_keys(".")
buffers_flat = buffers.flatten_keys(".")
params_target_flat = params_target.flatten_keys(".")
buffers_target_flat = buffers_target.flatten_keys(".")

###############################################################################

Expand Down Expand Up @@ -632,19 +626,15 @@ def make_model():
action = sampled_data["action"].clone()

sampled_data_out = sampled_data.select(*actor.in_keys)
sampled_data_out = factor(
sampled_data_out, params=params, buffers=buffers, vmap=(None, None, 0)
)
sampled_data_out = vmap(factor, (0, None))(sampled_data_out, params)
action_value = sampled_data_out["action_value"]
action_value = (action_value * action.to(action_value.dtype)).sum(-1, True)
with torch.no_grad():
tdstep = step_mdp(sampled_data)
next_value = factor(
tdstep.select(*actor.in_keys),
params=params_target,
buffers=buffers_target,
vmap=(None, None, 0),
)["chosen_action_value"]
next_value = vmap(factor, (0, None))(
tdstep.select(*actor.in_keys), params
)
next_value = next_value["chosen_action_value"]
error = vec_td_lambda_advantage_estimate(
gamma,
lmbda,
Expand All @@ -671,9 +661,6 @@ def make_model():
for (key, p1) in params_flat.items():
p2 = params_target_flat[key]
params_target_flat.set_(key, tau * p1.data + (1 - tau) * p2.data)
for (key, p1) in buffers_flat.items():
p2 = buffers_target_flat[key]
buffers_target_flat.set_(key, tau * p1.data + (1 - tau) * p2.data)

pbar.set_description(
f"error: {error: 4.4f}, value: {action_value.mean(): 4.4f}"
Expand Down Expand Up @@ -765,7 +752,7 @@ def make_model():
"grad_vals": grad_vals,
"traj_lengths_training": traj_lengths,
"traj_count": traj_count,
"weights": (params, buffers),
"weights": (params,),
},
"saved_results_tdlambda.pt",
)
Expand Down
Loading

0 comments on commit 08e3b71

Please sign in to comment.