Skip to content

Commit

Permalink
[BugFix] Pickable buffer (pytorch#1410)
Browse files Browse the repository at this point in the history
  • Loading branch information
albertbou92 authored Jul 26, 2023
1 parent 3558061 commit 3f4e9aa
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 0 deletions.
12 changes: 12 additions & 0 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import argparse
import importlib
import pickle
import sys
from functools import partial
from unittest import mock
Expand Down Expand Up @@ -242,6 +243,17 @@ def test_index(self, rb_type, sampler, writer, storage, size):
b = b.all()
assert b

def test_pickable(self, rb_type, sampler, writer, storage, size):

rb = self._get_rb(
rb_type=rb_type, sampler=sampler, writer=writer, storage=storage, size=size
)
serialized = pickle.dumps(rb)
rb2 = pickle.loads(serialized)
assert rb.__dict__.keys() == rb2.__dict__.keys()
for key in sorted(rb.__dict__.keys()):
assert isinstance(rb.__dict__[key], type(rb2.__dict__[key]))


@pytest.mark.parametrize("storage_type", [TensorStorage])
class TestStorages:
Expand Down
21 changes: 21 additions & 0 deletions torchrl/data/replay_buffers/replay_buffers.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,27 @@ def __iter__(self):
data = self.sample()
yield data

def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
_replay_lock = state.pop("_replay_lock", None)
_futures_lock = state.pop("_futures_lock", None)
if _replay_lock is not None:
state["_replay_lock_placeholder"] = None
if _futures_lock is not None:
state["_futures_lock_placeholder"] = None
return state

def __setstate__(self, state: Dict[str, Any]):
if "_replay_lock_placeholder" in state:
state.pop("_replay_lock_placeholder")
_replay_lock = threading.RLock()
state["_replay_lock"] = _replay_lock
if "_futures_lock_placeholder" in state:
state.pop("_futures_lock_placeholder")
_futures_lock = threading.RLock()
state["_futures_lock"] = _futures_lock
self.__dict__.update(state)


class PrioritizedReplayBuffer(ReplayBuffer):
"""Prioritized replay buffer.
Expand Down

0 comments on commit 3f4e9aa

Please sign in to comment.