Skip to content

Commit f2f8b4f

Browse files
committed
In __init__, iterating over FinitProbabilityDistribution simplified
1 parent 09958fe commit f2f8b4f

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

rl/markov_decision_process.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __init__(
134134
non_terminals: Set[S] = set(mapping.keys())
135135
self.mapping = {NonTerminal(s): {a: Categorical(
136136
{(NonTerminal(s1) if s1 in non_terminals else Terminal(s1), r): p
137-
for (s1, r), p in v.table().items()}
137+
for (s1, r), p in v}
138138
) for a, v in d.items()} for s, d in mapping.items()}
139139
self.non_terminal_states = list(self.mapping.keys())
140140

@@ -165,7 +165,7 @@ def apply_finite_policy(self, policy: FinitePolicy[S, A])\
165165
= defaultdict(float)
166166
actions = policy.act(state)
167167
for action, p_action in actions:
168-
for (s1, r), p in action_map[action].table().items():
168+
for (s1, r), p in action_map[action]:
169169
outcomes[(s1.state, r)] += p_action * p
170170

171171
transition_mapping[state.state] = Categorical(outcomes)

rl/markov_process.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __init__(self, transition_map: Mapping[S, FiniteDistribution[S]]):
104104
self.transition_map = {
105105
NonTerminal(s): Categorical(
106106
{(NonTerminal(s1) if s1 in non_terminals else Terminal(s1)): p
107-
for s1, p in v.table().items()}
107+
for s1, p in v}
108108
) for s, v in transition_map.items()
109109
}
110110
self.non_terminal_states = list(self.transition_map.keys())
@@ -272,7 +272,7 @@ def __init__(
272272
self.transition_reward_map = {
273273
NonTerminal(s): Categorical(
274274
{(NonTerminal(s1) if s1 in nt else Terminal(s1), r): p
275-
for (s1, r), p in v.table().items()}
275+
for (s1, r), p in v}
276276
) for s, v in transition_reward_map.items()
277277
}
278278

0 commit comments

Comments
 (0)