Skip to content

Commit

Permalink
Add MultiDiscrete.__setstate__ for new start attribute (Farama-Fo…
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts authored Jul 3, 2023
1 parent 96d05af commit f399319
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 27 deletions.
19 changes: 17 additions & 2 deletions gymnasium/spaces/multi_discrete.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Implementation of a space that represents the cartesian product of `Discrete` spaces."""
from __future__ import annotations

from typing import Any, Sequence
from typing import Any, Iterable, Mapping, Sequence

import numpy as np
from numpy.typing import NDArray
Expand Down Expand Up @@ -180,7 +180,7 @@ def __repr__(self):
return f"MultiDiscrete({self.nvec}, start={self.start})"
return f"MultiDiscrete({self.nvec})"

def __getitem__(self, index: int):
def __getitem__(self, index: int | tuple[int, ...]):
"""Extract a subspace from this ``MultiDiscrete`` space."""
nvec = self.nvec[index]
start = self.start[index]
Expand Down Expand Up @@ -209,3 +209,18 @@ def __eq__(self, other: Any) -> bool:
and np.all(self.nvec == other.nvec)
and np.all(self.start == other.start)
)

def __setstate__(self, state: Iterable[tuple[str, Any]] | Mapping[str, Any]):
"""Used when loading a pickled space.
This method has to be implemented explicitly to allow for loading of legacy states.
Args:
state: The new state
"""
state = dict(state)

if "start" not in state:
state["start"] = np.zeros(state["_shape"], dtype=state["dtype"])

super().__setstate__(state)
44 changes: 19 additions & 25 deletions tests/spaces/test_discrete.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,29 @@
from copy import deepcopy

import numpy as np

from gymnasium.spaces import Discrete


def test_space_legacy_pickling():
"""Test the legacy pickle of Discrete that is missing the `start` parameter."""
legacy_state = {
"shape": (
1,
2,
3,
),
"dtype": np.int64,
"np_random": np.random.default_rng(),
"n": 3,
}
space = Discrete(1)
space.__setstate__(legacy_state)

assert space.shape == legacy_state["shape"]
assert space.np_random == legacy_state["np_random"]
assert space.n == 3
assert space.dtype == legacy_state["dtype"]

# Test that start is missing
assert "start" in space.__dict__
del space.__dict__["start"] # legacy did not include start param
assert "start" not in space.__dict__

space.__setstate__(legacy_state)
assert space.start == 0
# Test that start is corrected passed
space = Discrete(1, start=2)
state = space.__dict__

new_space = Discrete(1)
new_space.__setstate__(state)
assert new_space == space
assert new_space.start == 2

legacy_space = Discrete(1)
legacy_state = deepcopy(legacy_space.__dict__)
del legacy_state["start"]

new_legacy_space = Discrete(2)
new_legacy_space.__setstate__(legacy_state)
assert new_legacy_space == legacy_space
assert new_legacy_space.start == 0


def test_sample_mask():
Expand Down
23 changes: 23 additions & 0 deletions tests/spaces/test_multidiscrete.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from copy import deepcopy

import numpy as np
import pytest

Expand Down Expand Up @@ -153,3 +155,24 @@ def test_multidiscrete_start_contains():

assert [12, 23, 34] in space
assert [13, 23, 34] not in space


def test_space_legacy_pickling():
"""Test the legacy pickle of Discrete that is missing the `start` parameter."""
# Test that start is corrected passed
space = MultiDiscrete([1, 2, 3], start=[4, 5, 6])
state = space.__dict__

new_space = MultiDiscrete([1, 2, 3])
new_space.__setstate__(state)
assert new_space == space
assert np.all(new_space.start == np.array([4, 5, 6]))

legacy_space = MultiDiscrete([1, 2, 3])
legacy_state = deepcopy(legacy_space.__dict__)
del legacy_state["start"]

new_legacy_space = MultiDiscrete([1, 2, 3])
new_legacy_space.__setstate__(legacy_state)
assert new_legacy_space == legacy_space
assert np.all(new_legacy_space.start == np.array([0, 0, 0]))

0 comments on commit f399319

Please sign in to comment.