Skip to content

Commit

Permalink
[Feature] Remove the Nd*TensorSpec classes (pytorch#772)
Browse files Browse the repository at this point in the history
  • Loading branch information
riiswa authored Dec 31, 2022
1 parent 578938a commit f6df86c
Show file tree
Hide file tree
Showing 31 changed files with 361 additions and 485 deletions.
4 changes: 2 additions & 2 deletions docs/source/reference/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ as shape, device, dtype and domain.
BoundedTensorSpec
OneHotDiscreteTensorSpec
UnboundedContinuousTensorSpec
NdBoundedTensorSpec
NdUnboundedContinuousTensorSpec
BoundedTensorSpec
UnboundedContinuousTensorSpec
BinaryDiscreteTensorSpec
MultOneHotDiscreteTensorSpec
DiscreteTensorSpec
Expand Down
6 changes: 3 additions & 3 deletions examples/dreamer/dreamer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, field as dataclass_field
from typing import Any, Callable, Optional, Sequence, Union

from torchrl.data import NdUnboundedContinuousTensorSpec
from torchrl.data import UnboundedContinuousTensorSpec
from torchrl.envs import ParallelEnv
from torchrl.envs.common import EnvBase
from torchrl.envs.env_creator import env_creator, EnvCreator
Expand Down Expand Up @@ -125,8 +125,8 @@ def make_env_transforms(
)

default_dict = {
"state": NdUnboundedContinuousTensorSpec(cfg.state_dim),
"belief": NdUnboundedContinuousTensorSpec(cfg.rssm_hidden_dim),
"state": UnboundedContinuousTensorSpec(cfg.state_dim),
"belief": UnboundedContinuousTensorSpec(cfg.rssm_hidden_dim),
}
env.append_transform(
TensorDictPrimer(random=False, default_value=0, **default_dict)
Expand Down
66 changes: 27 additions & 39 deletions test/mocking_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
CompositeSpec,
DiscreteTensorSpec,
MultOneHotDiscreteTensorSpec,
NdBoundedTensorSpec,
NdUnboundedContinuousTensorSpec,
OneHotDiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
Expand All @@ -26,20 +24,16 @@
"one_hot": OneHotDiscreteTensorSpec,
"categorical": DiscreteTensorSpec,
"unbounded": UnboundedContinuousTensorSpec,
"ndbounded": NdBoundedTensorSpec,
"ndunbounded": NdUnboundedContinuousTensorSpec,
"binary": BinaryDiscreteTensorSpec,
"mult_one_hot": MultOneHotDiscreteTensorSpec,
"composite": CompositeSpec,
}

default_spec_kwargs = {
BoundedTensorSpec: {"minimum": -1.0, "maximum": 1.0},
OneHotDiscreteTensorSpec: {"n": 7},
DiscreteTensorSpec: {"n": 7},
UnboundedContinuousTensorSpec: {},
NdBoundedTensorSpec: {"minimum": -torch.ones(4), "maxmimum": torch.ones(4)},
NdUnboundedContinuousTensorSpec: {
BoundedTensorSpec: {"minimum": -torch.ones(4), "maximum": torch.ones(4)},
UnboundedContinuousTensorSpec: {
"shape": [
7,
]
Expand Down Expand Up @@ -114,13 +108,13 @@ def __new__(
**kwargs,
):
if action_spec is None:
action_spec = NdUnboundedContinuousTensorSpec((1,))
action_spec = UnboundedContinuousTensorSpec((1,))
if observation_spec is None:
observation_spec = CompositeSpec(
observation=NdUnboundedContinuousTensorSpec((1,))
observation=UnboundedContinuousTensorSpec((1,))
)
if reward_spec is None:
reward_spec = NdUnboundedContinuousTensorSpec((1,))
reward_spec = UnboundedContinuousTensorSpec((1,))
if input_spec is None:
input_spec = CompositeSpec(action=action_spec)
cls._reward_spec = reward_spec
Expand Down Expand Up @@ -175,18 +169,18 @@ def __new__(
**kwargs,
):
if action_spec is None:
action_spec = NdUnboundedContinuousTensorSpec((1,))
action_spec = UnboundedContinuousTensorSpec((1,))
if input_spec is None:
input_spec = CompositeSpec(
action=action_spec,
observation=NdUnboundedContinuousTensorSpec((1,)),
observation=UnboundedContinuousTensorSpec((1,)),
)
if observation_spec is None:
observation_spec = CompositeSpec(
observation=NdUnboundedContinuousTensorSpec((1,))
observation=UnboundedContinuousTensorSpec((1,))
)
if reward_spec is None:
reward_spec = NdUnboundedContinuousTensorSpec((1,))
reward_spec = UnboundedContinuousTensorSpec((1,))
cls._reward_spec = reward_spec
cls._observation_spec = observation_spec
cls._input_spec = input_spec
Expand Down Expand Up @@ -283,8 +277,8 @@ def __new__(
if observation_spec is None:
cls.out_key = "observation"
observation_spec = CompositeSpec(
observation=NdUnboundedContinuousTensorSpec(shape=torch.Size([size])),
observation_orig=NdUnboundedContinuousTensorSpec(
observation=UnboundedContinuousTensorSpec(shape=torch.Size([size])),
observation_orig=UnboundedContinuousTensorSpec(
shape=torch.Size([size])
),
)
Expand Down Expand Up @@ -370,13 +364,13 @@ def __new__(
if observation_spec is None:
cls.out_key = "observation"
observation_spec = CompositeSpec(
observation=NdUnboundedContinuousTensorSpec(shape=torch.Size([size])),
observation_orig=NdUnboundedContinuousTensorSpec(
observation=UnboundedContinuousTensorSpec(shape=torch.Size([size])),
observation_orig=UnboundedContinuousTensorSpec(
shape=torch.Size([size])
),
)
if action_spec is None:
action_spec = NdBoundedTensorSpec(-1, 1, (7,))
action_spec = BoundedTensorSpec(-1, 1, (7,))
if reward_spec is None:
reward_spec = UnboundedContinuousTensorSpec()

Expand Down Expand Up @@ -471,10 +465,8 @@ def __new__(
if observation_spec is None:
cls.out_key = "pixels"
observation_spec = CompositeSpec(
pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size([1, 7, 7])),
pixels_orig=NdUnboundedContinuousTensorSpec(
shape=torch.Size([1, 7, 7])
),
pixels=UnboundedContinuousTensorSpec(shape=torch.Size([1, 7, 7])),
pixels_orig=UnboundedContinuousTensorSpec(shape=torch.Size([1, 7, 7])),
)
if action_spec is None:
action_spec = OneHotDiscreteTensorSpec(7)
Expand Down Expand Up @@ -523,10 +515,8 @@ def __new__(
if observation_spec is None:
cls.out_key = "pixels"
observation_spec = CompositeSpec(
pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])),
pixels_orig=NdUnboundedContinuousTensorSpec(
shape=torch.Size([7, 7, 3])
),
pixels=UnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])),
pixels_orig=UnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])),
)
if action_spec is None:
action_spec_cls = (
Expand Down Expand Up @@ -585,14 +575,14 @@ def __new__(
if observation_spec is None:
cls.out_key = "pixels"
observation_spec = CompositeSpec(
pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size(pixel_shape)),
pixels_orig=NdUnboundedContinuousTensorSpec(
pixels=UnboundedContinuousTensorSpec(shape=torch.Size(pixel_shape)),
pixels_orig=UnboundedContinuousTensorSpec(
shape=torch.Size(pixel_shape)
),
)

if action_spec is None:
action_spec = NdBoundedTensorSpec(-1, 1, pixel_shape[-1])
action_spec = BoundedTensorSpec(-1, 1, pixel_shape[-1])

if reward_spec is None:
reward_spec = UnboundedContinuousTensorSpec()
Expand Down Expand Up @@ -634,10 +624,8 @@ def __new__(
if observation_spec is None:
cls.out_key = "pixels"
observation_spec = CompositeSpec(
pixels=NdUnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])),
pixels_orig=NdUnboundedContinuousTensorSpec(
shape=torch.Size([7, 7, 3])
),
pixels=UnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])),
pixels_orig=UnboundedContinuousTensorSpec(shape=torch.Size([7, 7, 3])),
)
return super().__new__(
*args,
Expand Down Expand Up @@ -696,13 +684,13 @@ def __init__(
batch_size=batch_size,
)
self.observation_spec = CompositeSpec(
hidden_observation=NdUnboundedContinuousTensorSpec((4,))
hidden_observation=UnboundedContinuousTensorSpec((4,))
)
self.input_spec = CompositeSpec(
hidden_observation=NdUnboundedContinuousTensorSpec((4,)),
action=NdUnboundedContinuousTensorSpec((1,)),
hidden_observation=UnboundedContinuousTensorSpec((4,)),
action=UnboundedContinuousTensorSpec((1,)),
)
self.reward_spec = NdUnboundedContinuousTensorSpec((1,))
self.reward_spec = UnboundedContinuousTensorSpec((1,))

def _reset(self, tensordict: TensorDict, **kwargs) -> TensorDict:
td = TensorDict(
Expand Down
8 changes: 2 additions & 6 deletions test/test_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,7 @@
RandomPolicy,
)
from torchrl.collectors.utils import split_trajectories
from torchrl.data import (
CompositeSpec,
NdUnboundedContinuousTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.data import CompositeSpec, UnboundedContinuousTensorSpec
from torchrl.envs import EnvCreator, ParallelEnv, SerialEnv
from torchrl.envs.libs.gym import _has_gym, GymEnv
from torchrl.envs.transforms import TransformedEnv, VecNorm
Expand Down Expand Up @@ -942,7 +938,7 @@ def test_collector_output_keys(collector_class, init_random_frames, explicit_spe
],
}
if explicit_spec:
hidden_spec = NdUnboundedContinuousTensorSpec((1, hidden_size))
hidden_spec = UnboundedContinuousTensorSpec((1, hidden_size))
policy_kwargs["spec"] = CompositeSpec(
action=UnboundedContinuousTensorSpec(),
hidden1=hidden_spec,
Expand Down
32 changes: 16 additions & 16 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
from tensordict.tensordict import assert_allclose_td, TensorDict
from torch import autograd, nn
from torchrl.data import (
BoundedTensorSpec,
CompositeSpec,
DiscreteTensorSpec,
MultOneHotDiscreteTensorSpec,
NdBoundedTensorSpec,
NdUnboundedContinuousTensorSpec,
OneHotDiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.data.postprocs.postprocs import MultiStep
from torchrl.envs.model_based.dreamer import DreamerEnv
Expand Down Expand Up @@ -130,7 +130,7 @@ def _create_mock_actor(
elif action_spec_type == "categorical":
action_spec = DiscreteTensorSpec(action_dim)
elif action_spec_type == "nd_bounded":
action_spec = NdBoundedTensorSpec(
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
else:
Expand Down Expand Up @@ -417,7 +417,7 @@ class TestDDPG:

def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
# Actor
action_spec = NdBoundedTensorSpec(
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
module = nn.Linear(obs_dim, action_dim)
Expand Down Expand Up @@ -647,7 +647,7 @@ class TestSAC:

def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
# Actor
action_spec = NdBoundedTensorSpec(
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
Expand Down Expand Up @@ -1026,7 +1026,7 @@ class TestREDQ:

def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
# Actor
action_spec = NdBoundedTensorSpec(
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
Expand Down Expand Up @@ -1481,7 +1481,7 @@ class TestPPO:

def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
# Actor
action_spec = NdBoundedTensorSpec(
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
Expand All @@ -1504,7 +1504,7 @@ def _create_mock_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):

def _create_mock_actor_value(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
# Actor
action_spec = NdBoundedTensorSpec(
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
base_layer = nn.Linear(obs_dim, 5)
Expand Down Expand Up @@ -1808,7 +1808,7 @@ class TestA2C:

def _create_mock_actor(self, batch=2, obs_dim=3, action_dim=4, device="cpu"):
# Actor
action_spec = NdBoundedTensorSpec(
action_spec = BoundedTensorSpec(
-torch.ones(action_dim), torch.ones(action_dim), (action_dim,)
)
net = NormalParamWrapper(nn.Linear(obs_dim, 2 * action_dim))
Expand Down Expand Up @@ -2022,7 +2022,7 @@ def test_reinforce_value_net(self, advantage, gradient_mode, delay_value):
distribution_class=TanhNormal,
return_log_prob=True,
in_keys=["loc", "scale"],
spec=NdUnboundedContinuousTensorSpec(n_act),
spec=UnboundedContinuousTensorSpec(n_act),
)
if advantage == "gae":
advantage = GAE(
Expand Down Expand Up @@ -2146,8 +2146,8 @@ def _create_value_data(
def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
mock_env = TransformedEnv(ContinuousActionConvMockEnv(pixel_shape=[3, 64, 64]))
default_dict = {
"state": NdUnboundedContinuousTensorSpec(state_dim),
"belief": NdUnboundedContinuousTensorSpec(rssm_hidden_dim),
"state": UnboundedContinuousTensorSpec(state_dim),
"belief": UnboundedContinuousTensorSpec(rssm_hidden_dim),
}
mock_env.append_transform(
TensorDictPrimer(random=False, default_value=0, **default_dict)
Expand Down Expand Up @@ -2221,8 +2221,8 @@ def _create_world_model_model(self, rssm_hidden_dim, state_dim, mlp_num_units=20
def _create_mb_env(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
mock_env = TransformedEnv(ContinuousActionConvMockEnv(pixel_shape=[3, 64, 64]))
default_dict = {
"state": NdUnboundedContinuousTensorSpec(state_dim),
"belief": NdUnboundedContinuousTensorSpec(rssm_hidden_dim),
"state": UnboundedContinuousTensorSpec(state_dim),
"belief": UnboundedContinuousTensorSpec(rssm_hidden_dim),
}
mock_env.append_transform(
TensorDictPrimer(random=False, default_value=0, **default_dict)
Expand Down Expand Up @@ -2270,8 +2270,8 @@ def _create_mb_env(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
def _create_actor_model(self, rssm_hidden_dim, state_dim, mlp_num_units=200):
mock_env = TransformedEnv(ContinuousActionConvMockEnv(pixel_shape=[3, 64, 64]))
default_dict = {
"state": NdUnboundedContinuousTensorSpec(state_dim),
"belief": NdUnboundedContinuousTensorSpec(rssm_hidden_dim),
"state": UnboundedContinuousTensorSpec(state_dim),
"belief": UnboundedContinuousTensorSpec(rssm_hidden_dim),
}
mock_env.append_transform(
TensorDictPrimer(random=False, default_value=0, **default_dict)
Expand Down
10 changes: 5 additions & 5 deletions test/test_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from scipy.stats import ttest_1samp
from tensordict.tensordict import TensorDict
from torch import nn
from torchrl.data import CompositeSpec, NdBoundedTensorSpec
from torchrl.data import BoundedTensorSpec, CompositeSpec
from torchrl.envs.transforms.transforms import gSDENoise
from torchrl.envs.utils import set_exploration_mode
from torchrl.modules import SafeModule, SafeSequential
Expand Down Expand Up @@ -61,7 +61,7 @@ def test_ou_wrapper(device, d_obs=4, d_act=6, batch=32, n_steps=100, seed=0):
torch.manual_seed(seed)
net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device)
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
action_spec = NdBoundedTensorSpec(-torch.ones(d_act), torch.ones(d_act), (d_act,))
action_spec = BoundedTensorSpec(-torch.ones(d_act), torch.ones(d_act), (d_act,))
policy = ProbabilisticActor(
spec=action_spec,
module=module,
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_additivegaussian_sd(
):
torch.manual_seed(seed)
net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device)
action_spec = NdBoundedTensorSpec(
action_spec = BoundedTensorSpec(
-torch.ones(d_act, device=device),
torch.ones(d_act, device=device),
(d_act,),
Expand Down Expand Up @@ -173,7 +173,7 @@ def test_additivegaussian_wrapper(
torch.manual_seed(seed)
net = NormalParamWrapper(nn.Linear(d_obs, 2 * d_act)).to(device)
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
action_spec = NdBoundedTensorSpec(
action_spec = BoundedTensorSpec(
-torch.ones(d_act, device=device),
torch.ones(d_act, device=device),
(d_act,),
Expand Down Expand Up @@ -244,7 +244,7 @@ def test_gsde(
module = SafeModule(wrapper, in_keys=in_keys, out_keys=["loc", "scale"])
distribution_class = TanhNormal
distribution_kwargs = {"min": -bound, "max": bound}
spec = NdBoundedTensorSpec(
spec = BoundedTensorSpec(
-torch.ones(action_dim) * bound, torch.ones(action_dim) * bound, (action_dim,)
).to(device)

Expand Down
Loading

0 comments on commit f6df86c

Please sign in to comment.