-
Notifications
You must be signed in to change notification settings - Fork 0
/
deterministic_mdp.py
305 lines (267 loc) · 11.6 KB
/
deterministic_mdp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
import abc
import os
from collections import Counter
from functools import cached_property
from typing import Union
import numpy as np
import tqdm
from imitation.data.types import TrajectoryWithRew
from imitation.rewards import reward_function, reward_nets
from scipy import sparse
class DeterministicMDP(abc.ABC):
"""
A deterministic MDP.
"""
def __init__(self, *args, **kwargs):
"""
Initialize the DeterministicMDP with generic arguments and keyword arguments.
This allows for flexible initialization of subclasses with different parameters.
Args:
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
self._args = args
self._kwargs = kwargs
@property
@abc.abstractmethod
def actions(self):
"""
Return a list of all possible actions in the MDP.
"""
raise NotImplementedError
@abc.abstractmethod
def successor(self, state, action):
"""
Given a state and action, return the successor state and reward.
"""
raise NotImplementedError
def reward(self, state, action):
"""
Given a state and action, return the reward.
"""
_, reward = self.successor(state, action)
return reward
@abc.abstractmethod
def get_start_states(self):
"""
Return a list of all possible starting states of the MDP.
"""
raise NotImplementedError
@cached_property
def states(self):
"""
Return a list of all states in the MDP.
"""
visited = []
visited_states_encoded = set()
queue = self.get_start_states()
while queue:
state = queue.pop()
visited.append(state)
for action in self.actions:
next_state = self.successor(state, action)[0]
encoded_next_state = self.encode_state(next_state)
if encoded_next_state not in visited_states_encoded:
visited_states_encoded.add(encoded_next_state)
queue.append(next_state)
return visited
@abc.abstractmethod
def encode_state(self, state):
"""
Encode a state as a string.
"""
raise NotImplementedError
@abc.abstractmethod
def encode_action(self, action):
"""
Encode an action as a string.
"""
raise NotImplementedError
def encode_mdp_params(self):
"""
Encode the MDP parameters as a string. Used for saving/loading the transition matrix and reward vector for the
MDP, greatly speeding up the process of computing optimal policies.
By default, this just returns the name of the class. Override this if you want to save multiple MDPs of the same
class with different parameters.
Returns:
A string encoding the MDP parameters.
"""
return self.__class__.__name__
def get_state_index(self, state):
if not hasattr(self, "_state_index"):
self._state_index = {
self.encode_state(state): i
for i, state in tqdm.tqdm(
enumerate(self.states), desc="Constructing state index", total=len(self.states)
)
}
return self._state_index[self.encode_state(state)]
def get_action_index(self, action):
if not hasattr(self, "_action_index"):
self._action_index = {self.encode_action(action): i for i, action in enumerate(self.actions)}
return self._action_index[self.encode_action(action)]
@property
def reward_fn_vectorized(self):
"""
Return a vectorized version of the reward function. Useful for setting reward_fn in a TrajectoryGenerator to
the true reward function.
"""
def env_reward_fn(state, action, *args):
return np.array([self.reward(state, action) for state, action in zip(state, action)])
return env_reward_fn
def get_sparse_transition_matrix_and_reward_vector(
self,
alt_reward_fn: Union[reward_function.RewardFn, reward_nets.RewardNet] = None,
):
"""
Produce the data structures needed to run value iteration. Specifically, the sparse transition matrix and the
reward vector. The transition matrix is a sparse matrix of shape (num_states * num_actions, num_states), and the
reward vector is a vector of length num_states * num_actions.
Args:
alt_reward_fn (reward_function.RewardFn or reward_nets.RewardNet): If not None, this reward function will
be used instead of the one specified in the MDP.
Returns:
A tuple of (transition_matrix, reward_vector).
"""
if not (hasattr(self, "_sparse_transition_matrix") and hasattr(self, "_reward_vector")):
if not self._load_sparse_transition_matrix_and_reward_vector_from_file():
self._compute_sparse_transition_matrix_and_reward_vector()
self._save_sparse_transition_matrix_and_reward_vector_to_file()
if alt_reward_fn is not None:
# TODO: might have a batch size issue here; trying to predict for |S| * |A| inputs.
state_inputs = np.repeat(self.states, len(self.actions), axis=0)
action_inputs = np.tile(self.actions, (len(self.states)))
if isinstance(alt_reward_fn, reward_nets.RewardNet):
rewards = alt_reward_fn.predict(
state=state_inputs,
action=action_inputs,
next_state=state_inputs,
done=np.zeros_like(state_inputs, dtype=np.bool),
)
else:
# Use the reward_function.RewardFn protocol
rewards = np.array(
alt_reward_fn(
state_inputs,
action_inputs,
state_inputs,
np.zeros_like(state_inputs, dtype=np.bool_),
)
)
return self._sparse_transition_matrix, rewards
return self._sparse_transition_matrix, self._reward_vector
def _compute_sparse_transition_matrix_and_reward_vector(self):
"""
Compute the sparse transition matrix and reward vector for this MDP. This is a helper function for
get_sparse_transition_matrix_and_reward_vector, which caches the results.
"""
num_states = len(self.states)
num_actions = len(self.actions)
transitions = []
rewards = []
for state in tqdm.tqdm(self.states, desc="Constructing transition matrix"):
for action in self.actions:
successor_state, reward = self.successor(state, action)
transitions.append(self.get_state_index(successor_state))
rewards.append(reward)
transitions = np.array(transitions, dtype=np.int32)
rewards = np.array(rewards, dtype=np.float32)
self._reward_vector = rewards
data = np.ones_like(transitions, dtype=np.float32)
row_indices = np.arange(num_states * num_actions, dtype=np.int32)
col_indices = transitions
transition_matrix = sparse.csr_matrix(
(data, (row_indices, col_indices)), shape=(num_states * num_actions, num_states)
)
self._sparse_transition_matrix = transition_matrix
def _save_sparse_transition_matrix_and_reward_vector_to_file(self, env_matrix_dir="env_matrices"):
"""
Save the sparse transition matrix and reward vector to a file.
Args:
env_matrix_dir (str): Path to the directory containing the matrices. Defaults to "env_matrices".
"""
os.makedirs(env_matrix_dir, exist_ok=True)
sparse.save_npz(
os.path.join(env_matrix_dir, f"{self.encode_mdp_params()}_transition.npz"),
self._sparse_transition_matrix,
)
np.save(os.path.join(env_matrix_dir, f"{self.encode_mdp_params()}_reward.npy"), self._reward_vector)
def _load_sparse_transition_matrix_and_reward_vector_from_file(self, env_matrix_dir="env_matrices"):
"""
Load the sparse transition matrix and reward vector from a file.
Args:
env_matrix_dir (str): Path to the directory containing the matrices. Defaults to "env_matrices".
Returns:
Whether the file exists.
"""
try:
self._sparse_transition_matrix = sparse.load_npz(
os.path.join(env_matrix_dir, f"{self.encode_mdp_params()}_transition.npz")
)
self._reward_vector = np.load(os.path.join(env_matrix_dir, f"{self.encode_mdp_params()}_reward.npy"))
return True
except FileNotFoundError:
return False
def rollout_with_policy(
self,
policy,
fixed_horizon=None,
epsilon=None,
seed=None,
render=False,
logging_callback=None,
) -> TrajectoryWithRew:
"""
Runs a rollout of the environment using the given tabular policy.
Args:
policy (value_iteration.TabularPolicy): TabularPolicy for this environment.
fixed_horizon (int): If not None, the rollout will end after this many steps, regardless of whether the
agent gets stuck.
epsilon (float): If not None, the policy will be epsilon-greedy.
seed (int): If not None, the environment will be seeded with this value.
render (bool): If True, the environment will be rendered after each step.
logging_callback (callable): If not None, this function will be called after each step with the current
state, action, and reward.
Returns:
The trajectory generated by the policy as a TrajectoryWithRew object.
"""
state = self.reset(seed=seed)
# TODO: Might need to store state as a numerical array instead of human-readable dicts
states = [state]
state_indices = [self.get_state_index(state)]
actions = []
rewards = []
done = False
if render:
self.render()
while not done or (fixed_horizon is not None and len(actions) < fixed_horizon):
if epsilon is not None and np.random.random() < epsilon:
action = np.random.choice(self.actions)
else:
action = policy.predict(state)
next_state, reward, done, _ = self.step(self.get_action_index(action))
if logging_callback is not None:
logging_callback(state, action, reward)
states.append(next_state)
state_indices.append(self.get_state_index(next_state))
actions.append(action)
rewards.append(reward)
if self.get_state_index(next_state) in state_indices[:-1] and fixed_horizon is None:
if render:
print("Repeated state, ending rollout early.")
# Policy is deterministic, so if we've been here before, we're in a loop.
# StealingGridworld only has one-off rewards, so we can just terminate as long as it's the only
# environment in use.
break
state = next_state
if render:
self.render()
if render:
print(f"Total reward: {sum(rewards)}")
return TrajectoryWithRew(
obs=np.array(states, dtype=np.int16),
acts=np.array(actions, dtype=np.int16),
rews=np.array(rewards, dtype=float),
terminal=done,
infos=None,
)