diff --git a/docs/source/reference/data.rst b/docs/source/reference/data.rst index c5ee9820b5f..abcd96e6457 100644 --- a/docs/source/reference/data.rst +++ b/docs/source/reference/data.rst @@ -1147,16 +1147,17 @@ Utils :toctree: generated/ :template: rl_template.rst - MultiStep - consolidate_spec - check_no_exclusive_keys - contains_lazy_spec - Nested2TED + DensifyReward Flat2TED H5Combine H5Split + MultiStep + Nested2TED TED2Flat TED2Nested + check_no_exclusive_keys + consolidate_spec + contains_lazy_spec .. currentmodule:: torchrl.envs.transforms.rb_transforms diff --git a/test/test_postprocs.py b/test/test_postprocs.py index b357e94acdd..1948fd8cca4 100644 --- a/test/test_postprocs.py +++ b/test/test_postprocs.py @@ -12,7 +12,7 @@ from torchrl._utils import _ends_with from torchrl.collectors.utils import split_trajectories -from torchrl.data.postprocs.postprocs import MultiStep +from torchrl.data.postprocs.postprocs import DensifyReward, MultiStep if os.getenv("PYTORCH_TEST_FBCODE"): from pytorch.rl.test._utils_internal import get_default_devices @@ -20,205 +20,245 @@ from _utils_internal import get_default_devices -@pytest.mark.parametrize("n", range(1, 14)) -@pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize("key", ["observation", "pixels", "observation_whatever"]) -def test_multistep(n, key, device, T=11): - torch.manual_seed(0) +class TestMultiStep: + @pytest.mark.parametrize("n", range(1, 14)) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("key", ["observation", "pixels", "observation_whatever"]) + def test_multistep(self, n, key, device, T=11): + torch.manual_seed(0) - # mock data - b = 5 + # mock data + b = 5 - done = torch.zeros(b, T, 1, dtype=torch.bool, device=device) - done[0, -1] = True - done[1, -2] = True - done[2, -3] = True - done[3, -4] = True + done = torch.zeros(b, T, 1, dtype=torch.bool, device=device) + done[0, -1] = True + done[1, -2] = True + done[2, -3] = True + done[3, -4] = True - terminal = done.clone() - terminal[:, -1] = done.sum(1) != 1 + terminal = done.clone() + terminal[:, -1] = done.sum(1) != 1 - mask = done.clone().cumsum(1).cumsum(1) >= 2 - mask = ~mask + mask = done.clone().cumsum(1).cumsum(1) >= 2 + mask = ~mask - total_obs = torch.randn(1, T + 1, 1, device=device).expand(b, T + 1, 1) - tensordict = TensorDict( - source={ - key: total_obs[:, :T] * mask.to(torch.float), - "done": done, - "next": { - key: total_obs[:, 1:] * mask.to(torch.float), + total_obs = torch.randn(1, T + 1, 1, device=device).expand(b, T + 1, 1) + tensordict = TensorDict( + source={ + key: total_obs[:, :T] * mask.to(torch.float), "done": done, - "reward": torch.randn(1, T, 1, device=device).expand(b, T, 1) - * mask.to(torch.float), + "next": { + key: total_obs[:, 1:] * mask.to(torch.float), + "done": done, + "reward": torch.randn(1, T, 1, device=device).expand(b, T, 1) + * mask.to(torch.float), + }, + "collector": {"mask": mask}, }, - "collector": {"mask": mask}, - }, - batch_size=(b, T), - ).to(device) - - ms = MultiStep( - 0.9, - n, - ).to(device) - ms_tensordict = ms(tensordict.clone()) - - assert ms_tensordict.get("done").max() == 1 - - if n == 1: - assert_allclose_td( - tensordict, ms_tensordict.select(*list(tensordict.keys(True, True))) - ) + batch_size=(b, T), + ).to(device) - # assert that done at last step is similar to unterminated traj - torch.testing.assert_close( - ms_tensordict.get("gamma")[4], ms_tensordict.get("gamma")[0] - ) - torch.testing.assert_close( - ms_tensordict.get(("next", key))[4], ms_tensordict.get(("next", key))[0] - ) - torch.testing.assert_close( - ms_tensordict.get("steps_to_next_obs")[4], - ms_tensordict.get("steps_to_next_obs")[0], - ) + ms = MultiStep( + 0.9, + n, + ).to(device) + ms_tensordict = ms(tensordict.clone()) - # check that next obs is properly replaced, or that it is terminated - next_obs = ms_tensordict.get(key)[:, (ms.n_steps) :] - true_next_obs = ms_tensordict.get(("next", key))[:, : -(ms.n_steps)] - terminated = ~ms_tensordict.get("nonterminal") - assert ((next_obs == true_next_obs).all(-1) | terminated[:, (ms.n_steps) :]).all() + assert ms_tensordict.get("done").max() == 1 - # test gamma computation - torch.testing.assert_close( - ms_tensordict.get("gamma"), ms.gamma ** ms_tensordict.get("steps_to_next_obs") - ) + if n == 1: + assert_allclose_td( + tensordict, ms_tensordict.select(*list(tensordict.keys(True, True))) + ) + + # assert that done at last step is similar to unterminated traj + torch.testing.assert_close( + ms_tensordict.get("gamma")[4], ms_tensordict.get("gamma")[0] + ) + torch.testing.assert_close( + ms_tensordict.get(("next", key))[4], ms_tensordict.get(("next", key))[0] + ) + torch.testing.assert_close( + ms_tensordict.get("steps_to_next_obs")[4], + ms_tensordict.get("steps_to_next_obs")[0], + ) - # test reward - if n > 1: + # check that next obs is properly replaced, or that it is terminated + next_obs = ms_tensordict.get(key)[:, (ms.n_steps) :] + true_next_obs = ms_tensordict.get(("next", key))[:, : -(ms.n_steps)] + terminated = ~ms_tensordict.get("nonterminal") assert ( - ms_tensordict.get(("next", "reward")) - != ms_tensordict.get(("next", "original_reward")) - ).any() - else: + (next_obs == true_next_obs).all(-1) | terminated[:, (ms.n_steps) :] + ).all() + + # test gamma computation torch.testing.assert_close( - ms_tensordict.get(("next", "reward")), - ms_tensordict.get(("next", "original_reward")), + ms_tensordict.get("gamma"), + ms.gamma ** ms_tensordict.get("steps_to_next_obs"), ) + # test reward + if n > 1: + assert ( + ms_tensordict.get(("next", "reward")) + != ms_tensordict.get(("next", "original_reward")) + ).any() + else: + torch.testing.assert_close( + ms_tensordict.get(("next", "reward")), + ms_tensordict.get(("next", "original_reward")), + ) -@pytest.mark.parametrize("device", get_default_devices()) -@pytest.mark.parametrize( - "batch_size", - [ - [4], - [], - [1], - [2, 3], - ], -) -@pytest.mark.parametrize("T", [10, 1, 2]) -@pytest.mark.parametrize("obs_dim", [[1], []]) -@pytest.mark.parametrize("unsq_reward", [True, False]) -@pytest.mark.parametrize("last_done", [True, False]) -@pytest.mark.parametrize("n_steps", [4, 2, 1]) -def test_mutistep_cattrajs( - batch_size, T, obs_dim, unsq_reward, last_done, device, n_steps -): - # tests multi-step in the presence of consecutive trajectories. - obs = torch.randn(*batch_size, T + 1, *obs_dim) - reward = torch.rand(*batch_size, T) - action = torch.rand(*batch_size, T) - done = torch.zeros(*batch_size, T + 1, dtype=torch.bool) - done[..., T // 2] = 1 - if last_done: - done[..., -1] = 1 - if unsq_reward: - reward = reward.unsqueeze(-1) - done = done.unsqueeze(-1) - - td = TensorDict( - { - "obs": obs[..., :-1] if not obs_dim else obs[..., :-1, :], - "action": action, - "done": done[..., :-1] if not unsq_reward else done[..., :-1, :], - "next": { - "obs": obs[..., 1:] if not obs_dim else obs[..., 1:, :], - "done": done[..., 1:] if not unsq_reward else done[..., 1:, :], - "reward": reward, - }, - }, - batch_size=[*batch_size, T], - device=device, + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize( + "batch_size", + [ + [4], + [], + [1], + [2, 3], + ], ) - ms = MultiStep(0.98, n_steps) - tdm = ms(td) - if n_steps == 1: - # n_steps = 0 has no effect - for k in td["next"].keys(): - assert (tdm["next", k] == td["next", k]).all() - else: - next_obs = [] - obs = td["next", "obs"] - done = td["next", "done"] - if obs_dim: - obs = obs.squeeze(-1) + @pytest.mark.parametrize("T", [10, 1, 2]) + @pytest.mark.parametrize("obs_dim", [[1], []]) + @pytest.mark.parametrize("unsq_reward", [True, False]) + @pytest.mark.parametrize("last_done", [True, False]) + @pytest.mark.parametrize("n_steps", [4, 2, 1]) + def test_mutistep_cattrajs( + self, batch_size, T, obs_dim, unsq_reward, last_done, device, n_steps + ): + # tests multi-step in the presence of consecutive trajectories. + obs = torch.randn(*batch_size, T + 1, *obs_dim) + reward = torch.rand(*batch_size, T) + action = torch.rand(*batch_size, T) + done = torch.zeros(*batch_size, T + 1, dtype=torch.bool) + done[..., T // 2] = 1 + if last_done: + done[..., -1] = 1 if unsq_reward: - done = done.squeeze(-1) - for t in range(T): - idx = t + n_steps - 1 - while (done[..., t:idx].any() and idx > t) or idx > done.shape[-1] - 1: - idx = idx - 1 - next_obs.append(obs[..., idx]) - true_next_obs = tdm.get(("next", "obs")) - if obs_dim: - true_next_obs = true_next_obs.squeeze(-1) - next_obs = torch.stack(next_obs, -1) - assert (next_obs == true_next_obs).all() - - -@pytest.mark.parametrize("unsq_reward", [True, False]) -def test_unusual_done(unsq_reward): - batch_size = [10, 3] - T = 10 - obs_dim = [ - 1, - ] - last_done = True - device = torch.device("cpu") - n_steps = 3 - - obs = torch.randn(*batch_size, T + 1, 5, *obs_dim) - reward = torch.rand(*batch_size, T, 5) - action = torch.rand(*batch_size, T, 5) - done = torch.zeros(*batch_size, T + 1, 5, dtype=torch.bool) - done[..., T // 2, :] = 1 - if last_done: - done[..., -1, :] = 1 - if unsq_reward: - reward = reward.unsqueeze(-1) - done = done.unsqueeze(-1) - - td = TensorDict( - { - "obs": obs[..., :-1, :] if not obs_dim else obs[..., :-1, :, :], - "action": action, - "done": done[..., :-1, :] if not unsq_reward else done[..., :-1, :, :], - "next": { - "obs": obs[..., 1:, :] if not obs_dim else obs[..., 1:, :, :], - "done": done[..., 1:, :] if not unsq_reward else done[..., 1:, :, :], - "reward": reward, + reward = reward.unsqueeze(-1) + done = done.unsqueeze(-1) + + td = TensorDict( + { + "obs": obs[..., :-1] if not obs_dim else obs[..., :-1, :], + "action": action, + "done": done[..., :-1] if not unsq_reward else done[..., :-1, :], + "next": { + "obs": obs[..., 1:] if not obs_dim else obs[..., 1:, :], + "done": done[..., 1:] if not unsq_reward else done[..., 1:, :], + "reward": reward, + }, }, - }, - batch_size=[*batch_size, T], - device=device, - ) - ms = MultiStep(0.98, n_steps) - if unsq_reward: - with pytest.raises(RuntimeError, match="tensordict shape must be compatible"): + batch_size=[*batch_size, T], + device=device, + ) + ms = MultiStep(0.98, n_steps) + tdm = ms(td) + if n_steps == 1: + # n_steps = 0 has no effect + for k in td["next"].keys(): + assert (tdm["next", k] == td["next", k]).all() + else: + next_obs = [] + obs = td["next", "obs"] + done = td["next", "done"] + if obs_dim: + obs = obs.squeeze(-1) + if unsq_reward: + done = done.squeeze(-1) + for t in range(T): + idx = t + n_steps - 1 + while (done[..., t:idx].any() and idx > t) or idx > done.shape[-1] - 1: + idx = idx - 1 + next_obs.append(obs[..., idx]) + true_next_obs = tdm.get(("next", "obs")) + if obs_dim: + true_next_obs = true_next_obs.squeeze(-1) + next_obs = torch.stack(next_obs, -1) + assert (next_obs == true_next_obs).all() + + @pytest.mark.parametrize("unsq_reward", [True, False]) + def test_unusual_done(self, unsq_reward): + batch_size = [10, 3] + T = 10 + obs_dim = [ + 1, + ] + last_done = True + device = torch.device("cpu") + n_steps = 3 + + obs = torch.randn(*batch_size, T + 1, 5, *obs_dim) + reward = torch.rand(*batch_size, T, 5) + action = torch.rand(*batch_size, T, 5) + done = torch.zeros(*batch_size, T + 1, 5, dtype=torch.bool) + done[..., T // 2, :] = 1 + if last_done: + done[..., -1, :] = 1 + if unsq_reward: + reward = reward.unsqueeze(-1) + done = done.unsqueeze(-1) + + td = TensorDict( + { + "obs": obs[..., :-1, :] if not obs_dim else obs[..., :-1, :, :], + "action": action, + "done": done[..., :-1, :] if not unsq_reward else done[..., :-1, :, :], + "next": { + "obs": obs[..., 1:, :] if not obs_dim else obs[..., 1:, :, :], + "done": done[..., 1:, :] + if not unsq_reward + else done[..., 1:, :, :], + "reward": reward, + }, + }, + batch_size=[*batch_size, T], + device=device, + ) + ms = MultiStep(0.98, n_steps) + if unsq_reward: + with pytest.raises( + RuntimeError, match="tensordict shape must be compatible" + ): + _ = ms(td) + else: + # we just check that it runs _ = ms(td) - else: - # we just check that it runs - _ = ms(td) + + +class TestDensifyReward: + def test_densify_reward(self): + # Create a sample TensorDict + tensordict = TensorDict( + { + "next": { + "reward": torch.zeros(10, 1), + "done": torch.zeros(10, 1, dtype=torch.bool), + } + }, + batch_size=[10], + ) + # Set some done flags and rewards + tensordict["next", "done"][[3, 7]] = True + tensordict["next", "reward"][3] = 3 + tensordict["next", "reward"][7] = 7 + # Create an instance of LastRewardToTraj + last_reward_to_traj = DensifyReward() + # Apply the transform + new_tensordict = last_reward_to_traj(tensordict) + assert new_tensordict is tensordict + assert ( + new_tensordict["next", "reward"] + == torch.cat( + [ + torch.full((4, 1), 3.0), + torch.full((4, 1), 7.0), + torch.full((2, 1), 0.0), + ], + 0, + ) + ).all() class TestSplits: diff --git a/torchrl/data/__init__.py b/torchrl/data/__init__.py index 0083938530b..8ae832b257e 100644 --- a/torchrl/data/__init__.py +++ b/torchrl/data/__init__.py @@ -30,7 +30,7 @@ TensorMap, Tree, ) -from .postprocs import MultiStep +from .postprocs import DensifyReward, MultiStep from .replay_buffers import ( Flat2TED, FlatStorageCheckpointer, @@ -119,6 +119,7 @@ "CompositeSpec", "ConstantKLController", "DEVICE_TYPING", + "DensifyReward", "DiscreteTensorSpec", "Flat2TED", "FlatStorageCheckpointer", diff --git a/torchrl/data/postprocs/__init__.py b/torchrl/data/postprocs/__init__.py index afa0f73ecfb..4ec37500489 100644 --- a/torchrl/data/postprocs/__init__.py +++ b/torchrl/data/postprocs/__init__.py @@ -3,6 +3,6 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from .postprocs import MultiStep +from .postprocs import DensifyReward, MultiStep -__all__ = ["MultiStep"] +__all__ = ["MultiStep", "DensifyReward"] diff --git a/torchrl/data/postprocs/postprocs.py b/torchrl/data/postprocs/postprocs.py index 7814c6cce14..1868deb2c12 100644 --- a/torchrl/data/postprocs/postprocs.py +++ b/torchrl/data/postprocs/postprocs.py @@ -6,10 +6,13 @@ from __future__ import annotations import torch -from tensordict import TensorDictBase +from tensordict import NestedKey, TensorDictBase, unravel_key +from tensordict.nn import TensorDictModuleBase from tensordict.utils import expand_right from torch import nn +from torchrl.objectives.value.functional import reward2go + def _get_reward( gamma: float, @@ -291,3 +294,97 @@ def _multi_step_func( ) tensordict.batch_size = tensordict.batch_size[:ndim] return tensordict + + +class DensifyReward(TensorDictModuleBase): + """A util to reassign the reward at done state to the rest of the trajectory. + + This transform is to be used with sparse rewards to assign a reward to each step of a trajectory when only the + reward at `done` is non-null. + + .. note:: The class calls the :func:`~torchrl.objectives.value.functional.reward2go` function, which will + also sum intermediate rewards. Make sure you understand what the `reward2go` function returns before using + this module. + + Args: + reward_key (NestedKey, optional): The key in the input TensorDict where the reward is stored. + Defaults to `"reward"`. + done_key (NestedKey, optional): The key in the input TensorDict where the done flag is stored. + Defaults to `"done"`. + reward_key_out (NestedKey | None, optional): The key in the output TensorDict where the reassigned reward + will be stored. If None, it defaults to the value of `reward_key`. + Defaults to `None`. + time_dim (int, optional): The dimension in the input TensorDict where the time is unrolled. + Defaults to `2`. + discount (float, optional): The discount factor to use for computing the discounted cumulative sum of rewards. + Defaults to `1.0` (no discounting). + + Returns: + TensorDict: The input TensorDict with the reassigned reward stored under the key specified by `reward_key_out`. + + Examples: + >>> import torch + >>> from tensordict import TensorDict + >>> + >>> from torchrl.data import DensifyReward + >>> + >>> # Create a sample TensorDict + >>> tensordict = TensorDict({ + ... "next": { + ... "reward": torch.zeros(10, 1), + ... "done": torch.zeros(10, 1, dtype=torch.bool) + ... } + ... }, batch_size=[10]) + >>> # Set some done flags and rewards + >>> tensordict["next", "done"][[3, 7]] = True + >>> tensordict["next", "reward"][3] = 3 + >>> tensordict["next", "reward"][7] = 7 + >>> # Create an instance of LastRewardToTraj + >>> last_reward_to_traj = DensifyReward() + >>> # Apply the transform + >>> new_tensordict = last_reward_to_traj(tensordict) + >>> # Print the reassigned rewards + >>> print(new_tensordict["next", "reward"]) + tensor([[3.], + [3.], + [3.], + [3.], + [7.], + [7.], + [7.], + [7.], + [0.], + [0.]]) + + """ + + def __init__( + self, + *, + reward_key: NestedKey = "reward", + done_key: NestedKey = "done", + reward_key_out: NestedKey | None = None, + time_dim: int = 2, + discount: float = 1.0, + ): + super().__init__() + self.in_keys = [unravel_key(reward_key), unravel_key(done_key)] + if reward_key_out is None: + reward_key_out = reward_key + self.out_keys = [unravel_key(reward_key_out)] + self.time_dim = time_dim + self.discount = discount + + def forward(self, tensordict): + # Get done + done = tensordict.get(("next", self.in_keys[1])) + # Get reward + reward = tensordict.get(("next", self.in_keys[0])) + if reward.shape != done.shape: + raise RuntimeError( + f"reward and done state are expected to have the same shape. Got reard.shape={reward.shape} " + f"and done.shape={done.shape}." + ) + reward = reward2go(reward, done, time_dim=-2, gamma=self.discount) + tensordict.set(("next", self.out_keys[0]), reward) + return tensordict diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index fa0ca17317c..4930393f4b5 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -10304,10 +10304,15 @@ class TrajCounter(Transform): """ - def __init__(self, out_key: NestedKey = "traj_count"): + def __init__( + self, out_key: NestedKey = "traj_count", *, repeats: int | None = None + ): super().__init__(in_keys=[], out_keys=[out_key]) self._make_shared_value() self._initialized = False + if repeats is None: + repeats = 0 + self.repeats = repeats def _make_shared_value(self): self._traj_count = mp.Value("i", 0)