From 771ef814f98b30dbe0e1b7acb2625a0bf16a1e08 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 9 Jul 2023 14:04:11 +0100 Subject: [PATCH] [BugFix] Fix TD3 and compat with https://github.com/pytorch-labs/tensordict/pull/482 (#1375) --- benchmarks/test_objectives_benchmarks.py | 2 +- test/test_cost.py | 8 +++++-- torchrl/envs/utils.py | 27 ++++-------------------- torchrl/objectives/redq.py | 2 +- torchrl/objectives/td3.py | 24 +++++++++++---------- 5 files changed, 25 insertions(+), 38 deletions(-) diff --git a/benchmarks/test_objectives_benchmarks.py b/benchmarks/test_objectives_benchmarks.py index 308720e8a5f..4fd365afe37 100644 --- a/benchmarks/test_objectives_benchmarks.py +++ b/benchmarks/test_objectives_benchmarks.py @@ -439,7 +439,7 @@ def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden= ) loss(td) - benchmark(loss, td) + benchmark.pedantic(loss, args=(td,), rounds=100, iterations=10) def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64): diff --git a/test/test_cost.py b/test/test_cost.py index 4c55527717b..d087e7073bf 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -1771,7 +1771,6 @@ def test_constructor(self, spec, bounds): @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) @pytest.mark.parametrize("done_key", ["done", "done2"]) def test_td3_notensordict(self, observation_key, reward_key, done_key): - torch.manual_seed(self.seed) actor = self._create_mock_actor(in_keys=[observation_key]) qvalue = self._create_mock_value( @@ -1793,12 +1792,15 @@ def test_td3_notensordict(self, observation_key, reward_key, done_key): td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") with pytest.warns(UserWarning, match="No target network updater has been"): + torch.manual_seed(0) loss_val_td = loss(td) + torch.manual_seed(0) loss_val = loss(**kwargs) for i, key in enumerate(loss_val_td.keys()): torch.testing.assert_close(loss_val_td.get(key), loss_val[i]) # test select loss.select_out_keys("loss_actor", "loss_qvalue") + torch.manual_seed(0) if torch.__version__ >= "2.0.0": loss_actor, loss_qvalue = loss(**kwargs) else: @@ -3646,6 +3648,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est): def test_redq_batcher(self, n, delay_qvalue, num_qvalue, device, gamma=0.9): torch.manual_seed(self.seed) td = self._create_seq_mock_data_redq(device=device) + assert td.names == td.get("next").names actor = self._create_mock_actor(device=device) qvalue = self._create_mock_qvalue(device=device) @@ -3661,7 +3664,9 @@ def test_redq_batcher(self, n, delay_qvalue, num_qvalue, device, gamma=0.9): ms = MultiStep(gamma=gamma, n_steps=n).to(device) td_clone = td.clone() + assert td_clone.names == td_clone.get("next").names ms_td = ms(td_clone) + assert ms_td.names == ms_td.get("next").names torch.manual_seed(0) np.random.seed(0) @@ -5506,7 +5511,6 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est retain_graph=True, allow_unused=False, ) - print(advantage, gradient_mode, delay_value, td_est) @pytest.mark.parametrize( "td_est", diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index d9ac341ca9d..45205c79ab9 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -4,8 +4,6 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -from copy import copy - import pkg_resources import torch @@ -20,12 +18,7 @@ set_interaction_mode as set_exploration_mode, set_interaction_type as set_exploration_type, ) -from tensordict.tensordict import ( - LazyStackedTensorDict, - NestedKey, - TensorDict, - TensorDictBase, -) +from tensordict.tensordict import LazyStackedTensorDict, NestedKey, TensorDictBase __all__ = [ "exploration_mode", @@ -204,7 +197,7 @@ def step_mdp( if exclude_action: excluded = excluded.union({action_key}) next_td = tensordict.get("next") - out = _clone_no_keys(next_td) + out = next_td.empty() total_key = () if keep_other: @@ -230,7 +223,7 @@ def _set_single_key(source, dest, key, clone=False): if is_tensor_collection(val): new_val = dest.get(k, None) if new_val is None: - new_val = _clone_no_keys(val) + new_val = val.empty() # dest.set(k, new_val) dest._set_str(k, new_val, inplace=False, validated=True) source = val @@ -250,7 +243,7 @@ def _set(source, dest, key, total_key, excluded): if is_tensor_collection(val): new_val = dest.get(key, None) if new_val is None: - new_val = _clone_no_keys(val) + new_val = val.empty() non_empty_local = False for subkey in val.keys(): non_empty_local = ( @@ -267,18 +260,6 @@ def _set(source, dest, key, total_key, excluded): return non_empty -def _clone_no_keys(td): - return TensorDict( - source={}, - batch_size=td.batch_size, - device=td.device, - names=copy(td._td_dim_names), - _run_checks=False, - _is_shared=td.is_shared(), - _is_memmap=td.is_memmap(), - ) - - def get_available_libraries(): """Returns all the supported libraries.""" return SUPPORTED_LIBRARIES diff --git a/torchrl/objectives/redq.py b/torchrl/objectives/redq.py index 0429cac96b9..039b5c65b9d 100644 --- a/torchrl/objectives/redq.py +++ b/torchrl/objectives/redq.py @@ -437,7 +437,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: *self.actor_network.in_keys ) # next_observation -> tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0) - tensordict_actor = tensordict_actor.contiguous() + # tensordict_actor = tensordict_actor.contiguous() with set_exploration_type(ExplorationType.RANDOM): if self.gSDE: diff --git a/torchrl/objectives/td3.py b/torchrl/objectives/td3.py index 78445593f4d..62f0e793f29 100644 --- a/torchrl/objectives/td3.py +++ b/torchrl/objectives/td3.py @@ -343,6 +343,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: obs_keys = self.actor_network.in_keys tensordict_save = tensordict tensordict = tensordict.clone(False) + act = tensordict.get(self.tensor_keys.action) + action_shape = act.shape + action_device = act.device + # computing early for reprod + noise = torch.normal( + mean=torch.zeros(action_shape), + std=torch.full(action_shape, self.policy_noise), + ).to(action_device) + noise = noise.clamp(-self.noise_clip, self.noise_clip) tensordict_actor_grad = tensordict.select( *obs_keys @@ -351,24 +360,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase: *self.actor_network.in_keys ) # next_observation -> tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0) - tensordict_actor = tensordict_actor.contiguous() - + # DO NOT call contiguous bc we'll update the tds later actor_output_td = self._vmap_actor_network00( tensordict_actor, self._cached_stack_actor_params, ) # add noise to target policy - action = actor_output_td[1].get(self.tensor_keys.action) - noise = torch.normal( - mean=torch.zeros(action.shape), - std=torch.full(action.shape, self.policy_noise), - ).to(action.device) - noise = noise.clamp(-self.noise_clip, self.noise_clip) - - next_action = (actor_output_td[1][self.tensor_keys.action] + noise).clamp( + actor_output_td1 = actor_output_td[1] + next_action = (actor_output_td1.get(self.tensor_keys.action) + noise).clamp( self.min_action, self.max_action ) - actor_output_td[1].set(self.tensor_keys.action, next_action) + actor_output_td1.set(self.tensor_keys.action, next_action) tensordict_actor.set( self.tensor_keys.action, actor_output_td.get(self.tensor_keys.action),