forked from TikhonJelvis/RL-book
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexperience_replay.py
47 lines (38 loc) · 1.47 KB
/
experience_replay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from typing import Generic, Iterable, Iterator, List, TypeVar, Callable, \
Sequence
from rl.distribution import Categorical
T = TypeVar('T')
class ExperienceReplayMemory(Generic[T]):
saved_transitions: List[T]
time_weights_func: Callable[[int], float]
weights: List[float]
weights_sum: float
def __init__(
self,
time_weights_func: Callable[[int], float] = lambda _: 1.0,
):
self.saved_transitions = []
self.time_weights_func = time_weights_func
self.weights = []
self.weights_sum = 0.0
def add_data(self, transition: T) -> None:
self.saved_transitions.append(transition)
weight: float = self.time_weights_func(len(self.saved_transitions) - 1)
self.weights.append(weight)
self.weights_sum += weight
def sample_mini_batch(self, mini_batch_size: int) -> Sequence[T]:
num_transitions: int = len(self.saved_transitions)
return Categorical(
{tr: self.weights[num_transitions - 1 - i] / self.weights_sum
for i, tr in enumerate(self.saved_transitions)}
).sample_n(min(mini_batch_size, num_transitions))
def replay(
self,
transitions: Iterable[T],
mini_batch_size: int
) -> Iterator[Sequence[T]]:
for transition in transitions:
self.add_data(transition)
yield self.sample_mini_batch(mini_batch_size)
while True:
yield self.sample_mini_batch(mini_batch_size)