Skip to content

Commit

Permalink
[BugFix] Fix TD3 and compat with pytorch/tensordict#482 (pytorch#1375)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 9, 2023
1 parent 7c5ba3d commit 771ef81
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 38 deletions.
2 changes: 1 addition & 1 deletion benchmarks/test_objectives_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,7 +439,7 @@ def test_td3_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=
)

loss(td)
benchmark(loss, td)
benchmark.pedantic(loss, args=(td,), rounds=100, iterations=10)


def test_cql_speed(benchmark, n_obs=8, n_act=4, ncells=128, batch=128, n_hidden=64):
Expand Down
8 changes: 6 additions & 2 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -1771,7 +1771,6 @@ def test_constructor(self, spec, bounds):
@pytest.mark.parametrize("reward_key", ["reward", "reward2"])
@pytest.mark.parametrize("done_key", ["done", "done2"])
def test_td3_notensordict(self, observation_key, reward_key, done_key):

torch.manual_seed(self.seed)
actor = self._create_mock_actor(in_keys=[observation_key])
qvalue = self._create_mock_value(
Expand All @@ -1793,12 +1792,15 @@ def test_td3_notensordict(self, observation_key, reward_key, done_key):
td = TensorDict(kwargs, td.batch_size).unflatten_keys("_")

with pytest.warns(UserWarning, match="No target network updater has been"):
torch.manual_seed(0)
loss_val_td = loss(td)
torch.manual_seed(0)
loss_val = loss(**kwargs)
for i, key in enumerate(loss_val_td.keys()):
torch.testing.assert_close(loss_val_td.get(key), loss_val[i])
# test select
loss.select_out_keys("loss_actor", "loss_qvalue")
torch.manual_seed(0)
if torch.__version__ >= "2.0.0":
loss_actor, loss_qvalue = loss(**kwargs)
else:
Expand Down Expand Up @@ -3646,6 +3648,7 @@ def test_redq_batched(self, delay_qvalue, num_qvalue, device, td_est):
def test_redq_batcher(self, n, delay_qvalue, num_qvalue, device, gamma=0.9):
torch.manual_seed(self.seed)
td = self._create_seq_mock_data_redq(device=device)
assert td.names == td.get("next").names

actor = self._create_mock_actor(device=device)
qvalue = self._create_mock_qvalue(device=device)
Expand All @@ -3661,7 +3664,9 @@ def test_redq_batcher(self, n, delay_qvalue, num_qvalue, device, gamma=0.9):
ms = MultiStep(gamma=gamma, n_steps=n).to(device)

td_clone = td.clone()
assert td_clone.names == td_clone.get("next").names
ms_td = ms(td_clone)
assert ms_td.names == ms_td.get("next").names

torch.manual_seed(0)
np.random.seed(0)
Expand Down Expand Up @@ -5506,7 +5511,6 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value, td_est
retain_graph=True,
allow_unused=False,
)
print(advantage, gradient_mode, delay_value, td_est)

@pytest.mark.parametrize(
"td_est",
Expand Down
27 changes: 4 additions & 23 deletions torchrl/envs/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
# LICENSE file in the root directory of this source tree.
from __future__ import annotations

from copy import copy

import pkg_resources
import torch

Expand All @@ -20,12 +18,7 @@
set_interaction_mode as set_exploration_mode,
set_interaction_type as set_exploration_type,
)
from tensordict.tensordict import (
LazyStackedTensorDict,
NestedKey,
TensorDict,
TensorDictBase,
)
from tensordict.tensordict import LazyStackedTensorDict, NestedKey, TensorDictBase

__all__ = [
"exploration_mode",
Expand Down Expand Up @@ -204,7 +197,7 @@ def step_mdp(
if exclude_action:
excluded = excluded.union({action_key})
next_td = tensordict.get("next")
out = _clone_no_keys(next_td)
out = next_td.empty()

total_key = ()
if keep_other:
Expand All @@ -230,7 +223,7 @@ def _set_single_key(source, dest, key, clone=False):
if is_tensor_collection(val):
new_val = dest.get(k, None)
if new_val is None:
new_val = _clone_no_keys(val)
new_val = val.empty()
# dest.set(k, new_val)
dest._set_str(k, new_val, inplace=False, validated=True)
source = val
Expand All @@ -250,7 +243,7 @@ def _set(source, dest, key, total_key, excluded):
if is_tensor_collection(val):
new_val = dest.get(key, None)
if new_val is None:
new_val = _clone_no_keys(val)
new_val = val.empty()
non_empty_local = False
for subkey in val.keys():
non_empty_local = (
Expand All @@ -267,18 +260,6 @@ def _set(source, dest, key, total_key, excluded):
return non_empty


def _clone_no_keys(td):
return TensorDict(
source={},
batch_size=td.batch_size,
device=td.device,
names=copy(td._td_dim_names),
_run_checks=False,
_is_shared=td.is_shared(),
_is_memmap=td.is_memmap(),
)


def get_available_libraries():
"""Returns all the supported libraries."""
return SUPPORTED_LIBRARIES
Expand Down
2 changes: 1 addition & 1 deletion torchrl/objectives/redq.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
*self.actor_network.in_keys
) # next_observation ->
tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0)
tensordict_actor = tensordict_actor.contiguous()
# tensordict_actor = tensordict_actor.contiguous()

with set_exploration_type(ExplorationType.RANDOM):
if self.gSDE:
Expand Down
24 changes: 13 additions & 11 deletions torchrl/objectives/td3.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
obs_keys = self.actor_network.in_keys
tensordict_save = tensordict
tensordict = tensordict.clone(False)
act = tensordict.get(self.tensor_keys.action)
action_shape = act.shape
action_device = act.device
# computing early for reprod
noise = torch.normal(
mean=torch.zeros(action_shape),
std=torch.full(action_shape, self.policy_noise),
).to(action_device)
noise = noise.clamp(-self.noise_clip, self.noise_clip)

tensordict_actor_grad = tensordict.select(
*obs_keys
Expand All @@ -351,24 +360,17 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
*self.actor_network.in_keys
) # next_observation ->
tensordict_actor = torch.stack([tensordict_actor_grad, next_td_actor], 0)
tensordict_actor = tensordict_actor.contiguous()

# DO NOT call contiguous bc we'll update the tds later
actor_output_td = self._vmap_actor_network00(
tensordict_actor,
self._cached_stack_actor_params,
)
# add noise to target policy
action = actor_output_td[1].get(self.tensor_keys.action)
noise = torch.normal(
mean=torch.zeros(action.shape),
std=torch.full(action.shape, self.policy_noise),
).to(action.device)
noise = noise.clamp(-self.noise_clip, self.noise_clip)

next_action = (actor_output_td[1][self.tensor_keys.action] + noise).clamp(
actor_output_td1 = actor_output_td[1]
next_action = (actor_output_td1.get(self.tensor_keys.action) + noise).clamp(
self.min_action, self.max_action
)
actor_output_td[1].set(self.tensor_keys.action, next_action)
actor_output_td1.set(self.tensor_keys.action, next_action)
tensordict_actor.set(
self.tensor_keys.action,
actor_output_td.get(self.tensor_keys.action),
Expand Down

0 comments on commit 771ef81

Please sign in to comment.